From b98753bfbbee6a1a32a6f2cdf4dbfd9c214f6164 Mon Sep 17 00:00:00 2001 From: hofee Date: Tue, 13 May 2025 09:03:38 +0800 Subject: [PATCH] first commit --- .gitignore | 2 + Readme.md | 192 ++++++++ app_generate_strategy.py | 9 + app_generate_view.py | 16 + app_inference.py | 75 +++ app_sim.py | 11 + app_split.py | 9 + app_train.py | 8 + .../__pycache__/predict_result.cpython-39.pyc | Bin 0 -> 6306 bytes beans/predict_result.py | 162 +++++++ .../local/global_only_inference_config.yaml | 115 +++++ ...obal_points_and_pose_inference_config.yaml | 115 +++++ .../local/global_pts_and_local_pts_pose.yaml | 117 +++++ .../local/local_only_inference_config.yaml | 129 +++++ configs/local/simulation_config.yaml | 36 ++ configs/local/split_dataset_config.yaml | 22 + configs/local/strategy_generate_config.yaml | 27 ++ configs/local/train_config.yaml | 105 ++++ .../uncertainty_guide_evaluation_config.yaml | 130 +++++ configs/local/view_generate_config.yaml | 52 ++ configs/server/server_inference_config.yaml | 92 ++++ .../server_inference_server_config.yaml | 53 ++ .../server/server_split_dataset_config.yaml | 22 + configs/server/server_train_config.yaml | 167 +++++++ ...ab_global_only_pts_pipeline.cpython-39.pyc | Bin 0 -> 3220 bytes .../ab_local_only_pts_pipeline.cpython-39.pyc | Bin 0 -> 3584 bytes .../ab_mlp_pipeline.cpython-39.pyc | Bin 0 -> 2933 bytes core/__pycache__/evaluation.cpython-39.pyc | Bin 0 -> 4258 bytes .../global_pts_pipeline.cpython-39.pyc | Bin 0 -> 3506 bytes .../local_pts_pipeline.cpython-39.pyc | Bin 0 -> 3655 bytes core/__pycache__/loss.cpython-39.pyc | Bin 0 -> 1367 bytes core/__pycache__/nbv_dataset.cpython-39.pyc | Bin 0 -> 7923 bytes .../old_seq_dataset.cpython-39.pyc | Bin 0 -> 5768 bytes core/__pycache__/pipeline.cpython-39.pyc | Bin 0 -> 4130 bytes core/__pycache__/seq_dataset.cpython-39.pyc | Bin 0 -> 6628 bytes .../seq_dataset_preprocessed.cpython-39.pyc | Bin 0 -> 2610 bytes core/ab_global_only_pts_pipeline.py | 85 ++++ core/ab_local_only_pts_pipeline.py | 95 ++++ core/ab_mlp_pipeline.py | 81 ++++ core/evaluation.py | 109 +++++ core/global_pts_pipeline.py | 98 ++++ core/local_pts_pipeline.py | 99 ++++ core/loss.py | 27 ++ core/nbv_dataset.py | 282 +++++++++++ core/old_seq_dataset.py | 154 ++++++ core/pipeline.py | 140 ++++++ core/seq_dataset.py | 209 ++++++++ core/seq_dataset_preprocessed.py | 82 ++++ .../__pycache__/gf_view_finder.cpython-39.pyc | Bin 0 -> 4128 bytes .../mlp_view_finder.cpython-39.pyc | Bin 0 -> 2513 bytes .../pointnet++_encoder.cpython-39.pyc | Bin 0 -> 3994 bytes .../pointnet_encoder.cpython-39.pyc | Bin 0 -> 3446 bytes .../__pycache__/pose_encoder.cpython-39.pyc | Bin 0 -> 961 bytes .../pts_num_encoder.cpython-39.pyc | Bin 0 -> 971 bytes .../transformer_seq_encoder.cpython-39.pyc | Bin 0 -> 2302 bytes modules/func_lib/__init__.py | 6 + .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 302 bytes .../__pycache__/samplers.cpython-39.pyc | Bin 0 -> 2731 bytes .../func_lib/__pycache__/sde.cpython-39.pyc | Bin 0 -> 3454 bytes modules/func_lib/samplers.py | 95 ++++ modules/func_lib/sde.py | 121 +++++ modules/gf_view_finder.py | 167 +++++++ modules/mlp_view_finder.py | 91 ++++ modules/module_lib/__init__.py | 2 + .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 337 bytes ...gaussian_fourier_projection.cpython-39.pyc | Bin 0 -> 1091 bytes .../__pycache__/linear.cpython-39.pyc | Bin 0 -> 1522 bytes .../pointnet2_modules.cpython-39.pyc | Bin 0 -> 6108 bytes .../pointnet2_utils.cpython-39.pyc | Bin 0 -> 10030 bytes .../__pycache__/pytorch_utils.cpython-39.pyc | Bin 0 -> 5163 bytes .../module_lib/gaussian_fourier_projection.py | 17 + modules/module_lib/linear.py | 30 ++ modules/module_lib/pointnet2_modules.py | 162 +++++++ modules/module_lib/pointnet2_utils.py | 291 +++++++++++ modules/module_lib/pytorch_utils.py | 236 +++++++++ modules/pointnet++_encoder.py | 149 ++++++ modules/pointnet_encoder.py | 107 ++++ modules/pose_encoder.py | 21 + modules/pts_num_encoder.py | 20 + modules/transformer_seq_encoder.py | 63 +++ .../clean_preprocessed_data.cpython-39.pyc | Bin 0 -> 1268 bytes .../pack_preprocessed_data.cpython-39.pyc | Bin 0 -> 1525 bytes .../pack_upload_data.cpython-39.pyc | Bin 0 -> 1275 bytes .../__pycache__/preprocessor.cpython-39.pyc | Bin 0 -> 6134 bytes preprocess/clean_preprocessed_data.py | 43 ++ preprocess/pack_preprocessed_data.py | 48 ++ preprocess/pack_upload_data.py | 41 ++ preprocess/preprocessor.py | 185 +++++++ .../__pycache__/data_spliter.cpython-39.pyc | Bin 0 -> 2563 bytes .../evaluate_uncertainty_guide.cpython-39.pyc | Bin 0 -> 12622 bytes ...and_local_points_inferencer.cpython-39.pyc | Bin 0 -> 12294 bytes .../global_points_inferencer.cpython-39.pyc | Bin 0 -> 12035 bytes .../inference_server.cpython-39.pyc | Bin 0 -> 4831 bytes runners/__pycache__/inferencer.cpython-39.pyc | Bin 0 -> 11860 bytes .../local_points_inferencer.cpython-39.pyc | Bin 0 -> 12137 bytes runners/__pycache__/simulator.cpython-39.pyc | Bin 0 -> 10374 bytes .../strategy_generator.cpython-39.pyc | Bin 0 -> 5213 bytes .../__pycache__/view_generator.cpython-39.pyc | Bin 0 -> 1312 bytes runners/data_spliter.py | 57 +++ runners/evaluate_uncertainty_guide.py | 360 ++++++++++++++ runners/global_and_local_points_inferencer.py | 352 ++++++++++++++ runners/global_points_inferencer.py | 348 +++++++++++++ runners/inference_server.py | 116 +++++ runners/local_points_inferencer.py | 350 ++++++++++++++ runners/simulator.py | 456 ++++++++++++++++++ runners/strategy_generator.py | 154 ++++++ runners/view_generator.py | 19 + utils/__pycache__/control.cpython-39.pyc | Bin 0 -> 1849 bytes utils/__pycache__/data_load.cpython-39.pyc | Bin 0 -> 11518 bytes utils/__pycache__/pose.cpython-39.pyc | Bin 0 -> 6579 bytes utils/__pycache__/pts.cpython-39.pyc | Bin 0 -> 4387 bytes .../__pycache__/reconstruction.cpython-39.pyc | Bin 0 -> 7367 bytes utils/__pycache__/render.cpython-39.pyc | Bin 0 -> 4798 bytes utils/__pycache__/vis.cpython-39.pyc | Bin 0 -> 7216 bytes utils/control.py | 59 +++ utils/data_load.py | 391 +++++++++++++++ utils/pose.py | 253 ++++++++++ utils/pts.py | 117 +++++ utils/reconstruction.py | 267 ++++++++++ utils/render.py | 136 ++++++ utils/vis.py | 208 ++++++++ 121 files changed, 8665 insertions(+) create mode 100644 .gitignore create mode 100644 Readme.md create mode 100644 app_generate_strategy.py create mode 100644 app_generate_view.py create mode 100644 app_inference.py create mode 100644 app_sim.py create mode 100644 app_split.py create mode 100644 app_train.py create mode 100644 beans/__pycache__/predict_result.cpython-39.pyc create mode 100644 beans/predict_result.py create mode 100644 configs/local/global_only_inference_config.yaml create mode 100644 configs/local/global_points_and_pose_inference_config.yaml create mode 100644 configs/local/global_pts_and_local_pts_pose.yaml create mode 100644 configs/local/local_only_inference_config.yaml create mode 100644 configs/local/simulation_config.yaml create mode 100644 configs/local/split_dataset_config.yaml create mode 100644 configs/local/strategy_generate_config.yaml create mode 100644 configs/local/train_config.yaml create mode 100644 configs/local/uncertainty_guide_evaluation_config.yaml create mode 100644 configs/local/view_generate_config.yaml create mode 100644 configs/server/server_inference_config.yaml create mode 100644 configs/server/server_inference_server_config.yaml create mode 100644 configs/server/server_split_dataset_config.yaml create mode 100644 configs/server/server_train_config.yaml create mode 100644 core/__pycache__/ab_global_only_pts_pipeline.cpython-39.pyc create mode 100644 core/__pycache__/ab_local_only_pts_pipeline.cpython-39.pyc create mode 100644 core/__pycache__/ab_mlp_pipeline.cpython-39.pyc create mode 100644 core/__pycache__/evaluation.cpython-39.pyc create mode 100644 core/__pycache__/global_pts_pipeline.cpython-39.pyc create mode 100644 core/__pycache__/local_pts_pipeline.cpython-39.pyc create mode 100644 core/__pycache__/loss.cpython-39.pyc create mode 100644 core/__pycache__/nbv_dataset.cpython-39.pyc create mode 100644 core/__pycache__/old_seq_dataset.cpython-39.pyc create mode 100644 core/__pycache__/pipeline.cpython-39.pyc create mode 100644 core/__pycache__/seq_dataset.cpython-39.pyc create mode 100644 core/__pycache__/seq_dataset_preprocessed.cpython-39.pyc create mode 100644 core/ab_global_only_pts_pipeline.py create mode 100644 core/ab_local_only_pts_pipeline.py create mode 100644 core/ab_mlp_pipeline.py create mode 100644 core/evaluation.py create mode 100644 core/global_pts_pipeline.py create mode 100644 core/local_pts_pipeline.py create mode 100644 core/loss.py create mode 100644 core/nbv_dataset.py create mode 100644 core/old_seq_dataset.py create mode 100644 core/pipeline.py create mode 100644 core/seq_dataset.py create mode 100644 core/seq_dataset_preprocessed.py create mode 100644 modules/__pycache__/gf_view_finder.cpython-39.pyc create mode 100644 modules/__pycache__/mlp_view_finder.cpython-39.pyc create mode 100644 modules/__pycache__/pointnet++_encoder.cpython-39.pyc create mode 100644 modules/__pycache__/pointnet_encoder.cpython-39.pyc create mode 100644 modules/__pycache__/pose_encoder.cpython-39.pyc create mode 100644 modules/__pycache__/pts_num_encoder.cpython-39.pyc create mode 100644 modules/__pycache__/transformer_seq_encoder.cpython-39.pyc create mode 100644 modules/func_lib/__init__.py create mode 100644 modules/func_lib/__pycache__/__init__.cpython-39.pyc create mode 100644 modules/func_lib/__pycache__/samplers.cpython-39.pyc create mode 100644 modules/func_lib/__pycache__/sde.cpython-39.pyc create mode 100644 modules/func_lib/samplers.py create mode 100644 modules/func_lib/sde.py create mode 100644 modules/gf_view_finder.py create mode 100644 modules/mlp_view_finder.py create mode 100644 modules/module_lib/__init__.py create mode 100644 modules/module_lib/__pycache__/__init__.cpython-39.pyc create mode 100644 modules/module_lib/__pycache__/gaussian_fourier_projection.cpython-39.pyc create mode 100644 modules/module_lib/__pycache__/linear.cpython-39.pyc create mode 100644 modules/module_lib/__pycache__/pointnet2_modules.cpython-39.pyc create mode 100644 modules/module_lib/__pycache__/pointnet2_utils.cpython-39.pyc create mode 100644 modules/module_lib/__pycache__/pytorch_utils.cpython-39.pyc create mode 100644 modules/module_lib/gaussian_fourier_projection.py create mode 100644 modules/module_lib/linear.py create mode 100644 modules/module_lib/pointnet2_modules.py create mode 100644 modules/module_lib/pointnet2_utils.py create mode 100644 modules/module_lib/pytorch_utils.py create mode 100644 modules/pointnet++_encoder.py create mode 100644 modules/pointnet_encoder.py create mode 100644 modules/pose_encoder.py create mode 100644 modules/pts_num_encoder.py create mode 100644 modules/transformer_seq_encoder.py create mode 100644 preprocess/__pycache__/clean_preprocessed_data.cpython-39.pyc create mode 100644 preprocess/__pycache__/pack_preprocessed_data.cpython-39.pyc create mode 100644 preprocess/__pycache__/pack_upload_data.cpython-39.pyc create mode 100644 preprocess/__pycache__/preprocessor.cpython-39.pyc create mode 100644 preprocess/clean_preprocessed_data.py create mode 100644 preprocess/pack_preprocessed_data.py create mode 100644 preprocess/pack_upload_data.py create mode 100644 preprocess/preprocessor.py create mode 100644 runners/__pycache__/data_spliter.cpython-39.pyc create mode 100644 runners/__pycache__/evaluate_uncertainty_guide.cpython-39.pyc create mode 100644 runners/__pycache__/global_and_local_points_inferencer.cpython-39.pyc create mode 100644 runners/__pycache__/global_points_inferencer.cpython-39.pyc create mode 100644 runners/__pycache__/inference_server.cpython-39.pyc create mode 100644 runners/__pycache__/inferencer.cpython-39.pyc create mode 100644 runners/__pycache__/local_points_inferencer.cpython-39.pyc create mode 100644 runners/__pycache__/simulator.cpython-39.pyc create mode 100644 runners/__pycache__/strategy_generator.cpython-39.pyc create mode 100644 runners/__pycache__/view_generator.cpython-39.pyc create mode 100644 runners/data_spliter.py create mode 100644 runners/evaluate_uncertainty_guide.py create mode 100644 runners/global_and_local_points_inferencer.py create mode 100644 runners/global_points_inferencer.py create mode 100644 runners/inference_server.py create mode 100644 runners/local_points_inferencer.py create mode 100644 runners/simulator.py create mode 100644 runners/strategy_generator.py create mode 100644 runners/view_generator.py create mode 100644 utils/__pycache__/control.cpython-39.pyc create mode 100644 utils/__pycache__/data_load.cpython-39.pyc create mode 100644 utils/__pycache__/pose.cpython-39.pyc create mode 100644 utils/__pycache__/pts.cpython-39.pyc create mode 100644 utils/__pycache__/reconstruction.cpython-39.pyc create mode 100644 utils/__pycache__/render.cpython-39.pyc create mode 100644 utils/__pycache__/vis.cpython-39.pyc create mode 100644 utils/control.py create mode 100644 utils/data_load.py create mode 100644 utils/pose.py create mode 100644 utils/pts.py create mode 100644 utils/reconstruction.py create mode 100644 utils/render.py create mode 100644 utils/vis.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..913a66c --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/experiments +/__pycache__ \ No newline at end of file diff --git a/Readme.md b/Readme.md new file mode 100644 index 0000000..6ad7dfc --- /dev/null +++ b/Readme.md @@ -0,0 +1,192 @@ +# Next Best View for Reconstruction + +## 1. Setup Environment +### 1.1 Install Main Project +```bash +mkdir nbv_rec +cd nbv_rec +git clone https://git.hofee.top/hofee/nbv_reconstruction.git +``` +### 1.2 Install PytorchBoot +the environment is based on PytorchBoot, clone and install it from [PytorchBoot](https://git.hofee.top/hofee/PyTorchBoot.git) +```bash +git clone https://git.hofee.top/hofee/PyTorchBoot.git +cd PyTorchBoot +pip install . +cd .. +``` +### 1.3 Install Blender (Optional) +If you want to render your own dataset as described in [section 2. Render Datasets](#2-render-datasets), you'll need to install Blender version 4.0 from [Blender Release](https://download.blender.org/release/Blender4.0/). Here is an example of installing Blender on Ubuntu: +```bash +wget https://download.blender.org/release/Blender4.0/blender-4.0.2-linux-x64.tar.xz +tar -xvf blender-4.0.2-linux-x64.tar.xz +``` +If blender is not in your PATH, you can add it by: +```bash +export PATH=$PATH:/path/to/blender/blender-4.0.2-linux-x64 +``` +To run the blender script, you need to install the `pyyaml` and `scipy` package into your blender python environment. Run the following command to print the python path of your blender: +```bash +./blender -b --python-expr "import sys; print(sys.executable)" +``` +Then copy the python path `/path/to/blender_python` shown in the output and run the following command to install the packages: +```bash +/path/to/blender_python -m pip install pyyaml scipy +``` +### 1.4 Install Blender Render Script (Optional) +Clone the script from [nbv_rec_blender_render](https://git.hofee.top/hofee/nbv_rec_blender_render.git) and rename it to `blender`: +```bash +git clone https://git.hofee.top/hofee/nbv_rec_blender_render.git +mv nbv_rec_blender_render blender +``` + +### 1.5 Check Dependencies +Switch to the project root directory and run `pytorch-boot scan` or `ptb scan` to check if all dependencies are installed: +```bash +cd nbv_reconstruction +pytorch-boot scan +# or +ptb scan +``` +If you see project structure information in the output, it means all dependencies are correctly installed. Otherwise, you may need to run `pip install xxx` to install the missing packages. + +## 2. Render Datasets (Optional) +### 2.1 Download Object Mesh Models +Download the mesh models divided into three parts from: + - [object_meshes_part1.zip](None) + - [object_meshes_part2.zip](https://pan.baidu.com/s/1pBPhrFtBwEGp1g4vwsLIxA?pwd=1234) + - [object_meshes_part3.zip](https://pan.baidu.com/s/1peE8HqFFL0qNFhM5OC69gA?pwd=1234) + +or download the whole dataset from [object_meshes.zip](https://pan.baidu.com/s/1ilWWgzg_l7_pPBv64eSgzA?pwd=1234) + +Download the table model from [table.obj](https://pan.baidu.com/s/1sjjiID25Es_kmcdUIjU_Dw?pwd=1234) + +### 2.2 Set Render Configurations +Open file `configs/local/view_generate_config.yaml` and modify the parameters to fit your needs. You are required to at least set the following parameters in `runner-generate`: + - `object_dir`: the directory of the downloaded object mesh models + - `output_dir`: the directory to save the rendered dataset + - `table_model_path`: the path of the downloaded table model + +### 2.3 Render Dataset + +There are two ways to render the dataset: + +#### 2.3.1 Render with Visual Monitoring + +If you want to visually monitor the rendering progress and machine resource usage: + +1. In the terminal, run: + ``` + ptb ui + ``` +2. Open your browser and visit http://localhost:5000 +3. Navigate to `Project Dashboard - Project Structure - Applications - generate_view` +4. Click the `Run` button to execute the rendering script + +#### 2.3.2 Render in Terminal + +If you don't need visual monitoring and prefer to run the rendering process directly in the terminal, simply run: + +``` +ptb run generate_view +``` + +This command will start the rendering process without launching the UI. + +## 3. Preprocess + +⚠️ The preprocessing code is currently not managed by `PytorchBoot`. To run the preprocessing: + +1. Open the `./preprocess/preprocessor.py` file. +2. Locate the `if __name__ == "__main__":` block at the bottom of the file. +3. Specify the dataset folder by setting `root = "path/to/your/dataset"`. +4. Run the preprocessing script directly: + + ``` + python ./preprocess/preprocessor.py + ``` + +This will preprocess the data in the specified dataset folder. + +## 4. Generate Strategy Label + +### 4.1 Set Configuration + +Open the file `configs/local/strategy_generate_config.yaml` and modify the parameters to fit your needs. You are required to at least set the following parameter: + +- `datasets.OmniObject3d.root_dir`: the directory of your dataset + +### 4.2 Generate Strategy Label + +There are two ways to generate the strategy label: + +#### 4.2.1 Generate with Visual Monitoring + +If you want to visually monitor the generation progress and machine resource usage: + +1. In the terminal, run: + ``` + ptb ui + ``` +2. Open your browser and visit http://localhost:5000 +3. Navigate to Project Dashboard - Project Structure - Applications - generate_strategy +4. Click the `Run` button to execute the generation script + +#### 4.2.2 Generate in Terminal + +If you don't need visual monitoring and prefer to run the generation process directly in the terminal, simply run: + +``` +ptb run generate_strategy +``` + +This command will start the strategy label generation process without launching the UI. + +## 5. Train + +### 5.1 Set Configuration + +Open the file `configs/local/train_config.yaml` and modify the parameters to fit your needs. You are required to at least set the following parameters in the `experiment` section: + +```yaml +experiment: + name: your_experiment_name + root_dir: path/to/your/experiment_dir + use_checkpoint: False # if True, the checkpoint will be loaded + epoch: 600 # specific epoch to load, -1 stands for last epoch + max_epochs: 5000 # maximum epochs to train + save_checkpoint_interval: 1 # save checkpoint interval + test_first: True # if True, test process will be performed before training at each epoch +``` + +Adjust these parameters according to your training requirements. + + +### 5.2 Start Training + +There are two ways to start the training process: + +#### 5.2.1 Train with Visual Monitoring + +If you want to visually monitor the training progress and machine resource usage: + +1. In the terminal, run: + ``` + ptb ui + ``` +2. Open your browser and visit http://localhost:5000 +3. Navigate to Project Dashboard - Project Structure - Applications - train +4. Click the `Run` button to start the training process + +#### 5.2.2 Train in Terminal + +If you don't need visual monitoring and prefer to run the training process directly in the terminal, simply run: + +``` +ptb run train +``` + +This command will start the training process without launching the UI. + +## 6. Evaluation +... diff --git a/app_generate_strategy.py b/app_generate_strategy.py new file mode 100644 index 0000000..28905e5 --- /dev/null +++ b/app_generate_strategy.py @@ -0,0 +1,9 @@ +from PytorchBoot.application import PytorchBootApplication +from runners.strategy_generator import StrategyGenerator + +@PytorchBootApplication("generate_strategy") +class DataGenerateApp: + @staticmethod + def start(): + StrategyGenerator("configs/local/strategy_generate_config.yaml").run() + \ No newline at end of file diff --git a/app_generate_view.py b/app_generate_view.py new file mode 100644 index 0000000..a48d0ba --- /dev/null +++ b/app_generate_view.py @@ -0,0 +1,16 @@ +from PytorchBoot.application import PytorchBootApplication +from runners.view_generator import ViewGenerator + +@PytorchBootApplication("generate_view") +class ViewGenerateApp: + @staticmethod + def start(): + ''' + call default or your custom runners here, code will be executed + automatically when type "pytorch-boot run" or "ptb run" in terminal + + example: + Trainer("path_to_your_train_config").run() + Evaluator("path_to_your_eval_config").run() + ''' + ViewGenerator("configs/local/view_generate_config.yaml").run() diff --git a/app_inference.py b/app_inference.py new file mode 100644 index 0000000..bb3ccf7 --- /dev/null +++ b/app_inference.py @@ -0,0 +1,75 @@ +from PytorchBoot.application import PytorchBootApplication +from runners.global_points_inferencer import GlobalPointsInferencer +from runners.global_and_local_points_inferencer import GlobalAndLocalPointsInferencer +from runners.local_points_inferencer import LocalPointsInferencer +from runners.inference_server import InferencerServer +from runners.evaluate_uncertainty_guide import EvaluateUncertaintyGuide +@PytorchBootApplication("global_points_inference") +class GlobalPointsInferenceApp: + @staticmethod + def start(): + ''' + call default or your custom runners here, code will be executed + automatically when type "pytorch-boot run" or "ptb run" in terminal + + example: + Trainer("path_to_your_train_config").run() + Evaluator("path_to_your_eval_config").run() + ''' + GlobalPointsInferencer("./configs/local/global_only_inference_config.yaml").run() + +@PytorchBootApplication("global_and_local_points_inference") +class GlobalAndLocalPointsInferenceApp: + @staticmethod + def start(): + ''' + call default or your custom runners here, code will be executed + automatically when type "pytorch-boot run" or "ptb run" in terminal + + example: + Trainer("path_to_your_train_config").run() + Evaluator("path_to_your_eval_config").run() + ''' + GlobalAndLocalPointsInferencer("./configs/local/global_pts_and_local_pts_pose.yaml").run() + +@PytorchBootApplication("local_points_inference") +class LocalPointsInferenceApp: + @staticmethod + def start(): + ''' + call default or your custom runners here, code will be executed + automatically when type "pytorch-boot run" or "ptb run" in terminal + + example: + Trainer("path_to_your_train_config").run() + Evaluator("path_to_your_eval_config").run() + ''' + LocalPointsInferencer("./configs/local/local_only_inference_config.yaml").run() + +@PytorchBootApplication("server") +class InferenceServerApp: + @staticmethod + def start(): + ''' + call default or your custom runners here, code will be executed + automatically when type "pytorch-boot run" or "ptb run" in terminal + + example: + Trainer("path_to_your_train_config").run() + Evaluator("path_to_your_eval_config").run() + ''' + InferencerServer("./configs/server/server_inference_server_config.yaml").run() + +@PytorchBootApplication("evaluate_uncertainty_guide") +class EvaluateUncertaintyGuideApp: + @staticmethod + def start(): + ''' + call default or your custom runners here, code will be executed + automatically when type "pytorch-boot run" or "ptb run" in terminal + + example: + Trainer("path_to_your_train_config").run() + Evaluator("path_to_your_eval_config").run() + ''' + EvaluateUncertaintyGuide("./configs/local/uncertainty_guide_evaluation_config.yaml").run() \ No newline at end of file diff --git a/app_sim.py b/app_sim.py new file mode 100644 index 0000000..feb5687 --- /dev/null +++ b/app_sim.py @@ -0,0 +1,11 @@ +from PytorchBoot.application import PytorchBootApplication +from runners.simulator import Simulator + +@PytorchBootApplication("sim") +class SimulateApp: + @staticmethod + def start(): + simulator = Simulator("configs/local/simulation_config.yaml") + simulator.run("create") + simulator.run("simulate") + \ No newline at end of file diff --git a/app_split.py b/app_split.py new file mode 100644 index 0000000..b89923a --- /dev/null +++ b/app_split.py @@ -0,0 +1,9 @@ +from PytorchBoot.application import PytorchBootApplication +from runners.data_spliter import DataSpliter + +@PytorchBootApplication("split_data") +class DataSplitApp: + @staticmethod + def start(): + DataSpliter("configs/server/server_split_dataset_config.yaml").run() + \ No newline at end of file diff --git a/app_train.py b/app_train.py new file mode 100644 index 0000000..191853d --- /dev/null +++ b/app_train.py @@ -0,0 +1,8 @@ +from PytorchBoot.application import PytorchBootApplication +from PytorchBoot.runners.trainer import DefaultTrainer + +@PytorchBootApplication("train") +class TrainApp: + @staticmethod + def start(): + DefaultTrainer("configs/server/server_train_config.yaml").run() \ No newline at end of file diff --git a/beans/__pycache__/predict_result.cpython-39.pyc b/beans/__pycache__/predict_result.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..457c10235bcef0e0c560da34b3f90d66ac22f315 GIT binary patch literal 6306 zcma)A+ix6K8J{z^oxON{iQ~penr>T4*(%t*Kxm_ov`t7=39KloxC~Uo?40p>?b(_2 znQ>yTW&z<;@PHo>yiinR_bvPb2qE=v;Q2fu@qqjc4;6yK?>jSVdz}=qt2uMdH|P30 z-}n2zbNs2PnugyWZk=2E(|Jw%D-}-uDk!{(C;c4~u5lJ=F2h@o^uFO5OnXP;1~(sS z+}zb&Gqkw1p^Fc>=Gr0KuW&oGf2;AzN5%nj9Z0IB#8Hx3SPz}`#x7%8XogN$3#WLM z*LZzT=Z!tXZQ^@VeXWPnVRLwPx@0Np<0Qna~2w0r&a9ft2(#1{gA1ZRJij{+htgZMYB@*3a|6VL(O&g z6mQ~P<{ zxB7z}SrK8n<`2TQF7>6CrCrpe?$R(?eWd;KJIj55?=P<Lt(X#l6h)maauf;793oWX6r!O5gRRUPWT|Wh(N>y;!W;O)@27IQPzCP2 z!M#C}hVrRyn3ej$y%G`i;mo+^N#P_=T`#);@lgD07^0+FPtw6KA}M%$J;zan-26+Ev> zy@Pj$-meu{DZms~7fzZ9tft%|WiT@sDs3Rq^z&?n&EsjXVS5s||IvN|OJWJaQ(@Uc z2CHXdJYoO9vS5RKSpEUgx3MA!{5CtJc5S0rx)~&c9YIrTGeLWMYDHoiGE$9{xJc7a z&$nFQlL%sBzBDE|J2l_mBg5HG9L}^_-}1LxV}iKaQj%VLLej2^7f@APpyZ24WHrQ?p#V9EXGy4(+35S*;u-pE zz^_5YmBpLAcr}r%Co3f#$Oh^iF%(Mhf#cCn>42PtVis4RZW2(Wb0ZG<=@J=5300lB$v zl0Q43fI_FhV*Dl3_pPqRt&dq3%z>|V@ik}A3D^PCHfvOecm+FKSrJz-3qe8JA?-s! z3#QtW6Up23sg0yvDd7Jq2?%b{q=f88TD}neji`^NX%y6| zj06;ll9ijBokJ9Qfuc`#K+s#cRZytHf;Lj2R#K?A41r|tNee3rLVzHGAW5YeM7_aS z2lcz)&AYEkc2~Sh)h$2nM&YUT%h@qXN9y<$R8Acv8fyZFi){GJBnFSQcueVt_Z9VQ zBQvHPgi$)23J7NvssR-;a}W_!@h&hK=Wi+Xv!J9_k z_$r>lzW}7BQkq+A_~Il~pBtP^Kt{?#TdW{^37G;i&d5`*gSD|>3g``DYj|mGU0J7E zq_rZ*#071btPd!>>m$?(Ib6>12fH64vJ4QAkY7=8hpew6U_%26*cF4Jg(9z7$yvR) z3t#Pb!YDljfSesiIb&T&Nl#PKd<~WZyE&yZkJX<5fF?@OzXE`W3KXOrfUGj;3~9+O zqWPTN*Tf~R$DkqvBs$2!``4IqBId1NQ__=E+ynhl=qO#!%3)!_Z3}P3h7foix%p?gQxTZ4$RS1 zx;D=oc!4Gxj{RccDwiG)aRb`&@sHHx$T}kwGK2=ZutrW!fUNP`Bal7#hw`rjHF5?f zEYMKL9;6iL0LxPnFk2Xk+C_Wa5Yz?0nb=aJxv3!TB}h4 z7?c${1ya}X#hoJ;nc)}|9Wj$T7lbh#9ZujTI#juNs^V;ttxeW*8&K7wo=(9mpI5Lw z*15MTlK%L_a1R5e6n1KW-qi8neVNr%&en&oOq%}x(c^>-5DAd|C>!_!GAEq%JcZ~K zaKiY(orPn~$)&OtSgndiNcF@jlG2Rb>Ej~`DOf59)>8`?&p{U8DG8t4-y9o0`-G(& zor6?o6H_P8Jc~z7{S(*=#S)R$N6ZO>0%O6H6>D*Ai-|bbVSEfH2Jl!BJ9Bk*K2a`| z?D)GuO2^I2M*M**4PMzZxWlV^HbH?1r1mib8Bh$0<2@AjQ<%I7C5B5tbVKoJRplyq z1tZp>%5~~-I3;!xzT;Lp(N+i- z=310VT^ld-ri$p&?1)ZR1p(6WJ82TZF?K)v4gDV7kY?uhqTB7dARZKbO6Zs@*_LGY zC3_%iYGo;$c<4H`R$72tp_~X;5p~7ES|Z$Px|TdpwaUK#cw;y|c;K{hhm5ZU^J$#~ zMzZfp_C3kmw&B*d=|;%QdKm_)@0=_Qj|+!iBGgUEBFThg9m%pTB>4T~25f$E5x6!q z!krM+E!8cK*5J+p{cbOo4OR34?1AvxQ*wR~C0VqygzX)yc}a(^jZ!%^S+sBWx?3Vt zcNcEe=iDp8&+0u~VS$%lY+KwW$sKOdRPYLUtZU^Geb{7btDnrKtpL46{TH#K1F}B zdcT)$`4Kj>huY~15)JOv0S!9rWj6DfV>O@Yk91ubO=#Uojyo|Nl0v zCY$@zt`DD@q|&4D+Ou@iLpI=fvgUbEt*wai4bR)el!`Cn7nqQsjmM7BjbMO4ZY|+L zorvhNK)gi>onn*=RiSASnzTqp4ndK;_!Z^U)yZ)}-b@%s$e>f|#9<3{oz)yXhSPBB zj_o*(h1`r|tC}?Hz7OcV;bpjT7y{k66_;ta(hD}0pDZ1oUG}3$#hdDC4cC23@nC1z zME&w=a4Fj!YxJLQD619Z5OGsN_hjue(oz>YDu#Eb(~T(fMZ8o-mPKGoYZ1Rj0`5TGYm1zo8slx4Gwh0DbhhX&GL6&!yqe$u&C?A>(=6jPe2 X!y0kbS$35Lc zb&bt#dQP^&DgOb78Rf*kRF@ob+FQ<%67qZ9<6(B;V0zTmRjv>rn4f~x&33D z*utDa=6zC}3}4GfkEGE_k(buY6Pp&foTW3JrMZ?TSuu(;S>)MSLg#Hi-g;1egpOuI zuGmm~!^T{3C7wn@PkG9J%7(rQRETe&S}MXfRBhG4x23wOhi~)^^VzTso%?QZboVcx zU()5{T7&yVn%nX-sBwR-S|YY|p+1?IR4V_5#xlz*_Czejig8w^z+PL1VDD`I4eU$3Qa-8Ng<>^`JUR)c#OyWG(N{)4G2fHpP=QC{#{Z$u9 zndYgL(uKfooSwLTbCqm{hKzwZgRR=N^LVPuS)6Fsd0@5Cg*}^T7k>KLM_+t$&$O`C zMAWotS1`_s*c$R_cT4B-DBHNgC5?phOP!6)4iLwje0W-n zwLVm_jSpw0nCQeF&d%&(Y6mtvf(r5yfZs9eO%ndTkfJ^ ze20g;!(Xm{yZrpk=k}&M&y0D=t;m98>ksjke@4@|&_YzgvI(Dvhr+y4u_at<1YU)4A=0FoDaP8`R*6_!@PK$AJFKx2Y+2058CD$E_42; zkhtLZ^FRFQ$o0pCc^n%ht%=jz`N#Jj9J}5|s6k$zHU;pH!nhDKQ}~Q$Zq`SU{mpH3 z&HkpjMKi7S5jQX>5xQW-Nc3PuyyE^k^4@n2#Kukx0G@>d)^Hh~GjVHodAvG5l`6)Yo#u(A@1!186A47~^| zzY441I>YW(6?Pf*D}7p~SZy&|phZMOfdi+CZ(PJV-lLry<6T5!_kYLOw1`reqx}-&F(8 z^~!iU%U}!=Ic{Bh3Nkd&(Q_|gR|==cd}E?_pr&~X4Yu~la+}rWH#9}=Iq24am|v0L zJv8q2Pu?y^u}x0R+qCRWYBqQeFM=47{KHkh?aGZ27SB=DY8PQMjt6~;w^oOQ)`+5 zFy((MK&=qK1V?aWJwrN?`Fwg&8{st^JP=4{W(^9t^EE&M?@u%`K?B$zFz;if^8kVi zn?oAD?Bx1uD{G4rRR*mqXqdaS@8%E4Hl#6y#ZBfOP2#+Ap8d$$2bd!~%j;-}2R-~k zF20X^Sp0vX|8tnIu`?98Waa>8Z_G^2u}KkZLJWX0DV8mKE087-&zAlQ4n!4&Le!&l zP*FreI=G$`(-Dd;^4=!vBio(t=5ci;7VovSi?$%G;vUYBwby1QbnUq^fzeG!jWp$}UG_N&B zV(Ubms2lA?JN3+ zH&wS>OPaFm+VIpxrMP)I-z&A&9)p7z#rfj0df$(6P192^N8&Ab^7hwppe_nzWzH)Qy{ zv3>pUZ=Ucs5!QXy3r_JQw%4-^&oZ=|@8^?gp2j$pIzNzl9@r&Fg#b z{ppi)N4Y}>xR)oHELQlR_jr{7_n}QKE zd&ZySVwTFnw5Pg=WH!u&RHie{3%MG#3;FQr;87wUN24U8b+()q=7oba9{`+U7-yLj z(MZO+w`JU7HkC?|KTHrsNtWm+G9KU^B?qQkKR2p}iadrngyYXNx zRwHuiSq$Ht9;ppHmkwNPqhZ#DW@B|={Uxk!ZPnhw7km+vd}#@L$%*d>7H7{m08JimewsbV z^2gZ@jefiHm-&9LrRaT(a~Brl?tk+8KkS+AC|8eTB_gfjBs0$b2lw|)XJwQixkER( z2$EbG4>J=4j;&s*wj%$l@1d)%ubLY)<2PGe!Jz1%VOGF9a3d~${L|-lV6)9R|4~5E zp(_9V6`j%=LnE)TIoI%B$uYJC|B?%vD0`J%29d|kOKfhJ_QDm;(n61o;X#lczWa>N zow-}Oi$FkrmVEANYtbS<&%M$qz0$qRaJOH2ZASeP67>SR!RA5fmv$MH?#O}f{XJ#j zG*C)U$VbUg9$*1mLH7)xEHLJ}{m}Ju#&?JJrO*?q` zOv7>kAw{6PfvX?_GLS2P3Lo(ncoUb^Dq}sPXE#>K&{!HZU+=y-CfMjZA*7qsrk#l%9V93`N3WUVI8{O&OJIJV|{> zli07ArO!C~70kVbtGJ8?&PMEakR13ofdhyC&vU*54_0o*InH^R)N+*`{)xdVJ)ly8 z$l03wZY@zSj3HbH3~~jZm-Y$2&Hj83ekNXc@Hz<{c+m8GV1Zhu^k4;7#y*mVPi)bx zaR?D>9Y8RJ^Cbso+(o+#mnbwJ{sb#Ki*AYHao*Og#YWj)T8m3drv9Nf7h6=VQSvF< zl>nOdS$%<=YObFRvS)aweuNjJno=Z_YSfx1-kRrQ8xuWW@>Kc-doFh+>P)2W2YQuY9=s~Rdo%``(7LYx#h zgt6o7s7iS12Xwy9)5tMfXD>t5Ld9FNb7p$>{TgPIq$9^td{cRrCdj9EU{iN#BA~5k z(zGf$s=5I!q+TT)7wJwRPuf_#j0QOs-wn&*+dOQdku5w9R)U50_HQP|8X%wx!#Z_8DB|=qmQh`{N z2Ph8OK^U}y?O>yt>2)jY6y-~E<=#=1DBsI-eREZJ*5%BL>vP4^LNr|Kzq-C=RR5v5 zIXwG^6E&8)zTH}1JJU&8)D=*0Rw1Iqt*Y_rm}_rVRX|no6TDUX>}|^U1yz$lr4*zV I9SHRQ0X6&Ca{vGU literal 0 HcmV?d00001 diff --git a/core/__pycache__/ab_mlp_pipeline.cpython-39.pyc b/core/__pycache__/ab_mlp_pipeline.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2db6c28e2bd6391ce3fb57b7d0e0165cde761fdf GIT binary patch literal 2933 zcma)8TaVl{6tNRsm~RLtxp`nYhv$_=8%31~wyWXiC6eKjF73f8&JTuKdjt zey8%cWE;k{n7f9K^vp}9=9#&d8PfP~DQ}pvQOtAI$Yf4=(n}c^s+pxZpVt~WfAGBW zFyi~PA0@h<#j{+U*^SekFs6AgOcKs$pNF!uqMUq`aUnDw%BM6+B1x%o;UD&+-Jn&@ zOUtUzh(VixPr0fk;gIK9*yE~xS8~Bqd604C-umYHy_+|L2YrRFi<&NUP%Eb&r=b)L zZLjkr+=)4T(XxoLa~}7_5_EimRgKbK9OgNtPtf1rb%&fqVRtX>bKYg447-^~2fQb{ z*@4_klkVv#>?Qz4lUvX;w)71CKL$! zqzTZ$Lk=(fCur5D$1u}zqj58sy$1#$0k};upk{J%Y%#zu!ebj+aA8A>Q#b{1#V=&+ z7Ouf2IJ*NoqB5oXsfZahnes7HwIOW2jC}~QOmL>iuIv!i+U1hY8%?-&JhON4CM-f+ zh9bD|!a?ayD0_RV9wpgGQUlam&|A6bNP_1CVB#XQZNO8-S)KUAer8*BLNNKqwGZEU z8OX}Ncnmlf#+(DE7q9>v1lt5ao8s@u*c#gm7Zv~rumM%MeAEO~_lcl0D#gpYJ5`V~*$*NXko$W=V)PTl}xj_Qd z`mB@FmkxJ8e%YM&;p@ZYBuq9UV2oKZY)Z@lrTjp+M2XA|#RrYKF9*Vbo{$e#FOQEJs;ka$p!pyaC=$`+1-RE3?gN@iWO5Y|LDz_h0r9ADFdjm7`i1;|K^& zhn?u^244@Q8-tpm24niVqqfRyUw;vH)##e&PA0a_nh)X0b@BlV!rjI;StWJDhE2Q* zoY*8T_`$Z?VE<3DALFD>O$)RZT@dIxrIj>O)ui;nD2&S%BV*zcEWi+r*wB_B8b))m=1~n^S<6Y?GS6;tjvyxBZ&&cUs~j-AA9E)Y;ny<|kiGQ@J@m2b8LF zZm}(YelVyn*q0ak^vhG^Ccs@DD{18^&l@X?eMd5i^HQk%+0hb?9OiUnmbtJmBUDDC U{3|6`pJKf^%^3O@R5$VQzwN&rFaQ7m literal 0 HcmV?d00001 diff --git a/core/__pycache__/evaluation.cpython-39.pyc b/core/__pycache__/evaluation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae83be82c00aae38916c661c160d9bcf865cf770 GIT binary patch literal 4258 zcmZ`+-ESMY5hr(7tCg&8%ilT2@wI7nca78b8ni{w=8_!JqJb|+`hhMQEH)w~%U$g% zT*`@EsxQtx1_=68;7~w*yqCTf=%2#3q7Oy>g+6tLtCeNfCt-)f;c%854u3NnR;y(K zzkfVjKlo^skbh%m@#jG29z5w^fiS{oMEtC!z6E_NvWJfEP?XtGVd(m<;hm^B^nB0o zg{U+v`(?w^XlYpSD>J@2tob#=yHS1E@Eery%CyOfXY>(a9xJ^etaNIfQom*7Wg}l= z6_8hqe3{W_q*Hwlt0Vgzs#{NaH;Gdz#$7o`;@`+Y1XAxwlJXf}d?M2sTYk!8#>K2> zKTi67X#|=98}xczG`aY*pmPtN^coOOe98#8fHLb1@oi=^2U>>}U^#B~!iQ)DXXQWS zDLfD0Nzt>6JR@Jz{jW*K(oV{wo~D5)!GIzEeY!W~Y!L1pBt6deSSZ82kw^}CSMH6D z}qAD`FK_;GZ znFVc0c@QMzOAH(PORI;n@?2))h>4RT=^3R&mJZ8mNmbVggHw8IO_$Edv~or!O;wmyRTFhyZp)giE9aC-n0-pO z$SGBJw%ZND+`Gh#7tA=&OXJ16MA>lu-o9RW!BZK?BzTpGhy^C9e!_HRB>24d-UXk& z$AEwkwb@|kdm@o0`24aIVVoMi)6}In9P)G&cDY{qHILIo+)qNmwEOJAXOBO7s_m1( zNZ^U-U9e!zI=I8|^}|q#!RtVpNx_Ryb`SL0TqlhC5k`9|!x(J4{52>t?L;8a1(^uY zcH%?~wKL>ltet`6L+ys65vmn>Q4&gRr{kfvBOZ5J0>MNtC1W`n%fMJEVc2GDQSfve z$<(jmHnVr5j$h9=3obL0+{}V@aI#cGE6J@~h0|1vGpidCbV zv|{hjoyEJATiEXu@EYQo27xXI!7yQCIEO%2g5c#ijIt3?g?EVWA~|1+;c!8s2n00@ zka(778;QT#O^&z-`#cb#>PKEa_V#+IP(KwE*f0BseZ&xN)I{UWq=Xgz2f(3YTWLR*HmC7r_pTT-<{msJ3g zimZA@rylf6tO|frW3@ATvMkF}(3yi`Jp;-!s+Q-nhJjU}JsUtM*RAKEdP%c^S^$+5 z)l$m~&e6Hq^RW14|o< z95&S|BWgpfvF2N6vI#KUQe~X2HrENlEWj|@-crk^y{4L`{Xo@CyM6MhTs~~8YibLu zt*8~Xran;Hto61Cr*M6;lf6S-pTFbA??`;lA#(Mdbqw{#WS5eYn{sWs&X(2ISA@0I z?pNe*j(PXy1hjt-iD=iRuRwYSM_0%@S7Zp_!{&R2aSTlG8ovhi|PZb;Y>B!Jo# z31S7jisP6dYDh07h}GG+YsP_I8ef|wkab8PX*Y~xNER3Fvu>HZa#2=JyFBLR7U?!% z36B6S2eC{8gzhj({k7bf3`L&}8x8wS@GB}M39xz^bd%V-M`N30-$a+ZiK zAll}~T>Kch(l}0EjyXSpWGDEDL;5PD2jPgD>?zhz$Av^O97nnU8Mx2IE}FsJ>1&x| zSn?`P!{G=J+qgOCM#-4zm2NVGtj8}Nn7EG9v3Szu4DZnS2I6$-^VeNI!oo~oC7_G1 zLJ<$*zSzMzH-L0L%uo<9OMDB-w}JQ@=oJI=3kYux)ZfY+boF3m=Bw=NF^7ropgjnn zS*!`aU_9>E&tobxym3`vWkd5&JsLpmna7r2HA|UQKB>QaAcuTkIzXg(43=Y%^STkJy-=w$cO{gle ziUMm23Ho=sY#>(&m`7f@k*L-xzdjsN_q{r5-zI*~u`Vi@K>514sP ze-}Cm>hduaZLH$^bPX!{*ijVfdNQS|0JXI%t*NbC3wXjULT?pP!s=!=-vLKp0%SYBw&N z-&HD?j{)EQ0-h8zKLJ-(EC&lUdh+qrrNfn%89p$@KFsg9q63Zi2+7At&Pz3X)56=& z>$Oi{0G|}^I-IBD-4@vIv@?(@#N*-USbPB@-5$etYkKSAm!0vWE@wk{t+I;0G!sBD z8kDN!bMezQL!Q+q$EKFQpCs~D3Q*1yc|78}yuhz4>dZek45sSM#j(tMl=K7SGURHW re`5wi4qv9s6r*N;S2;s_cj4AS&HFRV6X`m71P^8L?|8Ja?a}`M2KjTE literal 0 HcmV?d00001 diff --git a/core/__pycache__/global_pts_pipeline.cpython-39.pyc b/core/__pycache__/global_pts_pipeline.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53107351c012cd75a141358aa984a8c49eadda0d GIT binary patch literal 3506 zcma)9&u<&Y6`q~_$t5YvisIOE(W*h6mMAKrNq-<|+PaC;AZ?-;vC;#ZVzJ~5rKOj< z^z6`%)aA)Wb7_G70YO0?^DoV{C*683(5BzpUD1@BgGlV`%$wQS;rrhA-p9Dx4H-Vp z#?JA$L)arMH^?>z)GrWl%^en@(4DIH-`DB`BGSk10 zhdNg$u-PBxNBvi`!Q)8DVV)IQ&4xP3vuK)3WtwC%I!g0FoDONL^^e^QihJlt*0%)f z^Do%Q5|-f4g1#+m;XGq~N4Ubn-4(tFaQ8$@gt+^nEjqXdUoeODTR3jl+&OsXPoG>k z#{1RF@8?OTi~G%qhJ5WD)F;o`=yVKY^m+Wq1@9GWgYYXoddrEJbi%qtdjpcp{5wJd`GUprw+zKAB46ef06Y zPd|K5`BhH)b9;D~7a`9N)IE73{3Y=hHTc#C7!!$06DEgZH z=g<2SDUx{qI3G#5FJc|l?l;aLZ=Of(-tmR|Qnc-z`|(29k@LvS@8T-{fTj+yIWM_pV{6PGa`oMk zEf9K3P_eN(u>K5Ix3+3;;S0V9O1`v&z2ro1gotxs96%*cH$TlDX8DtBmqu^x{&jxX zYbp9M#`y>q;~swUyAKXbca*Cqu@aG1agrJ5@Vy6zrn53iaN41VT!cohjE9*C!o^mv zR9lhP)pyWUJFDh4&G^k0H!vtVXqXl74t$7>UI#K1kzoAoFW8}kI zY|b@YS8|MP!N26fCYIi0+u-ild48$#buCO0RUc8J_k_ug$1mA|<`VZnJq%`lVe4r8{!qc>h3II1RF*C*zv1V+Doq$43y9Y=2p!v}>4XiYwewBeJPu?sDXb&A9xm7b{q znE$!y0KOxsK{{~ZGKQUMP+@-AM{kRh~o0pu1ziEGO`ef?^1u5Wh%QzF_5(~1ym z0bGFq(5$cpVhcjc`phA;<{t2HO1I>Ke`Ny+w%}!e^{ui6#C&1H8kS*Ah?fTl<4EVx zlU$`Ds&eI5WwLDXZQRbXv#ZsGf1&ka|-(>Jb$&WiJt2rj@w348_5T@2vv6+DZoW`j0%iLzOu za-)y3IJOrK<)(&S082PuaZt(wDDD9QSWZzrqHH!fdmVeKAED`Slrq%-LScLza3Jc3 z5;?yHtLLh>vC`N8zts&rz+(qH>KlDat3bUolHx z@bvdF_iJ3mHX3*uQQtv9;NJuey!}7VxhexzKE@#~xR|n&u=Vd=vb z==Qh+rUXaGlfwBD!07g(TY5`tv9V<8@A}$elf)MlmI^Ze!0enWhU^70neAad89?z7 zS87>-imHByUqU)jB$Lz4EY5Zl)T48X9YSUe7f+`WiIJ3p zY1c}pB5EZgz|KnhM9?}jbY&$oOn0@8M4uvh7<(9N<0b{(Z|pcbsS28Um5x?t&WfyN z_x#{5{bQ=o|12?tsFfwaq(r zK#A4fusZyz^u3UeP-7tBNR~mug;CUqx5#v&=+P`r>zPXAs5^AFDt}QzQ55Sn!@7;w z8puYB^NsC&<1m_A5+O(3<)iKw8StVjB W{Sv%g`|LMVw+m8(1WEG+JpTuM8oK`g literal 0 HcmV?d00001 diff --git a/core/__pycache__/local_pts_pipeline.cpython-39.pyc b/core/__pycache__/local_pts_pipeline.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87811db0733268028f5ec68292c1c09a2919ef52 GIT binary patch literal 3655 zcma)9O>Y~=8J?Z};POkBMgi*7v^Xp`sNUD1^6gGlV`%saEQ!}ocg_k9?3 zyDf&#zprl{{B51FuW7LQ@iBN8S1!>=CY^-!>CO{2cKVJ}+4$J)yLI0i`+dLcyGbx^ z^c!X0OTuxp-*lMr=Pl_koZmAU$i_1!8_F5AkKJRZ-;vHk)(cPYCbrXaEKgIk>+fgd zNtUYAd>9Q)rVn9r?`8Y_SJS~0q17-;bEBt26KAQI#1oaosS-&xjFKVkwEA)DL4Fq< z#rlq9ef~MybEG5rv#{?o;-Qu6=Fi zy+3_?<`nnn-4C)jHTi@3JVU;64C<3#=`{J@(HLf&f*tcYUoeLi=nJ>NnQk!jJmb#; z(-?)LX2C|DtsNA+ECKoClW>RU}nPj%pAKgkABBJyFBrb>!E z6`9_K_4DaOX-)pHp%8Hz8zF1}RQKY2+pXR!s-Y&2VNT;wWt(XUU0)k z&WJza`nv_2Bk&fWVQuwb{ROOkWi`RPA^AKk_`;Fyf)lk7Al{Dk0F*jf|0I2sW>3>w zH2Tf0zsz=fO-;YWdiPSipFMiWM_^>y z2ACNmSX})|xfS_be+ON?wQO$EOrzf7Dh7E64Rb=?feUf*qc<+(yn2f>zM}x51C_q} zD>{WULMpt$X57GY1;^Nx{0lB!qUd#Y1;ibBudtb0xN~263kN+mg7-jf`0Z0Z^JaeG z&qIk+v*0t|IP)g?cor015fuIvhNl}v&}P&xkd7|1n`{;qjlwO$!r$}Y_r{KPa2h1U zK&mJ4Q0-%c&wt)I`39d8nmnOw(`$ie2NBreBZCePtf!+?>K|aLB13bU=8vaJ9jS^a z-EpM%;}mp7Y!9qQY_=_WC>2TkNCDOoi}t<|QY9u5cckTm+=jF!8w+Fkcxv6;Nb4LS zYbb4|dH~vgW;;M{Ul}9<5aB$8opR+s2?F_hIA2YLmU9a@ZW5BXHl3rNT?oz9-OfNt zDPRd{Ne4CutN;LDme2yQIf3PT>JeD80C0GPUvSC4bO8jL^CHChX3+#-4e7$#Dq0mC zt_~2y!eruUrV}a3Jb6dpL_>$)eXTvCyxP;kRh)&6mDi3?ef;(7&Q(Kc{B7xoySiWC|pSSZ~t~K{hq8 zoI^(wAV0>H6Y&^G6*`-oJNW)D@LS8?(VG_nXGwU<0%y>n+`UBST@2vuB|7u6T7xuh zfr?q6W}}aaICAG6C8nBOz)E^wau6y2CjJh>S4~hjL^>Uxyoo*akI?ius+g|vpdj7= z8i@I!Lb|W8>iPN|th6rtZ-Z)+iY05Q`pk&Z*+h_D;~Ws05{A9@9Mp;`Ix~2a{*WfI zUp`GNaF2)(oc$g~9vK;&igJn5@5<+5g z4f)(zKog9Rto9jF6#QMd$NVn)^8@&p_z=M3l-v;ib*}*&pjV0jR&Zq0K(S~LOH6ar zLgH%fBS}i{3l6UM^LEi%Ktn!$11mf8ZUJ35YoouW|AF&Bq0Bn~8O}^(Y2xYZ7Uw5vY*7%&OKsiYJyM*-|LmZl#z&dJ=`D zDzWXQEm!&{w8|Yu2(V}}QOM0ybC3^8r6r)OrsyhtjjVp0<0RIN(!;V&>hHtaTYDaS zc4Ot(747YooDOVt=>RfJ2oeLah9f>$HUAN2y^+aF)^o!lNCw z!gjbBu9Y*rZi%v{dTg&gI4leCds${~FLi7sm9DMM<YBakf6;Atzxtyj zYD=Y;H&)k9O`PPFx(ZK~CnR&J?#pAY{IXJAmBK1~tMb{~q$qOA9fU~zMG^QGsPXE= literal 0 HcmV?d00001 diff --git a/core/__pycache__/loss.cpython-39.pyc b/core/__pycache__/loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1b1a846189cd6445da52452fca29a61e68b043f GIT binary patch literal 1367 zcmbtT%}yIJ5VpNPS<(bhK@_S=y*3xp2dFC4ALxNbm5_Sai?ytsWTWhEYCE8yNwZleuzq~l?;m@F{6uAa`JfyD=}*7} z5p+o6`A%~dGfE_zdV<~26Cym}-xA^9up1it&K@{>C~9D@IeQ@J8EHlL5RG)=TCdB8 zrPe8`>x+SM1f*X8lO(2s#7t1ZZb|ITFN%E;K&a4$C&#B41J5iM$`t55kj5GWIU_UL znUNW@zLvwTrCg!Iv0ta{T#77d_sg!7ZIPIyJyPXBrlvibn0{Hbi}MSvWLg&5sBvmC zG|5Z8Ft^iE$u^d9IGWgq^Q_1W=OJG2@}gH;^E3~az7*R1ur@%d0l9p?f>2RG56HC# zJA>Va-Ge=-yaBWrq_7G^utTnEhEAEG1Fx!8Ughr-wC<9rUxgyPV*;B09maDVRS`D6 zW4EL>S)LdvxW=q)!z8LF%|i3E5ei#(k9pt!hBc$s6ItF0Y@qweNZP=ZD(zc8mq}qg zJ1CR1isU&j-qiqHRXCJqKpM6AuwTHdS=P3kvw9BRwX>_XyaWd( zvl<>+AR#o`y)P4o+IMAX4z-a=mSzH2uCK{W2h8rcaPQmAUSW*KuCvjbImF)LW6=on Mh&EV6x0@0D0~P2Y&;S4c literal 0 HcmV?d00001 diff --git a/core/__pycache__/nbv_dataset.cpython-39.pyc b/core/__pycache__/nbv_dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c67efd6d6bca94b035f4b16c030c86e46220a6b GIT binary patch literal 7923 zcmbtZTaO&ab?*Cg_gr@N%3ZEVEvf4;%92(bB~}4(#1`;3-36Z=7Z~!3y^PTFM+1Xt# zVL&tJ>ZSo97dV#0vR>$v_gYuZB5>(V&)u5{GwV%fYg$yPcKpx!|0x@zQ9Od58MfHRiMOLyeWWd0-ys!5Tl$ zDtC>a>f)xFTdf(a^UCNszD~1qy?tS{UesBYd923fSe?zY1<)^YyM3|!JX`wI4>tH3 z-@uHQ_+|bAzjUA<(CAC;7eT-LNDp4(FY}iU^mnw6Ha|3iSAtjhrBQ>w%2o~y)GzZ_ ziC*z=CAh+`jDF)#)3u-KdmBu@r8UkyCV6i)bXmR;rTm>JLz$sj{b~|#w|3r-;%J8p zbeQiYJA=)u3=~&)lWoqg_C(U=&FpG?=Uymyb8=7OG!y-1)=FXwtI(A1B@z2h*6N_` z-b_+Hu9t3RDOH2><0)o_hO`h^X%c_4KXcqjw^89r9|GfJZChtLGakBu$xLQF)B=my z%)#AeCFbJpK*m1qC01q?++DUpGJmKC9?2a&K4ia$dzme&AD# z&6_g}jz!fpW|*^4r;B(@C5lYT^jteIMn+F%M~)u+vI2ke%qL>~Opcv`#&CCpAG%NK|NM6BMFK%^Ed5nk?-_O)jnX z6Nd8kySHwOIeKQ_{`Mb!b4!pSNFzz5r8u$LNh_9~k_gGIVM{CJoo%r~bmg=eb$BQu zm`>o~e>A01ykmkOJ)-%PApRQ_ z5(LDc;pKlhG%_8ILu;G4acB%prVlMVw=*k8hylg*p#uuI4x-2%7;6b_Codi9LpOKH zwG6%7%dNYX_>d8Q=d^PNPQ_>Ds)jj#fHmNj zz{zXFat`N#^(wgAxRY}#<(0gO^}!pQhx4hUR{p!aAMC-QVUtQ*LK*ygH`))I$vrL* zYe@h3&foks{r=%?>62!{j$0xXbqJ{8$SUjv4uXaETKs-0Ycb!?!aF?8REJzn`CfQx zM)4&KmR={i!#iP%?aOn;zMLX#_(~@P>8PxdPbkK~qZR(e7E!#zW$oyZvTa{&wxSiJ zC(?*}((3XkZqy1@TSsrfMu}LY34a$|!OAHf3J;~Ih*iJfV7!ca1Cok(oklOv633;7 zh&BpoB(M`W@HFAZHpiebgV{}0c}^$_EtgXFoWn+ zI5M{eD@TE)74K{)FT1(9IauFDm`Hawp7yv=612LsTJd%w-lVG8fybb|lI1W2->r;y z5o1bezezExVW`~Ig%T4Ox=hkMnpFtLv zFsZmq#S2uth(bDiAMr93zfX-4X*tn4g-$E0sx|vUAhuI%Rr*@IHbpGy7xg0iDXl%O zhjQ)+7${PXWyU_Xj^I;)%q(N@>X{NdF}z6vn*2(rP==wDtBuMDffWKOT&x9w*@Jp! zLB;GG;?uPcwD`-GmN^fYt_>Ys8$6elMlLgR=V#i_mFj(fh{N#gdrs^Z_iNgIR!n$EWLO4o4iL@y3kV_;3pB-DHXg( z`rAlroqpWf>jSq&Vkb>R1~l34ru{Ag0SQ@_djM@2P*>7t(!ZDN zBY>o>0TwCht(djo7-Y4WCse(uEKPU;I3*^e3)2YI0zTEot1htQew;?#9&{p9E2Ss5 z?n1H75n*`0mF*(L^?LAg_rT+LD_8VY2`OL|VWiw>`9xk<`4jeTBlbnJ) z(3uMMfE@~CShyoofom}c5Dc9i0GS=6L54Og4c(!~>>Yp%W&lvUgT zDTcKv)N%4EfXrN`Lr>-?cY9v^4dpqf{3XAWAjkU4-Cl-szl+r_*DAhSrpASXyu zqgv*V%2{Pp&F7K*`ad*^avPxN5Fuo^H0}4PZ@8RQ+O>QsU*0j0IhL{NO73B<6*Vg| zt95u%Yz1nRMJkAP8)?3G zrxhp7ekT&L(u}&~c&d`RIuB$RTg znRGEHjAVlT&>JjI8pqJ1JU{7{d^cytuY$+QtcDICcQfR7AW@QgXnIG7^aB>X7InacHUgsf$s zUPB}qWrlJ11|#fQ5p#mJ&NCv@8?=&3 zKc({1eUOclZk&Os(*_d8J;6u|uAI`kGY16Ar)bri7fA(a%!vRc-o*;v$CU~c1ac`o zz~6TNZ=4{Q4kZ=p52RtZhWo@YE!|OR1^}MEidr4un9dg}_5$z-WQ(S1G7< z0#22}5NSdf{xe8Ynn}|K8_!@v5g%iL22EL@GXQmdN$Ec^VQ!rwwm`FpEB!Hw76^+0 zIr8rV3}WVWjafJXMVKD<*w>3}3VG6`7obXKrExz_2#)o8*D;p)YZ@K7AWUEWk@g9W zwD0MHsRPyu=-tWepJ>;$J*Jqgo@SPu*DYd}(##!9`_TfO@DL(oz;uWDpBZ}=30ou} z_{@-YmWbxAbl|^$kL9ewbvGo??yp_z0#@QdMg zym`Ho;G8XeYx8M+M{_I2PHn=k|6Y4wOZ+JnKceEe{s`K`f5kQCUJ&&cSd0ws*^B)r zDE*DN5;YK7fM?UYr}&yG{St;N zJ*tp}d;%RRG-yV!BxjtEns@nwqL&gMCFx&Lg*a-M;$z(CEHG-`l@>^m&`bMHlr?vQ z%0wfM^3eOlZV^!zo0X{i~`3dmZKX_eE2t%E)GkcxbSNzt-9`(R`tQ}eF3#p zZ!tb^5Pw0O9M_`a4C;ulUl3m;F->k3B#FJc!fyR9?{{5 za&5HBu2ETYUAN&nuIIYG=iYYB>4ASYYQgIcUY`99Md26UpAdpY_WI1%1e5M*zG!km@u4{J$I%}UnDmvps3fpB z?Jon55x%{P(A5|W@*M>M+kb$^Ro~7`1qD?c9NZtYU+eR*@rrq zkY@S-A1*R{fGLiI)T!oNF-DrAubZp1{EtrX|Pn2b&uhsPP0CNP>h43X9G zH>OZ+7D7pn-H{#yOn3{sdSptfrfl53UTFPWiw4^jXGkr*+@)q$ zw!|#agHj|t7)^m3bD(-NQ1sAC|ChNHJ^5CkKo9Mq{e80}MMZ&XAS8D7o0)IE`R1GN zxqhwYX!!m0?&{9h7c}kf)S3QO(0K<>@{X=)Ok;YexjNoPXmm~2)blZ`YrFOcE_91- zQNdXzN|XxwofwZ7t3)Vu0dd9_m;)|qvrD;ztV>CU>d3ND0m-FbIj!NqW)yXY?J z8ZQo)Sm{WAUt?wFe6BHvn}_D1?k@8cR(WK6rHf5PTkYkpa%XssuM%~xcg_!2bC^|G zjn&x3mfUj)TV{AIqzFCFTKwE9x#GUk__>h2Z(3cqrw z-_<@{|CQms>R#oShCkp}+47MA_+@^T@D&f&+-v;W@Uw?mc?sY3Mx>@lc(ftdXcBMJKXcSc?gHTI-d*Dh zt)(-a8K0M2lbOu=TyrgEvjW~WE3y*ah5UBhBCD_}`lT;4OLNOOe@$L^$alByeEb1# zPH`}jL(`bronQ+OF_C0s(+-WHk?C}$8Q3_dWtwMZTE|SSj@>CdGR0eIu~T9O zvG&v;EMSy@Yb#s_P(eYJOdD2@XnhS(T|qOLEvngBKqUpu0V*qKKGTN_%wd%y1FII9 zh67c%q?I1@c>nh33~QTb3;EXC{U|(GOZIxbSfpzyPtvt+%=m}6xN^!xH;DX{%aVxW z)MJ5=PSOj5)N2JHm&IP1c+p;0mT5_-=4+V2I*;+T;9M?-Pz0|f;;z>d9QxDbN%CG} zN!r+w3sHn5Wy$k`2t++mqHa0zyFBUnO)jlpfUvYaeD~ppf+R0(tcisz;JR8tTadGm zGsi4>JwM$UZJIQB#66t ziDh)2K|Y?>4E-eWyl2|KKiTMV7K{yrXico8caZMH(RUcjM#P_t4Pc`g3%)V73R({n zSnnN(Ic!ExavqJQ&+3j|$G?F0=x@-w|LPf{-l|r151(7S`QL^|80FnhHq2c_092lt%GlyAc#*r~F znLe=a*r}Dl+XLgRK>-+cj*%6P%s~aP)Er4j|g-zcjq66678 zjqE0$g7N1zXbV4Tb6G#0QkLi_!7E&8f~2?#0*z{Z>4@)I&PY)?DL9gq{Eki?z@43% z5X6yGYY1?1Rz~|Nj>&w~fy57Kj~mpiO&X<0WKGxcBsbBJB9U2Hx@A!ROuAdPSvysbwcFn zl%dwO)EQRN>adoTkuW&FHS)Fv)r2n`6~QZ7>g-bHJS+E@b5+J##R(kS=zXhRgXW&|k#Ll}T?2*tM*Rc9qqh8md=o z>v*3#?w{z*9?wE5b~X=M=d%Tv$zrwy+0_v-O*ZpQeXyJ@!&GLsa?K!nc}`8PEo2oc z@@)rE+}sO&A*)Tl>!q>xBo-l)_355>|2f4ZL`uuw50dt0i+_IqZ~d=tW58wJl-L>}(P(CtbPNjJw#2BNuA=5xs#V=W>McuHKKCGLzgyXCBA79?#1V|WY91l;fZdCOh(dTQx_+1_cHR2);(#?S2xX_a7z~U!S;&*!?2isr-(KOy;dHAK+BrkM* z1dDw-nTS)L@=Nb6Mj1!%KpIg`7X8E*!atA|F6EhP-AgWctV;{K#Wqz$&-d_xvnm;|6?-B?lXi{=>LmvwSBUXb!! zDDKQj`R$8$ag<~A8q0Sq#4-Nk0J#9M<1XX`Lf(QzAZof&cPSr4T%_b`GQ64(!`rE= zn3Cdc?CQRnPag}Kd8bjwv?uBYSwqb3NQ)-JDXuLMcj#g)C0KG}?n=IeQokI`l~BA# z$9|u%%M%HG%L&Vh=Yek}JCE>_duX(g%Ow|l9?qEFl;PC%HT33H#9stdH{qcFYng!1 z%iZ&_n=a_aKMM2xYiBv?sohT`1LP1W-&cN`Qp`LlI;0@0Z#4|@2F8?GH1s@_tKTRDVhL?`GO2N z%pn#rb5mm$DkX%S(TKf?qYY7l9FM|5qv+TK4uucO`jFpT`JMJ9YTU=VxToqZhw7~{ zl`pkTZTFsHs&bksviOI@6b0Ni$bLFcN#>BY(I~R~!PvD(uHqdaJTs)7#-h0+3ozy+ z78o$rZ%T`F=Tv|s#ScfgnMsbSQtO&Zb_(MI!xfE{WG_j_#h6QXhER{N8!6t0eL!&E0!dkY#p z#lS?f7BQPg4*5R~^23_V%(yxl!N%x6>us0trJHY7Z1 zeoajejjWBT&?XNu2d?$RJvX=C*^n758$r)cmD4e zAG(uqC97&@??Fwv_HJ|?+P4K7RP(tUZ{6x^V!#o>6wSr_qI!fWZBG&6AE}itjMM5 znaRfuF;8>L({l-eJC7-KM7DMepP$Cr!{@VPJ&fBPLF%iBvL^8nku#HnM0^ZLRa$Z; z$5sKD@|(PXbMjG2EQCt5%A>G}~A6O|-7g*=ZF3xq_T zL46pJ21}&f6UdxYoRkiEUD})r%!RVdAvDycC`(PM-z?fp)<)lkyxa*`L62?AGDJwq z(d1Kly+%7JJ1lMH7IRC-Q1UZ0B~`3Z!A4mFdMM}cH%j_^&q1Hcys9xNmJW^ zkE?s^cdMJC&PcHQUiIvFb~jG9y1MGU`g|Y1_bP6+ zLWbw>H!j`3x5e1MsImNbXncS#AE1y-@&W5sf8OP2ivb&ouHY5T8M<9}UiXH6*SB?N zP#Xr_z}DSCeHeB_&XhlGNN>h}!K5#1PnoPK(Qh6(N4(pT{0{2`FR&uJ)!{~@DazFk zv*9R9Ra*Q!?iHDS0-AF(JLq0X_Z~%B^|CZC^te|fSsIO!ks2ha>d_p_k27!NH&9WG z&d?QKuzfDM6i-xP< zPOX3YWO=c)Hq}X-w1xanC<<2ak{yYum@&>u)FY?FMz$GhZsAS+qSg=k^^)~nDYn^6 zfF7QqJ9B#D4Beg6n`h|WoZc$hnAeg1Hant~S0uw>YPU>nn91=#<)$$z@<^qT$H6q}0z*qr9_fynH-TTGMe&5JgFv6j5Y+=x9GVFs*8dsA_5w z2zXk>g);Rt9;$p4_mm0m6iTbCcrsGP|M)jIKE3&ouA#3EsHjsx#%sL&K^7OffqF+6 zH&=sw-A3a}Zt79g8^n1YMPIW2+}<53nZ&#Ivwfv@Wn9F&Bc1hCuh<wq4dAFBowR>V(*G5lt3#&BoN)%&ezE(e5qd zLJR6}_0#l0nmtZ;sP*p7A0~G@b-jkgjQcSt#=HB;&whE!wDvRoIMy;MbeyEdz5CIf zyQaAyD!86YgFI+G)5Zs97#AH!Z`!Fy*!nW6`qHA=mf_7mA_D57o-UnoDvxz7mu2P`wNbyWahkWW1+kmI_-HUAEl+M(X?u?_x z`Vbks5yJXROx%fAdecAx5N2ZH6?|GJ!A|_rE&bBlU>IF1{RX3Y3DCH}uCPf^)=H-g zN^jq_UZpv91AFwPdX)6k0a|$e{nHo6cwP|m1!=w$!gKD&uy+p#*Sl|A9j8*i0hTpy zQy-`K!?99_sxtl7FxCf28V@2^@}6 zgnXS8h;N-AzWHh;E>ChwljO?_Ok~ z_5&&wlbbL}M`I|EoaiSsgP)BHP|hete+xWXH@}L4bP;fui{qv%INtyIUMNm#hi|=F zD+{kCS3ji|Lf?g2K0pIvUZ`biVyojuT$rn1c zGs4%`Y24OJMOI#+vjQvW8^pwX`FLuDR*M!AN(u1@W8W!fy)XPmQ#&%tSx zD|5~YcHW+@luhj60ywMXN_pWx$k2}40PHr)4Lj~0dag~^OXQ1_P1|#`tSx$O74_+b za%;w?8#DHt>%SE~&boEP?f!p@lUqKWZPDlU$(f7I>6UDit$y%l_E%UbXO_||Thog( zc0$9B+2GRTVtH}8P08AG{_r;XU6w5b4aC%P3%1<8WiIEvI89-@={a;Ut0z3ltA=s@ zz-;ug;U2P7f|F_(Sa&`Zgx)2eO?z8 z)^^WHW{$xWZ@U|6h&U*d^Z=0*(Za&Bx%_o1oAWLl46;3h;*)hLYIHYFt8m^mw6jHd za)>CCm>T8Nkf)|lr+-b$w@;7jZb2tSLVYw+S!?%>7K3bm3Ak3 z@L}cPV%bFClR*))zA1dSFAkr4@GtP8u0HAJqeFHC{<3;zcXp5n!ZXp;Syfq&&dU5U zvx0KDpyBuD3yT}4S2XRPR2lu{QF#qd{C8c`n8x&$=IMAFEu(FErk=G~ZQHX4wA0Rc zIYpZ-x1IO$indyXcF`*icuHPL%~kfw>Rs_FywaTLPcnO7S3FK@sy*#ZD>~PjY0rAI zigsIb?Rjrr*LbeK!1DY0+Zrpd;!TYexw&WV>E0q=Vx>*vb6xyR&7G}PZY{T$z2(9B z$Gl^Q>wCv}ss9i^PAhZ0d7^(jqgjRJ*#w(pQ{bKEcJtxp3Y+=L^-i)`?-5>ZKFS|K z@7yf|;}-cztbK|v^T+t({M4Slr?YvsaLc%${p{D?6X2fcKMDRt{=}CWTe@xT8(x*4 z>OaMw*w?+M`P2Q=`x zBMR5TYuA0jYf%^{V!M_!?%RavO0?9aC}1BXjTY+edr`~>bnd+*CemAu`PSHMJ}rtZ z)@a|OKXXut83?X)(=$HP)^(;c z1#flYco%T$(|8xz44cKf^cm#pl_9%1xp;+dy?^9m2Lh}aBU`gU(?dIqC5l8#^ipT0@mILJ~Dvw@M^%g)`iYlghzeE-T&a$E%mfJT_ ztE5`9m`pS$n^T*HxQP+28Zmv)K7;l@uzU~P&3tOi(rS6oW$}+>7PG>pn7(iJ%!679 zH7l8G&i5BmGle|&O^jJ&8e~$wEL{;riO(8B7UE8;k@)M47MHnB68qtHTUyC(hfBK_ z)Hb-BZ9xP4brH4wj^I$G8js_PvWy`K7e3>4!TDVPjwT(k9fll}@|_MBjW!PxX@xKq>f zom1WJkP<~qDoQl72@X-FYK5wKnk?=FH7>1pB8KwnTUV}%3F@=2zW@4NrdE9wWSsF*1DJ{zG7SpTHF@_Y)R(v@xmto4@g?zf$gr*)#fu?fon#$AMt5{jPfnW@<~_BCjs*|hq$ zDjo3UQfuGX*W;JB=3xacYOt=Yd7vR|c7TSH&uH-{qga>{4jPxGE~q8X<@FAE=Uenz~hSny5t- zXYj<6C^U$yqPzGn>W2Q5-aYd`u_Oz)n4;Fj()@W_rRs32LX9sL!=;i2y7JtHC5fs4ek_CGOt1S}@g zQwy%qNKKSxYM``k=k_QZuzPlb1IHMLj=JY0Ip(C!9;wA^y&U@7tS`+0q+HbVsY_7U z%O_@@vEbT-@jgkhCByC7S|gh1|kZWtt?&R zpvlf!Jd{)IV8^dT*SUbhC3EL8-wLkrmfv7I)si&9lxbL=14qFI4!a{7VIs(Kz>#ZBi()h=P^Aa|3FT6KTZ`ax7|?_w zMIeNExjTR8qv-5wr^O_7JVN_(nxq4 z?kX2=*U0@<4JCBxs&20ucK16jR+ zN~**7bY+6ZZ3{5~j1^If)&v5>s4P{t2xrAhiQOlo)O&VnUr{m}$W7}0J!>`DNg#`p zm{ci7hN7padKQH=!j8<5RuZ3+shu*JN^M(+%wkB#WI=tha@j*V^P+V6mv>Q?})1hwA^7vu9BNo)wm4=i8<{tw3iD_GJf zP@`mq4vUVVXW@;crDxZ1gmq*>2K2g@>$zx03e(G{`CcIzLW`XiaG=FRZ(3=Q(zRPg zuauTp4xD9h*=Z5SS5XH^366^coJbuU<^&yaQs@_xQoo$im4f>#BP-Xjk9|OUZ*nlx z6yABvI1S!PQld+TbP6fdG-jJQ80*rw-Yoj3(#dqTZcsvm{miALJae=jlyW|q#=d40 zFxvZuawn6$g>)V~3&ewjy9DPlo7(9-=3m5o<Pm}SiyGjU7F~Unxwj<+~SzI$+ALW+PU(f%KDU@-v@Kp}^7elT$-kUvKamu@ck#uxS_-##van zFl-+?18LD%IS9hYSLUf9i}>83{N%*=uxa>B6&W64lss`LQ|8~NzdSO@c_wEvWjmUy1mp^D;$ePu;XJmHza&lzO1V%Llw`TK``QKj^+2~a65AaKNGVUCpj0_Eg3m0)-9s8Go<|k&-$8^ZC}17?ESd8u z!uI_?qHma>t$fQi6mzvgHzVqz!iB{YKCCh`WhZ=_~@<;YsB|eJq)xFF#2<3oWqqO7}T+3Z{9XIc~Zpl4& zaC6fR8u&0mq-2Y=yH9-cCW8)pEuqBeT9hO$zD6qEefB#rjtrt;^y=LoV@nNf1rm1$ z>3ke*iy9YnDC#mNyR|6-A_zn`xrh(%eY!7vhcyqQZJ;P^J-b|W)lGsQZFZ+ea>>Jf zi}AST*$OPW)@o<7C5;=`s?SS%1g&P?+f5=-+js-FL92tYf)q>6jrJT|M~R21Z+^5d zhGFtZUC0$i=;=|O!6lkac!rgJ)t+9vK5aZ|>bEVhFXN{nP06(@5G z;sh~2Ooh5gD@5B#XFI!@blbruF0w@|4Hl1NWh&Hql;$rExXDQ=ZAS)&JWZ-YiKV&8 zccpzjXl>)7rx7=B-5rFu^2>!Rb0JO=(T$;4rS+es87&4+Gh!bqt0Vi6?#r1qzCzHT zz{*8bHvB8lGD?Q31D2+LshuDwD-oChnh?R6Lr}`z4uFZ0X}VFGM?b~ilNkR$kg*Qc literal 0 HcmV?d00001 diff --git a/core/__pycache__/seq_dataset_preprocessed.cpython-39.pyc b/core/__pycache__/seq_dataset_preprocessed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f31bcf366decb19fcc28168a724818070243f0a GIT binary patch literal 2610 zcma)7OK%)S5T2g*zWi9fV&~;N5LzYiB7y`&2tg7F5XP|xTt*hnc)Hh~%{=4c4p^J+KeB%%V*bombe=Z`7x;O!&Pn&;;QWwh4K~H5*$kUy zP0*fW^K9W|CA!2f^GjfBm0#gk`RWll!gW{p>fjn%JT@an9$KxX6Fjqx7SXl4vE=t- z1ysSXajz^o>CSg?5$|vT0%yJ4Ir!+_ovr(2($9IJw${re&Q>=?+2x7aDz+a}!ISa7 zETj_sM5Sc`i|c53y^Pt9D$QUVZkCdd_+V2>WTW>b-$N4fjP2YSRKL$v!92N!yM*5w<4E$%!j}C!g8(1hwTUY`a{0={cAo)XRtZx z(VlRJJbJ`WegZ^UiX2);_P{1XplLr&`IA;fd$Ql-LWJO7hm@v8 zswmYy93G%1DkyFU4hzQ)v|UOaWT{l(O;=LIbBd#utzF5pj;O<$S43ARO|n=@N?%$3 z{QO!_2u2%VK)l<%eYIN_?Gfw~n6(T3gpS1Rq!hf3xV{a@wcfsf(zWK{lct)<4Dm@7 zzJpuug3H=?TXCXU?Gm2~gY6w%FFQ zom3#YNn8}1(Vmh}A|6xwJ+4<@pDU;;>621qpavyGKV^CWk2jB>(xiOM1ym~)vEsVE zJ!1UR)a?&N18uot z7B#%IVEn0eF;{d2DaAAum>)y3fNC{LcepYMOfgEi2*e^+?IvkL>fno9^KZX;@WDH# z(W$`%Zi~|gdCV-6NtmDt3c3s36R;y2P&p2qt=V0Rm@T{+YC!SyDQ&@8X>ybt+1Rdc$QH*5mXaZ@DNSoz>)NU zgmyg$r%szm9;XoV!N>nCe$xR=E7~rV$~bg`v%Wm8alL{W`SwPBu=e$xtuMdXx0P1Y@{7g<`b8O_Mj%xQa3)Hbe zw5j)Cq~|vGp(7`|ccEvmVc^h@1){1KCT9>6*^42r${D*jsUaZ-UYh!@PVmi1olz&p zTSl);PW0iflxtbJLy;5M?Zhf-c_z4F2|99LifbtK;~o?y)6>Is!}EqC5+2T0hf?e% z0y9BZMgk>8fa74iPrPZw8j9e$;7r%@xPSzNI|&B9hznNc<_^*4{x89`&7}5vA}y4_ zt9qERXVq;uQ{v%g!$kia(U*6Dgrt7@hQz#f-8M2iB+Zp732hgyj{2NU-gy7=Yq scene_max_coverage_rate: + scene_max_coverage_rate = max_coverage_rate + max_coverage_rate_list.append(max_coverage_rate) + + if max_coverage_rate_list: + mean_coverage_rate = np.mean(max_coverage_rate_list) + + for seq_idx in range(seq_num): + label_path = DataLoadUtil.get_label_path( + self.root_dir, scene_name, seq_idx + ) + label_data = DataLoadUtil.load_label(label_path) + if max_coverage_rate_list[seq_idx] > mean_coverage_rate - 0.1: + for data_pair in label_data["data_pairs"]: + scanned_views = data_pair[0] + next_best_view = data_pair[1] + datalist.append( + { + "scanned_views": scanned_views, + "next_best_view": next_best_view, + "seq_max_coverage_rate": max_coverage_rate, + "scene_name": scene_name, + "label_idx": seq_idx, + "scene_max_coverage_rate": scene_max_coverage_rate, + } + ) + return datalist + + def preprocess_cache(self): + Log.info("preprocessing cache...") + for item_idx in range(len(self.datalist)): + self.__getitem__(item_idx) + Log.success("finish preprocessing cache.") + + def load_from_cache(self, scene_name, curr_frame_idx): + cache_name = f"{scene_name}_{curr_frame_idx}.txt" + cache_path = os.path.join(self.cache_dir, cache_name) + if os.path.exists(cache_path): + data = np.loadtxt(cache_path) + return data + else: + return None + + def save_to_cache(self, scene_name, curr_frame_idx, data): + cache_name = f"{scene_name}_{curr_frame_idx}.txt" + cache_path = os.path.join(self.cache_dir, cache_name) + try: + np.savetxt(cache_path, data) + except Exception as e: + Log.error(f"Save cache failed: {e}") + + def voxel_downsample_with_mapping(self, point_cloud, voxel_size=0.003): + voxel_indices = np.floor(point_cloud / voxel_size).astype(np.int32) + unique_voxels, inverse, counts = np.unique(voxel_indices, axis=0, return_inverse=True, return_counts=True) + idx_sort = np.argsort(inverse) + idx_unique = idx_sort[np.cumsum(counts)-counts] + downsampled_points = point_cloud[idx_unique] + return downsampled_points, inverse + + + def __getitem__(self, index): + data_item_info = self.datalist[index] + scanned_views = data_item_info["scanned_views"] + nbv = data_item_info["next_best_view"] + max_coverage_rate = data_item_info["seq_max_coverage_rate"] + scene_name = data_item_info["scene_name"] + ( + scanned_views_pts, + scanned_coverages_rate, + scanned_n_to_world_pose, + ) = ([], [], []) + #start_time = time.time() + start_indices = [0] + total_points = 0 + for view in scanned_views: + frame_idx = view[0] + coverage_rate = view[1] + view_path = DataLoadUtil.get_path(self.root_dir, scene_name, frame_idx) + cam_info = DataLoadUtil.load_cam_info(view_path, binocular=True) + + n_to_world_pose = cam_info["cam_to_world"] + target_point_cloud = ( + DataLoadUtil.load_from_preprocessed_pts(view_path) + ) + downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud( + target_point_cloud, self.pts_num + ) + scanned_views_pts.append(downsampled_target_point_cloud) + scanned_coverages_rate.append(coverage_rate) + n_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy( + np.asarray(n_to_world_pose[:3, :3]) + ) + n_to_world_trans = n_to_world_pose[:3, 3] + n_to_world_9d = np.concatenate([n_to_world_6d, n_to_world_trans], axis=0) + scanned_n_to_world_pose.append(n_to_world_9d) + total_points += len(downsampled_target_point_cloud) + start_indices.append(total_points) + + + #end_time = time.time() + #Log.info(f"load data time: {end_time - start_time}") + nbv_idx, nbv_coverage_rate = nbv[0], nbv[1] + nbv_path = DataLoadUtil.get_path(self.root_dir, scene_name, nbv_idx) + cam_info = DataLoadUtil.load_cam_info(nbv_path) + best_frame_to_world = cam_info["cam_to_world"] + + best_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy( + np.asarray(best_frame_to_world[:3, :3]) + ) + best_to_world_trans = best_frame_to_world[:3, 3] + best_to_world_9d = np.concatenate( + [best_to_world_6d, best_to_world_trans], axis=0 + ) + + combined_scanned_views_pts = np.concatenate(scanned_views_pts, axis=0) + voxel_downsampled_combined_scanned_pts_np, inverse = self.voxel_downsample_with_mapping(combined_scanned_views_pts, 0.003) + random_downsampled_combined_scanned_pts_np, random_downsample_idx = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, self.pts_num, require_idx=True) + + # all_idx_unique = np.arange(len(voxel_downsampled_combined_scanned_pts_np)) + # all_random_downsample_idx = all_idx_unique[random_downsample_idx] + # scanned_pts_mask = [] + # for idx, start_idx in enumerate(start_indices): + # if idx == len(start_indices) - 1: + # break + # end_idx = start_indices[idx+1] + # view_inverse = inverse[start_idx:end_idx] + # view_unique_downsampled_idx = np.unique(view_inverse) + # view_unique_downsampled_idx_set = set(view_unique_downsampled_idx) + # mask = np.array([idx in view_unique_downsampled_idx_set for idx in all_random_downsample_idx]) + # #scanned_pts_mask.append(mask) + data_item = { + "scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32), # Ndarray(S x Nv x 3) + "combined_scanned_pts": np.asarray(random_downsampled_combined_scanned_pts_np, dtype=np.float32), # Ndarray(N x 3) + #"scanned_pts_mask": np.asarray(scanned_pts_mask, dtype=np.bool), # Ndarray(N) + "scanned_coverage_rate": scanned_coverages_rate, # List(S): Float, range(0, 1) + "scanned_n_to_world_pose_9d": np.asarray(scanned_n_to_world_pose, dtype=np.float32), # Ndarray(S x 9) + "best_coverage_rate": nbv_coverage_rate, # Float, range(0, 1) + "best_to_world_pose_9d": np.asarray(best_to_world_9d, dtype=np.float32), # Ndarray(9) + "seq_max_coverage_rate": max_coverage_rate, # Float, range(0, 1) + "scene_name": scene_name, # String + } + + return data_item + + def __len__(self): + return len(self.datalist) + + def get_collate_fn(self): + def collate_fn(batch): + collate_data = {} + + ''' ------ Varialbe Length ------ ''' + + collate_data["scanned_pts"] = [ + torch.tensor(item["scanned_pts"]) for item in batch + ] + collate_data["scanned_n_to_world_pose_9d"] = [ + torch.tensor(item["scanned_n_to_world_pose_9d"]) for item in batch + ] + # collate_data["scanned_pts_mask"] = [ + # torch.tensor(item["scanned_pts_mask"]) for item in batch + # ] + ''' ------ Fixed Length ------ ''' + + collate_data["best_to_world_pose_9d"] = torch.stack( + [torch.tensor(item["best_to_world_pose_9d"]) for item in batch] + ) + collate_data["combined_scanned_pts"] = torch.stack( + [torch.tensor(item["combined_scanned_pts"]) for item in batch] + ) + + for key in batch[0].keys(): + if key not in [ + "scanned_pts", + "scanned_n_to_world_pose_9d", + "best_to_world_pose_9d", + "combined_scanned_pts", + "scanned_pts_mask", + ]: + collate_data[key] = [item[key] for item in batch] + return collate_data + + return collate_fn + + +# -------------- Debug ---------------- # +if __name__ == "__main__": + import torch + + seed = 0 + torch.manual_seed(seed) + np.random.seed(seed) + config = { + "root_dir": "/data/hofee/nbv_rec_part2_preprocessed", + "source": "nbv_reconstruction_dataset", + "split_file": "/data/hofee/data/sample.txt", + "load_from_preprocess": True, + "ratio": 0.5, + "batch_size": 2, + "filter_degree": 75, + "num_workers": 0, + "pts_num": 4096, + "type": namespace.Mode.TRAIN, + } + ds = NBVReconstructionDataset(config) + print(len(ds)) + # ds.__getitem__(10) + dl = ds.get_loader(shuffle=True) + for idx, data in enumerate(dl): + data = ds.process_batch(data, "cuda:0") + print(data) + # ------ Debug Start ------ + import ipdb + + ipdb.set_trace() + # ------ Debug End ------ diff --git a/core/old_seq_dataset.py b/core/old_seq_dataset.py new file mode 100644 index 0000000..753636e --- /dev/null +++ b/core/old_seq_dataset.py @@ -0,0 +1,154 @@ +import numpy as np +from PytorchBoot.dataset import BaseDataset +import PytorchBoot.namespace as namespace +import PytorchBoot.stereotype as stereotype +from PytorchBoot.utils.log_util import Log +import torch +import os +import sys +sys.path.append(r"/home/data/hofee/project/nbv_rec/nbv_reconstruction") + +from utils.data_load import DataLoadUtil +from utils.pose import PoseUtil +from utils.pts import PtsUtil + +@stereotype.dataset("old_seq_nbv_reconstruction_dataset") +class SeqNBVReconstructionDataset(BaseDataset): + def __init__(self, config): + super(SeqNBVReconstructionDataset, self).__init__(config) + self.type = config["type"] + if self.type != namespace.Mode.TEST: + Log.error("Dataset Only support test mode", terminate=True) + self.config = config + self.root_dir = config["root_dir"] + self.split_file_path = config["split_file"] + self.scene_name_list = self.load_scene_name_list() + self.datalist = self.get_datalist() + self.pts_num = config["pts_num"] + + self.model_dir = config["model_dir"] + self.filter_degree = config["filter_degree"] + self.load_from_preprocess = config.get("load_from_preprocess", False) + + + def load_scene_name_list(self): + scene_name_list = [] + with open(self.split_file_path, "r") as f: + for line in f: + scene_name = line.strip() + scene_name_list.append(scene_name) + return scene_name_list + + def get_datalist(self): + datalist = [] + for scene_name in self.scene_name_list: + seq_num = DataLoadUtil.get_label_num(self.root_dir, scene_name) + scene_max_coverage_rate = 0 + scene_max_cr_idx = 0 + + for seq_idx in range(seq_num): + label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name, seq_idx) + label_data = DataLoadUtil.load_label(label_path) + max_coverage_rate = label_data["max_coverage_rate"] + if max_coverage_rate > scene_max_coverage_rate: + scene_max_coverage_rate = max_coverage_rate + scene_max_cr_idx = seq_idx + + label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name, scene_max_cr_idx) + label_data = DataLoadUtil.load_label(label_path) + first_frame = label_data["best_sequence"][0] + best_seq_len = len(label_data["best_sequence"]) + datalist.append({ + "scene_name": scene_name, + "first_frame": first_frame, + "max_coverage_rate": scene_max_coverage_rate, + "best_seq_len": best_seq_len, + "label_idx": scene_max_cr_idx, + }) + return datalist + + def __getitem__(self, index): + data_item_info = self.datalist[index] + first_frame_idx = data_item_info["first_frame"][0] + first_frame_coverage = data_item_info["first_frame"][1] + max_coverage_rate = data_item_info["max_coverage_rate"] + scene_name = data_item_info["scene_name"] + first_cam_info = DataLoadUtil.load_cam_info(DataLoadUtil.get_path(self.root_dir, scene_name, first_frame_idx), binocular=True) + first_view_path = DataLoadUtil.get_path(self.root_dir, scene_name, first_frame_idx) + first_left_cam_pose = first_cam_info["cam_to_world"] + first_center_cam_pose = first_cam_info["cam_to_world_O"] + first_target_point_cloud = DataLoadUtil.load_from_preprocessed_pts(first_view_path) + first_pts_num = first_target_point_cloud.shape[0] + first_downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud(first_target_point_cloud, self.pts_num) + first_to_world_rot_6d = PoseUtil.matrix_to_rotation_6d_numpy(np.asarray(first_left_cam_pose[:3,:3])) + first_to_world_trans = first_left_cam_pose[:3,3] + first_to_world_9d = np.concatenate([first_to_world_rot_6d, first_to_world_trans], axis=0) + diag = DataLoadUtil.get_bbox_diag(self.model_dir, scene_name) + voxel_threshold = diag*0.02 + first_O_to_first_L_pose = np.dot(np.linalg.inv(first_left_cam_pose), first_center_cam_pose) + scene_path = os.path.join(self.root_dir, scene_name) + model_points_normals = DataLoadUtil.load_points_normals(self.root_dir, scene_name) + + data_item = { + "first_pts_num": np.asarray( + first_pts_num, dtype=np.int32 + ), + "first_pts": np.asarray([first_downsampled_target_point_cloud],dtype=np.float32), + "combined_scanned_pts": np.asarray(first_downsampled_target_point_cloud,dtype=np.float32), + "first_to_world_9d": np.asarray([first_to_world_9d],dtype=np.float32), + "scene_name": scene_name, + "max_coverage_rate": max_coverage_rate, + "voxel_threshold": voxel_threshold, + "filter_degree": self.filter_degree, + "O_to_L_pose": first_O_to_first_L_pose, + "first_frame_coverage": first_frame_coverage, + "scene_path": scene_path, + "model_points_normals": model_points_normals, + "best_seq_len": data_item_info["best_seq_len"], + "first_frame_id": first_frame_idx, + } + return data_item + + def __len__(self): + return len(self.datalist) + + def get_collate_fn(self): + def collate_fn(batch): + collate_data = {} + collate_data["first_pts"] = [torch.tensor(item['first_pts']) for item in batch] + collate_data["first_to_world_9d"] = [torch.tensor(item['first_to_world_9d']) for item in batch] + collate_data["combined_scanned_pts"] = torch.stack([torch.tensor(item['combined_scanned_pts']) for item in batch]) + for key in batch[0].keys(): + if key not in ["first_pts", "first_to_world_9d", "combined_scanned_pts"]: + collate_data[key] = [item[key] for item in batch] + return collate_data + return collate_fn + +# -------------- Debug ---------------- # +if __name__ == "__main__": + import torch + seed = 0 + torch.manual_seed(seed) + np.random.seed(seed) + config = { + "root_dir": "/home/data/hofee/project/nbv_rec/data/nbv_rec_data_512_preproc_npy", + "split_file": "/home/data/hofee/project/nbv_rec/data/OmniObject3d_train.txt", + "model_dir": "/home/data/hofee/project/nbv_rec/data/scaled_object_meshes", + "ratio": 0.005, + "batch_size": 2, + "filter_degree": 75, + "num_workers": 0, + "pts_num": 32684, + "type": namespace.Mode.TEST, + "load_from_preprocess": True + } + ds = SeqNBVReconstructionDataset(config) + print(len(ds)) + #ds.__getitem__(10) + dl = ds.get_loader(shuffle=True) + for idx, data in enumerate(dl): + data = ds.process_batch(data, "cuda:0") + print(data) + # ------ Debug Start ------ + import ipdb;ipdb.set_trace() + # ------ Debug End ------+ \ No newline at end of file diff --git a/core/pipeline.py b/core/pipeline.py new file mode 100644 index 0000000..8996dc4 --- /dev/null +++ b/core/pipeline.py @@ -0,0 +1,140 @@ +import torch +import time +from torch import nn +import PytorchBoot.namespace as namespace +import PytorchBoot.stereotype as stereotype +from PytorchBoot.factory.component_factory import ComponentFactory +from PytorchBoot.utils import Log + + +@stereotype.pipeline("nbv_reconstruction_pipeline") +class NBVReconstructionPipeline(nn.Module): + def __init__(self, config): + super(NBVReconstructionPipeline, self).__init__() + self.config = config + self.module_config = config["modules"] + + self.pts_encoder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["pts_encoder"] + ) + self.pose_encoder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["pose_encoder"] + ) + self.seq_encoder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["seq_encoder"] + ) + self.view_finder = ComponentFactory.create( + namespace.Stereotype.MODULE, self.module_config["view_finder"] + ) + + + self.eps = float(self.config["eps"]) + + def forward(self, data): + mode = data["mode"] + + if mode == namespace.Mode.TRAIN: + return self.forward_train(data) + elif mode == namespace.Mode.TEST: + return self.forward_test(data) + else: + Log.error("Unknown mode: {}".format(mode), True) + + def pertube_data(self, gt_delta_9d): + bs = gt_delta_9d.shape[0] + random_t = ( + torch.rand(bs, device=gt_delta_9d.device) * (1.0 - self.eps) + self.eps + ) + random_t = random_t.unsqueeze(-1) + mu, std = self.view_finder.marginal_prob(gt_delta_9d, random_t) + std = std.view(-1, 1) + z = torch.randn_like(gt_delta_9d) + perturbed_x = mu + z * std + target_score = -z * std / (std**2) + return perturbed_x, random_t, target_score, std + + def forward_train(self, data): + main_feat = self.get_main_feat(data) + """ get std """ + best_to_world_pose_9d_batch = data["best_to_world_pose_9d"] + perturbed_x, random_t, target_score, std = self.pertube_data( + best_to_world_pose_9d_batch + ) + input_data = { + "sampled_pose": perturbed_x, + "t": random_t, + "main_feat": main_feat, + } + estimated_score = self.view_finder(input_data) + output = { + "estimated_score": estimated_score, + "target_score": target_score, + "std": std, + } + return output + + def forward_test(self, data): + main_feat = self.get_main_feat(data) + repeat_num = data.get("repeat_num", 1) + main_feat = main_feat.repeat(repeat_num, 1) + estimated_delta_rot_9d, in_process_sample = self.view_finder.next_best_view( + main_feat + ) + result = { + "pred_pose_9d": estimated_delta_rot_9d, + "in_process_sample": in_process_sample, + } + return result + + def get_main_feat(self, data): + scanned_n_to_world_pose_9d_batch = data[ + "scanned_n_to_world_pose_9d" + ] # List(B): Tensor(S x 9) + scanned_pts_mask_batch = data["scanned_pts_mask"] # List(B): Tensor(S x N) + + scanned_pts_mask_batch = data["scanned_pts_mask"] # List(B): Tensor(N) + + device = next(self.parameters()).device + + embedding_list_batch = [] + + combined_scanned_pts_batch = data["combined_scanned_pts"] # Tensor(B x N x 3) + global_scanned_feat, per_point_feat_batch = self.pts_encoder.encode_points( + combined_scanned_pts_batch, require_per_point_feat=True + ) # global_scanned_feat: Tensor(B x Dg) + batch_size = len(scanned_n_to_world_pose_9d_batch) + for i in range(batch_size): + seq_len = len(scanned_n_to_world_pose_9d_batch[i]) + scanned_n_to_world_pose_9d = scanned_n_to_world_pose_9d_batch[i].to(device) # Tensor(S x 9) + scanned_pts_mask = scanned_pts_mask_batch[i] # Tensor(S x N) + per_point_feat = per_point_feat_batch[i] # Tensor(N x Dp) + partial_point_feat_seq = [] + for j in range(seq_len): + partial_per_point_feat = per_point_feat[scanned_pts_mask[j]] + if partial_per_point_feat.shape[0] == 0: + partial_point_feat = torch.zeros(per_point_feat.shape[1], device=device) + else: + partial_point_feat = torch.mean(partial_per_point_feat, dim=0) # Tensor(Dp) + partial_point_feat_seq.append(partial_point_feat) + partial_point_feat_seq = torch.stack(partial_point_feat_seq, dim=0) # Tensor(S x Dp) + + pose_feat_seq = self.pose_encoder.encode_pose(scanned_n_to_world_pose_9d) # Tensor(S x Dp) + + seq_embedding = torch.cat([partial_point_feat_seq, pose_feat_seq], dim=-1) + + embedding_list_batch.append(seq_embedding) # List(B): Tensor(S x (Dp)) + + seq_feat = self.seq_encoder.encode_sequence(embedding_list_batch) # Tensor(B x Ds) + main_feat = torch.cat([seq_feat, global_scanned_feat], dim=-1) # Tensor(B x (Ds+Dg)) + + if torch.isnan(main_feat).any(): + for i in range(len(main_feat)): + if torch.isnan(main_feat[i]).any(): + scanned_pts_mask = scanned_pts_mask_batch[i] + Log.info(f"scanned_pts_mask shape: {scanned_pts_mask.shape}") + Log.info(f"scanned_pts_mask sum: {scanned_pts_mask.sum()}") + import ipdb + ipdb.set_trace() + Log.error("nan in main_feat", True) + + return main_feat \ No newline at end of file diff --git a/core/seq_dataset.py b/core/seq_dataset.py new file mode 100644 index 0000000..c7332b2 --- /dev/null +++ b/core/seq_dataset.py @@ -0,0 +1,209 @@ +import numpy as np +from PytorchBoot.dataset import BaseDataset +import PytorchBoot.namespace as namespace +import PytorchBoot.stereotype as stereotype +from PytorchBoot.config import ConfigManager +from PytorchBoot.utils.log_util import Log +import torch +import os +import sys + +sys.path.append(r"/media/hofee/data/project/python/nbv_reconstruction/nbv_reconstruction") + +from utils.data_load import DataLoadUtil +from utils.pose import PoseUtil +from utils.pts import PtsUtil + + +@stereotype.dataset("seq_reconstruction_dataset") +class SeqReconstructionDataset(BaseDataset): + def __init__(self, config): + super(SeqReconstructionDataset, self).__init__(config) + self.config = config + self.root_dir = config["root_dir"] + self.split_file_path = config["split_file"] + self.scene_name_list = self.load_scene_name_list() + self.datalist = self.get_datalist() + + self.pts_num = config["pts_num"] + self.type = config["type"] + self.cache = config.get("cache") + self.load_from_preprocess = config.get("load_from_preprocess", False) + + if self.type == namespace.Mode.TEST: + #self.model_dir = config["model_dir"] + self.filter_degree = config["filter_degree"] + if self.type == namespace.Mode.TRAIN: + scale_ratio = 1 + self.datalist = self.datalist*scale_ratio + if self.cache: + expr_root = ConfigManager.get("runner", "experiment", "root_dir") + expr_name = ConfigManager.get("runner", "experiment", "name") + self.cache_dir = os.path.join(expr_root, expr_name, "cache") + # self.preprocess_cache() + + def load_scene_name_list(self): + scene_name_list = [] + with open(self.split_file_path, "r") as f: + for line in f: + scene_name = line.strip() + if not os.path.exists(os.path.join(self.root_dir, scene_name)): + continue + scene_name_list.append(scene_name) + return scene_name_list + + def get_scene_name_list(self): + return self.scene_name_list + + + def get_datalist(self): + datalist = [] + total = len(self.scene_name_list) + for idx, scene_name in enumerate(self.scene_name_list): + print(f"processing {scene_name} ({idx}/{total})") + scene_max_cr_idx = 0 + frame_len = DataLoadUtil.get_scene_seq_length(self.root_dir, scene_name) + + for i in range(10,frame_len): + path = DataLoadUtil.get_path(self.root_dir, scene_name, i) + pts = DataLoadUtil.load_from_preprocessed_pts(path, "npy") + print(pts.shape) + if pts.shape[0] == 0: + continue + else: + break + print(i) + datalist.append({ + "scene_name": scene_name, + "first_frame": i, + "best_seq_len": -1, + "max_coverage_rate": 1.0, + "label_idx": scene_max_cr_idx, + }) + return datalist + + def preprocess_cache(self): + Log.info("preprocessing cache...") + for item_idx in range(len(self.datalist)): + self.__getitem__(item_idx) + Log.success("finish preprocessing cache.") + + def load_from_cache(self, scene_name, curr_frame_idx): + cache_name = f"{scene_name}_{curr_frame_idx}.txt" + cache_path = os.path.join(self.cache_dir, cache_name) + if os.path.exists(cache_path): + data = np.loadtxt(cache_path) + return data + else: + return None + + def save_to_cache(self, scene_name, curr_frame_idx, data): + cache_name = f"{scene_name}_{curr_frame_idx}.txt" + cache_path = os.path.join(self.cache_dir, cache_name) + try: + np.savetxt(cache_path, data) + except Exception as e: + Log.error(f"Save cache failed: {e}") + + def seq_combined_pts(self, scene, frame_idx_list): + all_combined_pts = [] + for i in frame_idx_list: + path = DataLoadUtil.get_path(self.root_dir, scene, i) + pts = DataLoadUtil.load_from_preprocessed_pts(path,"npy") + if pts.shape[0] == 0: + continue + all_combined_pts.append(pts) + all_combined_pts = np.vstack(all_combined_pts) + downsampled_all_pts = PtsUtil.voxel_downsample_point_cloud(all_combined_pts, 0.003) + return downsampled_all_pts + + def __getitem__(self, index): + data_item_info = self.datalist[index] + max_coverage_rate = data_item_info["max_coverage_rate"] + best_seq_len = data_item_info["best_seq_len"] + scene_name = data_item_info["scene_name"] + ( + scanned_views_pts, + scanned_coverages_rate, + scanned_n_to_world_pose, + ) = ([], [], []) + view = data_item_info["first_frame"] + frame_idx = view + view_path = DataLoadUtil.get_path(self.root_dir, scene_name, frame_idx) + cam_info = DataLoadUtil.load_cam_info(view_path, binocular=True) + + n_to_world_pose = cam_info["cam_to_world"] + target_point_cloud = ( + DataLoadUtil.load_from_preprocessed_pts(view_path) + ) + downsampled_target_point_cloud = PtsUtil.random_downsample_point_cloud( + target_point_cloud, self.pts_num + ) + scanned_views_pts.append(downsampled_target_point_cloud) + + n_to_world_6d = PoseUtil.matrix_to_rotation_6d_numpy( + np.asarray(n_to_world_pose[:3, :3]) + ) + first_left_cam_pose = cam_info["cam_to_world"] + first_center_cam_pose = cam_info["cam_to_world_O"] + first_O_to_first_L_pose = np.dot(np.linalg.inv(first_left_cam_pose), first_center_cam_pose) + n_to_world_trans = n_to_world_pose[:3, 3] + n_to_world_9d = np.concatenate([n_to_world_6d, n_to_world_trans], axis=0) + scanned_n_to_world_pose.append(n_to_world_9d) + + frame_list = [] + for i in range(DataLoadUtil.get_scene_seq_length(self.root_dir, scene_name)): + frame_list.append(i) + gt_pts = self.seq_combined_pts(scene_name, frame_list) + data_item = { + "first_scanned_pts": np.asarray(scanned_views_pts, dtype=np.float32), # Ndarray(S x Nv x 3) + "first_scanned_n_to_world_pose_9d": np.asarray(scanned_n_to_world_pose, dtype=np.float32), # Ndarray(S x 9) + "seq_max_coverage_rate": max_coverage_rate, # Float, range(0, 1) + "best_seq_len": best_seq_len, # Int + "scene_name": scene_name, # String + "gt_pts": gt_pts, # Ndarray(N x 3) + "scene_path": os.path.join(self.root_dir, scene_name), # String + "O_to_L_pose": first_O_to_first_L_pose, + } + + return data_item + + def __len__(self): + return len(self.datalist) + + +# -------------- Debug ---------------- # +if __name__ == "__main__": + import torch + from tqdm import tqdm + import pickle + import os + + seed = 0 + torch.manual_seed(seed) + np.random.seed(seed) + + config = { + "root_dir": "/media/hofee/data/data/test_bottle/view", + "source": "seq_reconstruction_dataset", + "split_file": "/media/hofee/data/data/test_bottle/test_bottle.txt", + "load_from_preprocess": True, + "filter_degree": 75, + "num_workers": 0, + "pts_num": 8192, + "type": namespace.Mode.TEST, + } + + output_dir = "/media/hofee/data/data/test_bottle/preprocessed_dataset" + os.makedirs(output_dir, exist_ok=True) + + ds = SeqReconstructionDataset(config) + for i in tqdm(range(len(ds)), desc="processing dataset"): + output_path = os.path.join(output_dir, f"item_{i}.pkl") + item = ds.__getitem__(i) + for key, value in item.items(): + if isinstance(value, np.ndarray): + item[key] = value.tolist() + #import ipdb; ipdb.set_trace() + with open(output_path, "wb") as f: + pickle.dump(item, f) diff --git a/core/seq_dataset_preprocessed.py b/core/seq_dataset_preprocessed.py new file mode 100644 index 0000000..8b0ef99 --- /dev/null +++ b/core/seq_dataset_preprocessed.py @@ -0,0 +1,82 @@ +import numpy as np +from PytorchBoot.dataset import BaseDataset +import PytorchBoot.namespace as namespace +import PytorchBoot.stereotype as stereotype +from PytorchBoot.config import ConfigManager +from PytorchBoot.utils.log_util import Log +import pickle +import torch +import os +import sys + +sys.path.append(r"C:\Document\Local Project\nbv_rec\nbv_reconstruction") + +from utils.data_load import DataLoadUtil +from utils.pose import PoseUtil +from utils.pts import PtsUtil + +@stereotype.dataset("seq_reconstruction_dataset_preprocessed") +class SeqReconstructionDatasetPreprocessed(BaseDataset): + def __init__(self, config): + super(SeqReconstructionDatasetPreprocessed, self).__init__(config) + self.config = config + self.root_dir = config["root_dir"] + self.real_root_dir = r"/media/hofee/repository/final_test_set/view" + self.item_list = os.listdir(self.root_dir) + + def __getitem__(self, index): + data = pickle.load(open(os.path.join(self.root_dir, self.item_list[index]), "rb")) + data_item = { + "first_scanned_pts": np.asarray(data["first_scanned_pts"], dtype=np.float32), # Ndarray(S x Nv x 3) + "first_scanned_n_to_world_pose_9d": np.asarray(data["first_scanned_n_to_world_pose_9d"], dtype=np.float32), # Ndarray(S x 9) + "seq_max_coverage_rate": data["seq_max_coverage_rate"], # Float, range(0, 1) + "best_seq_len": data["best_seq_len"], # Int + "scene_name": data["scene_name"], # String + "gt_pts": np.asarray(data["gt_pts"], dtype=np.float32), # Ndarray(N x 3) + "scene_path": os.path.join(self.real_root_dir, data["scene_name"]), # String + "O_to_L_pose": np.asarray(data["O_to_L_pose"], dtype=np.float32), + } + return data_item + + def __len__(self): + return len(self.item_list) + +# -------------- Debug ---------------- # +if __name__ == "__main__": + import torch + + seed = 0 + torch.manual_seed(seed) + np.random.seed(seed) + ''' + OmniObject3d_test: + root_dir: "H:\\AI\\Datasets\\packed_test_data" + model_dir: "H:\\AI\\Datasets\\scaled_object_meshes" + source: seq_reconstruction_dataset + split_file: "H:\\AI\\Datasets\\data_list\\OmniObject3d_test.txt" + type: test + filter_degree: 75 + eval_list: + - pose_diff + - coverage_rate_increase + ratio: 0.1 + batch_size: 1 + num_workers: 12 + pts_num: 8192 + load_from_preprocess: True + ''' + config = { + "root_dir": "/media/hofee/data/data/test_bottle/preprocessed_dataset", + "source": "seq_reconstruction_dataset", + "split_file": "H:\\AI\\Datasets\\data_list\\OmniObject3d_test.txt", + "load_from_preprocess": True, + "ratio": 1, + "filter_degree": 75, + "num_workers": 0, + "pts_num": 8192, + "type": "test", + } + ds = SeqReconstructionDataset(config) + print(len(ds)) + print(ds.__getitem__(10)) + diff --git a/modules/__pycache__/gf_view_finder.cpython-39.pyc b/modules/__pycache__/gf_view_finder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cfac241af2fd55b42a9372763767c350ce8b0d1 GIT binary patch literal 4128 zcmbVP&668P6`!6NjYe8Yt6lG>?W`3;2m-Tf!e^OaK9bl0!eX3FfFY@=S#@i7WNBvI z({gqxieo7b+?V6vH8?D;`y%$IvJK?S_S-Mdy$74ERc3nO;j;vO&FwRln3o!MjOm<4q)$NgLO zn@rlg#LIk!S9q1z`0QI=FfXj(8D8geBYU{O=SM79981njB}-$;!c?*>>SF$wjaRIQ zMX@B7`Qlqnuqsx>D)u}p&WW>~^T&)?VrBT4SRAoe>|iap0Qp+yytsfdIxTzAYA&77 z4Q)1=_HU$_cuB?G7H((y*x($rXV9`WkbyO{M%J^|tMwx`U_&;thW5x})@^@aK_5GO zy4=7M2$?q0zG|pl(b$V*)Dud`tdX`GQKOggepfW2gg2y6eVH_3wQiO*9bJAfTYFrn zsI{vLhfo(bZSCbtHQRgJl0H7?gI{d)1dpSQ-Lx&l29H#NQXWJ>m1DkhP5EgXwCa*f%>PCL94L*9-PF68>&fu1qD=Huj9nFnII6;@-C`a$jV znidT|{FwO=WLH6i6)B467LNRxS8c0V`@~vuJg+Ikb6<7R|ad z;C#a^fGZhp5nS1DOW1W8ZDn9X_u`0X(e%-KAQ`#PxoYN~?VN*rfzRAy<9Q3z&sUBd z%sY?1!z%`De*$0P)msjb=c~tljnj49H(7Sy=%RhG7Hx4Ex6xTzws>v)4Qr&G5IVAt znbCXUQ>Uh-KI8GApwR z8qtg94c>9la}mPa|9U?0;6C(<6(jX+m#xs~hhAaIc@xg}Jo~Wmh$yr^G4!c$M7AxU zJ%g705d;AM0qDra8{y}O!jOYEU_P`4?hxT;#1Lo@T88!jLFvc?JRNyexd6t3Ds(&q zjKaX%aVTcqvyY0HSsWAx?#O<{KJt%BnD6uAEnEH`Bc(wRI?CY6hC?7jFe6#T=qe9p z29=4fL5cgz);a5_3J7AqK@|`z9o{#YiIUp)F+9L_&n@bHIw^ZZxQv)>k(X@FF%p1X(x{m_q11lz5gx$W z0lp;aBx~pSKfLq1fB*gPYO^B0fW9tffQCT!`#gFEv+|21FQ`-^8#^sv{vr(-OkX1I zOGMT|nl-tBmv-rD{y1q>6fS*V>%aa}Mt3C!yNyAmFP z5~_(XYo$`;K0kGQU2LcFwMg=xV2i8)VwHVHzd9@19`o!4RySZED6HA8eR$zfPCVAs zqDVG*h;E|X>vy&`TXYALhbj_+oB7ND41j4Fp$~8YP917KfsEVtDmkbCh!Ozd4@zx1w6m!l#K>kPm@yVH z*`_PMja_xYMChPI!MK%*c3ZoVO1s*XL>+`=`2@`?CH-EWj7;#?MJ|#whMRgO*zY&X zd1j!=vH=p*5ea#SeZw3sQa(;&pCR&DA|}q8OhWDo1>}k5tK<)f`w>XsWQmlU)Fpoj z?5)ds4*3pAq=hs)R;?Z6+>e{k8*htc8SQ7mspb(aZT#Hnd&<9GnT4yZJ<7VdvQyEr z0G;G*#)O1DjuJaY*=0P=rZWD>L9BE;j0_0#uz@o+h=CeVUEu0qPc&V56YI4N#D0sI zV!NA0>MKt*3;CrOV3NJ$4I)1!Vi5HJR<%Uzt8iOn$`r=G#gxoAob7((IyH9qxknxC z^unemuaSB~u7lg!+-#PpGDz^RAPjXmB>%xb0`QeEyxotw`G`qrl5RsjM}&?fUk1_D zaRO|KZg-Ld&6ht#Q%pLf@`KWSM!@jx$`!xiJLTnRSz@AOXae9hxr{m5z9ZuI5h3OC z=;5oU6-!>7)FI7Xj9_-s0EsPZdg!(KJi2oE#BBQwh>^SfOx!j>?LqWwHgZ$=CW@{n zi9C%`_%r^`6;ea>dL+e;-gfl k*i^7o9KYnVgb+s## zRARgG$pLZ-&>xV4bPUjY{|m1@`QB52A{}yN#Y!*Xf}`PZNX_G$Vcclc5RAWWpWAr~ z+}~z#dQ@QX15D{-gcxEtLs9vQdn6*bw4ggEiUwYd>1zXmo}x zu*ESMEwiODj#eg;GgHaxM6x`UoaN{E(h;7VSmP^vm7irRubt>TU*qTDKQHigexZBu z2xG+8Mwj@?82?J5%TWmO%iW7SggxUeeuP@9$0kDeTUZAV3dx_SG)o{l(?Ni9VBUf$ zF98{%5gMc0=;^{C9^w%mqY)V+jCO+|g8T5njampJ;-UzPfeO_Q5Bsr*dt7lL!=e+$ zVXt6=jE8Z~Lc!HQ!b+kpaUmn%C{|1Qw%y}ENdz__l zd#C7d-e$3i+kH`Vd7|3=eYI2M?R@JQ6+9_&sl*^rsS)w#bRuq-H_CP|>(gf`Y;@9` zak0_g*K_u7ONFWlfDZ(@b=<_lKrNd8;80@rPY)A1Gvrku97UL+h`dA{%rGM_g2l7;1A6t3)N6dCUXvM&RPO7! zhkNwlJ_Q8O$D5CJqZg++?eJJpmiCSTPUzV9)|%Iz9Q3&mbFk6@rD>ijO7*|su1TJ2_aVRgL_2Y!M8n)rd&GAKJXdL)Y5#7T^H|tS>c!5$ zUZ&zSqoPoB@BgCxS&d@gZb# zge~U+M=a+lr|S-#p*!@p0V&LRMGpO;0iJu1*itlA{G-^G*QYBZF$H!tWND(p-)A<= z=9Ot0uUrp**^*boG()_Th-hjP{}Z^arX;g6S1D6S~qV#Pe9NHb`V02nqsRv)?j4 zSecFZM0Y0OCfEj?HsGHE1LcNOYyv)n!hvu*T`0#cVdNFmdxteFD{ zw#b0g*Ms?K&Du;;VPd_hYbI5_Sc!B`lV?2r2WnkA1>-GWESc*}mFQ|89wTEN*MVdc zt|{YQKTGp%Di!ZbJ!jYD=5YweQUq2V_{o69*RLI$x;r*+83a3!e7A&3y*$LASxBz2 zYJE~>A+R{x$aAp_NA&80eJl8>DAWf0o#2JZo35Sa2k=_T4Ret&9a%3;-ja<;ShI9X z*G}WGvyh3|Dx!*Q_tPxpsg%VCI$3QuP>mTLId83p}86$2vajr zHO=9LO`!Xp!ls3TiMS4Kl>Xf`p_ktS382N*2~3~?6x;H5HA3+8d+P^F*nG`W3a!u- HWE1}fL`&C#|IpVDP?9GOW0FZ;XF*o;h6n`DWG8U!H#cyV zv*GRfD%azIr}9!f+?n#m^nxk#7gb72ADF9VaF$wsB%lJ}NKBvb^|Q$V<=o?ui}?PNL_m6!IN; zq32l5DRmlS&ZtxJT{-s*@q#%{?w(br5Gm*I?cBzB(n-F|%JY4S&BeX|mbp2!)U3s6 ztkUJ151O)FR}Yr&mfb$X=RV_rC)B(8lgVG-{p0p8ZVZH94Tawfg|BY(XVJ6&0l_#y z9#D4l2&xQnnSv^9VbFsrZNblaklg&m9`O*VQgo!D03ulF4DWP?7~e z2hMyfbSY)aBn!yXDclR$fqXUXz7f`V;_S1w&-Q#=4F&d^DW)n zP?dDPwUw?nXUGwo^M#5C-bcV^DkWrRa|LGrRQ2(W;9WC z%oVmG9W|1BWbi2f<9Y7BUWA^-msdQ6D~`C6dFmqwIVnC{9|6I z=lzY`MqcKgbD5VTJ8m%YIqVn|aO|QPw@GSIrB`lej=EDJ6AT}csoX+I3IN!bb$E}` z9bW{0a%)ZfNu(vL;KIFY#%ZOAnbK;rUDGOTDIMA~g=Cr6tumUCBkPe!T7wME`1Xm{ zw>9@l*Dty8;gYplwPFX*AUu>1izu|?YajGgcB92WVyVma4lGGHIZ-#^D zw~2T8*GqXSzyom2vg>Nsr&fpfyJk#SE_ne!XVZ zMb!TyyJ<4;XDA8LjKRudu~H{WAHo3KcE#c7X&GU0|G%Vh2ds9s%KyH zWlU;ZS*xVhr^ZX0y0Q*NwPj?Ack0b3UBPYG)OI6gbU=oRs@I2vpzBTFCZKezF9B9j z68hE1i-I$8%<5=6jR@ft_9+|L2-9>&m4*H`DTtc-TY%cnkZklB`oxWcXe__2AEIUQ zc5`!atglg1G!x?_TZpYzl&+gxw3=7}3$4}aIJfnbrt8e)WKGkNgVNHYhacU(wGuvh z_;4lj-PF>So%q9MGo4FPrByTCYAI8ADF0^NHtuDt=SbnTH2#2* zR#X&cUUXVoA)~BENGaAQFTviE`D@v#*Xj06ehJ8fM0va=ykiJks!OiVkK!H6-oFV~ z{M9f0=ek)wI~>c$z7f9ovp?qZDZI3jeZLW&_=`W9KaejVVO%U=Wuo&WB@Iab0D^wQ AYybcN literal 0 HcmV?d00001 diff --git a/modules/__pycache__/pointnet_encoder.cpython-39.pyc b/modules/__pycache__/pointnet_encoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49a399098fdbfe8315ca813e7c0b4a67d6267762 GIT binary patch literal 3446 zcmZ`*OOM>f5$Z^|0ZJ)vO zw_oob|EbN`zp!%saiQ`7e7VmVlT7l2O?bdLmPH~0L4Rvv2X?JM9^6)UWKZ2L-deB= z0jutgK`+;3|AqKe1h<2`(7ti4{f^vJcbC1Tc*F)GpY0tM_qsHx{|*S1;qj`Ss;`p`Ahx<8*wS5^}H3>(w5G1 z7C5p=HsFE#1+!VufSved_hVe;P?_IUi@+A#VU~{KCqr>S$ z0B6j(!p*=rb=rxbLoX)oW>>$HKJNKP-i_+MtfLt9z8yM zDoZOjY3Z0u506WcrltE~mY&>~r2`f{xnKPrRKJJS@BPvlg)nL$N;_A{NTZRz-c>L zU;Ymeb5;N{7W`ZQ762hqEW|lCq5zc4sl1}M3kOEHI0D81PR0$bLF3UJFy{Z}uyDTX z`Z+(hO*Kz*IZs{pmvxwNjb?5XFE52T+eYZ#F$?`D?Ucr0qX;?SDu|i?Sg+Q&aHEMF7$hI zYi=)F(Ta2H5l}x^(>m1mU@Td4!1cROUY^yOn!McB>a$r4aSGuVzCLEqf0ATJQ4&HQ zIyZ!JxxU&5R5z{BuK2T>D8CPj<&?L`9$*7;cK7DlS4Ir@!N5ZTNTZ1g!_p7KDxjd; z3d7H4QBt?qG9DU=DE%Iu?*|~tb{)z?l_WqPuU&=S(daK?7(`Z%+~QgNj<-d+uV?&k zL&qiKOW^aG?VABWnb_VYj&~5cdL7rmBd>((3Z?%AT30FESYr&NDr}LU-AY5@B8ZL~ zwFcL$G#&(76L&zQ?^6wOG1b1s6vE-f2z9u8+RDuu~P zAJTd#)k!>51G_@S_h5mbc@&x9G43<4W2w?>p-fSXA%6g0j{9W5F37GuVAqx)p0(cL ztyi|N|0`Yz#|8L!0Hm{a1)iF1fekTG2d+k*(hpEPL_t6&Qu!fP-^3yq=4hB4G3W74 zi$Vm6YZAc)D7?G}`uOjF_DeJvng3v5(Ossx)!<#`Y^@^MdzfS%!+!QqStT3Noppe5 z?-46geegCi;%G(}t58##q|E2hMQATgDK>-kWp!^F=_pB5QuZp<%*08)hxeMGQ)!}^ z$)4y)R-Sdpsrr|&kznI$mgTKEO2GB?4LV8)Lgn_yr_}FPS!VWfqqNG%@&2{4TakF_ z{*G=(M7R~Hpy?rBN2v_FJf4KOGrE`s9vvI*k literal 0 HcmV?d00001 diff --git a/modules/__pycache__/pose_encoder.cpython-39.pyc b/modules/__pycache__/pose_encoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c03aab8e23e22196d97bfc553093cd87cc212295 GIT binary patch literal 961 zcmYjP&2G~`5T04DT{lgOKvjV|7hmct2ZRtm;!>)LQX%=WGTsew<=E-&29(O2QmLk%U(J03Wx)etBbfLNZ>FnCGM7QPo86lA--xbVTY7-L4Qtl^Y|(9sK#6&Xg?kblOamN@Z^I zbgr9ARoHaCw9}?e>o4DgR#e?sy(nzyUFy5oy3<)B7nL&Ua7oAWrS@a=1V_DXY+={_@~#xn%1yE8X%q~*i&WtfI^P@9`r+ky?W%G zd;P-^K6XeB)1}Aj$HUr#J4xh+UgtAKLmvsz^CLVK;%bpseWatOeLlCp9cQ7cN{B)E z2fh*CL$^x5PSd?FHa_S)$&L$cql;<3)1$LxXT5G3du*)IsB%C-m#8f^fzCkFYEU(rqRdv50gab#Q*>R literal 0 HcmV?d00001 diff --git a/modules/__pycache__/pts_num_encoder.cpython-39.pyc b/modules/__pycache__/pts_num_encoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d49be6b58850c8fd8b1a5f1ded5228ca0fdf275 GIT binary patch literal 971 zcmZWnOK;Oa5T02-;zmh%3egi6F1gfK4hSJ0;!>ij1qu1GGTsewxBwAIJPFRdb4)_Pj$smB)88;b%c^Q>?<|bwRa0_V zWwnqhbC=u4+c^3u84?l{kg!Xbp+LbdqQn!P@Gl_oFTsZ-paOxtdsyd{(bJ}UyU9xJ z{5>vXMB={sh?D_?Eu6CzTLVH%HuqY}zX#u$@mGOhYZnX29*{)BYuTE$aN;6S%Wk~2 zPp@_a5C(s02X$k($jcAcNHr!y8|Y>smGa565$Ac88_sQ*)zvKjq8vi&AIXD{)=M*E zhsW}?k(J5Q!iEQVB~zsw;)!Q{EsL4jAxAk%uF&pqo)xLqoL|AOBrc`M(|BIbq>M#s z(s-fj6PcNKu{86#imT6Gxsp^^8`WeccOm(WZ^Cg|i>8n|zD3RcVyWERLK6Lu6yPp~ z7~zQdIK)nVpY7eA{r*j>*)|EzcI~_-(QfTpSkdU$NW;VF1p8i@+NG;x;1)#7Q+1EL zUfn0l?)FQ=H$WdyOm{`_?8$#H{(U$D0vn zvM4y;LUPApa&S6k9lP48h>gZwWHPXUsZ}=bDSmXg?7UZXZT7X1O4eq%kaoCbhdmk_ ezVDdm(ri`irqB%_dfvZ$;Z{sLVn)P9gns}{^Y69* literal 0 HcmV?d00001 diff --git a/modules/__pycache__/transformer_seq_encoder.cpython-39.pyc b/modules/__pycache__/transformer_seq_encoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9fb466432ed66fc282530ea6fa0e54e172c2448 GIT binary patch literal 2302 zcma)7OOG2x5bo}I+H0?OHzwH)K?H&#nS%|SP!z2YByb54QgRw8YG$X~_BiwK_Oz4i zYRqL<3LN&por9MHQuqV>m%eiHnHwO5su?FSiG&zy%GF&xRn^sBeI2&jO#;tvw|2+Z zeL{Z6!TRIE;4^rOUqJ{WXhH(~X9-DZKq=~0!cr@+DxD^F>I4oY(pk8|KBr$1;RyGX z2v^dHcVeB;z?Ti-J!I+^;fn^mO|cb7`$$t3usDk7nmhgSoJNbgN@C z?wvT3p)zb3!4B*~CPTFau9buS+pA@Dp3kUha2_RLQE>i@{CO})rHI49IEPIRM5x2T zROJ&H>B00!kMnGh9X#SnLdJzwvq;A{B_H$E^Z=?gOJp%vPizfS{ppeFKpG8ri_0Jg z-J~w{sZAODyOjOyx;FKn`_}QbwGF=po&Jm3Me7(sr?|%oL*Ici0<4tKC=Yh#R>>BK zC}_EJXw*YrBYpruXKZe3u%6H#2;!>3GvXL9gDr-L0&KB`SK1T*oX(w+Od6$KLLKp|Eg0W?26t<0dj2=`oz6mxfO`j4G>o8n%_J^# zl&91E_hI}8yySTQKgwET<0UopZzg#ZCdED|FB}1VI6@yyB`DoOMf{;W;=m^&&PF^9 zi-+n8L;<6PkZ&0`oK9sXjD@M;$ruX~HIB|zG&Zi5S&=JaALMyrn9hwWUgwE8QJ7MdY(+aBj0)r*pS3zKi6{-@ zjPoTBEm@**i3b$}ivxvMw4o5N3O}xBO|?Pm*-PxUCES!`iM^*j1@Q*d@uM%`Ev|y7 zpuRh&nk*=G;2zPIAjr!rnA0j;&sIHHZ5+5fL}ol>5X2AC7<9rrXtQGRWNn)Wc=a9# zOrofIj8^00t=rXoBd&f~pXiO_&6Rfb-Z-4f8@+B7t>Rxm1Ol_4W&N3slcKM(6>Kkl zb5xl=&vV@`v{W+3HZaXK-K~$wxL;StRn=F0f|+6F##?c(+Q(6|zB-EdSXLVyNj1Qw zc$~rJ>fTtP&^S{SXBsXL%y@0v)e%?C?%%9~?;}MNeGopN4t|zTyVrcifMKiOU#{;X R<1*me=dK0J``osg^dCI#Vk`gv literal 0 HcmV?d00001 diff --git a/modules/func_lib/__init__.py b/modules/func_lib/__init__.py new file mode 100644 index 0000000..d139a11 --- /dev/null +++ b/modules/func_lib/__init__.py @@ -0,0 +1,6 @@ +from modules.func_lib.samplers import ( + cond_ode_sampler +) +from modules.func_lib.sde import ( + init_sde +) diff --git a/modules/func_lib/__pycache__/__init__.cpython-39.pyc b/modules/func_lib/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ca34e313d34ecb51be837eda7e3ebab4135df94 GIT binary patch literal 302 zcmZWjJ8A7ct}l+w8kQ2q{v#akm77yn1F8^jRcrjF03KOV09p~k)Y~@5J=oVYn)L*1v z4L{%Y!D|^b%B0+AF%q^g-ej3HZkoDVyPzkFsohV$m)7jsS(sw6>sJ4HjZqMFVzH)m z9pA6uP#Xn9*du~gS-}{?8}=E-V;+JBzk8PvDNf39>a!^j9^?$eYy-bBK+0=O*^bBa P;_-)m;$)agu3K>dOZ-(Q literal 0 HcmV?d00001 diff --git a/modules/func_lib/__pycache__/samplers.cpython-39.pyc b/modules/func_lib/__pycache__/samplers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40b9251a5cc33bef8554f437ee38c736b28cae38 GIT binary patch literal 2731 zcmZ`*O^+i-8Lq1SaJR?f@p!yDvn)hr6Bb$7j1v-aS~dZKI7AvE*a)y9Sv{`uxJ|p; zscz5gjLN6o)p!ph5Wm1a<_94DOP`Q}6!r$<0!U$b-tx?{EJ?NMt*^JL-p{8x+T5%% zls|s3H;QVEeNKZHONhbysQEn*$s|u$Z{6~qz*r>C#O=A9=A6Wv_&pzeS9*_FuOc_) zhO9l}y+HC`uulCc&R_=}UWRd6s3RQ}3WMOIELR^FaWW9Jepy_!KI)HA^Iw4&HfD2C z@TD+9aw#5*8*Ihrj&U|{^-WedOINxDAA2okxO8_|3+pY+k9`A?c9`^}e<7rMwR)57 zF^E+;>)juzb`t+aCGjZBWIG$SqxMguS)RvH+Lm!%==g9}#97)tiHlKtqA-1?bk>ei zS#6`_1$vQ|ZK;mq$a3cK(Inb`$|3r}5eSv{pOT7CpP@i;N5+#*rS$SqG*zWYr=>g9 znJm2`(}PjzB-v3Zrg7=yvq{G-`B^WthxRAY10CR?k{^7=^fr3``lLHiGLE{VY^YRM zMn%-cSz|RQy3^BQl%?JD@VKwlAWNamY(OG*Q+3k6obFCCIZIUD9nR80KZy^!c{G_O zO6U92({lGH$qu8WKh<%jv3g}G5+bjIFxTfz{u&RSxuW`Sy@Q!A{x`CjP` zW-|JMxETBEdu90PEGqgZ4^GeS4>s4cmhOR4xH+hK9fm5{lAG#^KXDdZ@*8X{Bz*pa znVMuXrr$MU!4}R?m_RxgV(Bczil=*~_Eqz{Y+f_9v6ye3dqYGA0yKEyp0jg)F3z2E z_uL=)(wCJ-!MrZne9MHmgX`ZH?$R@XshDu(%(n`ETrqVT^)r-o!QEZHBrZL`J2)S_~#`Wcv3`km>vVFCNlwpsB42IBZ&9@B9CStm2Hh#yo z0bj#3*ILUOZfuUX=`>SmvH8xzH#-mcCt}{3Uo*{RbH!k#&AZI(psn9!3va<6^7(bs zGS}qR4l~zVi1&HRG)+s^Ol6N*9T{#)Wg`_Pyoro~op#MGB-mcxI}gqF5|W#+#h|T# z^lZ6fwl8I8a{I9$`Qaz=@+@S*UK1{8C+ST)n_s!}rCGZZQc5AWR*q>71@g3cP==Ez zPWwX@6`Jx2ImGV|sCQK;y(lOt0P`P#Pzjm%L=dzk6=ELZKr zQ_JwP4}Q|lv*cLmeE^;&B0E4_3!L55YdvM|sQrAyfhRG=%M z;fC$>?JYLfb(2nMuBX*8((+Qc*D;^dO-xkzZGLtK$u9elu78=!`*)LU5GDD&{r}(m zKj7=(464?QkI*UDn3HKFM+ggykZe+1ICGA^i@+8(1wF}E+#rnRV&U2J1q9n5b>^;d zmOji{SuudN`@9c0M!>poxj^fob)=6}st|xLTnb_eTsR^)d!2xqOiL(bQTFT2zP7n2 zJz1OrUVTF5f%=T=SD{Z?Tbo4or+9oyUR)ydpWf4?cPCi8$}d!!XIg)Ywp4~m78P%Q zuXM9i)0g2PcyP&^K=Y%<&H^|g*_=4vH@+cYLEwJPb5GKCC_kh1Od;z|* zTVL<_tIgItrTg;_-}ye?7FQQpQo0dQf1>bU$kH$R6kYl?T%hSW*54sQ3R)w@e*scI z{gU3YYnnj2tR1Ps8q{iFUnqgMj6l5%k~qz$(Lia+QD{cV(T|U(*0}#Qp*6hzF7fbF zeS=!}I94aRgYMc$o^?doOfx-+5*V}E&2*SFO0k=s>P|)P)ADnt>n;t+-)otzoHoZBFT2s^<453RTUB0n{AYV& z?@QtNqUK+(xjYbc?sCC>-T<}VKmD-D{b#OoOZeiI=MmE2w>VyE@ZI5cRQ(~&?1Z{S zQdn~BdG9)U2YMIu{sCAB!)&mx{R*6lr>ASDR`HLN;D-S*Yy-mn2iT+6;Y;ej7|3_m bKOOf7!E?$V1g}64J3-?TlGngY^zZ)x(S5S8 literal 0 HcmV?d00001 diff --git a/modules/func_lib/__pycache__/sde.cpython-39.pyc b/modules/func_lib/__pycache__/sde.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bdebc38ab9e748bdee7ee39a3e2ece57207cbba GIT binary patch literal 3454 zcmb7G&2Jl35PxsKymr3aB&C72(C^m$atjilRa1#pd=mhPh zXF-kANqUaD1 zL`Gyr#%>YGDKRa2RwR0MWJgxhxFXhMWlgL(v0RO180Ul5e7jz2xBQH(uf}E&Qf#%o zy02`g)5D)1l-nM)>g9z_)APzy59{Tw>MVMVu-v^KE_D2|Ketj-UZdj&p;~T)Ef#s7 z)>3h~-J#_rFDN&c{YGu6HCGPEJKen=7gxMmyRPP2etoG1m2)oJ!A1#Ll$KQEpc-u8 z=EHM&{I$WCK|~TJOBgJ)!Dc9KOOkgaOlR5%#iAr5G;dpxaYyPmW@q4>NG#f9-h>p6 zf{P-Xom8?aN<#03fuSU0hrE?m!z($l9d=Y>A$CH~4>~Hgf=^T!8-6!7mpbzbC=(?| zu>pz?s2G5;t(_c?kZLum#FQ_3ThC8j<))ME10RoknHN=%2J?YG11z z_&Byz-6x+-4i@TNufHmmS7TYLPNdMOuOeuzjvBVWg#~(IC$t3`uX0+dec;w_eaFvF zvm4Xk28_r-z<+H-1q8_mq2sVJra5E;lfW6%Vg}geP)F7s8DR$IA=sp4gkVk!4l-ci z0R4^*tW!yy_w!vf0re?&Z&p3i=fs6IGjW3Ht#;KEX4Z0wkEKH!J9A!GuboNXXRG6g zeT~y4jpZn>9N{5LpCqcqmF^(aIa|~C?vvwkMyeB#Yy&%cPkZDz2n3WfT*G&l=$Y`4 z$yu&FVZz!zjEzle@7*mK+9GYtCPYinrdVZaNvL+QWQK>KlRG#$@OThY86I6AQk{lm z;JhXToHqKrjDJo(hTIjQ*bV@W@uk>wxOp$r*+GW4b+N4m5uY+(wL+@17`K6j&#Sk5 z6U6@m%`jjF;G)=M4Hb0=)CQLYbr{2@LyxlH7!N}(n+lfahU0w(>$b#uQmPjq8So#& zjQ`r0t00~fZxl2xb4fOEerQZYf-KG!#y#>wX9O|_>8|s@O7p<3sFTpGQNQu`qeqVt z6OTYv$|VJ?ACyzPlA#g1)CwA%Wj}l@Nb-rrz{yKk9B}ZtZn^RZu&rJMH(;p=!M`Cz zm9apv%Z6P$l5C`rF^l<4Wbn4@f2}?I<=wfDXE(Sr4B&*TW40mJ>WDkUmgcAdnry~v zhqv9~QF}0dFJV>Msw{3*uYw=2kui~F?>;a(X55-gU8_)xb!5<}lKxyWlh&i-fr6}a zWfC7InQIRm)JCEbtrHeovm4C8}?70 zx70y-k2g)z^&D@=X|wq-F4!*`;pZ>|%*tzqdYgLTmQx2NkRFj>pOulq7-l2Rir?urWndRStrUEqmN+Ug& zMtgZ0OEtNX8|7&{J5xU66U~K;*P8vy&@^$)_*$yt+jNa&YZLoMGd|HZmhp)$D;)m< z-Hy8%@RDoVXyL_5)r8;G%kZdrg@@f3Vq?Wq9Dt4IG0Z<)xc}u3e>^y^^Br?twVziW z;JGAUF3+N%Whq~^b%vED;yhp1!cJ!?h;!Y#3S0Fh#Szs|*I-KS416!q&0Uvg-Ntlp zP@I#x%IN|SPgOeMTkY7T$`!Fiqmfs3F6W+(N;35WLxJXw174^?SiR$x~oE+eb(m|u^&x1FR>fC4G zGWN0J6dIt-zFfN;lbOT+Bb~()>mN_=P#1V}t`M>9FSonb72l_Zt*yAp| eps + t_eval = np.linspace(T, eps, num_steps) + res = integrate.solve_ivp( + ode_func, + (T, eps), + init_x.reshape(-1).cpu().numpy(), + rtol=rtol, + atol=atol, + method="RK45", + t_eval=t_eval, + ) + xs = torch.tensor(res.y, device=device).T.view( + -1, batch_size, pose_dim + ) # [num_steps, bs, pose_dim] + x = torch.tensor(res.y[:, -1], device=device).reshape(shape) # [bs, pose_dim] + # denoise, using the predictor step in P-C sampler + if denoise: + # Reverse diffusion predictor for denoising + vec_eps = torch.ones((x.shape[0], 1), device=x.device) * eps + drift, diffusion = sde_coeff(vec_eps) + data["sampled_pose"] = x.float() + data["t"] = vec_eps + grad = score_model(data) + drift = drift - diffusion**2 * grad # R-SDE + mean_x = x + drift * ((1 - eps) / (1000 if num_steps is None else num_steps)) + x = mean_x + + num_steps = xs.shape[0] + xs = xs.reshape(batch_size*num_steps, -1) + xs[:, :-3] = PoseUtil.normalize_rotation(xs[:, :-3], pose_mode) + xs = xs.reshape(num_steps, batch_size, -1) + x[:, :-3] = PoseUtil.normalize_rotation(x[:, :-3], pose_mode) + return xs.permute(1, 0, 2), x diff --git a/modules/func_lib/sde.py b/modules/func_lib/sde.py new file mode 100644 index 0000000..d93c999 --- /dev/null +++ b/modules/func_lib/sde.py @@ -0,0 +1,121 @@ +import functools +import torch +import numpy as np + + +# ----- VE SDE ----- +# ------------------ +def ve_marginal_prob(x, t, sigma_min=0.01, sigma_max=90): + std = sigma_min * (sigma_max / sigma_min) ** t + mean = x + return mean, std + + +def ve_sde(t, sigma_min=0.01, sigma_max=90): + sigma = sigma_min * (sigma_max / sigma_min) ** t + drift_coeff = torch.tensor(0) + diffusion_coeff = sigma * torch.sqrt(torch.tensor(2 * (np.log(sigma_max) - np.log(sigma_min)), device=t.device)) + return drift_coeff, diffusion_coeff + + +def ve_prior(shape, sigma_min=0.01, sigma_max=90, T=1.0): + _, sigma_max_prior = ve_marginal_prob(None, T, sigma_min=sigma_min, sigma_max=sigma_max) + return torch.randn(*shape) * sigma_max_prior + + +# ----- VP SDE ----- +# ------------------ +def vp_marginal_prob(x, t, beta_0=0.1, beta_1=20): + log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0 + mean = torch.exp(log_mean_coeff) * x + std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) + return mean, std + + +def vp_sde(t, beta_0=0.1, beta_1=20): + beta_t = beta_0 + t * (beta_1 - beta_0) + drift_coeff = -0.5 * beta_t + diffusion_coeff = torch.sqrt(beta_t) + return drift_coeff, diffusion_coeff + + +def vp_prior(shape, beta_0=0.1, beta_1=20): + return torch.randn(*shape) + + +# ----- sub-VP SDE ----- +# ---------------------- +def subvp_marginal_prob(x, t, beta_0, beta_1): + log_mean_coeff = -0.25 * t ** 2 * (beta_1 - beta_0) - 0.5 * t * beta_0 + mean = torch.exp(log_mean_coeff) * x + std = 1 - torch.exp(2. * log_mean_coeff) + return mean, std + + +def subvp_sde(t, beta_0, beta_1): + beta_t = beta_0 + t * (beta_1 - beta_0) + drift_coeff = -0.5 * beta_t + discount = 1. - torch.exp(-2 * beta_0 * t - (beta_1 - beta_0) * t ** 2) + diffusion_coeff = torch.sqrt(beta_t * discount) + return drift_coeff, diffusion_coeff + + +def subvp_prior(shape, beta_0=0.1, beta_1=20): + return torch.randn(*shape) + + +# ----- EDM SDE ----- +# ------------------ +def edm_marginal_prob(x, t, sigma_min=0.002, sigma_max=80): + std = t + mean = x + return mean, std + + +def edm_sde(t, sigma_min=0.002, sigma_max=80): + drift_coeff = torch.tensor(0) + diffusion_coeff = torch.sqrt(2 * t) + return drift_coeff, diffusion_coeff + + +def edm_prior(shape, sigma_min=0.002, sigma_max=80): + return torch.randn(*shape) * sigma_max + + +def init_sde(sde_mode): + # the SDE-related hyperparameters are copied from https://github.com/yang-song/score_sde_pytorch + if sde_mode == 'edm': + sigma_min = 0.002 + sigma_max = 80 + eps = 0.002 + prior_fn = functools.partial(edm_prior, sigma_min=sigma_min, sigma_max=sigma_max) + marginal_prob_fn = functools.partial(edm_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max) + sde_fn = functools.partial(edm_sde, sigma_min=sigma_min, sigma_max=sigma_max) + T = sigma_max + elif sde_mode == 've': + sigma_min = 0.01 + sigma_max = 50 + eps = 1e-5 + marginal_prob_fn = functools.partial(ve_marginal_prob, sigma_min=sigma_min, sigma_max=sigma_max) + sde_fn = functools.partial(ve_sde, sigma_min=sigma_min, sigma_max=sigma_max) + T = 1.0 + prior_fn = functools.partial(ve_prior, sigma_min=sigma_min, sigma_max=sigma_max) + elif sde_mode == 'vp': + beta_0 = 0.1 + beta_1 = 20 + eps = 1e-3 + prior_fn = functools.partial(vp_prior, beta_0=beta_0, beta_1=beta_1) + marginal_prob_fn = functools.partial(vp_marginal_prob, beta_0=beta_0, beta_1=beta_1) + sde_fn = functools.partial(vp_sde, beta_0=beta_0, beta_1=beta_1) + T = 1.0 + elif sde_mode == 'subvp': + beta_0 = 0.1 + beta_1 = 20 + eps = 1e-3 + prior_fn = functools.partial(subvp_prior, beta_0=beta_0, beta_1=beta_1) + marginal_prob_fn = functools.partial(subvp_marginal_prob, beta_0=beta_0, beta_1=beta_1) + sde_fn = functools.partial(subvp_sde, beta_0=beta_0, beta_1=beta_1) + T = 1.0 + else: + raise NotImplementedError + return prior_fn, marginal_prob_fn, sde_fn, eps, T diff --git a/modules/gf_view_finder.py b/modules/gf_view_finder.py new file mode 100644 index 0000000..dfccb06 --- /dev/null +++ b/modules/gf_view_finder.py @@ -0,0 +1,167 @@ +import torch +import torch.nn as nn +import PytorchBoot.stereotype as stereotype + +from utils.pose import PoseUtil +import modules.module_lib as mlib +import modules.func_lib as flib + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +@stereotype.module("gf_view_finder") +class GradientFieldViewFinder(nn.Module): + def __init__(self, config): + + super(GradientFieldViewFinder, self).__init__() + + self.regression_head = config["regression_head"] + self.per_point_feature = config["per_point_feature"] + self.act = nn.ReLU(True) + self.sample_mode = config["sample_mode"] + self.pose_mode = config["pose_mode"] + pose_dim = PoseUtil.get_pose_dim(self.pose_mode) + self.prior_fn, self.marginal_prob_fn, self.sde_fn, self.sampling_eps, self.T = flib.init_sde(config["sde_mode"]) + self.sampling_steps = config["sampling_steps"] + self.t_feat_dim = config["t_feat_dim"] + self.pose_feat_dim = config["pose_feat_dim"] + self.main_feat_dim = config["main_feat_dim"] + + ''' encode pose ''' + self.pose_encoder = nn.Sequential( + nn.Linear(pose_dim, self.pose_feat_dim ), + self.act, + nn.Linear(self.pose_feat_dim , self.pose_feat_dim ), + self.act, + ) + + ''' encode t ''' + self.t_encoder = nn.Sequential( + mlib.GaussianFourierProjection(embed_dim=self.t_feat_dim ), + nn.Linear(self.t_feat_dim , self.t_feat_dim ), + self.act, + ) + + ''' fusion tail ''' + if self.regression_head == 'Rx_Ry_and_T': + if self.pose_mode != 'rot_matrix': + raise NotImplementedError + if not self.per_point_feature: + ''' rotation_x_axis regress head ''' + self.fusion_tail_rot_x = nn.Sequential( + nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256), + self.act, + zero_module(nn.Linear(256, 3)), + ) + self.fusion_tail_rot_y = nn.Sequential( + nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256), + self.act, + zero_module(nn.Linear(256, 3)), + ) + ''' tranalation regress head ''' + self.fusion_tail_trans = nn.Sequential( + nn.Linear(self.t_feat_dim + self.pose_feat_dim + self.main_feat_dim, 256), + self.act, + zero_module(nn.Linear(256, 3)), + ) + else: + raise NotImplementedError + else: + raise NotImplementedError + + def forward(self, data): + """ + Args: + data, dict { + 'main_feat': [bs, c] + 'pose_sample': [bs, pose_dim] + 't': [bs, 1] + } + """ + + main_feat = data['main_feat'] + sampled_pose = data['sampled_pose'] + t = data['t'] + t_feat = self.t_encoder(t.squeeze(1)) + pose_feat = self.pose_encoder(sampled_pose) + + if self.per_point_feature: + raise NotImplementedError + else: + total_feat = torch.cat([main_feat, t_feat, pose_feat], dim=-1) + _, std = self.marginal_prob_fn(total_feat, t) + + if self.regression_head == 'Rx_Ry_and_T': + rot_x = self.fusion_tail_rot_x(total_feat) + rot_y = self.fusion_tail_rot_y(total_feat) + trans = self.fusion_tail_trans(total_feat) + out_score = torch.cat([rot_x, rot_y, trans], dim=-1) / (std+1e-7) # normalisation + else: + raise NotImplementedError + + return out_score + + def marginal_prob(self, x, t): + return self.marginal_prob_fn(x,t) + + def sample(self, data, atol=1e-5, rtol=1e-5, snr=0.16, denoise=True, init_x=None, T0=None): + + if self.sample_mode == 'ode': + T0 = self.T if T0 is None else T0 + in_process_sample, res = flib.cond_ode_sampler( + score_model=self, + data=data, + prior=self.prior_fn, + sde_coeff=self.sde_fn, + atol=atol, + rtol=rtol, + eps=self.sampling_eps, + T=T0, + num_steps=self.sampling_steps, + pose_mode=self.pose_mode, + denoise=denoise, + init_x=init_x + ) + else: + raise NotImplementedError + + return in_process_sample, res + + def next_best_view(self, main_feat): + data = { + 'main_feat': main_feat, + } + in_process_sample, res = self.sample(data) + return res.to(dtype=torch.float32), in_process_sample + + +''' ----------- DEBUG -----------''' +if __name__ == "__main__": + config = { + "regression_head": "Rx_Ry_and_T", + "per_point_feature": False, + "pose_mode": "rot_matrix", + "sde_mode": "ve", + "sampling_steps": 500, + "sample_mode": "ode" + } + test_seq_feat = torch.rand(32, 2048).to("cuda:0") + test_pose = torch.rand(32, 9).to("cuda:0") + test_t = torch.rand(32, 1).to("cuda:0") + view_finder = GradientFieldViewFinder(config).to("cuda:0") + test_data = { + 'seq_feat': test_seq_feat, + 'sampled_pose': test_pose, + 't': test_t + } + score = view_finder(test_data) + print(score.shape) + res, inprocess = view_finder.next_best_view(test_seq_feat) + print(res.shape, inprocess.shape) diff --git a/modules/mlp_view_finder.py b/modules/mlp_view_finder.py new file mode 100644 index 0000000..7aab553 --- /dev/null +++ b/modules/mlp_view_finder.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn +import PytorchBoot.stereotype as stereotype + +from utils.pose import PoseUtil +import modules.module_lib as mlib +import modules.func_lib as flib + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + +@stereotype.module("mlp_view_finder") +class MLPViewFinder(nn.Module): + def __init__(self, config): + + super(MLPViewFinder, self).__init__() + + self.regression_head = 'Rx_Ry_and_T' + self.per_point_feature = False + self.act = nn.ReLU(True) + self.main_feat_dim = config["main_feat_dim"] + + ''' rotation_x_axis regress head ''' + self.fusion_tail_rot_x = nn.Sequential( + nn.Linear(self.main_feat_dim, 256), + self.act, + zero_module(nn.Linear(256, 3)), + ) + self.fusion_tail_rot_y = nn.Sequential( + nn.Linear(self.main_feat_dim, 256), + self.act, + zero_module(nn.Linear(256, 3)), + ) + ''' tranalation regress head ''' + self.fusion_tail_trans = nn.Sequential( + nn.Linear(self.main_feat_dim, 256), + self.act, + zero_module(nn.Linear(256, 3)), + ) + + + def forward(self, data): + """ + Args: + data, dict { + 'main_feat': [bs, c] + } + """ + + total_feat = data['main_feat'] + rot_x = self.fusion_tail_rot_x(total_feat) + rot_y = self.fusion_tail_rot_y(total_feat) + trans = self.fusion_tail_trans(total_feat) + output = torch.cat([rot_x,rot_y,trans], dim=-1) + return output + + def next_best_view(self, main_feat): + data = { + 'main_feat': main_feat, + } + res = self(data) + return res.to(dtype=torch.float32), None + +''' ----------- DEBUG -----------''' +if __name__ == "__main__": + config = { + "regression_head": "Rx_Ry_and_T", + "per_point_feature": False, + "pose_mode": "rot_matrix", + "sde_mode": "ve", + "sampling_steps": 500, + "sample_mode": "ode" + } + test_seq_feat = torch.rand(32, 2048).to("cuda:0") + test_pose = torch.rand(32, 9).to("cuda:0") + test_t = torch.rand(32, 1).to("cuda:0") + view_finder = GradientFieldViewFinder(config).to("cuda:0") + test_data = { + 'seq_feat': test_seq_feat, + 'sampled_pose': test_pose, + 't': test_t + } + score = view_finder(test_data) + print(score.shape) + res, inprocess = view_finder.next_best_view(test_seq_feat) + print(res.shape, inprocess.shape) diff --git a/modules/module_lib/__init__.py b/modules/module_lib/__init__.py new file mode 100644 index 0000000..0d9f7bf --- /dev/null +++ b/modules/module_lib/__init__.py @@ -0,0 +1,2 @@ +from modules.module_lib.gaussian_fourier_projection import GaussianFourierProjection +from modules.module_lib.linear import Linear \ No newline at end of file diff --git a/modules/module_lib/__pycache__/__init__.cpython-39.pyc b/modules/module_lib/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f8d897c357aad6341bbf22c364a1b65c4ce67b9 GIT binary patch literal 337 zcmZury-ve05Vn&DA=K_v;svtgUH~D6s$CF5V#|Q#)>mTW*pcmkI`TL?0q@e4iC19a zoLUtFC*6Jb_uYMRKEGgOpUcbco$xCaf32Y6mONZhP)u>ddfsxL^1>8-*_M=-kF2h~ zX@k}Luvp761TF2#jb5XF@a_e-Q9Db+vpa23`unV8|$p-ZamdDt1S%7*1rL zq_m5FNM<_eOPYT$UlB~<=LvdLT0-Y`h@fPYkk7%ekDaq%H?P9u*4YsKI7tMHZz8uL c|H_yIf)Lv3C`2>t59im%g&R&ULy|Y+ztCo5yZ`_I literal 0 HcmV?d00001 diff --git a/modules/module_lib/__pycache__/gaussian_fourier_projection.cpython-39.pyc b/modules/module_lib/__pycache__/gaussian_fourier_projection.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b738059ba628a89521af1689e2c5f84c93605a6b GIT binary patch literal 1091 zcmZ`%&1&2*5SA?4>kZkqw4pSmhiz|L!kYG6LZBsqLPMa3f-gd`w7Xtxc~_E~AGW8Y z^JCyh1PSNb4nm(gDqAX6(_-_YLdy))2^#56>n?0YX2Wb6-MmcHro@ zpg7{VMA_2fEbu<5!YssS5AlGLbHs@Vr_mY7VvawfL3j)Qq5a(H`{L?4(XT-Xlwppb z9?l5wT%Zx=goo$dEaVZ7&rud#pb%v-)JyE<9&5BNnB1)!RS0#U>Z!<0QOnaet5jbx z$?K{=63jG8=>DiyeIfIj7joP;MJ4*$h?yS#@=c)aot!LhUb}k)$637v%AhlJg|87t zCYTZvUPTMc@pE*I7r}-f{4!V|8gt}E-oFJT(e41-o)Yt>fClJTF+S+nNH;T~Y(i-v z3qz?5q_o`wrdTBmBqCF*d}1TlzqI(vfY?xra%8)rIuxApqJkPZE5!}AT}tzkX-(-3 z`aMr8!3&m7>X8sBXNIM-B{J#k#7t_L%EPZziM*EDsKx`9N^wNna9Y*8DTQt=EsMi+ z+#*d!OQ>mE5wIT4PE-e`>^g`%nRtzyzFHCR4>!Zr#JR&}B`cucLqLB71+c<)eTU5& z1)FGrrf%1K#Q6X<-arOV0rYQyA=A(hpdz@&U}I;;f!o5%z|){yU$=iifC>W84LWKa zLIC>A2D3uB)wd*P#u8mfxYc?p7I0kAKDZ(x6o(k1M0U#E25u(Hi`DCKodsRvoTa1Zxdx^q#Mkhp8>5oDF? zO*w31INGTMhrxr`$6#ZG87Y Hg4fAkhOHWh literal 0 HcmV?d00001 diff --git a/modules/module_lib/__pycache__/linear.cpython-39.pyc b/modules/module_lib/__pycache__/linear.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc817db604a29c225b2d3455ee96e8966e8e6c8c GIT binary patch literal 1522 zcmaJ>&5I*N6tCA`=_He^y4l^qT_g^R&?p-Z9z2K$DkA8xu;QT+ikhxU(wRIZl@9sJL3=-o<{e3$;5l5-gyvoyQ zp_f+y=5XbZ02FqxG3RI#nZ0u;(q)>@uHi0$hFfimI=lOAek8I~j?+9f<7FY$c<&AD zMS*qlmG#Q!+F0KdI+Ir(U zGF5KAJWCHIS(+=MpR7)8|5&B7xxv7xp)Qy2pkNTfW1Nq0j&bgSet_D9Ot;6DCglgH zRV>DmBgwii*%aF@gNttL%8m??a^6ekvDn9k^lb256a5F>#Z#JsYHCE%L9rbnE~rmb z{8FPZ>>TN~W`LBc87JB??i#Q5aPLU?#eLB+{vx>Sm=J9+bEc58y7LFTg4!P;JrUon zT|K`+&EOU$AWZH^hPYls9cMpsHI+l(B{4z~^|5EwN@*MNrevJkewy>C5~k9sv^&_i zo1G`q#OMJYL|Wvw`&4MLR7UBhoHl4{Z@q(5lu?MZQQ1^Kq+pLoypJO8QfD|ZB*sX% zY!7Ypbdqj>`om+PXXR_KUCxtClqKg4f%JXy^)3p^rQAWmpa%i|hOh@`ciMJ5*(Q|G z)7TDLUqAj&aYNN#U?*x=IBU3c*X%9!f^?o_73fcDxb$haMXUj>^w$FWQJ}}@!8)d} zzBH%Ed@xOt))P|l$kBUbhm-c+#Qx^;_#_IA*-1*GBx*3Yk134cY`D4n1`v7wWfi_i zY*mhB$5!6uxt*D;9`FfBw9PuO^42Z&fg632-hkh3ze;m}Fc^iI{z&Tdcl zq`D`zSG^Kqjluy(;=%#>u$u!?Zn+`;0N@|s)JHh^hBzX{%kNcB&x|L#Q8?1AeqEoh zzTW%2_bO_)Jq^#F-gsj3FHdONKPfZ&n8>__FZ%-s*Eox{p!&0c9_Xw}8v|X}m^Ls2 z3ne-?Vte2Oj!K(IyFo*>v|?}244O>4s&Sh;cQx(^eXF%+1Z|YKyn&L2Xd}OXyvLi! zH^l<-Cy;OPHuCLpk415kFKp}b_nh6(Iwu}re%jgt2HIieiDBAL@TbMwBzEpFen z0{bJ))B*=na?8`<_3<=rTzWg@qgcEiWui-StRG72_zZm6Ga$LPr|q)?>?YR=Ez}h? z6g4?RiN!TA*VjsS5Zw-kX&QIz(#l3dAxk$5`$<0!!_rRF4wP5{nCkBfwZ?=rsZ-yl70*DTmRgqZ6GI&Jc< zRT$hDY4Td3Z9!~DOum&H`(|N*S_QUy3oX%hrcZ))yGCy34r<*!w#VMk-oLwRQ?9Vj zXy|1W`kt{%^$k_x>@k(|xCw#%ByVoD3I~Vj?6`#{@SwtV#m)#lBLp5H~f54P^I5ZQ_1^Dl#8q`uD5!gS6x0~czK>77g<;Khfux= zGa9GvIM0v-i6%HLug-Q_K#Qqcsby5Me#BbPYBJ8Vw6Y1)nnyj0YWg*!_p;F-+&)Jd z^X~Th?>>Lw{m%TVXaaR-e`bl35hSVg10rjZ#P9*BI8?5=Eo36%FzfG#z>PlaDII=r zhzE*j@cy9lSlP&Blw`v+6Qu_=%lo}iI?5#Ji6lP3VlTob$}oj)N2Gwk>5Y*@QI>-r zurJJ_!7vu31x7tl8nH-Zle&|_$u>yIu9RMlAs0EY05dARdPXQYwU*{TCN2Y!>bPobf>gCucVdkdWm;69*D*#C@>g zCF_^1*(UNN@la$5snXd<<*i6^Vx4^-L~EG(B6CgmbDQa`_0TlT&s{_RT(_3klD?pO z%w!9Pt}*_bCQ4fR&Nq(6mN9qW>KcF>IGRKQ0UY(humUWkTVeRYD2l5Z1zz%5nwou6 zDUkdQW!hsT?TR=K!`0CP;xR8Ef}pBch}$ z!X%Z0j;%zPgp&9BrJGbjnpHx15$lvMf=nzVZ2^fY{YmVzzk^IIWUd2f8K|+xOf&#+ zK8KvS$F}wCg#_NoPc(ofSxqB1_pJl6sP-h~n!0UUdEVP&VD0r z1JGNARk+-`XDGn_c>xm2aa9sTVHpsw&})rUS$aAl0)Y6n++QY3=tjW~_+=fKN$_yan@itJ{t_~i>SNluSaq0L-*yKc~*pzQ#SRK6uWFF zy0r1mq?E33|OC~Hsl*t(BJ+kox!DqUnSQraMkZp9GC7eubkVmabq)* zg1>$3y3)cO;O9dSrWu-F~l)%L?hjgA&LD>YF)k0oD zm+adhTFcX)(QP<$9s@k%wRQV&iMbC=(*-0i>QAzruh&RZ`^A8tiF?=a5rBcyIkx2^fJ$CBNrD2@I^@^_O!gQ6&;T{@S%nRFayh%_QWyuTVMKpC z0-`YNRQ3NCm9KSgqBUa$ox!|WG^%~A`R0|BuGRopWeITzuI3w$MOgvrEb|$tBiHnJ zr2BtCAUI;m&zsZuiG`ht&KjU;0!N)j1r)E7bg8~yDt1Q!0qIjl47HZcmYDf)=S(f= zwXz?t>GDT3Fe$hsK#{K!d5s7Otvpu6FQN7q_$DfDGxuV|BMX8tXH&Uz~|ASx;p#| zvr^ujpUo@7yyni=e5RU$O?jxi%rO8$Q603+`pIous#t(yD%~Dz9QOTyLntDkD#WYq!47L0a84Q2sK$>@yJBB)E3ha$LF4 zmL2Ht>ATi09K6Ef#^c&k+OCu9*iMJMZA9tDu1lzh@&@{taQp0{abWCv$uk9Rg?exV z_w55*YjN|}`WfvLaNV~lQvV%$NBda0mTJV0!PS8?X(0B8u|V|SOw7WA%h}v_xrfLe z&L>>Yrb;#UJt~7_HKRWtfx+V~8ON|k#KCaK;bLWd3KS zCd5O!3HY)JLzhXnkKNg|9)u3hzh3RF{`Q4`e0=B4SLD}_zX?5m3jfmC7Gj7ifXZ_u zMVC(!IRnyZR{o!xZ%To9DygDR19ug{$*K(YSIA)+17%8E+#W^=m*1lP=RnHFC>h7@ z6{tK-Ok0(DzDcdfcE9&a`X2S)hR{UdpXXB#jY4<(h;9l#*!Ct=uGsy5nH_Z;W3?=PV~I8|rQ zZ;TQZCP#7UP1DjRE|t$N391z^mUez;h)AdMV(rQ}sQ6!bqsm^X@vFs7H&LZ%CrkI` g>PF%fAc2;V__)~0CA|gOGL|kZKD~H)Y1w7}1$kN(-lsz0XH)6%qLk4)QRXDXJ}@yIZ2#hGZNvBopVl*bv_Nh(3g?BeV0 zYFCl1N0ucOsyCKk0|DkA$Vm56`v6|}7Z~1nO&~y!Y^-fCD=%!HdmkUk`skJd#FY4& zd?erJJ?EW!UxjOHEd$3d-`zg=^h<{EKXg++9NgT(nSP2SG9q(mbm^}(G)Gp~GU?jJ zwcWM(+8jC~x9jq?GxSD{ZllQeyFP#4?6%%BA~*7$7?CIJ$98u`tQ@X(*O1$YeB}C^ zyDrv|`wDWKQ46^(&fO3j$i0Bvm1q^YtDL(jHj(=(a@V4Dme@k>MdZE`Z6J4p zb1#WY$bAjD7ott%ZVLBs8?CsEoL8eQ$jN+a{B9a)k6#CdFX8^T8S zX;@6$`H_RF{ruwHvCIxcnthZE;_TyabTk~q`)xb7A3mDoUVKE^xhF+7mT})kMf&h@ zLFXNu={+P<^Taqc($x=)pPINfPmO=EkL_dU*v+g%`^1^rQ|r_^_V~`lopWlPc#(N( z9yhYaiBCCGYidp#Q}5I>jj4wgTH!yNhS3U&;|HlogFz-_m<@g+g1w^g0XH*9i{=Me z5~KnnmIWa$vLM+DR8N9E5uzt3Bs&PR;2=aH=zcH^H*a4HI@f|5?cl+IkRs>=Zx>b1zpff|Gr;&>3*zx;S4e7S z*}y@}E#p3kh`Ldgpr)g00gp~H&W^^JYJw`f9b_U-6B+cAI1ACGcwd!SD{CJ!7xYf7 z+wEq)Hb{H)>R^AIjMLmn2NRLISt9!fxzisG(nI~Sfh0F?=Vm82d%352k~^6g9X&JTYk2y5 zu`?3UAlx}f_Jr7p!Ytf5lF1?Fc<1O*c96t7@$OH0QlR~5CdUk&omf2R71=wZBpMGz zTKw$|2fI6EYkT8tFifu>J<1zsXccv)lo%_k*1EZFZdqIA09(c zanLnPSK)(Lj==uR>a?v+hr86axbwLg@~uj?@nY`phW#H?D+1i5T&OjM^cYd%gSPa9BK9g-K_F{zzcdBDq2pdc~aDFh+S9Ir6iV(7@Lc zYMUg|p>I;gziFATP5Y7L$}kBvlfYd+dRZG+}ulDf`&Z+!GJ< zOZ#j1n9$l0@}Mz1mT(Oh%#s|urM6NPTSg6bU*MtlAl>_Al2d*+A00s47u_PBoD8HJE=- zn_=2+$_w-b)_OyJnUYOP2;p)INp9YeuhPv$N-7Iu<08}hOPndyXEayMRcjU4+psR- z@8q>PThO?hf!)mPcTvjJIHZM0&9a_x%yMRB_Mvm)PSsLmpLnM*yyM2yJj7yjYGBrz zCoT09^Xy}`8;1>^_eNywSpYL2w<7-nb*M2Tpw10k?oNbEdbvY{mA5K#RU>+3c0~7O z80ldx#xeIf6H})K^G&LQ$FCgLe2rhx)==Eyiyl=vKg0y=zs0G@4a2%>ZJRA~+nj9A z53ion4lLt!s`n}-pQ9w8glIs1o|5w^=|_=h(!qTlHu=Q z|6n(fDM^N+V^9(ScdbP9Ibx;wrY2bCzv80Q^AAbTRL_@j3kjuKVZrdNQacen5nY|p zgR)!_F(vOQ4W(EQUA(c_^YXjrw8}oWkpMGiGT0iPE6{ClLwa$np$4IIUuK%PzD6&; zMF}H}v8IS?)OD$wZz&`$p|0=YD@9!^gsXMnYJ>Z0Nc)eo1klcjrTx6v_*5{8xD8L#l|DTpu`o<05cOxo6|} z@9}gAj;rM#SScJ7EdT7fG&n8Z*t3Ci=c&bx3%N0V)BsBST=f!iRLG-b7-pjFRBKM| zJs^N&3gpx^eUN?70(S>zdM?Oq77B07cr);2j=d7_4J-)|1uOvcV&Y)IXs~1;u+trh z29gXm;O;{-kWjF~<2uMc5)y0ZC=FDr1Gr$oIe7*}=OBCn{J&8y`K9rn1980|C|F{_ z|F9k0)$UIu(!LxVLEFO>N`hTcs}0`9o=isbb(GkQ*}pjh&O}IzcJipQvcW`7x%#vTl7s~3{ zZJ+{hkj0l)|0CSYh$6k571FnqMJ#>`S>%_ps9DG&C5D^IvLTGvwj(Tsdp244>!}$0}*Gu&Z!B1fmO_Cp|V(*e&Gr)7Q0{?v=DWYSj_s@vpJ;v zZek$MRDP0lto#Zk*C=5Hy~M=0z1n74is*R~Sv#}Mkgcf2=&$MO+mu|VWQUTub%^xc z&u}Vz*MfW|eP>-UCzlt@MlTSZXD?bESm5yzai(MVO;kFfiuLMn^%7PG%3W9;$cMZ| zuYQ1J`RXvE?G{!BR)5P@g$wV=WPF4;N9i!iH&DmSYC8jlA)cP8^OUzx1hGmF5&j{3 z1MML)$)7sM%}n`-P<2fEMEOky1-!*Qi+#iRn2$Tl>_j zIcJl-t#$#ESr@b0jP;o0jiRsYamvbPI?Lz$cQVGx=cIVgrnPOg0Gb7Z(qt?Dh79XW zmi<>UEOI?9ay<*!bASgmzRlUnhVL1oGe09;fDm`IkaW8bJm(uFx8B^if&t%lfcV@HT zxWp~Y)%lG2A9zDCs^79UOfn5H>Xk*~C@EFGO@&tE%vT)bAn*GZRGubEc84iL-_s2p$YP$&MV^tlel8RO_$Oc79 z5x`c`f%sV;Gi>=IhAL`_e+@0{yr} zfIX4XAch;l{8^9Am9||#v>53^?iSp9qaE}ibA)o}>2s<%xK9&S6|f9zFQj=EJXB(x z;Oh;9T*p=jIUp?DxlXRW8odRvTx0At*n~nQOPoz=IKs_nxc09#U*D!TZ&Bh=LN3^x zkFCOltb8iul##*T;DeOHk*tZ?z1|`{eBrjXZOzRu5s8OW+;%!Rl#6XuE_MyHTs0}4 z|H7&L>&>b6lScQIkAr z0ZMpk|J>L!BZSqTtaUx?a5fJ8@@&!-yld~N!~3FcxAk+wF}f?j;7Y#6WbsZMF~{9Q z45FMqgwj$$Y|*cK)^ts$P0bq6GN7?El(^g%zPH29$>luM9h{Cl09B%_!jQJjOmO=+`Qrd69i;e$zlMjAd-f4DTV3iT4^P zts&#gKWvf`*|$)d`-yez-~ZVZLO9-}=jl&BfFV6LvGeO4`yAtFO#U!GynjC8l*u5f z=&E2^tuE>Ux*qo^;spA&s|Z+`_Cp|09kLEW(Zk zqsce!gmIh{o6mHiiw#1=dMN{D7|^N`ra`Eme!ktxTNSrf)|NLj8OG^Rk_z@T^G#Mm z@PT@X-Zjgc<2VH;3+zUa%`z2X|23qj=lhgFM4}kERMp)k+R)>|^A#>rM1=Z$L*M9z z=d1P6ndeXP9sd_+N^ENQn}~_F%uAM!)Wf!_2O@W@$;DX|6gsnQ-Xm^Zo#)nfQ4*7C z?(8PXP!g(Y{7Ts3U}nBnV6ZQ+Dit_nFp}#jncu&ItKZ>dMEd4-Q&P?X)e_F=NJ-df zEk;xenl)=Q>wsq6Lr^2Ko&*^Uoy?^ae4gP_ zaA;(FY{)hqfS2FHA2k3>N0ypUo@VakEcdKITPg0r_)nC-jAqnm$DT8W8O9++N~6n+ z^LMx)=egh)Vi^;RV@Ea^@RAr@og5*uM$+{u5J z@F_t2|HK)C{Wy^#nlD{j@Slky?tmh0TnoN7v&GB2AZhG6vV*kk%J1Wk{Fsswo$?+% zYAi+P59rMwQt}U!@B~-0OnAFMv&;@H@KzxWINQ*Mr?pZ6v4sfGKT^?ixZ*I-Kh4%z z0f}~7SG9Th=9VKVr~MfQW_O)`M{+$JX9;;|lI&ZOf#<$@634lP)5{)F07sIQ<#zfg z?XFk9)1V)ra43y`age+CDar6resS4vO$6^Sf~tSFE8H4SQzP>NPG(6@;~45?Tss*z3~N{ RujBkG&Nj}^diLfUe*&%sCsqIe literal 0 HcmV?d00001 diff --git a/modules/module_lib/__pycache__/pytorch_utils.cpython-39.pyc b/modules/module_lib/__pycache__/pytorch_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b39b331e9aa94eecbb8367faad72a1d21cb286a9 GIT binary patch literal 5163 zcmb_g&yO5O6|U;v(=$6e9$T>uBqRg_!(kVj-w+aGn++f{LfAxTB+~Rw_jr3f+q16j z*~Ck4gye{vEaeYaAv*|({RcqGA!o$Fe?VV3StZ`6N3tZH=p@Vm7@asVXeuI4%{z}k< z-wbWE*Ws@QE%+_=8}LsAZTRi7$CW@{6P=SEVzi>$(Q@Nzl%%uKsl6`jO1-DBd0V{V*3DE7vyT-1YluwCAVMD9){26~fKUL8LHJBldT~ zd?kvL)Q|h27mw5q-ZFf3BT+P7MeC*NWEx~vVEP&iT^33Wg<5o!=G>R3MQUX-bIuB~0esw<4Ut^W47N_7-_%F1 zWm;ycL1v(EjoQxiSPMjEWOiWOweD+D-1<}?5Dv5Zp~NWm?V*1=RKGq@tLE!C-g7GlStpOtAr@3``6z^RK5OpI7 zhXd859yw6czK=&Mmk}1F=6U_0pCq1lBtGu1?}S0*uWyY8VYnXnslUFfM%!UOUEjT( zZjIvgcyrHFVSj{mQDdwgRfTVOW%YV7sHFVs4WmtZaiscN-Z+hh$=dF1bpj)^QIapf zh`Ottl7?(ZO(`w3~+6ZM#Egk6nHTCN*RNGI2?MO`ZOLa%vrko3>o6TiIR}# zSQ>`PgOl1rL$fqoh5wMBTwoZ@m#*YFn2KRG^M=Hs*3+3c27uyVO=vMM`AOKlzM<&d zev95Nuao#oE6d2f02ob?oz+L}%x;K-dR7B+wM+*R&8!~CyVgM~bF$X9b_c4J)zwA{ zdIRmPZJ_=>aGKfbS6L(ZMdoIB=HNu^;0`P~IFU8gZ_)B*wwgf>NZ;=22kp2CUrzn3 zp&p=*-+@VXwi1}9MRtPwwE{7g>Q808g?i>iW+xBO^MkV9Mtz%~FWz_Hx7csIE0ZA6!dos;S`JprLC#^j8~Dc`9t(lcKo(;>s0W_Evt+%=fo-SkuNjFhtI zk`=NN$Gm+^OufN4?)PXEd)^oW>e-udd4sQ0jbijO8JGgZ+jA(qONh3sRTPj6{*Nuk z1(REv0qbfGEO|}#Tkwn0(k!lb;Wy-dyTbaE*>-n=`;G3Ash*`N5^>aL$CQOo)$OnFSdmh!Lh36X)@S zS|>yJ!bU^TQWwa4jSK^0$r5j)F(GIPXxu4?^M0$G@pOT-E>B*4gF0F!m)}Ig7D_^* z>u8Gng}9TR(v{EvD^Zkzc8Js?Gy?rl<25h>^X+C2>w(0=f+l|g>TE3+*#nQGfQL`~k+DV4h>jx=OT(*PlC6^o15ehV)j z=`+E6R2)-0$kP{c|J2NQt0*4X^T_@jUQwSX^LUKY7}TVDM3?+WN0%^PMc3rWYGq^z z-7AlX?1dt->f6*g#cgsuaV*t$@XBHYzlU1&T{271w~zY8s9gxEs9~jPLsT!3Sq44o z{>@1*2oCuR!J$luuH-eCDmawoYMk}b7KdyX_a-ru5_lB*&@J;`%6=XG)dJ@bIWO{G zO5mG4mtk55OzTBX286ESBfODZR|LXhOFwe|ho}=Pikz_0@S|FR@2^tR+?K$*!SQ#h zWpnX7C6)v$TE2v2fW(O?P-O{LW@io&A~ldbxLG}Gz%xOP)hRh_-fRkn9RC`p;Ogt> zxns?<=4I3}O!MZC%9Orf@6pN(m6lCwZl_!J$t5uQ4ZfEUokXk5B8xqWtucH5_x)8$ zw$9HWtTE5a=h2UPfea~YzIxM-c4+gbt=1oUih27YwLbxQDHLy^ukvGy-S`G9c z24#~)MmfnW|EcJ`N8>m~>r>+xr?=1^VD|x%7w>8z{oFZx*BRdp!nb+`AB9>nyvg96 z#C8Gd>J_TJ2=o6&r20NOU4$DAjER37ZjiLg)UpIHe?a|H0cKW`Rvr%+b%j34AR+Rp zCq)X);`cNz0{!EV;$o+Er%0jq`Y6XP5>2|2i!c+Uu+~BIpmk8ZLgx-mobYkR$Ipi1 z6={$Y)Szt+uYC<`@8;UZEhsCZ8(SaJq4}4L;Kt$tBXOdlqWS)G4#n9!}@4;|$>7M6Tmpk8@Fx zejy#BbID%WW*{}=6s+{={NmiFbCH$2k3QZz)(1{V+*b7Q{(K*tvRPehRoCz4fo)vNUNC2;>M>I=Bz{N0ug z05r+mPjRMRWbt&rU*EVq-F@9hL;a~#a03JQ3k~87jAK_v37n>(eLAb(Lq$(@nz_UO zL0F4pPQh|3y}b+h (torch.Tensor, torch.Tensor): + """ + :param xyz: (B, N, 3) tensor of the xyz coordinates of the features + :param features: (B, N, C) tensor of the descriptors of the the features + :param new_xyz: + :return: + new_xyz: (B, npoint, 3) tensor of the new features' xyz + new_features: (B, npoint, \sum_k(mlps[k][-1])) tensor of the new_features descriptors + """ + new_features_list = [] + + xyz_flipped = xyz.transpose(1, 2).contiguous() + if new_xyz is None: + new_xyz = pointnet2_utils.gather_operation( + xyz_flipped, + pointnet2_utils.furthest_point_sample(xyz, self.npoint) + ).transpose(1, 2).contiguous() if self.npoint is not None else None + + for i in range(len(self.groupers)): + new_features = self.groupers[i](xyz, new_xyz, features) # (B, C, npoint, nsample) + + new_features = self.mlps[i](new_features) # (B, mlp[-1], npoint, nsample) + + if self.pool_method == 'max_pool': + new_features = F.max_pool2d( + new_features, kernel_size=[1, new_features.size(3)] + ) # (B, mlp[-1], npoint, 1) + elif self.pool_method == 'avg_pool': + new_features = F.avg_pool2d( + new_features, kernel_size=[1, new_features.size(3)] + ) # (B, mlp[-1], npoint, 1) + else: + raise NotImplementedError + + new_features = new_features.squeeze(-1) # (B, mlp[-1], npoint) + new_features_list.append(new_features) + + return new_xyz, torch.cat(new_features_list, dim=1) + + +class PointnetSAModuleMSG(_PointnetSAModuleBase): + """Pointnet set abstraction layer with multiscale grouping""" + + def __init__(self, *, npoint: int, radii: List[float], nsamples: List[int], mlps: List[List[int]], bn: bool = True, + use_xyz: bool = True, pool_method='max_pool', instance_norm=False): + """ + :param npoint: int + :param radii: list of float, list of radii to group with + :param nsamples: list of int, number of samples in each ball query + :param mlps: list of list of int, spec of the pointnet before the global pooling for each scale + :param bn: whether to use batchnorm + :param use_xyz: + :param pool_method: max_pool / avg_pool + :param instance_norm: whether to use instance_norm + """ + super().__init__() + + assert len(radii) == len(nsamples) == len(mlps) + + self.npoint = npoint + self.groupers = nn.ModuleList() + self.mlps = nn.ModuleList() + for i in range(len(radii)): + radius = radii[i] + nsample = nsamples[i] + self.groupers.append( + pointnet2_utils.QueryAndGroup(radius, nsample, use_xyz=use_xyz) + if npoint is not None else pointnet2_utils.GroupAll(use_xyz) + ) + mlp_spec = mlps[i] + if use_xyz: + mlp_spec[0] += 3 + + self.mlps.append(pt_utils.SharedMLP(mlp_spec, bn=bn, instance_norm=instance_norm)) + self.pool_method = pool_method + + +class PointnetSAModule(PointnetSAModuleMSG): + """Pointnet set abstraction layer""" + + def __init__(self, *, mlp: List[int], npoint: int = None, radius: float = None, nsample: int = None, + bn: bool = True, use_xyz: bool = True, pool_method='max_pool', instance_norm=False): + """ + :param mlp: list of int, spec of the pointnet before the global max_pool + :param npoint: int, number of features + :param radius: float, radius of ball + :param nsample: int, number of samples in the ball query + :param bn: whether to use batchnorm + :param use_xyz: + :param pool_method: max_pool / avg_pool + :param instance_norm: whether to use instance_norm + """ + super().__init__( + mlps=[mlp], npoint=npoint, radii=[radius], nsamples=[nsample], bn=bn, use_xyz=use_xyz, + pool_method=pool_method, instance_norm=instance_norm + ) + + +class PointnetFPModule(nn.Module): + r"""Propigates the features of one set to another""" + + def __init__(self, *, mlp: List[int], bn: bool = True): + """ + :param mlp: list of int + :param bn: whether to use batchnorm + """ + super().__init__() + self.mlp = pt_utils.SharedMLP(mlp, bn=bn) + + def forward( + self, unknown: torch.Tensor, known: torch.Tensor, unknow_feats: torch.Tensor, known_feats: torch.Tensor + ) -> torch.Tensor: + """ + :param unknown: (B, n, 3) tensor of the xyz positions of the unknown features + :param known: (B, m, 3) tensor of the xyz positions of the known features + :param unknow_feats: (B, C1, n) tensor of the features to be propigated to + :param known_feats: (B, C2, m) tensor of features to be propigated + :return: + new_features: (B, mlp[-1], n) tensor of the features of the unknown features + """ + if known is not None: + dist, idx = pointnet2_utils.three_nn(unknown, known) + dist_recip = 1.0 / (dist + 1e-8) + norm = torch.sum(dist_recip, dim=2, keepdim=True) + weight = dist_recip / norm + + interpolated_feats = pointnet2_utils.three_interpolate(known_feats, idx, weight) + else: + interpolated_feats = known_feats.expand(*known_feats.size()[0:2], unknown.size(1)) + + if unknow_feats is not None: + new_features = torch.cat([interpolated_feats, unknow_feats], dim=1) # (B, C2 + C1, n) + else: + new_features = interpolated_feats + + new_features = new_features.unsqueeze(-1) + + new_features = self.mlp(new_features) + + return new_features.squeeze(-1) + + +if __name__ == "__main__": + pass diff --git a/modules/module_lib/pointnet2_utils.py b/modules/module_lib/pointnet2_utils.py new file mode 100644 index 0000000..97a5466 --- /dev/null +++ b/modules/module_lib/pointnet2_utils.py @@ -0,0 +1,291 @@ +import torch +from torch.autograd import Variable +from torch.autograd import Function +import torch.nn as nn +from typing import Tuple +import sys + +import pointnet2_cuda as pointnet2 + + +class FurthestPointSampling(Function): + @staticmethod + def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor: + """ + Uses iterative furthest point sampling to select a set of npoint features that have the largest + minimum distance + :param ctx: + :param xyz: (B, N, 3) where N > npoint + :param npoint: int, number of features in the sampled set + :return: + output: (B, npoint) tensor containing the set + """ + assert xyz.is_contiguous() + + B, N, _ = xyz.size() + output = torch.cuda.IntTensor(B, npoint) + temp = torch.cuda.FloatTensor(B, N).fill_(1e10) + + pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output) + return output + + @staticmethod + def backward(xyz, a=None): + return None, None + + +furthest_point_sample = FurthestPointSampling.apply + + +class GatherOperation(Function): + + @staticmethod + def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: + """ + :param ctx: + :param features: (B, C, N) + :param idx: (B, npoint) index tensor of the features to gather + :return: + output: (B, C, npoint) + """ + assert features.is_contiguous() + assert idx.is_contiguous() + + B, npoint = idx.size() + _, C, N = features.size() + output = torch.cuda.FloatTensor(B, C, npoint) + + pointnet2.gather_points_wrapper(B, C, N, npoint, features, idx, output) + + ctx.for_backwards = (idx, C, N) + return output + + @staticmethod + def backward(ctx, grad_out): + idx, C, N = ctx.for_backwards + B, npoint = idx.size() + + grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) + grad_out_data = grad_out.data.contiguous() + pointnet2.gather_points_grad_wrapper(B, C, N, npoint, grad_out_data, idx, grad_features.data) + return grad_features, None + + +gather_operation = GatherOperation.apply + + +class ThreeNN(Function): + + @staticmethod + def forward(ctx, unknown: torch.Tensor, known: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Find the three nearest neighbors of unknown in known + :param ctx: + :param unknown: (B, N, 3) + :param known: (B, M, 3) + :return: + dist: (B, N, 3) l2 distance to the three nearest neighbors + idx: (B, N, 3) index of 3 nearest neighbors + """ + assert unknown.is_contiguous() + assert known.is_contiguous() + + B, N, _ = unknown.size() + m = known.size(1) + dist2 = torch.cuda.FloatTensor(B, N, 3) + idx = torch.cuda.IntTensor(B, N, 3) + + pointnet2.three_nn_wrapper(B, N, m, unknown, known, dist2, idx) + return torch.sqrt(dist2), idx + + @staticmethod + def backward(ctx, a=None, b=None): + return None, None + + +three_nn = ThreeNN.apply + + +class ThreeInterpolate(Function): + + @staticmethod + def forward(ctx, features: torch.Tensor, idx: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Performs weight linear interpolation on 3 features + :param ctx: + :param features: (B, C, M) Features descriptors to be interpolated from + :param idx: (B, n, 3) three nearest neighbors of the target features in features + :param weight: (B, n, 3) weights + :return: + output: (B, C, N) tensor of the interpolated features + """ + assert features.is_contiguous() + assert idx.is_contiguous() + assert weight.is_contiguous() + + B, c, m = features.size() + n = idx.size(1) + ctx.three_interpolate_for_backward = (idx, weight, m) + output = torch.cuda.FloatTensor(B, c, n) + + pointnet2.three_interpolate_wrapper(B, c, m, n, features, idx, weight, output) + return output + + @staticmethod + def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + :param ctx: + :param grad_out: (B, C, N) tensor with gradients of outputs + :return: + grad_features: (B, C, M) tensor with gradients of features + None: + None: + """ + idx, weight, m = ctx.three_interpolate_for_backward + B, c, n = grad_out.size() + + grad_features = Variable(torch.cuda.FloatTensor(B, c, m).zero_()) + grad_out_data = grad_out.data.contiguous() + + pointnet2.three_interpolate_grad_wrapper(B, c, n, m, grad_out_data, idx, weight, grad_features.data) + return grad_features, None, None + + +three_interpolate = ThreeInterpolate.apply + + +class GroupingOperation(Function): + + @staticmethod + def forward(ctx, features: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: + """ + :param ctx: + :param features: (B, C, N) tensor of features to group + :param idx: (B, npoint, nsample) tensor containing the indicies of features to group with + :return: + output: (B, C, npoint, nsample) tensor + """ + assert features.is_contiguous() + assert idx.is_contiguous() + + B, nfeatures, nsample = idx.size() + _, C, N = features.size() + output = torch.cuda.FloatTensor(B, C, nfeatures, nsample) + + pointnet2.group_points_wrapper(B, C, N, nfeatures, nsample, features, idx, output) + + ctx.for_backwards = (idx, N) + return output + + @staticmethod + def backward(ctx, grad_out: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + :param ctx: + :param grad_out: (B, C, npoint, nsample) tensor of the gradients of the output from forward + :return: + grad_features: (B, C, N) gradient of the features + """ + idx, N = ctx.for_backwards + + B, C, npoint, nsample = grad_out.size() + grad_features = Variable(torch.cuda.FloatTensor(B, C, N).zero_()) + + grad_out_data = grad_out.data.contiguous() + pointnet2.group_points_grad_wrapper(B, C, N, npoint, nsample, grad_out_data, idx, grad_features.data) + return grad_features, None + + +grouping_operation = GroupingOperation.apply + + +class BallQuery(Function): + + @staticmethod + def forward(ctx, radius: float, nsample: int, xyz: torch.Tensor, new_xyz: torch.Tensor) -> torch.Tensor: + """ + :param ctx: + :param radius: float, radius of the balls + :param nsample: int, maximum number of features in the balls + :param xyz: (B, N, 3) xyz coordinates of the features + :param new_xyz: (B, npoint, 3) centers of the ball query + :return: + idx: (B, npoint, nsample) tensor with the indicies of the features that form the query balls + """ + assert new_xyz.is_contiguous() + assert xyz.is_contiguous() + + B, N, _ = xyz.size() + npoint = new_xyz.size(1) + idx = torch.cuda.IntTensor(B, npoint, nsample).zero_() + + pointnet2.ball_query_wrapper(B, N, npoint, radius, nsample, new_xyz, xyz, idx) + return idx + + @staticmethod + def backward(ctx, a=None): + return None, None, None, None + + +ball_query = BallQuery.apply + + +class QueryAndGroup(nn.Module): + def __init__(self, radius: float, nsample: int, use_xyz: bool = True): + """ + :param radius: float, radius of ball + :param nsample: int, maximum number of features to gather in the ball + :param use_xyz: + """ + super().__init__() + self.radius, self.nsample, self.use_xyz = radius, nsample, use_xyz + + def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None) -> Tuple[torch.Tensor]: + """ + :param xyz: (B, N, 3) xyz coordinates of the features + :param new_xyz: (B, npoint, 3) centroids + :param features: (B, C, N) descriptors of the features + :return: + new_features: (B, 3 + C, npoint, nsample) + """ + idx = ball_query(self.radius, self.nsample, xyz, new_xyz) + xyz_trans = xyz.transpose(1, 2).contiguous() + grouped_xyz = grouping_operation(xyz_trans, idx) # (B, 3, npoint, nsample) + grouped_xyz -= new_xyz.transpose(1, 2).unsqueeze(-1) + + if features is not None: + grouped_features = grouping_operation(features, idx) + if self.use_xyz: + new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, C + 3, npoint, nsample) + else: + new_features = grouped_features + else: + assert self.use_xyz, "Cannot have not features and not use xyz as a feature!" + new_features = grouped_xyz + + return new_features + + +class GroupAll(nn.Module): + def __init__(self, use_xyz: bool = True): + super().__init__() + self.use_xyz = use_xyz + + def forward(self, xyz: torch.Tensor, new_xyz: torch.Tensor, features: torch.Tensor = None): + """ + :param xyz: (B, N, 3) xyz coordinates of the features + :param new_xyz: ignored + :param features: (B, C, N) descriptors of the features + :return: + new_features: (B, C + 3, 1, N) + """ + grouped_xyz = xyz.transpose(1, 2).unsqueeze(2) + if features is not None: + grouped_features = features.unsqueeze(2) + if self.use_xyz: + new_features = torch.cat([grouped_xyz, grouped_features], dim=1) # (B, 3 + C, 1, N) + else: + new_features = grouped_features + else: + new_features = grouped_xyz + + return new_features diff --git a/modules/module_lib/pytorch_utils.py b/modules/module_lib/pytorch_utils.py new file mode 100644 index 0000000..09cb7bc --- /dev/null +++ b/modules/module_lib/pytorch_utils.py @@ -0,0 +1,236 @@ +import torch.nn as nn +from typing import List, Tuple + + +class SharedMLP(nn.Sequential): + + def __init__( + self, + args: List[int], + *, + bn: bool = False, + activation=nn.ReLU(inplace=True), + preact: bool = False, + first: bool = False, + name: str = "", + instance_norm: bool = False, + ): + super().__init__() + + for i in range(len(args) - 1): + self.add_module( + name + 'layer{}'.format(i), + Conv2d( + args[i], + args[i + 1], + bn=(not first or not preact or (i != 0)) and bn, + activation=activation + if (not first or not preact or (i != 0)) else None, + preact=preact, + instance_norm=instance_norm + ) + ) + + +class _ConvBase(nn.Sequential): + + def __init__( + self, + in_size, + out_size, + kernel_size, + stride, + padding, + activation, + bn, + init, + conv=None, + batch_norm=None, + bias=True, + preact=False, + name="", + instance_norm=False, + instance_norm_func=None + ): + super().__init__() + + bias = bias and (not bn) + conv_unit = conv( + in_size, + out_size, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias + ) + init(conv_unit.weight) + if bias: + nn.init.constant_(conv_unit.bias, 0) + + if bn: + if not preact: + bn_unit = batch_norm(out_size) + else: + bn_unit = batch_norm(in_size) + if instance_norm: + if not preact: + in_unit = instance_norm_func(out_size, affine=False, track_running_stats=False) + else: + in_unit = instance_norm_func(in_size, affine=False, track_running_stats=False) + + if preact: + if bn: + self.add_module(name + 'bn', bn_unit) + + if activation is not None: + self.add_module(name + 'activation', activation) + + if not bn and instance_norm: + self.add_module(name + 'in', in_unit) + + self.add_module(name + 'conv', conv_unit) + + if not preact: + if bn: + self.add_module(name + 'bn', bn_unit) + + if activation is not None: + self.add_module(name + 'activation', activation) + + if not bn and instance_norm: + self.add_module(name + 'in', in_unit) + + +class _BNBase(nn.Sequential): + + def __init__(self, in_size, batch_norm=None, name=""): + super().__init__() + self.add_module(name + "bn", batch_norm(in_size)) + + nn.init.constant_(self[0].weight, 1.0) + nn.init.constant_(self[0].bias, 0) + + +class BatchNorm1d(_BNBase): + + def __init__(self, in_size: int, *, name: str = ""): + super().__init__(in_size, batch_norm=nn.BatchNorm1d, name=name) + + +class BatchNorm2d(_BNBase): + + def __init__(self, in_size: int, name: str = ""): + super().__init__(in_size, batch_norm=nn.BatchNorm2d, name=name) + + +class Conv1d(_ConvBase): + + def __init__( + self, + in_size: int, + out_size: int, + *, + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + activation=nn.ReLU(inplace=True), + bn: bool = False, + init=nn.init.kaiming_normal_, + bias: bool = True, + preact: bool = False, + name: str = "", + instance_norm=False + ): + super().__init__( + in_size, + out_size, + kernel_size, + stride, + padding, + activation, + bn, + init, + conv=nn.Conv1d, + batch_norm=BatchNorm1d, + bias=bias, + preact=preact, + name=name, + instance_norm=instance_norm, + instance_norm_func=nn.InstanceNorm1d + ) + + +class Conv2d(_ConvBase): + + def __init__( + self, + in_size: int, + out_size: int, + *, + kernel_size: Tuple[int, int] = (1, 1), + stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), + activation=nn.ReLU(inplace=True), + bn: bool = False, + init=nn.init.kaiming_normal_, + bias: bool = True, + preact: bool = False, + name: str = "", + instance_norm=False + ): + super().__init__( + in_size, + out_size, + kernel_size, + stride, + padding, + activation, + bn, + init, + conv=nn.Conv2d, + batch_norm=BatchNorm2d, + bias=bias, + preact=preact, + name=name, + instance_norm=instance_norm, + instance_norm_func=nn.InstanceNorm2d + ) + + +class FC(nn.Sequential): + + def __init__( + self, + in_size: int, + out_size: int, + *, + activation=nn.ReLU(inplace=True), + bn: bool = False, + init=None, + preact: bool = False, + name: str = "" + ): + super().__init__() + + fc = nn.Linear(in_size, out_size, bias=not bn) + if init is not None: + init(fc.weight) + if not bn: + nn.init.constant(fc.bias, 0) + + if preact: + if bn: + self.add_module(name + 'bn', BatchNorm1d(in_size)) + + if activation is not None: + self.add_module(name + 'activation', activation) + + self.add_module(name + 'fc', fc) + + if not preact: + if bn: + self.add_module(name + 'bn', BatchNorm1d(out_size)) + + if activation is not None: + self.add_module(name + 'activation', activation) + diff --git a/modules/pointnet++_encoder.py b/modules/pointnet++_encoder.py new file mode 100644 index 0000000..861dce6 --- /dev/null +++ b/modules/pointnet++_encoder.py @@ -0,0 +1,149 @@ +import torch +import torch.nn as nn +import os +import sys +path = os.path.abspath(__file__) +for i in range(2): + path = os.path.dirname(path) +PROJECT_ROOT = path +sys.path.append(PROJECT_ROOT) +import PytorchBoot.stereotype as stereotype +from modules.module_lib.pointnet2_modules import PointnetSAModuleMSG + + +ClsMSG_CFG_Dense = { + 'NPOINTS': [512, 256, 128, None], + 'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]], + 'NSAMPLE': [[32, 64], [16, 32], [8, 16], [None, None]], + 'MLPS': [[[16, 16, 32], [32, 32, 64]], + [[64, 64, 128], [64, 96, 128]], + [[128, 196, 256], [128, 196, 256]], + [[256, 256, 512], [256, 384, 512]]], + 'DP_RATIO': 0.5, +} + +ClsMSG_CFG_Light = { + 'NPOINTS': [512, 256, 128, None], + 'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]], + 'NSAMPLE': [[16, 32], [16, 32], [16, 32], [None, None]], + 'MLPS': [[[16, 16, 32], [32, 32, 64]], + [[64, 64, 128], [64, 96, 128]], + [[128, 196, 256], [128, 196, 256]], + [[256, 256, 512], [256, 384, 512]]], + 'DP_RATIO': 0.5, +} + +ClsMSG_CFG_Light_2048 = { + 'NPOINTS': [512, 256, 128, None], + 'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16], [None, None]], + 'NSAMPLE': [[16, 32], [16, 32], [16, 32], [None, None]], + 'MLPS': [[[16, 16, 32], [32, 32, 64]], + [[64, 64, 128], [64, 96, 128]], + [[128, 196, 256], [128, 196, 256]], + [[256, 256, 1024], [256, 512, 1024]]], + 'DP_RATIO': 0.5, +} + +ClsMSG_CFG_Strong = { + 'NPOINTS': [512, 256, 128, 64, None], + 'RADIUS': [[0.02, 0.04], [0.04, 0.08], [0.08, 0.16],[0.16, 0.32], [None, None]], + 'NSAMPLE': [[16, 32], [16, 32], [16, 32], [16, 32], [None, None]], + 'MLPS': [[[16, 16, 32], [32, 32, 64]], + [[64, 64, 128], [64, 96, 128]], + [[128, 196, 256], [128, 196, 256]], + [[256, 256, 512], [256, 512, 512]], + [[512, 512, 2048], [512, 1024, 2048]] + ], + 'DP_RATIO': 0.5, +} + +ClsMSG_CFG_Lighter = { + 'NPOINTS': [512, 256, 128, 64, None], + 'RADIUS': [[0.01], [0.02], [0.04], [0.08], [None]], + 'NSAMPLE': [[64], [32], [16], [8], [None]], + 'MLPS': [[[32, 32, 64]], + [[64, 64, 128]], + [[128, 196, 256]], + [[256, 256, 512]], + [[512, 512, 1024]]], + 'DP_RATIO': 0.5, +} + + +def select_params(name): + if name == 'light': + return ClsMSG_CFG_Light + elif name == 'lighter': + return ClsMSG_CFG_Lighter + elif name == 'dense': + return ClsMSG_CFG_Dense + elif name == 'light_2048': + return ClsMSG_CFG_Light_2048 + elif name == 'strong': + return ClsMSG_CFG_Strong + else: + raise NotImplementedError + + +def break_up_pc(pc): + xyz = pc[..., 0:3].contiguous() + features = ( + pc[..., 3:].transpose(1, 2).contiguous() + if pc.size(-1) > 3 else None + ) + + return xyz, features + + +@stereotype.module("pointnet++_encoder") +class PointNet2Encoder(nn.Module): + def encode_points(self, pts, require_per_point_feat=False): + return self.forward(pts) + + def __init__(self, config:dict): + super().__init__() + + channel_in = config.get("in_dim", 3) - 3 + params_name = config.get("params_name", "light") + + self.SA_modules = nn.ModuleList() + selected_params = select_params(params_name) + for k in range(selected_params['NPOINTS'].__len__()): + mlps = selected_params['MLPS'][k].copy() + channel_out = 0 + for idx in range(mlps.__len__()): + mlps[idx] = [channel_in] + mlps[idx] + channel_out += mlps[idx][-1] + + self.SA_modules.append( + PointnetSAModuleMSG( + npoint=selected_params['NPOINTS'][k], + radii=selected_params['RADIUS'][k], + nsamples=selected_params['NSAMPLE'][k], + mlps=mlps, + use_xyz=True, + bn=True + ) + ) + channel_in = channel_out + + def forward(self, point_cloud: torch.cuda.FloatTensor): + xyz, features = break_up_pc(point_cloud) + + l_xyz, l_features = [xyz], [features] + for i in range(len(self.SA_modules)): + li_xyz, li_features = self.SA_modules[i](l_xyz[i], l_features[i]) + l_xyz.append(li_xyz) + l_features.append(li_features) + return l_features[-1].squeeze(-1) + + +if __name__ == '__main__': + seed = 100 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + net = PointNet2Encoder(config={"in_dim": 3, "params_name": "strong"}).cuda() + pts = torch.randn(2, 2444, 3).cuda() + print(torch.mean(pts, dim=1)) + pre = net.encode_points(pts) + print(pre.shape) diff --git a/modules/pointnet_encoder.py b/modules/pointnet_encoder.py new file mode 100644 index 0000000..6e414f2 --- /dev/null +++ b/modules/pointnet_encoder.py @@ -0,0 +1,107 @@ +from __future__ import print_function +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.utils.data +from torch.autograd import Variable +import numpy as np +import torch.nn.functional as F + +import PytorchBoot.stereotype as stereotype +@stereotype.module("pointnet_encoder") +class PointNetEncoder(nn.Module): + + def __init__(self, config:dict): + super(PointNetEncoder, self).__init__() + + self.out_dim = config["out_dim"] + self.in_dim = config["in_dim"] + self.feature_transform = config.get("feature_transform", False) + self.stn = STNkd(k=self.in_dim) + self.conv1 = torch.nn.Conv1d(self.in_dim , 64, 1) + self.conv2 = torch.nn.Conv1d(64, 128, 1) + self.conv3 = torch.nn.Conv1d(128, 512, 1) + self.conv4 = torch.nn.Conv1d(512, self.out_dim , 1) + if self.feature_transform: + self.f_stn = STNkd(k=64) + + def forward(self, x): + trans = self.stn(x) + x = x.transpose(2, 1) + x = torch.bmm(x, trans) + x = x.transpose(2, 1) + x = F.relu(self.conv1(x)) + + if self.feature_transform: + trans_feat = self.f_stn(x) + x = x.transpose(2, 1) + x = torch.bmm(x, trans_feat) + x = x.transpose(2, 1) + + point_feat = x + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + x = self.conv4(x) + x = torch.max(x, 2, keepdim=True)[0] + x = x.view(-1, self.out_dim) + return x, point_feat + + def encode_points(self, pts, require_per_point_feat=False): + pts = pts.transpose(2, 1) + global_pts_feature, per_point_feature = self(pts) + if require_per_point_feat: + return global_pts_feature, per_point_feature.transpose(2, 1) + else: + return global_pts_feature + +class STNkd(nn.Module): + def __init__(self, k=64): + super(STNkd, self).__init__() + self.conv1 = torch.nn.Conv1d(k, 64, 1) + self.conv2 = torch.nn.Conv1d(64, 128, 1) + self.conv3 = torch.nn.Conv1d(128, 1024, 1) + self.fc1 = nn.Linear(1024, 512) + self.fc2 = nn.Linear(512, 256) + self.fc3 = nn.Linear(256, k * k) + self.relu = nn.ReLU() + + self.k = k + + def forward(self, x): + batchsize = x.size()[0] + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + x = torch.max(x, 2, keepdim=True)[0] + x = x.view(-1, 1024) + + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + + iden = ( + Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))) + .view(1, self.k * self.k) + .repeat(batchsize, 1) + ) + if x.is_cuda: + iden = iden.to(x.get_device()) + x = x + iden + x = x.view(-1, self.k, self.k) + return x + +if __name__ == "__main__": + sim_data = Variable(torch.rand(32, 2500, 3)) + config = { + "in_dim": 3, + "out_dim": 1024, + "feature_transform": False + } + pointnet = PointNetEncoder(config) + out = pointnet.encode_points(sim_data) + + print("global feat", out.size()) + + out, per_point_out = pointnet.encode_points(sim_data, require_per_point_feat=True) + print("point feat", out.size()) + print("per point feat", per_point_out.size()) diff --git a/modules/pose_encoder.py b/modules/pose_encoder.py new file mode 100644 index 0000000..40b67fd --- /dev/null +++ b/modules/pose_encoder.py @@ -0,0 +1,21 @@ +from torch import nn +import PytorchBoot.stereotype as stereotype + +@stereotype.module("pose_encoder") +class PoseEncoder(nn.Module): + def __init__(self, config): + super(PoseEncoder, self).__init__() + self.config = config + pose_dim = config["pose_dim"] + out_dim = config["out_dim"] + self.act = nn.ReLU(True) + + self.pose_encoder = nn.Sequential( + nn.Linear(pose_dim, out_dim), + self.act, + nn.Linear(out_dim, out_dim), + self.act, + ) + + def encode_pose(self, pose): + return self.pose_encoder(pose) diff --git a/modules/pts_num_encoder.py b/modules/pts_num_encoder.py new file mode 100644 index 0000000..2210c21 --- /dev/null +++ b/modules/pts_num_encoder.py @@ -0,0 +1,20 @@ +from torch import nn +import PytorchBoot.stereotype as stereotype + +@stereotype.module("pts_num_encoder") +class PointsNumEncoder(nn.Module): + def __init__(self, config): + super(PointsNumEncoder, self).__init__() + self.config = config + out_dim = config["out_dim"] + self.act = nn.ReLU(True) + + self.pts_num_encoder = nn.Sequential( + nn.Linear(1, out_dim), + self.act, + nn.Linear(out_dim, out_dim), + self.act, + ) + + def encode_pts_num(self, num_seq): + return self.pts_num_encoder(num_seq) diff --git a/modules/transformer_seq_encoder.py b/modules/transformer_seq_encoder.py new file mode 100644 index 0000000..13aabb2 --- /dev/null +++ b/modules/transformer_seq_encoder.py @@ -0,0 +1,63 @@ +import torch +from torch import nn +from torch.nn.utils.rnn import pad_sequence +import PytorchBoot.stereotype as stereotype + + +@stereotype.module("transformer_seq_encoder") +class TransformerSequenceEncoder(nn.Module): + def __init__(self, config): + super(TransformerSequenceEncoder, self).__init__() + self.config = config + embed_dim = config["embed_dim"] + encoder_layer = nn.TransformerEncoderLayer( + d_model=embed_dim, + nhead=config["num_heads"], + dim_feedforward=config["ffn_dim"], + batch_first=True, + ) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layer, num_layers=config["num_layers"] + ) + self.fc = nn.Linear(embed_dim, config["output_dim"]) + + def encode_sequence(self, embedding_list_batch): + + lengths = [] + + for embedding_list in embedding_list_batch: + lengths.append(len(embedding_list)) + + embedding_tensor = pad_sequence(embedding_list_batch, batch_first=True) # Shape: [batch_size, max_seq_len, embed_dim] + + max_len = max(lengths) + padding_mask = torch.tensor([([0] * length + [1] * (max_len - length)) for length in lengths], dtype=torch.bool).to(embedding_tensor.device) + + transformer_output = self.transformer_encoder(embedding_tensor, src_key_padding_mask=padding_mask) + final_feature = transformer_output.mean(dim=1) + final_output = self.fc(final_feature) + + return final_output + + +if __name__ == "__main__": + config = { + "embed_dim": 256, + "num_heads": 4, + "ffn_dim": 256, + "num_layers": 3, + "output_dim": 1024, + } + + encoder = TransformerSequenceEncoder(config) + seq_len = [5, 8, 9, 4] + batch_size = 4 + + embedding_list_batch = [ + torch.randn(seq_len[idx], config["embed_dim"]) for idx in range(batch_size) + ] + output_feature = encoder.encode_sequence( + embedding_list_batch + ) + print("Encoded Feature:", output_feature) + print("Feature Shape:", output_feature.shape) diff --git a/preprocess/__pycache__/clean_preprocessed_data.cpython-39.pyc b/preprocess/__pycache__/clean_preprocessed_data.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59ba801f8941cf8fd6889367667b2f2d1cd10e9c GIT binary patch literal 1268 zcmZuwJ8u&~5Z=e#`RqI}iIj&Bk_!sl;07Uq5JJKwije3mqtkIakuTiC?5#r*Ih6@Y zL6J5hksLt*1t8G`QScimsJ2B3$`7DP#q9YKjFgjhW_EVIot>HQ{Hdu5!SArS(D`l? z^2=<-KlG9NxUy9gf(ROs)=1M96-@A#q{W0KY}6bq;XGqfic-r4GIYQ`pskW1Ef>mA z8qih+CWH$maG|or7Rej3&jyqd;9&pr4j6hSMA@7O`k2%!CzzWo1@Q$Y{-Cd>D{&km@(sjQ)GEJ~CkSf8Ec9m8iEHO7{Lc3t60D( zG-42;-{_>v01e?QztNTH3IwXrU00nnX(UfqJqbaYWJ>k|6&e@3@Cr9Z0|KzltPwq$M}Be(n}y{WFVq9e;d#wD(=j` zS4Vr_j(0!THdyWY3B|Kh+Bl{4x~0pI^kR^{0y%+x%_B(ich)UAWgL}Jp%D&M^H8kG zY4qj|ABoE*sOX6wMP44BT}5wZej-Av^cs~{OqA<+u^%R$w{bhT({82#W=>T5VH)^R z?a^pd?P3^bdTds01#wrJ%6gM1io!mg00000 literal 0 HcmV?d00001 diff --git a/preprocess/__pycache__/pack_preprocessed_data.cpython-39.pyc b/preprocess/__pycache__/pack_preprocessed_data.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afa6808afa39bdb818e77257ecfdc686a38dffd6 GIT binary patch literal 1525 zcmZvc&u<(x6vyrPy_0MjQAmnJX^?tgIb&r75JICMYZyD)N#+-}x1_7- zX{l7C9xC-zwMAKRfdgYw5(N~B6$x#Yxqo(bJH%*r17>-X`zXTP`Ua2OMK{@l4T z`6(piFPtntK1@D=rdsGY;j|!0_ofNu4tIYbiNiha!`l@e4-Or9#KXkrQ4$Dp914Fy zlZcaKAYu{CX|f_#`9MTsAXe@=Z<5F4sWYdP2v_(oR1wTw&{lbjYn*;TMk~+31G1aB z`05|@9lR!=K~rB-LNqy`$4=|Cbm81_?aXc61y$EIIreZw&uvIcIbFDSyxz)NQ2CYi zkAt>ToYsYPXW_g=I6(M40g=gjmgp2D!2T7&1tRE)0ElOn=oBQ}3xo$m*b^ZTe_Ntc zknk=LJ`hn)L<{PW;Pk^?<2G8E%PLDN*3@~W z70WB0XF|;`Es5LuNY6HUgS6tTs%4oL3UxEg(o)EjHK|NX%ZWdM)p=pOvc4rovGHVGYvZd7wlZ-&)y-5hurvcJ z*}x2L7zKs7k?%44M@W?J}$x0 zbUdl|g&6Zxr{hM}2O`ts=C+>H)wufh7Ly{YE2ZT$(>Y4wd)CWfEdYRXRpTbj4%v&j z;0*WMZf=`PcBzjLG!HA_J?IEsqpLJ_)}i0{-+s;NYgBH6^-JZ1#Xf*$%jpITG!)64 z?vQUj*n@I{Qh}0by~l*x#NCAA<(`IOn^PG1PeUwL?Y5-#S_cdGJ1Bfeo?!Q%%dp$& zZ$FeBsN19=5@o1%w(h`--Iz4u@=lz6NX|Eg>h z|9!NAT42nE)P)gZM37qY8rYg>U)CkV$lk_jpldt1iu+)qEp;~5=p#tjX(2ww<5cKN VL?b7<8u=0R;1@%S;Obwe{{nlZq;>!R literal 0 HcmV?d00001 diff --git a/preprocess/__pycache__/pack_upload_data.cpython-39.pyc b/preprocess/__pycache__/pack_upload_data.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efb8bdf30b12f0a42c2f599c4347f4f14235b46c GIT binary patch literal 1275 zcmZWpJC74F5VpPEM{)^o@di;)kcLQ92q6UGgrK27n|wIz6KtQ?&rbJ3LcY4i;f5Hwi(NNp z1QE0(X*<%C3MPVgBxQn&5XS(xh<2FT6>%DhB#nUV#}KY)nh28ipbyECrfV<|JxHJj zYwy`PvPV9$C8Y!c2>+lEEd#^`qVM(ueN0Ac-!KBQF$vtzwFJi)`+e;C2O2}R={{>% zLs#r=;8#Httf;xSmWW*gKM#gu~Xc5QYEkh8jO1NQz+GsnRBUh!feNys+LG6|wXAGd=Fg!V#g<8Jd7Yc7<=e8TYy`W7#siXBz5}9A+D3Xh zH$`drOum9qV!4vi*wCM7`*LpTxyjJe_S#Eun#LG|!_s~ihnpV?BHJoTust7y^ZP{9 zaeS_WXX6=OnvbV)3t%jAlaFg9w_#$&^&2yl)wp{8DpN3#l{RWVF@=NRb=JY?4M>=P z*5f*#>}2!0l)1>vIe;?$LLk6E;*+p%Vv2|W!d85!&{s89&N5;)ifBr z*;U*r%-mR9_-D7PyWdo&@TA4yS6V_=Fd8aXLFN8zeE7d_t5COD=5uM|k6kJ#5p}`G zv!&!Lb1AuSZDgchoO9jD-ig^ilTpjgUGzm>uaykhU%JJ97p5Tk0qCWw% CZ%YRN literal 0 HcmV?d00001 diff --git a/preprocess/__pycache__/preprocessor.cpython-39.pyc b/preprocess/__pycache__/preprocessor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04f6370fba5afdfe601181d968bc2e05ca156526 GIT binary patch literal 6134 zcmc&&%WoXXdGG4(>FIe7hZH${YD7I&vvx#3k~nc3@7k815r~C$H!#ND^tPsoSQ1%U2_T+2$bJf-6UryAy@Zc zzOMS}(N$ml>Q~=aVXaoS@ciy>``}O4EbG50Gy9Z~c^m)aH_WoQ#o5T}WpCEAP2TA_ z=I!=eVGj$>JZ_ybW(jv#eC7)avNR*hqH@ZzAESQkQFUDF)lAwM)yIuq!=&9&bG*=7 zK>8MXZE^3Hj2F58aiO=!-!PgIG-Y0aW{KZ6nkqCkUWeui?;1@5nkHX>X4&Lh$S?Ax zkKNu1zrvS4wt8*8!rOSS@>RZu_ZokbZ}82JS#Moz@U4gTr%e8uU+rxQ>)DprJY~JB z-0F42HF5O`>un=_{mcqtTXfo1?}oT4I%4}1d)2yYeQ?yfDV%4{snxq>tnh}Bx5cKo z$vdacDMSCM=<;jg<}WP1{h9NGYZ=&^0 z@us-dwtiuKW>b4=$Jx79_X{s*EqA+2x9*F693`nd?5Be`dOsbEpsn5w)9_v#@;qOB zCrv1+o%AS0xid+VzC(TF52aoFzk`2r3qskPJ+$R=YCU7hQtZ@bR#ZjKKFe_n1F?5y z0kyl1wxfwICgCH2FKCw%P`Z;aJ&+!1=+fR`B!cvKBA#2a3eD%eow4A9aOWW26Jm#B z#5)ri4@E!SnH;AFakLXXcofLF`RzpFageJ~TcAN-B*{!NmUkw{x>(#H}S zgT;JSw`CnE*r1k%a$lsuBpyVt22JDap~^n~#u_~&QftVvnRBLWEU2wqEQ)jLP79{Y znXN&~pZo~6axZI{l@?S%Im&(F&~|_Yo!!hn?9Ovf+i@ak-)eV=QE5*c4U#m`{y2OH zQ^_P7yQf_l$EhwPeGv&*R)%8{4ET|3pslP>2RXa<`DkCe@NHv**UateD~2Eq z3{jb>JjWKq6trK={-Dp#<9zC*4lr&%0cYea@p)sWuRBCSu56*9LAxXXX3%zL*0lp| zuQ}81wbN=|p4Lq?OK5}^Ycjcw_uS&q?x+22#{OF*h|g9n+Q(oz9AQ7Tr{1)vio}4c zR@+jAC1OL+{<<~wm4~!5qj!zI1Pj>8S0z>0b9jM!Pwi*)E}q)ca_T-SsIqdAuArW$ zAfFXgW$1H%)#9Z+`;_h3S*s~=Yzfw@4$G(Zu#7{<`h*VT;wSbV#xSj^8d#&e``nf| zxvdvAM4mQ^=k^QI%Nv>g4N8-E@emtsYhkqiY(%#i%}e%2pNM}(y;$f z7o|uJ!ikV%PQ59^h{xlA$B&~V98W;!hC74)C_d!6f^(=JrXs>$yKy8EbVIGR!$DT; zy>49>xR|5|+8u|&1E1<5-UPWzJ~Q;b2N~F6hDt3DJ%@M?jQqdxN>med#*e ztYgc&DCwJXhU(xUfS*I{e~`(c=FLE@2uS(L$)V-}YEBNd5+La0P|Nfg)XD~GB>>8M z1!@PVATS54hvK~+?N`@Hnj0kmt=dg`g#6Zlv--hV|vR;x3(*IQzD^h%PIFQ+CM)*8~HdQbrY`L~q*8ie+;13~^BY5tzX*Gc>XiEohjM-ty8@fL}S@VF^tFJF8S5KXL(vtK3vZNc17i(^|qx-l%uyQ zQ+`@i4C(5sg&L3XH1+-I1HOm=X*Hs;eEbd;=pFy7?%nL+z zs~IJF2ZjnSfk=C!K@^VmvECTH@Zc%!_nno`oP- z`zjro(^`t-mAH+URE?LP5{hZ1y(~BZcNJCR3A>Ak>4X2CHdCL28R529l>=E*O_Odw z))j+nkmj>0#aQU2M7^lysWV-`t|+SoiZ!3w(-v|Sa9WGEfW8)A1S3+33*0VFmrz8UeaQCFsIwLD!_S5yH=ty8S}lug&v@?3sBZKy@HrdIHc74xMw zWbuiIUN@$jX;Uq!_4Be-Q&yO6sg2(EXIfAF8d3{Pe22iiq^z)7EfNZKy5m zu_j+LBkxSFsZDhaR^2ppey%#NwEF|vEvXLbwK$E1PM-=~mi4CvjHxr-POqrCYQs62 z8S#_C^t#cn&gd~(=x<>CmeqB211`}`eAP~tp;_S@SjWzJYCCObD}gV&e3N42bT!v( zpKI1~&CN5-M?X&2jjnEjtS#*Jt0#>P+>^b*{-F#hPVRJ_%zb)E_@sj-S7fXZLNT!Z z&9e4C|JUK~)BiYl>%~7owg2;<(euZ*_L0H!=eM*cLq5RfG+o$wVSa48YYU^EyHQe8X= z;D+rB`F#{VXC}bf?QTeiKD11vFh99Lv=`^DWSE3fhQ|_{+tN$KVX(ZIB@xmRtDq=I z#76tQmRyNXSt*nZVLH`{^3)C{h8_nkxarUAlr%2$_H-zs%(isWS?WCB=%8 z;v^GcbZ=e6Exl4UB{%~;L{K!lG)}`&7MCyNzUd70|1L=(n;SiPsXaC{{2h6vZAnWXl z_}5Y5GmqK-TQ0O%(?(6Z$(qgGct`R2+Y>i#SyOBFiKs{_AF&-`yk+@o3$U% zTs@E{`S?zse7tN|(-oA7mt`uKsB?3faY$MRGoSIK49?6Qod1n+%nwK6n{;X<--Gah s9r{(-4Z8|iwOhrON57`)H~b~PQT|aGm?{&FfJ3U=#DC35pvh?d4?W{JeE= 0) & (pixel_x < w) & (pixel_y >= 0) & (pixel_y < h) + mask_colors = mask[pixel_y[valid_indices], pixel_x[valid_indices]] + selected_points_indices = np.where((mask_colors == display_table_mask_label).all(axis=-1))[0] + selected_points_indices = np.where(valid_indices)[0][selected_points_indices] + return selected_points_indices + + +def save_scene_data(root, scene, scene_idx=0, scene_total=1,file_type="txt"): + + ''' configuration ''' + target_mask_label = (0, 255, 0) + display_table_mask_label=(0, 0, 255) + random_downsample_N = 32768 + voxel_size=0.003 + filter_degree = 75 + min_z = 0.2 + max_z = 0.5 + + ''' scan points ''' + display_table_info = DataLoadUtil.get_display_table_info(root, scene) + radius = display_table_info["radius"] + + scan_points = np.asarray(ReconstructionUtil.generate_scan_points(display_table_top=0,display_table_radius=radius)) + + ''' read frame data(depth|mask|normal) ''' + frame_num = DataLoadUtil.get_scene_seq_length(root, scene) + for frame_id in range(frame_num): + print(f"[scene({scene_idx}/{scene_total})|frame({frame_id}/{frame_num})]Processing {scene} frame {frame_id}") + path = DataLoadUtil.get_path(root, scene, frame_id) + cam_info = DataLoadUtil.load_cam_info(path, binocular=True) + depth_L, depth_R = DataLoadUtil.load_depth( + path, cam_info["near_plane"], + cam_info["far_plane"], + binocular=True + ) + mask_L, mask_R = DataLoadUtil.load_seg(path, binocular=True) + normal_L = DataLoadUtil.load_normal(path, binocular=True, left_only=True) + ''' target points ''' + mask_img_L = mask_L + mask_img_R = mask_R + + target_mask_img_L = (mask_L == target_mask_label).all(axis=-1) + target_mask_img_R = (mask_R == target_mask_label).all(axis=-1) + + + sampled_target_points_L, sampled_target_normal_L = get_world_points_and_normal(depth_L,target_mask_img_L,normal_L, cam_info["cam_intrinsic"], cam_info["cam_to_world"], random_downsample_N) + sampled_target_points_R = get_world_points(depth_R, target_mask_img_R, cam_info["cam_intrinsic"], cam_info["cam_to_world_R"], random_downsample_N) + + + has_points = sampled_target_points_L.shape[0] > 0 and sampled_target_points_R.shape[0] > 0 + if has_points: + target_points, overlap_idx = PtsUtil.get_overlapping_points( + sampled_target_points_L, sampled_target_points_R, voxel_size, require_idx=True + ) + sampled_target_normal_L = sampled_target_normal_L[overlap_idx] + + if has_points: + has_points = target_points.shape[0] > 0 + + if has_points: + target_points, target_normals = PtsUtil.filter_points( + target_points, sampled_target_normal_L, cam_info["cam_to_world"], theta_limit = filter_degree, z_range=(min_z, max_z) + ) + + + ''' scan points indices ''' + scan_points_indices_L = get_scan_points_indices(scan_points, mask_img_L, display_table_mask_label, cam_info["cam_intrinsic"], cam_info["cam_to_world"]) + scan_points_indices_R = get_scan_points_indices(scan_points, mask_img_R, display_table_mask_label, cam_info["cam_intrinsic"], cam_info["cam_to_world_R"]) + scan_points_indices = np.intersect1d(scan_points_indices_L, scan_points_indices_R) + + if not has_points: + target_points = np.zeros((0, 3)) + target_normals = np.zeros((0, 3)) + + save_target_points(root, scene, frame_id, target_points, file_type=file_type) + save_target_normals(root, scene, frame_id, target_normals, file_type=file_type) + save_scan_points_indices(root, scene, frame_id, scan_points_indices, file_type=file_type) + + save_scan_points(root, scene, scan_points) # The "done" flag of scene preprocess + + +if __name__ == "__main__": + #root = "/media/hofee/repository/new_data_with_normal" + root = r"/media/hofee/data/data/test_bottle/view" + scene_list = os.listdir(root) + from_idx = 0 # 1000 + to_idx = len(scene_list) # 1500 + + + cnt = 0 + import time + total = to_idx - from_idx + for scene in scene_list[from_idx:to_idx]: + start = time.time() + if os.path.exists(os.path.join(root, scene, "scan_points.txt")): + print(f"Scene {scene} has been processed") + cnt+=1 + continue + save_scene_data(root, scene, cnt, total, file_type="npy") + cnt+=1 + end = time.time() + print(f"Time cost: {end-start}") diff --git a/runners/__pycache__/data_spliter.cpython-39.pyc b/runners/__pycache__/data_spliter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e12fe28e7e37fc7366b53bf57bdf0355d4bd950 GIT binary patch literal 2563 zcmbVOOK%%D5GMDnmtPH%I4|hZgV4fCPHlmrC>o=^v_??}NV`BmY?hS0@@iKhsW=wo zlU(Ff^jaV&ARYY+`d@hMsehrT&Tyr8?ex?o$erQx@p0yx;U=9x7=BO37 zC=rcCn+Hq!4dEe=E(ni=J9p|(trYbm2!h}#V&@GRP)@mX5ydWdxpzTg zkNZ4;w_o9l1K!|G&_ljuN4I(Bq7g@Y+iDwp6ST$~;*q!s_;*ZeA5a{wKsYr7D5F_j zsOlLsn#{?9&gha-G9wFzTS>bMk2?rv>COoA-|4>9gLish^$=!8OPJ3ZoB)9EAZ}jS z!35G;(fHOZCX!{T8PBBg^x0UL2!|@6Rj&y{CSxH@#8_5jnlZDPmkDR$HS}g9QD}+C zje8`tY=9(@L9&H?+jeA<=M2{yC;HfQ))LDzrF)L?l*k7rNUbj~scA5l=7~~_U6OyE z_eX+fN&mPU2+_x>{jn^EBGvuznLaLye(~ZJlLBTdEhnkYuuGh>Rd>HSrm|Dl6xbV| z$qv9nh^h?@p>5ivTeeJZ?O*w6Z++&J;;_|&GCJ)kR68F`O_l`td`K=SlC)#ZYEKPT zyrgmm#>(68nmZsF7r=dhYgjhW>1KUr_;zhm2u6a|wVN6cG;)Lt9b^fo=Ty7e1Kpp~ zIo&5e|9}ab&&$+K7JKuq^-YuAMjt80zl9b~ysU0rs2 z8$kD@;N{4K>Uc63pEZL zE)2#lbq^Ziw5dmTplm?FcLXfH3uU@n=g~DHrWmQ~2VtX8Eu*TuGo%`PzHl-3VEX$N z{p-yrcjwM;UEq(q@VU1LX23f1=5z3u2W1Wkzu0<0)Yp$846yvP3WuJ5@C@Q160gCs ztUNLF)X-)xG)}2ZG)hh&yrrt_5Q11FBVoL9EDGbFN=Vqo85Y1!F6an0Vx;(Ck6|PkisK`2XR_4czTBY;GXH4^L_c?!RZDtLvqQ8TFvo!W#%r z12<5Uu~Eq zg*YRp9Yt=m83mEM;Y$40+ptuSk3cr}zCY8YOpm{UD|Byl->lveJ!^LA1FCzjx-z%x zA0|4>SGnYNyp_^E5m((e7F3HY~pA` Jv_Nqp`Y(3+ZuI~F literal 0 HcmV?d00001 diff --git a/runners/__pycache__/evaluate_uncertainty_guide.cpython-39.pyc b/runners/__pycache__/evaluate_uncertainty_guide.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23d32766a36c44b02552fe5686f17ef6fa3bbe9c GIT binary patch literal 12622 zcmc&)No*Y1d9GJgy;U#Vo2#fMwTxP#NJ%r2MzS)RMbgaJGqj=^d*Zgom1e&p*;H?O zRV|6_>Iq=VcHlr61qqOY$4E0V0?#F!f4nwRvj;lhxv(Wwvbi%L}ju%RhUwFt1?}kDa@#RvNBtpE6k~U zs&b?{Uzk_xB8(ZO;v}qj#KAz2+h9gih{6@XnsMnmDcdJ

S7yp$=8dnbED`XrLoaF z&00-$Hr}akpd#aXMX%}F)!r;A>itf=$Z35!P+nnlmCo3)Z7ykfcLZQC2oGIvUJ zxWP{}EVw^L#?k1i3;Or8b;cRjAH@p>H@NvoE0{dOqxg;R7?0yO%3t6aKJff`-a#0P2&Rg)N~F;$JlBW;Y{vV2Mtvl@?%Y2zB^ zOAhpnslF*My+bV-Z>Tej-l<)7L;LVT+wcvpc_Tc&!;;!}^l7badZV2&Z)j`0GeP}) ztiH>T%lPIs?cN{wS|`gB>J;5Rac-$KCe^4Z^qt&gPnbC8O;esHcl3|v+7WN2Gv-HV zE_?O^5VMKD9guZx^t|?jxyR3GO#OWgX)Ri&-Ol#U0eH=A9Z~DdW7c_|nx-B{Hz>`i zbirHP(b`e89>v*iwPSwFkKU)9eBY0Dj`=a#buHr^=jqM~D($duGw)>Q1$7svP@}iw z7@h0n{W$e~!tOP2-qX9-5!Z10nN>NqR#C)hyCqR>cy^=cZAz=&^cqdi=4BzR`}K!T z#dgarN2aRfnq6}q*hS$KfBCKPKYR0^T92+jB}9?K-7z2_Th!}b%d#7wba%Pd*p}(q z+I>5urn+b)i13!T+oTq}FFEeJKGvEg_!LEYR&vE)eWJ+BApNhbuHL*OlLN;TqZlQ| zC>f{mL*3IgJlC!@t74LBO=3SXIdD3u*IgN-y%Ky3z)$M6hBPZ+8`4w&G`H&InjE4U z+j)q+maAY)a+r>73&qO=1(D=HFKND8=Z-Yrzxm$#GD+~T!6m`? zWL6crkfT#8@X55Q2g3`B!>af|lmUyN#z31iikwU9tKe?VO(E(+M%U}2TJ!`Fx{Q`^ zzn&vwH(q=HwfAnmFQdf<_Jg9D6VYQ*Bxht4!+V5 z^)095tu(g1&3bKxa7^rxA}fTfVfBho+}~X}oT)E2w#6I(FpIBy4T+|w^idZ4nJjCJ zvN#(?dET&)4zmnf)}uz2rC3z`S=4>vo3e>wgg7n`%s>Ahd1zbPa*NZ zU^(j$GccR zy@iac#2;W1XRWzI&yCJ|j-15EG`{M{-&Rtk!*;=^K3*fKLI>XgUw^_rJoyd1&Dy$W z;^2^{-0+!4oDuE2P@jCg6Wi6=h8IVzz5`AOzMt?c-`ru`^o<{AKh#O?J~qHff8LI? zej6i_#J#~^qdW#V5%W@8=}rc74*7;3`vDZGAL_eCJFe=7iGzwi_eMISJVC8!1NcAC z?(yE9ELw*re`xp-o|@ErQ|(NHl!#)T1TJqJGpp77#4aQSjW#uJ;tAWn%Cxq{wAMK< z>svdR87oclv~TrS`Wkq4+xp15$3M{e`JnEcb`Eph%l7tSbfMGzrM(;=rMy8Y70Lv$ zN;hlua6rgd$Yxr}+oDd`K$!NFqg!4FgLK@IjMPf?iYw?QpP9>zyOq|NKI#P=FITL9 z8}RL1!1HoIzMSF=S97iC>qW1$`SoY)!>2?Ct(f9UOS$}%j6j(yZOT}!Zf^*1*n|V# z>LQ|8+i+ykw$TWdScU93L6e-O zGU5<> z%_kN0$Yhsam-$2SCQaRE*NP$HuwJnuSs7pjhg0yB6Z_W*Wk(>BBy;vrrMf`cxJ$<( z9yZRVSX@uBIi`a{#n}|{!_3s@QOY7mTd|}# zO@=ft!z1PeeT=j3=s~)UdeDOD8C#|gx(wSG)oLjQ$&gcr-pIbe3KmbNG59G>@$>=` zQA1K9-KhJyf=rL${xh!5hlIgZ#g(kl{dkw%kfKTk3y*;DnP9MtN6{~VUt*RPF4}A- zdAyxMO#(G({94msf5>ONNGHmnlDCJvSSRjhaLs8_#-S{erJ7ObfOK8IrlB>>m8#xr z%TRq;YY%$~7z2zC1@oemGJLVn7Y$yEqv_|3+Lf8E8o%V6y#YsMkwuK!A;Wp>zkRI`)Nay`Ul_z0nkubz* zUht=Jh7%BVr~DH%KX!Do{Q}A`LHy%5`zikvKZ+IeJPG1H-9GJ~?wnCN&61xFphr9E zoZZzr%e$Jl!gDZ>7ar?;@v+f97g6}$`2VIu8G;Ue|wwJU6F{fO?xyKvyWoNX_+SG?KIi~fwi0Q*&dK9q@v(tFCxrUyn@yk?eQziKcxSj)xneei|tweWJrydadO7b`exT2yOvKz9w*PNgV` zUx!eHOMz_EuI;Lf_~e&KPk2Wu5>J)26WxH432Pm;lz=7fmWuEvLVciq59*@QQ#t}^ zY+~Z4q#r3kEJyZ^!bH$rB3$_=#!L0;TA3#6PD<;*TM?{)g3@o`LFm^&+i>s7nQqPV zyO3s;oCi^8&E=*GRegk)-A1LjZF@zyeX#3>9N$|Nb~WTs;P&x0Nd>A`xJ<8?E967q z4s>h>4jl!ZS^yp#B+={IijuPdN_}D}Cmc^~U(G#ZE3M>GF0>`D=EOZfWdmP){{A|4 zW;hKTmmDx|vD&COv|CuAwi&GFvwO=hOfA*}H4E zueavVmg~0Vs0{!smSChWT+IQd1}|4zL)@u&MYRz0)T>L+*piq4+*-qX8gkciE#oQ{ zO|6lah*pMYDxVCDm##KU_2Ro2Ei=J(6*y2SN4on^d!&AHD#J)G^p=h9o30Xu&uly{s=!}m$n$_@!?5YnGkyvoB*EJ`vzKI z16rh>2Oj;vtEUut!`XI1<$qA+2gQcAK%o48jTVeT?|0!0i>AA>Rw$&xpTh^ zlP}S%xv>5m_(3BrN5BUd35|obdc7j%Y1>NcevxuLs}ja@S&(ffGPFdN^!B|p2ivMH zV@0kn!HT?JFYt5$DvVWW$(m6e4gBShFW8NCk}mhQ?U&y9Uh_BH{vMg!)*cu0elS#zIq z#O9qP}bBjsu%^g;$>K}4lFTs^R8Sx_!uh7a-uTCCXQKHQ`l#sgu{wjxKn?+9^fh&O5tM(zVS)?IfRzra8 z;yiNsD}m9sMYR(p0rQ{yq%iyXOy>_(8?of@()w*OPGm!@Q6?y-C;&VU*$}f86^S;c z5Fb(UyOexPGiGU*A#3C|arFuzfSe3O$~u@LvGY&}UEy;4pa$~kD7QX53o_Z`uJ)=l z$zd-?pT|q_JG8u#L*apBcF)v`=c&ES98g~)r*g%WM+e∓HEWz@gNaYxm1w)O56= zOJ@v>W?Y;5p6gU0ypffxSqPqaFWo$PZoE;4fFev z9W4F;iEP8qA1l|u%)zfJG@1mNA_JMB#Skm#`hq5VI83l*Zo!b|B7Z*&Htt%8NHQ1)I-z7 z2M93bH2meWkG6j5{r=|lJw7z@))4a~SCU&mLdw6-e0>&XUPt$o^B=x6I9y;%lJA?m zWG37JW;@c3qIX33u@N0a9=+flwnyX4Nh1MkMSV?9; zCQVTTi7AYh>9Aa?Hz996jRh{T2&5rJh(LZ6N{c`iK{Q62wKC2H!6h~zGd%>2npL-1 z{nQFwol0nhl7@n8G9A!Zm!8wX1h%Sm#V@E!G+0^2;Zn2JtQO_MrA7zQJCuoE#IOJf zNkbPuK{5)rpi9nM4fj8Vn^jSXONZXme$uxG?pv!*JrhVX* zNEftm1iFZ9QmnOnC}aF+gBtXraD9Rs zcwjIO(zej9Z6!g{XxkIK89uMM{yDhB>5m>5sL&0_`>WWI)E}%#_8_0?E1~;B8>EWd zg|HH?j5f-pyA^o2IkYyhg5lyUCFdwnT2J8W?g#0WE=$K(0y?dd}_wS|qG(|x(&h@*7IOQ|< z6kr>#ftl_yqQbnn+DhG|z^RSEX~Wxm`WZfm>Vxg&6Z=3Zhx$={6^v2pzf9bKsadMS zFqR?!c-#`y!1JU6o<0+Q zL-2&4(Y9hlCG{8x;Z;D^t#~PJbkgXZ;U+}@)FYL_-ljL?BhCPv zp8-t_V;72FK8No=8pL;mQ4x>8Q<;%4psLxUiU#7Vtx1Z7J3QBiUl0{p2kBK~51)kW2!9!9x z_hF78=@b}?W0V`GWSJ%%xrL=y>)tI0U#{E{3SZ=Ce;aQ2_#>cGMWTdFGSv;=t8Ij_ zsR_(1hx$NJk8)so00u$)FhOSxu$M>`-Xb*Z$Hd?bOMeywsx_D1NDvm1sy40^IGtra zGog5;*eMY7=Y*})lbHT=t9y`=h}a0fYLUTLWpS5#0^2P`YNKUM6h{N(2EVR zM9C}xl%VoyNMp%0ShLNB3I&U|sEguBN?o84Z_|ksiziujc)(5Vynh#mZ+sY&4sM(_ zE5N~t>S>>FwZLqIN!*`Ikl@+Bxf7^crOhcm$6BWkCHCj$+q1bYM3u;zQJV{hTD`U- zwWRlyB-o}Y$Qz=NS;iaXhEsuOGJMTM{uKufjFJZ>@>S|~8VUTSz50C^q^7w;#|U5y z>^O4hK?ufsUWkweM68O=;QkDW2l4rddI9ptY%`DH6{4wjl(jSn872HDtvNa-^3w(Bfz~V1;svGM5qRuw;J7h??LPmiXYyZFzwz)2hN;2 zcxHviX}pmYcW|}?PyQaA{2e63pz+229EeUU|{$r zdMQQZ-YK;vgk9oB@am4dvx#RqJZceXDY2m*n28UuX82HO4bQ}&_v5{Un~$iilBQ12 z(8MHeR_e}Tn z7G`?J&OBphMHbnXmAf4xY7Y7EXl|~h-dFVA5Q@{-M$0PsuF3b?jF- zIcx&VpowitT9kaNm#ZNMr1`ht2WZzX-O4lk1Ep5;qx3?Oq?T<<%O5?>_o>E@8YWx??~Vbs_MVo9F4i@tSaX3@gH9Xe#tgV0aUJnf zN={LN_xEsT-riP^kzR+lY&j%=IIS}=*g$YqFHi%eFgw^rL5%Q8Dp+POKQ`EBJ$Sz! zJjLA864H`FSvfN}CVbnt9Q=wTZE9q&g91cA&U;P~`peTaU?_kVtzr`=4m{S;a8A0A pfqPY#OZO_{*8TX%K+})DgZL8Trx15g4Rz8aLKFOkL()+(O zlz(-QP=*sDiUk-jlDKUY0WTuJNsO$GEV9TVi@?Yt2;L${ki>?Pl>=lE!9nNT`+p6j z*p_!GvR=J;xbNO`@4N36Gnu4@&rkmI@aF9=YTCb2V(=3|;`8{r&k+bs=oPIH{OScA zX`^CP&4Q^@o>{S~kwS#yRwY`E6=EEZRN~b{A;IxzC0R`sQXJPS=|URqVwFraQ^t~nE(BVuSr|AHol#mFN~j5r6k&27DKROlaQ z`O#-szqXp!Rq}nOCLH-?uUtVYezWd6VLW!zbHnK9`%bA|b3NHCdF6U7%+1`CjwqMB z_Z_!c@zC1x?usgYGjG*v>*bAii?!m0Bat!uR=wJ&*PNPny;$<<@*(n4Zx&tWYSBYT zZKLluYc;NH-l%UNBkOuauj$&=UN32?eWP9!G=5=Wqf%chR_tO;*p+$-u|~aI^IW@J zTX&>WD><@6iyZt!!-U&G#L;NW3&uCKbzSJfcoZ*~!W7mctze0Wh~hUQj)}C$JkkqM zaa;_EVZ>tMgcucLh{eUC7!UIfhzZnBh`g8*(uI%W~ckaTIeH`le=Sg<)*cX*Kr+UgH}zaf4QPlNb0-e|^Tl;$n!$@ip;v zKh`zP)3&va?&~|cu6ahs^sKGOR#fOahOgPNfgFQ#;sZG*=OhMlEY3-aNE@xUl0Jor zSxrR8v~dmnrI1eZ2$^Y0XT71$aF9kxmRpY?HqseIIJTps4$_@*5##wC*qZ1}q9iWj zTzAT&oN44t?dVT*8TV!=E)v_uw-2^WZ?<#LH@TNXA3|6y{B1*QYr|(X_m^iho&Uaq zuof-Td=B;3f?3UN9p+JT=rbphJmdKds<*)55%1`>*0xYOe^tAEy&dr*zIBi0vF2Ny zV}69@T+4dLMXGaxQrr6P>b(^&BFAV)(R!hi_oGzziGKHOtoGy%W+XK1ZF*IWtyNg` zY_}xK4bN^Ay-k&@H@!yFvqf2|r((P1mZQ?ua?P$e_wAx|ia-DQ_@95_Us{i@ zJR=5BncCVt+qqY)G>ed}R)Pd)dCRTWlrB|T1R`cv%C6Tse6`uAluJd=$$5_J<-)>T ztypy~<-Xp|pHdOG*>I$a+YqF(XWNpPN|_surxH}nZ4^t6N`Bx$$m-t1h6B<4^6Kif z_f=|OA#wz*100kHU#ChlDwHxoOUQ?Myx#3jGHOHo5WP+MS zOIWVwsCc(b8D3pQg>$crNeop02DbA6^DS2aF3PC8%3?ZNTaXnsOr)?S<5NK2sPh1H zK&^VisVR%8WKlPmeZQ`<-mMEqS+}lza7(3#Qg%4q2~KorN)u(msto4?%7WxDC-2KL zs1)RwC{t!p2$lQ|?38m&%DPn1^}4JUJr%q9#;rF#xOPiLiw#T+l^POYMp#zQ$CTwd zm35U0<`68zcZz_0wq2?eUDvjsYR{`HRX}%fWwXBS zI4iWAm4>WuIVEqU@zC3>*H(z9;GuX>neBIX&Xhzx? z)q_7vpD>2?aedfKB4=37>Ptq{oY2#Hl>d@?QcoLMlnm>w6Tcb_m%AKW0{Hbl6eq3! z7QXHx0uP`l^bP?S0EC1Sf@zUpq9h^=%3%>TS6dULZOXNc+`6n+b55gP+JulSx5j!H zoLh&KxSNhx&PQb$-7AasOn!!nn}gnW~t=3ZeEvHP*c9$3*MoI zCdkVZW`zaH#P!y~OID92yB3~8M;Uy%pFaQ@uwX2m7Yc2#15En%8Wv3`7unXYB7OV# z*NnE_HoRyjMjBC=zV5|?@*NmLzR^kSXh7d2a*b`EUfc3gUfQ>|bz%AD_q6XDq&Xj( zK+S(_M_Rv!78$}vpkY+RrnPp=%Wgq4pwAKC^kd(HN%VbV$85(rf7Flf=<=_;G3bQ^ zm7)yW@u7CN)|*q3#u2ISn|?&3CpF*VnQ5H^gtYAhHgf_!^Jsoz2ReaTTbeidM1Ocy z*V;*4Yn}0?{Ny%z#z@m5<0tzgeFdA}PJWcUTl-M!$Ai2x+8Ok9H`kks*`-eRm*%od zAB~;`UEPlcH4-ydc2(Y>olf?6LaJB@E3MQ`Stsrw&U+?MEpI4GIBrQrYNdL`RY|wx zggE)!T5jB_v`+OoGJv~Wu>$B64|4(F<-n6UZgnZwioRL&N}FGOuHSw};?asR7%k=U zGx8Ws$|7Z=B2Xr!O%?sVDmlBM8@x%Nj~ELZi}NS(ft4T6@?_z@*g0P|z<@DKY%m=%g`vW=HR2MG8IQ z9CKOwb!u)<-BGSPM6IT^tzj`XO>IYB?3Zhg_j<)Y=EULtAtHkPpDxI_w|t zQ-0b{is{F(ubJ(M_GEj?pV-jb)BY4fGJSLYw9x#FpAs{VEx`0_`{0JwK7{WazK2Ds zo%3f=cOJ0?Z@#n8r>CFVF+c&@u_N9Qe|}r{BX_hXQQ7iyel(0N_^}`c4mm1jgL(kk zgMJR-AwP<6&R^g-0dtao8S6UcPh&Ml0lUZjqcjrccC39I=@b4iN>2DE#9@rQ2&#ft z{Kc(|$UQdNd2i8Q?BsXYC=RGeGwkMgr*JgNiiO99IP%zRpY%^|EwxYir=X_H+keWv zC4K61A5KAOtaI951n+_00-juFuA~wC{0p49C{h7$QmtM&b7$*^yifUk_i3qpy1k5} z$g)4cJlr|!=erc?oZEq1k$QPz-2GdB24%}Zy5>W@oqu9I{7tC0 z^H6VV=<9-5^b^>vyni0MWl>G<6>-wfZyOvt<%5#w2aM_vVg9SE?+U)(UT+h-peJ*n>TFz*%f zaQhhIuVJS~uy3!Rhonzoz-x>~U2G!GW$d6aMtk`6sr7e_4vg|Iv}gTep>$rv%5`zh zzvv&C!g|Gd{~}Z22-eQW$e;vh@H)%t@g6NMfEGF=`gnKe{6qWgEAP!K)E@m_+}nHY z?LGb1cJ*(e|C+zRyt#m$E{RX=$-me>Li>wvyifZ)n0n9l>h=xzlU(Q7m#9*714a%`jj;A5&Vp{K2w#=3YgF%k zT~-9wwO<#3cDB&w8JX@B4v%5oNbgDL%BlBAgDGbwiEs&I+rVYAA3zc$+Zj zHtQ9kGVA3EonnOp8_$8WfvHmq$b*F>dR1FdI_|>0Of2Q3&8AV5#i|>& z{O2pT@}`=o-*aT8*kBil-J2CVA=I>l1zDExEszgI2%1FTiD;#2^<3C=zBRfLr0RGtY#%leL5=G7X=A0{4cWn zfSgkBRD6YML_MeG!aEaNhiAw;e@-5!sSvBGM6>4NDB{2+7il>X{-%h#S%eeJTC3M9 z@*q_tlU$yspm)T<*{&?vt(B!-Cg_N?TL59;*Hx_O7Nsmcl;p8c;~PLGVg|+r(2LyL zl3jjDxE0{FC;H06RQCXI1VZBM<}gV}qz(omE$a7cflviq!5vg@3VEBVL`b%4DtZrO zE8S6tUTAT@>}`TljRwx7d`OgpZ^|q+nmCja>2PquU4R+|a)~m|P!8ab_t;en`%^ri zvb{!!Q8f*2FWnhjsU0!}+$dbsO-3I&l}i1-{47l%>eb1yDe06f*C}v#UXV>p4xY6` zIC+YC)&7*#=cviM+yu-l&mxw8E%3?Qq1*!{iPI)|Azk2R7Zl|mav33TxGnuM6(^}6 zBvc8CF@por5DgiZg+zMjLN+OQM8P+x#|fHch`CLvMSXBPkU8avR=m)y~ug;cW8K4JK=)V#ICLx(s(X2`{dWiiwVEP;lZ*& z>Xg9~@M!ea*u64vn3guQiS5IU5U3nfNyAA=E(HcxmDz>W@=u6Ve@wymXu^?@xM6Ju zHQvmylK|lZoXhX!X@K`Wy4MC)#L6|mG?1uDBmF*&l;9y4%CZy%t9N}5lf5G1_#XR1 zS&^>kqpKPnNOvhVNqPHGYL;|PR|NK`G&R%bW>_J+M~QXiM7L#)rXCBaqfWjc3crE` zgxrv87!m9}-|OjEL(imr-G4)%u@__5jN><|TkP*h>I?dmX&EQ=C4Cw3G<>^h{KfTg zBgeko3DgY!sLhPIU@XBWGR$6*;q-DOswd$kS!T~k4qoD@L4K4ev?V{v1k&UeoI5;%n9%3oHBDM5GjiyWMqbb_85U~fj0J9=Lrz?uGt%ZkgS^cK{<^;$|Lylt}AzKjgC^Hmopmoll zQ_*IvjDd2?VYy*O{OLn{!?H4^#!0Ks2(7 z!!2iXuStr9LyZn{d8qw%JIdvc(F+{9noG_u4Qw!l4Q6-w<(F*NUNo==HC%YX=3Oa~ zBJO_3!A5D`yUU39a4s};IZpexp-*bV$x(pa z(N}R}Km?b6K>=xM`DYZ+5b{R|^6@|s{T7a@(moib=+te;1Jxf z_;h=K3#V=5pcdV;B!r3k?K#MK66M-f3i6GzT{jAJ!|k8p`)ysL8*aW&fUJKRGg8L= zHKpIrr~9Vo9`8DtCwHK~q^qKha_LS5rmKMcCYRA%E>W;T0b4?`K(>2#$ZWyhqGHmF zeXly*R;R+7YdS*9H>u7ZBA1YNgovC)L<4lE!ARqf*6hofXNJ}^z=!nk9~MPGLpLE(4}_`0K$MCPs4R{9#<@!7cJYj~T_KEX#>WRK(%dkCtA z`q$2UA)cDEaG!2=QEeAb{|&l^7?kLhFkH{!o*!5-hv2= zE(rL=M9@P}+wz8dTtk5WXCNQLm<1!~3%visLEgtbAFeC-ElF_Wr`e+{5AoI3BwgM* zBG;#4aE;qaR2{F#4 zC$m34CW-xI7TXmZP01wP!#|R=R&ZRz>oL@$;52nQavejj*1hXcje@-q!DaGje;RK1 z8X=Hb7FTvtrMvgdJdHqjB_5*T`9+@yHlG2Q2Ej{$1g$mTVG>(7k3gQi{1`m7==X9C zy}ESTC0-;;Zk+8qL(hI*n1L5B zlC;*|NghJpb($o@qux66l3jSA*IiTWZUJjOddR`n4IWPLOE9u;dQT>T$y)4031r|H zO*`Nna>zJvdXzl)cKFf5Nh(bC^Y`!^_wYqp2P)*DI$}63qy6mb9^JQ zZ@a{&v}3@e=N7jYaQg$Zo?G`GH4x;&oWz3)UAz0OeS1#t-!t=k22VQWO{{ZYDttgx zxPgE`9uGMf*70lbEijDOKD>9V9Sw7w3+G5mz~F}$rf;DZ8#Z@w!xpH`#MR)D8996t z&kZ=`A}YzMhaY*#-@=#-g>aMM?Gw%vcmm);7HO(-o`NB4w9@Y={oWonz`HPh3Kibq zzefxaC_y@I^24Cdzl1INZ>g3q^HWjxbIZFreMVybTtD?(KP?X++DhhbhL6#5h=;3l zb1i-~(0er}=P{X9@^tS5i=CzxK+NYTmZpHLp1^7LRZ7z939OxE-~=r@8Mx)?p`YBM z8rLY;13lE;DlMLTzM7tmC4st20H7t9P&n%vt^BV|2An+5lP2pFG$^>+i_w(@nCRbw zpZ^(5UC-GtLOnu9cMbtO1dH)#d@lLh+R{#c-jNo}fh6HkfV<-nU7x>kcM_CynGz~#$z$kj?fq{h2# z@EouABs%oit~te;3*H)Z5V5B;z(j%~`Ii*@fC3YXRiif_@JXxfX94c1u?-dzW;DtDEHa>b0RW(M1YZ~T^nUzwuCj|^5| zLIiTXYZamIJwpwK0&3AHc&1$xlB|DtmH7Q2Z^*!|_7ts(Uk+I6bHt}EX$q2d(zpC) RP#igM2nFfz7su1M{(rbMi6;O6 literal 0 HcmV?d00001 diff --git a/runners/__pycache__/global_points_inferencer.cpython-39.pyc b/runners/__pycache__/global_points_inferencer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00bb362edd678acea848b919763ef96794176349 GIT binary patch literal 12035 zcmc&)No*Y1d9JPZh0T4Fs3x_sC5qHGni*NPJfp=k_6)6PV*{P z9MfdKe*631|Ni~GVkVPR@VPrQv-$Bmit?XS82m&~_#(c}OC&-OYDF=FU)5Am)+$<6 zH+7Zj^omi9m=Vq!m1s3)#yB6T#H$H2!TD$U z&T&3dIabY^dCq4m^VJ1&K~?OD&T%octA0rlIWhc15ySS89eqbNPYCq`r7-dW`&U*A zs-L`T*Mu#<=9Vic#qZP|JIu%KxK5ZIxoel|HOH0Bl3T9V!rIIoX^V2ny=yzoii^>P z`#@Cjo4HZ1t(Q06FV>11wnRnlM!njo*X){myI6ATavOE2w~LN_v*;qFvC+HDT8%sF z@6|U@k#*dn+jOjIZw7xmDQK_#LD^{akuDOm?uC3eBu9a+AqOA>nqG7@5 zAY&_Z(5CiXWnC4j(4NFiUFgDiqL_w=h$wy|;-pB6%oEj&ic?}p->zVu?$~L>{G#m=_Bem*u+S;sjtA z`mSOqW)4Sl&Y$=a@9Yk5?A`vpw1G{_;PcU4t!wT|u@Tam4( zPj{_3g%0poaPlW(^Ss7L!Deu zMoX4Q4qj`sZ{ zZQY&i9PxCX<>=SJng;%Mz@)X@1;zR0c}3;F50TcQWrF8ue=mS)ZtEDYlE<8Rkz|PH zH>lqNr^nqBJ4)L?>-7x9dT1jm|ZbWVB^g1MG;Pl;6LG?jML-&T7&UPO)7 zj$-ser{G1Y?=$tmJJ{`+T|guhoN0R199^q0qghT#mK(0sD7u?|vfgwXP1h1-=_eo7 zAKMknDYtAtT`kwFn*GQsO1t=rZ;k!vOaIY&a{UDnfSIYSJ+$nH#Y(dXPHH8HO_sNu zdd*iQ2rXbJt5SB{*0Gz-Mx|UTx^~{R9XB5~=4-{OeKr5BcHyiaaheTV`f&>kRCX;( z5>fg3hVA+Z>gF_xCEHJa;DWj8?smflqke65_10ZKHLwwx!zejS$q0=f>cZauXV#ik zKML^yi2(WX?wwojy}NqL&os&nyHc*%7A+$aG%Q-ec3s<#ciVi;t@}}7KP&@?p$b^v zvL6H9auuS)*XoXMFdmI9@QR-!R9KRwQGnm5{TO&auX@9-`37UjplJYozprn+Ul+D- z+`IL`Ju^j^vI20;2`+STN`NwA{S4PbaBL)C#R4F=+Wj$A3mR70|)Wwy}`nqke&~{cDvc6@P+?B?*yIHTT z5KYNFGJb`~FRWgXESsH`!$j_KV_Qz*1ZVMe-aw*gX>CLe{tR_O%c)~(PEVpHr)Je9 zEviqbX*J4!Nj0gawJch4YU|j+vTwOdpCyR8-iPv}(Qn}EEFy6s0)^TkaRqTftOzzV zhz-#a5gOGnL+Y!o3DP9x+D3j|)~k8DQ7>(RC6-&GJ&BuN2TwSgwpcDiWg64_2Az<6 zjk1|iQ$pXk7LW83XbJR>A9I?elI=JJRlbRy@(aD>It|o;Oa_?M5-<{1TQi638eR8n zIE4u^`11JQ!}hUJY?(I*nZFCk^Zhk!mSkCEN4<&iN2k82wbi!fMmsUmcS846Hy*U_ z!npCYPGVPqj7_3e+ktFr8*a)?d&Z6`3{U@o@6@n74K*6(0MhNK{5UR1=U zm3GX{ZbANI&S6jYVn2YH^FwV{Z^yZQ#Eb8$@^9QxNdE-2q7BFJb>%^=2UC*P5vd>Q zUPPoP70+PEl+HNG({=)fIDwgYH7~IX$xovV#hrYnZr@auc2ZSZ=iMnUxr3Rp(zM8U z$^J?o;?Ud4kCP8-UswA1pzgeK9&`$=q*1Eg}`#h*CD7RSs#8x7Hb=}pRz18 zLiAOkJxFCy`CjlgQO9Q#3&3^sqS0XLk{g6B*^aV^0so_D!f$icf?KCgmC3G@+$dnZku$%$bRz(#368I%H zO@cvqa_Jkp{aBxv@^%%feg9d`FvBV0MuQll_hsr0d!> z1+5vOP;0L(OZ8=?opTc~7*#P8%!^V=ZLO`|0-J@G`-jE03vvtfH_flTkONyzdv2W__iS}fB%A43w z+tc0@QZiQa-n3A>jF%GAPYsCN+4hkQrF|6NIed?aR6Fm@qVGI%3+{Ypp^qakwW|RG zbPC7aP7A=&!Te7%X`r=kG0UWyGz#$l(Lbv3Y)OyiJmB=v8ta_%7C}#-p@7cT znbK$lukZ@BEs9h?Z`7;T&J@@Bl+USf;5jX|&$XA~JS=-N-ZGSm{!DjndU|`M7WZ^< zd`xMt@JTEYUfSnD0W02l@1(baQ#>I~K8*lZ7upxS3mEq)@P1yLVw)l_MqXENp2s>D zyh0Zvor}BR6;c(?v}Z@j%O?FN7deXO}*{L~Hg-wIq}qYDu|yO95?()P3!lA-Ayh6>*l< z`_*&vE|{*pr2)FW+`dSd@Gdj#vq+af$>*jO?1CL7?{ZKU%Tvnz2v+rtz5fwf(c#-q|yMRL1wrZ-DRP-U3tD0%p1@4xP!n(mqb71b6e_ z`WPIfN2+C62<#FjkWPE^!TGqKV@VVD=DefcG1Ohb3B3MEs{Oh*AG8HBM!fNfj=fy- z;*hYfi8q5bKF!zR+Fo5{z1kWXux~<}gjJ`!1=VFnxT=$35 zPSFjR9j;@-?w0VuoKg`HAz{_1-=n&$2<~gWDFVG|V9X0LqUAb`9^N|&6CsEc5$v;$ zmFm?s_?n_ODVpW`G%VL$hFsNZ(YbHVbZcH+*f*-=DMn#Bm77jMGlxalX;g~amRp2> z4?r8{*xstJtL_g4{=K^it7@}e5q@U9Tp@>1*f6wg`1OpCT7VJ2p6GRLMai>-k(XG? zOWT#(SMx8`l~!sgAG#V>^M0DlsxE@ydYr)KHt@ygXV+nC>UP61Q{JM#MDKujO0%)g_>MjY0&pGECz_%4CbSi+ZY; zRg5;X0l16~RQkhRI2cCiXAZvtx!jtOc0L{KHdNDr*#pnfE% z?;aiXmD>QTpvxA1X7Bf6#qeZ;H&82M_s$-AP^E;M-axnbWt8BP>59hM6>3=AFvVq zinMHr&`!kJEF#`ytkvrkNk+0bwgFFzh=Iic@I?`1$i0=S}U9)Q(EI-5;ePTpS9l``i2Nq<=6FfM<)eU~ z1G(81it-`1kpvEprQhbqi7QAFnhDA=fkUJr8xk%vi6q%2Hz{dQ@_=TXASgr0CoB@; zxHy2=n+%l3I%FP6?@-OWDa^5h8ccZh9Ik~oX{LI@(^~ZnS+C=WUZp%)r{xhV?QO`M z*fTYYG=^p7fchE*3@eU#Y_M%$GN2H|VLOO@hT8IB8FHAmHnfT3!&OVSi1vJ)l)D(4q`CtzWzLe$>9W8+m1fTL zsTqdHo>lxfquasjhN9Ev?YWkavZ40rI#a7HHkpQ zG6yg62+T$`3S~@TEQK;AP^M7q1WFXlh^r^iC&(GvggS?SWzs;Kme&xr#ovN%pgpfG z@aQ~XimP*4T0f#ucvr(;_m}hk9nEsh|D+>Xt#$g45J{I3pqe;F^oy4tZ~esmgU#zD z(lp%<1x-j=5@k9hsK~PSR8O0QC)UwijshdJiMR*+Jqmsz{0Z+|ZyW7MJBr@8>VXd& z^J2&&oQcRpbV`9Y9LFqi{3hU6M%J|tl;_4OTb*PN8AaSG_*t+>Q`W#@5C zLf0y0CSbCV-OWrefyK3s`~fY6I~H2mj3XXq(X0?x$OSzin|fM`v-w{owSNa=Pz6xr&h&|Bg`2V7P*#Z3WQ%e4Xk^{(QvcJ1N*_7xwx}O*_boR`vgb;9LcvFK*YYmqU zR5oy~s&k4t6FEdef?kHF!IxHXb)bPVaFyXvm>xGeJE%b~y3j}n9oNNk;OivXm8}%` z8Etz)_jFj*^ZaZIny*D=<$5 zY&3Zm!)1XIiV_FbPAmvh_pW5wO8qi5llH?4f$*vp)0wI8BA7g5d6oL?L-L2HJ4Q$* zGgpCtPJ@8Pz@6E{iRi$v`t1_gx!n)td#6H{b^lbl4+t2SaiQJU<)WuLi@{ztGJ|y|0Wc{;E8!#hbvp`5o(td$7 zTt$ASebzd22-h7jd=Iwv(C|+&C8DVGKcvgiUA|Wh>^! zw-B!a9cy77v=U1EV55_wXAVxqO$!~DYgM(m(pts=;d5H6a-e%mHA%Ojw#fIf7u<`s5>?wR^813;h1L3Cd#?fU znu<+6y1QxPYN~rbkK0t=V4@!fs{ACIpL|K>9L1}2k;z79EAgYSSqKpI^Y;m9e3N4LT{7@3S9h}KZ_;Z}{< zSN=H$BW^?(C`9B0e&k6?{+1HLnEVGM&jS`TQ)_T%n+?7bl%&Q39MBCn5BofAnQSu! z+k>L}m7D7y*5P~0VXDEs5%ls$sDw};U=d#7$0Gv|MC3oBem`);wr>+~EC$uq;vran zWwt$_by0o;vpF+J*jmAh1AZ_?_Db)qK>(z|0g!;!7tlX^%|gKx8v%=wi@402sOQU5 z#7TSg`)G@2xJ~f0awWQ%J!A(&H+xKWa|7&|Y5r%(o5>)APVe8OMLj-8$v)9S(|wsXO#mutHkO1m zU4npFf++;gRBO%upWBDWfEg5iCn^QyY)Uw{jvR>QkAt5-9W320sQ3p-trmvqy&P#a z%kq;Jg_D|43YMR?tOw0vC1~LXQSv5D{Wc}!*t3hpZi)PmimdfWtOla?r<7xLNc;3N z;Tx_J!g6?9(YFX0DJ*J-&~F-FhsXwg{9-&BKbrh4Wj+~CCX(@ZG8st@Wg}T5Zb<4| zI2ou`QiEo?iFdTlauD$h!=FJFM~g{Jn2E3g;ZPDd=2&+OUZ?fm1BX%7HM>}IKvRQ3 z9FArNvX4ZMd`8K)Dfu%b=E$9Gev$Gv)a-Kb;4OH8I6l}w2<3-}fxS2}*hZHY;mbrw z&2HWt9qh9nJUZq_hI?B=K4OA3X9mZFPu`Y;Uzs9khX*?_A_6|%vx_i3o}mFl0k&ur zi^Q=c>m89Feon^+GH|MVzI3X7{8~U#Um!Yl$j~9~COyi38s$;LfgjX#f02u;{|lo? B0^$Gw literal 0 HcmV?d00001 diff --git a/runners/__pycache__/inference_server.cpython-39.pyc b/runners/__pycache__/inference_server.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bc45812fca018751983653337bbcf6ed58914d2 GIT binary patch literal 4831 zcmZ`-&2Qtz73Yu?Ny)M-f2^~1f>cd|WCgT#J}xb=P1@|PlSMa*yPGtPT7cDZD4Q}x zDl?m-|FLRW|E9+LuY|_WaLd1faEr6lY8Ow|W~OboZS!>6E}nK;>U(Xk z@3;Ma&<&%L)lb3jK%09GsnOB}#yb{-sonzL%&e;R2x$rGywOUO_mv+*QJkY*~zZ}L= zX}>4)Ea@ILnXcc;`-41-Gqux+R4%B`yFbjbSb%fx<$HMeyGnkhk`(OXu84UOsr#`U zrb;g-SvMAO7R4fzv3L}Vh(0?1bEgfdz{D1*t8IT}bs1;eej2nL?r`_11+|s9hi8fV zJiyZ{Hr4j|OT5Y#pR#tquk#vT0$1i6W`$+G^0eHp@J-WN?^nOFT&ukZ zMZB!*chAIpP)KRdiAX$2>EMPV8cSYtSXa* zqK%RAsg1rWxawpPw027GR5{TlPGj}4HCg5kt;sD&<+c*5VOe|sd-k7yef;ou?+Zd; zb48cra1aX}gkh2;DhvgQN6oC$kHb(qoxwnRQJ!^^Jzb_wIp{>Ot~^k%M6M18u?}|c z?%ug~w{=G^4w6BfCRuD=7ldQHEaHyBtl5~hRjxfAKT5EUKTtBvhJ9Umls}2nP$r{T zFQs{hhw&3gmh=IZrmbBWr(L~J2oerDYX2G2TC~0-isr(JvdbRdA6NBd=!cpOGqV#kxHm5ejHA_w}ru>+&%-@6k}V1L$QdRz*8+* z2eDYysn~V4#_G&to-^9GZ1Am_I-_~Zm+>m8<~ecX`mgVVLWSi#$7vXN+##I~5r0$epb zz2~lj*7J%R0A_1!^t~$(yST~)piy`sEgPWS1~UbjP2f_1#>6?bChjRhgDP>GLgIv_7Nb3dBLfbFo4qmWc-${4iJxLxY)zO`2F+!Zz+_)N1b>e^+G?%!os8y zSgA$nP{QbNS<&fB@diyTizXi8M?}aev=a@6x`Z$=IE3*K@bwC?$rIR)5NN(A#THHV zRGubMHQmD03;9sZG%4OO6NE_?m$hCYeHBCRVwyA=*bdka?EtzSz5Jc}E#_a)E$#Kp ziMwc!6O&#G9fL#o?K#{&arU6WsbkWpvq3I;anvy`N~7+@(Fqb1g1XJ!sl$3L1LCGF z-o!`6bs_|4af66)(ULA9r$+nGQJl$K7=@Xcx;)GzvT;0$n~qVOcnh!T)hJI>=$Oyp z@v`C;r*n!aH}P7bJD~JcT(7awt5{e-w*6$%_B-vi0EAoCL9^}z}%Lv~;`^}_UMUVV32pRRr`r4YnWT)oKh{cfbh-7sqrLhg(!MHbu zLP^PY&BGiEZM^FyaFu zq-*h0kY}v0S%W%E3@InG5@QE_vI=6^9;@K0;vT(u#U#jXTQ}1@>ZJ0$t+TgB{##q- zO1{RvDbMcbDB0XWE>zPOK9V8M*#b_u{f~nA`SS-OlL64CxD(J!+Lg zQM4aN2Sv`KFW<*|Q@^Uiv(}&6&quFc(ad+ectKHQuQNsc9Su~0bern3s4piDRf$E< zpR7Lv!wl&`{EZX|BqRK*pjYlyP;c1NK2#ED8C6df2*~YIhl&Z@3jxy#}q( zOXS}%Z=CC`@bpm`&fM7|M;pD}%IC+=1{&QYjT=eUQ2TM?+_s1h+NF|KGodSFsRGox zZ1@CetGNo0;KQN~Wftz>%V8A7Qi=!E^(m3hi2RBO0Y!XHkArLI@d)g@?2@73WHc( zrY#pN>uQ%&aT$h1m3_Vji9Ibu>TDe#Ttg{RWj3olcWvTsP-JOb0VU~Y4$iF$h$Ppg zP)N6=1U`S-0orV;esNxLz|)xriSqz%dvbnm(WN5N{1PuWnV_8`R8qtfFB6G~kfy`| z$g?8;=!J)!=wLV~&LOYTXoHA}L-ZPDZs~!bn8)>^ZJpca#k;Cwk@;N_Vg-d|!~qgh zFBd1ExdV;pqhDMSD@Pw~L6XiaHhS%frmw7XL8zIUZ_yf5K~rii5*8^M$;&wVxp=O4 zAG1VAgfi3g+$9xNT~vk5Y6SwhuAKK@LXug5q?eS(EJDe(E3y+xnrz4XwVFtgHBmFm zhG9SFLx45Z)iC^W*h!0{r#W*KxzprG+LlQng)fq(3INcTjR<;idLE48G{sT2K?8;u zO&i-UppTqXu7L!0btza5%7Gs&RZ5kmTB+s=n$g@SyxVx8Q4$pdrfSutE~3MMIVYPW zZ<4pZwtHx5uMgmdTW81QB5|zGk5Mr^JvT$6XZ-5?oNn>Yk>dQkIX^Nzqiz*XK>*gv z^Am>n$B8TgRoU>!cCt8ftv!r8ncSih9=4s?Eg$1?#uuH literal 0 HcmV?d00001 diff --git a/runners/__pycache__/inferencer.cpython-39.pyc b/runners/__pycache__/inferencer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60f13e32323218bcc06f19863d2e4b239d92cf25 GIT binary patch literal 11860 zcmc&)S!`U_dA{513y0)zr>G&dm?Kf7BwLoPD3KRQu^pRHZKY{$V%!erT#`f0rgQH| z;^dBNlnPR$QKO=7jROpU^g%%3KC~#hq(xB_EzqJs+Na{8D7riZddx##gcR=g|M$*9 zitcNNoHJ)X|M}0qpHs?Yk_tXQPt9&`yrn4rMTOx{1cle|b)F#+icqVH8T_iIin3PK zYPzYbRHs*sTEvWS-l#@vF*C;bNHtzdm|9uS#yNf z$(cF)#;c>XF>@?z8?WWfe3+l8O`4OOPgJLB)8;hilhv8ptU1g1RCTU4Z_aZ*T|HJ? zFc&zVsUEKt%mU}L)y3M9xuh!gWcP#^*;l`;h@2RGsEAQ}eox<1&67g?Kq-zr#r~D` zqUtB#x9h@|Uvn!}l;XDw_GR8j=gV}8+FH(t+HEb)Wh1$EoqBN*?r%3T2&XL z4fn37;WzVUqrO?$daqP3ZP^kPxi=fNW}|M`-J7Md+mO4cOTAHY?CT{LDUFT3->TQS zv;J;l3l&+%Ex9ens`Y0{Q}1^hB|+=^$x40GmUg{t%Q9_d_!A8aP8S(lp`$gm?<$+B zP=)p|Zt6l8#zVz4L_|dK8xf~OT4WxoW>lOOBO-@fOuQh*#5i(su`Kdo-GrD#|AZ)t zX)%Mgq&OpH#T;@eah6v&CKethKnN>bIxY$*WyGRb!niEgoe(Dh!^n3PLostWjB|eS z?ZIBOp^e0lkK?Q3>wKarimU7?UDZ?fRaJ4duI?H;k)5be_cc$kVna0=*TjcvbgoGZ z)filp6p;=_?<74+6LX4)jw^Wu^QBNu^9q?6Dreo1ZZ0UJCCj5nksIxfAsydW(Ff&j zUc?x_iJi&r6k6gU&V8p{s+mE}^uGF7m2r2L@*=UP{b0VMyK~)nPv=>VeH{#F;BODC zSkGNhoL`(*RQ~%0(t5N)@EjZL1yC*Q9OqREn6n^~4DsR?^;_cfgnM#N=@@8TysmtB zvlHH>%7|mQ$9MrfW4z?zW$7wA^OPwM0ew$$O0lcGYq!ZQD=RDs`)F-?vKA zF8%D=`A@(6AMJSoK!Tk3vX*?;%CrzIEf>x7Tm@ znP#PFS1WbfqGe=)hDFQRu50`8UYoDE4L>UEdldjNQU!xs_5;9MsX<`)TEp=T#-p** zsMP%&p~8|Za{~NE?FYaEdNrDM-8UFZ22BI#2Yr3xy@s%Tlofz$PI94# zQv#F`>u0zg;)#o9j*ItY1y~9?=xFoxk`R9KmvK_|4JjMak8U<(t>pT#>#yH_{ev5~ z{b;EPz|g5FD|Ht~u-Pcae8aJ;n`SD2fj44mZ@%ZNn`XkXZ4d_FFh`#Q&5!q>dZPLX z%PLn(j$>I*lxJUCt3e!>*0vj)w!KE%S!>G1j$L-wn!E0HqrOHoB@amQH6p*TdQGwf zcGd=hb+x%Gr*USp_&Tp3QM9x+rUrk8I;rK9u<}$(E1j| zsMNO#o3c?W*v&?H8-%|)v}$1!l<#cYVzn5RX-w}MbPn=m%4W(f2~FZ!Jkn3}5#Yz1 zR=I3DPEnOt&{JOPC$G^!9VmMP8HW`PunR8Joj4c~g+) z`;aL=*uVx!N=5e6>nMMC`deB@?PzYa8zb!{bWe5TLHj<86Hn_V_7zCaBxi=xLzpQ) z()RUEoa@KD_`WLt(jA9fPf#n`fcmd1ck6walC+LU{YdvBB0Z&e21BNFCrDCu5;(X? z%*?BKiG9d)8f_@<)MIt`x~g=Ns?t91PJ78c%#4+0M8-=FR{91Gx|94UdAI&`Wsnc* z&MW6J*WE%N7QKg^-Y>y&K%b0V0A8JshBXQkQg&U^!N?ILek_y+?bI#VAQB+*dMeOV z)O#`_v0DfvRRM%o;89l!?dTgN zx4iw$XX=Mfi3i#-mLe;K;;cMDfEc6&{0Ou_dE1ZG8`hSDWJ}l(xVD5Gl^4-V5>@$W zl2zdOEv)MJI#iJ#hm2=Y;_Hw-k_-wzB1`ox+fP{*8X?VUPzW#3B*b>INC`{&vy>xV zk*_21bIriqvmE57H66l?@=|`E@Xv@2>8nYAFCpYehK^%PbR^QMkgiWoiu6^sL9|L zR*gs@pLHYMs7QmiM%-98?q%UxWWJ-#D6pQ$-N=4N3(`&Ps)E*xP^h)vmZkcN(#g3A z*nO%P3Fbv9C326nZW^m+_SJ4yq%b2Kh^q?PMkqIW2%?aX-7NXh0bcr!xrGG0o| zJTf3(=Q{ISO6M593-}%vsZPP0L*GT@mfXee(f~(ZYF`5e=oC)4C%nZy)r;Iw9!KR# zui!<)+>#dyav~*^>dDSY?_~E>sJc#z*!JEPE zPCwDVD}ug2Qvuy=GUd?< z-kIm9uLvA4{ZX%eJ5yl$$9!&Qjy$(>oz>17e1h}dthWY@qCeK1+n(MzUypmbIFVO6 z=lNvL5nehMKoJ+b3*Kq(JWlnbIQ1w3TwUzE=v~CPKLYn%6sOsUD2TDk3eNX<_eJka z45%6x1AbLp@Gg5NrUA9M=w0T$oWKs)D+^k{<6i&=M&wWXJ9-g2 zQn{u_0mqIy9meoWG#>L_Iy`&*tpoE%<@AC14ak9MZ;2^v2{T<4pF5LxxpRU}30~yC z4lp=Ok5tRD5!fY6Af53RgY$8}$dV`SEqKSgjtN6s!nJbBC4_T?Ri}RU8?q|6ul1@3^r(R`PswVQn>2d#;3!Ol2v9_@E1EAi zY8!ANMSoH>%lBwluD1+%ptX{7$DHleJioAS)W`vh!hEW$mUb<-1h*Z4 zHqHFOs<5l>j|6VJyA8`~yHOQ>X0uWyKTy~(wrsfTjF5VO5x}14cWp<>VT7TVSb;U@ z%H1o4XX;uzwNeOuiz@{`O{P^3LGU$BVRKvf;`7sMus3zP>6j^Ux1Cb0S+xmzcqtTA zAe?8*)s{ohG)uBnbHb7Te(iQq_a__oY*{TeImlr3p<*}QoRKi)DiYo~MInSBAOc^$ zovJqq;n2nQ2-~EEN5O=j0hE4WY2E>;>pkd`Mr%>*V^-FEA-k5)CL5W%5dTquM~h| zT~uoA5n)%|5-$Xc_{z#Nbw%cZw|4G8L*Z(ntzTIIsyE2L*%t2R?y~pjM=9 zOGIoU&UOi5Bx9q|s7eYS$n%t3prr3>z{#pe4h&^cZ%&f0a*zRGeK-7A$tg)$+LaU^ zF!NjRwul&58~|VNYaGczvm~5&#Igy$@)-L16A&F>68j*9NfJK(e8AG8alalgRWKC- zDvg$qA5a&g(i9sMj&UtlQiLo*%wG4S_khRp9sk&KQ{1n(+dy5j32%|zH%T(i*UQZo z+(p7ATpEPT%`wPFQGs`3LnQK{JLb~SM-ZUh+8{U4oB{cl?+kC<3aP}*1WjSlkM7#l zYU947uno?xL9s$nYJZ+LPbC`d(yR$yhlG}yr7Y+@3=No5{U##&J`X^Jkp>#bTyt!w+(X+G)g6H@ zbB;7oj}HzhI&*eF(lAF3tm4Od%%WOygFugkC^FAH#}A*20t^Ze2-AY2(uV^Pzk-=< ze4YP9qHvHQr^oRdRSgbDB-JH#S~s*a>MC++HH)xB8tr+lz>(BR^b7uI#H_Z2FvSW& z7C8=L_|8ihHgF`B|0lPFUhb`m8Dbi~z@=o91&ZBku8-}>6zoS{M`JZ$otF`Bb!Y94+hnnJ; z(a)ZJwDS}9cebyUN$d1J6i^{S$x#!MQDo(Ns;AAtE9+`5hky~|L`(#}A4NeC2ZiUZ zcZ^P?6Gd;_=D;70c`@V>*F@AJI<3Gnj$@WMeiQI9Bb(X>$`fOqEl{$Jj1n#m{4BVm zCF|fUCS*SomdlM6G|Q8i{1^cF5okRM?t|;-ek1~|^WsH6+NxKuF{miH1>NZ4a-&sq zTD2!h7LH&!Qcs~yp_>yk6R=ju>1HOFz~Wj*enLy(8iiIiC&K5)F=mQ$c~SW98sf|b&Iu7_^5$_8>N;ZaA7d8zTa=Ji zmcK~}Eg`>$q!x8J-6JS;bv}25P|Fg-3CDT-)rS2EFJ8 zBO!F${4RihlW13VQs7&(9f;M_-LH2Zmp=4IcfTB|2M?3e>1+2leD!`YJuoQ`g|ErV zxC0F&9Y5Nvlqwk48{cZgUeSAv9tN^ zC$7U!EH_|B_~UQBb>q#i-1_R<>$k1zZ@+J?zy970zDSdA0Jx$SpfOf)ELz|YV*GKk z^4U}kn2s>GOo&R-et{_5D}JVZ)?OImt|NvVz|a9&{aZ|f2q}H}r36zYMceT{A*eyD zi0`mGRsJ<$5E?_-iMjC|#G*j2T381ygA%{7=%(mpf>U+VLdWeHUOxzfqV4=fAebTO zZMY*IV)~%#SliZ#Y`9$#-y(bc9c3ImJRqYhYScHFxF>-sKgpIPUwxVh&XZMM9G9zT zl|P|`^az{(hiFCq2{o{L_otNmGfL>%$RE9lrPmwoO{g-#(SzWsb!-5J6W+fEyvS@w z^kb%bx3LUHz?AaOsbyq<1dB{`pK%b}10`s$0lgACBG?1;9HzbCalv4gb3}LZ=uUJ; zR#qPDnTBhWXM2lhGGpjWyGfPk1-f=Alh{?BB$^pPgC+&ln|EW-ynGsu$r0LfH&G|@;#sY zBDJzGdW~{44591Z6*U=m&u?gOD4)Nd93~px4?!&d0F_{P0uUh*e%CSdc0&>?9tP;| zQ1?0kMA%Z)_WTfbpPS(TI6c&|Mk1nPD+5mh`27;u8U05A0bm9PE&@6}f&Ss+5(vv_zlMPco@i1x}5R=S8x@%04t1>JC{{>X8O&l$%8 zjvf?zxP)sFn9V%8e{D}95+)p;45-T8?;SaF`skUNZZmjDC~sn)M`D3ODbBk{NGRi} z0*hh%YJ3F>!*L*vyT-l(7#BhqNoZ^Qv_SU^^kM_!?sCX0?Zov!AyYs#@k~R|E8-_v zRq!(;c^hl8=t39@uXA9tb5>b*lpj+aaPSSQr^@6h1~3W^*z ze$uisL;{1CJ*!u%jWV44FRdP!;3=*7Og;NdJtrrSZ6^!2!lzLM_up{I zf8f67`X5-t1V$hd=(0JG&HqR_B60akl+fKY_Q|yUPsp3eAjD12&ZI>>$ap`@nKa$o zv{8zmDrz>CghX6{z*vFV0uNDZPaL^rSn(L8@H0#)C}&f`xlQCi9NpmO)8W$1qKbcz z&U$f_o~DsTvMfJoQADQ&VQ%?p%evbtRf86Oi6n1eir{X7+-i1p*j12Za0S|mL}MT> zN!A4HPkZw-;d7}nB4c>-FtFHIX1+$9Ny<&*>rliCKJ;=t8b6l&HDxgwPbQM_crqDD zj$|WQBW`ft;;BF-lKL{!Ej&zhR)Z*H7?KRCID$*!z)XY{2vL$SG4s7KcnH>iU>im{ zH|$c~0UZs8P&l9%$Tbo%^0z7Zn36w4VvgO~<(DFFKs~MoZ_9$meiOqDgiwC-7ua@_ z!)w3!7E;V&v&pT6}1G@-g z))^Ww5@3r~u}JJnvaS(#;g@TCAVa6xSAu^0YCuw7B06=>AR#U$UC4hLu~Ff(0ZW+hUM7NQ&vS7Oz8A@^~#|mRX*?4uLFcHKjt5bz3j>jw0)tSN!#}k#=>Re%t@pR>AHCM=SJX2Y!E*F+{&7SHU6C?Zjmo$+Tqfaz3YA@`C_VmJWp?|35 z$6jFl+D2Yii4W|Wu;o|Xas{c_ow{QO@#r1b38G^k*rj^Sab>gQmg}`3H+@IiqFi!6 zu$^YbMQhW&FRJ)Wzgw?umA7vfYsGC_A|v~5z1pbP?3#PCSaR#~G4hh{6dn6|(M3pY zBOf$tHLe`GRo_NN#&L^o)3K_(UQ$&1R=p@_{K9mlUMg0sM!j5f9jjd1vZY-s*|J0{ z8~j9qgwsL9)@Yv##<#UCUFgDi5-Ws6NSIHwf+@lxg5R(>AyOj!L@z|dNiia_h(*OI zF($?li-{F65#&vZDb$aPyqFQQC`*XbVouB>mK0}rgd<|{NgQ0U%BiCwhg4cDiDk6Q zaNaR-9CH}?wq|ODEOzFsn!H5|yTj{xzrQSFU`^p`|B#UrF= zDV=dgI$1xBk_@*VMQpS)hH!jeM;)X)6C%p0N6>n?llLN2_nChGJ*@WhK4v5|Y-wskjc-<1%q*uQ%MI6R6x|(_s5jk4 z)3rods>Fl(BfDZb<(92d)pE_M*$=Ivw2MFg=ERS`^e?R^*Ip0e~ zQ7cYDvbO8gYf6_YC46zRDrLuQ9lhRcRLZ5IYv)|sadSaou2!ttS90HM=g+9H(`?vM z#Vp8A*|jW5Or=8GwyWY)&1n=%wn}{HLbmGehR5%)Vb& znYZi0R_5IsAKq0-qLdX(cZw5Tn$kp>uqw^@fGt0n<>W(I29^Ar5Xw}jD1=J<26oE6 zA!S{v$W~oei>``Z|Ki;*et6@qiWD1|7%DX+FpMCqo{uWiu`63D>CeGmi1F@irEjUY zW7}X1%tMX7GBp+J&g!|Y;+9pa6dlL1o@*~R*QJxf4lt4~a z&*-a0Bs8U`^a%eY^n{)=GAPOFt;M0px7Ov(5`eDvp*U&tcky*r5V!z9p?3(Z02U;O z5JHoL5G7$@P!5Y@XrnbnTBKat&TYwhHD@>Kr5y;wT5G(Ax4A7ygtKFdwR}XT(7iHg zf8^&Vnl3dZl#Od~OBF|nuY6S0X_iX1{NbAS_ymQ()^mRYin@gxmo$fEq<$%f=I}f@#pA0G;=Bdnv zyhR(F=&^!S(EvGG$vd)6OhF9yLZDjSP^Pe*k_y*K^@^hsPRS0S@TIxdxL0YN=`*5_ zXt`nq=qMiNd`!!M6?5F`O0E@or|6b;zW!4G_ytKtE6RAYn#<406ErE4^o0sTUzB!K zv{tvaB``2<1Db6K#Fdv(OA>Rb6rmc>VH=}5Dg>paV$cRY=!Ae$l1vX3mc`n(t&*07 zLf~E%sv}RGoTlIm1*{X!QjDZW{yGAcZTJSG<=FR|WUQ%4r})6`8*c%8tN4vaCA&eU zrV2ASs<2HMnoqJ=sbm+F%i@u|L0$JTneh>eb+Kl>3_=6&fQt-ZPbfd3h2P{ z;sOO^&6P+4cRvKCL_L=IUjQ>aARMGF##m=~u`ZvJ!9!N82t(pbNVYB_s29gCNofM| zX||Ih)=nWOj+``pVG4;P;u$yGiHH=WZN!asVqOO4l=QcZSq&x^on12DGW>AMxT>Kv zEi@|am1QWuthKXl9LAz9M*Mz}N{Z}Lqm#nu>3zME5lQrj)5ulr|59_C>W*>U5o$H7 z?PkUJEVUhdwO_71(d!v&J3Kzum_&`S-SOQC^f-mp4D{O`WixadD5JN>+!+|8gkMH` z+#46upwk5CH0RFmE{LQzA!eQe%i5Ffk=;da!~;(49`z=@q?hs%V)kkD8=>}8d%8X2 zO>OJ#S#Jg*8MQfYR%l+@ONzOtCeV4ly|AsdkKntA?@^I#=e&8;T|#WxUFt0N>FFi+ z4N!n~?3jDZTiVmT@ICEWM7F$~7YSm^Ueu3)Lyn7izaH>)!OI~$;zbZHddnOqJWdcE zV_he_S*+$b(DtNvoJPXjPP9)Veag$CEFBdG8{2 z%cPp_8{)K=-!nLN#+#>@w>WGJFPK4~`7VBPk;d$&Y6<8;7TiG_S+vlGeJ*n>TFz*%faQg(}Z(*lKv2SmnhlEF= z&uffEU2G!GW$d6a#(Mbm+WH4Z2S$09+VkFtKsqmD<+`}wUG|R6V7=m^cbTbh3~T2D zqF(|uc$;PQWRDh?KnoobeX_fA-jQMZ%6sz$wMV~~5BFYs?@<4>1O1!mzveA7Z!V*! zE8?|1d6(PAXn*mI^=UuI$&|||ihUs(BAoS>{GD|_#~2m!7QG|hQRKY|y1w;kvi+vF z((teh#dRHm#I=@8@3J3gRt!-PI*qL2#=GnYE zH!F58s96cau`J;=An%9}Gzs6w&`Q?oxuEG%YlJP}T)F1k$hnrWnrmf(Q(KNW0_)>6 zj9qk6qV-Oo2lPg*(0vY8BaB<3LbdqlN^Ug=W7WQoU3&!o0OknSrbX8C_Aj6R$KT)H zymPI!fU;b-EJtM^V6lWl6)=^#g2bfkqB>xCM@{ehXKGOs%X=MU#Vmg65pYNX& z{G&*Mjv~Gf*x!7k)&OHYwk6rgYF@&^rxSv8S&#t3vm(on$lU~Y#Mh`s#I#k(*ak8YGcY!QUgWQq?95BRi2&a_(N`X&x|6^W z2#I}=!z3Y*TJS|$)F0M-q4K+eqo>{!@?)wJCfTm3$ODkAbWa_5rNzUty8}u!8aRsb z*-#R`sZgoW#95R`hw~CV`f3cwCCWHD*?>ddV@E9yr?^jLYm?5PY8KpHx;MB|D_{yZ zO}MC=j6Ak0mHI>Zd741Pt&_`A(&1QcQDF1DAe)#Rd}c>*;1u<$^_tb^sL5?^0%n%y z5zD{jdtmNS?qo^gkV*bU2e{b71?z98zg&uFuV|LJp1+?aGO6$r?>P8c;`_enk|11qlebA=fY>*e`zA!Lf;+N&7ng zhCpNIMK%<}Z$vlQ&5_WT^_h@qoYq(MHN;bJ;HK~w(OnwHA)(2{=dA*kO`0vUwi2kqqbjcPgAQS|^7)X1dG(?Zl3$ zpTGQM_pjXF-?>&IWz+qTtA+rR{V@eZ$kO+8&zQ%Vu4B0Dd4@j~-WHr3$>R#2E6(|$ zw%HE1BdCqbC7h3RDSQ&SngD!A35kyQ*%IE#BtP%Ic~Y>>GFrFYO9 zF8>0(#NY^X$j7B&_h+!L>=VB@e3uT(zBQ=Nf{QZuT8Y$d_e1V9O7oUoMZ|+Ep{dIW z+O}GU$Nv}S%b#QV^4}@=1qDNxX12e$KGT$FH@crFBDD3q?F9Xh$>44X zm*fWSIVf#nTXp9o%Ozs)sQ9%^&%nuA$8CiH#=z}~M>hkwL)t?QYSA4^T!e5vy$GpJ zpj_KcLYh%_;I@D+sJ(N1J*{hWLCtpsknFEwM#^}&sq}~WRNuZlwC=UvN;)CJkvQ@zC(2m5xIoC zB_cA}xf<{}1^$_UY-V1|Lqi?;H%nyac0UyFZOS<$`kT^yufU{?3FBT!u6Vk$0>Z{> z0MUJ&C9eVjbY%zeSJm#nP7_c`&b%&mT`K{2cg}_-@V9<=+r7puZ64$_dt zEEpzV;q)I3ayoA8a09^)Km5Bp%^G9wk8QN3>GstYxjymyd)QXIYP&^#&Coi(Q6DVt z6~K2>vFV8J?%22s>%IoS?X5DI<)=X^m0;^LxSi!VS@QXDNn9sKuuJ}F^EM@W$2(vA z@R&>UPkttv>6Lkis7X@*HME|6N4CmviXVp#v>uE2fP1*b zBjJ^JJEy^GLrKq6>9yb`5od z3w1IOf92ZxTlM;-Wzq5A(rEJXzmN$;pt-`g!Vgsj9-~OS%F_c2YV3D9(3*gcZc?z8s0NAb{g*}1^=D%?Uq0FLJt4B7ZK_|g_u zXdk*e=DvnGE(UWXfH(MgM94Ey%cObU$Ay@$3FFuO*IwkGjlVSDLkp_}s}+82C4UQJ zGR(oNg-1v@3*e&OfxyvJWkRR*3Ih-5% zZ<*GvR_Gz4^M$noeZ3$NeyN{%sh^dTh_({BJHcbH9OA*EoLq~q^n0(?aN}TO| zU}@8geE9eR#poCzkE5FJ+WI;r>1hL&%pAW$t4;XMb9!hecd5n=3J$>vb$5%_OrBXy z&qNad(p7-VDr_K}NR3wRe_K5c3g|$SbqX33-0#KcJ^?KAGymu1VCrUG$3GZqBR@*- z@<^vymP%OUM`}VHSSn>%_nXCvU&1f1gG@QN|h9#N;UCX*IDyjoq^xe&tf+*fdu@qT!;Ligf?oT+Xk=TdasTHcj~5HtT|w* zLGKN_F@4}8;E{hx0iVHDv8z5yeS!GYA$35qO)8ZC42mNM-Woq0{Ngkj)Bgu% COHCdC literal 0 HcmV?d00001 diff --git a/runners/__pycache__/simulator.cpython-39.pyc b/runners/__pycache__/simulator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44422a8d8048029e9312a1b4a7341859ec3818c4 GIT binary patch literal 10374 zcma)Ci*p@UdEeLFyZ5PUS=P&rqmTr+5M+~t&@yfmTb3Lpwj^XZDNEALwa)75%J=27 zdu>as7^t0+5=;#P=@1e|N?RJYl+w&}nhtIA2XuyMJEaWU8JdBKgLxEahd#hA{e9o= zeMo7T>T16|`29TGgl}?1UQ8n6;#x7aNwp7wH^yV2$zaWC5J`2zxPUc{9pg^ znUQC|i@siOvw+N?#Go9ZemoKGnwgzhbdD{|%+F6PIP-HSF~#fP4n=lA`Pt9W-*t2q zA#>ZR+gPTjs(8FtC$7fyot`9Bhqo=BT9{oteB}6)^XS~d@uSY-oHPG{Bd3lZpIVGC zSQ!Y%zW&OOCaCbh-|Js}g$f_}>ApW>%oE~!FVXr5_xnHjOU5ZrfA||*c=apK{E*3C z{g>y$nCgl8OD}Qd*gt>zhg^8zi~kdmE939^$;6U|uG)V(3MMYPv#k5dMP*YB)Gbw2 z0yR(^NIM{89nfHy@XgJnP=&UoZD@fe^brNnV3=TNGE9yr!g?-&c2?A4i~36m1o_4QkXbj5&lYOuDF%*iXZCBjdP(sP|72TO$J;xy?adPx(26HzB6dF6TS zV*Kt`)iauWN1$wKTMD>A+0^M??URfx@;h zK$vMt7hCE&Zv9-Cjx-3BCD|gG0w`*7j41V`nvk~>mcbp2LyuD^cv%2N&8unE!j)Fn zbDdGgL9;|M*#1kRKo_s_k5Isjm3TEO-FI@A)sDJRe8fXOt zl%_~rRF5i;KenNxRZ^I!HAFHngawn5!aaS6MzI34L;t48F#RCtQ=4fXGXp!ArP{L> zK$+W89@91w8)m!y?1Q|YxmoC&?FxvMQmlop{{nwq)wNAGy#p(as`^l>6yByN*A2o^ z#Z|4_O6GwmYS)PyuImIyuF5yA{^TD{oSvU3=EH<^J&1%sLM2I6glXZPZk3&8!*fHe z-VBo^uOww@4Wy-d*$s7f%?*>a#yQvV8==|^^PcZE<99^TZupg&3wf=&ZWH2II^#Oc z$`ZuVEhhQ>%e!eDv$4EPL#vLR?GEYO9j1ia^jDlptqi;4mTOen!7a>9C(w(xA3y=8 zSP+x90A`tQz&To$o&bAA#fyZ--Nr=3P~;FX6Wxb`4;ASvs~W#xhy}6FmJ&aiP>uL) z>3I1H7=)?#h3d1s@$2bkwN!V%IYWzT$e_pW3qIv;;)5PwNVnM7LHC|Rih1MF=9Zn zB;N|~1TK%{l9Wguzx6#`IqH$rZCgQ-N*8a%kFha6czso6je!jj>Ih^67)@v>Xz2axKgk)S+*^? z5e4#^e?b*}myHbrb?Ke;qHtN?NC2A==^zmpXuAs*sh>y60#74ldq&#LBjr}}n|*;U z1~85;_MpY!mJZn@d?A3eUr{~dAw>)YW~3{BQ$=4iyIp~a7Q;NVkrz~yM^RSA-T_7I zyPViaf`(=p(RJ#8=@++P71RgK)e4gFT+_T(n)Uwns^gklN0cF@j9uMK_`6pJ#I1W3 z+OM*X-5k42Hg^wtqZRIX!wU7)!D!`CE$D`@m$$LL>p=0A%hpB;b;ChQ+%E38ti@XX z&RNJg*-|^zF{GgxSy4MhJk18zE|9{gl&YjBn33(>ZA?Q}Tv%Y@?!}^NXIq{-UAf?j zBPGve4aH(oOxUSvrG93nF5C;DQG;0xb3i~Ej+Sc8O1&KBp>N>-cb@wY3&4%2PD`j0 z(ln+liAu|3j{)A{GXGthkPm}}BS8AKXdseUDK*{2HFyvY0TY^4w_b+Zkdo!orO8IM zA&a^^POYe#`?LF&unJdp7Km=F(^$d@n&}Xg(6CzZFp-ngd$f=Ns(7U0wIJACY?iR| z@lHXNPtz#J2s}t&2Ea}>POmzm0vCr0b$CqcR?YQR#__m5G9_h0zHe;0Qgz4b4S#I8 z(W;BFl0RmrX$|mScoDBh9_}|M2cjXaK$w#SWV!2u>dy(CG7ZB)ijwkd#{p2jv{UZ&rcB)=E+j3`w-524Gw;99UwX{;HifuG`Ms-P2s}1-d^m6T}c=$ra zvv)^5O~b|O?7ZF2O}F8$3!05;TeZ@KX+)ZCaYWt;6r5J}FXgRxU@b0>5PA?w`TZ?o z#5&&mv~1L-TG(Al{+f-ck&QgMhSyhF@-Tajqk8zk&K;YYJK;=z;P}k!(WwP{AfnjL zO5i}=#| zN1J08I!Yq;pzU1jamNIG=U5gT7C^Y+cm-{#>Ef8ceuue( z`=c!5^1ck9u+s>~Q8U% z@`-g^-nZb%!LGuWeG=AqIkLj-XT~y4XBP%mTIjG-h~0+B+QBaK9aecrxd_KaB=;(0 zT?6ByDy%cg6;Wm_9 z1c{8Ykw8s`Yv3nDHJI~oV4{w!GyDd?d{5nngCy#b^#^zK*%ufC3Rb_f$2NXvZ*6=M zzJ>`4ZY+~Uhhz8@{29ausj{+>#+AX9#g+4uL5d@T`Hej2ETRkYh^hO6^kypP69w{( z5O-zx30?p{W!p~?#c+hpZKfhe^NM=*%f7LirrtzF&OL?I(VcXJ2MpQ z2?m2bct1k};FpczU;uL)rdeLm&Mvn5etbW#FB`27bISz-0p7rt!E~R*+VdP08Jm4X zgR8KmUD4zbf0q~`&k~-;=t=7Hh}etV!M<5GRgt~m_~vq9G+nVv2vzwZ+V zzx~Ozsry02H8B<47y&$+AF6EAuoa4Xjmu2Kjot~u>*PIOr?~Yx#hCH~0Ofa0?)(0? z?;D*cns%nzIPc2im0ATlIMZsvxoAJIMOq+J0PTzj%ZKepu>c!0+p5iw7efvV3XhM+KR$>!6(I_C5^?JmGB_j3r?j&M_D+plE1m6? zor9@*h-~9Ycw;VGsG=5X;|D{vBFV+?3BS{a;L7+*jau{m0=D7bQLC&U>?$XY_afEr zH#j&PKU8hNNAT`Pq22U4MRqNUL9YGYN3Q+()2}`Cm1}?f+-uLh`1%W9x%TNV{OGSf z@!Ch9yY?q9y#9w@eeL5fPA?YwBdcV1XC?AM*q%@?o%W*XXEA;37O%kL*-Ofd*o0@~ zOntc#nk8&}UD%_T3;8fU9IOJ1n2SDN-71W?@UL<1RR7&gi?-)$F= zDV$lHnVThhxwtTQT>dHzQotZYQ(D5i!xnmMg4r{GgD?h?ux;ruh^$-)t+po|stcid zK2$G+>RPC-hnXIS*iP45HD{H3w6n}WPBr47EOSujiD?`Txjp!J;33|1Pr5X{bW~lz z4%nHfjN!pzxuupQuPP#j?}ykNUIX3ieyqcloVD^1lM_86(}-S9(UMf(&h`>T48p8J zFE#9cBO81jcuxW-T0tMs@6z+yT`GS1sJc(xuMeyEX;~d+GWwxJ(P&;BRP(Vx?HK0B zi2jF}A5;x33j^K{^I8C1)S3+^OtYo;BeqX#>-&Bge{M6PlqVq}rlolMz~Cf62#{a| zq4+ceu!!>mgjbaI323d;;4{^*-sUL)%lmsFX)Ph8u4WJ%V z+@pzbka|plGKn6O=rI`}t@FmhKrC zNQ8=sO@3rMYNT*&w@_3(_d%sE+_W2RSJ9e7uTOkPj?c1!pph*+`9&KD^rN zxwBr)N8))4B;-g>SaY~8dK`Wh7n_`Ac8|$V?t0CkK^h|x{UpV3rMf3Sh{5bz!RWuV zTk>7QM3j)ym*5cV9z0=PH2h*;%(;Hb3#5kWuoef@hCIlHs6;+Mfk=p6i@}3F>0LfS6sgo_Q^&R3_FpcunRQYHcFle zn(T0F!vZ%&o~@nIK7cpk9l{ns*dPEFgoylH_C9YKN#JzGL9hn&H1LD!fNI_}lj*Lo z-CQ5-ip|aewpkxKWFhOrLi0%!gc?~MW_ws3J|7?-R7c)|@<0YNeIy?6jUX86rYQ_z zfEG`i5CzV};vjAdzYSB6pb#aPl>938G%R>ECYUqT(*nnJ+=B8asUW?q!`fN1d@$JK zDNu|nKLqgl%ln`F{`+74n+XJ-Y|x37BjXsQ1S|9H?8=NPKa3v2 zBw1+$`|_iBDrO>%+XegMxXa5(@HHe9c>N6NU05OG;F}yc4Fe=WSNV$N1%j zJYSMR{u+(ZCMF(gA3;RC7p7wRR&yt-^kME!`&{WpkH=rBP67v!NXE$~A>6f`MS zOCQkIN4i|FqwO{(pmbnIbN(G3q7?gMNUO8nRLKc}4%CH;gq-24t2)@jpadb`>G+7A z5rpiGN_2=s7z7g{iNsg}ZIIFeriT^@J;Iwv9VQ+_@*Kb922xh-mV8G~1D(Eh*1&PK z1`s5`m4zb=U2Tz$hv%@_lTrKxt+(Gp5%qj)NtZ%kKm*< zKCXQ5@WRnY4=+qPbMsTP$BuV-J`VgPDFFF91g=n58Jcp;$KkH-oMV){pNNvkO>z1j zA0a660Rp#o0H>-+nYD;l$#cZ9540J(12-uyw}@9&>xPN=ol^W%WBj~e^{ZpQFu>1n z0S3EV+?jn0G9z;Eg5OG4h>Th9iI)MhuD-RW}$)d7_&!9u`Z$rBg)C8_MR&PV!@m?)@kSJ(Tx8^*SF+fLEU zbX=3TyaE6TJH)Va$Z=Jr^}H7~d{`u@`<_sOXTk6F-Kcn;CO{#KqV^mej#Hi(-+ma- zZ|*iHjYjT_PPz++FgUZsxg1U!HsNTscU68Cm6!(^CrQ%`3-NgrsX+E@Q9DMp&j8rB z)GBp{tS;43_|iURIp2`gH>8tjb(=JG;u{5Qw!W*oXYn|1*fm-v?;7n0_=Ha%$(f4g zs_{F<*;>UzVvLT0{;ypNlMeirDvp&0=J_MpL&PF=! z1sE`#(bfH{@qujSbZO~KtI1hsc^haWe&RRzEFQnqGoQWP&1h$=I{m%*#ON(`Cl(1u zduw=deRQVPGd1~nn#PW~JP+KT_ROWnhV5t$2FpbHo7rL8#3-)s=mQd=|4&AZC0&6w z&^W_}^CkocIvE(sA>jlG?_7}QlO#VT&d|YQdJ^=|3W*L5Sv%4u6cQl&69)+}Ly0)= zK{*Qf+XP5OK-}-4GOL3-sl-aain#j=gdoLiTMM!E50kMof#XA2p`-+DfSp=!;zxKs zUA((far`bidx>^<00r_g2lRrGHAveIK!I3VulM){p@Vk7|9<*PgAOq8l}6feY7K!5 zIm%gw9Or6O6MZwFBMb=p9H&&rp+=Mwah&Lb0bVvMKoY4iSr_c}v%K>OGJh)}i>Tz| zU_SMflndr9l21{I%_Z3wNg7d-R*JOO=cq)iEs3WjEj!G`pLZ;|Rmfn^&M+6UFp$Qf z>3GKh=!}^n)A+-To7yk(NFnz;!GxL7-w(~vYypw$Zp*Y9mTeudhO9A)WD6-v&*R6s zjLzecAHZiIy!gj3hI}7DI6A*ZU%{;$X*7KL2m)WvdGTG6{D*_R1PJz}+wgH9f^S>8 z<^EopC_T>~#n`zPzP9tmfufX%B(;a5H6A5WvA5Gix`1{fv<^k@^?uSlY%9n@(*ZpH H5QqFP9$iNY literal 0 HcmV?d00001 diff --git a/runners/__pycache__/strategy_generator.cpython-39.pyc b/runners/__pycache__/strategy_generator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83bc904da137a4ea5c37824d1e89951ef26a6017 GIT binary patch literal 5213 zcmai2&2JmW72hwC%jK6yecQ54J82sxs+E4A1>Dq4V>M}uL^a~HMYl0FYtBldOzu*% zLt7SxCppL|sZYHm0jOgxxfbZ3Q1sZ-TniNNzhE>-e{Xh0(jP#z+TnXMZ)V=Sd7m5B zYMz0=ziqGW|L2ln{F{*ZzXFgq@uq)8!whEj4L|!eeN*F>Z{yeMJA;B>Flo&07Y8N3 zq+zG;4$6L6!)D(bRQ!sD3;pV#=GQb_?AHekzoFq$e__z{n;Lfet-+$dXc|#vvc$@# z<~s)SSmmL?DpBFYJ~91eX5KP7)yLq?*y)(6`2Hx4BMzu`BZ+r=d+&vDxTibpcauHz z*HaOSQ5p=U1J&(NgzqLH`@QJ(LD70Y>Lzh2_^2y-Nu2kVZil0(wWG1s7x)7MnoQH<*u)6LoveLt{UtkT5xvUJVm>tzG zu_~*9*8RkA48M%yE~w_MoQHR&ELd~<6HNDJ<4upy2;;<wuP_6wR(4b~7gH3R367V+Qpe`R zqH@FGFo?rJq$*tBU4|4hM7d|PRXMpA@%y|dB2`MG!yq0F)C%?W!(kBeC=A3tkJA05 z&(vZ!84O1v3hwvBeh?>o5cbt_+707inDkI92Vy^ox8ghZ0{)aPZpG34An)GFB$RHQ ziFKxrSBDSy3RobnQUVOKYI>Gs;`zocESTAIeC|0lx_Tyqu1VAKzZ!n4cyHiMKf{QC zHf6eVYLCq+<^Z#VC2VP);Mq1FnYXvcz@)uyj2+=j3Z!4GB<+mm6#hcbcBKRRw#S7j z1}z#Fr&tA8Nfaj~SvWPvu7pwCwfG~f;+~k}6042N1k0idZED(W|EHTd%u9tFYBCZM}1ut^L#(mk<95J8aN=X`R}z%4z@H zPF}Xt;VaX=QYN>V^t6?&@9;XVhBwftDydQ$iJ;$0g=*YOj-vid-Fbxw$pBR;Csb|@ zq{`( zPs~@2Yn?gTKR+?Ij_upm+OV`P9M+T4E&6G{alCT9{mVW~mpy3z>Im$no9$b7d&9@H zU1vcX8DA$}3p|YXB4ziZI5Wm7oe=i22a`{gm87af=LPGWoUd}Ul@S@IN7gRVtY?u6aDqdahUxxnc1H)uUAE|L`g1)A%4Q_=+Tf;jt2&!7D{ zzMc{U35A591wrn9UJ&}I=1#}bN93y*224!yZEBujh{kW=O|_)W(JeSTFH97@GQC8qJ2TMPH$0CsY?APdOLjE^m}&d1g+d?dVU;h+eQX_HU> z)c(*Km*80o(nS^kN=;MVGsYg?3f?L!jccML9rEnA|23{-Z&ybFSHjO~pIAO1uZaCw zS=Qn8yf4AkxFKf!mAU?fx&G>0e^ZbzZNSsFGEd6aWb2^8O6vwI-L>Fh9Z)R5^P1zu z*wfSt1k1{#DHpY$^#HZB=ccuq7+sQ8^eySp*=i0yk~LO`Pp%yv&7H`}PedFn%G!xF zUWN=-AOlw{$(0jRF5g9b+B9JZZ*Cn#GgEhSCc+aus zr|yUL`22Wdyb-?&y7Qnz#6p~T5v_Mdo9}QzY_OKRAkVEChnA=OMp~@}fL1 z&$C6`LLKkY7xwtlRR5M|BU%qHPA;((X;I{xjR0H2sm`C6XOS2>+9rLGtB9g)aap$6 z`ib?Ijqzo1W%81|Brl&@rtwGUu7l?Ukuq-mz<6Y(|9#ys_46mRa#hQ96>|N@OoEyG zW>VCqdTwVH0{N>jrN?B}A0Y3qaQOZnkH}kR0k1};NYQSWQ-W#_MSAR9jfW4E9rJ-& zoKGlwG3y~X!uBb6@wchjq~(8z=D3w5q_h@t*DS&m)!GXd$YMiTXQKCd(S22(rnNe) z%(BciTD159veRpA{yiWaTX__c;4k4O{o0)<6~Wn7syPUcf-WV?D24(mUpQ4J{54wU z()IS+*e>0V*k=2^@Th$wCv3kDLd5m8&GuVZu6--#yp1TlC-xD~&HJ5~vv}-LJkAni zuA}Y}Mcpi_QMR5pGoVU*WoMNKYHRnB`ws#Zv#K_*ca%=&FgO2g)ndsZ_>r*Yy zv|iVBWPKmfg!gDd0XoHnzu4~$Q2DW7lt#OwexNPG43v%Q<-Y@=Un4C`!+VjIGJk_6 zo4e0sPW&h6Cezeav4QNUYEe$G1jxxA;4O2F9*goQqVSM2We;UfEl?2c24;B}bwK4I zzj7@X`9=NPy1jt53apjaT;4au;+px2Rky3=CTL!iflPJe%x0g;K**w~TGnBo@)?98 zOWv7aig6#*lltZ9=(7U<2&+EAn^OH~Kx{1tfW&(Io#!*3 zsSsv{(m`s4e~+qjoq`9-3rLSg@DPEj2EpMd>}Mmo=H)-fK05EyxtP9j{t^0AEw8@2 zQNJGqbByW~O{a_WDU6CJ&8=%y7u*Z=Lfv+M%W0&ul6jvltyX~fq0IgR!RKy1pszst zZ^2Hl=4E;Q%PG%Ot$Bj%b3pb5p*cS=!sn$l^9(EV{VCjal!ym#Y~CFF)S|iT)t^)v zdZDXYh}tz~4ms<6`h!B|1LC9}Z)`tzsAxw1u z3n!e$#K+x;>B#U6y0n=i%eN>I*0s&e8+t^z#qBxawlHSqjQS3zC#2*20UxB_p@s7) zPZA;ht&2!}XJ-QBK&LXq^Yvj>iaG>A5TA158&~9%a>|Xl>zmx<)|~hjZiL+kvuKO|=VxCNpwPr}TzWGNn_(3{FnSg7!OR zVXHh7vTzxTl1MXFG{ZDGjm}sW=wZhwEG6PoSpzQ%TGSZ}<3K6K7Ub_2ZzOmWc*FEm z2#*Ik@G_a6i%@&nSP#?0O9mHA3h<}2%tIZa36fNqJy~W>d7GK*Wn<|;I1H)EtkD*o z>~7HQtp z*a@Wwv|!>AsveC*q8AmI<{${a<{3+Zk&tz`yM|~7)NPom3`!?2Hu3!D^I>q0z{?%D zWf3+2ZpUfB*Wk&=2<0B0CtySS3eWRRJoj95pwZ#yVv{2~Xz~VQqm<_{$}Pse= data["seq_max_coverage_rate"] - 1e-3: + print("max coverage rate reached!: ", pred_cr) + + + + pred_cr_seq.append(pred_cr) + scanned_view_pts.append(new_target_pts) + + input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)] + + combined_scanned_pts = np.vstack(scanned_view_pts) + voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_pts, voxel_threshold) + random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N) + input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device) + input_data["scanned_pts"] = [torch.cat([input_data["scanned_pts"][0], torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device)], dim=0)] + + last_pred_cr = pred_cr + pts_num = voxel_downsampled_combined_scanned_pts_np.shape[0] + Log.info(f"delta pts num:,{pts_num - last_pts_num },{pts_num}, {last_pts_num}") + + if pts_num - last_pts_num < self.min_new_pts_num and pred_cr <= data["seq_max_coverage_rate"] - 1e-2: + retry += 1 + retry_duplication_pose.append(pred_pose.cpu().numpy().tolist()) + Log.red(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}") + elif pts_num - last_pts_num < self.min_new_pts_num and pred_cr > data["seq_max_coverage_rate"] - 1e-2: + success += 1 + Log.success(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}") + + last_pts_num = pts_num + + + input_data["scanned_n_to_world_pose_9d"] = input_data["scanned_n_to_world_pose_9d"][0].cpu().numpy().tolist() + result = { + "pred_pose_9d_seq": input_data["scanned_n_to_world_pose_9d"], + "combined_scanned_pts": input_data["combined_scanned_pts"], + "target_pts_seq": scanned_view_pts, + "coverage_rate_seq": pred_cr_seq, + "max_coverage_rate": data["seq_max_coverage_rate"], + "pred_max_coverage_rate": max(pred_cr_seq), + "scene_name": scene_name, + "retry_no_pts_pose": retry_no_pts_pose, + "retry_duplication_pose": retry_duplication_pose, + "retry_overlap_pose": retry_overlap_pose, + "best_seq_len": data["best_seq_len"], + } + self.stat_result[scene_name] = { + "coverage_rate_seq": pred_cr_seq, + "pred_max_coverage_rate": max(pred_cr_seq), + "pred_seq_len": len(pred_cr_seq), + } + print('success rate: ', max(pred_cr_seq)) + + return result + + def voxel_downsample_with_mapping(self, point_cloud, voxel_size=0.003): + voxel_indices = np.floor(point_cloud / voxel_size).astype(np.int32) + unique_voxels, inverse, counts = np.unique(voxel_indices, axis=0, return_inverse=True, return_counts=True) + idx_sort = np.argsort(inverse) + idx_unique = idx_sort[np.cumsum(counts)-counts] + downsampled_points = point_cloud[idx_unique] + return downsampled_points, inverse + + def compute_coverage_rate(self, scanned_view_pts, new_pts, model_pts, threshold=0.005): + if new_pts is not None: + new_scanned_view_pts = scanned_view_pts + [new_pts] + else: + new_scanned_view_pts = scanned_view_pts + combined_point_cloud = np.vstack(new_scanned_view_pts) + down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold) + return ReconstructionUtil.compute_coverage_rate(model_pts, down_sampled_combined_point_cloud, threshold) + + def voxel_downsample_with_mapping(self, point_cloud, voxel_size=0.003): + voxel_indices = np.floor(point_cloud / voxel_size).astype(np.int32) + unique_voxels, inverse, counts = np.unique(voxel_indices, axis=0, return_inverse=True, return_counts=True) + idx_sort = np.argsort(inverse) + idx_unique = idx_sort[np.cumsum(counts)-counts] + downsampled_points = point_cloud[idx_unique] + return downsampled_points, inverse + + def save_inference_result(self, dataset_name, scene_name, output): + dataset_dir = os.path.join(self.output_dir, dataset_name) + if not os.path.exists(dataset_dir): + os.makedirs(dataset_dir) + output_path = os.path.join(dataset_dir, f"{scene_name}.pkl") + pickle.dump(output, open(output_path, "wb")) + with open(self.stat_result_path, "w") as f: + json.dump(self.stat_result, f) + + + def get_checkpoint_path(self, is_last=False): + return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME, + "Epoch_{}.pth".format( + self.current_epoch if self.current_epoch != -1 and not is_last else "last")) + + def load_checkpoint(self, is_last=False): + self.load(self.get_checkpoint_path(is_last)) + Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}") + if is_last: + checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME) + meta_path = os.path.join(checkpoint_root, "meta.json") + if not os.path.exists(meta_path): + raise FileNotFoundError( + "No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"])) + file_path = os.path.join(checkpoint_root, "meta.json") + with open(file_path, "r") as f: + meta = json.load(f) + self.current_epoch = meta["last_epoch"] + self.current_iter = meta["last_iter"] + + def load_experiment(self, backup_name=None): + super().load_experiment(backup_name) + self.current_epoch = self.experiments_config["epoch"] + #self.load_checkpoint(is_last=(self.current_epoch == -1)) + + def create_experiment(self, backup_name=None): + super().create_experiment(backup_name) + + + def load(self, path): + state_dict = torch.load(path) + self.pipeline.load_state_dict(state_dict) + + def print_info(self): + def print_dataset(dataset: BaseDataset): + config = dataset.get_config() + name = dataset.get_name() + Log.blue(f"Dataset: {name}") + for k,v in config.items(): + Log.blue(f"\t{k}: {v}") + + super().print_info() + table_size = 70 + Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+') + #Log.blue(self.pipeline) + Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+') + for i, test_set in enumerate(self.test_set_list): + Log.blue(f"test dataset {i}: ") + print_dataset(test_set) + + Log.blue(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+') + diff --git a/runners/global_and_local_points_inferencer.py b/runners/global_and_local_points_inferencer.py new file mode 100644 index 0000000..26ad563 --- /dev/null +++ b/runners/global_and_local_points_inferencer.py @@ -0,0 +1,352 @@ +import os +import json +from utils.render import RenderUtil +from utils.pose import PoseUtil +from utils.pts import PtsUtil +from utils.reconstruction import ReconstructionUtil +from beans.predict_result import PredictResult + +import torch +from tqdm import tqdm +import numpy as np +import pickle + +from PytorchBoot.config import ConfigManager +import PytorchBoot.namespace as namespace +import PytorchBoot.stereotype as stereotype +from PytorchBoot.factory import ComponentFactory + +from PytorchBoot.dataset import BaseDataset +from PytorchBoot.runners.runner import Runner +from PytorchBoot.utils import Log +from PytorchBoot.status import status_manager +from utils.data_load import DataLoadUtil + +@stereotype.runner("global_and_local_points_inferencer") +class GlobalAndLocalPointsInferencer(Runner): + def __init__(self, config_path): + + super().__init__(config_path) + + self.script_path = ConfigManager.get(namespace.Stereotype.RUNNER, "blender_script_path") + self.output_dir = ConfigManager.get(namespace.Stereotype.RUNNER, "output_dir") + self.voxel_size = ConfigManager.get(namespace.Stereotype.RUNNER, "voxel_size") + self.min_new_area = ConfigManager.get(namespace.Stereotype.RUNNER, "min_new_area") + CM = 0.01 + self.min_new_pts_num = self.min_new_area * (CM / self.voxel_size) **2 + ''' Pipeline ''' + self.pipeline_name = self.config[namespace.Stereotype.PIPELINE] + self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name) + self.pipeline = self.pipeline.to(self.device) + + ''' Experiment ''' + self.load_experiment("nbv_evaluator") + self.stat_result_path = os.path.join(self.output_dir, "stat.json") + if os.path.exists(self.stat_result_path): + with open(self.stat_result_path, "r") as f: + self.stat_result = json.load(f) + else: + self.stat_result = {} + + ''' Test ''' + self.test_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TEST) + self.test_dataset_name_list = self.test_config["dataset_list"] + self.test_set_list = [] + self.test_writer_list = [] + seen_name = set() + for test_dataset_name in self.test_dataset_name_list: + if test_dataset_name not in seen_name: + seen_name.add(test_dataset_name) + else: + raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name)) + test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name) + self.test_set_list.append(test_set) + self.print_info() + + + def run(self): + Log.info("Loading from epoch {}.".format(self.current_epoch)) + self.inference() + Log.success("Inference finished.") + + + def inference(self): + self.pipeline.eval() + with torch.no_grad(): + test_set: BaseDataset + for dataset_idx, test_set in enumerate(self.test_set_list): + status_manager.set_progress("inference", "inferencer", f"dataset", dataset_idx, len(self.test_set_list)) + test_set_name = test_set.get_name() + + total=int(len(test_set)) + for i in tqdm(range(total), desc=f"Processing {test_set_name}", ncols=100): + try: + #import ipdb; ipdb.set_trace() + data = test_set.__getitem__(i) + scene_name = data["scene_name"] + inference_result_path = os.path.join(self.output_dir, test_set_name, f"{scene_name}.pkl") + + if os.path.exists(inference_result_path): + Log.info(f"Inference result already exists for scene: {scene_name}") + continue + + status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total) + output = self.predict_sequence(data) + self.save_inference_result(test_set_name, data["scene_name"], output) + except Exception as e: + print(e) + Log.error(f"Error, {e}") + continue + + status_manager.set_progress("inference", "inferencer", f"dataset", len(self.test_set_list), len(self.test_set_list)) + + def predict_sequence(self, data, cr_increase_threshold=0, overlap_area_threshold=25, scan_points_threshold=10, max_iter=50, max_retry = 10, max_success=3): + scene_name = data["scene_name"] + Log.info(f"Processing scene: {scene_name}") + status_manager.set_status("inference", "inferencer", "scene", scene_name) + + ''' data for rendering ''' + scene_path = data["scene_path"] + O_to_L_pose = data["O_to_L_pose"] + voxel_threshold = self.voxel_size + filter_degree = 75 + down_sampled_model_pts = data["gt_pts"] + + first_frame_to_world_9d = data["first_scanned_n_to_world_pose_9d"][0] + first_frame_to_world = np.eye(4) + first_frame_to_world[:3,:3] = PoseUtil.rotation_6d_to_matrix_numpy(first_frame_to_world_9d[:6]) + first_frame_to_world[:3,3] = first_frame_to_world_9d[6:] + + ''' data for inference ''' + input_data = {} + + input_data["combined_scanned_pts"] = torch.tensor(data["first_scanned_pts"][0], dtype=torch.float32).to(self.device).unsqueeze(0) + input_data["scanned_pts"] = [torch.tensor(data["first_scanned_pts"][0], dtype=torch.float32).to(self.device).unsqueeze(0)] + input_data["scanned_pts_mask"] = [torch.zeros(input_data["combined_scanned_pts"].shape[1], dtype=torch.bool).to(self.device).unsqueeze(0)] + input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(data["first_scanned_n_to_world_pose_9d"], dtype=torch.float32).to(self.device)] + input_data["mode"] = namespace.Mode.TEST + input_pts_N = input_data["combined_scanned_pts"].shape[1] + root = os.path.dirname(scene_path) + display_table_info = DataLoadUtil.get_display_table_info(root, scene_name) + radius = display_table_info["radius"] + scan_points = np.asarray(ReconstructionUtil.generate_scan_points(display_table_top=0,display_table_radius=radius)) + # + first_frame_target_pts, first_frame_target_normals, first_frame_scan_points_indices = RenderUtil.render_pts(first_frame_to_world, scene_path, self.script_path, scan_points, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose) + scanned_view_pts = [first_frame_target_pts] + history_indices = [first_frame_scan_points_indices] + last_pred_cr, added_pts_num = self.compute_coverage_rate(scanned_view_pts, None, down_sampled_model_pts, threshold=voxel_threshold) + retry_duplication_pose = [] + retry_no_pts_pose = [] + retry_overlap_pose = [] + retry = 0 + pred_cr_seq = [last_pred_cr] + success = 0 + last_pts_num = PtsUtil.voxel_downsample_point_cloud(data["first_scanned_pts"][0], voxel_threshold).shape[0] + #import time + #import ipdb; ipdb.set_trace() + while len(pred_cr_seq) < max_iter and retry < max_retry and success < max_success: + #import ipdb; ipdb.set_trace() + Log.green(f"iter: {len(pred_cr_seq)}, retry: {retry}/{max_retry}, success: {success}/{max_success}") + combined_scanned_pts = np.vstack(scanned_view_pts) + voxel_downsampled_combined_scanned_pts_np, inverse = self.voxel_downsample_with_mapping(combined_scanned_pts, voxel_threshold) + + output = self.pipeline(input_data) + pred_pose_9d = output["pred_pose_9d"] + pred_pose = torch.eye(4, device=pred_pose_9d.device) + # # save pred_pose_9d ------ + # root = "/media/hofee/data/project/python/nbv_reconstruction/nbv_reconstruction/temp_output_result" + # scene_dir = os.path.join(root, scene_name) + # if not os.path.exists(scene_dir): + # os.makedirs(scene_dir) + # pred_9d_path = os.path.join(scene_dir,f"pred_pose_9d_{len(pred_cr_seq)}.npy") + # pts_path = os.path.join(scene_dir,f"combined_scanned_pts_{len(pred_cr_seq)}.txt") + # np_combined_scanned_pts = input_data["combined_scanned_pts"][0].cpu().numpy() + # np.save(pred_9d_path, pred_pose_9d.cpu().numpy()) + # np.savetxt(pts_path, np_combined_scanned_pts) + # # ----- ----- ----- + predict_result = PredictResult(pred_pose_9d.cpu().numpy(), input_pts=input_data["combined_scanned_pts"][0].cpu().numpy(), cluster_params=dict(eps=0.25, min_samples=3)) + # ----------------------- + # import ipdb; ipdb.set_trace() + # predict_result.visualize() + # ----------------------- + pred_pose_9d_candidates = predict_result.candidate_9d_poses + for pred_pose_9d in pred_pose_9d_candidates: + #import ipdb; ipdb.set_trace() + pred_pose_9d = torch.tensor(pred_pose_9d, dtype=torch.float32).to(self.device).unsqueeze(0) + pred_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(pred_pose_9d[:,:6])[0] + pred_pose[:3,3] = pred_pose_9d[0,6:] + try: + + new_target_pts, new_target_normals, new_scan_points_indices = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, scan_points, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose) + #import ipdb; ipdb.set_trace() + if not ReconstructionUtil.check_scan_points_overlap(history_indices, new_scan_points_indices, scan_points_threshold): + curr_overlap_area_threshold = overlap_area_threshold + else: + curr_overlap_area_threshold = overlap_area_threshold * 0.5 + + downsampled_new_target_pts = PtsUtil.voxel_downsample_point_cloud(new_target_pts, voxel_threshold) + overlap, _ = ReconstructionUtil.check_overlap(downsampled_new_target_pts, voxel_downsampled_combined_scanned_pts_np, overlap_area_threshold = curr_overlap_area_threshold, voxel_size=voxel_threshold, require_new_added_pts_num = True) + if not overlap: + Log.yellow("no overlap!") + retry += 1 + retry_overlap_pose.append(pred_pose.cpu().numpy().tolist()) + continue + + history_indices.append(new_scan_points_indices) + except Exception as e: + Log.error(f"Error in scene {scene_path}, {e}") + print("current pose: ", pred_pose) + print("curr_pred_cr: ", last_pred_cr) + retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist()) + retry += 1 + continue + + if new_target_pts.shape[0] == 0: + Log.red("no pts in new target") + retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist()) + retry += 1 + continue + + pred_cr, _ = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold) + Log.yellow(f"{pred_cr}, {last_pred_cr}, max: , {data['seq_max_coverage_rate']}") + if pred_cr >= data["seq_max_coverage_rate"] - 1e-3: + print("max coverage rate reached!: ", pred_cr) + + + + pred_cr_seq.append(pred_cr) + scanned_view_pts.append(new_target_pts) + + input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)] + + combined_scanned_pts = np.vstack(scanned_view_pts) + voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_pts, voxel_threshold) + random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N) + input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device) + input_data["scanned_pts"] = [torch.cat([input_data["scanned_pts"][0], torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device)], dim=0)] + + last_pred_cr = pred_cr + pts_num = voxel_downsampled_combined_scanned_pts_np.shape[0] + Log.info(f"delta pts num:,{pts_num - last_pts_num },{pts_num}, {last_pts_num}") + + if pts_num - last_pts_num < self.min_new_pts_num and pred_cr <= data["seq_max_coverage_rate"] - 1e-2: + retry += 1 + retry_duplication_pose.append(pred_pose.cpu().numpy().tolist()) + Log.red(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}") + elif pts_num - last_pts_num < self.min_new_pts_num and pred_cr > data["seq_max_coverage_rate"] - 1e-2: + success += 1 + Log.success(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}") + + last_pts_num = pts_num + + + input_data["scanned_n_to_world_pose_9d"] = input_data["scanned_n_to_world_pose_9d"][0].cpu().numpy().tolist() + result = { + "pred_pose_9d_seq": input_data["scanned_n_to_world_pose_9d"], + "combined_scanned_pts": input_data["combined_scanned_pts"], + "target_pts_seq": scanned_view_pts, + "coverage_rate_seq": pred_cr_seq, + "max_coverage_rate": data["seq_max_coverage_rate"], + "pred_max_coverage_rate": max(pred_cr_seq), + "scene_name": scene_name, + "retry_no_pts_pose": retry_no_pts_pose, + "retry_duplication_pose": retry_duplication_pose, + "retry_overlap_pose": retry_overlap_pose, + "best_seq_len": data["best_seq_len"], + } + self.stat_result[scene_name] = { + "coverage_rate_seq": pred_cr_seq, + "pred_max_coverage_rate": max(pred_cr_seq), + "pred_seq_len": len(pred_cr_seq), + } + print('success rate: ', max(pred_cr_seq)) + + return result + + def voxel_downsample_with_mapping(self, point_cloud, voxel_size=0.003): + voxel_indices = np.floor(point_cloud / voxel_size).astype(np.int32) + unique_voxels, inverse, counts = np.unique(voxel_indices, axis=0, return_inverse=True, return_counts=True) + idx_sort = np.argsort(inverse) + idx_unique = idx_sort[np.cumsum(counts)-counts] + downsampled_points = point_cloud[idx_unique] + return downsampled_points, inverse + + def compute_coverage_rate(self, scanned_view_pts, new_pts, model_pts, threshold=0.005): + if new_pts is not None: + new_scanned_view_pts = scanned_view_pts + [new_pts] + else: + new_scanned_view_pts = scanned_view_pts + combined_point_cloud = np.vstack(new_scanned_view_pts) + down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold) + return ReconstructionUtil.compute_coverage_rate(model_pts, down_sampled_combined_point_cloud, threshold) + + def voxel_downsample_with_mapping(self, point_cloud, voxel_size=0.003): + voxel_indices = np.floor(point_cloud / voxel_size).astype(np.int32) + unique_voxels, inverse, counts = np.unique(voxel_indices, axis=0, return_inverse=True, return_counts=True) + idx_sort = np.argsort(inverse) + idx_unique = idx_sort[np.cumsum(counts)-counts] + downsampled_points = point_cloud[idx_unique] + return downsampled_points, inverse + + def save_inference_result(self, dataset_name, scene_name, output): + dataset_dir = os.path.join(self.output_dir, dataset_name) + if not os.path.exists(dataset_dir): + os.makedirs(dataset_dir) + output_path = os.path.join(dataset_dir, f"{scene_name}.pkl") + pickle.dump(output, open(output_path, "wb")) + with open(self.stat_result_path, "w") as f: + json.dump(self.stat_result, f) + + + def get_checkpoint_path(self, is_last=False): + return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME, + "Epoch_{}.pth".format( + self.current_epoch if self.current_epoch != -1 and not is_last else "last")) + + def load_checkpoint(self, is_last=False): + self.load(self.get_checkpoint_path(is_last)) + Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}") + if is_last: + checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME) + meta_path = os.path.join(checkpoint_root, "meta.json") + if not os.path.exists(meta_path): + raise FileNotFoundError( + "No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"])) + file_path = os.path.join(checkpoint_root, "meta.json") + with open(file_path, "r") as f: + meta = json.load(f) + self.current_epoch = meta["last_epoch"] + self.current_iter = meta["last_iter"] + + def load_experiment(self, backup_name=None): + super().load_experiment(backup_name) + self.current_epoch = self.experiments_config["epoch"] + self.load_checkpoint(is_last=(self.current_epoch == -1)) + + def create_experiment(self, backup_name=None): + super().create_experiment(backup_name) + + + def load(self, path): + state_dict = torch.load(path) + self.pipeline.load_state_dict(state_dict) + + def print_info(self): + def print_dataset(dataset: BaseDataset): + config = dataset.get_config() + name = dataset.get_name() + Log.blue(f"Dataset: {name}") + for k,v in config.items(): + Log.blue(f"\t{k}: {v}") + + super().print_info() + table_size = 70 + Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+') + Log.blue(self.pipeline) + Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+') + for i, test_set in enumerate(self.test_set_list): + Log.blue(f"test dataset {i}: ") + print_dataset(test_set) + + Log.blue(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+') + diff --git a/runners/global_points_inferencer.py b/runners/global_points_inferencer.py new file mode 100644 index 0000000..a5d6403 --- /dev/null +++ b/runners/global_points_inferencer.py @@ -0,0 +1,348 @@ +import os +import json +from utils.render import RenderUtil +from utils.pose import PoseUtil +from utils.pts import PtsUtil +from utils.reconstruction import ReconstructionUtil +from beans.predict_result import PredictResult + +import torch +from tqdm import tqdm +import numpy as np +import pickle + +from PytorchBoot.config import ConfigManager +import PytorchBoot.namespace as namespace +import PytorchBoot.stereotype as stereotype +from PytorchBoot.factory import ComponentFactory + +from PytorchBoot.dataset import BaseDataset +from PytorchBoot.runners.runner import Runner +from PytorchBoot.utils import Log +from PytorchBoot.status import status_manager +from utils.data_load import DataLoadUtil + +@stereotype.runner("global_points_inferencer") +class GlobalPointsInferencer(Runner): + def __init__(self, config_path): + + super().__init__(config_path) + + self.script_path = ConfigManager.get(namespace.Stereotype.RUNNER, "blender_script_path") + self.output_dir = ConfigManager.get(namespace.Stereotype.RUNNER, "output_dir") + self.voxel_size = ConfigManager.get(namespace.Stereotype.RUNNER, "voxel_size") + self.min_new_area = ConfigManager.get(namespace.Stereotype.RUNNER, "min_new_area") + CM = 0.01 + self.min_new_pts_num = self.min_new_area * (CM / self.voxel_size) **2 + ''' Pipeline ''' + self.pipeline_name = self.config[namespace.Stereotype.PIPELINE] + self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name) + self.pipeline = self.pipeline.to(self.device) + + ''' Experiment ''' + self.load_experiment("nbv_evaluator") + self.stat_result_path = os.path.join(self.output_dir, "stat.json") + if os.path.exists(self.stat_result_path): + with open(self.stat_result_path, "r") as f: + self.stat_result = json.load(f) + else: + self.stat_result = {} + + ''' Test ''' + self.test_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TEST) + self.test_dataset_name_list = self.test_config["dataset_list"] + self.test_set_list = [] + self.test_writer_list = [] + seen_name = set() + for test_dataset_name in self.test_dataset_name_list: + if test_dataset_name not in seen_name: + seen_name.add(test_dataset_name) + else: + raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name)) + test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name) + self.test_set_list.append(test_set) + self.print_info() + + + def run(self): + Log.info("Loading from epoch {}.".format(self.current_epoch)) + self.inference() + Log.success("Inference finished.") + + + def inference(self): + self.pipeline.eval() + with torch.no_grad(): + test_set: BaseDataset + for dataset_idx, test_set in enumerate(self.test_set_list): + status_manager.set_progress("inference", "inferencer", f"dataset", dataset_idx, len(self.test_set_list)) + test_set_name = test_set.get_name() + + total=int(len(test_set)) + for i in tqdm(range(total), desc=f"Processing {test_set_name}", ncols=100): + try: + data = test_set.__getitem__(i) + scene_name = data["scene_name"] + inference_result_path = os.path.join(self.output_dir, test_set_name, f"{scene_name}.pkl") + + if os.path.exists(inference_result_path): + Log.info(f"Inference result already exists for scene: {scene_name}") + continue + + status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total) + output = self.predict_sequence(data) + self.save_inference_result(test_set_name, data["scene_name"], output) + except Exception as e: + print(e) + Log.error(f"Error, {e}") + continue + + status_manager.set_progress("inference", "inferencer", f"dataset", len(self.test_set_list), len(self.test_set_list)) + + def predict_sequence(self, data, cr_increase_threshold=0, overlap_area_threshold=25, scan_points_threshold=10, max_iter=50, max_retry = 10, max_success=3): + scene_name = data["scene_name"] + Log.info(f"Processing scene: {scene_name}") + status_manager.set_status("inference", "inferencer", "scene", scene_name) + + ''' data for rendering ''' + scene_path = data["scene_path"] + O_to_L_pose = data["O_to_L_pose"] + voxel_threshold = self.voxel_size + filter_degree = 75 + down_sampled_model_pts = data["gt_pts"] + + first_frame_to_world_9d = data["first_scanned_n_to_world_pose_9d"][0] + first_frame_to_world = np.eye(4) + first_frame_to_world[:3,:3] = PoseUtil.rotation_6d_to_matrix_numpy(first_frame_to_world_9d[:6]) + first_frame_to_world[:3,3] = first_frame_to_world_9d[6:] + + ''' data for inference ''' + input_data = {} + + input_data["combined_scanned_pts"] = torch.tensor(data["first_scanned_pts"][0], dtype=torch.float32).to(self.device).unsqueeze(0) + input_data["scanned_pts_mask"] = [torch.zeros(input_data["combined_scanned_pts"].shape[1], dtype=torch.bool).to(self.device).unsqueeze(0)] + input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(data["first_scanned_n_to_world_pose_9d"], dtype=torch.float32).to(self.device)] + input_data["mode"] = namespace.Mode.TEST + input_pts_N = input_data["combined_scanned_pts"].shape[1] + + root = os.path.dirname(scene_path) + display_table_info = DataLoadUtil.get_display_table_info(root, scene_name) + radius = display_table_info["radius"] + scan_points = np.asarray(ReconstructionUtil.generate_scan_points(display_table_top=0,display_table_radius=radius)) + + first_frame_target_pts, first_frame_target_normals, first_frame_scan_points_indices = RenderUtil.render_pts(first_frame_to_world, scene_path, self.script_path, scan_points, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose) + scanned_view_pts = [first_frame_target_pts] + history_indices = [first_frame_scan_points_indices] + last_pred_cr, added_pts_num = self.compute_coverage_rate(scanned_view_pts, None, down_sampled_model_pts, threshold=voxel_threshold) + retry_duplication_pose = [] + retry_no_pts_pose = [] + retry_overlap_pose = [] + retry = 0 + pred_cr_seq = [last_pred_cr] + success = 0 + last_pts_num = PtsUtil.voxel_downsample_point_cloud(data["first_scanned_pts"][0], voxel_threshold).shape[0] + #import time + while len(pred_cr_seq) < max_iter and retry < max_retry and success < max_success: + #import ipdb; ipdb.set_trace() + Log.green(f"iter: {len(pred_cr_seq)}, retry: {retry}/{max_retry}, success: {success}/{max_success}") + combined_scanned_pts = np.vstack(scanned_view_pts) + voxel_downsampled_combined_scanned_pts_np, inverse = self.voxel_downsample_with_mapping(combined_scanned_pts, voxel_threshold) + output = self.pipeline(input_data) + pred_pose_9d = output["pred_pose_9d"] + pred_pose = torch.eye(4, device=pred_pose_9d.device) + # # save pred_pose_9d ------ + # root = "/media/hofee/data/project/python/nbv_reconstruction/nbv_reconstruction/temp_output_result" + # scene_dir = os.path.join(root, scene_name) + # if not os.path.exists(scene_dir): + # os.makedirs(scene_dir) + # pred_9d_path = os.path.join(scene_dir,f"pred_pose_9d_{len(pred_cr_seq)}.npy") + # pts_path = os.path.join(scene_dir,f"combined_scanned_pts_{len(pred_cr_seq)}.txt") + # np_combined_scanned_pts = input_data["combined_scanned_pts"][0].cpu().numpy() + # np.save(pred_9d_path, pred_pose_9d.cpu().numpy()) + # np.savetxt(pts_path, np_combined_scanned_pts) + # # ----- ----- ----- + predict_result = PredictResult(pred_pose_9d.cpu().numpy(), input_pts=input_data["combined_scanned_pts"][0].cpu().numpy(), cluster_params=dict(eps=0.25, min_samples=3)) + # ----------------------- + # import ipdb; ipdb.set_trace() + # predict_result.visualize() + # ----------------------- + pred_pose_9d_candidates = predict_result.candidate_9d_poses + for pred_pose_9d in pred_pose_9d_candidates: + #import ipdb; ipdb.set_trace() + pred_pose_9d = torch.tensor(pred_pose_9d, dtype=torch.float32).to(self.device).unsqueeze(0) + pred_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(pred_pose_9d[:,:6])[0] + pred_pose[:3,3] = pred_pose_9d[0,6:] + try: + new_target_pts, new_target_normals, new_scan_points_indices = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, scan_points, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose) + #import ipdb; ipdb.set_trace() + if not ReconstructionUtil.check_scan_points_overlap(history_indices, new_scan_points_indices, scan_points_threshold): + curr_overlap_area_threshold = overlap_area_threshold + else: + curr_overlap_area_threshold = overlap_area_threshold * 0.5 + + downsampled_new_target_pts = PtsUtil.voxel_downsample_point_cloud(new_target_pts, voxel_threshold) + overlap, _ = ReconstructionUtil.check_overlap(downsampled_new_target_pts, voxel_downsampled_combined_scanned_pts_np, overlap_area_threshold = curr_overlap_area_threshold, voxel_size=voxel_threshold, require_new_added_pts_num = True) + if not overlap: + Log.yellow("no overlap!") + retry += 1 + retry_overlap_pose.append(pred_pose.cpu().numpy().tolist()) + continue + + history_indices.append(new_scan_points_indices) + except Exception as e: + Log.error(f"Error in scene {scene_path}, {e}") + print("current pose: ", pred_pose) + print("curr_pred_cr: ", last_pred_cr) + retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist()) + retry += 1 + continue + + if new_target_pts.shape[0] == 0: + Log.red("no pts in new target") + retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist()) + retry += 1 + continue + + pred_cr, _ = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold) + Log.yellow(f"{pred_cr}, {last_pred_cr}, max: , {data['seq_max_coverage_rate']}") + if pred_cr >= data["seq_max_coverage_rate"] - 1e-3: + print("max coverage rate reached!: ", pred_cr) + + + + pred_cr_seq.append(pred_cr) + scanned_view_pts.append(new_target_pts) + + input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)] + + combined_scanned_pts = np.vstack(scanned_view_pts) + voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_pts, voxel_threshold) + random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N) + input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device) + + + last_pred_cr = pred_cr + pts_num = voxel_downsampled_combined_scanned_pts_np.shape[0] + Log.info(f"delta pts num:,{pts_num - last_pts_num },{pts_num}, {last_pts_num}") + + if pts_num - last_pts_num < self.min_new_pts_num and pred_cr <= data["seq_max_coverage_rate"] - 1e-2: + retry += 1 + retry_duplication_pose.append(pred_pose.cpu().numpy().tolist()) + Log.red(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}") + elif pts_num - last_pts_num < self.min_new_pts_num and pred_cr > data["seq_max_coverage_rate"] - 1e-2: + success += 1 + Log.success(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}") + + last_pts_num = pts_num + + + input_data["scanned_n_to_world_pose_9d"] = input_data["scanned_n_to_world_pose_9d"][0].cpu().numpy().tolist() + result = { + "pred_pose_9d_seq": input_data["scanned_n_to_world_pose_9d"], + "combined_scanned_pts": input_data["combined_scanned_pts"], + "target_pts_seq": scanned_view_pts, + "coverage_rate_seq": pred_cr_seq, + "max_coverage_rate": data["seq_max_coverage_rate"], + "pred_max_coverage_rate": max(pred_cr_seq), + "scene_name": scene_name, + "retry_no_pts_pose": retry_no_pts_pose, + "retry_duplication_pose": retry_duplication_pose, + "retry_overlap_pose": retry_overlap_pose, + "best_seq_len": data["best_seq_len"], + } + self.stat_result[scene_name] = { + "coverage_rate_seq": pred_cr_seq, + "pred_max_coverage_rate": max(pred_cr_seq), + "pred_seq_len": len(pred_cr_seq), + } + print('success rate: ', max(pred_cr_seq)) + + return result + + def voxel_downsample_with_mapping(self, point_cloud, voxel_size=0.003): + voxel_indices = np.floor(point_cloud / voxel_size).astype(np.int32) + unique_voxels, inverse, counts = np.unique(voxel_indices, axis=0, return_inverse=True, return_counts=True) + idx_sort = np.argsort(inverse) + idx_unique = idx_sort[np.cumsum(counts)-counts] + downsampled_points = point_cloud[idx_unique] + return downsampled_points, inverse + + def compute_coverage_rate(self, scanned_view_pts, new_pts, model_pts, threshold=0.005): + if new_pts is not None: + new_scanned_view_pts = scanned_view_pts + [new_pts] + else: + new_scanned_view_pts = scanned_view_pts + combined_point_cloud = np.vstack(new_scanned_view_pts) + down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold) + return ReconstructionUtil.compute_coverage_rate(model_pts, down_sampled_combined_point_cloud, threshold) + + def voxel_downsample_with_mapping(self, point_cloud, voxel_size=0.003): + voxel_indices = np.floor(point_cloud / voxel_size).astype(np.int32) + unique_voxels, inverse, counts = np.unique(voxel_indices, axis=0, return_inverse=True, return_counts=True) + idx_sort = np.argsort(inverse) + idx_unique = idx_sort[np.cumsum(counts)-counts] + downsampled_points = point_cloud[idx_unique] + return downsampled_points, inverse + + def save_inference_result(self, dataset_name, scene_name, output): + dataset_dir = os.path.join(self.output_dir, dataset_name) + if not os.path.exists(dataset_dir): + os.makedirs(dataset_dir) + output_path = os.path.join(dataset_dir, f"{scene_name}.pkl") + pickle.dump(output, open(output_path, "wb")) + with open(self.stat_result_path, "w") as f: + json.dump(self.stat_result, f) + + + def get_checkpoint_path(self, is_last=False): + return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME, + "Epoch_{}.pth".format( + self.current_epoch if self.current_epoch != -1 and not is_last else "last")) + + def load_checkpoint(self, is_last=False): + self.load(self.get_checkpoint_path(is_last)) + Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}") + if is_last: + checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME) + meta_path = os.path.join(checkpoint_root, "meta.json") + if not os.path.exists(meta_path): + raise FileNotFoundError( + "No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"])) + file_path = os.path.join(checkpoint_root, "meta.json") + with open(file_path, "r") as f: + meta = json.load(f) + self.current_epoch = meta["last_epoch"] + self.current_iter = meta["last_iter"] + + def load_experiment(self, backup_name=None): + super().load_experiment(backup_name) + self.current_epoch = self.experiments_config["epoch"] + self.load_checkpoint(is_last=(self.current_epoch == -1)) + + def create_experiment(self, backup_name=None): + super().create_experiment(backup_name) + + + def load(self, path): + state_dict = torch.load(path) + self.pipeline.load_state_dict(state_dict) + + def print_info(self): + def print_dataset(dataset: BaseDataset): + config = dataset.get_config() + name = dataset.get_name() + Log.blue(f"Dataset: {name}") + for k,v in config.items(): + Log.blue(f"\t{k}: {v}") + + super().print_info() + table_size = 70 + Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+') + Log.blue(self.pipeline) + Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+') + for i, test_set in enumerate(self.test_set_list): + Log.blue(f"test dataset {i}: ") + print_dataset(test_set) + + Log.blue(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+') + diff --git a/runners/inference_server.py b/runners/inference_server.py new file mode 100644 index 0000000..35ec910 --- /dev/null +++ b/runners/inference_server.py @@ -0,0 +1,116 @@ +import os +import json +import torch +import numpy as np +from flask import Flask, request, jsonify + +import PytorchBoot.namespace as namespace +import PytorchBoot.stereotype as stereotype +from PytorchBoot.factory import ComponentFactory + +from PytorchBoot.runners.runner import Runner +from PytorchBoot.utils import Log + +from utils.pts import PtsUtil +from beans.predict_result import PredictResult + +@stereotype.runner("inferencer_server") +class InferencerServer(Runner): + def __init__(self, config_path): + super().__init__(config_path) + + ''' Web Server ''' + self.app = Flask(__name__) + ''' Pipeline ''' + self.pipeline_name = self.config[namespace.Stereotype.PIPELINE] + self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name) + self.pipeline = self.pipeline.to(self.device) + self.pts_num = 8192 + self.voxel_size = 0.002 + + ''' Experiment ''' + self.load_experiment("inferencer_server") + + def get_input_data(self, data): + input_data = {} + scanned_pts = data["scanned_pts"] + scanned_n_to_world_pose_9d = data["scanned_n_to_world_pose_9d"] + combined_scanned_views_pts = np.concatenate(scanned_pts, axis=0) + voxel_downsampled_combined_scanned_pts = PtsUtil.voxel_downsample_point_cloud( + combined_scanned_views_pts, self.voxel_size + ) + fps_downsampled_combined_scanned_pts, fps_idx = PtsUtil.fps_downsample_point_cloud( + voxel_downsampled_combined_scanned_pts, self.pts_num, require_idx=True + ) + + input_data["scanned_pts"] = scanned_pts + input_data["scanned_n_to_world_pose_9d"] = np.asarray(scanned_n_to_world_pose_9d, dtype=np.float32) + input_data["combined_scanned_pts"] = np.asarray(fps_downsampled_combined_scanned_pts, dtype=np.float32) + return input_data + + def get_result(self, output_data): + + pred_pose_9d = output_data["pred_pose_9d"] + pred_pose_9d = np.asarray(PredictResult(pred_pose_9d.cpu().numpy(), None, cluster_params=dict(eps=0.25, min_samples=3)).candidate_9d_poses, dtype=np.float32) + result = { + "pred_pose_9d": pred_pose_9d.tolist() + } + return result + + def collate_input(self, input_data): + collated_input_data = {} + collated_input_data["scanned_pts"] = [torch.tensor(input_data["scanned_pts"], dtype=torch.float32, device=self.device)] + collated_input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(input_data["scanned_n_to_world_pose_9d"], dtype=torch.float32, device=self.device)] + collated_input_data["combined_scanned_pts"] = torch.tensor(input_data["combined_scanned_pts"], dtype=torch.float32, device=self.device).unsqueeze(0) + return collated_input_data + + def run(self): + Log.info("Loading from epoch {}.".format(self.current_epoch)) + + @self.app.route("/inference", methods=["POST"]) + def inference(): + data = request.json + input_data = self.get_input_data(data) + collated_input_data = self.collate_input(input_data) + output_data = self.pipeline.forward_test(collated_input_data) + result = self.get_result(output_data) + return jsonify(result) + + + self.app.run(host="0.0.0.0", port=5000) + + def get_checkpoint_path(self, is_last=False): + return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME, + "Epoch_{}.pth".format( + self.current_epoch if self.current_epoch != -1 and not is_last else "last")) + + def load_checkpoint(self, is_last=False): + self.load(self.get_checkpoint_path(is_last)) + Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}") + if is_last: + checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME) + meta_path = os.path.join(checkpoint_root, "meta.json") + if not os.path.exists(meta_path): + raise FileNotFoundError( + "No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"])) + file_path = os.path.join(checkpoint_root, "meta.json") + with open(file_path, "r") as f: + meta = json.load(f) + self.current_epoch = meta["last_epoch"] + self.current_iter = meta["last_iter"] + + def load_experiment(self, backup_name=None): + super().load_experiment(backup_name) + self.current_epoch = self.experiments_config["epoch"] + self.load_checkpoint(is_last=(self.current_epoch == -1)) + + def create_experiment(self, backup_name=None): + super().create_experiment(backup_name) + + + def load(self, path): + state_dict = torch.load(path) + self.pipeline.load_state_dict(state_dict) + + + diff --git a/runners/local_points_inferencer.py b/runners/local_points_inferencer.py new file mode 100644 index 0000000..b5c3f50 --- /dev/null +++ b/runners/local_points_inferencer.py @@ -0,0 +1,350 @@ +import os +import json +from utils.render import RenderUtil +from utils.pose import PoseUtil +from utils.pts import PtsUtil +from utils.reconstruction import ReconstructionUtil +from beans.predict_result import PredictResult + +import torch +from tqdm import tqdm +import numpy as np +import pickle + +from PytorchBoot.config import ConfigManager +import PytorchBoot.namespace as namespace +import PytorchBoot.stereotype as stereotype +from PytorchBoot.factory import ComponentFactory + +from PytorchBoot.dataset import BaseDataset +from PytorchBoot.runners.runner import Runner +from PytorchBoot.utils import Log +from PytorchBoot.status import status_manager +from utils.data_load import DataLoadUtil + +@stereotype.runner("local_points_inferencer") +class LocalPointsInferencer(Runner): + def __init__(self, config_path): + + super().__init__(config_path) + + self.script_path = ConfigManager.get(namespace.Stereotype.RUNNER, "blender_script_path") + self.output_dir = ConfigManager.get(namespace.Stereotype.RUNNER, "output_dir") + self.voxel_size = ConfigManager.get(namespace.Stereotype.RUNNER, "voxel_size") + self.min_new_area = ConfigManager.get(namespace.Stereotype.RUNNER, "min_new_area") + CM = 0.01 + self.min_new_pts_num = self.min_new_area * (CM / self.voxel_size) ** 2 + + ''' Pipeline ''' + self.pipeline_name = self.config[namespace.Stereotype.PIPELINE] + self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name) + self.pipeline = self.pipeline.to(self.device) + + ''' Experiment ''' + self.load_experiment("nbv_evaluator") + self.stat_result_path = os.path.join(self.output_dir, "stat.json") + if os.path.exists(self.stat_result_path): + with open(self.stat_result_path, "r") as f: + self.stat_result = json.load(f) + else: + self.stat_result = {} + + ''' Test ''' + self.test_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TEST) + self.test_dataset_name_list = self.test_config["dataset_list"] + self.test_set_list = [] + self.test_writer_list = [] + seen_name = set() + for test_dataset_name in self.test_dataset_name_list: + if test_dataset_name not in seen_name: + seen_name.add(test_dataset_name) + else: + raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name)) + test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name) + self.test_set_list.append(test_set) + self.print_info() + + + def run(self): + Log.info("Loading from epoch {}.".format(self.current_epoch)) + self.inference() + Log.success("Inference finished.") + + + def inference(self): + self.pipeline.eval() + with torch.no_grad(): + test_set: BaseDataset + for dataset_idx, test_set in enumerate(self.test_set_list): + status_manager.set_progress("inference", "inferencer", f"dataset", dataset_idx, len(self.test_set_list)) + test_set_name = test_set.get_name() + + total=int(len(test_set)) + for i in tqdm(range(total), desc=f"Processing {test_set_name}", ncols=100): + try: + data = test_set.__getitem__(i) + scene_name = data["scene_name"] + inference_result_path = os.path.join(self.output_dir, test_set_name, f"{scene_name}.pkl") + + if os.path.exists(inference_result_path): + Log.info(f"Inference result already exists for scene: {scene_name}") + continue + + status_manager.set_progress("inference", "inferencer", f"Batch[{test_set_name}]", i+1, total) + output = self.predict_sequence(data) + self.save_inference_result(test_set_name, data["scene_name"], output) + except Exception as e: + print(e) + Log.error(f"Error, {e}") + continue + + status_manager.set_progress("inference", "inferencer", f"dataset", len(self.test_set_list), len(self.test_set_list)) + + def predict_sequence(self, data, cr_increase_threshold=0, overlap_area_threshold=25, scan_points_threshold=10, max_iter=50, max_retry = 10, max_success=3): + scene_name = data["scene_name"] + Log.info(f"Processing scene: {scene_name}") + status_manager.set_status("inference", "inferencer", "scene", scene_name) + + ''' data for rendering ''' + scene_path = data["scene_path"] + O_to_L_pose = data["O_to_L_pose"] + voxel_threshold = self.voxel_size + filter_degree = 75 + down_sampled_model_pts = data["gt_pts"] + + first_frame_to_world_9d = data["first_scanned_n_to_world_pose_9d"][0] + first_frame_to_world = np.eye(4) + first_frame_to_world[:3,:3] = PoseUtil.rotation_6d_to_matrix_numpy(first_frame_to_world_9d[:6]) + first_frame_to_world[:3,3] = first_frame_to_world_9d[6:] + + ''' data for inference ''' + input_data = {} + + input_data["combined_scanned_pts"] = torch.tensor(data["first_scanned_pts"][0], dtype=torch.float32).to(self.device).unsqueeze(0) + input_data["scanned_pts"] = [torch.tensor(data["first_scanned_pts"][0], dtype=torch.float32).to(self.device).unsqueeze(0)] + input_data["scanned_pts_mask"] = [torch.zeros(input_data["combined_scanned_pts"].shape[1], dtype=torch.bool).to(self.device).unsqueeze(0)] + input_data["scanned_n_to_world_pose_9d"] = [torch.tensor(data["first_scanned_n_to_world_pose_9d"], dtype=torch.float32).to(self.device)] + input_data["mode"] = namespace.Mode.TEST + input_pts_N = input_data["combined_scanned_pts"].shape[1] + root = os.path.dirname(scene_path) + display_table_info = DataLoadUtil.get_display_table_info(root, scene_name) + radius = display_table_info["radius"] + scan_points = np.asarray(ReconstructionUtil.generate_scan_points(display_table_top=0,display_table_radius=radius)) + + first_frame_target_pts, first_frame_target_normals, first_frame_scan_points_indices = RenderUtil.render_pts(first_frame_to_world, scene_path, self.script_path, scan_points, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose) + scanned_view_pts = [first_frame_target_pts] + history_indices = [first_frame_scan_points_indices] + last_pred_cr, added_pts_num = self.compute_coverage_rate(scanned_view_pts, None, down_sampled_model_pts, threshold=voxel_threshold) + retry_duplication_pose = [] + retry_no_pts_pose = [] + retry_overlap_pose = [] + retry = 0 + pred_cr_seq = [last_pred_cr] + success = 0 + last_pts_num = PtsUtil.voxel_downsample_point_cloud(data["first_scanned_pts"][0], voxel_threshold).shape[0] + #import time + while len(pred_cr_seq) < max_iter and retry < max_retry and success < max_success: + #import ipdb; ipdb.set_trace() + Log.green(f"iter: {len(pred_cr_seq)}, retry: {retry}/{max_retry}, success: {success}/{max_success}") + combined_scanned_pts = np.vstack(scanned_view_pts) + voxel_downsampled_combined_scanned_pts_np, inverse = self.voxel_downsample_with_mapping(combined_scanned_pts, voxel_threshold) + + output = self.pipeline(input_data) + pred_pose_9d = output["pred_pose_9d"] + pred_pose = torch.eye(4, device=pred_pose_9d.device) + # # save pred_pose_9d ------ + # root = "/media/hofee/data/project/python/nbv_reconstruction/nbv_reconstruction/temp_output_result" + # scene_dir = os.path.join(root, scene_name) + # if not os.path.exists(scene_dir): + # os.makedirs(scene_dir) + # pred_9d_path = os.path.join(scene_dir,f"pred_pose_9d_{len(pred_cr_seq)}.npy") + # pts_path = os.path.join(scene_dir,f"combined_scanned_pts_{len(pred_cr_seq)}.txt") + # np_combined_scanned_pts = input_data["combined_scanned_pts"][0].cpu().numpy() + # np.save(pred_9d_path, pred_pose_9d.cpu().numpy()) + # np.savetxt(pts_path, np_combined_scanned_pts) + # # ----- ----- ----- + predict_result = PredictResult(pred_pose_9d.cpu().numpy(), input_pts=input_data["combined_scanned_pts"][0].cpu().numpy(), cluster_params=dict(eps=0.25, min_samples=3)) + # ----------------------- + # import ipdb; ipdb.set_trace() + # predict_result.visualize() + # ----------------------- + pred_pose_9d_candidates = predict_result.candidate_9d_poses + for pred_pose_9d in pred_pose_9d_candidates: + #import ipdb; ipdb.set_trace() + pred_pose_9d = torch.tensor(pred_pose_9d, dtype=torch.float32).to(self.device).unsqueeze(0) + pred_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_tensor_batch(pred_pose_9d[:,:6])[0] + pred_pose[:3,3] = pred_pose_9d[0,6:] + try: + new_target_pts, new_target_normals, new_scan_points_indices = RenderUtil.render_pts(pred_pose, scene_path, self.script_path, scan_points, voxel_threshold=voxel_threshold, filter_degree=filter_degree, nO_to_nL_pose=O_to_L_pose) + #import ipdb; ipdb.set_trace() + if not ReconstructionUtil.check_scan_points_overlap(history_indices, new_scan_points_indices, scan_points_threshold): + curr_overlap_area_threshold = overlap_area_threshold + else: + curr_overlap_area_threshold = overlap_area_threshold * 0.5 + + downsampled_new_target_pts = PtsUtil.voxel_downsample_point_cloud(new_target_pts, voxel_threshold) + overlap, _ = ReconstructionUtil.check_overlap(downsampled_new_target_pts, voxel_downsampled_combined_scanned_pts_np, overlap_area_threshold = curr_overlap_area_threshold, voxel_size=voxel_threshold, require_new_added_pts_num = True) + if not overlap: + Log.yellow("no overlap!") + retry += 1 + retry_overlap_pose.append(pred_pose.cpu().numpy().tolist()) + continue + + history_indices.append(new_scan_points_indices) + except Exception as e: + Log.error(f"Error in scene {scene_path}, {e}") + print("current pose: ", pred_pose) + print("curr_pred_cr: ", last_pred_cr) + retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist()) + retry += 1 + continue + + if new_target_pts.shape[0] == 0: + Log.red("no pts in new target") + retry_no_pts_pose.append(pred_pose.cpu().numpy().tolist()) + retry += 1 + continue + + pred_cr, _ = self.compute_coverage_rate(scanned_view_pts, new_target_pts, down_sampled_model_pts, threshold=voxel_threshold) + Log.yellow(f"{pred_cr}, {last_pred_cr}, max: , {data['seq_max_coverage_rate']}") + if pred_cr >= data["seq_max_coverage_rate"] - 1e-3: + print("max coverage rate reached!: ", pred_cr) + + + + pred_cr_seq.append(pred_cr) + scanned_view_pts.append(new_target_pts) + + input_data["scanned_n_to_world_pose_9d"] = [torch.cat([input_data["scanned_n_to_world_pose_9d"][0], pred_pose_9d], dim=0)] + + combined_scanned_pts = np.vstack(scanned_view_pts) + voxel_downsampled_combined_scanned_pts_np = PtsUtil.voxel_downsample_point_cloud(combined_scanned_pts, voxel_threshold) + random_downsampled_combined_scanned_pts_np = PtsUtil.random_downsample_point_cloud(voxel_downsampled_combined_scanned_pts_np, input_pts_N) + input_data["combined_scanned_pts"] = torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device) + input_data["scanned_pts"] = [torch.cat([input_data["scanned_pts"][0], torch.tensor(random_downsampled_combined_scanned_pts_np, dtype=torch.float32).unsqueeze(0).to(self.device)], dim=0)] + + last_pred_cr = pred_cr + pts_num = voxel_downsampled_combined_scanned_pts_np.shape[0] + Log.info(f"delta pts num:,{pts_num - last_pts_num },{pts_num}, {last_pts_num}") + + if pts_num - last_pts_num < self.min_new_pts_num and pred_cr <= data["seq_max_coverage_rate"] - 1e-2: + retry += 1 + retry_duplication_pose.append(pred_pose.cpu().numpy().tolist()) + Log.red(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}") + elif pts_num - last_pts_num < self.min_new_pts_num and pred_cr > data["seq_max_coverage_rate"] - 1e-2: + success += 1 + Log.success(f"delta pts num < {self.min_new_pts_num}:, {pts_num}, {last_pts_num}") + + last_pts_num = pts_num + + + input_data["scanned_n_to_world_pose_9d"] = input_data["scanned_n_to_world_pose_9d"][0].cpu().numpy().tolist() + result = { + "pred_pose_9d_seq": input_data["scanned_n_to_world_pose_9d"], + "combined_scanned_pts": input_data["combined_scanned_pts"], + "target_pts_seq": scanned_view_pts, + "coverage_rate_seq": pred_cr_seq, + "max_coverage_rate": data["seq_max_coverage_rate"], + "pred_max_coverage_rate": max(pred_cr_seq), + "scene_name": scene_name, + "retry_no_pts_pose": retry_no_pts_pose, + "retry_duplication_pose": retry_duplication_pose, + "retry_overlap_pose": retry_overlap_pose, + "best_seq_len": data["best_seq_len"], + } + self.stat_result[scene_name] = { + "coverage_rate_seq": pred_cr_seq, + "pred_max_coverage_rate": max(pred_cr_seq), + "pred_seq_len": len(pred_cr_seq), + } + print('success rate: ', max(pred_cr_seq)) + + return result + + def voxel_downsample_with_mapping(self, point_cloud, voxel_size=0.003): + voxel_indices = np.floor(point_cloud / voxel_size).astype(np.int32) + unique_voxels, inverse, counts = np.unique(voxel_indices, axis=0, return_inverse=True, return_counts=True) + idx_sort = np.argsort(inverse) + idx_unique = idx_sort[np.cumsum(counts)-counts] + downsampled_points = point_cloud[idx_unique] + return downsampled_points, inverse + + def compute_coverage_rate(self, scanned_view_pts, new_pts, model_pts, threshold=0.005): + if new_pts is not None: + new_scanned_view_pts = scanned_view_pts + [new_pts] + else: + new_scanned_view_pts = scanned_view_pts + combined_point_cloud = np.vstack(new_scanned_view_pts) + down_sampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud,threshold) + return ReconstructionUtil.compute_coverage_rate(model_pts, down_sampled_combined_point_cloud, threshold) + + def voxel_downsample_with_mapping(self, point_cloud, voxel_size=0.003): + voxel_indices = np.floor(point_cloud / voxel_size).astype(np.int32) + unique_voxels, inverse, counts = np.unique(voxel_indices, axis=0, return_inverse=True, return_counts=True) + idx_sort = np.argsort(inverse) + idx_unique = idx_sort[np.cumsum(counts)-counts] + downsampled_points = point_cloud[idx_unique] + return downsampled_points, inverse + + def save_inference_result(self, dataset_name, scene_name, output): + dataset_dir = os.path.join(self.output_dir, dataset_name) + if not os.path.exists(dataset_dir): + os.makedirs(dataset_dir) + output_path = os.path.join(dataset_dir, f"{scene_name}.pkl") + pickle.dump(output, open(output_path, "wb")) + with open(self.stat_result_path, "w") as f: + json.dump(self.stat_result, f) + + + def get_checkpoint_path(self, is_last=False): + return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME, + "Epoch_{}.pth".format( + self.current_epoch if self.current_epoch != -1 and not is_last else "last")) + + def load_checkpoint(self, is_last=False): + self.load(self.get_checkpoint_path(is_last)) + Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}") + if is_last: + checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME) + meta_path = os.path.join(checkpoint_root, "meta.json") + if not os.path.exists(meta_path): + raise FileNotFoundError( + "No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"])) + file_path = os.path.join(checkpoint_root, "meta.json") + with open(file_path, "r") as f: + meta = json.load(f) + self.current_epoch = meta["last_epoch"] + self.current_iter = meta["last_iter"] + + def load_experiment(self, backup_name=None): + super().load_experiment(backup_name) + self.current_epoch = self.experiments_config["epoch"] + self.load_checkpoint(is_last=(self.current_epoch == -1)) + + def create_experiment(self, backup_name=None): + super().create_experiment(backup_name) + + + def load(self, path): + state_dict = torch.load(path) + self.pipeline.load_state_dict(state_dict) + + def print_info(self): + def print_dataset(dataset: BaseDataset): + config = dataset.get_config() + name = dataset.get_name() + Log.blue(f"Dataset: {name}") + for k,v in config.items(): + Log.blue(f"\t{k}: {v}") + + super().print_info() + table_size = 70 + Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+') + Log.blue(self.pipeline) + Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+') + for i, test_set in enumerate(self.test_set_list): + Log.blue(f"test dataset {i}: ") + print_dataset(test_set) + + Log.blue(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+') + diff --git a/runners/simulator.py b/runners/simulator.py new file mode 100644 index 0000000..c38fe5d --- /dev/null +++ b/runners/simulator.py @@ -0,0 +1,456 @@ +# import pybullet as p +# import pybullet_data +import numpy as np +import os +import time +from PytorchBoot.runners.runner import Runner +import PytorchBoot.stereotype as stereotype +from PytorchBoot.config import ConfigManager +from utils.control import ControlUtil + + +@stereotype.runner("simulator") +class Simulator(Runner): + CREATE: str = "create" + SIMULATE: str = "simulate" + INIT_GRIPPER_POSE:np.ndarray = np.asarray( + [[0.41869126 ,0.87596275 , 0.23951774 , 0.36005292], + [ 0.70787907 ,-0.4800251 , 0.51813998 ,-0.40499909], + [ 0.56884584, -0.04739109 ,-0.82107382 ,0.76881103], + [ 0. , 0. , 0. , 1. ]]) + TURNTABLE_WORLD_TO_PYBULLET_WORLD:np.ndarray = np.asarray( + [[1, 0, 0, 0.8], + [0, 1, 0, 0], + [0, 0, 1, 0.5], + [0, 0, 0, 1]]) + + debug_pose = np.asarray([ + [ + 0.992167055606842, + -0.10552699863910675, + 0.06684812903404236, + -0.07388903945684433 + ], + [ + 0.10134342312812805, + 0.3670985698699951, + -0.9246448874473572, + -0.41582486033439636 + ], + [ + 0.07303514331579208, + 0.9241767525672913, + 0.37491756677627563, + 1.0754833221435547 + ], + [ + 0.0, + 0.0, + 0.0, + 1.0 + ]]) + + def __init__(self, config_path): + super().__init__(config_path) + self.config_path = config_path + self.robot_id = None + self.turntable_id = None + self.target_id = None + camera_config = ConfigManager.get("simulation", "camera") + self.camera_params = { + 'width': camera_config["width"], + 'height': camera_config["height"], + 'fov': camera_config["fov"], + 'near': camera_config["near"], + 'far': camera_config["far"] + } + self.sim_config = ConfigManager.get("simulation") + + def run(self, cmd): + print(f"Simulator run {cmd}") + if cmd == self.CREATE: + self.prepare_env() + self.create_env() + elif cmd == self.SIMULATE: + self.simulate() + + def simulate(self): + self.reset() + self.init() + debug_pose = Simulator.debug_pose + offset = np.asarray([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) + debug_pose = debug_pose @ offset + for _ in range(10000): + debug_pose_2 = np.eye(4) + debug_pose_2[0,0] = -1 + debug_pose_2[2,3] = 0.5 + self.move_to(debug_pose_2) + # Wait for the system to stabilize + for _ in range(20): # Simulate 20 steps to ensure stability + p.stepSimulation() + time.sleep(0.001) # Add small delay to ensure physics simulation + + depth_img, segm_img = self.take_picture() + p.stepSimulation() + + def prepare_env(self): + p.connect(p.GUI) + p.setAdditionalSearchPath(pybullet_data.getDataPath()) + p.setGravity(0, 0, 0) + p.loadURDF("plane.urdf") + + def create_env(self): + print(self.config) + robot_config = self.sim_config["robot"] + turntable_config = self.sim_config["turntable"] + target_config = self.sim_config["target"] + + self.robot_id = p.loadURDF( + robot_config["urdf_path"], + robot_config["initial_position"], + p.getQuaternionFromEuler(robot_config["initial_orientation"]), + useFixedBase=True + ) + + p.changeDynamics( + self.robot_id, + linkIndex=-1, + mass=0, + linearDamping=0, + angularDamping=0, + lateralFriction=0 + ) + + visual_shape_id = p.createVisualShape( + shapeType=p.GEOM_CYLINDER, + radius=turntable_config["radius"], + length=turntable_config["height"], + rgbaColor=[0.7, 0.7, 0.7, 1] + ) + collision_shape_id = p.createCollisionShape( + shapeType=p.GEOM_CYLINDER, + radius=turntable_config["radius"], + height=turntable_config["height"] + ) + self.turntable_id = p.createMultiBody( + baseMass=0, # 设置质量为0使其成为静态物体 + baseCollisionShapeIndex=collision_shape_id, + baseVisualShapeIndex=visual_shape_id, + basePosition=turntable_config["center_position"] + ) + + # 禁用转盘的动力学 + p.changeDynamics( + self.turntable_id, + -1, # -1 表示基座 + mass=0, + linearDamping=0, + angularDamping=0, + lateralFriction=0 + ) + + + obj_path = os.path.join(target_config["obj_dir"], target_config["obj_name"], "mesh.obj") + + assert os.path.exists(obj_path), f"Error: File not found at {obj_path}" + + # 加载OBJ文件作为目标物体 + target_visual = p.createVisualShape( + shapeType=p.GEOM_MESH, + fileName=obj_path, + rgbaColor=target_config["rgba_color"], + specularColor=[0.4, 0.4, 0.4], + meshScale=[target_config["scale"]] * 3 + ) + + # 使用简化的碰撞形状 + target_collision = p.createCollisionShape( + shapeType=p.GEOM_MESH, + fileName=obj_path, + meshScale=[target_config["scale"]] * 3, + flags=p.GEOM_FORCE_CONCAVE_TRIMESH # 尝试使用凹面网格 + ) + + + # 创建目标物体 + self.target_id = p.createMultiBody( + baseMass=0, # 设置质量为0使其成为静态物体 + baseCollisionShapeIndex=target_collision, + baseVisualShapeIndex=target_visual, + basePosition=[ + turntable_config["center_position"][0], + turntable_config["center_position"][1], + turntable_config["height"] + turntable_config["center_position"][2] + ], + baseOrientation=p.getQuaternionFromEuler([np.pi/2, 0, 0]) + ) + + # 禁用目标物体的动力学 + p.changeDynamics( + self.target_id, + -1, # -1 表示基座 + mass=0, + linearDamping=0, + angularDamping=0, + lateralFriction=0 + ) + + # 创建固定约束,将目标物体固定在转盘上 + cid = p.createConstraint( + parentBodyUniqueId=self.turntable_id, + parentLinkIndex=-1, # -1 表示基座 + childBodyUniqueId=self.target_id, + childLinkIndex=-1, # -1 表示基座 + jointType=p.JOINT_FIXED, + jointAxis=[0, 0, 0], + parentFramePosition=[0, 0, 0], # 相对于转盘中心的偏移 + childFramePosition=[0, 0, 0] # 相对于物体中心的偏移 + ) + + # 设置约束参数 + p.changeConstraint(cid, maxForce=100) # 设置最大力,确保约束稳定 + + def move_robot_to_pose(self, target_matrix): + # 从4x4齐次矩阵中提取位置(前3个元素) + position = target_matrix[:3, 3] + + # 从3x3旋转矩阵中提取方向四元数 + R = target_matrix[:3, :3] + + # 计算四元数的w分量 + w = np.sqrt(max(0, 1 + R[0,0] + R[1,1] + R[2,2])) / 2 + + # 避免除零错误,同时处理不同情况 + if abs(w) < 1e-8: + # 当w接近0时的特殊情况 + x = np.sqrt(max(0, 1 + R[0,0] - R[1,1] - R[2,2])) / 2 + y = np.sqrt(max(0, 1 - R[0,0] + R[1,1] - R[2,2])) / 2 + z = np.sqrt(max(0, 1 - R[0,0] - R[1,1] + R[2,2])) / 2 + + # 确定符号 + if R[2,1] - R[1,2] < 0: x = -x + if R[0,2] - R[2,0] < 0: y = -y + if R[1,0] - R[0,1] < 0: z = -z + else: + # 正常情况 + x = (R[2,1] - R[1,2]) / (4 * w) + y = (R[0,2] - R[2,0]) / (4 * w) + z = (R[1,0] - R[0,1]) / (4 * w) + + orientation = (x, y, z, w) + + # 设置IK求解参数 + num_joints = p.getNumJoints(self.robot_id) + lower_limits = [] + upper_limits = [] + joint_ranges = [] + rest_poses = [] + + # 获取关节限制和默认姿态 + for i in range(num_joints): + joint_info = p.getJointInfo(self.robot_id, i) + lower_limits.append(joint_info[8]) + upper_limits.append(joint_info[9]) + joint_ranges.append(joint_info[9] - joint_info[8]) + rest_poses.append(0) # 可以设置一个较好的默认姿态 + + # 使用增强版IK求解器,考虑碰撞避障 + joint_poses = p.calculateInverseKinematics( + self.robot_id, + 7, # end effector link index + position, + orientation, + lowerLimits=lower_limits, + upperLimits=upper_limits, + jointRanges=joint_ranges, + restPoses=rest_poses, + maxNumIterations=100, + residualThreshold=1e-4 + ) + + # 分步移动到目标位置,同时检查碰撞 + current_poses = [p.getJointState(self.robot_id, i)[0] for i in range(7)] + steps = 50 # 分50步移动 + + for step in range(steps): + # 线性插值计算中间位置 + intermediate_poses = [] + for current, target in zip(current_poses, joint_poses): + t = (step + 1) / steps + intermediate = current + (target - current) * t + intermediate_poses.append(intermediate) + + # 设置关节位置 + for i in range(7): + p.setJointMotorControl2( + self.robot_id, + i, + p.POSITION_CONTROL, + intermediate_poses[i] + ) + + # 执行一步模拟 + p.stepSimulation() + + # 检查碰撞 + if p.getContactPoints(self.robot_id, self.turntable_id): + print("检测到潜在碰撞,停止移动") + return False + + return True + + + def rotate_turntable(self, angle_degrees): + # 旋转转盘 + current_pos, current_orn = p.getBasePositionAndOrientation(self.turntable_id) + current_orn = p.getEulerFromQuaternion(current_orn) + + new_orn = list(current_orn) + new_orn[2] += np.radians(angle_degrees) + new_orn_quat = p.getQuaternionFromEuler(new_orn) + + p.resetBasePositionAndOrientation( + self.turntable_id, + current_pos, + new_orn_quat + ) + + # 同时旋转目标物体 + target_pos, target_orn = p.getBasePositionAndOrientation(self.target_id) + target_orn = p.getEulerFromQuaternion(target_orn) + + # 更新目标物体的方向 + target_orn = list(target_orn) + target_orn[2] += np.radians(angle_degrees) + target_orn_quat = p.getQuaternionFromEuler(target_orn) + + # 计算物体新的位置(绕转盘中心旋转) + turntable_center = current_pos + relative_pos = np.array(target_pos) - np.array(turntable_center) + + # 创建旋转矩阵 + theta = np.radians(angle_degrees) + rotation_matrix = np.array([ + [np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1] + ]) + + # 计算新的相对位置 + new_relative_pos = rotation_matrix.dot(relative_pos) + new_pos = np.array(turntable_center) + new_relative_pos + + # 更新目标物体的位置和方向 + p.resetBasePositionAndOrientation( + self.target_id, + new_pos, + target_orn_quat + ) + + def get_camera_pose(self): + end_effector_link = 7 # Franka末端执行器的链接索引 + state = p.getLinkState(self.robot_id, end_effector_link) + ee_pos = state[0] # 世界坐标系中的位置 + camera_orn = state[1] # 世界坐标系中的朝向(四元数) + + # 计算相机的视角矩阵 + rot_matrix = p.getMatrixFromQuaternion(camera_orn) + rot_matrix = np.array(rot_matrix).reshape(3, 3) + + # 相机的前向向量(与末端执行器的x轴对齐) + camera_forward = rot_matrix.dot(np.array([0, 0, 1])) # x轴方向 + + # 将相机位置向前偏移0.1米 + offset = 0.12 + camera_pos = np.array(ee_pos) + camera_forward * offset + camera_target = camera_pos + camera_forward + + # 相机的上向量(与末端执行器的z轴对齐) + camera_up = rot_matrix.dot(np.array([1, 0, 0])) # z轴方向 + + return camera_pos, camera_target, camera_up + + def take_picture(self): + camera_pos, camera_target, camera_up = self.get_camera_pose() + + view_matrix = p.computeViewMatrix( + cameraEyePosition=camera_pos, + cameraTargetPosition=camera_target, + cameraUpVector=camera_up + ) + + projection_matrix = p.computeProjectionMatrixFOV( + fov=self.camera_params['fov'], + aspect=self.camera_params['width'] / self.camera_params['height'], + nearVal=self.camera_params['near'], + farVal=self.camera_params['far'] + ) + + _,_,rgb_img,depth_img,segm_img = p.getCameraImage( + width=self.camera_params['width'], + height=self.camera_params['height'], + viewMatrix=view_matrix, + projectionMatrix=projection_matrix, + renderer=p.ER_BULLET_HARDWARE_OPENGL + ) + + depth_img = self.camera_params['far'] * self.camera_params['near'] / ( + self.camera_params['far'] - (self.camera_params['far'] - self.camera_params['near']) * depth_img) + + depth_img = np.array(depth_img) + segm_img = np.array(segm_img) + + return depth_img, segm_img + + def reset(self): + target_pos = [0.5, 0, 1] + target_orn = p.getQuaternionFromEuler([np.pi, 0, 0]) + target_matrix = np.eye(4) + target_matrix[:3, 3] = target_pos + target_matrix[:3, :3] = np.asarray(p.getMatrixFromQuaternion(target_orn)).reshape(3,3) + self.move_robot_to_pose(target_matrix) + + def init(self): + self.move_to(Simulator.INIT_GRIPPER_POSE) + + def move_to(self, pose: np.ndarray): + #delta_degree, min_new_cam_to_world = ControlUtil.solve_display_table_rot_and_cam_to_world(pose) + #print(delta_degree) + min_new_cam_to_pybullet_world = Simulator.TURNTABLE_WORLD_TO_PYBULLET_WORLD@pose + self.move_to_cam_pose(min_new_cam_to_pybullet_world) + #self.rotate_turntable(delta_degree) + + + + def __del__(self): + p.disconnect() + + def create_experiment(self, backup_name=None): + return super().create_experiment(backup_name) + + def load_experiment(self, backup_name=None): + super().load_experiment(backup_name) + + def move_to_cam_pose(self, camera_pose: np.ndarray): + # 从相机位姿矩阵中提取位置和旋转矩阵 + camera_pos = camera_pose[:3, 3] + R_camera = camera_pose[:3, :3] + + # 相机的朝向向量(z轴) + forward = R_camera[:, 2] + + # 由于相机与末端执行器之间有固定偏移,需要计算末端执行器位置 + # 相机在末端执行器前方0.12米 + gripper_pos = camera_pos - forward * 0.12 + + # 末端执行器的旋转矩阵需要考虑与相机坐标系的固定变换 + # 假设相机的forward对应gripper的z轴,相机的x轴对应gripper的x轴 + R_gripper = R_camera + + # 构建4x4齐次变换矩阵 + gripper_pose = np.eye(4) + gripper_pose[:3, :3] = R_gripper + gripper_pose[:3, 3] = gripper_pos + print(gripper_pose) + # 移动机器人到计算出的位姿 + return self.move_robot_to_pose(gripper_pose) \ No newline at end of file diff --git a/runners/strategy_generator.py b/runners/strategy_generator.py new file mode 100644 index 0000000..6c63dc2 --- /dev/null +++ b/runners/strategy_generator.py @@ -0,0 +1,154 @@ +import os +import json + +import numpy as np +from PytorchBoot.runners.runner import Runner +from PytorchBoot.config import ConfigManager +from PytorchBoot.utils import Log +import PytorchBoot.stereotype as stereotype +from PytorchBoot.status import status_manager + +from utils.data_load import DataLoadUtil +from utils.reconstruction import ReconstructionUtil +from utils.pts import PtsUtil + +@stereotype.runner("strategy_generator") +class StrategyGenerator(Runner): + def __init__(self, config): + super().__init__(config) + self.load_experiment("generate_strategy") + self.status_info = { + "status_manager": status_manager, + "app_name": "generate_strategy", + "runner_name": "strategy_generator" + } + self.overwrite = ConfigManager.get("runner", "generate", "overwrite") + self.seq_num = ConfigManager.get("runner","generate","seq_num") + self.overlap_area_threshold = ConfigManager.get("runner","generate","overlap_area_threshold") + self.compute_with_normal = ConfigManager.get("runner","generate","compute_with_normal") + self.scan_points_threshold = ConfigManager.get("runner","generate","scan_points_threshold") + + + + def run(self): + dataset_name_list = ConfigManager.get("runner", "generate", "dataset_list") + voxel_threshold = ConfigManager.get("runner","generate","voxel_threshold") + for dataset_idx in range(len(dataset_name_list)): + dataset_name = dataset_name_list[dataset_idx] + status_manager.set_progress("generate_strategy", "strategy_generator", "dataset", dataset_idx, len(dataset_name_list)) + root_dir = ConfigManager.get("datasets", dataset_name, "root_dir") + from_idx = ConfigManager.get("datasets",dataset_name,"from") + to_idx = ConfigManager.get("datasets",dataset_name,"to") + scene_name_list = os.listdir(root_dir) + if to_idx == -1: + to_idx = len(scene_name_list) + cnt = 0 + total = len(scene_name_list[from_idx:to_idx]) + Log.info(f"Processing Dataset: {dataset_name}, From: {from_idx}, To: {to_idx}") + for scene_name in scene_name_list[from_idx:to_idx]: + Log.info(f"({dataset_name})Processing [{cnt}/{total}]: {scene_name}") + status_manager.set_progress("generate_strategy", "strategy_generator", "scene", cnt, total) + output_label_path = DataLoadUtil.get_label_path(root_dir, scene_name,0) + if os.path.exists(output_label_path) and not self.overwrite: + Log.info(f"Scene <{scene_name}> Already Exists, Skip") + cnt += 1 + continue + + self.generate_sequence(root_dir, scene_name,voxel_threshold) + cnt += 1 + status_manager.set_progress("generate_strategy", "strategy_generator", "scene", total, total) + status_manager.set_progress("generate_strategy", "strategy_generator", "dataset", len(dataset_name_list), len(dataset_name_list)) + + def create_experiment(self, backup_name=None): + super().create_experiment(backup_name) + output_dir = os.path.join(str(self.experiment_path), "output") + os.makedirs(output_dir) + + def load_experiment(self, backup_name=None): + super().load_experiment(backup_name) + + def generate_sequence(self, root, scene_name, voxel_threshold): + status_manager.set_status("generate_strategy", "strategy_generator", "scene", scene_name) + frame_num = DataLoadUtil.get_scene_seq_length(root, scene_name) + + model_points_normals = DataLoadUtil.load_points_normals(root, scene_name) + model_pts = model_points_normals[:,:3] + down_sampled_model_pts, idx = PtsUtil.voxel_downsample_point_cloud(model_pts, voxel_threshold, require_idx=True) + down_sampled_model_nrm = model_points_normals[idx, 3:] + pts_list = [] + nrm_list = [] + scan_points_indices_list = [] + non_zero_cnt = 0 + + for frame_idx in range(frame_num): + status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_idx, frame_num) + pts_path = os.path.join(root,scene_name, "pts", f"{frame_idx}.npy") + nrm_path = os.path.join(root,scene_name, "nrm", f"{frame_idx}.npy") + idx_path = os.path.join(root,scene_name, "scan_points_indices", f"{frame_idx}.npy") + + pts = np.load(pts_path) + if self.compute_with_normal: + if pts.shape[0] == 0: + nrm = np.zeros((0,3)) + else: + nrm = np.load(nrm_path) + nrm_list.append(nrm) + pts_list.append(pts) + indices = np.load(idx_path) + scan_points_indices_list.append(indices) + if pts.shape[0] > 0: + non_zero_cnt += 1 + status_manager.set_progress("generate_strategy", "strategy_generator", "loading frame", frame_num, frame_num) + + seq_num = min(self.seq_num, non_zero_cnt) + init_view_list = [] + idx = 0 + while len(init_view_list) < seq_num and idx < len(pts_list): + if pts_list[idx].shape[0] > 50: + init_view_list.append(idx) + idx += 1 + + seq_idx = 0 + import time + for init_view in init_view_list: + status_manager.set_progress("generate_strategy", "strategy_generator", "computing sequence", seq_idx, len(init_view_list)) + start = time.time() + + if not self.compute_with_normal: + limited_useful_view, _, _ = ReconstructionUtil.compute_next_best_view_sequence(down_sampled_model_pts, pts_list, scan_points_indices_list = scan_points_indices_list,init_view=init_view, + threshold=voxel_threshold, scan_points_threshold=self.scan_points_threshold, overlap_area_threshold=self.overlap_area_threshold, status_info=self.status_info) + else: + limited_useful_view, _, _ = ReconstructionUtil.compute_next_best_view_sequence_with_normal(down_sampled_model_pts, down_sampled_model_nrm, pts_list, nrm_list, scan_points_indices_list = scan_points_indices_list,init_view=init_view, + threshold=voxel_threshold, scan_points_threshold=self.scan_points_threshold, overlap_area_threshold=self.overlap_area_threshold, status_info=self.status_info) + end = time.time() + print(f"Time: {end-start}") + data_pairs = self.generate_data_pairs(limited_useful_view) + seq_save_data = { + "data_pairs": data_pairs, + "best_sequence": limited_useful_view, + "max_coverage_rate": limited_useful_view[-1][1] + } + + status_manager.set_status("generate_strategy", "strategy_generator", "max_coverage_rate", limited_useful_view[-1][1]) + Log.success(f"Scene <{scene_name}> Finished, Max Coverage Rate: {limited_useful_view[-1][1]}, Best Sequence length: {len(limited_useful_view)}") + + output_label_path = DataLoadUtil.get_label_path(root, scene_name, seq_idx) + + + with open(output_label_path, 'w') as f: + json.dump(seq_save_data, f) + seq_idx += 1 + status_manager.set_progress("generate_strategy", "strategy_generator", "computing sequence", len(init_view_list), len(init_view_list)) + + + def generate_data_pairs(self, useful_view): + data_pairs = [] + for next_view_idx in range(1, len(useful_view)): + scanned_views = useful_view[:next_view_idx] + next_view = useful_view[next_view_idx] + data_pairs.append((scanned_views, next_view)) + return data_pairs + + + + \ No newline at end of file diff --git a/runners/view_generator.py b/runners/view_generator.py new file mode 100644 index 0000000..634ccbf --- /dev/null +++ b/runners/view_generator.py @@ -0,0 +1,19 @@ +import subprocess +from PytorchBoot.runners.runner import Runner +import PytorchBoot.stereotype as stereotype + +@stereotype.runner("view_generator") +class ViewGenerator(Runner): + def __init__(self, config_path): + super().__init__(config_path) + self.config_path = config_path + + def run(self): + result = subprocess.run(['/home/hofee/blender-4.0.2-linux-x64/blender', '-b', '-P', '../blender/run_blender.py', '--', self.config_path]) + print() + + def create_experiment(self, backup_name=None): + return super().create_experiment(backup_name) + + def load_experiment(self, backup_name=None): + super().load_experiment(backup_name) \ No newline at end of file diff --git a/utils/__pycache__/control.cpython-39.pyc b/utils/__pycache__/control.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..752b92a3bfcec81731f85e5504aca71d7ec2bf0a GIT binary patch literal 1849 zcma)6&2Jqw5VyVFkGv#ph(hxa70Lmm>?Kb~98p`T1#VBMEgV(?%UL@wo9gp!V~ly33^77K;UO6_OIX!n#R~fPbFvD*r}jxx!#Ewd8sPz*v8eVKy2_2lx2^?(zaPW8gi`r-4K8jtTchLp7-b!I%x%Q!F_rF_LwUNBv;SCz_z!6ydKlMuvg z&Q6k*A}bj$XSoEHpA78pNUekYf4-g+Qe=E`SRF_?5nS_0t*RrL>PbDXLV=RHRppH`0!)z))0OczU^KeG9lFG7Q^5a#K%z+xAnl1J7oLTs z@Gre5=)32=2fDQLE^$vg)cYOn{0-Cx&O!hdb~~?w;;HZ99}v_-!8Mds&@mL>NkGG{ zuB9(`J0ehv)Q>t4!5RJmv@|*j$EXi2$HFe`h>a-@r$gEQfRJZD|76Es53@t4d=`ZW zoqjT{c8+1^1g-2!Y^}vtxB?Gx8H?X^4@L5 zO{?TU_y*3!W@VrBbHvLK)Zd2vQ@^&b8RVZ6agP$6z dvqBC}$~k#U92TR`hbuj>8E@=kifk(*{{h+h&jkPg literal 0 HcmV?d00001 diff --git a/utils/__pycache__/data_load.cpython-39.pyc b/utils/__pycache__/data_load.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6034bea7c38d2fe6f3f2de929f04c9a7844a8d47 GIT binary patch literal 11518 zcmcgyTZ|jmd7c}G!|PtOmR6E4^4)Zjl`mBWdq^oH7QcU`oNnlwPu8_>5FX;Gl>eVC^PeK1rMZ66D_u=@S~8FD$~ zN&)(iUCclKx&3o4|NXFDsT4K*{q^a^jqiR@)BcqT<9{X!r|^WY0|-s%9nDp5-8Hyu zx+cFZ*W!21&GFlIZGPw7Jii^+!Q1E*Hj8di*UoCf6xMZ3Sl0}&&{k@C zl7Bf0Ux`|s233szIVhaM6Rzl*=4-AlG}rLW9jl3{eW>s1ZqB!bAWrj2D1AqQH0&cu|xXF99!$3gczqRWZYO1^BF(W4sD{UL0V22Ka)wi}6|D2gM=A z=YW4v92Iw8*D;HEl%5m!hHKx*PZz@prhYXdVj+h*GW%(bmWaKRyMKqa$nbDf>m<#3GVhx)*7 zYJ+@ipVq#%G;lyGA3LC0i1Ry6RA?9D0$@p4?-)DfI2T*%#x9Ot*WO%3eI?F|oUq?9 z2F19DW#w1gqQJUdZ3Uuxx254ZkjYnV6gPtBHOU?^clL+eyB$Q4a#Yli2khMv24-dP#9z zb~o3y*4Hs0zuxKAqbEO?f>ABh>!<*Sc+KcZ6{Si zJeqTBv)=S)@SSiGKr@QEgZ~BmJGzCZs6U`r_5RUu_FBr0ZWuH&{j>0ZIwz`aECVAGGh>2A~(+5nnS#$x*qW|q1Cn5pH`cwX+~K=zi)(qlwg+lZj# zpx7W$O>tS8W z`gYBd_tU5c0NjP_k`${YbJ^Oc+nd~uyQY_hUJrp}SOp+9Fm72USmBm!_8<5$yMmz& zI=SH=oqrAwZ^}U+*aRq*Sc~=8cu(6k(+$v9Ak9WMOe~V2#Anw z6v#}j!l6~r&?Pj6^i|E6*ZX%(O_%y)j*VtJnjwweo1xG#KRPyj$Bdzz!ehd?s6`eI zHb>=aWbVwBB-g30`5kq5wxpw1bNn>0#P*?yqA)Q#evsrlz(q?!E00btDe!Qxq?}w; zg;u3ngRRY{K`tcIsTs3+86w#~GPS%>=gcY4`bKMf4uy#|n&}$BlHns*9h5pL;*iv{N7FmfB%A0K;b2ERUlxaW?Chp{-P#8sIZo2Qz`;XzmkhdxGZu zBtYg2=QE?26DJRIf##{y7VQDpp&4zwK&w;iI;;2ZpE}RU{#VqSVxZahUwOmrz>B^4q(yCrJgmzY4l~@zi$$craj;9u{9-XEA81i@#CZ~TmvnS(yGbc@& zcf{hvThk|vy}&uM7n5t-11rk6ot;8#p)MEaE^3lktyY{o@MlpgPXM^(weD6RT0zrW z>t2OhdNuNcC{#Mx?zDn>r;TO`arh3c$72FfN&a1#foIIc0&#{a~S| zp4i@ca|HuO>`xZCyr0;Qx@Y!;_Jj74mG4VHruPJ%opH_h_K|@Z5r;6Zyo$)c!v8f> z*Miw6weXoIG@bvy0+=d*{>+|LW-XBD^2QV=SzS0%(B{w*yoKYcAQ_{;0+Z5T5 zF8CfQg)WTi?D5M;aS}t5Vg;d0h{d>dhzxt~j(u|6$Br#{{%{4ZnTQmu*iLaLMy#+- z#+KcM+kD4n8j#?)kdEV6L1-{h`IRYENsGIPCnPyn&KU@kRf7|@0|98~un_$yIb{_5 zke3P=>x)`!E@&L=ko7ZGnjoP4%dP)BRkIVb@zzs`-P)80Dw5fAUs^r${Aur%l^0)n ze&y_$(}*QPic3<*jUE4|pQ`2D0{V9p+ZPajc*5_&?Sm!$-^+PYwoPoL!Yo(N!j{Kpdw!ZND6DMpxfB$z>|FqWj9?Y*$MA_rnKr- z$2S43+43qPL)twhP_NP#C|Xj0jDAww+b4l?NVZZl9mmpb{Ro^}1r6N>uApP6I%wNN zx7oA|f02Wd$HS`j-+{!sydE3j^;Bq-cnp2X{RIZ5I?C-pB%)k9ziV8BS-KR*eD<${*j=OJ!WBh z8Hh^j40OB=}MZwztOM6oEkambXrvf)cVG3#p~V#wlVJ-n>q-6YqPK5UdJeUTKR4T|i$9Wlj0 zjgDUrJR0{ckPV5b25U@_v86A7Zxrm;rnx2bn|lJbM<=m}n*0g~NteLS1Js;J_v5S7 zL^@txAVAwGe+~eS3I{^314jtquSZ@t=xoakl-&|9a41WAtA*^0HLbeku_ok5Mx!K& z$>bh5EJaw&bu=m&G91GxD@mD8WoceV9pkp8T97ly#$iJ`iHWpB#^zAQWF$sq@Dy11 zU5L@UQ_Kn8L0G2jjS-vgX&)dP1^$5WkO+}acgHbM`J*uq9gMLt2BIUZ1x@4@HDs*v zBfJnhaUm6{A3~&lct_gf^CLbmq%ij*yLW;JAD(1 z7+`gYO%ZnOREqr9$wXhO-1=!8PW`W027z zv^u}xMI*_R))Su&Yx|4S@~g;Ya;CJ24jxwbp9c~_FA1GJX+5F@2QrxhW6G0$JmX12 zMZ=k3ZxMRxC6JNl2vD$r3Knnj4`IJ3!NB>z$rUTp#_fNPuA>YVJmh^07xzodptqgz zkV#>W#O`NMxFgF#&|;*nAZ%EsFfSYO1T!uNR{=3`ei&%6s!mKe$10Rg|G(Fd@l(t{ z;ue)L+ac4An;T`|0g?U(v;avJ!>q^pEMFvmBXTQP53wAZ={S)sq`}Xw*E)dDJPD{v(i0 z1?&Mz5kHjSp4$w%l<^a1Kw_jgE@9$I(kN^qmxyzt9E39tBZYcp$dGkZnNx~U(pK^# zmDz&{i)~@D0?J1=BruPW&XfX@;|eIE^oJv?!}Rq+RNN`W1;Sw9{t<0ufp}<8M16(A zv|Xd5#YJ5H*>4-dUN=r?=u?exjYGa*N-NK_KqoGNvU6AiEF9KE@qKGhj`eqS`EuH` z#62tg)e62-N$bm8U*-BL>Z`GVQK!;xRq?HAidTSF;%bb$z7twp!Mln#jAdurHenj}cwAi+3}`K; zzwOH!W1<^@7%L2R4j;_k0n|ga7%%MA8@|_-t!6853o3a+89j$?SU4R`7YmZinGuEHwoR@4*vg)hOlDK!na^ZAO;R$%32ebhrKaLVMofts zr~Dy&@G%ka*w+mCM6B)PceNb{ccQTOtkU4W$#zuODdMsK{$o_aT_y_)yhlVk5EX+( zr2rwqT_{-AL7yt$ggUVUX2splH}HM=O@RHGASt#&@NQHO8n|+dy0U@0#7lk z=xrrAWS)B4iZ8xJw1~Z#8{{t&MjTQrxH;NzICJ#6Kupc$;r$u8UGi6``>zt9vv(H* z?w2`a)~N2mY~7HrlUX%F@dxSN$)i6^*)q;ksA;gh<%v|F4#Ok(tY#@TpnJIC;m_jm zVXz;`sNCC7B!y zbY}(GC&G)sq48tuveCwG9MFffA2b}U-e)jA9|P_<@l0D2(4L}97v@bvQ5(#n4VpWi zZG%pm^t1-%(?sEIFd!(*#dGnDsJw3s=Hq#=U3Dcfu5RP1+`)}H?x=$14MS~J9h))4 zaU_J3q{y_MaRzLtDf7X+tGJj3aQ7QVQtsgv4|&m4ccvwrHSSk%XP)CbbQPcTZA4SH z1QKSrXK4Czkq8|EK+xXlBqiKwH|j{Z<4LS;;D^lGiGHbefOAty_t5PPF0+RUx<%vX z(&Q&cx9kDC2c-YzR!jQc`c?;7IG$#rZzTGaL~kW}JJI`+OtqXNutQWzV^rLQVLnwY zP%X(_+^^QFZ#DN4yyF{U3A*y1(2RULJx<4Kbh=yOKJ05qBB@zL{TNi@qP}R{4JURM zK@3W_Z0i`Q#!qZ(pPTDxdKssLqVAC>`~wQ9C{Q##6a~0(EDCT)Mxu}#*boIuHw{H0 zKPCzeWph(eQ1wa_3M>kUr;((Z6on$%>=+OZOc88e1e=#KY+hm!Fyc}hqSG!+u{ok; zI4R}#DYmMt$lu0l|eBc%Q%r1SS}G#JlqM(24}hE_Gs$@c0w?-KYPfj=eieF78gHJ6bNVjB7T)NO)2X85jwxJdEE zv&0t%(UI=y_xK{^h{;=$iNQCK!qU=ffav5;c)hjk`pImu{} zArv}(r~{2Czk?b|={W*A2M`((zunT`v7k8%@Flf2t_!%vKuu{-Mw*Y(_YN@%AXo%i z(%OYSgwCpPD|BtzO8z9wYuZ|kC~Z?gtO*^;IDAFmRG%lk3Gdk)%)~QvjlhY!?jI+TR8h;nJH<0Pj)Ld-g?xKiI$Md|QZ&XWMV24Y&NN?xKqd#h;9 zu2X{w*lhD&>ioj$s&|SZ4p}DVO8N&nlx{bC?`GEH$e@ zQdg7Z27EYvO>fARn0|n%Y&YCTvfa$v=)H?%_+%?09`%ci{z1hi9(XHFx`j)WMfoPa zIGqq&K&`w&fO)k>oyLBr!H%G-&L`xQi(5_UE7q90i#ou*cDBzq=y2#aZS;#3+EfGE zYZ%2(7R>&!%(m_`m{R;-q6w}5EZ2rTaZ34DSO_VTlEfIRe7FL`H#}bM*xyCT<@sW zz#ZI1S8(o>f1ITp?9WiQrwM$Hz!wNSPv8uJV+76<*i!^ODoO&|1insyJaowl<2MOo zUGe9H{VjpNBd}))7`vD1=7P+GpLnS==$q4C1$L1gk2!Ozn=IuVKh%S ocWx6xn-utmh)a|gB#o$eaGwPx&Ixry&w-N^E76_H_RdlXUFfILF$LqS zW7TcWo@Y`>;~JC3RZ%bajx;X|^-I4ft*iW5_SWfo$uF#wmzcCa=Aea-X)Z?=W$C)b z8QNS~mM-|pb%V1j=TKAG)q6~ip?(a#_FZ>4zUd;Xj&I}pSglEP@G62ihdcZmh_AmL z-_7fW6m?S?b<4NqNoh&@8o$c%#kRbUzD9EnxKsF=isp*os&ZWJ*IWtQ{c=Jc&|DeZ z19DOx)SL@$N*=n#Fk=PfX?a+e$9&u0w=zz*Jo1r{N97&YO!V15P=5^dcgnkT{RGC& z$m4qK0hCY3yLEXI<$L74x_nTMVf7Ct&THM!e?4low>cKNEr=K(^Gy>pzs69oa6f}P z{0JmsF^l+Bp=M(Crm<;mTAOyvQ7UXYo5cXN^fQ=Xd0Wh5u6f-f5-Y|HGdAPGme?$1 zoE1CZ${A(2Q6hp~VxyDnc9KFi zDzO?$HP#cWr&>W&9jll1G+xK+w1NSqv6ht577UCMOZt~qmlA7@x+T_n`q)Sxs6=UGbaITyH&E?gn$g zrMErhH@jGnT5Zx2=K}v7FRPwgg?hue9uz&>TTc!x`H`2c&l`|RnRH7pWTM1N+!f<| zT#Rx0jdKS-6SW&NxxLKZW*VjK==YeWNp_T91r+Gk7W8KnBDTVBK$p;|5gXvh9-!f4 zY?C++tpgK|Qa^i=#eAm>t1u&WqX3;7(mc${3{UGYTgrG{Z&1CiL5uEz8fh%yhc{Xn z@6fyPEXL**kF7~|h)Jq15kJjhE4IcN)xxrc`Xnk#J2s`5)L(RzX=}4#ZH1a5*HCv6 zp~Wlm97T)Y*q>XU>PD&QRb54bRxU_lk+}tl-EIYq_EKU7FjQ5c7SU@da=OIu*ZstR zttLjZ8>)S%s}|~Gn9q~`lJb3Ri9>v2P3@Fp*dQ1-L zF#p6m(GpIBFyTV;HlMy_nkIkjmTAnw1}6E&u{~`dAFpf-Z4${Z!cyY3?q5RzQs%%1 z4*$2d%uO4-h(&CG1A;`b1Z@{iT!96lz-u;w zZqRBr+9@cw1Bf3cyprC@e9u z$2iRoPr*(RO|<2cRhvKwU2xoqWY4B_Zz$Qz+WN_!?vW*QMr<37fpq^8b8;(=j15iGRWE;?r zga`=_V9G#}!~oh{+Gu@G!z%+=BZy|uRmi{%Ne@P8paXTX5SsVOA#kvlBTH$o;K!K1 zjQPJ9o#8N2w0>RIA9a+JA%#QwBy@s>1PbA1*V}dI8En~kBdyoHECA(ZU`B->V>C2g zsU^iLtBuHeXKj5WE28wy_AOisRgoC7g_QlW@As%s(y_b|sn%NR*g89-bZnbsob2;g z+rIKJgBPxO&9$W1uc$gmm%6ICtWKb3irOb~eDofbUerapPBVLvAB0_%dJlF9*hrjF z_tKz~^o|3yZuDZO*Y-Pp5c%>srMe2(@-W)!mHvdCuJqMwc^yC_@N}}KyLd?VQdXU! zc}H=Xa>Yg)IOFB>eS~f)ARu+0Lc(6+C;2HnC-HNEw|AgM2$K(>?(SKd6vOh1jFb{9 zjkSr}#*<`waLDKLQLs`WAI1LCk5*MMBx5qlKHSD&ARjv2(D|-*Uw}ZKsbD`1H`idm z2@DX$44ubeU$TM|6grW(jtDr(rnN%?G;ms>`NXlVaa+usK=ns0B19Ruc(LsGddnrmfzjD$#ft z_nAao!t)Y69~qIZdL8{f!5va~XT}LctRuLQDnHN{Wv=+!`yd2tr>?2q!Zbq?CM)hs z_#0qic7tsZFd?%8aFO*ABpN!~qnr?cc!O`zu7peT5TF&3v1;lE7(1)CcS!BXs;Wj} zQmn_)+ugv&_Q4OA8$G`&R13{XF%9>7*wl#}N4kCg)r%W|mA_zyki3tT5rreb)k?t@)2`iy}7fJMO? zwMCg(ZcM^cS(2r9WXGw3fEEDWjnpM#$#zte$OzsY-P;6$)lP5yQ}j=*M;w9O&^F8E z8~1#T_2{8Pk|N|IX=${ZK>@2z16aBb$aV``bZ&v27i_?-n!~g z4cAt%YbkpxP=C+gF)So|Wny3fUol`Pk02xOZyCuY(*?x(paq|dp9T*olu}?Qh_>2L zBptcjpu-OvM=3b`pu-k|k}Yfjq3!~>!X%U(n;IhY;f7v0+|Gn`Kdk8)MWC^S8`r=8 zloy`VMKlq~_9V{s=#=j>j0`DLWMUH0@#Nr9)@_DV`8Zh=fGmC9)$)U7{l*odeODLW=%$28f>!+?~L^PY8C)b%lvz{GH$D?&*;l&i>=x ziT>k)mky-&R8jCL6J-atA$Yb!h8$?m>S>WxjqQIS;HdZT=55*wqC#is?^DHD5cN}# z1*8oR5i?7qza7^5t;_`f`sC6h`ug-qDD-*i_X3gqM96`muDzysj@tE$^fr#VSNA2B z_SE!Li3D3epmWCAAedcT4fOGAqmAR_ASvze%A>i+v(=A?&^9ySUaB{(Ev+|F^K*$HFpUFJJ_`)4Bu(Hp@3+Jy9}15+{>3;+NC literal 0 HcmV?d00001 diff --git a/utils/__pycache__/pts.cpython-39.pyc b/utils/__pycache__/pts.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbbf05aac88bb7b3bca183c259bf18b9b2b6edd3 GIT binary patch literal 4387 zcmb7H&6C?m6<6!avTRS1VUo<`YlmzWFbkPoast9;VN=!rc17iQpa zcHd~*hZt{ccStT?Hk9{q=f8r8^;&@$Tl@mIxc!p7V1W(oB6ql}H3!-yUgMtDTxgeh zo%>p=L2K~lO9Qh!v|D^dx9e!1;-_`nNBast!&mi;2DCT$S$fBkIjcOU*^T->_TArs851{yBz4{+x{hbWAFCJo6h8dJ7!7RIhMWfM!9 z;vw3$v^axlV4PR>C)UKCI1?8$9O+2tP$-PVlza0Xu3*%l0M@!-HxcJwUC)ryFVkx0WNeG-d2R;^0!WTUjm4@}`; zxS%5()RdVHm9x{&GNIfkFP;x$jbVl{ zVW*#sxbmN5&*FZVC*xQ(MErP^h&W96Gu0{|Bq>iiajsaXYO_~z%Xd*8#)W!upIX(9 zt2}#}=Fwo-k9nxK%2lg;7V0(g14CTG`~Umy)*$9dwAIUY;&_WkMYJ^(*>2n^wuaA( zUY2g9kDi1g?qq3Rh*76Vs4IRNR^3}8_&MJi7Wwtzb9JeF;iKKl&C7GvuwT9kVVLKb z&DNRER++~<^E4Y@s}OVjbqg$kh$3zdqDN3p2GN4^|9=1hfB=z0Ad+>8#|kKd0#MW; z+P*kLZ90}f=o?~<#6=S4NSr6JPGW<^1rlWVQd~w$z4^K`hayUOHlS7W(-364ndNL) zZ~XQbv0Syvk^$)qgkv3r_t22WE~se(_ns-PA^J9@Q6L_>h)Qhe<-tzEltB+4UpG5*q_^&>lSrx zN%q3LZT#d1lN#LM6UBktJn6{VO@lkqmfq4DAieqgI%YJaE9+CN_4py}07~@Xs6R-* zLbMMDxh10EABPil9TE{9Fimg|+D-9gsFkC}vB+{|KgzN`JdhURO+15Jflp}4?I&r} z?_l(0oL&GMinJ-EnkW*cs-Z0}V>a`4tGm^x2BJ0#3RFfcTo3 z*3shuD2TJGGoXC_QXP|)sV+)*bCgiT6F+fQm}Thfu&FgcTbeph%*qdza*9lJdSVMR z#9I(;Yr$k+rQT!oaqYFp4U0HMdbov&IYqo-uCehqW@q^#Qgqe6Iqv1xGhbc zwo8Jr7RW$&nAdhRtONyFJF1}Lj3E%6K4lntF%ni@vR5bxo=bkWxK1Vx?r9)z-=B=J+xCfz5!0h3Vy z*D2oR%1-hmRnF61EaG-cwX{LY$|StItlm4y=YglelZSVMdc{m=&N}ajv)Csnfk@0m zbsC-@>$n%JW~e0l(Qufg-KqjpF7X})FFiY%zs_~zB3xXe9QH9b&#AaDP(z&t=+7ag zA+`O-wkp=TGV}Y%3mkqEa}Uh!Z+=bJUv72(_E)<8eyc;&HM=M%sd&?f_%j-kQIH_j zfzZkt;G&^%VZomWjxGY>w`|HN`9OClb#Z5nLd(TS)a=sQTddTXuGsyewhwI3+WU3j zi{>|EgS*^CaMdWVq|K4$_U!-ZMVRj4n+Ui=)vs~m>;IN)R`Xq6FN0II2up2$ytcWM z^Z}5~r%BP?ebyP_;4D zaosprtpK4oq8Mx*oEeAn=Vf=W zy3kckInG6`a!y6ZKgIqz6~Ts0$Xuh0$~vomZi%~~R7vhF;I63`1>k%Ak_ zZ$+!Hr}*53H`#XE!NEuP;yIpJu47)?RE|DJ8QI4YNyP<47WSgNCon#zoelHo9rfg} z8nF#{5S!nb|Fb)7{Em8mh^Jy3qP?QLFifLC9EQpd!$HPJeQGzu@bM_>S5KNKb&8}j zi1EqA1(hR$YF6;q+9e|p+QE*%V`nMG+DEW+pq35 zlleAor@Wh|(0|5cdKT5_D7uHwOB!{vVVvIJftB6h;ybgYPuf#?H_HRQMiT}b)H{f00JNgf?R^XAomyA6)DS#oG8lf#@fh9Y|*kLTQ*{4GK(IPLjVJ? zXRu2$J*QR5j;mC)rSc)RlN4E%i!QoGx%jM;4?fvbJ~+x@z2z#EUG`}!RnGT%1|Yc0 zwH>hLP5-{>e(&{r-)lO>VqU@TKX1+LeDso{{0~*ee>y5R@dVFyJ9tK|xy-thz z#(x=9ZsG|9h%3L;)Lfk{GL7l)s`u2Uf!Y!?Sw_}O)RtL}YkqeWsF{wx$sJg?vhu7iY%z)~Y!=trthj@55!krA1pcBBhKXiTjtZL~%ikrB$S z!L*vPphQ`b6`3s!pVl`YYT&|EkcP>r9GiOsmmkEM-;ecR(2dPt$LYHbE3SkN-*!X0 z-}C&?ZgqMC7FS!n?rqO^*{RlCxWnCGr`KU|rp-cVBF-@{2pzxW2C-_#l~(VL%ZIZ9 z+sAi`vd!6cZSI8b2MV9YY9774(RG>UZ0z*5Tz7*xp|jEFy zQ}WyJ-Me<$yfJ_UFz&VP)KzKos4vrjV$QINc0uuNPkQdn$U$+Q;zj0FY>}{8_@)M z450{ZTZ2{$sF_R`g=0-sWREIz*;AA~hA2`GOjzk;qLL^bWsem((h@l_ag!`#QcS`! za?F&LQ5Iz}DJ>(j$!XDtmbZJ&+kN)BcUb9%P!#Xli5aYi`{7}A*ful+YE7&uFRh7_ zKslLhoN+y$>?yMYLuj?*Y~b8+F+4WBKD!;~0+?3FotDdUG>4SXnBo&uF3|j9YJPr? zcb!f=IsBxZH1Hw@HOooG9vL7n5rau$koA1#!2p}HUXW-V26+3x9$IE|BAv{7VbH9e z9+Y%7OT!H(i)ogJ_2FvxJbjOJ&MQR7h!avpDPBF(aO}HYxRdUKmKqQfMOt71HfU)j zwE#;jsAbTScKE^{YmwuOAJ)M0)fTyr;g94>3X-PjqmY`WDa?^+W}+;NNrwZ`NAAR8 zT9kuv=HN{7!V>waLQgay48A2Mgw{4fnCU%rG*eF;(VRlYm5hM|~f1fGQJkz5E zycaQRNmOqs@BA=Y4wc=)QSn$2b8UtHZ0}?2YPm+B0IQTX^DDGj_f`HP_KBWo1?wy! z@T`c;v3_3-s=|oszJ)Ml;H=j}b9XY~ft}Rh9nAL(jCyhJV<{c%g1E?Vv?`^vDkZa) z?tU$mhupoW%qh`%ab7I64gRq>A1#RGXdV(-5^K>S)-}(Rfl9j*3oO${DR^HjiRn3o zWl>Z4hxfI;_l0#-7V}a*i>Z7rgj38C=YF8DoVWm*7fYZMqWA;l->iF(B>D@YA{Ndm zXfKKdPy|Z!W~f&zMbqL!bWSWn`y~B@dmi7>#5_VXjlt~CDbU|6o5XuoWYe_@l**9f zMZBl*zJzxL@5|!iQAJ!5m)W$a92?+p@n~Ast710M{xJ})n4#T}{^=`wa))~#j=2Bm z_i#_SCA`KtimJ$M={VanoAL~7Cy3>zHy-{DzX!=V`1y?wG)~@%p9gvHzaRYi$FJRf z{RXFV_h{yo0q3qC*16lN!++M{KVPaps$Q#q-VGjat)#6IV$o=foW53n&Dno!dS%FU z2*xmQeta@uVcqjv@a}<&zL|u{8?Mu-pG@I&e&Z70MK}oTuH$1Ck1eO)w|%GU#uI$t z`!1K2N99(!k>UJDB>+pngTQ6=JDz*DQB85n9f?=CPRth*_s@EF{lMw=JMO7iM3{w> zL*_Fy*ysTB#wK@stYej0GTiMV9KaEJSWgB$tk;bVddIaSS_Cb}9|oN?a_~!XF81MK zoZ07am2sZ%QnJ;!KxM3c8>l;IT#=wDE{{b?TdrhooOcR_WRLXDQsr^Z^S#if{l_!o z!$*KBo=D~E`CB~*+wB00BbL&XAc@A4JJ@%RAK2+P;)$-aPZ*s7ZhY=UcnnH9>6jd= z$c&(tG&r5%+!0on;Ug|kWEiS1mdHTXuo_Azh8jK@GM(x?X-S7oLMIM(jaV=D5EtoX z2S5uh@)hJN$aY4ik%&gJ9<2%=YfdrV>)X9Sh@65F7_Z0winNxyjbo8anw7=jdLf*| zoa0*BJ`9^<=~ek;O1j#-TpR@J?#Jh7|Kmy;%6ozKIm8{0W1_i4YG-`y0`-h*=OmPm zt!e7X9VE+46)a`11zr9!_Rs%<$YUUQWe9?OcR#d80vd_@&oL*s529!pV5B_4d;x!X zwW1MfGPE+vGVr5psTpdM?Ub@YlkM``+SJ_R0S1mX%{g9)fn zjb`3f{6!U6Ao4;Tupx_pHG5=HUib~g*Pc`4tTG@?1&}3~buO9{=KwF}qWLfv=0gDd z6re01YpqFuvIxkrBo>YeA}fR9M7oY;v4F7sDZt1I1AM@H70((*BiI9&d=g#2?k`9n zf^{>!ts$H*-dFdKQlP{j3-m@8v96*hii^UMVC1IAM3(?YGRy=Rxdbp$O88*qmmr&r z1S5?2NXcQ%#7DxB%d!Q4Y}9f^%JlM>Od9}YSENjzNhQ*NM4lO!h^PUyToqSEy$#^Q z6>&AH11zjSPUl4Z;<^f(Tf-*c0@ban1A5DrCDZABF zcGto(V9T!;^4CNtkn_JG^0!33NQCmX#&gL<$09Yn1~R&f zpuo*HshRFK5?J5R8e+5YbP!fXjYE=dYO`S)0nbl@x z$QgwqgsL%Js>t$x0lR+Ua)?r7viIO$K(>_#F(=C4lE}P=n?|IO-^I0%LGMuAr8`SRB$SXI zA|1{S+mI-0NTMxVLYNdC;IDJCmPu=QL9s7~R1@d>pVct?E#+J47Op0k@vbGzTez^e zr}FOz^T_%(l5Y4Mczt-Q%pQ7?%;2sPJ5c0>Ws6F~A2N7(zeS!PSO*cxJC`CAnK14{ z5Cn8$M{mmI=a9{PTWOnLLnz$Sm;%P?X2VF6fR5)oo$c5ltK(O}s-eZ&`g6^ayN?q( zx+wZUiwj-PAKfX&IkI1g#K&#{b1!eZz&?F`{deH>8$8P4OJ``4M$u*XILKRnxzlSo zo#6WVh>3Jwja=evEHLA*g2dT@?`>hZvDWuuz10hFpYQpx?sPiR2Q>=u6z)Fz9p?ad zy0?+bguQ-TJ>A3|<_!XV7Yv&v!sBu;$$eOX>P{HCUBdO5qzC!du*F2;({W>!$LfBp z9>nTlB7-}$#40Tj_d8GS@VDqmZ$RmkRPl#fPy7+&o5PcT+D*@syiWc1gfg=X9`eko zIRJrNi1bhiVbAx-C~&QF55@rd*4idVe4~FY?cqNsFb93}tydd*lF8{vx#9v6F_#BO zfbd@%m$zg~r2NEL-NzwjJkfyW3*}*o++) zYVVo>m;y%%!Y1j3)=qykhHwn literal 0 HcmV?d00001 diff --git a/utils/__pycache__/render.cpython-39.pyc b/utils/__pycache__/render.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38f814ed4dcab61e92cc116af9326b1b98782e81 GIT binary patch literal 4798 zcmaJ^TaVku73NSBMNya4-mhLecA7>Byc-8W(QE9sK?1by#@ILwN)%wZ8fuqXid2S_ zy}KO&0h|DRG2+(-eb{{|;FmrX{SWO=n5PzfDE!`sqCn$*XDIE>Mkw&`%$eKK%sJnD z6EqsGf#)lCW$*T?Vf>Q{=bsV^ckpHj1UEQ~4X^kzkLj}MnfhDuO8RSgmj0H#vi{m$ z1z$6E`c<#WsBbBD`!%nobt|s-8(xDM9~ivM?avKvKQg^0XCE2u$}`MqY_*xHzZc}e z-8A4I=TVIE%y!sGlPnj*P9CMn*jl-lXQXzh+xe%A!X3QXHbiKU<8Ye6VH3`l;Hc%< zyzEs%huggJx$%g3Rp<_{YTbqI@|xCb&}VpyPk+wPs?M8mbXK|BVZuY99VrQhX`VfY z&|14NXkFN_M#g7`vHR8U>G|vJU0Cq^{dR}yCeI8MD*qI3_8$moNVaT@?A$nJlAW+$ zo1@Cek&Y~J*q01x&l%9oWnTw>73otb~iVArmRX=+T9YZ_ha*zzLgVmRFfvn za9XbCRX}4giXF(^CLsnWfNXbVID^|P*atS zp3B{3gV(#}3G150tfLm{YSNOeUj4-E)e#lO7ZgX8FU&6f!)RJg%PAVM-Pn3+7AtCE z8OFLcJT;%uN(-m3*GMPv>^YvNOj*G}ls%&ezW1{9xR@*pk+yU%R2q zWS}ZZnjD29Rpl%XI{T_3!fY=Xgu+5YbwvaTPy0SkA0}DQAH<YIK6Wk1<7* zWKpLmh6j`4?28_3DR!h-q*za}1H}$iL;IBZ7$_8h*C+-a7u9^>IQR2F?1p*q2QpQ{ z%6O-B7rRVf7|;)fyK45L>0v5j{?rhgSf&4dxY-YR6m0IL-7wtb$llF?NPA%?-y9s~ zdug(n>^$(rg)H7o!iWCYyg5W3XPZK2)yCjZt&*R5im|sow!traoW+dUI)q_5tjU(x zlDWoO%wqO)w^V0y%rURCYv$2wXGyy84a0Q^E+$VC%4r_{8HI5kmJz>`JhVp@Y3 z94pCs&qW@&FXT}}HjqQ^mZ-pYQ6)iuFX|*3B{BSQ6(1SQ`PpMkg36gPw>mX7~Jjw=fa0sZ{0n`#q`f;!m z#%JN>EuQP;@1^~8H%!8GnCZy!mc}p@^@H8;e6?2@L+E%))Dm03Z-!k1 zPOad3bp2B3zGRSp{dVq8AN=Dezy0jz7-;ML4(-U~p=6^x*Ae}hNfF&+KrLHhpC6Y- zPL4DG#CT$Uc@I3nIkrx~EV%gxL{whx*+3e`OB#C|oZhvutxC2NlNqOF!l!zUh~ec& z=9g=uMqWFv1GgH~>Jwyyh2NtRGm_RVBm3Q32GhS^LFSEKlZX~~fR0sJ-L*llK(OFl z^K^Xh0?i|)N%mn_Tq?C}CP$da7LOtX5vxn%>b7Ef%9zA{=x zizPWDmrv{G$MxmWnq28E^9A^~ux~wK=@YqjRvulESDrBOmwctSDp%y1wBYli{uk?` zt8!IdMK-SLHh+`rFZBC2`mM=zv}cBcIqEm zpOi=XaVLK#&!p@NU*YR{Db=-)o!7A%-(ZvrzmPp?Wk`R7OcR-RNbYU$sFXf_|n8 z;z!ivO>l{RI3^T;0;1bD-`RL;n^L8`KY(M^ohV5=!#EJ0n}mVz0XIqLRl5_lZHo8dfme;gZtkZ^e0T2A#d@sxcKaToQ?p2O_P`TYu5ZQ0PDv+y&w@knx z!XPHH=6~pS1;{3;ca@We{XsX1Lp5_BS}Fo@_#PN-Cr`zpvhy^KvRs*9o)#^$vU(t< z%FgzPOew3Af=5~D00h;dMOGFc_6OLQ;SRVNkPZw`3=^*r{*ic>#`pmVVkoMHCZWi%1h=?Yh6T&`6RM@KqT@tF1Mpj{hWI{N zrXHjO`6zYAv(i{@wgws&DOis~0k+c?bh1?JK{07G4_IGyiu#rSO9_T0T7JKeww*9xyT z-ebCWdChTMF=B@6#@UHKIdS2@OU>Kf>`Pi-K6%~Z(%=vB>;o)pwgX|5TBW*aGxKZP zYT~$UvO0SM7litA)3m;}%S~7;<}kGQpIe?`EfZYQoMAIylWV1Q*innigo81fcrCW@ z+%2!MHH>33m$~NA{JCTM%!9TiZex0J2V(1eTreHqPjDymedYRoKjlM^8ei3Y|I=X* z7d7HO_0Tb{)8=C;{t`mf!LahE(+_cj;zeTVoYG108)|mOo3%#dS!XBF{ft<<1K-C# z(`4=(-YL9Ilj(NtDe*Jva}NSA8R7Q0O=Vh1@e4AR6M)8Hk$qG8mbRgTn(#m)n3^lB r7cX~$uHmX$RAFn1%LD<(%Y0VO4~w;VmoPD-dx(SWIs4VI@OAcoLBtFv literal 0 HcmV?d00001 diff --git a/utils/__pycache__/vis.cpython-39.pyc b/utils/__pycache__/vis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58cda0516ea85aae9f73b4daa324a369db03fa6e GIT binary patch literal 7216 zcmcgxOOG5^6|PrTRaaMc&->vx;Jknd#7+d5Bp4@vl>iEZB@Bz!M5UU(HREn~SM}Vg zaXd}6KsXR*VFVIl#bUfP3xu-a7w`vIu()D_SOiOK5IY1i-#NFtr)M%IfDqkNx6i$G zUsd;T(1WPKEJ-QcKtWc7{(u|GW)5b@+z+UJru$aX3uDu{A=~?mSgI+N~^+c zZp-CgujM6`u76lfT$I5~SxdrurvCC+t5@$gS`C)%Ub8>fnq%4N&G#2t3#Re9Au7Us z+Ys(;tF@Rc3Gap_e<}RdvZ%IJM9^9lwboj)*j-PSlT{HWYoab1cPp61{HC$de1Lr$ zn;WJIuf#?CMiz^26`daH-qj)>i~iLtPe^W4=h=^g%B#3?i6Y@gZ&`^gOkurk+%{Vd zWLr2aS0M9N(e6U_gwO3BJ!@p|8Y2f+rLao-wt2;P^ZO$g_iegg zy7%mnR~TKV^ypFUS|h*og}rW!s$~@?=={JE&W^QiylYPD6|UnPM!`GAJLZnX59nk5 zfabq#j%r2a&@F5CAeGLccL(Qe8fE3sqc+pnJInVS+;>>2+%pe-yc!4lt)Qh^1l`(U zSXv^uXO(C*jbp7fv{tB97F-FSqPW+K#`}okRIp6DJAbK{wc}oX8Ny_y2M_TXHfM&MFK!jXxk&?$L(4{8 z(z{YiyGaq{?Ica2JlTtSNxEBHR}~qjyNUAY<{(Bxy$9>!g*)j^rmS?JDzTLDfpX)) zAW4PtWRhQx2Z{1%&safOS*{!&?R2wFx)CZzW?7*syqpTM{bI0RM4}^=J6^Xo*(P4O z3MUoXvobsK3ED894R3-abWBt&-S0D4lZt zGz!BCO$Xo5Y~bsgbLO1om}_QW1!jnA!CWzipZpKbs!eZu@*xeSi>>?T5K2>ijVxTi zu0~d2kuAektslTn@0qX~L)freE_bX;&}o4h-JRaEsXcW{V`R;?v7JxGYr6;Pj_k3W zyTW_hFC2{{uyY5ysK7oU-7!n_hc)_}4+twC5FSZ#8j{RT)=>QSwkHYU*s5fcBt9W} z{d`Pv3DOHxTShVb{P7n}S~<1{7{`Yg#D*)$x|E$9B$A#(JtngvCew^A3IdgWT*%I? z2s$)4kZZIk0mnu~eiA>epp#QS(z8}`+;y_J{1o+mjtbuTk(#YPa*rQF?|cP?0fh?8 zC(SeXg?d5RpkC(iwnQUKzEE$r{Nom-e*{$55dkjQ2Hm??smJp27Xwx3qh)|4VJg_u$V>wiEI>=}yS06HcV!JSeB zrVXG7VL|sbMIebNf;XiI&Y@payTM_tv}Y*7Z0iRof@DDQ@udGxRQd+3$Eb9gt0o`% z@49M9C@^f!4AHK7(lB+^p|%dP33-MJ1}V~dMk;LtAED}LnjIc#wrcFixSvEFaVsLm zpf&yDH1ZQvOqE+cOI56|6dnbqG?$=* zuW>yfT*`k&-bx27<6br4VqakyAT&YnIu1X)O;symNiIBvxhS7#HW3EZm{Nyizn4hH z9p&`Fe22>%n?&Pr)O0@E?oMcSlU3`Th6?;uV=w%OI0P`pL_wYCjX? z9>Kf=y(fOor;hI@&UL&M<%77LM49aDcGAfJdI<`C2+xR?Er_ zt6LK_5#FU>+!J-o7nXim!*9Ltp~=woE$dxl)F>NeC>ra=no-u549LxO<1u5@yhHNb zO#3{y&xwW6JXWli&2p}+@7TPGb=bh~BJ5#PV%Q6)cuZ6p8?#mnjsd;6j@QK*1T>4~icVaRP5;>3`x)NLa=8q~F?p6{w5$}X-L=DY zFpkIWS$BXvNJ`AMo(0e^Px&Veik&$y=6uRo zoX>iHwh$$EGs zyhxtxsv1jrDjktjvoy)Uvu-5H+0L?_vJv|zFU`_nA~QrV9Ir5MBa~5(1Mvl7ouiOK z=P-2CqHN?yRAm5lC^P|N?m(gHsbpbwhW1sbv`h!P zpV*82=;tW$CyT>F^!pp*s$CgcY5-=P?@ zAf$5OIVb{0x|3sWow(JqPTPVVok-gr@f}Xz5-xTtbH*qDn%0oa2@AJ?_qdI`Ga|j` z;RW0>Wcf+>fEwa z#E+9iZFN(V-xHA-!H6U#IP_B4=cvXz5$aEQ@FSyqjO1uXnb)+w_i0vE`lHb2G@9Hz z(tMfG7^cY#(NCb-s$wgeQnUgrGFC(#^3b8Rs?^GKuGN^HJroif;)D;A?RlE{1jP=u zUSfLm{S`WJ3fcfC37J>><{2m!sa9w{i7X;=h-an?%qW+h0RP#zwqd?Tf1&at=Au!g zf+zxg0{w&>M$sPrgFx(!YCx2t)9xooMGJYJHuz~OzCgt*RJ=;Xx2PzoAbXj4hH0!I zi3w3Yfti+XS!S?`tAVSIYXw(mS;oTN>Q$^B`I{kW4x(eR}&zr1w?sr!DC7F%yn zj{4kHohRQKXTm4HGl}x8BIyr?&wl0Qt*tM9b!&_MrIRN`zO|KOJiG?jG!pI6c6KWw zB8g0|4f(%bD=%913ol%}#;$qejH>ATALo+hLAnRMtmt*Nk*Ft0x#M4Z&?}UkALQ~J zy}Z|{AgyA>R36wZQ)w^W25W}Tz)RWb@L4rE`G21NYY>sIU{9l}x|PyK>;!^w9n6`t zmHcYu_h^NusMw_9C` 0: + # if new_cam_to_world[0,3] > 0: + return False + x = abs(new_cam_to_world[0,3]) + y = abs(new_cam_to_world[1,3]) + tan_y_x = y/x + min_angle = 0 / 180 * np.pi + max_angle = 90 / 180 * np.pi + if tan_y_x < np.tan(min_angle) or tan_y_x > np.tan(max_angle): + return False + + return True + + @staticmethod + def solve_display_table_rot_and_cam_to_world(cam_to_world: np.ndarray) -> tuple: + if ControlUtil.check_limit(cam_to_world): + return 0, cam_to_world + else: + min_display_table_rot = 180 + min_new_cam_to_world = None + for display_table_rot in np.linspace(0.1,360, 1800): + new_world_to_world = ControlUtil.get_z_axis_rot_mat(display_table_rot) + new_cam_to_new_world = cam_to_world + new_cam_to_world = new_world_to_world @ new_cam_to_new_world + + if ControlUtil.check_limit(new_cam_to_world): + if display_table_rot < min_display_table_rot: + min_display_table_rot, min_new_cam_to_world = display_table_rot, new_cam_to_world + if abs(display_table_rot - 360) < min_display_table_rot: + min_display_table_rot, min_new_cam_to_world = display_table_rot - 360, new_cam_to_world + + if min_new_cam_to_world is None: + raise ValueError("No valid display table rotation found") + + delta_degree = min_display_table_rot - ControlUtil.curr_rotation + ControlUtil.curr_rotation = min_display_table_rot + return delta_degree, min_new_cam_to_world + + @staticmethod + def get_z_axis_rot_mat(degree): + radian = np.radians(degree) + return np.array([ + [np.cos(radian), -np.sin(radian), 0, 0], + [np.sin(radian), np.cos(radian), 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + + diff --git a/utils/data_load.py b/utils/data_load.py new file mode 100644 index 0000000..7a93fba --- /dev/null +++ b/utils/data_load.py @@ -0,0 +1,391 @@ +import os +import numpy as np +import json +import cv2 +import trimesh +import torch +import OpenEXR +import Imath +from utils.pts import PtsUtil + + +class DataLoadUtil: + TABLE_POSITION = np.asarray([0, 0, 0.8215]) + + @staticmethod + def load_exr_image(file_path): + exr_file = OpenEXR.InputFile(file_path) + header = exr_file.header() + dw = header['dataWindow'] + width = dw.max.x - dw.min.x + 1 + height = dw.max.y - dw.min.y + 1 + float_channels = ['R', 'G', 'B'] + img_data = [] + for channel in float_channels: + channel_data = exr_file.channel(channel) + img_data.append(np.frombuffer(channel_data, dtype=np.float16).reshape((height, width))) + img = np.stack(img_data, axis=-1) + return img + + @staticmethod + def get_display_table_info(root, scene_name): + scene_info = DataLoadUtil.load_scene_info(root, scene_name) + display_table_info = scene_info["display_table"] + return display_table_info + + @staticmethod + def get_display_table_top(root, scene_name): + display_table_height = DataLoadUtil.get_display_table_info(root, scene_name)[ + "height" + ] + display_table_top = DataLoadUtil.TABLE_POSITION + np.asarray( + [0, 0, display_table_height] + ) + return display_table_top + + @staticmethod + def get_path(root, scene_name, frame_idx): + path = os.path.join(root, scene_name, f"{frame_idx}") + return path + + @staticmethod + def get_label_num(root, scene_name): + label_dir = os.path.join(root, scene_name, "label") + if not os.path.exists(label_dir): + return 0 + return len(os.listdir(label_dir)) + + @staticmethod + def get_label_path(root, scene_name, seq_idx): + label_dir = os.path.join(root, scene_name, "label") + if not os.path.exists(label_dir): + os.makedirs(label_dir) + path = os.path.join(label_dir, f"{seq_idx}.json") + return path + + @staticmethod + def get_scene_seq_length(root, scene_name): + camera_params_path = os.path.join(root, scene_name, "camera_params") + return len(os.listdir(camera_params_path)) + + @staticmethod + def load_mesh_at(model_dir, object_name, world_object_pose): + model_path = os.path.join(model_dir, object_name, "mesh.obj") + mesh = trimesh.load(model_path) + mesh.apply_transform(world_object_pose) + return mesh + + @staticmethod + def get_bbox_diag(model_dir, object_name): + model_path = os.path.join(model_dir, object_name, "mesh.obj") + mesh = trimesh.load(model_path) + bbox = mesh.bounding_box.extents + diagonal_length = np.linalg.norm(bbox) + return diagonal_length + + @staticmethod + def load_scene_info(root, scene_name): + scene_info_path = os.path.join(root, scene_name, "scene_info.json") + with open(scene_info_path, "r") as f: + scene_info = json.load(f) + return scene_info + + @staticmethod + def load_target_pts_num_dict(root, scene_name): + target_pts_num_path = os.path.join(root, scene_name, "target_pts_num.json") + with open(target_pts_num_path, "r") as f: + target_pts_num_dict = json.load(f) + return target_pts_num_dict + + @staticmethod + def load_depth(path, min_depth=0.01, max_depth=5.0, binocular=False): + + def load_depth_from_real_path(real_path, min_depth, max_depth): + depth = cv2.imread(real_path, cv2.IMREAD_UNCHANGED) + depth = depth.astype(np.float32) / 65535.0 + min_depth = min_depth + max_depth = max_depth + depth_meters = min_depth + (max_depth - min_depth) * depth + return depth_meters + + if binocular: + depth_path_L = os.path.join( + os.path.dirname(path), "depth", os.path.basename(path) + "_L.png" + ) + depth_path_R = os.path.join( + os.path.dirname(path), "depth", os.path.basename(path) + "_R.png" + ) + depth_meters_L = load_depth_from_real_path( + depth_path_L, min_depth, max_depth + ) + depth_meters_R = load_depth_from_real_path( + depth_path_R, min_depth, max_depth + ) + return depth_meters_L, depth_meters_R + else: + depth_path = os.path.join( + os.path.dirname(path), "depth", os.path.basename(path) + ".png" + ) + depth_meters = load_depth_from_real_path(depth_path, min_depth, max_depth) + return depth_meters + + @staticmethod + def load_seg(path, binocular=False, left_only=False): + if binocular and not left_only: + + def clean_mask(mask_image): + green = [0, 255, 0] + red = [255, 0, 0] + threshold = 2 + mask_image = np.where( + np.abs(mask_image - green) <= threshold, green, mask_image + ) + mask_image = np.where( + np.abs(mask_image - red) <= threshold, red, mask_image + ) + return mask_image + + mask_path_L = os.path.join( + os.path.dirname(path), "mask", os.path.basename(path) + "_L.png" + ) + mask_image_L = clean_mask(cv2.imread(mask_path_L, cv2.IMREAD_UNCHANGED)) + mask_path_R = os.path.join( + os.path.dirname(path), "mask", os.path.basename(path) + "_R.png" + ) + mask_image_R = clean_mask(cv2.imread(mask_path_R, cv2.IMREAD_UNCHANGED)) + return mask_image_L, mask_image_R + else: + if binocular and left_only: + mask_path = os.path.join( + os.path.dirname(path), "mask", os.path.basename(path) + "_L.png" + ) + else: + mask_path = os.path.join( + os.path.dirname(path), "mask", os.path.basename(path) + ".png" + ) + mask_image = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED) + return mask_image + + @staticmethod + def load_normal(path, binocular=False, left_only=False, file_type="exr"): + if binocular and not left_only: + normal_path_L = os.path.join( + os.path.dirname(path), "normal", os.path.basename(path) + f"_L.{file_type}" + ) + normal_image_L = DataLoadUtil.load_exr_image(normal_path_L) + + normal_path_R = os.path.join( + os.path.dirname(path), "normal", os.path.basename(path) + f"_R.{file_type}" + ) + normal_image_R = DataLoadUtil.load_exr_image(normal_path_R) + normalized_normal_image_L = normal_image_L * 2.0 - 1.0 + normalized_normal_image_R = normal_image_R * 2.0 - 1.0 + return normalized_normal_image_L, normalized_normal_image_R + else: + if binocular and left_only: + normal_path = os.path.join( + os.path.dirname(path), "normal", os.path.basename(path) + f"_L.{file_type}" + ) + else: + normal_path = os.path.join( + os.path.dirname(path), "normal", os.path.basename(path) + f".{file_type}" + ) + normal_image = DataLoadUtil.load_exr_image(normal_path) + normalized_normal_image = normal_image * 2.0 - 1.0 + return normalized_normal_image + + @staticmethod + def load_label(path): + with open(path, "r") as f: + label_data = json.load(f) + return label_data + + @staticmethod + def load_from_preprocessed_pts(path, file_type="npy"): + npy_path = os.path.join( + os.path.dirname(path), "pts", os.path.basename(path) + "." + file_type + ) + if file_type == "txt": + pts = np.loadtxt(npy_path) + else: + pts = np.load(npy_path) + return pts + + @staticmethod + def load_from_preprocessed_nrm(path, file_type="npy"): + npy_path = os.path.join( + os.path.dirname(path), "nrm", os.path.basename(path) + "." + file_type + ) + if file_type == "txt": + nrm = np.loadtxt(npy_path) + else: + nrm = np.load(npy_path) + return nrm + + @staticmethod + def cam_pose_transformation(cam_pose_before): + offset = np.asarray([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) + cam_pose_after = cam_pose_before @ offset + return cam_pose_after + + @staticmethod + def load_cam_info(path, binocular=False, display_table_as_world_space_origin=True): + scene_dir = os.path.dirname(path) + root_dir = os.path.dirname(scene_dir) + scene_name = os.path.basename(scene_dir) + camera_params_path = os.path.join( + os.path.dirname(path), "camera_params", os.path.basename(path) + ".json" + ) + with open(camera_params_path, "r") as f: + label_data = json.load(f) + cam_to_world = np.asarray(label_data["extrinsic"]) + cam_to_world = DataLoadUtil.cam_pose_transformation(cam_to_world) + + if display_table_as_world_space_origin: + world_to_display_table = np.eye(4) + world_to_display_table[:3, 3] = -DataLoadUtil.get_display_table_top( + root_dir, scene_name + ) + cam_to_world = np.dot(world_to_display_table, cam_to_world) + cam_intrinsic = np.asarray(label_data["intrinsic"]) + cam_info = { + "cam_to_world": cam_to_world, + "cam_intrinsic": cam_intrinsic, + "far_plane": label_data["far_plane"], + "near_plane": label_data["near_plane"], + } + if binocular: + cam_to_world_R = np.asarray(label_data["extrinsic_R"]) + cam_to_world_R = DataLoadUtil.cam_pose_transformation(cam_to_world_R) + cam_to_world_O = np.asarray(label_data["extrinsic_cam_object"]) + cam_to_world_O = DataLoadUtil.cam_pose_transformation(cam_to_world_O) + if display_table_as_world_space_origin: + cam_to_world_O = np.dot(world_to_display_table, cam_to_world_O) + cam_to_world_R = np.dot(world_to_display_table, cam_to_world_R) + cam_info["cam_to_world_O"] = cam_to_world_O + cam_info["cam_to_world_R"] = cam_to_world_R + return cam_info + + @staticmethod + def get_real_cam_O_from_cam_L( + cam_L, cam_O_to_cam_L, scene_path, display_table_as_world_space_origin=True + ): + root_dir = os.path.dirname(scene_path) + scene_name = os.path.basename(scene_path) + if isinstance(cam_L, torch.Tensor): + cam_L = cam_L.cpu().numpy() + nO_to_display_table_pose = cam_L @ cam_O_to_cam_L + if display_table_as_world_space_origin: + display_table_to_world = np.eye(4) + display_table_to_world[:3, 3] = DataLoadUtil.get_display_table_top( + root_dir, scene_name + ) + nO_to_world_pose = np.dot(display_table_to_world, nO_to_display_table_pose) + nO_to_world_pose = DataLoadUtil.cam_pose_transformation(nO_to_world_pose) + return nO_to_world_pose + + @staticmethod + def get_target_point_cloud( + depth, cam_intrinsic, cam_extrinsic, mask, target_mask_label=(0, 255, 0, 255), require_full_points=False + ): + h, w = depth.shape + i, j = np.meshgrid(np.arange(w), np.arange(h), indexing="xy") + + z = depth + x = (i - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0] + y = (j - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1] + + points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3) + mask = mask.reshape(-1, 4) + + target_mask = (mask == target_mask_label).all(axis=-1) + + target_points_camera = points_camera[target_mask] + target_points_camera_aug = np.concatenate( + [target_points_camera, np.ones((target_points_camera.shape[0], 1))], axis=-1 + ) + + target_points_world = np.dot(cam_extrinsic, target_points_camera_aug.T).T[:, :3] + data = { + "points_world": target_points_world, + "points_camera": target_points_camera, + } + return data + + @staticmethod + def get_point_cloud(depth, cam_intrinsic, cam_extrinsic): + h, w = depth.shape + i, j = np.meshgrid(np.arange(w), np.arange(h), indexing="xy") + + z = depth + x = (i - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0] + y = (j - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1] + + points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3) + points_camera_aug = np.concatenate( + [points_camera, np.ones((points_camera.shape[0], 1))], axis=-1 + ) + + points_world = np.dot(cam_extrinsic, points_camera_aug.T).T[:, :3] + return {"points_world": points_world, "points_camera": points_camera} + + @staticmethod + def get_target_point_cloud_world_from_path( + path, + binocular=False, + random_downsample_N=65536, + voxel_size=0.005, + target_mask_label=(0, 255, 0, 255), + display_table_mask_label=(0, 0, 255, 255), + get_display_table_pts=False, + require_normal=False, + ): + cam_info = DataLoadUtil.load_cam_info(path, binocular=binocular) + if binocular: + depth_L, depth_R = DataLoadUtil.load_depth( + path, cam_info["near_plane"], cam_info["far_plane"], binocular=True + ) + mask_L, mask_R = DataLoadUtil.load_seg(path, binocular=True) + point_cloud_L = DataLoadUtil.get_target_point_cloud( + depth_L, + cam_info["cam_intrinsic"], + cam_info["cam_to_world"], + mask_L, + target_mask_label, + )["points_world"] + point_cloud_R = DataLoadUtil.get_target_point_cloud( + depth_R, + cam_info["cam_intrinsic"], + cam_info["cam_to_world_R"], + mask_R, + target_mask_label, + )["points_world"] + point_cloud_L = PtsUtil.random_downsample_point_cloud( + point_cloud_L, random_downsample_N + ) + point_cloud_R = PtsUtil.random_downsample_point_cloud( + point_cloud_R, random_downsample_N + ) + overlap_points = PtsUtil.get_overlapping_points( + point_cloud_L, point_cloud_R, voxel_size + ) + return overlap_points + else: + depth = DataLoadUtil.load_depth( + path, cam_info["near_plane"], cam_info["far_plane"] + ) + mask = DataLoadUtil.load_seg(path) + point_cloud = DataLoadUtil.get_target_point_cloud( + depth, cam_info["cam_intrinsic"], cam_info["cam_to_world"], mask + )["points_world"] + return point_cloud + + @staticmethod + def load_points_normals(root, scene_name, display_table_as_world_space_origin=True): + points_path = os.path.join(root, scene_name, "points_and_normals.txt") + points_normals = np.loadtxt(points_path) + if display_table_as_world_space_origin: + points_normals[:, :3] = points_normals[ + :, :3 + ] - DataLoadUtil.get_display_table_top(root, scene_name) + return points_normals diff --git a/utils/pose.py b/utils/pose.py new file mode 100644 index 0000000..01db630 --- /dev/null +++ b/utils/pose.py @@ -0,0 +1,253 @@ +import numpy as np +import torch +import torch.nn.functional as F + + +class PoseUtil: + ROTATION = 1 + TRANSLATION = 2 + SCALE = 3 + + @staticmethod + def get_uniform_translation(trans_m_min, trans_m_max, trans_unit, debug=False): + if isinstance(trans_m_min, list): + x_min, y_min, z_min = trans_m_min + x_max, y_max, z_max = trans_m_max + else: + x_min, y_min, z_min = trans_m_min, trans_m_min, trans_m_min + x_max, y_max, z_max = trans_m_max, trans_m_max, trans_m_max + + x = np.random.uniform(x_min, x_max) + y = np.random.uniform(y_min, y_max) + z = np.random.uniform(z_min, z_max) + translation = np.array([x, y, z]) + if trans_unit == "cm": + translation = translation / 100 + if debug: + print("uniform translation:", translation) + return translation + + @staticmethod + def get_uniform_rotation(rot_degree_min=0, rot_degree_max=180, debug=False): + axis = np.random.randn(3) + axis /= np.linalg.norm(axis) + theta = np.random.uniform( + rot_degree_min / 180 * np.pi, rot_degree_max / 180 * np.pi + ) + + K = np.array( + [[0, -axis[2], axis[1]], [axis[2], 0, -axis[0]], [-axis[1], axis[0], 0]] + ) + R = np.eye(3) + np.sin(theta) * K + (1 - np.cos(theta)) * (K @ K) + if debug: + print("uniform rotation:", theta * 180 / np.pi) + return R + + @staticmethod + def get_uniform_pose( + trans_min, trans_max, rot_min=0, rot_max=180, trans_unit="cm", debug=False + ): + translation = PoseUtil.get_uniform_translation( + trans_min, trans_max, trans_unit, debug + ) + rotation = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug) + pose = np.eye(4) + pose[:3, :3] = rotation + pose[:3, 3] = translation + return pose + + @staticmethod + def get_n_uniform_pose( + trans_min, + trans_max, + rot_min=0, + rot_max=180, + n=1, + trans_unit="cm", + fix=None, + contain_canonical=True, + debug=False, + ): + if fix == PoseUtil.ROTATION: + translations = np.zeros((n, 3)) + for i in range(n): + translations[i] = PoseUtil.get_uniform_translation( + trans_min, trans_max, trans_unit, debug + ) + if contain_canonical: + translations[0] = np.zeros(3) + rotations = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug) + elif fix == PoseUtil.TRANSLATION: + rotations = np.zeros((n, 3, 3)) + for i in range(n): + rotations[i] = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug) + if contain_canonical: + rotations[0] = np.eye(3) + translations = PoseUtil.get_uniform_translation( + trans_min, trans_max, trans_unit, debug + ) + else: + translations = np.zeros((n, 3)) + rotations = np.zeros((n, 3, 3)) + for i in range(n): + translations[i] = PoseUtil.get_uniform_translation( + trans_min, trans_max, trans_unit, debug + ) + for i in range(n): + rotations[i] = PoseUtil.get_uniform_rotation(rot_min, rot_max, debug) + if contain_canonical: + translations[0] = np.zeros(3) + rotations[0] = np.eye(3) + + pose = np.eye(4, 4, k=0)[np.newaxis, :].repeat(n, axis=0) + pose[:, :3, :3] = rotations + pose[:, :3, 3] = translations + + return pose + + @staticmethod + def get_n_uniform_pose_batch( + trans_min, + trans_max, + rot_min=0, + rot_max=180, + n=1, + batch_size=1, + trans_unit="cm", + fix=None, + contain_canonical=False, + debug=False, + ): + + batch_poses = [] + for i in range(batch_size): + pose = PoseUtil.get_n_uniform_pose( + trans_min, + trans_max, + rot_min, + rot_max, + n, + trans_unit, + fix, + contain_canonical, + debug, + ) + batch_poses.append(pose) + pose_batch = np.stack(batch_poses, axis=0) + return pose_batch + + @staticmethod + def get_uniform_scale(scale_min, scale_max, debug=False): + if isinstance(scale_min, list): + x_min, y_min, z_min = scale_min + x_max, y_max, z_max = scale_max + else: + x_min, y_min, z_min = scale_min, scale_min, scale_min + x_max, y_max, z_max = scale_max, scale_max, scale_max + + x = np.random.uniform(x_min, x_max) + y = np.random.uniform(y_min, y_max) + z = np.random.uniform(z_min, z_max) + scale = np.array([x, y, z]) + if debug: + print("uniform scale:", scale) + return scale + + @staticmethod + def normalize_rotation(rotation, rotation_mode): + if rotation_mode == "quat_wxyz" or rotation_mode == "quat_xyzw": + rotation /= torch.norm(rotation, dim=-1, keepdim=True) + elif rotation_mode == "rot_matrix": + rot_matrix = PoseUtil.rotation_6d_to_matrix_tensor_batch(rotation) + rotation[:, :3] = rot_matrix[:, 0, :] + rotation[:, 3:6] = rot_matrix[:, 1, :] + elif rotation_mode == "euler_xyz_sx_cx": + rot_sin_theta = rotation[:, :3] + rot_cos_theta = rotation[:, 3:6] + theta = torch.atan2(rot_sin_theta, rot_cos_theta) + rotation[:, :3] = torch.sin(theta) + rotation[:, 3:6] = torch.cos(theta) + elif rotation_mode == "euler_xyz": + pass + else: + raise NotImplementedError + return rotation + + @staticmethod + def get_pose_dim(rot_mode): + assert rot_mode in [ + "quat_wxyz", + "quat_xyzw", + "euler_xyz", + "euler_xyz_sx_cx", + "rot_matrix", + ], f"the rotation mode {rot_mode} is not supported!" + + if rot_mode == "quat_wxyz" or rot_mode == "quat_xyzw": + pose_dim = 7 + elif rot_mode == "euler_xyz": + pose_dim = 6 + elif rot_mode == "euler_xyz_sx_cx" or rot_mode == "rot_matrix": + pose_dim = 9 + else: + raise NotImplementedError + return pose_dim + + @staticmethod + def rotation_6d_to_matrix_tensor_batch(d6: torch.Tensor) -> torch.Tensor: + + a1, a2 = d6[..., :3], d6[..., 3:] + b1 = F.normalize(a1, dim=-1) + b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 + b2 = F.normalize(b2, dim=-1) + b3 = torch.cross(b1, b2, dim=-1) + return torch.stack((b1, b2, b3), dim=-2) + + @staticmethod + def matrix_to_rotation_6d_tensor_batch(matrix: torch.Tensor) -> torch.Tensor: + batch_dim = matrix.size()[:-2] + return matrix[..., :2, :].clone().reshape(batch_dim + (6,)) + + @staticmethod + def rotation_6d_to_matrix_numpy(d6): + a1, a2 = d6[:3], d6[3:] + b1 = a1 / np.linalg.norm(a1) + b2 = a2 - np.dot(b1, a2) * b1 + b2 = b2 / np.linalg.norm(b2) + b3 = np.cross(b1, b2) + return np.stack((b1, b2, b3), axis=-2) + + @staticmethod + def matrix_to_rotation_6d_numpy(matrix): + return np.copy(matrix[:2, :]).reshape((6,)) + + @staticmethod + def rotation_angle_distance(R1, R2): + R = torch.matmul(R1, R2.transpose(1, 2)) + trace = torch.diagonal(R, dim1=1, dim2=2).sum(-1) + angle = torch.acos(torch.clamp((trace - 1) / 2, -1.0, 1.0))/torch.pi*180 + return angle + + +""" ------------ Debug ------------ """ + +if __name__ == "__main__": + for _ in range(1): + PoseUtil.get_uniform_pose( + trans_min=[-25, -25, 10], + trans_max=[25, 25, 60], + rot_min=0, + rot_max=10, + debug=True, + ) + PoseUtil.get_uniform_scale(scale_min=0.25, scale_max=0.30, debug=True) + PoseUtil.get_n_uniform_pose_batch( + trans_min=[-25, -25, 10], + trans_max=[25, 25, 60], + rot_min=0, + rot_max=10, + batch_size=2, + n=2, + fix=PoseUtil.TRANSLATION, + debug=True, + ) diff --git a/utils/pts.py b/utils/pts.py new file mode 100644 index 0000000..66d3d84 --- /dev/null +++ b/utils/pts.py @@ -0,0 +1,117 @@ +import numpy as np +import open3d as o3d +import torch + +class PtsUtil: + + @staticmethod + def voxel_downsample_point_cloud(point_cloud, voxel_size=0.005, require_idx=False): + voxel_indices = np.floor(point_cloud / voxel_size).astype(np.int32) + if require_idx: + _, inverse, counts = np.unique(voxel_indices, axis=0, return_inverse=True, return_counts=True) + idx_sort = np.argsort(inverse) + idx_unique = idx_sort[np.cumsum(counts)-counts] + downsampled_points = point_cloud[idx_unique] + return downsampled_points, idx_unique + else: + unique_voxels = np.unique(voxel_indices, axis=0, return_inverse=True) + return unique_voxels[0]*voxel_size + + @staticmethod + def voxel_downsample_point_cloud_random(point_cloud, voxel_size=0.005, require_idx=False): + voxel_indices = np.floor(point_cloud / voxel_size).astype(np.int32) + unique_voxels, inverse, counts = np.unique(voxel_indices, axis=0, return_inverse=True, return_counts=True) + idx_sort = np.argsort(inverse) + idx_unique = idx_sort[np.cumsum(counts)-counts] + downsampled_points = point_cloud[idx_unique] + if require_idx: + return downsampled_points, inverse + return downsampled_points + + @staticmethod + def random_downsample_point_cloud(point_cloud, num_points, require_idx=False): + if point_cloud.shape[0] == 0: + if require_idx: + return point_cloud, np.array([]) + return point_cloud + idx = np.random.choice(len(point_cloud), num_points, replace=True) + if require_idx: + return point_cloud[idx], idx + return point_cloud[idx] + + @staticmethod + def fps_downsample_point_cloud(point_cloud, num_points, require_idx=False): + N = point_cloud.shape[0] + mask = np.zeros(N, dtype=bool) + + sampled_indices = np.zeros(num_points, dtype=int) + sampled_indices[0] = np.random.randint(0, N) + distances = np.linalg.norm(point_cloud - point_cloud[sampled_indices[0]], axis=1) + for i in range(1, num_points): + farthest_index = np.argmax(distances) + sampled_indices[i] = farthest_index + mask[farthest_index] = True + + new_distances = np.linalg.norm(point_cloud - point_cloud[farthest_index], axis=1) + distances = np.minimum(distances, new_distances) + + sampled_points = point_cloud[sampled_indices] + if require_idx: + return sampled_points, sampled_indices + return sampled_points + + @staticmethod + def random_downsample_point_cloud_tensor(point_cloud, num_points): + idx = torch.randint(0, len(point_cloud), (num_points,)) + return point_cloud[idx] + + @staticmethod + def voxelize_points(points, voxel_size): + voxel_indices = np.floor(points / voxel_size).astype(np.int32) + unique_voxels = np.unique(voxel_indices, axis=0, return_inverse=True) + return unique_voxels + + @staticmethod + def transform_point_cloud(points, pose_mat): + points_h = np.concatenate([points, np.ones((points.shape[0], 1))], axis=1) + points_h = np.dot(pose_mat, points_h.T).T + return points_h[:, :3] + + @staticmethod + def get_overlapping_points(point_cloud_L, point_cloud_R, voxel_size=0.005, require_idx=False): + voxels_L, indices_L = PtsUtil.voxelize_points(point_cloud_L, voxel_size) + voxels_R, _ = PtsUtil.voxelize_points(point_cloud_R, voxel_size) + + voxel_indices_L = voxels_L.view([("", voxels_L.dtype)] * 3) + voxel_indices_R = voxels_R.view([("", voxels_R.dtype)] * 3) + overlapping_voxels = np.intersect1d(voxel_indices_L, voxel_indices_R) + mask_L = np.isin( + indices_L, np.where(np.isin(voxel_indices_L, overlapping_voxels))[0] + ) + overlapping_points = point_cloud_L[mask_L] + if require_idx: + return overlapping_points, mask_L + return overlapping_points + + @staticmethod + def filter_points(points, normals, cam_pose, theta_limit=45, z_range=(0.2, 0.45)): + + """ filter with normal """ + normals_normalized = normals / np.linalg.norm(normals, axis=1, keepdims=True) + cos_theta = np.dot(normals_normalized, np.array([0, 0, 1])) + theta = np.arccos(cos_theta) * 180 / np.pi + idx = theta < theta_limit + filtered_sampled_points = points[idx] + filtered_normals = normals[idx] + + """ filter with z range """ + points_cam = PtsUtil.transform_point_cloud(filtered_sampled_points, np.linalg.inv(cam_pose)) + idx = (points_cam[:, 2] > z_range[0]) & (points_cam[:, 2] < z_range[1]) + z_filtered_points = filtered_sampled_points[idx] + z_filtered_normals = filtered_normals[idx] + return z_filtered_points[:, :3], z_filtered_normals + + @staticmethod + def point_to_hash(point, voxel_size): + return tuple(np.floor(point / voxel_size).astype(int)) + \ No newline at end of file diff --git a/utils/reconstruction.py b/utils/reconstruction.py new file mode 100644 index 0000000..6645ee9 --- /dev/null +++ b/utils/reconstruction.py @@ -0,0 +1,267 @@ +import numpy as np +from scipy.spatial import cKDTree +from utils.pts import PtsUtil + +class ReconstructionUtil: + + @staticmethod + def compute_coverage_rate(target_point_cloud, combined_point_cloud, threshold=0.01): + kdtree = cKDTree(combined_point_cloud) + distances, _ = kdtree.query(target_point_cloud) + covered_points_num = np.sum(distances < threshold*2) + coverage_rate = covered_points_num / target_point_cloud.shape[0] + return coverage_rate, covered_points_num + + @staticmethod + def compute_coverage_rate_with_normal(target_point_cloud, combined_point_cloud, target_normal, combined_normal, threshold=0.01, normal_threshold=0.1): + kdtree = cKDTree(combined_point_cloud) + distances, indices = kdtree.query(target_point_cloud) + is_covered_by_distance = distances < threshold*2 + normal_dots = np.einsum('ij,ij->i', target_normal, combined_normal[indices]) + is_covered_by_normal = normal_dots > normal_threshold + + pts_nrm_target = np.hstack([target_point_cloud, target_normal]) + np.savetxt("pts_nrm_target.txt", pts_nrm_target) + pts_nrm_combined = np.hstack([combined_point_cloud, combined_normal]) + np.savetxt("pts_nrm_combined.txt", pts_nrm_combined) + import ipdb; ipdb.set_trace() + covered_points_num = np.sum(is_covered_by_distance & is_covered_by_normal) + coverage_rate = covered_points_num / target_point_cloud.shape[0] + + return coverage_rate, covered_points_num + + + @staticmethod + def check_overlap(new_point_cloud, combined_point_cloud, overlap_area_threshold=25, voxel_size=0.01, require_new_added_pts_num=False): + kdtree = cKDTree(combined_point_cloud) + distances, _ = kdtree.query(new_point_cloud) + overlapping_points_num = np.sum(distances < voxel_size*2) + cm = 0.01 + voxel_size_cm = voxel_size / cm + overlap_area = overlapping_points_num * voxel_size_cm * voxel_size_cm + if require_new_added_pts_num: + return overlap_area > overlap_area_threshold, len(new_point_cloud)-np.sum(distances < voxel_size*1.2) + return overlap_area > overlap_area_threshold + + + @staticmethod + def get_new_added_points(old_combined_pts, new_pts, threshold=0.005): + if old_combined_pts.size == 0: + return new_pts + if new_pts.size == 0: + return np.array([]) + + tree = cKDTree(old_combined_pts) + distances, _ = tree.query(new_pts, k=1) + new_added_points = new_pts[distances > threshold] + return new_added_points + + @staticmethod + def compute_next_best_view_sequence(target_point_cloud, point_cloud_list, scan_points_indices_list, threshold=0.01, overlap_area_threshold=25, init_view = 0, scan_points_threshold=5, status_info=None): + selected_views = [init_view] + combined_point_cloud = point_cloud_list[init_view] + history_indices = [scan_points_indices_list[init_view]] + + max_rec_pts = np.vstack(point_cloud_list) + downsampled_max_rec_pts = PtsUtil.voxel_downsample_point_cloud(max_rec_pts, threshold) + combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(combined_point_cloud, threshold) + max_rec_pts_num = downsampled_max_rec_pts.shape[0] + max_real_rec_pts_coverage, _ = ReconstructionUtil.compute_coverage_rate(target_point_cloud, downsampled_max_rec_pts, threshold) + + new_coverage, new_covered_num = ReconstructionUtil.compute_coverage_rate(downsampled_max_rec_pts, combined_point_cloud, threshold) + current_coverage = new_coverage + current_covered_num = new_covered_num + + remaining_views = list(range(len(point_cloud_list))) + view_sequence = [(init_view, current_coverage)] + cnt_processed_view = 0 + remaining_views.remove(init_view) + curr_rec_pts_num = combined_point_cloud.shape[0] + drop_output_ratio = 0.4 + + import time + while remaining_views: + best_view = None + best_coverage_increase = -1 + best_combined_point_cloud = None + best_covered_num = 0 + + for view_index in remaining_views: + if np.random.rand() < drop_output_ratio: + continue + if point_cloud_list[view_index].shape[0] == 0: + continue + if selected_views: + new_scan_points_indices = scan_points_indices_list[view_index] + if not ReconstructionUtil.check_scan_points_overlap(history_indices, new_scan_points_indices, scan_points_threshold): + curr_overlap_area_threshold = overlap_area_threshold + else: + curr_overlap_area_threshold = overlap_area_threshold * 0.5 + + if not ReconstructionUtil.check_overlap(point_cloud_list[view_index], combined_point_cloud, overlap_area_threshold = curr_overlap_area_threshold, voxel_size=threshold): + continue + + new_combined_point_cloud = np.vstack([combined_point_cloud, point_cloud_list[view_index]]) + new_downsampled_combined_point_cloud = PtsUtil.voxel_downsample_point_cloud(new_combined_point_cloud,threshold) + new_coverage, new_covered_num = ReconstructionUtil.compute_coverage_rate(downsampled_max_rec_pts, new_downsampled_combined_point_cloud, threshold) + coverage_increase = new_coverage - current_coverage + if coverage_increase > best_coverage_increase: + best_coverage_increase = coverage_increase + best_view = view_index + best_covered_num = new_covered_num + best_combined_point_cloud = new_downsampled_combined_point_cloud + + + if best_view is not None: + if best_coverage_increase <=1e-3 or best_covered_num - current_covered_num <= 5: + break + + selected_views.append(best_view) + best_rec_pts_num = best_combined_point_cloud.shape[0] + print(f"Current rec pts num: {curr_rec_pts_num}, Best rec pts num: {best_rec_pts_num}, Best cover pts: {best_covered_num}, Max rec pts num: {max_rec_pts_num}") + print(f"Current coverage: {current_coverage+best_coverage_increase}, Best coverage increase: {best_coverage_increase}, Max Real coverage: {max_real_rec_pts_coverage}") + current_covered_num = best_covered_num + curr_rec_pts_num = best_rec_pts_num + combined_point_cloud = best_combined_point_cloud + remaining_views.remove(best_view) + history_indices.append(scan_points_indices_list[best_view]) + current_coverage += best_coverage_increase + cnt_processed_view += 1 + if status_info is not None: + sm = status_info["status_manager"] + app_name = status_info["app_name"] + runner_name = status_info["runner_name"] + sm.set_status(app_name, runner_name, "current coverage", current_coverage) + sm.set_progress(app_name, runner_name, "processed view", cnt_processed_view, len(point_cloud_list)) + + view_sequence.append((best_view, current_coverage)) + + else: + break + if status_info is not None: + sm = status_info["status_manager"] + app_name = status_info["app_name"] + runner_name = status_info["runner_name"] + sm.set_progress(app_name, runner_name, "processed view", len(point_cloud_list), len(point_cloud_list)) + return view_sequence, remaining_views, combined_point_cloud + + @staticmethod + def compute_next_best_view_sequence_with_normal(target_point_cloud, target_normal, point_cloud_list, normal_list, scan_points_indices_list, threshold=0.01, overlap_area_threshold=25, init_view = 0, scan_points_threshold=5, status_info=None): + selected_views = [init_view] + combined_point_cloud = point_cloud_list[init_view] + combined_normal = normal_list[init_view] + history_indices = [scan_points_indices_list[init_view]] + + max_rec_pts = np.vstack(point_cloud_list) + max_rec_nrm = np.vstack(normal_list) + downsampled_max_rec_pts, idx = PtsUtil.voxel_downsample_point_cloud(max_rec_pts, threshold, require_idx=True) + downsampled_max_rec_nrm = max_rec_nrm[idx] + max_rec_pts_num = downsampled_max_rec_pts.shape[0] + try: + max_real_rec_pts_coverage, _ = ReconstructionUtil.compute_coverage_rate_with_normal(target_point_cloud, downsampled_max_rec_pts, target_normal, downsampled_max_rec_nrm, threshold) + except: + import ipdb; ipdb.set_trace() + + new_coverage, new_covered_num = ReconstructionUtil.compute_coverage_rate_with_normal(downsampled_max_rec_pts, combined_point_cloud, downsampled_max_rec_nrm, combined_normal, threshold) + current_coverage = new_coverage + current_covered_num = new_covered_num + + remaining_views = list(range(len(point_cloud_list))) + view_sequence = [(init_view, current_coverage)] + cnt_processed_view = 0 + remaining_views.remove(init_view) + curr_rec_pts_num = combined_point_cloud.shape[0] + + while remaining_views: + best_view = None + best_coverage_increase = -1 + best_combined_point_cloud = None + best_combined_normal = None + best_covered_num = 0 + + for view_index in remaining_views: + if point_cloud_list[view_index].shape[0] == 0: + continue + if selected_views: + new_scan_points_indices = scan_points_indices_list[view_index] + if not ReconstructionUtil.check_scan_points_overlap(history_indices, new_scan_points_indices, scan_points_threshold): + curr_overlap_area_threshold = overlap_area_threshold + else: + curr_overlap_area_threshold = overlap_area_threshold * 0.5 + + if not ReconstructionUtil.check_overlap(point_cloud_list[view_index], combined_point_cloud, overlap_area_threshold = curr_overlap_area_threshold, voxel_size=threshold): + continue + + new_combined_point_cloud = np.vstack([combined_point_cloud, point_cloud_list[view_index]]) + new_combined_normal = np.vstack([combined_normal, normal_list[view_index]]) + new_downsampled_combined_point_cloud, idx = PtsUtil.voxel_downsample_point_cloud(new_combined_point_cloud,threshold, require_idx=True) + new_downsampled_combined_normal = new_combined_normal[idx] + new_coverage, new_covered_num = ReconstructionUtil.compute_coverage_rate_with_normal(downsampled_max_rec_pts, new_downsampled_combined_point_cloud, downsampled_max_rec_nrm, new_downsampled_combined_normal, threshold) + coverage_increase = new_coverage - current_coverage + if coverage_increase > best_coverage_increase: + best_coverage_increase = coverage_increase + best_view = view_index + best_covered_num = new_covered_num + best_combined_point_cloud = new_downsampled_combined_point_cloud + best_combined_normal = new_downsampled_combined_normal + + + if best_view is not None: + if best_coverage_increase <=1e-3 or best_covered_num - current_covered_num <= 5: + break + + selected_views.append(best_view) + best_rec_pts_num = best_combined_point_cloud.shape[0] + print(f"Current rec pts num: {curr_rec_pts_num}, Best rec pts num: {best_rec_pts_num}, Best cover pts: {best_covered_num}, Max rec pts num: {max_rec_pts_num}") + print(f"Current coverage: {current_coverage}, Best coverage increase: {best_coverage_increase}, Max Real coverage: {max_real_rec_pts_coverage}") + current_covered_num = best_covered_num + curr_rec_pts_num = best_rec_pts_num + combined_point_cloud = best_combined_point_cloud + combined_normal = best_combined_normal + remaining_views.remove(best_view) + history_indices.append(scan_points_indices_list[best_view]) + current_coverage += best_coverage_increase + cnt_processed_view += 1 + if status_info is not None: + sm = status_info["status_manager"] + app_name = status_info["app_name"] + runner_name = status_info["runner_name"] + sm.set_status(app_name, runner_name, "current coverage", current_coverage) + sm.set_progress(app_name, runner_name, "processed view", cnt_processed_view, len(point_cloud_list)) + + view_sequence.append((best_view, current_coverage)) + + else: + break + if status_info is not None: + sm = status_info["status_manager"] + app_name = status_info["app_name"] + runner_name = status_info["runner_name"] + sm.set_progress(app_name, runner_name, "processed view", len(point_cloud_list), len(point_cloud_list)) + return view_sequence, remaining_views, combined_point_cloud + + + @staticmethod + def generate_scan_points(display_table_top, display_table_radius, min_distance=0.03, max_points_num = 500, max_attempts = 1000): + points = [] + attempts = 0 + while len(points) < max_points_num and attempts < max_attempts: + angle = np.random.uniform(0, 2 * np.pi) + r = np.random.uniform(0, display_table_radius) + x = r * np.cos(angle) + y = r * np.sin(angle) + z = display_table_top + new_point = (x, y, z) + if all(np.linalg.norm(np.array(new_point) - np.array(existing_point)) >= min_distance for existing_point in points): + points.append(new_point) + attempts += 1 + return points + + @staticmethod + def check_scan_points_overlap(history_indices, indices2, threshold=5): + for indices1 in history_indices: + if len(set(indices1).intersection(set(indices2))) >= threshold: + return True + return False + + \ No newline at end of file diff --git a/utils/render.py b/utils/render.py new file mode 100644 index 0000000..a47e171 --- /dev/null +++ b/utils/render.py @@ -0,0 +1,136 @@ + +import os +import json +import time +import subprocess +import tempfile +import shutil +import numpy as np +from utils.data_load import DataLoadUtil +from utils.reconstruction import ReconstructionUtil +from utils.pts import PtsUtil +class RenderUtil: + target_mask_label = (0, 255, 0) + display_table_mask_label = (0, 0, 255) + random_downsample_N = 32768 + min_z = 0.2 + max_z = 0.5 + + @staticmethod + def get_world_points_and_normal(depth, mask, normal, cam_intrinsic, cam_extrinsic, random_downsample_N): + z = depth[mask] + i, j = np.nonzero(mask) + x = (j - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0] + y = (i - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1] + + points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3) + normal_camera = normal[mask].reshape(-1, 3) + sampled_target_points, idx = PtsUtil.random_downsample_point_cloud( + points_camera, random_downsample_N, require_idx=True + ) + if len(sampled_target_points) == 0: + return np.zeros((0, 3)), np.zeros((0, 3)) + sampled_normal_camera = normal_camera[idx] + + points_camera_aug = np.concatenate((sampled_target_points, np.ones((sampled_target_points.shape[0], 1))), axis=-1) + points_camera_world = np.dot(cam_extrinsic, points_camera_aug.T).T[:, :3] + + return points_camera_world, sampled_normal_camera + + @staticmethod + def get_world_points(depth, mask, cam_intrinsic, cam_extrinsic, random_downsample_N): + z = depth[mask] + i, j = np.nonzero(mask) + x = (j - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0] + y = (i - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1] + + points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3) + sampled_target_points = PtsUtil.random_downsample_point_cloud( + points_camera, random_downsample_N + ) + points_camera_aug = np.concatenate((sampled_target_points, np.ones((sampled_target_points.shape[0], 1))), axis=-1) + points_camera_world = np.dot(cam_extrinsic, points_camera_aug.T).T[:, :3] + + return points_camera_world + + @staticmethod + def get_scan_points_indices(scan_points, mask, display_table_mask_label, cam_intrinsic, cam_extrinsic): + scan_points_homogeneous = np.hstack((scan_points, np.ones((scan_points.shape[0], 1)))) + points_camera = np.dot(np.linalg.inv(cam_extrinsic), scan_points_homogeneous.T).T[:, :3] + points_image_homogeneous = np.dot(cam_intrinsic, points_camera.T).T + points_image_homogeneous /= points_image_homogeneous[:, 2:] + pixel_x = points_image_homogeneous[:, 0].astype(int) + pixel_y = points_image_homogeneous[:, 1].astype(int) + h, w = mask.shape[:2] + valid_indices = (pixel_x >= 0) & (pixel_x < w) & (pixel_y >= 0) & (pixel_y < h) + mask_colors = mask[pixel_y[valid_indices], pixel_x[valid_indices]] + selected_points_indices = np.where((mask_colors == display_table_mask_label).all(axis=-1))[0] + selected_points_indices = np.where(valid_indices)[0][selected_points_indices] + return selected_points_indices + + @staticmethod + def render_pts(cam_pose, scene_path, script_path, scan_points, voxel_threshold=0.005, filter_degree=75, nO_to_nL_pose=None, require_full_scene=False): + #import ipdb; ipdb.set_trace() + nO_to_world_pose = DataLoadUtil.get_real_cam_O_from_cam_L(cam_pose, nO_to_nL_pose, scene_path=scene_path) + + + with tempfile.TemporaryDirectory() as temp_dir: + params = { + "cam_pose": nO_to_world_pose.tolist(), + "scene_path": scene_path + } + scene_info_path = os.path.join(scene_path, "scene_info.json") + shutil.copy(scene_info_path, os.path.join(temp_dir, "scene_info.json")) + params_data_path = os.path.join(temp_dir, "params.json") + with open(params_data_path, 'w') as f: + json.dump(params, f) + result = subprocess.run([ + '/home/hofee/blender-4.0.2-linux-x64/blender', '-b', '-P', script_path, '--', temp_dir + ], capture_output=True, text=True) + #print(result) + #import ipdb; ipdb.set_trace() + path = os.path.join(temp_dir, "tmp") + cam_info = DataLoadUtil.load_cam_info(path, binocular=True) + depth_L, depth_R = DataLoadUtil.load_depth( + path, cam_info["near_plane"], + cam_info["far_plane"], + binocular=True + ) + mask_L, mask_R = DataLoadUtil.load_seg(path, binocular=True) + normal_L = DataLoadUtil.load_normal(path, binocular=True, left_only=True) + ''' target points ''' + mask_img_L = mask_L + mask_img_R = mask_R + + target_mask_img_L = (mask_L == RenderUtil.target_mask_label).all(axis=-1) + target_mask_img_R = (mask_R == RenderUtil.target_mask_label).all(axis=-1) + + + sampled_target_points_L, sampled_target_normal_L = RenderUtil.get_world_points_and_normal(depth_L,target_mask_img_L,normal_L, cam_info["cam_intrinsic"], cam_info["cam_to_world"], RenderUtil.random_downsample_N) + sampled_target_points_R = RenderUtil.get_world_points(depth_R, target_mask_img_R, cam_info["cam_intrinsic"], cam_info["cam_to_world_R"], RenderUtil.random_downsample_N ) + + + has_points = sampled_target_points_L.shape[0] > 0 and sampled_target_points_R.shape[0] > 0 + if has_points: + target_points, overlap_idx = PtsUtil.get_overlapping_points( + sampled_target_points_L, sampled_target_points_R, voxel_threshold, require_idx=True + ) + sampled_target_normal_L = sampled_target_normal_L[overlap_idx] + + if has_points: + has_points = target_points.shape[0] > 0 + + if has_points: + target_points, target_normals = PtsUtil.filter_points( + target_points, sampled_target_normal_L, cam_info["cam_to_world"], theta_limit = filter_degree, z_range=(RenderUtil.min_z, RenderUtil.max_z) + ) + + + scan_points_indices_L = RenderUtil.get_scan_points_indices(scan_points, mask_img_L, RenderUtil.display_table_mask_label, cam_info["cam_intrinsic"], cam_info["cam_to_world"]) + scan_points_indices_R = RenderUtil.get_scan_points_indices(scan_points, mask_img_R, RenderUtil.display_table_mask_label, cam_info["cam_intrinsic"], cam_info["cam_to_world_R"]) + scan_points_indices = np.intersect1d(scan_points_indices_L, scan_points_indices_R) + if not has_points: + target_points = np.zeros((0, 3)) + target_normals = np.zeros((0, 3)) + #import ipdb; ipdb.set_trace() + return target_points, target_normals, scan_points_indices \ No newline at end of file diff --git a/utils/vis.py b/utils/vis.py new file mode 100644 index 0000000..d739b7d --- /dev/null +++ b/utils/vis.py @@ -0,0 +1,208 @@ +import numpy as np +import matplotlib.pyplot as plt +import sys +import os +import trimesh + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from utils.data_load import DataLoadUtil +from utils.pts import PtsUtil +from utils.pose import PoseUtil + +class visualizeUtil: + + @staticmethod + def save_all_cam_pos_and_cam_axis(root, scene, output_dir): + length = DataLoadUtil.get_scene_seq_length(root, scene) + all_cam_pos = [] + all_cam_axis = [] + for i in range(length): + path = DataLoadUtil.get_path(root, scene, i) + cam_info = DataLoadUtil.load_cam_info(path, binocular=True) + cam_pose = cam_info["cam_to_world"] + cam_pos = cam_pose[:3, 3] + cam_axis = cam_pose[:3, 2] + + num_samples = 10 + sample_points = [cam_pos + 0.02*t * cam_axis for t in range(num_samples)] + sample_points = np.array(sample_points) + + all_cam_pos.append(cam_pos) + all_cam_axis.append(sample_points) + + all_cam_pos = np.array(all_cam_pos) + all_cam_axis = np.array(all_cam_axis).reshape(-1, 3) + np.savetxt(os.path.join(output_dir, "all_cam_pos.txt"), all_cam_pos) + np.savetxt(os.path.join(output_dir, "all_cam_axis.txt"), all_cam_axis) + + @staticmethod + def get_cam_pose_and_cam_axis(cam_pose, is_6d_pose): + if is_6d_pose: + matrix_cam_pose = np.eye(4) + matrix_cam_pose[:3,:3] = PoseUtil.rotation_6d_to_matrix_numpy(cam_pose[:6]) + matrix_cam_pose[:3, 3] = cam_pose[6:] + else: + matrix_cam_pose = cam_pose + cam_pos = matrix_cam_pose[:3, 3] + cam_axis = matrix_cam_pose[:3, 2] + num_samples = 10 + sample_points = [cam_pos + 0.02*t * cam_axis for t in range(num_samples)] + sample_points = np.array(sample_points) + return cam_pos, sample_points + + @staticmethod + def save_all_combined_pts(root, scene, output_dir): + length = DataLoadUtil.get_scene_seq_length(root, scene) + all_combined_pts = [] + for i in range(length): + path = DataLoadUtil.get_path(root, scene, i) + pts = DataLoadUtil.load_from_preprocessed_pts(path,"npy") + if pts.shape[0] == 0: + continue + all_combined_pts.append(pts) + all_combined_pts = np.vstack(all_combined_pts) + downsampled_all_pts = PtsUtil.voxel_downsample_point_cloud(all_combined_pts, 0.001) + np.savetxt(os.path.join(output_dir, "all_combined_pts.txt"), downsampled_all_pts) + + @staticmethod + def save_seq_cam_pos_and_cam_axis(root, scene, frame_idx_list, output_dir): + all_cam_pos = [] + all_cam_axis = [] + for i in frame_idx_list: + path = DataLoadUtil.get_path(root, scene, i) + cam_info = DataLoadUtil.load_cam_info(path, binocular=True) + cam_pose = cam_info["cam_to_world"] + cam_pos = cam_pose[:3, 3] + cam_axis = cam_pose[:3, 2] + + num_samples = 10 + sample_points = [cam_pos + 0.02*t * cam_axis for t in range(num_samples)] + sample_points = np.array(sample_points) + + all_cam_pos.append(cam_pos) + all_cam_axis.append(sample_points) + + all_cam_pos = np.array(all_cam_pos) + all_cam_axis = np.array(all_cam_axis).reshape(-1, 3) + np.savetxt(os.path.join(output_dir, "seq_cam_pos.txt"), all_cam_pos) + np.savetxt(os.path.join(output_dir, "seq_cam_axis.txt"), all_cam_axis) + + @staticmethod + def save_seq_combined_pts(root, scene, frame_idx_list, output_dir): + all_combined_pts = [] + for i in frame_idx_list: + path = DataLoadUtil.get_path(root, scene, i) + pts = DataLoadUtil.load_from_preprocessed_pts(path,"npy") + if pts.shape[0] == 0: + continue + all_combined_pts.append(pts) + all_combined_pts = np.vstack(all_combined_pts) + downsampled_all_pts = PtsUtil.voxel_downsample_point_cloud(all_combined_pts, 0.001) + np.savetxt(os.path.join(output_dir, "seq_combined_pts.txt"), downsampled_all_pts) + + @staticmethod + def save_target_mesh_at_world_space( + root, model_dir, scene_name, display_table_as_world_space_origin=True + ): + scene_info = DataLoadUtil.load_scene_info(root, scene_name) + target_name = scene_info["target_name"] + transformation = scene_info[target_name] + if display_table_as_world_space_origin: + location = transformation["location"] - DataLoadUtil.get_display_table_top( + root, scene_name + ) + else: + location = transformation["location"] + rotation_euler = transformation["rotation_euler"] + pose_mat = trimesh.transformations.euler_matrix(*rotation_euler) + pose_mat[:3, 3] = location + + mesh = DataLoadUtil.load_mesh_at(model_dir, target_name, pose_mat) + mesh_dir = os.path.join(root, scene_name, "mesh") + if not os.path.exists(mesh_dir): + os.makedirs(mesh_dir) + model_path = os.path.join(mesh_dir, "world_target_mesh.obj") + mesh.export(model_path) + + @staticmethod + def save_points_and_normals(root, scene, frame_idx, output_dir, binocular=False): + target_mask_label = (0, 255, 0, 255) + path = DataLoadUtil.get_path(root, scene, frame_idx) + cam_info = DataLoadUtil.load_cam_info(path, binocular=binocular, display_table_as_world_space_origin=False) + depth = DataLoadUtil.load_depth( + path, cam_info["near_plane"], + cam_info["far_plane"], + binocular=binocular, + ) + if isinstance(depth, tuple): + depth = depth[0] + + mask = DataLoadUtil.load_seg(path, binocular=binocular, left_only=True) + normal = DataLoadUtil.load_normal(path, binocular=binocular, left_only=True) + ''' target points ''' + if mask is None: + target_mask_img = np.ones_like(depth, dtype=bool) + else: + target_mask_img = (mask == target_mask_label).all(axis=-1) + cam_intrinsic = cam_info["cam_intrinsic"] + z = depth[target_mask_img] + i, j = np.nonzero(target_mask_img) + x = (j - cam_intrinsic[0, 2]) * z / cam_intrinsic[0, 0] + y = (i - cam_intrinsic[1, 2]) * z / cam_intrinsic[1, 1] + + random_downsample_N = 1000 + + points_camera = np.stack((x, y, z), axis=-1).reshape(-1, 3) + normal_camera = normal[target_mask_img].reshape(-1, 3) + sampled_target_points, idx = PtsUtil.random_downsample_point_cloud( + points_camera, random_downsample_N, require_idx=True + ) + if len(sampled_target_points) == 0: + print("No target points") + + + sampled_normal_camera = normal_camera[idx] + sampled_visualized_normal = [] + sampled_normal_camera[:, 2] = -sampled_normal_camera[:, 2] + sampled_normal_camera[:, 1] = -sampled_normal_camera[:, 1] + num_samples = 10 + for i in range(len(sampled_target_points)): + sampled_visualized_normal.append([sampled_target_points[i] + 0.02*t * sampled_normal_camera[i] for t in range(num_samples)]) + + sampled_visualized_normal = np.array(sampled_visualized_normal).reshape(-1, 3) + np.savetxt(os.path.join(output_dir, "target_pts.txt"), sampled_target_points) + np.savetxt(os.path.join(output_dir, "target_normal.txt"), sampled_visualized_normal) + + @staticmethod + def save_pts_nrm(root, scene, frame_idx, output_dir, binocular=False): + path = DataLoadUtil.get_path(root, scene, frame_idx) + pts_world = DataLoadUtil.load_from_preprocessed_pts(path, "npy") + nrm_camera = DataLoadUtil.load_from_preprocessed_nrm(path, "npy") + cam_info = DataLoadUtil.load_cam_info(path, binocular=binocular) + cam_to_world = cam_info["cam_to_world"] + nrm_world = nrm_camera @ cam_to_world[:3, :3].T + visualized_nrm = [] + num_samples = 10 + for i in range(len(pts_world)): + for t in range(num_samples): + visualized_nrm.append(pts_world[i] - 0.02 * t * nrm_world[i]) + + visualized_nrm = np.array(visualized_nrm) + np.savetxt(os.path.join(output_dir, "nrm.txt"), visualized_nrm) + np.savetxt(os.path.join(output_dir, "pts.txt"), pts_world) + +# ------ Debug ------ + +if __name__ == "__main__": + root = r"C:\Document\Local Project\nbv_rec\nbv_reconstruction\temp" + model_dir = r"H:\\AI\\Datasets\\scaled_object_box_meshes" + scene = "box" + output_dir = r"C:\Document\Local Project\nbv_rec\nbv_reconstruction\test" + + #visualizeUtil.save_all_cam_pos_and_cam_axis(root, scene, output_dir) + # visualizeUtil.save_all_combined_pts(root, scene, output_dir) + # visualizeUtil.save_seq_combined_pts(root, scene, [0, 121, 286, 175, 111,366,45,230,232,225,255,17,199,78,60], output_dir) + # visualizeUtil.save_seq_cam_pos_and_cam_axis(root, scene, [0, 121, 286, 175, 111,366,45,230,232,225,255,17,199,78,60], output_dir) + # visualizeUtil.save_target_mesh_at_world_space(root, model_dir, scene) + #visualizeUtil.save_points_and_normals(root, scene,"10", output_dir, binocular=True) + visualizeUtil.save_pts_nrm(root, scene, "116", output_dir, binocular=True)