update gf_view_finder

This commit is contained in:
hofee 2024-09-02 23:47:52 +08:00
parent 2fcfcd1966
commit e0fb9a7617
7 changed files with 78 additions and 31 deletions

View File

@ -13,13 +13,13 @@ runner:
generate: generate:
voxel_threshold: 0.005 voxel_threshold: 0.005
overlap_threshold: 0.5 overlap_threshold: 0.5
save_points: True save_points: False
dataset_list: dataset_list:
- OmniObject3d - OmniObject3d
datasets: datasets:
OmniObject3d: OmniObject3d:
model_dir: "/media/hofee/data/data/scaled_object_meshes" model_dir: "H:\\AI\\Datasets\\scaled_object_meshes"
root_dir: "/media/hofee/data/data/nbv_rec/sample" root_dir: "C:\\Document\\Local Project\\nbv_rec\\data\\sample"

View File

@ -88,8 +88,8 @@ def cond_ode_sampler(
x = mean_x x = mean_x
num_steps = xs.shape[0] num_steps = xs.shape[0]
xs = xs.reshape(batch_size * num_steps, -1) xs = xs.reshape(batch_size*num_steps, -1)
xs = PoseUtil.normalize_rotation(xs, pose_mode) xs[:, :-3] = PoseUtil.normalize_rotation(xs[:, :-3], pose_mode)
xs = xs.reshape(num_steps, batch_size, -1) xs = xs.reshape(num_steps, batch_size, -1)
x = PoseUtil.normalize_rotation(x, pose_mode) x[:, :-3] = PoseUtil.normalize_rotation(x[:, :-3], pose_mode)
return xs.permute(1, 0, 2), x return xs.permute(1, 0, 2), x

View File

@ -2,6 +2,9 @@ import torch
import torch.nn as nn import torch.nn as nn
import PytorchBoot.stereotype as stereotype import PytorchBoot.stereotype as stereotype
import sys
sys.path.append(r"C:\Document\Local Project\nbv_rec\nbv_reconstruction")
from utils.pose import PoseUtil from utils.pose import PoseUtil
import modules.module_lib as mlib import modules.module_lib as mlib
import modules.func_lib as flib import modules.func_lib as flib
@ -47,7 +50,7 @@ class GradientFieldViewFinder(nn.Module):
) )
''' fusion tail ''' ''' fusion tail '''
if self.regression_head == 'Rx_Ry': if self.regression_head == 'Rx_Ry_and_T':
if self.pose_mode != 'rot_matrix': if self.pose_mode != 'rot_matrix':
raise NotImplementedError raise NotImplementedError
if not self.per_point_feature: if not self.per_point_feature:
@ -62,6 +65,12 @@ class GradientFieldViewFinder(nn.Module):
self.act, self.act,
zero_module(nn.Linear(256, 3)), zero_module(nn.Linear(256, 3)),
) )
''' tranalation regress head '''
self.fusion_tail_trans = nn.Sequential(
nn.Linear(128 + 256 + 2048, 256),
self.act,
zero_module(nn.Linear(256, 3)),
)
else: else:
raise NotImplementedError raise NotImplementedError
else: else:
@ -89,10 +98,11 @@ class GradientFieldViewFinder(nn.Module):
total_feat = torch.cat([seq_feat, t_feat, pose_feat], dim=-1) total_feat = torch.cat([seq_feat, t_feat, pose_feat], dim=-1)
_, std = self.marginal_prob_fn(total_feat, t) _, std = self.marginal_prob_fn(total_feat, t)
if self.regression_head == 'Rx_Ry': if self.regression_head == 'Rx_Ry_and_T':
rot_x = self.fusion_tail_rot_x(total_feat) rot_x = self.fusion_tail_rot_x(total_feat)
rot_y = self.fusion_tail_rot_y(total_feat) rot_y = self.fusion_tail_rot_y(total_feat)
out_score = torch.cat([rot_x, rot_y], dim=-1) / (std + 1e-7) # normalisation trans = self.fusion_tail_trans(total_feat)
out_score = torch.cat([rot_x, rot_y, trans], dim=-1) / (std+1e-7) # normalisation
else: else:
raise NotImplementedError raise NotImplementedError
@ -134,18 +144,24 @@ class GradientFieldViewFinder(nn.Module):
''' ----------- DEBUG -----------''' ''' ----------- DEBUG -----------'''
if __name__ == "__main__": if __name__ == "__main__":
test_scene_feat = torch.rand(32, 1024).to("cuda:0") config = {
test_target_feat = torch.rand(32, 1024).to("cuda:0") "regression_head": "Rx_Ry_and_T",
test_pose = torch.rand(32, 6).to("cuda:0") "per_point_feature": False,
"pose_mode": "rot_matrix",
"sde_mode": "ve",
"sampling_steps": 500,
"sample_mode": "ode"
}
test_seq_feat = torch.rand(32, 2048).to("cuda:0")
test_pose = torch.rand(32, 9).to("cuda:0")
test_t = torch.rand(32, 1).to("cuda:0") test_t = torch.rand(32, 1).to("cuda:0")
view_finder = GradientFieldViewFinder().to("cuda:0") view_finder = GradientFieldViewFinder(config).to("cuda:0")
test_data = { test_data = {
'target_feat': test_target_feat, 'seq_feat': test_seq_feat,
'scene_feat': test_scene_feat,
'sampled_pose': test_pose, 'sampled_pose': test_pose,
't': test_t 't': test_t
} }
score = view_finder(test_data) score = view_finder(test_data)
print(score.shape)
result = view_finder.next_best_view(test_scene_feat, test_target_feat) res, inprocess = view_finder.next_best_view(test_seq_feat)
print(result) print(res.shape, inprocess.shape)

