add split dataset
This commit is contained in:
parent
f58360c0c0
commit
2fcfcd1966
9
app_split.py
Normal file
9
app_split.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from PytorchBoot.application import PytorchBootApplication
|
||||||
|
from runners.data_splitor import DataSplitor
|
||||||
|
|
||||||
|
@PytorchBootApplication("split")
|
||||||
|
class DataSplitApp:
|
||||||
|
@staticmethod
|
||||||
|
def start():
|
||||||
|
DataSplitor(r"configs\split_dataset_config.yaml").run()
|
||||||
|
|
22
configs/split_dataset_config.yaml
Normal file
22
configs/split_dataset_config.yaml
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
|
||||||
|
runner:
|
||||||
|
general:
|
||||||
|
seed: 0
|
||||||
|
device: cpu
|
||||||
|
cuda_visible_devices: "0,1,2,3,4,5,6,7"
|
||||||
|
|
||||||
|
experiment:
|
||||||
|
name: debug
|
||||||
|
root_dir: "experiments"
|
||||||
|
|
||||||
|
split:
|
||||||
|
root_dir: "C:\\Document\\Local Project\\nbv_rec\\data\\sample"
|
||||||
|
type: "unseen_instance" # "unseen_category"
|
||||||
|
datasets:
|
||||||
|
OmniObject3d_train:
|
||||||
|
path: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_train.txt"
|
||||||
|
ratio: 0.5
|
||||||
|
|
||||||
|
OmniObject3d_test:
|
||||||
|
path: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_test.txt"
|
||||||
|
ratio: 0.5
|
@ -11,10 +11,14 @@ runner:
|
|||||||
|
|
||||||
train:
|
train:
|
||||||
dataset_list:
|
dataset_list:
|
||||||
- OmniObject3d
|
- OmniObject3d_train
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
OmniObject3d:
|
OmniObject3d_train:
|
||||||
root_dir: "/media/hofee/data/data/nbv_rec/sample"
|
root_dir: "C:\\Document\\Local Project\\nbv_rec\\data\\sample"
|
||||||
|
split_file: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_train.txt"
|
||||||
|
|
||||||
|
OmniObject3d_test:
|
||||||
|
root_dir: "C:\\Document\\Local Project\\nbv_rec\\data\\sample"
|
||||||
|
split_file: "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_test.txt"
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,10 +1,9 @@
|
|||||||
import os
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PytorchBoot.dataset import BaseDataset
|
from PytorchBoot.dataset import BaseDataset
|
||||||
import PytorchBoot.stereotype as stereotype
|
import PytorchBoot.stereotype as stereotype
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.path.append(r"/media/hofee/data/project/python/nbv_reconstruction/nbv_reconstruction")
|
sys.path.append(r"C:\Document\Local Project\nbv_rec\nbv_reconstruction")
|
||||||
|
|
||||||
from utils.data_load import DataLoadUtil
|
from utils.data_load import DataLoadUtil
|
||||||
from utils.pose import PoseUtil
|
from utils.pose import PoseUtil
|
||||||
@ -16,13 +15,22 @@ class NBVReconstructionDataset(BaseDataset):
|
|||||||
super(NBVReconstructionDataset, self).__init__(config)
|
super(NBVReconstructionDataset, self).__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.root_dir = config["root_dir"]
|
self.root_dir = config["root_dir"]
|
||||||
|
self.split_file_path = config["split_file"]
|
||||||
|
self.scene_name_list = self.load_scene_name_list()
|
||||||
self.datalist = self.get_datalist()
|
self.datalist = self.get_datalist()
|
||||||
self.pts_num = 1024
|
self.pts_num = 1024
|
||||||
|
|
||||||
|
def load_scene_name_list(self):
|
||||||
|
scene_name_list = []
|
||||||
|
with open(self.split_file_path, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
scene_name = line.strip()
|
||||||
|
scene_name_list.append(scene_name)
|
||||||
|
return scene_name_list
|
||||||
|
|
||||||
def get_datalist(self):
|
def get_datalist(self):
|
||||||
datalist = []
|
datalist = []
|
||||||
scene_name_list = os.listdir(self.root_dir)
|
for scene_name in self.scene_name_list:
|
||||||
for scene_name in scene_name_list:
|
|
||||||
label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name)
|
label_path = DataLoadUtil.get_label_path(self.root_dir, scene_name)
|
||||||
label_data = DataLoadUtil.load_label(label_path)
|
label_data = DataLoadUtil.load_label(label_path)
|
||||||
for data_pair in label_data["data_pairs"]:
|
for data_pair in label_data["data_pairs"]:
|
||||||
@ -97,8 +105,12 @@ class NBVReconstructionDataset(BaseDataset):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import torch
|
import torch
|
||||||
|
seed = 0
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
config = {
|
config = {
|
||||||
"root_dir": "/media/hofee/data/data/nbv_rec/sample",
|
"root_dir": "C:\\Document\\Local Project\\nbv_rec\\data\\sample",
|
||||||
|
"split_file": "C:\\Document\\Local Project\\nbv_rec\\data\\OmniObject3d_train.txt",
|
||||||
"ratio": 0.05,
|
"ratio": 0.05,
|
||||||
"batch_size": 1,
|
"batch_size": 1,
|
||||||
"num_workers": 0,
|
"num_workers": 0,
|
||||||
|
55
runners/data_splitor.py
Normal file
55
runners/data_splitor.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
from PytorchBoot.runners.runner import Runner
|
||||||
|
from PytorchBoot.config import ConfigManager
|
||||||
|
from PytorchBoot.utils import Log
|
||||||
|
import PytorchBoot.stereotype as stereotype
|
||||||
|
|
||||||
|
|
||||||
|
@stereotype.runner("data_splitor", comment="unfinished")
|
||||||
|
class DataSplitor(Runner):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.load_experiment("data_split")
|
||||||
|
self.root_dir = ConfigManager.get("runner", "split", "root_dir")
|
||||||
|
self.type = ConfigManager.get("runner", "split", "type")
|
||||||
|
self.datasets = ConfigManager.get("runner", "split", "datasets")
|
||||||
|
self.datapath_list = self.load_all_datapath()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.split_dataset()
|
||||||
|
|
||||||
|
def split_dataset(self):
|
||||||
|
|
||||||
|
random.shuffle(self.datapath_list)
|
||||||
|
start_idx = 0
|
||||||
|
for dataset in self.datasets:
|
||||||
|
ratio = self.datasets[dataset]["ratio"]
|
||||||
|
path = self.datasets[dataset]["path"]
|
||||||
|
split_size = int(len(self.datapath_list) * ratio)
|
||||||
|
split_files = self.datapath_list[start_idx:start_idx + split_size]
|
||||||
|
start_idx += split_size
|
||||||
|
self.save_split_files(path, split_files)
|
||||||
|
Log.success(f"save {dataset} split files to {path}")
|
||||||
|
|
||||||
|
def save_split_files(self, path, split_files):
|
||||||
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write("\n".join(split_files))
|
||||||
|
|
||||||
|
|
||||||
|
def load_all_datapath(self):
|
||||||
|
return os.listdir(self.root_dir)
|
||||||
|
|
||||||
|
def create_experiment(self, backup_name=None):
|
||||||
|
super().create_experiment(backup_name)
|
||||||
|
|
||||||
|
def load_experiment(self, backup_name=None):
|
||||||
|
super().load_experiment(backup_name)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user