nbv_grasping/runners/preprocessor.py
2024-10-09 16:13:22 +00:00

72 lines
2.2 KiB
Python
Executable File

import os
from abc import ABC, abstractmethod
import shutil
from configs.config import ConfigManager
from runners.runner import Runner
class Preprocessor(Runner, ABC):
DATA = "data"
def __init__(self, config_path):
super().__init__(config_path)
self.preprocess_config = ConfigManager.get("settings", "preprocess")
def load_experiment(self,backup_name=None):
super().load_experiment(backup_name)
exists_ok = self.experiments_config["keep_exists"]
if not exists_ok:
data_dir = os.path.join(str(self.experiment_path), Preprocessor.DATA)
shutil.rmtree(data_dir, ignore_errors=True)
os.makedirs(data_dir)
self.create_dataset_list()
def create_experiment(self,backup_name=None):
super().create_experiment(backup_name)
data_dir = os.path.join(str(self.experiment_path), Preprocessor.DATA)
os.makedirs(data_dir)
self.create_dataset_list()
def create_dataset_list(self):
dataset_list = self.preprocess_config["dataset_list"]
exists_ok = self.experiments_config["keep_exists"]
for dataset in dataset_list:
source = dataset["source"]
source_dir = os.path.join(str(self.experiment_path), Preprocessor.DATA, source)
if not os.path.exists(source_dir):
os.makedirs(source_dir,exist_ok=exists_ok)
dataset_name = dataset["data_type"]
dataset_dir = os.path.join(source_dir, dataset_name)
if not os.path.exists(dataset_dir):
os.makedirs(dataset_dir,exist_ok=exists_ok)
@abstractmethod
def get_dataloader(self, dataset_config):
pass
@abstractmethod
def get_model(self, model_config):
pass
@abstractmethod
def prediction(self, model, dataloader):
pass
@abstractmethod
def preprocess(self, predicted_data):
pass
@abstractmethod
def save_processed_data(self, processed_data, data_config=None):
pass
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="../configs/local_gsnet_preprocess_config.yaml")
args = parser.parse_args()
preproc = Preprocessor(args.config)