72 lines
2.2 KiB
Python
Executable File
72 lines
2.2 KiB
Python
Executable File
import os
|
|
from abc import ABC, abstractmethod
|
|
import shutil
|
|
|
|
from configs.config import ConfigManager
|
|
from runners.runner import Runner
|
|
|
|
|
|
class Preprocessor(Runner, ABC):
|
|
DATA = "data"
|
|
|
|
def __init__(self, config_path):
|
|
super().__init__(config_path)
|
|
|
|
self.preprocess_config = ConfigManager.get("settings", "preprocess")
|
|
|
|
def load_experiment(self,backup_name=None):
|
|
super().load_experiment(backup_name)
|
|
exists_ok = self.experiments_config["keep_exists"]
|
|
if not exists_ok:
|
|
data_dir = os.path.join(str(self.experiment_path), Preprocessor.DATA)
|
|
shutil.rmtree(data_dir, ignore_errors=True)
|
|
os.makedirs(data_dir)
|
|
self.create_dataset_list()
|
|
|
|
def create_experiment(self,backup_name=None):
|
|
super().create_experiment(backup_name)
|
|
data_dir = os.path.join(str(self.experiment_path), Preprocessor.DATA)
|
|
os.makedirs(data_dir)
|
|
self.create_dataset_list()
|
|
|
|
def create_dataset_list(self):
|
|
dataset_list = self.preprocess_config["dataset_list"]
|
|
exists_ok = self.experiments_config["keep_exists"]
|
|
for dataset in dataset_list:
|
|
source = dataset["source"]
|
|
source_dir = os.path.join(str(self.experiment_path), Preprocessor.DATA, source)
|
|
if not os.path.exists(source_dir):
|
|
os.makedirs(source_dir,exist_ok=exists_ok)
|
|
dataset_name = dataset["data_type"]
|
|
dataset_dir = os.path.join(source_dir, dataset_name)
|
|
if not os.path.exists(dataset_dir):
|
|
os.makedirs(dataset_dir,exist_ok=exists_ok)
|
|
|
|
@abstractmethod
|
|
def get_dataloader(self, dataset_config):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_model(self, model_config):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def prediction(self, model, dataloader):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def preprocess(self, predicted_data):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def save_processed_data(self, processed_data, data_config=None):
|
|
pass
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--config", type=str, default="../configs/local_gsnet_preprocess_config.yaml")
|
|
args = parser.parse_args()
|
|
preproc = Preprocessor(args.config)
|