train pointnet++

This commit is contained in:
hofee 2024-12-30 14:00:53 +00:00
parent 34548c64a3
commit 88d44f020e
3 changed files with 20 additions and 17 deletions

View File

@ -7,7 +7,7 @@ runner:
parallel: False parallel: False
experiment: experiment:
name: train_ab_global_only_pointnet++ name: train_ab_global_only_with_accept_probability
root_dir: "experiments" root_dir: "experiments"
use_checkpoint: False use_checkpoint: False
epoch: -1 # -1 stands for last epoch epoch: -1 # -1 stands for last epoch
@ -80,7 +80,7 @@ dataset:
pipeline: pipeline:
nbv_reconstruction_pipeline: nbv_reconstruction_pipeline:
modules: modules:
pts_encoder: pointnet++_encoder pts_encoder: pointnet_encoder
seq_encoder: transformer_seq_encoder seq_encoder: transformer_seq_encoder
pose_encoder: pose_encoder pose_encoder: pose_encoder
view_finder: gf_view_finder view_finder: gf_view_finder

View File

@ -4,6 +4,7 @@ import PytorchBoot.namespace as namespace
import PytorchBoot.stereotype as stereotype import PytorchBoot.stereotype as stereotype
from PytorchBoot.config import ConfigManager from PytorchBoot.config import ConfigManager
from PytorchBoot.utils.log_util import Log from PytorchBoot.utils.log_util import Log
import torch import torch
import os import os
import sys import sys
@ -50,7 +51,7 @@ class NBVReconstructionDataset(BaseDataset):
scene_name_list.append(scene_name) scene_name_list.append(scene_name)
return scene_name_list return scene_name_list
def get_datalist(self): def get_datalist(self, bias=False):
datalist = [] datalist = []
for scene_name in self.scene_name_list: for scene_name in self.scene_name_list:
seq_num = DataLoadUtil.get_label_num(self.root_dir, scene_name) seq_num = DataLoadUtil.get_label_num(self.root_dir, scene_name)
@ -79,16 +80,18 @@ class NBVReconstructionDataset(BaseDataset):
for data_pair in label_data["data_pairs"]: for data_pair in label_data["data_pairs"]:
scanned_views = data_pair[0] scanned_views = data_pair[0]
next_best_view = data_pair[1] next_best_view = data_pair[1]
datalist.append( accept_probability = scanned_views[-1][1]
{ if accept_probability > np.random.rand():
"scanned_views": scanned_views, datalist.append(
"next_best_view": next_best_view, {
"seq_max_coverage_rate": max_coverage_rate, "scanned_views": scanned_views,
"scene_name": scene_name, "next_best_view": next_best_view,
"label_idx": seq_idx, "seq_max_coverage_rate": max_coverage_rate,
"scene_max_coverage_rate": scene_max_coverage_rate, "scene_name": scene_name,
} "label_idx": seq_idx,
) "scene_max_coverage_rate": scene_max_coverage_rate,
}
)
return datalist return datalist
def preprocess_cache(self): def preprocess_cache(self):
@ -227,9 +230,10 @@ if __name__ == "__main__":
torch.manual_seed(seed) torch.manual_seed(seed)
np.random.seed(seed) np.random.seed(seed)
config = { config = {
"root_dir": "/data/hofee/data/packed_preprocessed_data", "root_dir": "/data/hofee/data/new_full_data",
"model_dir": "../data/scaled_object_meshes",
"source": "nbv_reconstruction_dataset", "source": "nbv_reconstruction_dataset",
"split_file": "/data/hofee/data/OmniObject3d_train.txt", "split_file": "/data/hofee/data/new_full_data_list/OmniObject3d_train.txt",
"load_from_preprocess": True, "load_from_preprocess": True,
"ratio": 0.5, "ratio": 0.5,
"batch_size": 2, "batch_size": 2,

View File

@ -75,11 +75,10 @@ class PointNet2Encoder(nn.Module):
def __init__(self, config:dict): def __init__(self, config:dict):
super().__init__() super().__init__()
input_channels = config.get("in_dim", 3) - 3 channel_in = config.get("in_dim", 3) - 3
params_name = config.get("params_name", "light") params_name = config.get("params_name", "light")
self.SA_modules = nn.ModuleList() self.SA_modules = nn.ModuleList()
channel_in = input_channels
selected_params = select_params(params_name) selected_params = select_params(params_name)
for k in range(selected_params['NPOINTS'].__len__()): for k in range(selected_params['NPOINTS'].__len__()):
mlps = selected_params['MLPS'][k].copy() mlps = selected_params['MLPS'][k].copy()