import os from abc import ABC, abstractmethod import shutil from configs.config import ConfigManager from runners.runner import Runner class Preprocessor(Runner, ABC): DATA = "data" def __init__(self, config_path): super().__init__(config_path) self.preprocess_config = ConfigManager.get("settings", "preprocess") def load_experiment(self,backup_name=None): super().load_experiment(backup_name) exists_ok = self.experiments_config["keep_exists"] if not exists_ok: data_dir = os.path.join(str(self.experiment_path), Preprocessor.DATA) shutil.rmtree(data_dir, ignore_errors=True) os.makedirs(data_dir) self.create_dataset_list() def create_experiment(self,backup_name=None): super().create_experiment(backup_name) data_dir = os.path.join(str(self.experiment_path), Preprocessor.DATA) os.makedirs(data_dir) self.create_dataset_list() def create_dataset_list(self): dataset_list = self.preprocess_config["dataset_list"] exists_ok = self.experiments_config["keep_exists"] for dataset in dataset_list: source = dataset["source"] source_dir = os.path.join(str(self.experiment_path), Preprocessor.DATA, source) if not os.path.exists(source_dir): os.makedirs(source_dir,exist_ok=exists_ok) dataset_name = dataset["data_type"] dataset_dir = os.path.join(source_dir, dataset_name) if not os.path.exists(dataset_dir): os.makedirs(dataset_dir,exist_ok=exists_ok) @abstractmethod def get_dataloader(self, dataset_config): pass @abstractmethod def get_model(self, model_config): pass @abstractmethod def prediction(self, model, dataloader): pass @abstractmethod def preprocess(self, predicted_data): pass @abstractmethod def save_processed_data(self, processed_data, data_config=None): pass if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default="../configs/local_gsnet_preprocess_config.yaml") args = parser.parse_args() preproc = Preprocessor(args.config)