From 34548c64a3f29da2fa5e703d093ac31116b23d61 Mon Sep 17 00:00:00 2001 From: hofee Date: Sat, 28 Dec 2024 19:50:22 +0000 Subject: [PATCH] deploy pointnet++ finished --- configs/server/server_train_config.yaml | 2 +- modules/pointnet++_encoder.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/configs/server/server_train_config.yaml b/configs/server/server_train_config.yaml index 14cfb4c..198596d 100644 --- a/configs/server/server_train_config.yaml +++ b/configs/server/server_train_config.yaml @@ -80,7 +80,7 @@ dataset: pipeline: nbv_reconstruction_pipeline: modules: - pts_encoder: pointnet_encoder + pts_encoder: pointnet++_encoder seq_encoder: transformer_seq_encoder pose_encoder: pose_encoder view_finder: gf_view_finder diff --git a/modules/pointnet++_encoder.py b/modules/pointnet++_encoder.py index 7c50e1a..e223319 100644 --- a/modules/pointnet++_encoder.py +++ b/modules/pointnet++_encoder.py @@ -7,6 +7,7 @@ for i in range(2): path = os.path.dirname(path) PROJECT_ROOT = path sys.path.append(PROJECT_ROOT) +import PytorchBoot.stereotype as stereotype from modules.module_lib.pointnet2_modules import PointnetSAModuleMSG @@ -66,8 +67,9 @@ def break_up_pc(pc): return xyz, features +@stereotype.module("pointnet++_encoder") class PointNet2Encoder(nn.Module): - def encode_points(self, pts): + def encode_points(self, pts, require_per_point_feat=False): return self.forward(pts) def __init__(self, config:dict):