diff --git a/configs/server/server_train_config.yaml b/configs/server/server_train_config.yaml index 198596d..a3d8bec 100644 --- a/configs/server/server_train_config.yaml +++ b/configs/server/server_train_config.yaml @@ -7,7 +7,7 @@ runner: parallel: False experiment: - name: train_ab_global_only_pointnet++ + name: train_ab_global_only_with_accept_probability root_dir: "experiments" use_checkpoint: False epoch: -1 # -1 stands for last epoch @@ -80,7 +80,7 @@ dataset: pipeline: nbv_reconstruction_pipeline: modules: - pts_encoder: pointnet++_encoder + pts_encoder: pointnet_encoder seq_encoder: transformer_seq_encoder pose_encoder: pose_encoder view_finder: gf_view_finder diff --git a/core/nbv_dataset.py b/core/nbv_dataset.py index ca9c0c7..5777602 100644 --- a/core/nbv_dataset.py +++ b/core/nbv_dataset.py @@ -4,6 +4,7 @@ import PytorchBoot.namespace as namespace import PytorchBoot.stereotype as stereotype from PytorchBoot.config import ConfigManager from PytorchBoot.utils.log_util import Log + import torch import os import sys @@ -50,7 +51,7 @@ class NBVReconstructionDataset(BaseDataset): scene_name_list.append(scene_name) return scene_name_list - def get_datalist(self): + def get_datalist(self, bias=False): datalist = [] for scene_name in self.scene_name_list: seq_num = DataLoadUtil.get_label_num(self.root_dir, scene_name) @@ -79,16 +80,18 @@ class NBVReconstructionDataset(BaseDataset): for data_pair in label_data["data_pairs"]: scanned_views = data_pair[0] next_best_view = data_pair[1] - datalist.append( - { - "scanned_views": scanned_views, - "next_best_view": next_best_view, - "seq_max_coverage_rate": max_coverage_rate, - "scene_name": scene_name, - "label_idx": seq_idx, - "scene_max_coverage_rate": scene_max_coverage_rate, - } - ) + accept_probability = scanned_views[-1][1] + if accept_probability > np.random.rand(): + datalist.append( + { + "scanned_views": scanned_views, + "next_best_view": next_best_view, + "seq_max_coverage_rate": max_coverage_rate, + "scene_name": scene_name, + "label_idx": seq_idx, + "scene_max_coverage_rate": scene_max_coverage_rate, + } + ) return datalist def preprocess_cache(self): @@ -227,9 +230,10 @@ if __name__ == "__main__": torch.manual_seed(seed) np.random.seed(seed) config = { - "root_dir": "/data/hofee/data/packed_preprocessed_data", + "root_dir": "/data/hofee/data/new_full_data", + "model_dir": "../data/scaled_object_meshes", "source": "nbv_reconstruction_dataset", - "split_file": "/data/hofee/data/OmniObject3d_train.txt", + "split_file": "/data/hofee/data/new_full_data_list/OmniObject3d_train.txt", "load_from_preprocess": True, "ratio": 0.5, "batch_size": 2, diff --git a/modules/pointnet++_encoder.py b/modules/pointnet++_encoder.py index e223319..c597fb5 100644 --- a/modules/pointnet++_encoder.py +++ b/modules/pointnet++_encoder.py @@ -75,11 +75,10 @@ class PointNet2Encoder(nn.Module): def __init__(self, config:dict): super().__init__() - input_channels = config.get("in_dim", 3) - 3 + channel_in = config.get("in_dim", 3) - 3 params_name = config.get("params_name", "light") self.SA_modules = nn.ModuleList() - channel_in = input_channels selected_params = select_params(params_name) for k in range(selected_params['NPOINTS'].__len__()): mlps = selected_params['MLPS'][k].copy()