nbv_grasping/modules/pts_encoder/pts_encoder_factory.py
2024-10-09 16:13:22 +00:00

57 lines
2.1 KiB
Python
Executable File

import sys
import os
path = os.path.abspath(__file__)
for i in range(3):
path = os.path.dirname(path)
PROJECT_ROOT = path
sys.path.append(PROJECT_ROOT)
from modules.pts_encoder.abstract_pts_encoder import PointsEncoder
from modules.pts_encoder.pointnet_encoder import PointNetEncoder
from modules.pts_encoder.pointnet2_encoder import PointNet2Encoder
from modules.pts_encoder.pointnet3_encoder import PointNet3Encoder
class PointsEncoderFactory:
@staticmethod
def create(name, config) -> PointsEncoder:
general_config = config["general"]
pts_encoder_config = config["pts_encoder"][name]
if name == "pointnet":
return PointNetEncoder(
in_dim=general_config["pts_channels"],
out_dim=general_config["feature_dim"],
global_feat=not general_config["per_point_feature"]
)
elif name == "pointnet++":
return PointNet2Encoder(
input_channels=general_config["pts_channels"] - 3,
params_name=pts_encoder_config["params_name"]
)
elif name == "pointnet++rgb":
return PointNet3Encoder(
input_channels=general_config["pts_channels"] - 3,
params_name=pts_encoder_config["params_name"],
target_layer=pts_encoder_config["target_layer"],
rgb_feat_dim=pts_encoder_config["rgb_feat_dim"]
)
else:
raise ValueError(f"Unknown encoder name: {name}")
''' ------------ Debug ------------ '''
if __name__ == "__main__":
from configs.config import ConfigManager
import torch
pts = torch.rand(32, 1200, 3) # BxNxC
ConfigManager.load_config_with('configs/local_train_config.yaml')
ConfigManager.print_config()
pts_encoder = PointsEncoderFactory.create(name="pointnet++", config=ConfigManager.get("modules"))
print(pts_encoder)
pts = pts.to("cuda")
pts_encoder = pts_encoder.to("cuda")
pts_feat = pts_encoder.encode_points(pts)
print(pts_feat.shape)