This commit is contained in:
hofee 2024-11-04 23:49:12 +08:00
parent 2b7243d1be
commit 5bcd0fc6e3
2 changed files with 14 additions and 18 deletions

View File

@ -20,18 +20,18 @@ runner:
voxel_size: 0.003 voxel_size: 0.003
dataset: dataset:
OmniObject3d_train: # OmniObject3d_train:
root_dir: "C:\\Document\\Datasets\\inference_test" # root_dir: "C:\\Document\\Datasets\\inference_test1"
model_dir: "C:\\Document\\Datasets\\scaled_object_meshes" # model_dir: "C:\\Document\\Datasets\\scaled_object_meshes"
source: seq_reconstruction_dataset_preprocessed # source: seq_reconstruction_dataset_preprocessed
split_file: "C:\\Document\\Datasets\\data_list\\sample.txt" # split_file: "C:\\Document\\Datasets\\data_list\\sample.txt"
type: test # type: test
filter_degree: 75 # filter_degree: 75
ratio: 1 # ratio: 1
batch_size: 1 # batch_size: 1
num_workers: 12 # num_workers: 12
pts_num: 8192 # pts_num: 8192
load_from_preprocess: True # load_from_preprocess: True
OmniObject3d_test: OmniObject3d_test:
root_dir: "C:\\Document\\Datasets\\inference_test" root_dir: "C:\\Document\\Datasets\\inference_test"

View File

@ -87,7 +87,7 @@ class Inferencer(Runner):
status_manager.set_progress("inference", "inferencer", f"dataset", len(self.test_set_list), len(self.test_set_list)) status_manager.set_progress("inference", "inferencer", f"dataset", len(self.test_set_list), len(self.test_set_list))
def predict_sequence(self, data, cr_increase_threshold=0.001, overlap_area_threshold=25, scan_points_threshold=10, max_iter=50, max_retry = 7): def predict_sequence(self, data, cr_increase_threshold=0, overlap_area_threshold=25, scan_points_threshold=10, max_iter=50, max_retry = 7):
scene_name = data["scene_name"] scene_name = data["scene_name"]
Log.info(f"Processing scene: {scene_name}") Log.info(f"Processing scene: {scene_name}")
status_manager.set_status("inference", "inferencer", "scene", scene_name) status_manager.set_status("inference", "inferencer", "scene", scene_name)
@ -106,7 +106,6 @@ class Inferencer(Runner):
''' data for inference ''' ''' data for inference '''
input_data = {} input_data = {}
scanned_pts = []
input_data["combined_scanned_pts"] = torch.tensor(data["first_scanned_pts"][0], dtype=torch.float32).to(self.device).unsqueeze(0) input_data["combined_scanned_pts"] = torch.tensor(data["first_scanned_pts"][0], dtype=torch.float32).to(self.device).unsqueeze(0)
input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(data["first_scanned_n_to_world_pose_9d"], dtype=torch.float32).to(self.device)] input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(data["first_scanned_n_to_world_pose_9d"], dtype=torch.float32).to(self.device)]
input_data["mode"] = namespace.Mode.TEST input_data["mode"] = namespace.Mode.TEST
@ -119,7 +118,6 @@ class Inferencer(Runner):
scanned_view_pts = [first_frame_target_pts] scanned_view_pts = [first_frame_target_pts]
history_indices = [first_frame_scan_points_indices] history_indices = [first_frame_scan_points_indices]
last_pred_cr, added_pts_num = self.compute_coverage_rate(scanned_view_pts, None, down_sampled_model_pts, threshold=voxel_threshold) last_pred_cr, added_pts_num = self.compute_coverage_rate(scanned_view_pts, None, down_sampled_model_pts, threshold=voxel_threshold)
scanned_pts.append(first_frame_target_pts)
retry_duplication_pose = [] retry_duplication_pose = []
retry_no_pts_pose = [] retry_no_pts_pose = []
retry_overlap_pose = [] retry_overlap_pose = []
@ -154,7 +152,6 @@ class Inferencer(Runner):
retry_overlap_pose.append(pred_pose.cpu().numpy().tolist()) retry_overlap_pose.append(pred_pose.cpu().numpy().tolist())
continue continue
scanned_pts.append(new_target_pts)
history_indices.append(new_scan_points_indices) history_indices.append(new_scan_points_indices)
end_time = time.time() end_time = time.time()
print(f"Time taken for rendering: {end_time - start_time} seconds") print(f"Time taken for rendering: {end_time - start_time} seconds")
@ -182,7 +179,7 @@ class Inferencer(Runner):
print("max coverage rate reached!: ", pred_cr) print("max coverage rate reached!: ", pred_cr)
success += 1 success += 1
elif new_added_pts_num < 5: elif new_added_pts_num < 5:
success += 1 #success += 1
print("min added pts num reached!: ", new_added_pts_num) print("min added pts num reached!: ", new_added_pts_num)
if pred_cr <= last_pred_cr + cr_increase_threshold: if pred_cr <= last_pred_cr + cr_increase_threshold:
retry += 1 retry += 1
@ -208,7 +205,6 @@ class Inferencer(Runner):
input_data["scanned_n_to_world_pose_9d"] = input_data["scanned_n_to_world_pose_9d"][0].cpu().numpy().tolist() input_data["scanned_n_to_world_pose_9d"] = input_data["scanned_n_to_world_pose_9d"][0].cpu().numpy().tolist()
result = { result = {
"scanned_pts": scanned_pts,
"pred_pose_9d_seq": input_data["scanned_n_to_world_pose_9d"], "pred_pose_9d_seq": input_data["scanned_n_to_world_pose_9d"],
"combined_scanned_pts": input_data["combined_scanned_pts"], "combined_scanned_pts": input_data["combined_scanned_pts"],
"target_pts_seq": scanned_view_pts, "target_pts_seq": scanned_view_pts,