nbv_reconstruction/runners/inferencer.py

305 lines
15 KiB
Python
Raw Normal View History

2024-09-18 15:55:34 +08:00
import os
import json
2024-09-19 00:14:26 +08:00
from utils.render import RenderUtil
from utils.pose import PoseUtil
from utils.pts import PtsUtil
from utils.reconstruction import ReconstructionUtil
2024-09-18 15:55:34 +08:00
import torch
from tqdm import tqdm
2024-09-19 00:14:26 +08:00
import numpy as np
import pickle
2024-09-18 15:55:34 +08:00
from PytorchBoot.config import ConfigManager
import PytorchBoot.namespace as namespace
import PytorchBoot.stereotype as stereotype
from PytorchBoot.factory import ComponentFactory
from PytorchBoot.dataset import BaseDataset
from PytorchBoot.runners.runner import Runner
from PytorchBoot.utils import Log
from PytorchBoot.status import status_manager
2024-11-04 17:17:54 +08:00
from utils.data_load import DataLoadUtil
2024-09-27 16:01:07 +08:00
@stereotype.runner("inferencer")
2024-09-19 00:14:26 +08:00
class Inferencer(Runner):
2024-09-18 15:55:34 +08:00
def __init__(self, config_path):
super().__init__(config_path)
2024-09-19 00:14:26 +08:00
self.script_path = ConfigManager.get(namespace.Stereotype.RUNNER, "blender_script_path")
self.output_dir = ConfigManager.get(namespace.Stereotype.RUNNER, "output_dir")
2024-11-01 08:43:13 +00:00
self.voxel_size = ConfigManager.get(namespace.Stereotype.RUNNER, "voxel_size")
2024-09-18 15:55:34 +08:00
''' Pipeline '''
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
self.pipeline = self.pipeline.to(self.device)
''' Experiment '''
self.load_experiment("nbv_evaluator")
2024-11-04 17:17:54 +08:00
self.stat_result_path = os.path.join(self.output_dir, "stat.json")
if os.path.exists(self.stat_result_path):
with open(self.stat_result_path, "r") as f:
self.stat_result = json.load(f)
else:
self.stat_result = {}
2024-09-18 15:55:34 +08:00
''' Test '''
self.test_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TEST)
self.test_dataset_name_list = self.test_config["dataset_list"]
self.test_set_list = []
self.test_writer_list = []
seen_name = set()
for test_dataset_name in self.test_dataset_name_list:
if test_dataset_name not in seen_name:
seen_name.add(test_dataset_name)
else:
raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name))
test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name)
self.test_set_list.append(test_set)
self.print_info()
2024-09-19 00:14:26 +08:00
2024-09-18 15:55:34 +08:00
def run(self):
Log.info("Loading from epoch {}.".format(self.current_epoch))
2024-09-19 00:14:26 +08:00
self.inference()
Log.success("Inference finished.")
2024-09-18 15:55:34 +08:00
2024-09-19 00:14:26 +08:00
def inference(self):
2024-09-18 15:55:34 +08:00
self.pipeline.eval()
with torch.no_grad():
test_set: BaseDataset
for dataset_idx, test_set in enumerate(self.test_set_list):
2024-09-19 00:14:26 +08:00
status_manager.set_progress("inference", "inferencer", f"dataset", dataset_idx, len(self.test_set_list))
2024-09-18 15:55:34 +08:00
test_set_name = test_set.get_name()
2024-09-19 00:14:26 +08:00
2024-11-01 08:43:13 +00:00
total=int(len(test_set))
2024-11-04 17:17:54 +08:00
for i in tqdm(range(total), desc=f"Processing {test_set_name}", ncols=100):
data = test_set.__getitem__(i)
scene_name = data["scene_name"]
2024-11-01 22:51:16 +00:00
inference_result_path = os.path.join(self.output_dir, test_set_name, f"{scene_name}.pkl")
if os.path.exists(inference_result_path):
Log.info(f"Inference result already exists for scene: {scene_name}")
continue
2024-11-04 17:17:54 +08:00
2024-09-19 00:14:26 +08:00
status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total)
output = self.predict_sequence(data)
2024-11-01 21:58:44 +00:00
self.save_inference_result(test_set_name, data["scene_name"], output)
2024-09-19 00:14:26 +08:00
status_manager.set_progress("inference", "inferencer", f"dataset", len(self.test_set_list), len(self.test_set_list))
2024-11-04 17:17:54 +08:00
def predict_sequence(self, data, cr_increase_threshold=0.001, overlap_area_threshold=25, scan_points_threshold=10, max_iter=50, max_retry = 7):
2024-11-01 15:47:11 +00:00
scene_name = data["scene_name"]
2024-09-19 00:14:26 +08:00
Log.info(f"Processing scene: {scene_name}")
status_manager.set_status("inference", "inferencer", "scene", scene_name)
''' data for rendering '''
2024-11-01 15:47:11 +00:00
scene_path = data["scene_path"]
O_to_L_pose = data["O_to_L_pose"]
2024-11-01 08:43:13 +00:00
voxel_threshold = self.voxel_size
filter_degree = 75
down_sampled_model_pts = data["gt_pts"]
2024-11-01 15:47:11 +00:00
2024-11-01 08:43:13 +00:00
first_frame_to_world_9d = data["first_scanned_n_to_world_pose_9d"][0]
first_frame_to_world = np.eye(4)
first_frame_to_world[:3,:3] = PoseUtil.rotation_6d_to_matrix_numpy(first_frame_to_world_9d[:6])
first_frame_to_world[:3,3] = first_frame_to_world_9d[6:]
2024-09-19 00:14:26 +08:00
''' data for inference '''
input_data = {}
2024-11-04 17:17:54 +08:00
scanned_pts = []
2024-11-01 15:47:11 +00:00
input_data["combined_scanned_pts"] = torch.tensor(data["first_scanned_pts"][0], dtype=torch.float32).to(self.device).unsqueeze(0)
2024-11-01 08:43:13 +00:00
input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(data["first_scanned_n_to_world_pose_9d"], dtype=torch.float32).to(self.device)]
2024-09-19 00:14:26 +08:00
input_data["mode"] = namespace.Mode.TEST
2024-11-01 08:43:13 +00:00
input_pts_N = input_data["combined_scanned_pts"].shape[1]
2024-11-04 17:17:54 +08:00
root = os.path.dirname(scene_path)
display_table_info = DataLoadUtil.get_display_table_info(root, scene_name)
radius = display_table_info["radius"]
scan_points = np.asarray(ReconstructionUtil.generate_scan_points(display_table_top=0,display_table_radius=radius))
first_frame_target_pts, first_frame_target_normals, first_frame_scan_points_indices = RenderUtil.render_pts(first_frame_to_world, scene_path, self.script_path, scan_points, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose)
2024-09-19 00:14:26 +08:00
scanned_view_pts = [first_frame_target_pts]
2024-11-04 17:17:54 +08:00
history_indices = [first_frame_scan_points_indices]
2024-11-01 15:47:11 +00:00
last_pred_cr, added_pts_num = self.compute_coverage_rate(scanned_view_pts, None, down_sampled_model_pts, threshold=voxel_threshold)
2024-11-04 17:17:54 +08:00
scanned_pts.append(first_frame_target_pts)
retry_duplication_pose = []
retry_no_pts_pose = []
2024-11-04 17:17:54 +08:00
retry_overlap_pose = []
retry = 0
pred_cr_seq = [last_pred_cr]
2024-11-01 21:58:44 +00:00
success = 0
2024-11-02 21:54:46 +00:00
import time
while len(pred_cr_seq) < max_iter and retry < max_retry:
2024-11-02 21:54:46 +00:00
start_time = time.time()
2024-09-19 00:14:26 +08:00
output = self.pipeline(input_data)
2024-11-02 21:54:46 +00:00
end_time = time.time()
print(f"Time taken for inference: {end_time - start_time} seconds")
2024-09-23 14:30:51 +08:00
pred_pose_9d = output["pred_pose_9d"]
pred_pose = torch.eye(4, device=pred_pose_9d.device)
2024-09-19 00:14:26 +08:00
2024-09-23 14:30:51 +08:00
pred_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(pred_pose_9d[:,:6])[0]
pred_pose[:3,3] = pred_pose_9d[0,6:]
2024-09-19 00:14:26 +08:00
try:
2024-11-02 21:54:46 +00:00
start_time = time.time()
2024-11-04 17:17:54 +08:00
new_target_pts, new_target_normals, new_scan_points_indices = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, scan_points, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose)
#import ipdb; ipdb.set_trace()
if not ReconstructionUtil.check_scan_points_overlap(history_indices, new_scan_points_indices, scan_points_threshold):
curr_overlap_area_threshold = overlap_area_threshold
else:
curr_overlap_area_threshold = overlap_area_threshold * 0.5
downsampled_new_target_pts = PtsUtil.voxel_downsample_point_cloud(new_target_pts, voxel_threshold)
overlap, new_added_pts_num = ReconstructionUtil.check_overlap(downsampled_new_target_pts, down_sampled_model_pts, overlap_area_threshold = curr_overlap_area_threshold, voxel_size=voxel_threshold, require_new_added_pts_num = True)
if not overlap:
retry += 1
retry_overlap_pose.append(pred_pose.cpu().numpy().tolist())
continue
scanned_pts.append(new_target_pts)
history_indices.append(new_scan_points_indices)
2024-11-02 21:54:46 +00:00
end_time = time.time()
print(f"Time taken for rendering: {end_time - start_time} seconds")
2024-09-19 00:14:26 +08:00
except Exception as e:
Log.warning(f"Error in scene {scene_path}, {e}")
print("current pose: ", pred_pose)
print("curr_pred_cr: ", last_pred_cr)
2024-09-23 14:30:51 +08:00
retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist())
retry += 1
2024-09-19 00:14:26 +08:00
continue
2024-11-01 15:47:11 +00:00
if new_target_pts.shape[0] == 0:
print("no pts in new target")
retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist())
retry += 1
continue
2024-11-02 21:54:46 +00:00
start_time = time.time()
2024-11-04 17:17:54 +08:00
pred_cr, covered_pts_num = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold)
2024-11-02 21:54:46 +00:00
end_time = time.time()
print(f"Time taken for coverage rate computation: {end_time - start_time} seconds")
2024-11-01 15:47:11 +00:00
print(pred_cr, last_pred_cr, " max: ", data["seq_max_coverage_rate"])
2024-11-04 17:17:54 +08:00
print("new added pts num: ", new_added_pts_num)
2024-11-01 15:47:11 +00:00
if pred_cr >= data["seq_max_coverage_rate"] - 1e-3:
print("max coverage rate reached!: ", pred_cr)
2024-11-01 21:58:44 +00:00
success += 1
2024-11-04 17:17:54 +08:00
elif new_added_pts_num < 5:
success += 1
2024-11-01 15:47:11 +00:00
print("min added pts num reached!: ", new_added_pts_num)
2024-09-20 15:00:26 +08:00
if pred_cr <= last_pred_cr + cr_increase_threshold:
retry += 1
2024-09-23 14:30:51 +08:00
retry_duplication_pose.append(pred_pose.cpu().numpy().tolist())
continue
retry = 0
pred_cr_seq.append(pred_cr)
2024-11-01 15:47:11 +00:00
scanned_view_pts.append(new_target_pts)
down_sampled_new_pts_world = PtsUtil.random_downsample_point_cloud(new_target_pts, input_pts_N)
2024-09-18 15:55:34 +08:00
2024-11-01 15:47:11 +00:00
new_pts = down_sampled_new_pts_world
2024-09-23 14:30:51 +08:00
input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)]
2024-11-01 15:47:11 +00:00
combined_scanned_pts = np.concatenate([input_data["combined_scanned_pts"][0].cpu().numpy(), new_pts], axis=0)
voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_pts, 0.002)
2024-09-27 16:01:07 +08:00
random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N)
input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device)
2024-11-01 21:58:44 +00:00
if success > 3:
break
2024-09-19 00:14:26 +08:00
last_pred_cr = pred_cr
input_data["scanned_n_to_world_pose_9d"] = input_data["scanned_n_to_world_pose_9d"][0].cpu().numpy().tolist()
2024-09-19 00:14:26 +08:00
result = {
2024-11-04 17:17:54 +08:00
"scanned_pts": scanned_pts,
"pred_pose_9d_seq": input_data["scanned_n_to_world_pose_9d"],
2024-11-01 21:58:44 +00:00
"combined_scanned_pts": input_data["combined_scanned_pts"],
2024-09-19 00:14:26 +08:00
"target_pts_seq": scanned_view_pts,
"coverage_rate_seq": pred_cr_seq,
2024-11-01 21:58:44 +00:00
"max_coverage_rate": data["seq_max_coverage_rate"],
"pred_max_coverage_rate": max(pred_cr_seq),
"scene_name": scene_name,
"retry_no_pts_pose": retry_no_pts_pose,
"retry_duplication_pose": retry_duplication_pose,
2024-11-04 17:17:54 +08:00
"retry_overlap_pose": retry_overlap_pose,
2024-11-01 21:58:44 +00:00
"best_seq_len": data["best_seq_len"],
2024-09-19 00:14:26 +08:00
}
2024-09-27 16:01:07 +08:00
self.stat_result[scene_name] = {
"coverage_rate_seq": pred_cr_seq,
"pred_max_coverage_rate": max(pred_cr_seq),
"pred_seq_len": len(pred_cr_seq),
}
2024-11-01 21:58:44 +00:00
print('success rate: ', max(pred_cr_seq))
2024-09-27 16:01:07 +08:00
2024-09-19 00:14:26 +08:00
return result
def compute_coverage_rate(self, scanned_view_pts, new_pts, model_pts, threshold=0.005):
if new_pts is not None:
new_scanned_view_pts = scanned_view_pts + [new_pts]
else:
new_scanned_view_pts = scanned_view_pts
combined_point_cloud = np.vstack(new_scanned_view_pts)
down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold)
return ReconstructionUtil.compute_coverage_rate(model_pts, down_sampled_combined_point_cloud, threshold)
2024-09-18 15:55:34 +08:00
2024-09-19 00:14:26 +08:00
def save_inference_result(self, dataset_name, scene_name, output):
dataset_dir = os.path.join(self.output_dir, dataset_name)
if not os.path.exists(dataset_dir):
os.makedirs(dataset_dir)
2024-09-20 15:00:26 +08:00
output_path = os.path.join(dataset_dir, f"{scene_name}.pkl")
pickle.dump(output, open(output_path, "wb"))
2024-11-04 17:17:54 +08:00
with open(self.stat_result_path, "w") as f:
2024-09-27 16:01:07 +08:00
json.dump(self.stat_result, f)
2024-09-19 00:14:26 +08:00
2024-09-18 15:55:34 +08:00
def get_checkpoint_path(self, is_last=False):
return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME,
"Epoch_{}.pth".format(
self.current_epoch if self.current_epoch != -1 and not is_last else "last"))
def load_checkpoint(self, is_last=False):
self.load(self.get_checkpoint_path(is_last))
Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}")
if is_last:
checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME)
meta_path = os.path.join(checkpoint_root, "meta.json")
if not os.path.exists(meta_path):
raise FileNotFoundError(
"No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"]))
file_path = os.path.join(checkpoint_root, "meta.json")
with open(file_path, "r") as f:
meta = json.load(f)
self.current_epoch = meta["last_epoch"]
self.current_iter = meta["last_iter"]
def load_experiment(self, backup_name=None):
super().load_experiment(backup_name)
2024-09-19 00:14:26 +08:00
self.current_epoch = self.experiments_config["epoch"]
self.load_checkpoint(is_last=(self.current_epoch == -1))
2024-09-18 15:55:34 +08:00
def create_experiment(self, backup_name=None):
super().create_experiment(backup_name)
2024-09-19 00:14:26 +08:00
2024-09-18 15:55:34 +08:00
def load(self, path):
state_dict = torch.load(path)
2024-09-19 00:14:26 +08:00
self.pipeline.load_state_dict(state_dict)
2024-09-18 15:55:34 +08:00
def print_info(self):
def print_dataset(dataset: BaseDataset):
config = dataset.get_config()
name = dataset.get_name()
Log.blue(f"Dataset: {name}")
for k,v in config.items():
Log.blue(f"\t{k}: {v}")
super().print_info()
table_size = 70
Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+')
Log.blue(self.pipeline)
Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+')
for i, test_set in enumerate(self.test_set_list):
Log.blue(f"test dataset {i}: ")
print_dataset(test_set)
Log.blue(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+')