Compare commits
No commits in common. "04d3a359e18b944089231b9db2044ea0e5ccc728" and "982a3b9b60816cb8e528ec11fd90b0170cce3c8a" have entirely different histories.
04d3a359e1
...
982a3b9b60
@ -1,6 +1,5 @@
|
|||||||
from PytorchBoot.application import PytorchBootApplication
|
from PytorchBoot.application import PytorchBootApplication
|
||||||
from runners.inferencer import Inferencer
|
from runners.inferencer import Inferencer
|
||||||
from runners.inference_server import InferencerServer
|
|
||||||
|
|
||||||
@PytorchBootApplication("inference")
|
@PytorchBootApplication("inference")
|
||||||
class InferenceApp:
|
class InferenceApp:
|
||||||
@ -15,17 +14,3 @@ class InferenceApp:
|
|||||||
Evaluator("path_to_your_eval_config").run()
|
Evaluator("path_to_your_eval_config").run()
|
||||||
'''
|
'''
|
||||||
Inferencer("./configs/local/inference_config.yaml").run()
|
Inferencer("./configs/local/inference_config.yaml").run()
|
||||||
|
|
||||||
@PytorchBootApplication("server")
|
|
||||||
class InferenceServerApp:
|
|
||||||
@staticmethod
|
|
||||||
def start():
|
|
||||||
'''
|
|
||||||
call default or your custom runners here, code will be executed
|
|
||||||
automatically when type "pytorch-boot run" or "ptb run" in terminal
|
|
||||||
|
|
||||||
example:
|
|
||||||
Trainer("path_to_your_train_config").run()
|
|
||||||
Evaluator("path_to_your_eval_config").run()
|
|
||||||
'''
|
|
||||||
InferencerServer("./configs/server/server_inference_server_config.yaml").run()
|
|
@ -1,53 +0,0 @@
|
|||||||
|
|
||||||
runner:
|
|
||||||
general:
|
|
||||||
seed: 0
|
|
||||||
device: cuda
|
|
||||||
cuda_visible_devices: "0,1,2,3,4,5,6,7"
|
|
||||||
|
|
||||||
experiment:
|
|
||||||
name: train_ab_global_only
|
|
||||||
root_dir: "experiments"
|
|
||||||
epoch: -1 # -1 stands for last epoch
|
|
||||||
|
|
||||||
pipeline: nbv_reconstruction_pipeline
|
|
||||||
voxel_size: 0.003
|
|
||||||
|
|
||||||
pipeline:
|
|
||||||
nbv_reconstruction_pipeline:
|
|
||||||
modules:
|
|
||||||
pts_encoder: pointnet_encoder
|
|
||||||
seq_encoder: transformer_seq_encoder
|
|
||||||
pose_encoder: pose_encoder
|
|
||||||
view_finder: gf_view_finder
|
|
||||||
eps: 1e-5
|
|
||||||
global_scanned_feat: True
|
|
||||||
|
|
||||||
module:
|
|
||||||
pointnet_encoder:
|
|
||||||
in_dim: 3
|
|
||||||
out_dim: 1024
|
|
||||||
global_feat: True
|
|
||||||
feature_transform: False
|
|
||||||
transformer_seq_encoder:
|
|
||||||
embed_dim: 256
|
|
||||||
num_heads: 4
|
|
||||||
ffn_dim: 256
|
|
||||||
num_layers: 3
|
|
||||||
output_dim: 1024
|
|
||||||
|
|
||||||
gf_view_finder:
|
|
||||||
t_feat_dim: 128
|
|
||||||
pose_feat_dim: 256
|
|
||||||
main_feat_dim: 2048
|
|
||||||
regression_head: Rx_Ry_and_T
|
|
||||||
pose_mode: rot_matrix
|
|
||||||
per_point_feature: False
|
|
||||||
sample_mode: ode
|
|
||||||
sampling_steps: 500
|
|
||||||
sde_mode: ve
|
|
||||||
pose_encoder:
|
|
||||||
pose_dim: 9
|
|
||||||
out_dim: 256
|
|
||||||
pts_num_encoder:
|
|
||||||
out_dim: 64
|
|
@ -50,9 +50,6 @@ class SeqReconstructionDataset(BaseDataset):
|
|||||||
scene_name_list.append(scene_name)
|
scene_name_list.append(scene_name)
|
||||||
return scene_name_list
|
return scene_name_list
|
||||||
|
|
||||||
def get_scene_name_list(self):
|
|
||||||
return self.scene_name_list
|
|
||||||
|
|
||||||
def get_datalist(self):
|
def get_datalist(self):
|
||||||
datalist = []
|
datalist = []
|
||||||
for scene_name in self.scene_name_list:
|
for scene_name in self.scene_name_list:
|
||||||
|
@ -13,7 +13,7 @@ from PytorchBoot.utils import Log
|
|||||||
|
|
||||||
from utils.pts import PtsUtil
|
from utils.pts import PtsUtil
|
||||||
|
|
||||||
@stereotype.runner("inferencer_server")
|
@stereotype.runner("inferencer")
|
||||||
class InferencerServer(Runner):
|
class InferencerServer(Runner):
|
||||||
def __init__(self, config_path):
|
def __init__(self, config_path):
|
||||||
super().__init__(config_path)
|
super().__init__(config_path)
|
||||||
@ -24,10 +24,9 @@ class InferencerServer(Runner):
|
|||||||
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
|
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
|
||||||
self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
|
self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
|
||||||
self.pipeline = self.pipeline.to(self.device)
|
self.pipeline = self.pipeline.to(self.device)
|
||||||
self.pts_num = 8192
|
|
||||||
|
|
||||||
''' Experiment '''
|
''' Experiment '''
|
||||||
self.load_experiment("inferencer_server")
|
self.load_experiment("nbv_evaluator")
|
||||||
|
|
||||||
def get_input_data(self, data):
|
def get_input_data(self, data):
|
||||||
input_data = {}
|
input_data = {}
|
||||||
@ -37,36 +36,28 @@ class InferencerServer(Runner):
|
|||||||
fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud(
|
fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud(
|
||||||
combined_scanned_views_pts, self.pts_num, require_idx=True
|
combined_scanned_views_pts, self.pts_num, require_idx=True
|
||||||
)
|
)
|
||||||
# combined_scanned_views_pts_mask = np.zeros(len(scanned_pts), dtype=np.uint8)
|
combined_scanned_views_pts_mask = np.zeros(len(scanned_pts), dtype=np.uint8)
|
||||||
# start_idx = 0
|
start_idx = 0
|
||||||
# for i in range(len(scanned_pts)):
|
for i in range(len(scanned_pts)):
|
||||||
# end_idx = start_idx + len(scanned_pts[i])
|
end_idx = start_idx + len(scanned_pts[i])
|
||||||
# combined_scanned_views_pts_mask[start_idx:end_idx] = i
|
combined_scanned_views_pts_mask[start_idx:end_idx] = i
|
||||||
# start_idx = end_idx
|
start_idx = end_idx
|
||||||
|
|
||||||
# fps_downsampled_combined_scanned_pts_mask = combined_scanned_views_pts_mask[fps_idx]
|
fps_downsampled_combined_scanned_pts_mask = combined_scanned_views_pts_mask[fps_idx]
|
||||||
|
|
||||||
input_data["scanned_pts"] = scanned_pts
|
input_data["scanned_pts_mask"] = np.asarray(fps_downsampled_combined_scanned_pts_mask, dtype=np.uint8)
|
||||||
# input_data["scanned_pts_mask"] = np.asarray(fps_downsampled_combined_scanned_pts_mask, dtype=np.uint8)
|
|
||||||
input_data["scanned_n_to_world_pose_9d"] = np.asarray(scanned_n_to_world_pose_9d, dtype=np.float32)
|
input_data["scanned_n_to_world_pose_9d"] = np.asarray(scanned_n_to_world_pose_9d, dtype=np.float32)
|
||||||
input_data["combined_scanned_pts"] = np.asarray(fps_downsampled_combined_scanned_pts, dtype=np.float32)
|
input_data["combined_scanned_pts"] = np.asarray(fps_downsampled_combined_scanned_pts, dtype=np.float32)
|
||||||
return input_data
|
return input_data
|
||||||
|
|
||||||
def get_result(self, output_data):
|
def get_result(self, output_data):
|
||||||
|
|
||||||
pred_pose_9d = output_data["pred_pose_9d"]
|
estimated_delta_rot_9d = output_data["pred_pose_9d"]
|
||||||
result = {
|
result = {
|
||||||
"pred_pose_9d": pred_pose_9d.tolist()
|
"estimated_delta_rot_9d": estimated_delta_rot_9d.tolist()
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def collate_input(self, input_data):
|
|
||||||
collated_input_data = {}
|
|
||||||
collated_input_data["scanned_pts"] = [torch.tensor(input_data["scanned_pts"], dtype=torch.float32, device=self.device)]
|
|
||||||
collated_input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(input_data["scanned_n_to_world_pose_9d"], dtype=torch.float32, device=self.device)]
|
|
||||||
collated_input_data["combined_scanned_pts"] = torch.tensor(input_data["combined_scanned_pts"], dtype=torch.float32, device=self.device).unsqueeze(0)
|
|
||||||
return collated_input_data
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
Log.info("Loading from epoch {}.".format(self.current_epoch))
|
Log.info("Loading from epoch {}.".format(self.current_epoch))
|
||||||
|
|
||||||
@ -74,8 +65,7 @@ class InferencerServer(Runner):
|
|||||||
def inference():
|
def inference():
|
||||||
data = request.json
|
data = request.json
|
||||||
input_data = self.get_input_data(data)
|
input_data = self.get_input_data(data)
|
||||||
collated_input_data = self.collate_input(input_data)
|
output_data = self.pipeline.forward_test(input_data)
|
||||||
output_data = self.pipeline.forward_test(collated_input_data)
|
|
||||||
result = self.get_result(output_data)
|
result = self.get_result(output_data)
|
||||||
return jsonify(result)
|
return jsonify(result)
|
||||||
|
|
@ -68,16 +68,9 @@ class Inferencer(Runner):
|
|||||||
test_set_name = test_set.get_name()
|
test_set_name = test_set.get_name()
|
||||||
|
|
||||||
total=int(len(test_set))
|
total=int(len(test_set))
|
||||||
scene_name_list = test_set.get_scene_name_list()
|
|
||||||
for i in range(total):
|
for i in range(total):
|
||||||
scene_name = scene_name_list[i]
|
|
||||||
inference_result_path = os.path.join(self.output_dir, test_set_name, f"{scene_name}.pkl")
|
|
||||||
if os.path.exists(inference_result_path):
|
|
||||||
Log.info(f"Inference result already exists for scene: {scene_name}")
|
|
||||||
continue
|
|
||||||
data = test_set.__getitem__(i)
|
data = test_set.__getitem__(i)
|
||||||
status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total)
|
status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total)
|
||||||
scene_name = data["scene_name"]
|
|
||||||
output = self.predict_sequence(data)
|
output = self.predict_sequence(data)
|
||||||
self.save_inference_result(test_set_name, data["scene_name"], output)
|
self.save_inference_result(test_set_name, data["scene_name"], output)
|
||||||
|
|
||||||
@ -115,12 +108,9 @@ class Inferencer(Runner):
|
|||||||
retry = 0
|
retry = 0
|
||||||
pred_cr_seq = [last_pred_cr]
|
pred_cr_seq = [last_pred_cr]
|
||||||
success = 0
|
success = 0
|
||||||
import time
|
|
||||||
while len(pred_cr_seq) < max_iter and retry < max_retry:
|
while len(pred_cr_seq) < max_iter and retry < max_retry:
|
||||||
start_time = time.time()
|
|
||||||
output = self.pipeline(input_data)
|
output = self.pipeline(input_data)
|
||||||
end_time = time.time()
|
|
||||||
print(f"Time taken for inference: {end_time - start_time} seconds")
|
|
||||||
pred_pose_9d = output["pred_pose_9d"]
|
pred_pose_9d = output["pred_pose_9d"]
|
||||||
pred_pose = torch.eye(4, device=pred_pose_9d.device)
|
pred_pose = torch.eye(4, device=pred_pose_9d.device)
|
||||||
|
|
||||||
@ -128,10 +118,7 @@ class Inferencer(Runner):
|
|||||||
pred_pose[:3,3] = pred_pose_9d[0,6:]
|
pred_pose[:3,3] = pred_pose_9d[0,6:]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
|
||||||
new_target_pts, new_target_normals = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose)
|
new_target_pts, new_target_normals = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose)
|
||||||
end_time = time.time()
|
|
||||||
print(f"Time taken for rendering: {end_time - start_time} seconds")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
Log.warning(f"Error in scene {scene_path}, {e}")
|
Log.warning(f"Error in scene {scene_path}, {e}")
|
||||||
print("current pose: ", pred_pose)
|
print("current pose: ", pred_pose)
|
||||||
@ -146,10 +133,8 @@ class Inferencer(Runner):
|
|||||||
retry += 1
|
retry += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
pred_cr, new_added_pts_num = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold)
|
pred_cr, new_added_pts_num = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold)
|
||||||
end_time = time.time()
|
|
||||||
print(f"Time taken for coverage rate computation: {end_time - start_time} seconds")
|
|
||||||
print(pred_cr, last_pred_cr, " max: ", data["seq_max_coverage_rate"])
|
print(pred_cr, last_pred_cr, " max: ", data["seq_max_coverage_rate"])
|
||||||
if pred_cr >= data["seq_max_coverage_rate"] - 1e-3:
|
if pred_cr >= data["seq_max_coverage_rate"] - 1e-3:
|
||||||
print("max coverage rate reached!: ", pred_cr)
|
print("max coverage rate reached!: ", pred_cr)
|
||||||
|
@ -24,6 +24,8 @@ class DataLoadUtil:
|
|||||||
for channel in float_channels:
|
for channel in float_channels:
|
||||||
channel_data = exr_file.channel(channel)
|
channel_data = exr_file.channel(channel)
|
||||||
img_data.append(np.frombuffer(channel_data, dtype=np.float16).reshape((height, width)))
|
img_data.append(np.frombuffer(channel_data, dtype=np.float16).reshape((height, width)))
|
||||||
|
|
||||||
|
# 将各通道组合成一个 (height, width, 3) 的 RGB 图像
|
||||||
img = np.stack(img_data, axis=-1)
|
img = np.stack(img_data, axis=-1)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
11
utils/pts.py
11
utils/pts.py
@ -17,17 +17,6 @@ class PtsUtil:
|
|||||||
unique_voxels = np.unique(voxel_indices, axis=0, return_inverse=True)
|
unique_voxels = np.unique(voxel_indices, axis=0, return_inverse=True)
|
||||||
return unique_voxels[0]*voxel_size
|
return unique_voxels[0]*voxel_size
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def voxel_downsample_point_cloud_random(point_cloud, voxel_size=0.005, require_idx=False):
|
|
||||||
voxel_indices = np.floor(point_cloud / voxel_size).astype(np.int32)
|
|
||||||
unique_voxels, inverse, counts = np.unique(voxel_indices, axis=0, return_inverse=True, return_counts=True)
|
|
||||||
idx_sort = np.argsort(inverse)
|
|
||||||
idx_unique = idx_sort[np.cumsum(counts)-counts]
|
|
||||||
downsampled_points = point_cloud[idx_unique]
|
|
||||||
if require_idx:
|
|
||||||
return downsampled_points, inverse
|
|
||||||
return downsampled_points
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def random_downsample_point_cloud(point_cloud, num_points, require_idx=False):
|
def random_downsample_point_cloud(point_cloud, num_points, require_idx=False):
|
||||||
if point_cloud.shape[0] == 0:
|
if point_cloud.shape[0] == 0:
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
@ -69,13 +68,9 @@ class RenderUtil:
|
|||||||
params_data_path = os.path.join(temp_dir, "params.json")
|
params_data_path = os.path.join(temp_dir, "params.json")
|
||||||
with open(params_data_path, 'w') as f:
|
with open(params_data_path, 'w') as f:
|
||||||
json.dump(params, f)
|
json.dump(params, f)
|
||||||
start_time = time.time()
|
|
||||||
result = subprocess.run([
|
result = subprocess.run([
|
||||||
'blender', '-b', '-P', script_path, '--', temp_dir
|
'blender', '-b', '-P', script_path, '--', temp_dir
|
||||||
], capture_output=True, text=True)
|
], capture_output=True, text=True)
|
||||||
end_time = time.time()
|
|
||||||
print(result)
|
|
||||||
print(f"-- Time taken for blender: {end_time - start_time} seconds")
|
|
||||||
if result.returncode != 0:
|
if result.returncode != 0:
|
||||||
print("Blender script failed:")
|
print("Blender script failed:")
|
||||||
print(result.stderr)
|
print(result.stderr)
|
||||||
@ -87,7 +82,6 @@ class RenderUtil:
|
|||||||
cam_info["far_plane"],
|
cam_info["far_plane"],
|
||||||
binocular=True
|
binocular=True
|
||||||
)
|
)
|
||||||
start_time = time.time()
|
|
||||||
mask_L, mask_R = DataLoadUtil.load_seg(path, binocular=True)
|
mask_L, mask_R = DataLoadUtil.load_seg(path, binocular=True)
|
||||||
normal_L = DataLoadUtil.load_normal(path, binocular=True, left_only=True)
|
normal_L = DataLoadUtil.load_normal(path, binocular=True, left_only=True)
|
||||||
''' target points '''
|
''' target points '''
|
||||||
@ -120,7 +114,6 @@ class RenderUtil:
|
|||||||
if not has_points:
|
if not has_points:
|
||||||
target_points = np.zeros((0, 3))
|
target_points = np.zeros((0, 3))
|
||||||
target_normals = np.zeros((0, 3))
|
target_normals = np.zeros((0, 3))
|
||||||
end_time = time.time()
|
|
||||||
print(f"-- Time taken for processing: {end_time - start_time} seconds")
|
|
||||||
#import ipdb; ipdb.set_trace()
|
#import ipdb; ipdb.set_trace()
|
||||||
return target_points, target_normals
|
return target_points, target_normals
|
Loading…
x
Reference in New Issue
Block a user