View File

@ -4,9 +4,9 @@ from PytorchBoot.runners.runner import Runner
from PytorchBoot.config import ConfigManager from PytorchBoot.config import ConfigManager
from PytorchBoot.utils import Log from PytorchBoot.utils import Log
import PytorchBoot.stereotype as stereotype import PytorchBoot.stereotype as stereotype
from PytorchBoot.status import status_manager
@stereotype.runner("data_splitor")
@stereotype.runner("data_splitor", comment="unfinished")
class DataSplitor(Runner): class DataSplitor(Runner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@ -23,15 +23,17 @@ class DataSplitor(Runner):
random.shuffle(self.datapath_list) random.shuffle(self.datapath_list)
start_idx = 0 start_idx = 0
for dataset in self.datasets: for dataset_idx in range(len(self.datasets)):
dataset = list(self.datasets.keys())[dataset_idx]
ratio = self.datasets[dataset]["ratio"] ratio = self.datasets[dataset]["ratio"]
path = self.datasets[dataset]["path"] path = self.datasets[dataset]["path"]
split_size = int(len(self.datapath_list) * ratio) split_size = int(len(self.datapath_list) * ratio)
split_files = self.datapath_list[start_idx:start_idx + split_size] split_files = self.datapath_list[start_idx:start_idx + split_size]
start_idx += split_size start_idx += split_size
self.save_split_files(path, split_files) self.save_split_files(path, split_files)
status_manager.set_progress("split", "data_splitor", "split dataset", dataset_idx, len(self.datasets))
Log.success(f"save {dataset} split files to {path}") Log.success(f"save {dataset} split files to {path}")
status_manager.set_progress("split", "data_splitor", "split dataset", len(self.datasets), len(self.datasets))
def save_split_files(self, path, split_files): def save_split_files(self, path, split_files):
os.makedirs(os.path.dirname(path), exist_ok=True) os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w") as f: with open(path, "w") as f:

View File

@ -6,6 +6,7 @@ from PytorchBoot.runners.runner import Runner
from PytorchBoot.config import ConfigManager from PytorchBoot.config import ConfigManager
from PytorchBoot.utils import Log from PytorchBoot.utils import Log
import PytorchBoot.stereotype as stereotype import PytorchBoot.stereotype as stereotype
from PytorchBoot.status import status_manager
from utils.data_load import DataLoadUtil from utils.data_load import DataLoadUtil
from utils.reconstruction import ReconstructionUtil from utils.reconstruction import ReconstructionUtil
@ -16,12 +17,19 @@ class StrategyGenerator(Runner):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.load_experiment("generate") self.load_experiment("generate")
self.status_info = {
"status_manager": status_manager,
"app_name": "generate",
"runner_name": "strategy_generator"
}
def run(self): def run(self):
dataset_name_list = ConfigManager.get("runner", "generate", "dataset_list") dataset_name_list = ConfigManager.get("runner", "generate", "dataset_list")
voxel_threshold, overlap_threshold = ConfigManager.get("runner","generate","voxel_threshold"), ConfigManager.get("runner","generate","overlap_threshold") voxel_threshold, overlap_threshold = ConfigManager.get("runner","generate","voxel_threshold"), ConfigManager.get("runner","generate","overlap_threshold")
self.save_pts = ConfigManager.get("runner","generate","save_points") self.save_pts = ConfigManager.get("runner","generate","save_points")
for dataset_name in dataset_name_list: for dataset_idx in range(len(dataset_name_list)):
dataset_name = dataset_name_list[dataset_idx]
status_manager.set_progress("generate", "strategy_generator", "dataset", dataset_idx, len(dataset_name_list))
root_dir = ConfigManager.get("datasets", dataset_name, "root_dir") root_dir = ConfigManager.get("datasets", dataset_name, "root_dir")
model_dir = ConfigManager.get("datasets", dataset_name, "model_dir") model_dir = ConfigManager.get("datasets", dataset_name, "model_dir")
scene_name_list = os.listdir(root_dir) scene_name_list = os.listdir(root_dir)
@ -29,8 +37,12 @@ class StrategyGenerator(Runner):
total = len(scene_name_list) total = len(scene_name_list)
for scene_name in scene_name_list: for scene_name in scene_name_list:
Log.info(f"({dataset_name})Processing [{cnt}/{total}]: {scene_name}") Log.info(f"({dataset_name})Processing [{cnt}/{total}]: {scene_name}")
status_manager.set_progress("generate", "strategy_generator", "scene", cnt, total)
self.generate_sequence(root_dir, model_dir, scene_name,voxel_threshold, overlap_threshold) self.generate_sequence(root_dir, model_dir, scene_name,voxel_threshold, overlap_threshold)
cnt += 1 cnt += 1
status_manager.set_progress("generate", "strategy_generator", "scene", total, total)
status_manager.set_progress("generate", "strategy_generator", "dataset", len(dataset_name_list), len(dataset_name_list))
def create_experiment(self, backup_name=None): def create_experiment(self, backup_name=None):
super().create_experiment(backup_name) super().create_experiment(backup_name)
@ -41,6 +53,7 @@ class StrategyGenerator(Runner):
super().load_experiment(backup_name) super().load_experiment(backup_name)
def generate_sequence(self, root, model_dir, scene_name, voxel_threshold, overlap_threshold): def generate_sequence(self, root, model_dir, scene_name, voxel_threshold, overlap_threshold):
status_manager.set_status("generate", "strategy_generator", "scene", scene_name)
frame_num = DataLoadUtil.get_scene_seq_length(root, scene_name) frame_num = DataLoadUtil.get_scene_seq_length(root, scene_name)
model_pts = DataLoadUtil.load_original_model_points(model_dir, scene_name) model_pts = DataLoadUtil.load_original_model_points(model_dir, scene_name)
down_sampled_model_pts = PtsUtil.voxel_downsample_point_cloud(model_pts, voxel_threshold) down_sampled_model_pts = PtsUtil.voxel_downsample_point_cloud(model_pts, voxel_threshold)
@ -50,7 +63,7 @@ class StrategyGenerator(Runner):
for frame_idx in range(frame_num): for frame_idx in range(frame_num):
path = DataLoadUtil.get_path(root, scene_name, frame_idx) path = DataLoadUtil.get_path(root, scene_name, frame_idx)
status_manager.set_progress("generate", "strategy_generator", "loading frame", frame_idx, frame_num)
point_cloud = DataLoadUtil.get_point_cloud_world_from_path(path) point_cloud = DataLoadUtil.get_point_cloud_world_from_path(path)
sampled_point_cloud = PtsUtil.voxel_downsample_point_cloud(point_cloud, voxel_threshold) sampled_point_cloud = PtsUtil.voxel_downsample_point_cloud(point_cloud, voxel_threshold)
if self.save_pts: if self.save_pts:
@ -59,13 +72,17 @@ class StrategyGenerator(Runner):
os.makedirs(pts_dir) os.makedirs(pts_dir)
np.savetxt(os.path.join(pts_dir, f"{frame_idx}.txt"), sampled_point_cloud) np.savetxt(os.path.join(pts_dir, f"{frame_idx}.txt"), sampled_point_cloud)
pts_list.append(sampled_point_cloud) pts_list.append(sampled_point_cloud)
limited_useful_view, _ = ReconstructionUtil.compute_next_best_view_sequence_with_overlap(down_sampled_transformed_model_pts, pts_list, threshold=voxel_threshold, overlap_threshold=overlap_threshold) status_manager.set_progress("generate", "strategy_generator", "loading frame", frame_num, frame_num)
limited_useful_view, _ = ReconstructionUtil.compute_next_best_view_sequence_with_overlap(down_sampled_transformed_model_pts, pts_list, threshold=voxel_threshold, overlap_threshold=overlap_threshold, status_info=self.status_info)
data_pairs = self.generate_data_pairs(limited_useful_view) data_pairs = self.generate_data_pairs(limited_useful_view)
seq_save_data = { seq_save_data = {
"data_pairs": data_pairs, "data_pairs": data_pairs,
"best_sequence": limited_useful_view, "best_sequence": limited_useful_view,
"max_coverage_rate": limited_useful_view[-1][1] "max_coverage_rate": limited_useful_view[-1][1]
} }
status_manager.set_status("generate", "strategy_generator", "max_coverage_rate", limited_useful_view[-1][1])
Log.success(f"Scene <{scene_name}> Finished, Max Coverage Rate: {limited_useful_view[-1][1]}, Best Sequence length: {len(limited_useful_view)}") Log.success(f"Scene <{scene_name}> Finished, Max Coverage Rate: {limited_useful_view[-1][1]}, Best Sequence length: {len(limited_useful_view)}")
output_label_path = DataLoadUtil.get_label_path(root, scene_name) output_label_path = DataLoadUtil.get_label_path(root, scene_name)

View File

@ -184,11 +184,11 @@ class PoseUtil:
], f"the rotation mode {rot_mode} is not supported!" ], f"the rotation mode {rot_mode} is not supported!"
if rot_mode == "quat_wxyz" or rot_mode == "quat_xyzw": if rot_mode == "quat_wxyz" or rot_mode == "quat_xyzw":
pose_dim = 4 pose_dim = 7
elif rot_mode == "euler_xyz": elif rot_mode == "euler_xyz":
pose_dim = 3
elif rot_mode == "euler_xyz_sx_cx" or rot_mode == "rot_matrix":
pose_dim = 6 pose_dim = 6
elif rot_mode == "euler_xyz_sx_cx" or rot_mode == "rot_matrix":
pose_dim = 9
else: else:
raise NotImplementedError raise NotImplementedError
return pose_dim return pose_dim

