57 lines
2.1 KiB
Python
Executable File
57 lines
2.1 KiB
Python
Executable File
import sys
|
|
import os
|
|
path = os.path.abspath(__file__)
|
|
for i in range(3):
|
|
path = os.path.dirname(path)
|
|
PROJECT_ROOT = path
|
|
sys.path.append(PROJECT_ROOT)
|
|
|
|
from modules.pts_encoder.abstract_pts_encoder import PointsEncoder
|
|
from modules.pts_encoder.pointnet_encoder import PointNetEncoder
|
|
from modules.pts_encoder.pointnet2_encoder import PointNet2Encoder
|
|
from modules.pts_encoder.pointnet3_encoder import PointNet3Encoder
|
|
|
|
class PointsEncoderFactory:
|
|
@staticmethod
|
|
def create(name, config) -> PointsEncoder:
|
|
general_config = config["general"]
|
|
pts_encoder_config = config["pts_encoder"][name]
|
|
if name == "pointnet":
|
|
return PointNetEncoder(
|
|
in_dim=general_config["pts_channels"],
|
|
out_dim=general_config["feature_dim"],
|
|
global_feat=not general_config["per_point_feature"]
|
|
)
|
|
elif name == "pointnet++":
|
|
return PointNet2Encoder(
|
|
input_channels=general_config["pts_channels"] - 3,
|
|
params_name=pts_encoder_config["params_name"]
|
|
)
|
|
elif name == "pointnet++rgb":
|
|
return PointNet3Encoder(
|
|
input_channels=general_config["pts_channels"] - 3,
|
|
params_name=pts_encoder_config["params_name"],
|
|
target_layer=pts_encoder_config["target_layer"],
|
|
rgb_feat_dim=pts_encoder_config["rgb_feat_dim"]
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown encoder name: {name}")
|
|
|
|
|
|
''' ------------ Debug ------------ '''
|
|
if __name__ == "__main__":
|
|
from configs.config import ConfigManager
|
|
import torch
|
|
|
|
pts = torch.rand(32, 1200, 3) # BxNxC
|
|
ConfigManager.load_config_with('configs/local_train_config.yaml')
|
|
ConfigManager.print_config()
|
|
pts_encoder = PointsEncoderFactory.create(name="pointnet++", config=ConfigManager.get("modules"))
|
|
print(pts_encoder)
|
|
pts = pts.to("cuda")
|
|
pts_encoder = pts_encoder.to("cuda")
|
|
|
|
pts_feat = pts_encoder.encode_points(pts)
|
|
|
|
print(pts_feat.shape)
|