34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
import os
|
|
import json
|
|
from PytorchBoot.runners.runner import Runner
|
|
from PytorchBoot.config import ConfigManager
|
|
from PytorchBoot.utils import Log
|
|
import PytorchBoot.stereotype as stereotype
|
|
|
|
|
|
@stereotype.runner("data_generator", comment="unfinished")
|
|
class DataGenerator(Runner):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.load_experiment("generate")
|
|
|
|
def run(self):
|
|
dataset_name_list = ConfigManager.get("runner", "generate" ,"dataset_list")
|
|
for dataset_name in dataset_name_list:
|
|
self.generate(dataset_name)
|
|
|
|
def generate(self, dataset_name):
|
|
dataset_config = ConfigManager.get("datasets", dataset_name)
|
|
model_dir = dataset_config["model_dir"]
|
|
output_dir = dataset_config["output_dir"]
|
|
Log.debug(model_dir)
|
|
Log.debug(output_dir)
|
|
|
|
def create_experiment(self, backup_name=None):
|
|
super().create_experiment(backup_name)
|
|
output_dir = os.path.join(str(self.experiment_path), "output")
|
|
os.makedirs(output_dir)
|
|
|
|
def load_experiment(self, backup_name=None):
|
|
super().load_experiment(backup_name)
|
|
|