nbv_grasping/modules/view_finder/view_finder_factory.py

45 lines
1.9 KiB
Python
Raw Permalink Normal View History

2024-10-09 16:13:22 +00:00
from modules.view_finder.abstract_view_finder import ViewFinder
from modules.view_finder.gf_view_finder import GradientFieldViewFinder
class ViewFinderFactory:
@staticmethod
def create(name, config) -> ViewFinder:
general_config = config["general"]
view_finder_config = config["view_finder"][name]
if name == "gradient_field":
return GradientFieldViewFinder(
pose_mode=view_finder_config["pose_mode"],
regression_head=view_finder_config["regression_head"],
per_point_feature=general_config["per_point_feature"],
sample_mode=view_finder_config["sample_mode"],
sampling_steps=view_finder_config.get("sampling_steps", None),
sde_mode=view_finder_config["sde_mode"]
)
else:
raise ValueError(f"Unknown next-best-view finder name: {name}")
''' ------------ Debug ------------ '''
if __name__ == "__main__":
from configs.config import ConfigManager
import torch
ConfigManager.load_config_with('../../configs/local_train_config.yaml')
ConfigManager.print_config()
view_finder = ViewFinderFactory.create(name="gradient_field", config=ConfigManager.get("modules"))
test_scene_feat = torch.rand(32, 1024).to("cuda:0")
test_target_feat = torch.rand(32, 1024).to("cuda:0")
test_pose = torch.rand(32, 6).to("cuda:0")
test_t = torch.rand(32, 1).to("cuda:0")
view_finder = view_finder.to("cuda:0")
test_data = {
'target_feat': test_target_feat,
'scene_feat': test_scene_feat,
'sampled_pose': test_pose,
't': test_t
}
score = view_finder(test_data)
print(score.shape)
pose_6d = view_finder.next_best_view(scene_pts_feat=test_data["scene_feat"], target_pts_feat=test_data["target_feat"])
print(pose_6d.shape)