View File

@ -45,12 +45,12 @@ class ReconstructionUtil:
@staticmethod @staticmethod
def compute_next_best_view_sequence_with_overlap(target_point_cloud, point_cloud_list, threshold=0.01, overlap_threshold=0.3): def compute_next_best_view_sequence_with_overlap(target_point_cloud, point_cloud_list, threshold=0.01, overlap_threshold=0.3, status_info=None):
selected_views = [] selected_views = []
current_coverage = 0.0 current_coverage = 0.0
remaining_views = list(range(len(point_cloud_list))) remaining_views = list(range(len(point_cloud_list)))
view_sequence = [] view_sequence = []
cnt_processed_view = 0
while remaining_views: while remaining_views:
best_view = None best_view = None
best_coverage_increase = -1 best_coverage_increase = -1
@ -74,6 +74,14 @@ class ReconstructionUtil:
if coverage_increase > best_coverage_increase: if coverage_increase > best_coverage_increase:
best_coverage_increase = coverage_increase best_coverage_increase = coverage_increase
best_view = view_index best_view = view_index
cnt_processed_view += 1
if status_info is not None:
sm = status_info["status_manager"]
app_name = status_info["app_name"]
runner_name = status_info["runner_name"]
sm.set_status(app_name, runner_name, "current coverage", current_coverage)
sm.set_progress(app_name, runner_name, "processed view", cnt_processed_view, len(point_cloud_list))
if best_view is not None: if best_view is not None:
if best_coverage_increase <=1e-3: if best_coverage_increase <=1e-3:
@ -87,7 +95,11 @@ class ReconstructionUtil:
else: else:
break break
if status_info is not None:
sm = status_info["status_manager"]
app_name = status_info["app_name"]
runner_name = status_info["runner_name"]
sm.set_progress(app_name, runner_name, "processed view", len(point_cloud_list), len(point_cloud_list))
return view_sequence, remaining_views return view_sequence, remaining_views