deploy pointnet++ finished

This commit is contained in:
hofee 2024-12-28 19:50:22 +00:00
parent 47ea0ac434
commit 34548c64a3
2 changed files with 4 additions and 2 deletions

View File

@ -80,7 +80,7 @@ dataset:
pipeline: pipeline:
nbv_reconstruction_pipeline: nbv_reconstruction_pipeline:
modules: modules:
pts_encoder: pointnet_encoder pts_encoder: pointnet++_encoder
seq_encoder: transformer_seq_encoder seq_encoder: transformer_seq_encoder
pose_encoder: pose_encoder pose_encoder: pose_encoder
view_finder: gf_view_finder view_finder: gf_view_finder

View File

@ -7,6 +7,7 @@ for i in range(2):
path = os.path.dirname(path) path = os.path.dirname(path)
PROJECT_ROOT = path PROJECT_ROOT = path
sys.path.append(PROJECT_ROOT) sys.path.append(PROJECT_ROOT)
import PytorchBoot.stereotype as stereotype
from modules.module_lib.pointnet2_modules import PointnetSAModuleMSG from modules.module_lib.pointnet2_modules import PointnetSAModuleMSG
@ -66,8 +67,9 @@ def break_up_pc(pc):
return xyz, features return xyz, features
@stereotype.module("pointnet++_encoder")
class PointNet2Encoder(nn.Module): class PointNet2Encoder(nn.Module):
def encode_points(self, pts): def encode_points(self, pts, require_per_point_feat=False):
return self.forward(pts) return self.forward(pts)
def __init__(self, config:dict): def __init__(self, config:dict):