Init ptb
This commit is contained in:
61
PytorchBoot/runners/runner.py
Normal file
61
PytorchBoot/runners/runner.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
import time
|
||||
from abc import abstractmethod, ABC
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from PytorchBoot.config import ConfigManager
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
|
||||
class Runner(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config_path):
|
||||
ConfigManager.load_config_with(config_path)
|
||||
ConfigManager.print_config()
|
||||
self.config = ConfigManager.get("runner")
|
||||
self.seed = self.config["general"]["seed"]
|
||||
self.device = self.config["general"]["device"]
|
||||
self.cuda_visible_devices = self.config["general"]["cuda_visible_devices"]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = self.cuda_visible_devices
|
||||
self.experiments_config = self.config["experiment"]
|
||||
self.experiment_path = os.path.join(self.experiments_config["root_dir"], self.experiments_config["name"])
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
lt = time.localtime()
|
||||
self.file_name = f"{lt.tm_year}_{lt.tm_mon}_{lt.tm_mday}_{lt.tm_hour}h{lt.tm_min}m{lt.tm_sec}s"
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_experiment(self, backup_name=None):
|
||||
if not os.path.exists(self.experiment_path):
|
||||
Log.info(f"experiments environment {self.experiments_config['name']} does not exists.")
|
||||
self.create_experiment(backup_name)
|
||||
else:
|
||||
Log.info(f"experiments environment {self.experiments_config['name']}")
|
||||
backup_config_dir = os.path.join(str(self.experiment_path), "configs")
|
||||
if not os.path.exists(backup_config_dir):
|
||||
os.makedirs(backup_config_dir)
|
||||
ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name)
|
||||
|
||||
@abstractmethod
|
||||
def create_experiment(self, backup_name=None):
|
||||
Log.info("creating experiment: " + self.experiments_config["name"])
|
||||
os.makedirs(self.experiment_path)
|
||||
backup_config_dir = os.path.join(str(self.experiment_path), "configs")
|
||||
os.makedirs(backup_config_dir)
|
||||
ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name)
|
||||
log_dir = os.path.join(str(self.experiment_path), "log")
|
||||
os.makedirs(log_dir)
|
||||
cache_dir = os.path.join(str(self.experiment_path), "cache")
|
||||
os.makedirs(cache_dir)
|
||||
|
||||
def print_info(self):
|
||||
table_size = 80
|
||||
Log.blue("+" + "-" * table_size + "+")
|
||||
Log.blue(f"| Experiment <{self.experiments_config['name']}>")
|
||||
Log.blue("+" + "-" * table_size + "+")
|
Reference in New Issue
Block a user