deploy pointnet++ again

This commit is contained in:
2024-12-28 19:38:27 +00:00
parent 91cabec977
commit 47ea0ac434
29 changed files with 10 additions and 1720 deletions

View File

@@ -5,9 +5,10 @@ import sys
path = os.path.abspath(__file__)
for i in range(2):
path = os.path.dirname(path)
PROJECT_ROOT = path
PROJECT_ROOT = path
sys.path.append(PROJECT_ROOT)
from modules.module_lib.pointnet2_utils.pointnet2.pointnet2_modules import PointnetSAModuleMSG
from modules.module_lib.pointnet2_modules import PointnetSAModuleMSG
ClsMSG_CFG_Dense = {
'NPOINTS': [512, 256, 128, None],
@@ -72,7 +73,7 @@ class PointNet2Encoder(nn.Module):
def __init__(self, config:dict):
super().__init__()
input_channels = config.get("in_dim", 0)
input_channels = config.get("in_dim", 3) - 3
params_name = config.get("params_name", "light")
self.SA_modules = nn.ModuleList()
@@ -112,7 +113,7 @@ if __name__ == '__main__':
seed = 100
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
net = PointNet2Encoder(config={"in_dim": 0, "params_name": "light"}).cuda()
net = PointNet2Encoder(config={"in_dim": 3, "params_name": "light"}).cuda()
pts = torch.randn(2, 1024, 3).cuda()
print(torch.mean(pts, dim=1))
pre = net.encode_points(pts)