This commit is contained in:
hofee 2024-11-25 09:41:28 +00:00
parent 04d3a359e1
commit 155b655938
2 changed files with 37 additions and 12 deletions

View File

@ -48,14 +48,17 @@ class SeqReconstructionDataset(BaseDataset):
for line in f: for line in f:
scene_name = line.strip() scene_name = line.strip()
scene_name_list.append(scene_name) scene_name_list.append(scene_name)
return scene_name_list return scene_name_list
def get_scene_name_list(self): def get_scene_name_list(self):
return self.scene_name_list return self.scene_name_list
def get_datalist(self): def get_datalist(self):
datalist = [] datalist = []
for scene_name in self.scene_name_list: total = len(self.scene_name_list)
for idx, scene_name in enumerate(self.scene_name_list):
print(f"processing {scene_name} ({idx}/{total})")
seq_num = DataLoadUtil.get_label_num(self.root_dir, scene_name) seq_num = DataLoadUtil.get_label_num(self.root_dir, scene_name)
scene_max_coverage_rate = 0 scene_max_coverage_rate = 0
max_coverage_rate_list = [] max_coverage_rate_list = []
@ -178,23 +181,41 @@ class SeqReconstructionDataset(BaseDataset):
# -------------- Debug ---------------- # # -------------- Debug ---------------- #
if __name__ == "__main__": if __name__ == "__main__":
import torch import torch
from tqdm import tqdm
import pickle
import os
seed = 0 seed = 0
torch.manual_seed(seed) torch.manual_seed(seed)
np.random.seed(seed) np.random.seed(seed)
config = { config = {
"root_dir": "/data/hofee/data/new_full_data", "root_dir": "/data/hofee/data/new_full_data",
"source": "seq_reconstruction_dataset", "source": "seq_reconstruction_dataset",
"split_file": "/data/hofee/data/sample.txt", "split_file": "/data/hofee/data/new_full_data_list/OmniObject3d_test.txt",
"load_from_preprocess": True, "load_from_preprocess": True,
"ratio": 0.5,
"batch_size": 2,
"filter_degree": 75, "filter_degree": 75,
"num_workers": 0, "num_workers": 0,
"pts_num": 4096, "pts_num": 8192,
"type": namespace.Mode.TRAIN, "type": namespace.Mode.TEST,
} }
ds = SeqReconstructionDataset(config)
print(len(ds))
print(ds.__getitem__(10))
output_dir = "/data/hofee/trash_can/output_inference_test"
new_output_dir = "/data/hofee/inference_test"
os.makedirs(output_dir, exist_ok=True)
os.makedirs(new_output_dir, exist_ok=True)
ds = SeqReconstructionDataset(config)
for i in tqdm(range(len(ds)), desc="processing dataset"):
output_path = os.path.join(output_dir, f"item_{i}.pkl")
if os.path.exists(output_path):
item = pickle.load(open(output_path, "rb"))
else:
item = ds.__getitem__(i)
for key, value in item.items():
if isinstance(value, np.ndarray):
item[key] = value.tolist()
new_output_path = os.path.join(new_output_dir, f"item_{i}.pkl")
with open(new_output_path, "wb") as f:
pickle.dump(item, f)

View File

@ -25,6 +25,7 @@ class InferencerServer(Runner):
self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name) self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
self.pipeline = self.pipeline.to(self.device) self.pipeline = self.pipeline.to(self.device)
self.pts_num = 8192 self.pts_num = 8192
self.voxel_size = 0.002
''' Experiment ''' ''' Experiment '''
self.load_experiment("inferencer_server") self.load_experiment("inferencer_server")
@ -34,8 +35,11 @@ class InferencerServer(Runner):
scanned_pts = data["scanned_pts"] scanned_pts = data["scanned_pts"]
scanned_n_to_world_pose_9d = data["scanned_n_to_world_pose_9d"] scanned_n_to_world_pose_9d = data["scanned_n_to_world_pose_9d"]
combined_scanned_views_pts = np.concatenate(scanned_pts, axis=0) combined_scanned_views_pts = np.concatenate(scanned_pts, axis=0)
voxel_downsampled_combined_scanned_pts = PtsUtil.voxel_downsample_point_cloud(
combined_scanned_views_pts, self.voxel_size
)
fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud( fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud(
combined_scanned_views_pts, self.pts_num, require_idx=True voxel_downsampled_combined_scanned_pts, self.pts_num, require_idx=True
) )
# combined_scanned_views_pts_mask = np.zeros(len(scanned_pts), dtype=np.uint8) # combined_scanned_views_pts_mask = np.zeros(len(scanned_pts), dtype=np.uint8)
# start_idx = 0 # start_idx = 0