deploy pointnet++ again
This commit is contained in:
@@ -5,9 +5,10 @@ import sys
|
||||
path = os.path.abspath(__file__)
|
||||
for i in range(2):
|
||||
path = os.path.dirname(path)
|
||||
PROJECT_ROOT = path
|
||||
PROJECT_ROOT = path
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
from modules.module_lib.pointnet2_utils.pointnet2.pointnet2_modules import PointnetSAModuleMSG
|
||||
from modules.module_lib.pointnet2_modules import PointnetSAModuleMSG
|
||||
|
||||
|
||||
ClsMSG_CFG_Dense = {
|
||||
'NPOINTS': [512, 256, 128, None],
|
||||
@@ -72,7 +73,7 @@ class PointNet2Encoder(nn.Module):
|
||||
def __init__(self, config:dict):
|
||||
super().__init__()
|
||||
|
||||
input_channels = config.get("in_dim", 0)
|
||||
input_channels = config.get("in_dim", 3) - 3
|
||||
params_name = config.get("params_name", "light")
|
||||
|
||||
self.SA_modules = nn.ModuleList()
|
||||
@@ -112,7 +113,7 @@ if __name__ == '__main__':
|
||||
seed = 100
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
net = PointNet2Encoder(config={"in_dim": 0, "params_name": "light"}).cuda()
|
||||
net = PointNet2Encoder(config={"in_dim": 3, "params_name": "light"}).cuda()
|
||||
pts = torch.randn(2, 1024, 3).cuda()
|
||||
print(torch.mean(pts, dim=1))
|
||||
pre = net.encode_points(pts)
|
||||
|
Reference in New Issue
Block a user