From 030bf551929631dabdeb78502a586ec029391306 Mon Sep 17 00:00:00 2001 From: hofee Date: Wed, 25 Sep 2024 09:31:22 +0000 Subject: [PATCH] add global_pts_pipeline and pose_seq_encooder --- app_train.py | 2 +- ....yaml => server_split_dataset_config.yaml} | 0 ...l => server_strategy_generate_config.yaml} | 0 ...n_config.yaml => server_train_config.yaml} | 28 ++++-- ....yaml => server_view_generate_config.yaml} | 0 core/global_pts_pipeline.py | 95 +++++++++++++++++++ core/{pipeline.py => local_pts_pipeline.py} | 8 +- core/nbv_dataset.py | 3 +- modules/transformer_pose_seq_encoder.py | 63 ++++++++++++ 9 files changed, 186 insertions(+), 13 deletions(-) rename configs/server/{split_dataset_config.yaml => server_split_dataset_config.yaml} (100%) rename configs/server/{strategy_generate_config.yaml => server_strategy_generate_config.yaml} (100%) rename configs/server/{train_config.yaml => server_train_config.yaml} (76%) rename configs/server/{view_generate_config.yaml => server_view_generate_config.yaml} (100%) create mode 100644 core/global_pts_pipeline.py rename core/{pipeline.py => local_pts_pipeline.py} (95%) create mode 100644 modules/transformer_pose_seq_encoder.py diff --git a/app_train.py b/app_train.py index 071398f..191853d 100644 --- a/app_train.py +++ b/app_train.py @@ -5,4 +5,4 @@ from PytorchBoot.runners.trainer import DefaultTrainer class TrainApp: @staticmethod def start(): - DefaultTrainer("configs/server/train_config.yaml").run() \ No newline at end of file + DefaultTrainer("configs/server/server_train_config.yaml").run() \ No newline at end of file diff --git a/configs/server/split_dataset_config.yaml b/configs/server/server_split_dataset_config.yaml similarity index 100% rename from configs/server/split_dataset_config.yaml rename to configs/server/server_split_dataset_config.yaml diff --git a/configs/server/strategy_generate_config.yaml b/configs/server/server_strategy_generate_config.yaml similarity index 100% rename from configs/server/strategy_generate_config.yaml rename to configs/server/server_strategy_generate_config.yaml diff --git a/configs/server/train_config.yaml b/configs/server/server_train_config.yaml similarity index 76% rename from configs/server/train_config.yaml rename to configs/server/server_train_config.yaml index d3d70bd..662c017 100644 --- a/configs/server/train_config.yaml +++ b/configs/server/server_train_config.yaml @@ -3,11 +3,11 @@ runner: general: seed: 0 device: cuda - cuda_visible_devices: "1" + cuda_visible_devices: "0" parallel: False experiment: - name: overfit_w_global_feat + name: overfit_w_global_feat_wo_local_pts_feat_small root_dir: "experiments" use_checkpoint: False epoch: -1 # -1 stands for last epoch @@ -28,7 +28,7 @@ runner: #- OmniObject3d_test - OmniObject3d_val - pipeline: nbv_reconstruction_pipeline + pipeline: nbv_reconstruction_global_pts_pipeline dataset: OmniObject3d_train: @@ -70,7 +70,7 @@ dataset: filter_degree: 75 eval_list: - pose_diff - ratio: 0.005 + ratio: 1 batch_size: 1 num_workers: 12 pts_num: 4096 @@ -78,7 +78,7 @@ dataset: pipeline: - nbv_reconstruction_pipeline: + nbv_reconstruction_local_pts_pipeline: modules: pts_encoder: pointnet_encoder seq_encoder: transformer_seq_encoder @@ -87,6 +87,15 @@ pipeline: eps: 1e-5 global_scanned_feat: True + nbv_reconstruction_global_pts_pipeline: + modules: + pts_encoder: pointnet_encoder + pose_seq_encoder: transformer_pose_seq_encoder + pose_encoder: pose_encoder + view_finder: gf_view_finder + eps: 1e-5 + global_scanned_feat: True + module: @@ -105,10 +114,17 @@ module: num_layers: 3 output_dim: 2048 + transformer_pose_seq_encoder: + pose_embed_dim: 256 + num_heads: 4 + ffn_dim: 256 + num_layers: 3 + output_dim: 1024 + gf_view_finder: t_feat_dim: 128 pose_feat_dim: 256 - main_feat_dim: 3072 + main_feat_dim: 2048 regression_head: Rx_Ry_and_T pose_mode: rot_matrix per_point_feature: False diff --git a/configs/server/view_generate_config.yaml b/configs/server/server_view_generate_config.yaml similarity index 100% rename from configs/server/view_generate_config.yaml rename to configs/server/server_view_generate_config.yaml diff --git a/core/global_pts_pipeline.py b/core/global_pts_pipeline.py new file mode 100644 index 0000000..ed2ea42 --- /dev/null +++ b/core/global_pts_pipeline.py @@ -0,0 +1,95 @@ +import torch +from torch import nn +import PytorchBoot.namespace as namespace +import PytorchBoot.stereotype as stereotype +from PytorchBoot.factory.component_factory import ComponentFactory +from PytorchBoot.utils import Log + + +@stereotype.pipeline("nbv_reconstruction_global_pts_pipeline") +class NBVReconstructionGlobalPointsPipeline(nn.Module): + def __init__(self, config): + super(NBVReconstructionGlobalPointsPipeline, self).__init__() + self.config = config + self.module_config = config["modules"] + self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_encoder"]) + self.pose_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_encoder"]) + self.pose_seq_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pose_seq_encoder"]) + self.view_finder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["view_finder"]) + self.eps = float(self.config["eps"]) + self.enable_global_scanned_feat = self.config["global_scanned_feat"] + + def forward(self, data): + mode = data["mode"] + + if mode == namespace.Mode.TRAIN: + return self.forward_train(data) + elif mode == namespace.Mode.TEST: + return self.forward_test(data) + else: + Log.error("Unknown mode: {}".format(mode), True) + + def pertube_data(self, gt_delta_9d): + bs = gt_delta_9d.shape[0] + random_t = torch.rand(bs, device=gt_delta_9d.device) * (1. - self.eps) + self.eps + random_t = random_t.unsqueeze(-1) + mu, std = self.view_finder.marginal_prob(gt_delta_9d, random_t) + std = std.view(-1, 1) + z = torch.randn_like(gt_delta_9d) + perturbed_x = mu + z * std + target_score = - z * std / (std ** 2) + return perturbed_x, random_t, target_score, std + + def forward_train(self, data): + main_feat = self.get_main_feat(data) + ''' get std ''' + best_to_world_pose_9d_batch = data["best_to_world_pose_9d"] + perturbed_x, random_t, target_score, std = self.pertube_data(best_to_world_pose_9d_batch) + input_data = { + "sampled_pose": perturbed_x, + "t": random_t, + "main_feat": main_feat, + } + estimated_score = self.view_finder(input_data) + output = { + "estimated_score": estimated_score, + "target_score": target_score, + "std": std + } + return output + + def forward_test(self,data): + main_feat = self.get_main_feat(data) + estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view(main_feat) + result = { + "pred_pose_9d": estimated_delta_rot_9d, + "in_process_sample": in_process_sample + } + return result + + + def get_main_feat(self, data): + scanned_n_to_world_pose_9d_batch = data['scanned_n_to_world_pose_9d'] + + device = next(self.parameters()).device + + pts_feat_seq_list = [] + pose_feat_seq_list = [] + + for scanned_n_to_world_pose_9d in scanned_n_to_world_pose_9d_batch: + scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d.to(device) + pose_feat_seq_list.append(self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d)) + + main_feat = self.pose_seq_encoder.encode_sequence(pose_feat_seq_list) + + if self.enable_global_scanned_feat: + combined_scanned_pts_batch = data['combined_scanned_pts'] + global_scanned_feat = self.pts_encoder.encode_points(combined_scanned_pts_batch) + main_feat = torch.cat([main_feat, global_scanned_feat], dim=-1) + + + if torch.isnan(main_feat).any(): + Log.error("nan in main_feat", True) + + return main_feat + diff --git a/core/pipeline.py b/core/local_pts_pipeline.py similarity index 95% rename from core/pipeline.py rename to core/local_pts_pipeline.py index 8079706..8827dc9 100644 --- a/core/pipeline.py +++ b/core/local_pts_pipeline.py @@ -5,12 +5,10 @@ import PytorchBoot.stereotype as stereotype from PytorchBoot.factory.component_factory import ComponentFactory from PytorchBoot.utils import Log -from utils.pts import PtsUtil - -@stereotype.pipeline("nbv_reconstruction_pipeline") -class NBVReconstructionPipeline(nn.Module): +@stereotype.pipeline("nbv_reconstruction_local_pts_pipeline") +class NBVReconstructionLocalPointsPipeline(nn.Module): def __init__(self, config): - super(NBVReconstructionPipeline, self).__init__() + super(NBVReconstructionLocalPointsPipeline, self).__init__() self.config = config self.module_config = config["modules"] self.pts_encoder = ComponentFactory.create(namespace.Stereotype.MODULE, self.module_config["pts_encoder"]) diff --git a/core/nbv_dataset.py b/core/nbv_dataset.py index ebe45da..269da83 100644 --- a/core/nbv_dataset.py +++ b/core/nbv_dataset.py @@ -34,7 +34,7 @@ class NBVReconstructionDataset(BaseDataset): self.model_dir = config["model_dir"] self.filter_degree = config["filter_degree"] if self.type == namespace.Mode.TRAIN: - scale_ratio = 10 + scale_ratio = 100 self.datalist = self.datalist*scale_ratio if self.cache: expr_root = ConfigManager.get("runner", "experiment", "root_dir") @@ -83,6 +83,7 @@ class NBVReconstructionDataset(BaseDataset): "label_idx": seq_idx, "scene_max_coverage_rate": scene_max_coverage_rate }) + break # TODO: for small version debug return datalist def preprocess_cache(self): diff --git a/modules/transformer_pose_seq_encoder.py b/modules/transformer_pose_seq_encoder.py new file mode 100644 index 0000000..926a0e4 --- /dev/null +++ b/modules/transformer_pose_seq_encoder.py @@ -0,0 +1,63 @@ +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence +import PytorchBoot.stereotype as stereotype + + +@stereotype.module("transformer_pose_seq_encoder") +class TransformerPoseSequenceEncoder(nn.Module): + def __init__(self, config): + super(TransformerPoseSequenceEncoder, self).__init__() + self.config = config + embed_dim = config["pose_embed_dim"] + encoder_layer = nn.TransformerEncoderLayer( + d_model=embed_dim, + nhead=config["num_heads"], + dim_feedforward=config["ffn_dim"], + batch_first=True, + ) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layer, num_layers=config["num_layers"] + ) + self.fc = nn.Linear(embed_dim, config["output_dim"]) + + def encode_sequence(self, pose_embedding_list_batch): + + lengths = [] + + for pose_embedding_list in pose_embedding_list_batch: + lengths.append(len(pose_embedding_list)) + + combined_tensor = pad_sequence(pose_embedding_list_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim] + + max_len = max(lengths) + padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(combined_tensor.device) + + transformer_output = self.transformer_encoder(combined_tensor, src_key_padding_mask=padding_mask) + final_feature = transformer_output.mean(dim=1) + final_output = self.fc(final_feature) + + return final_output + + +if __name__ == "__main__": + config = { + "pose_embed_dim": 256, + "num_heads": 4, + "ffn_dim": 256, + "num_layers": 3, + "output_dim": 1024, + } + + encoder = TransformerPoseSequenceEncoder(config) + seq_len = [5, 8, 9, 4] + batch_size = 4 + + pose_embedding_list_batch = [ + torch.randn(seq_len[idx], config["pose_embed_dim"]) for idx in range(batch_size) + ] + output_feature = encoder.encode_sequence( + pose_embedding_list_batch + ) + print("Encoded Feature:", output_feature) + print("Feature Shape:", output_feature.shape)