Init ptb
This commit is contained in:
0
PytorchBoot/__init__.py
Normal file
0
PytorchBoot/__init__.py
Normal file
21
PytorchBoot/application.py
Normal file
21
PytorchBoot/application.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
|
||||
application_class = {}
|
||||
def PytorchBootApplication(arg=None):
|
||||
if callable(arg):
|
||||
cls = arg
|
||||
if "default" in application_class:
|
||||
Log.error("Multiple classes annotated with default @PytorchBootApplication, require a 'name' parameter.", True)
|
||||
application_class["default"] = cls
|
||||
return cls
|
||||
|
||||
else:
|
||||
name = arg
|
||||
def decorator(cls):
|
||||
if name is None:
|
||||
raise Log.error("The 'name' parameter is required when using @PytorchBootApplication with arguments.", True)
|
||||
if name in application_class:
|
||||
raise Log.error(f"Multiple classes annotated with @PytorchBootApplication with the same name '{name}' found.", True)
|
||||
application_class[name] = cls
|
||||
return cls
|
||||
return decorator
|
108
PytorchBoot/boot.py
Normal file
108
PytorchBoot/boot.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from PytorchBoot.application import application_class
|
||||
from PytorchBoot.stereotype import get_all_component_classes, get_all_component_comments
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
from PytorchBoot.utils.timer_util import Timer
|
||||
from PytorchBoot.utils.project_util import ProjectUtil
|
||||
from PytorchBoot.templates.application import template as app_template
|
||||
from PytorchBoot.templates.config import template as config_template
|
||||
from PytorchBoot.ui.server.app import app
|
||||
|
||||
|
||||
def run():
|
||||
root_path = os.getcwd()
|
||||
ProjectUtil.scan_project(root_path)
|
||||
|
||||
app_name = "default"
|
||||
if len(application_class) == 0:
|
||||
Log.error("No class annotated with @PytorchBootApplication found.", True)
|
||||
if len(sys.argv) < 3 and "default" not in application_class:
|
||||
Log.error("No default @PytorchBootApplication found. Please specify the 'name' parameter.", True)
|
||||
if len(sys.argv) == 3:
|
||||
app_name = sys.argv[2]
|
||||
|
||||
app_cls = application_class.get(app_name)
|
||||
|
||||
if app_cls is None:
|
||||
Log.error(f"No class annotated with @PytorchBootApplication found with the name '{app_name}'.", True)
|
||||
|
||||
if not hasattr(app_cls, "start"):
|
||||
Log.error("The class annotated with @PytorchBootApplication should have a 'start' method.", True)
|
||||
|
||||
Log.info(f"Application '{app_cls.__name__}' started.")
|
||||
timer = Timer("Application")
|
||||
|
||||
timer.start()
|
||||
app_cls.start()
|
||||
timer.stop()
|
||||
Log.info(timer.get_elasped_time_str(Timer.HOURS))
|
||||
Log.success("Application finished.")
|
||||
|
||||
|
||||
def init():
|
||||
Log.info("Initializing PytorchBoot project.")
|
||||
root_path = os.getcwd()
|
||||
if len(os.listdir(root_path)) > 0:
|
||||
Log.error("Current directory is not empty. Please provide an empty directory.")
|
||||
else:
|
||||
with open(os.path.join(root_path, "application.py"), "w") as file:
|
||||
file.write(app_template)
|
||||
with open(os.path.join(root_path, "config.yaml"), "w") as file:
|
||||
file.write(config_template)
|
||||
|
||||
Log.success("PytorchBoot project initialized.")
|
||||
Log.info("Now you can create your components and run the application.")
|
||||
|
||||
def scan():
|
||||
root_path = os.getcwd()
|
||||
ProjectUtil.scan_project(root_path)
|
||||
comments = get_all_component_comments()
|
||||
Log.info("Components detected in the project:")
|
||||
for stereotype, classes in get_all_component_classes().items():
|
||||
Log.info(f" {stereotype}:")
|
||||
for name, cls in classes.items():
|
||||
comment = comments[stereotype].get(name)
|
||||
if comment is not None:
|
||||
Log.warning(f" - {name}: {cls.__module__}.{cls.__name__} ({comment})")
|
||||
else:
|
||||
Log.success(f" - {name}: {cls.__module__}.{cls.__name__}")
|
||||
|
||||
Log.info("Applications detected in the project:")
|
||||
for app_name, app_cls in application_class.items():
|
||||
Log.success(f" - {app_name}: {app_cls.__module__}.{app_cls.__name__}")
|
||||
Log.success("Scan completed.")
|
||||
|
||||
def ui():
|
||||
port = 5000
|
||||
if len(sys.argv) == 3:
|
||||
port = int(sys.argv[2])
|
||||
Log.success(f"PytorchBoot UI server started at http://localhost:{port}")
|
||||
app.run(port=port, host="0.0.0.0")
|
||||
|
||||
|
||||
def help():
|
||||
Log.info("PytorchBoot commands:")
|
||||
Log.info(" init: Initialize a new PytorchBoot project in the current directory.")
|
||||
Log.info(" run [name]: Run the PytorchBoot application with the specified name. If no name is provided, the default application will be run.")
|
||||
Log.info(" scan: Scan the project for PytorchBoot components.")
|
||||
Log.info(" ui [port]: Start the PytorchBoot UI server. If no port is provided, the default port 5000 will be used.")
|
||||
Log.info(" help: Display this help message.")
|
||||
|
||||
def main():
|
||||
if len(sys.argv) > 1:
|
||||
if sys.argv[1] == "init":
|
||||
init()
|
||||
elif sys.argv[1] == "run":
|
||||
run()
|
||||
elif sys.argv[1] == "scan":
|
||||
scan()
|
||||
elif sys.argv[1] == "ui":
|
||||
ui()
|
||||
elif sys.argv[1] == "help":
|
||||
help()
|
||||
else:
|
||||
Log.error("Invalid command: " + sys.argv[1] + ". Use 'pytorch-boot help' for help.")
|
||||
else:
|
||||
Log.error("Please provide a command to run the application.")
|
21
PytorchBoot/component.py
Normal file
21
PytorchBoot/component.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
|
||||
class Component:
|
||||
TYPE: str
|
||||
NAME: str
|
||||
|
||||
def get_name(self):
|
||||
return self.NAME
|
||||
|
||||
def get_type(self):
|
||||
return self.TYPE
|
||||
|
||||
def get_config(self):
|
||||
return self.config
|
||||
|
||||
def print(self):
|
||||
Log.blue("Component Information")
|
||||
Log.blue(f"- Type: {self.TYPE}")
|
||||
Log.blue(f"- Name: {self.NAME}")
|
||||
Log.blue(f"- Config: \n\t{self.config}")
|
||||
|
59
PytorchBoot/config.py
Normal file
59
PytorchBoot/config.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import argparse
|
||||
import os.path
|
||||
import shutil
|
||||
import yaml
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
|
||||
class ConfigManager:
|
||||
config = None
|
||||
config_path = None
|
||||
|
||||
@staticmethod
|
||||
def get(*args):
|
||||
result = ConfigManager.config
|
||||
for arg in args:
|
||||
result = result[arg]
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def load_config_with(config_file_path):
|
||||
ConfigManager.config_path = config_file_path
|
||||
if not os.path.exists(ConfigManager.config_path):
|
||||
raise ValueError(f"Config file <{config_file_path}> does not exist")
|
||||
with open(config_file_path, 'r') as file:
|
||||
ConfigManager.config = yaml.safe_load(file)
|
||||
|
||||
@staticmethod
|
||||
def backup_config_to(target_config_dir, file_name, prefix="config"):
|
||||
file_name = f"__{prefix}_{file_name}.yaml"
|
||||
target_config_file_path = str(os.path.join(target_config_dir, file_name))
|
||||
shutil.copy(ConfigManager.config_path, target_config_file_path)
|
||||
|
||||
@staticmethod
|
||||
def load_config():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, default='', help='config file path')
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
ConfigManager.load_config_with(args.config)
|
||||
|
||||
@staticmethod
|
||||
def print_config(key: str = None, group: dict = None, level=0):
|
||||
table_size = 80
|
||||
if key and group:
|
||||
value = group[key]
|
||||
if type(value) is dict:
|
||||
Log.blue("\t" * level + f"+-{key}:")
|
||||
for k in value:
|
||||
ConfigManager.print_config(k, value, level=level + 1)
|
||||
else:
|
||||
Log.blue("\t" * level + f"| {key}: {value}")
|
||||
elif key:
|
||||
ConfigManager.print_config(key, ConfigManager.config, level=level)
|
||||
else:
|
||||
Log.blue("+" + "-" * table_size + "+")
|
||||
Log.blue(f"| Configurations in <{ConfigManager.config_path}>:")
|
||||
Log.blue("+" + "-" * table_size + "+")
|
||||
for key in ConfigManager.config:
|
||||
ConfigManager.print_config(key, level=level + 1)
|
||||
Log.blue("+" + "-" * table_size + "+")
|
1
PytorchBoot/dataset/__init__.py
Normal file
1
PytorchBoot/dataset/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from PytorchBoot.dataset.base_dataset import BaseDataset
|
BIN
PytorchBoot/dataset/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
PytorchBoot/dataset/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
PytorchBoot/dataset/__pycache__/base_dataset.cpython-39.pyc
Normal file
BIN
PytorchBoot/dataset/__pycache__/base_dataset.cpython-39.pyc
Normal file
Binary file not shown.
44
PytorchBoot/dataset/base_dataset.py
Normal file
44
PytorchBoot/dataset/base_dataset.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from abc import ABC
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
from PytorchBoot.component import Component
|
||||
|
||||
class BaseDataset(ABC, Dataset, Component):
|
||||
def __init__(self, config):
|
||||
super(BaseDataset, self).__init__()
|
||||
self.config = config
|
||||
|
||||
@staticmethod
|
||||
def process_batch(batch, device):
|
||||
for key in batch.keys():
|
||||
if isinstance(batch[key], list):
|
||||
continue
|
||||
batch[key] = batch[key].to(device)
|
||||
return batch
|
||||
|
||||
def get_collate_fn(self):
|
||||
return None
|
||||
|
||||
def get_loader(self, shuffle=False):
|
||||
ratio = self.config["ratio"]
|
||||
if ratio > 1 or ratio <= 0:
|
||||
raise ValueError(
|
||||
f"dataset ratio should be between (0,1], found {ratio} in {self.config['name']}"
|
||||
)
|
||||
subset_size = max(1,int(len(self) * ratio))
|
||||
indices = np.random.permutation(len(self))[:subset_size]
|
||||
subset = Subset(self, indices)
|
||||
return DataLoader(
|
||||
subset,
|
||||
batch_size=self.config["batch_size"],
|
||||
num_workers=self.config["num_workers"],
|
||||
shuffle=shuffle,
|
||||
collate_fn=self.get_collate_fn(),
|
||||
)
|
||||
|
||||
|
||||
|
2
PytorchBoot/factory/__init__.py
Normal file
2
PytorchBoot/factory/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from PytorchBoot.factory.component_factory import ComponentFactory
|
||||
from PytorchBoot.factory.optimizer_factory import OptimizerFactory
|
BIN
PytorchBoot/factory/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
PytorchBoot/factory/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
PytorchBoot/factory/__pycache__/component_factory.cpython-39.pyc
Normal file
BIN
PytorchBoot/factory/__pycache__/component_factory.cpython-39.pyc
Normal file
Binary file not shown.
BIN
PytorchBoot/factory/__pycache__/optimizer_factory.cpython-39.pyc
Normal file
BIN
PytorchBoot/factory/__pycache__/optimizer_factory.cpython-39.pyc
Normal file
Binary file not shown.
27
PytorchBoot/factory/component_factory.py
Normal file
27
PytorchBoot/factory/component_factory.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from PytorchBoot.component import Component
|
||||
from PytorchBoot.stereotype import *
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
from PytorchBoot.config import ConfigManager
|
||||
|
||||
class ComponentFactory:
|
||||
@staticmethod
|
||||
def create(component_type: str, name: str) -> Component:
|
||||
component_classes = get_component_classes(component_type=component_type)
|
||||
if component_classes is None:
|
||||
Log.error(f"Unsupported component type: {component_type}", True)
|
||||
|
||||
if component_type == namespace.Stereotype.DATASET:
|
||||
config = ConfigManager.get(component_type, name)
|
||||
cls = dataset_classes[config["source"]]
|
||||
dataset_obj = cls(config)
|
||||
dataset_obj.NAME = name
|
||||
dataset_obj.TYPE = component_type
|
||||
return dataset_obj
|
||||
|
||||
if name not in component_classes:
|
||||
Log.error(f"Unsupported component name: {name}", True)
|
||||
|
||||
cls = component_classes[name]
|
||||
config = ConfigManager.get(component_type, name)
|
||||
return cls(config)
|
||||
|
67
PytorchBoot/factory/optimizer_factory.py
Normal file
67
PytorchBoot/factory/optimizer_factory.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import torch.optim as optim
|
||||
|
||||
class OptimizerFactory:
|
||||
@staticmethod
|
||||
def create(config: dict, params) -> optim.Optimizer:
|
||||
optim_type = config["type"]
|
||||
lr = config.get("lr", 1e-3)
|
||||
|
||||
if optim_type == "SGD":
|
||||
return optim.SGD(
|
||||
params,
|
||||
lr=lr,
|
||||
momentum=config.get("momentum", 0.9),
|
||||
weight_decay=config.get("weight_decay", 1e-4),
|
||||
)
|
||||
elif optim_type == "Adam":
|
||||
return optim.Adam(
|
||||
params,
|
||||
lr=lr,
|
||||
betas=config.get("betas", (0.9, 0.999)),
|
||||
eps=config.get("eps", 1e-8),
|
||||
)
|
||||
elif optim_type == "AdamW":
|
||||
return optim.AdamW(
|
||||
params,
|
||||
lr=lr,
|
||||
betas=config.get("betas", (0.9, 0.999)),
|
||||
eps=config.get("eps", 1e-8),
|
||||
weight_decay=config.get("weight_decay", 1e-2),
|
||||
)
|
||||
elif optim_type == "RMSprop":
|
||||
return optim.RMSprop(
|
||||
params,
|
||||
lr=lr,
|
||||
alpha=config.get("alpha", 0.99),
|
||||
eps=config.get("eps", 1e-8),
|
||||
weight_decay=config.get("weight_decay", 1e-4),
|
||||
momentum=config.get("momentum", 0.9),
|
||||
)
|
||||
elif optim_type == "Adagrad":
|
||||
return optim.Adagrad(
|
||||
params,
|
||||
lr=lr,
|
||||
lr_decay=config.get("lr_decay", 0),
|
||||
weight_decay=config.get("weight_decay", 0),
|
||||
)
|
||||
elif optim_type == "Adamax":
|
||||
return optim.Adamax(
|
||||
params,
|
||||
lr=lr,
|
||||
betas=config.get("betas", (0.9, 0.999)),
|
||||
eps=config.get("eps", 1e-8),
|
||||
weight_decay=config.get("weight_decay", 0),
|
||||
)
|
||||
elif optim_type == "LBFGS":
|
||||
return optim.LBFGS(
|
||||
params,
|
||||
lr=lr,
|
||||
max_iter=config.get("max_iter", 20),
|
||||
max_eval=config.get("max_eval", None),
|
||||
tolerance_grad=config.get("tolerance_grad", 1e-7),
|
||||
tolerance_change=config.get("tolerance_change", 1e-9),
|
||||
history_size=config.get("history_size", 100),
|
||||
line_search_fn=config.get("line_search_fn", None),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unknown optimizer: {}".format(optim_type))
|
33
PytorchBoot/namespace.py
Normal file
33
PytorchBoot/namespace.py
Normal file
@@ -0,0 +1,33 @@
|
||||
|
||||
class Stereotype:
|
||||
DATASET:str = "dataset"
|
||||
MODULE:str = "module"
|
||||
PIPELINE:str = "pipeline"
|
||||
RUNNER:str = "runner"
|
||||
FACTORY:str = "factory"
|
||||
EVALUATION_METHOD:str = "evaluation_method"
|
||||
LOSS_FUNCTION:str = "loss_function"
|
||||
|
||||
class Mode:
|
||||
TRAIN:str = "train"
|
||||
TEST:str = "test"
|
||||
EVALUATION:str = "evaluation"
|
||||
|
||||
class Direcotry:
|
||||
CHECKPOINT_DIR_NAME: str = 'checkpoints'
|
||||
TENSORBOARD_DIR_NAME: str = 'tensorboard'
|
||||
LOG_DIR_NAME: str = 'log'
|
||||
RESULT_DIR_NAME: str = 'results'
|
||||
|
||||
class TensorBoard:
|
||||
SCALAR: str = "scalar"
|
||||
IMAGE: str = "image"
|
||||
POINT: str = "point"
|
||||
|
||||
class LogType:
|
||||
INFO:str = "info"
|
||||
ERROR:str = "error"
|
||||
WARNING:str = "warning"
|
||||
SUCCESS:str = "success"
|
||||
DEBUG:str = "debug"
|
||||
TERMINATE:str = "terminate"
|
4
PytorchBoot/runners/__init__.py
Normal file
4
PytorchBoot/runners/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from PytorchBoot.runners.trainer import DefaultTrainer
|
||||
from PytorchBoot.runners.evaluator import DefaultEvaluator
|
||||
from PytorchBoot.runners.predictor import DefaultPredictor
|
||||
from PytorchBoot.runners.runner import Runner
|
BIN
PytorchBoot/runners/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
PytorchBoot/runners/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
PytorchBoot/runners/__pycache__/evaluator.cpython-39.pyc
Normal file
BIN
PytorchBoot/runners/__pycache__/evaluator.cpython-39.pyc
Normal file
Binary file not shown.
BIN
PytorchBoot/runners/__pycache__/predictor.cpython-39.pyc
Normal file
BIN
PytorchBoot/runners/__pycache__/predictor.cpython-39.pyc
Normal file
Binary file not shown.
BIN
PytorchBoot/runners/__pycache__/runner.cpython-39.pyc
Normal file
BIN
PytorchBoot/runners/__pycache__/runner.cpython-39.pyc
Normal file
Binary file not shown.
BIN
PytorchBoot/runners/__pycache__/trainer.cpython-39.pyc
Normal file
BIN
PytorchBoot/runners/__pycache__/trainer.cpython-39.pyc
Normal file
Binary file not shown.
BIN
PytorchBoot/runners/__pycache__/web_runner.cpython-39.pyc
Normal file
BIN
PytorchBoot/runners/__pycache__/web_runner.cpython-39.pyc
Normal file
Binary file not shown.
132
PytorchBoot/runners/evaluator.py
Normal file
132
PytorchBoot/runners/evaluator.py
Normal file
@@ -0,0 +1,132 @@
|
||||
import os
|
||||
import json
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from PytorchBoot.config import ConfigManager
|
||||
import PytorchBoot.namespace as namespace
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
from PytorchBoot.factory import ComponentFactory
|
||||
|
||||
from PytorchBoot.dataset import BaseDataset
|
||||
from PytorchBoot.runners.runner import Runner
|
||||
from PytorchBoot.utils import Log
|
||||
|
||||
@stereotype.runner("default_evaluator")
|
||||
class DefaultEvaluator(Runner):
|
||||
def __init__(self, config_path):
|
||||
super().__init__(config_path)
|
||||
''' Pipeline '''
|
||||
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
|
||||
self.pipeline = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
|
||||
self.pipeline:torch.nn.Module = self.pipeline.to(self.device)
|
||||
|
||||
''' Experiment '''
|
||||
self.model_path = self.config["experiment"]["model_path"]
|
||||
self.load_experiment("default_evaluator")
|
||||
|
||||
|
||||
''' Test '''
|
||||
self.test_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TEST)
|
||||
self.test_dataset_name_list = self.test_config["dataset_list"]
|
||||
self.test_set_list = []
|
||||
self.test_writer_list = []
|
||||
seen_name = set()
|
||||
for test_dataset_name in self.test_dataset_name_list:
|
||||
if test_dataset_name not in seen_name:
|
||||
seen_name.add(test_dataset_name)
|
||||
else:
|
||||
raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name))
|
||||
test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name)
|
||||
self.test_set_list.append(test_set)
|
||||
|
||||
self.print_info()
|
||||
|
||||
def run(self):
|
||||
eval_result = self.test()
|
||||
self.save_eval_result(eval_result)
|
||||
|
||||
def test(self):
|
||||
self.pipeline.eval()
|
||||
eval_result = {}
|
||||
with torch.no_grad():
|
||||
test_set: BaseDataset
|
||||
for dataset_idx, test_set in enumerate(self.test_set_list):
|
||||
test_set_config = test_set.get_config()
|
||||
eval_list = test_set_config["eval_list"]
|
||||
ratio = test_set_config["ratio"]
|
||||
test_set_name = test_set.get_name()
|
||||
output_list = []
|
||||
data_list = []
|
||||
test_loader = test_set.get_loader()
|
||||
loop = tqdm(enumerate(test_loader), total=int(len(test_loader)))
|
||||
for _, data in loop:
|
||||
test_set.process_batch(data, self.device)
|
||||
data["mode"] = namespace.Mode.TEST
|
||||
output = self.pipeline(data)
|
||||
output_list.append(output)
|
||||
data_list.append(data)
|
||||
loop.set_description(
|
||||
f'Evaluating [{dataset_idx+1}/{len(self.test_set_list)}] (Test: {test_set_name}, ratio={ratio})')
|
||||
result_dict = self.eval_fn(output_list, data_list, eval_list)
|
||||
eval_result[test_set_name] = result_dict
|
||||
return eval_result
|
||||
|
||||
def save_eval_result(self, eval_result):
|
||||
result_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.RESULT_DIR_NAME)
|
||||
eval_result_path = os.path.join(result_dir, self.file_name + "_eval_result.json")
|
||||
with open(eval_result_path, "w") as f:
|
||||
json.dump(eval_result, f, indent=4)
|
||||
Log.success(f"Saved evaluation result to {eval_result_path}")
|
||||
|
||||
@staticmethod
|
||||
def eval_fn(output_list, data_list, eval_list):
|
||||
collected_result = {}
|
||||
for eval_method_name in eval_list:
|
||||
eval_method = ComponentFactory.create(namespace.Stereotype.EVALUATION_METHOD, eval_method_name)
|
||||
eval_results:dict = eval_method.evaluate(output_list, data_list)
|
||||
for data_type, eval_result in eval_results.items():
|
||||
if data_type not in collected_result:
|
||||
collected_result[data_type] = {}
|
||||
for name, value in eval_result.items():
|
||||
collected_result[data_type][name] = value
|
||||
|
||||
return collected_result
|
||||
|
||||
def load_checkpoint(self):
|
||||
self.load(self.model_path)
|
||||
Log.success(f"Loaded checkpoint from {self.model_path}")
|
||||
|
||||
def load_experiment(self, backup_name=None):
|
||||
super().load_experiment(backup_name)
|
||||
self.load_checkpoint()
|
||||
|
||||
def create_experiment(self, backup_name=None):
|
||||
super().create_experiment(backup_name)
|
||||
result_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.RESULT_DIR_NAME)
|
||||
os.makedirs(result_dir)
|
||||
|
||||
def load(self, path):
|
||||
state_dict = torch.load(path)
|
||||
self.pipeline.load_state_dict(state_dict)
|
||||
|
||||
def print_info(self):
|
||||
def print_dataset(dataset: BaseDataset):
|
||||
config = dataset.get_config()
|
||||
name = dataset.get_name()
|
||||
Log.blue(f"Dataset: {name}")
|
||||
for k,v in config.items():
|
||||
Log.blue(f"\t{k}: {v}")
|
||||
|
||||
super().print_info()
|
||||
table_size = 70
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+')
|
||||
Log.blue(self.pipeline)
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+')
|
||||
for i, test_set in enumerate(self.test_set_list):
|
||||
Log.blue(f"test dataset {i}: ")
|
||||
print_dataset(test_set)
|
||||
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+')
|
||||
|
128
PytorchBoot/runners/predictor.py
Normal file
128
PytorchBoot/runners/predictor.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import os
|
||||
import json
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from PytorchBoot.config import ConfigManager
|
||||
import PytorchBoot.namespace as namespace
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
from PytorchBoot.factory import ComponentFactory
|
||||
|
||||
from PytorchBoot.dataset import BaseDataset
|
||||
from PytorchBoot.runners.runner import Runner
|
||||
from PytorchBoot.utils import Log
|
||||
|
||||
@stereotype.runner("default_predictor")
|
||||
class DefaultPredictor(Runner):
|
||||
def __init__(self, config_path):
|
||||
super().__init__(config_path)
|
||||
''' Pipeline '''
|
||||
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
|
||||
self.pipeline = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
|
||||
self.pipeline:torch.nn.Module = self.pipeline.to(self.device)
|
||||
|
||||
''' Experiment '''
|
||||
self.model_path = self.config["experiment"]["model_path"]
|
||||
self.load_experiment("default_predictor")
|
||||
self.save_original_data = self.config["experiment"]["save_original_data"]
|
||||
|
||||
''' Testset '''
|
||||
self.test_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TEST)
|
||||
self.test_dataset_name_list = self.test_config["dataset_list"]
|
||||
self.test_set_list = []
|
||||
self.test_writer_list = []
|
||||
seen_name = set()
|
||||
for test_dataset_name in self.test_dataset_name_list:
|
||||
if test_dataset_name not in seen_name:
|
||||
seen_name.add(test_dataset_name)
|
||||
else:
|
||||
raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name))
|
||||
test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name)
|
||||
self.test_set_list.append(test_set)
|
||||
|
||||
self.print_info()
|
||||
|
||||
def run(self):
|
||||
predict_result = self.predict()
|
||||
self.save_predict_result(predict_result)
|
||||
|
||||
def predict(self):
|
||||
self.pipeline.eval()
|
||||
predict_result = {}
|
||||
with torch.no_grad():
|
||||
test_set: BaseDataset
|
||||
for dataset_idx, test_set in enumerate(self.test_set_list):
|
||||
test_set_config = test_set.get_config()
|
||||
ratio = test_set_config["ratio"]
|
||||
test_set_name = test_set.get_name()
|
||||
output_list = []
|
||||
data_list = []
|
||||
test_loader = test_set.get_loader()
|
||||
loop = tqdm(enumerate(test_loader), total=int(len(test_loader)))
|
||||
for _, data in loop:
|
||||
test_set.process_batch(data, self.device)
|
||||
data["mode"] = namespace.Mode.TEST
|
||||
output = self.pipeline(data)
|
||||
output_list.append(output)
|
||||
data_list.append(data)
|
||||
loop.set_description(
|
||||
f'Predicting [{dataset_idx+1}/{len(self.test_set_list)}] (Test: {test_set_name}, ratio={ratio})')
|
||||
predict_result[test_set_name] = {
|
||||
"output": output_list,
|
||||
"data": data_list
|
||||
}
|
||||
return predict_result
|
||||
|
||||
def save_predict_result(self, predict_result):
|
||||
result_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.RESULT_DIR_NAME, self.file_name+"_predict_result")
|
||||
os.makedirs(result_dir)
|
||||
for test_set_name in predict_result.keys():
|
||||
os.mkdir(os.path.join(result_dir, test_set_name))
|
||||
idx = 0
|
||||
for output, data in zip(predict_result[test_set_name]["output"], predict_result[test_set_name]["data"]):
|
||||
output_path = os.path.join(result_dir, test_set_name, f"output_{idx}.pth")
|
||||
torch.save(output, output_path)
|
||||
if self.save_original_data:
|
||||
data_path = os.path.join(result_dir, test_set_name, f"data_{idx}.pth")
|
||||
torch.save(data, data_path)
|
||||
idx += 1
|
||||
Log.success(f"Saved predict result of {test_set_name} to {result_dir}")
|
||||
Log.success(f"Saved all predict result to {result_dir}")
|
||||
|
||||
def load_checkpoint(self):
|
||||
self.load(self.model_path)
|
||||
Log.success(f"Loaded checkpoint from {self.model_path}")
|
||||
|
||||
def load_experiment(self, backup_name=None):
|
||||
super().load_experiment(backup_name)
|
||||
self.load_checkpoint()
|
||||
|
||||
def create_experiment(self, backup_name=None):
|
||||
super().create_experiment(backup_name)
|
||||
result_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.RESULT_DIR_NAME)
|
||||
os.makedirs(result_dir)
|
||||
|
||||
def load(self, path):
|
||||
state_dict = torch.load(path)
|
||||
self.pipeline.load_state_dict(state_dict)
|
||||
|
||||
def print_info(self):
|
||||
def print_dataset(dataset: BaseDataset):
|
||||
config = dataset.get_config()
|
||||
name = dataset.get_name()
|
||||
Log.blue(f"Dataset: {name}")
|
||||
for k,v in config.items():
|
||||
Log.blue(f"\t{k}: {v}")
|
||||
|
||||
super().print_info()
|
||||
table_size = 70
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+')
|
||||
Log.blue(self.pipeline)
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+')
|
||||
for i, test_set in enumerate(self.test_set_list):
|
||||
Log.blue(f"test dataset {i}: ")
|
||||
print_dataset(test_set)
|
||||
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+')
|
||||
|
61
PytorchBoot/runners/runner.py
Normal file
61
PytorchBoot/runners/runner.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
import time
|
||||
from abc import abstractmethod, ABC
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from PytorchBoot.config import ConfigManager
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
|
||||
class Runner(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config_path):
|
||||
ConfigManager.load_config_with(config_path)
|
||||
ConfigManager.print_config()
|
||||
self.config = ConfigManager.get("runner")
|
||||
self.seed = self.config["general"]["seed"]
|
||||
self.device = self.config["general"]["device"]
|
||||
self.cuda_visible_devices = self.config["general"]["cuda_visible_devices"]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = self.cuda_visible_devices
|
||||
self.experiments_config = self.config["experiment"]
|
||||
self.experiment_path = os.path.join(self.experiments_config["root_dir"], self.experiments_config["name"])
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
lt = time.localtime()
|
||||
self.file_name = f"{lt.tm_year}_{lt.tm_mon}_{lt.tm_mday}_{lt.tm_hour}h{lt.tm_min}m{lt.tm_sec}s"
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_experiment(self, backup_name=None):
|
||||
if not os.path.exists(self.experiment_path):
|
||||
Log.info(f"experiments environment {self.experiments_config['name']} does not exists.")
|
||||
self.create_experiment(backup_name)
|
||||
else:
|
||||
Log.info(f"experiments environment {self.experiments_config['name']}")
|
||||
backup_config_dir = os.path.join(str(self.experiment_path), "configs")
|
||||
if not os.path.exists(backup_config_dir):
|
||||
os.makedirs(backup_config_dir)
|
||||
ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name)
|
||||
|
||||
@abstractmethod
|
||||
def create_experiment(self, backup_name=None):
|
||||
Log.info("creating experiment: " + self.experiments_config["name"])
|
||||
os.makedirs(self.experiment_path)
|
||||
backup_config_dir = os.path.join(str(self.experiment_path), "configs")
|
||||
os.makedirs(backup_config_dir)
|
||||
ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name)
|
||||
log_dir = os.path.join(str(self.experiment_path), "log")
|
||||
os.makedirs(log_dir)
|
||||
cache_dir = os.path.join(str(self.experiment_path), "cache")
|
||||
os.makedirs(cache_dir)
|
||||
|
||||
def print_info(self):
|
||||
table_size = 80
|
||||
Log.blue("+" + "-" * table_size + "+")
|
||||
Log.blue(f"| Experiment <{self.experiments_config['name']}>")
|
||||
Log.blue("+" + "-" * table_size + "+")
|
266
PytorchBoot/runners/trainer.py
Normal file
266
PytorchBoot/runners/trainer.py
Normal file
@@ -0,0 +1,266 @@
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from PytorchBoot.config import ConfigManager
|
||||
import PytorchBoot.namespace as namespace
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
from PytorchBoot.factory import ComponentFactory
|
||||
from PytorchBoot.factory import OptimizerFactory
|
||||
|
||||
from PytorchBoot.dataset import BaseDataset
|
||||
from PytorchBoot.runners.runner import Runner
|
||||
from PytorchBoot.utils.tensorboard_util import TensorboardWriter
|
||||
from PytorchBoot.stereotype import EXTERNAL_FRONZEN_MODULES
|
||||
from PytorchBoot.utils import Log
|
||||
from PytorchBoot.status import status_manager
|
||||
|
||||
@stereotype.runner("default_trainer")
|
||||
class DefaultTrainer(Runner):
|
||||
def __init__(self, config_path):
|
||||
super().__init__(config_path)
|
||||
tensorboard_path = os.path.join(self.experiment_path, namespace.Direcotry.TENSORBOARD_DIR_NAME)
|
||||
|
||||
''' Pipeline '''
|
||||
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
|
||||
self.parallel = self.config["general"]["parallel"]
|
||||
self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
|
||||
if self.parallel and self.device == "cuda":
|
||||
self.pipeline = torch.nn.DataParallel(self.pipeline)
|
||||
self.pipeline = self.pipeline.to(self.device)
|
||||
|
||||
''' Experiment '''
|
||||
self.current_epoch = 0
|
||||
self.current_iter = 0
|
||||
self.max_epochs = self.experiments_config["max_epochs"]
|
||||
self.test_first = self.experiments_config["test_first"]
|
||||
self.load_experiment("default_trainer")
|
||||
|
||||
''' Train '''
|
||||
self.train_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TRAIN)
|
||||
self.train_dataset_name= self.train_config["dataset"]
|
||||
self.train_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, self.train_dataset_name)
|
||||
self.optimizer = OptimizerFactory.create(self.train_config["optimizer"], self.pipeline.parameters())
|
||||
self.train_writer = SummaryWriter(
|
||||
log_dir=os.path.join(tensorboard_path, f"[{namespace.Mode.TRAIN}]{self.train_dataset_name}"))
|
||||
|
||||
''' Test '''
|
||||
self.test_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TEST)
|
||||
self.test_dataset_name_list = self.test_config["dataset_list"]
|
||||
self.test_set_list = []
|
||||
self.test_writer_list = []
|
||||
seen_name = set()
|
||||
for test_dataset_name in self.test_dataset_name_list:
|
||||
if test_dataset_name not in seen_name:
|
||||
seen_name.add(test_dataset_name)
|
||||
else:
|
||||
raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name))
|
||||
test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name)
|
||||
test_writer = SummaryWriter(
|
||||
log_dir=os.path.join(tensorboard_path, f"[test]{test_dataset_name}"))
|
||||
self.test_set_list.append(test_set)
|
||||
self.test_writer_list.append(test_writer)
|
||||
|
||||
self.print_info()
|
||||
|
||||
def run(self):
|
||||
save_interval = self.experiments_config["save_checkpoint_interval"]
|
||||
if self.current_epoch != 0:
|
||||
Log.info("Continue training from epoch {}.".format(self.current_epoch))
|
||||
else:
|
||||
Log.info("Start training from initial model.")
|
||||
if self.test_first:
|
||||
Log.info("Do test first.")
|
||||
self.test()
|
||||
while self.current_epoch < self.max_epochs:
|
||||
self.current_epoch += 1
|
||||
status_manager.set_progress("train", "default_trainer", "Epoch", self.current_epoch, self.max_epochs)
|
||||
self.train()
|
||||
self.test()
|
||||
if self.current_epoch % save_interval == 0:
|
||||
self.save_checkpoint()
|
||||
self.save_checkpoint(is_last=True)
|
||||
|
||||
def train(self):
|
||||
self.pipeline.train()
|
||||
train_set_name = self.train_dataset_name
|
||||
config = self.train_set.get_config()
|
||||
train_loader = self.train_set.get_loader(shuffle=True)
|
||||
|
||||
total=len(train_loader)
|
||||
loop = tqdm(enumerate(train_loader), total=total)
|
||||
|
||||
for i, data in loop:
|
||||
status_manager.set_progress("train", "default_trainer", f"(train) Batch[{train_set_name}]", i+1, total)
|
||||
self.train_set.process_batch(data, self.device)
|
||||
loss_dict = self.train_step(data)
|
||||
loop.set_description(
|
||||
f'Epoch [{self.current_epoch}/{self.max_epochs}] (Train: {train_set_name}, ratio={config["ratio"]})')
|
||||
loop.set_postfix(loss=loss_dict)
|
||||
for loss_name, loss in loss_dict.items():
|
||||
status_manager.set_status("train", "default_trainer", f"[loss]{loss_name}", loss)
|
||||
TensorboardWriter.write_tensorboard(self.train_writer, "iter", loss_dict, self.current_iter, simple_scalar=True)
|
||||
self.current_iter += 1
|
||||
|
||||
|
||||
|
||||
def train_step(self, data):
|
||||
self.optimizer.zero_grad()
|
||||
data["mode"] = namespace.Mode.TRAIN
|
||||
output = self.pipeline(data)
|
||||
total_loss, loss_dict = self.loss_fn(output, data)
|
||||
total_loss.backward()
|
||||
self.optimizer.step()
|
||||
for k, v in loss_dict.items():
|
||||
loss_dict[k] = round(v, 5)
|
||||
return loss_dict
|
||||
|
||||
def loss_fn(self, output, data):
|
||||
loss_name_list = self.train_config["losses"]
|
||||
loss_dict = {}
|
||||
total_loss = torch.tensor(0.0, dtype=torch.float32, device=self.device)
|
||||
for loss_name in loss_name_list:
|
||||
target_loss_fn = ComponentFactory.create(namespace.Stereotype.LOSS_FUNCTION, loss_name)
|
||||
loss = target_loss_fn.compute(output, data)
|
||||
loss_dict[loss_name] = loss.item()
|
||||
total_loss += loss
|
||||
|
||||
loss_dict['total_loss'] = total_loss.item()
|
||||
return total_loss, loss_dict
|
||||
|
||||
def test(self):
|
||||
self.pipeline.eval()
|
||||
with torch.no_grad():
|
||||
test_set: BaseDataset
|
||||
for dataset_idx, test_set in enumerate(self.test_set_list):
|
||||
test_set_config = test_set.get_config()
|
||||
eval_list = test_set_config["eval_list"]
|
||||
ratio = test_set_config["ratio"]
|
||||
test_set_name = test_set.get_name()
|
||||
writer = self.test_writer_list[dataset_idx]
|
||||
output_list = []
|
||||
data_list = []
|
||||
test_loader = test_set.get_loader()
|
||||
total=int(len(test_loader))
|
||||
loop = tqdm(enumerate(test_loader), total=total)
|
||||
for i, data in loop:
|
||||
status_manager.set_progress("train", "default_trainer", f"(test) Batch[{test_set_name}]", i+1, total)
|
||||
test_set.process_batch(data, self.device)
|
||||
data["mode"] = namespace.Mode.TEST
|
||||
output = self.pipeline(data)
|
||||
output_list.append(output)
|
||||
data_list.append(data)
|
||||
loop.set_description(
|
||||
f'Epoch [{self.current_epoch}/{self.max_epochs}] (Test: {test_set_name}, ratio={ratio})')
|
||||
result_dict = self.eval_fn(output_list, data_list, eval_list)
|
||||
TensorboardWriter.write_tensorboard(writer, "epoch", result_dict, self.current_epoch - 1)
|
||||
|
||||
@staticmethod
|
||||
def eval_fn(output_list, data_list, eval_list):
|
||||
collected_result = {}
|
||||
for eval_method_name in eval_list:
|
||||
eval_method = ComponentFactory.create(namespace.Stereotype.EVALUATION_METHOD, eval_method_name)
|
||||
eval_results:dict = eval_method.evaluate(output_list, data_list)
|
||||
for data_type, eval_result in eval_results.items():
|
||||
if data_type not in collected_result:
|
||||
collected_result[data_type] = {}
|
||||
for name, value in eval_result.items():
|
||||
collected_result[data_type][name] = value
|
||||
status_manager.set_status("train", "default_trainer", f"[eval]{name}", value)
|
||||
|
||||
return collected_result
|
||||
|
||||
def get_checkpoint_path(self, is_last=False):
|
||||
return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME,
|
||||
"Epoch_{}.pth".format(
|
||||
self.current_epoch if self.current_epoch != -1 and not is_last else "last"))
|
||||
|
||||
def load_checkpoint(self, is_last=False):
|
||||
self.load(self.get_checkpoint_path(is_last))
|
||||
Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}")
|
||||
if is_last:
|
||||
checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME)
|
||||
meta_path = os.path.join(checkpoint_root, "meta.json")
|
||||
if not os.path.exists(meta_path):
|
||||
raise FileNotFoundError(
|
||||
"No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"]))
|
||||
file_path = os.path.join(checkpoint_root, "meta.json")
|
||||
with open(file_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
self.current_epoch = meta["last_epoch"]
|
||||
self.current_iter = meta["last_iter"]
|
||||
|
||||
def save_checkpoint(self, is_last=False):
|
||||
self.save(self.get_checkpoint_path(is_last))
|
||||
if not is_last:
|
||||
Log.success(f"Checkpoint at epoch {self.current_epoch} saved to {self.get_checkpoint_path(is_last)}")
|
||||
else:
|
||||
meta = {
|
||||
"last_epoch": self.current_epoch,
|
||||
"last_iter": self.current_iter,
|
||||
"time": str(datetime.now())
|
||||
}
|
||||
checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME)
|
||||
file_path = os.path.join(checkpoint_root, "meta.json")
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(meta, f)
|
||||
|
||||
def load_experiment(self, backup_name=None):
|
||||
super().load_experiment(backup_name)
|
||||
if self.experiments_config["use_checkpoint"]:
|
||||
self.current_epoch = self.experiments_config["epoch"]
|
||||
self.load_checkpoint(is_last=(self.current_epoch == -1))
|
||||
|
||||
def create_experiment(self, backup_name=None):
|
||||
super().create_experiment(backup_name)
|
||||
ckpt_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.CHECKPOINT_DIR_NAME)
|
||||
os.makedirs(ckpt_dir)
|
||||
tensorboard_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.TENSORBOARD_DIR_NAME)
|
||||
os.makedirs(tensorboard_dir)
|
||||
|
||||
def load(self, path):
|
||||
state_dict = torch.load(path)
|
||||
if self.parallel:
|
||||
self.pipeline.module.load_state_dict(state_dict)
|
||||
else:
|
||||
self.pipeline.load_state_dict(state_dict)
|
||||
|
||||
def save(self, path):
|
||||
if self.parallel:
|
||||
state_dict = self.pipeline.module.state_dict()
|
||||
else:
|
||||
state_dict = self.pipeline.state_dict()
|
||||
|
||||
for name, module in self.pipeline.named_modules():
|
||||
if module.__class__ in EXTERNAL_FRONZEN_MODULES:
|
||||
if name in state_dict:
|
||||
del state_dict[name]
|
||||
|
||||
torch.save(state_dict, path)
|
||||
|
||||
|
||||
def print_info(self):
|
||||
def print_dataset(dataset: BaseDataset):
|
||||
config = dataset.get_config()
|
||||
name = dataset.get_name()
|
||||
Log.blue(f"Dataset: {name}")
|
||||
for k,v in config.items():
|
||||
Log.blue(f"\t{k}: {v}")
|
||||
|
||||
super().print_info()
|
||||
table_size = 70
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+')
|
||||
Log.blue(self.pipeline)
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+')
|
||||
Log.blue("train dataset: ")
|
||||
print_dataset(self.train_set)
|
||||
for i, test_set in enumerate(self.test_set_list):
|
||||
Log.blue(f"test dataset {i}: ")
|
||||
print_dataset(test_set)
|
||||
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+')
|
||||
|
17
PytorchBoot/runners/web_runner.py
Normal file
17
PytorchBoot/runners/web_runner.py
Normal file
@@ -0,0 +1,17 @@
|
||||
|
||||
from abc import abstractmethod, ABC
|
||||
|
||||
from PytorchBoot.config import ConfigManager
|
||||
from PytorchBoot.runners import Runner
|
||||
from PytorchBoot.utils import Log
|
||||
|
||||
class WebRunner(ABC, Runner):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config_path):
|
||||
ConfigManager.load_config_with(config_path)
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
pass
|
||||
|
56
PytorchBoot/status.py
Normal file
56
PytorchBoot/status.py
Normal file
@@ -0,0 +1,56 @@
|
||||
|
||||
class StatusManager:
|
||||
def __init__(self):
|
||||
self.running_app = {}
|
||||
self.last_status = {}
|
||||
self.curr_status = {}
|
||||
self.progress = {}
|
||||
self.log = []
|
||||
|
||||
def is_running(self):
|
||||
return len(self.running_app) > 0
|
||||
|
||||
def run_app(self, app_name, app):
|
||||
self.running_app[app_name] = app
|
||||
|
||||
def end_app(self, app_name):
|
||||
self.running_app.pop(app_name)
|
||||
|
||||
def set_status(self, app_name, runner_name, key, value):
|
||||
self.last_status = self.curr_status
|
||||
if app_name not in self.curr_status:
|
||||
self.curr_status[app_name] = {}
|
||||
if runner_name not in self.curr_status[app_name]:
|
||||
self.curr_status[app_name][runner_name] = {}
|
||||
self.curr_status[app_name][runner_name][key] = value
|
||||
|
||||
def set_progress(self, app_name, runner_name, key, curr_value, max_value):
|
||||
if app_name not in self.progress:
|
||||
self.progress[app_name] = {}
|
||||
if runner_name not in self.progress[app_name]:
|
||||
self.progress[app_name][runner_name] = {}
|
||||
self.progress[app_name][runner_name][key] = (curr_value, max_value)
|
||||
|
||||
def get_status(self):
|
||||
return self.curr_status
|
||||
|
||||
def get_progress(self):
|
||||
return self.progress
|
||||
|
||||
def add_log(self, time_str, log_type, message):
|
||||
self.log.append((time_str, log_type, message))
|
||||
|
||||
def get_log(self):
|
||||
return self.log
|
||||
|
||||
def get_running_apps(self):
|
||||
return list(self.running_app.keys())
|
||||
|
||||
def get_last_status(self):
|
||||
return self.last_status
|
||||
|
||||
def reset_status(self):
|
||||
self.last_status = {}
|
||||
self.curr_status = {}
|
||||
|
||||
status_manager = StatusManager()
|
149
PytorchBoot/stereotype.py
Normal file
149
PytorchBoot/stereotype.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import inspect
|
||||
|
||||
from PytorchBoot.component import Component
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
import PytorchBoot.namespace as namespace
|
||||
|
||||
|
||||
def ensure_component_subclass(cls, type_name, name):
|
||||
if not issubclass(cls, Component):
|
||||
new_cls = type(cls.__name__, (Component, cls), {
|
||||
**cls.__dict__,
|
||||
"TYPE": type_name,
|
||||
"NAME": name
|
||||
})
|
||||
new_cls.__original_class__ = cls
|
||||
else:
|
||||
new_cls = cls
|
||||
for method_name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
|
||||
if getattr(method, "__isabstractmethod__", False):
|
||||
Log.error(f"Component <{name}> contains abstract method <{method_name}>.", True)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
# --- Classes --- #
|
||||
dataset_classes = {}
|
||||
dataset_comments = {}
|
||||
def dataset(dataset_name, comment=None):
|
||||
def decorator(cls):
|
||||
if not hasattr(cls, "get_loader") or not callable(getattr(cls, "get_loader")):
|
||||
Log.error(f"dataset <{cls.__name__}> must implement a 'get_loader' method", True)
|
||||
cls = ensure_component_subclass(cls, namespace.Stereotype.DATASET, dataset_name)
|
||||
dataset_comments[dataset_name] = comment
|
||||
dataset_classes[dataset_name] = cls
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
module_classes = {}
|
||||
module_comments = {}
|
||||
def module(module_name, comment=None):
|
||||
def decorator(cls):
|
||||
cls = ensure_component_subclass(cls, namespace.Stereotype.MODULE, module_name)
|
||||
module_comments[module_name] = comment
|
||||
module_classes[module_name] = cls
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
pipeline_classes = {}
|
||||
pipline_comments = {}
|
||||
def pipeline(pipeline_name, comment=None):
|
||||
def decorator(cls):
|
||||
if not hasattr(cls, 'forward') or not callable(getattr(cls, 'forward')):
|
||||
Log.error(f"pipeline <{cls.__name__}> must implement a 'forward' method", True)
|
||||
cls = ensure_component_subclass(cls, namespace.Stereotype.PIPELINE, pipeline_name)
|
||||
pipeline_classes[pipeline_name] = cls
|
||||
pipline_comments[pipeline_name] = comment
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
runner_classes = {}
|
||||
runner_comments = {}
|
||||
def runner(runner_name, comment=None):
|
||||
def decorator(cls):
|
||||
if not hasattr(cls, 'run') or not callable(getattr(cls, 'run')):
|
||||
Log.error(f"runner <{cls.__name__}> must implement a 'run' method", True)
|
||||
cls = ensure_component_subclass(cls, namespace.Stereotype.RUNNER, runner_name)
|
||||
runner_classes[runner_name] = cls
|
||||
runner_comments[runner_name] = comment
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
factory_classes = {}
|
||||
factory_comments = {}
|
||||
def factory(factory_name, comment=None):
|
||||
def decorator(cls):
|
||||
if not hasattr(cls, 'create') or not callable(getattr(cls, 'create')):
|
||||
Log.error(f"factory <{cls.__name__}> must implement a 'create' method", True)
|
||||
|
||||
cls = ensure_component_subclass(cls, namespace.Stereotype.FACTORY, factory_name)
|
||||
factory_classes[factory_name] = cls
|
||||
factory_comments[factory_name] = comment
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
loss_classes = {}
|
||||
loss_comments = {}
|
||||
def loss_function(loss_name, comment=None):
|
||||
def decorator(cls):
|
||||
|
||||
if not hasattr(cls, 'compute') or not callable(getattr(cls, 'compute')):
|
||||
Log.error(f"loss function <{cls.__name__}> must implement a 'compute' method", True)
|
||||
cls = ensure_component_subclass(cls, namespace.Stereotype.LOSS_FUNCTION, loss_name)
|
||||
loss_classes[loss_name] = cls
|
||||
loss_comments[loss_name] = comment
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
evaluation_classes = {}
|
||||
evaluation_comments = {}
|
||||
def evaluation_method(evaluation_name, comment=None):
|
||||
def decorator(cls):
|
||||
|
||||
if not hasattr(cls, 'evaluate') or not callable(getattr(cls, 'evaluate')):
|
||||
Log.error(f"evaluation method <{cls.__name__}> must implement a 'evaluate' method", True)
|
||||
cls = ensure_component_subclass(cls, namespace.Stereotype.EVALUATION_METHOD, evaluation_name)
|
||||
evaluation_classes[evaluation_name] = cls
|
||||
evaluation_comments[evaluation_name] = comment
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
# --- Others --- #
|
||||
EXTERNAL_FRONZEN_MODULES = set()
|
||||
|
||||
def external_frozen_module(cls):
|
||||
if not hasattr(cls, 'load') or not callable(getattr(cls, 'load')):
|
||||
Log.error(f"external module <{cls.__name__}> must implement a 'load' method", True)
|
||||
EXTERNAL_FRONZEN_MODULES.add(cls)
|
||||
return cls
|
||||
|
||||
# --- Utils --- #
|
||||
|
||||
all_component_classes = {
|
||||
namespace.Stereotype.DATASET: dataset_classes,
|
||||
namespace.Stereotype.MODULE: module_classes,
|
||||
namespace.Stereotype.PIPELINE: pipeline_classes,
|
||||
namespace.Stereotype.RUNNER: runner_classes,
|
||||
namespace.Stereotype.LOSS_FUNCTION: loss_classes,
|
||||
namespace.Stereotype.EVALUATION_METHOD: evaluation_classes,
|
||||
namespace.Stereotype.FACTORY: factory_classes
|
||||
}
|
||||
|
||||
all_component_comments = {
|
||||
namespace.Stereotype.DATASET: dataset_comments,
|
||||
namespace.Stereotype.MODULE: module_comments,
|
||||
namespace.Stereotype.PIPELINE: pipline_comments,
|
||||
namespace.Stereotype.RUNNER: runner_comments,
|
||||
namespace.Stereotype.LOSS_FUNCTION: loss_comments,
|
||||
namespace.Stereotype.EVALUATION_METHOD: evaluation_comments,
|
||||
namespace.Stereotype.FACTORY: factory_comments
|
||||
}
|
||||
|
||||
def get_all_component_classes():
|
||||
return all_component_classes
|
||||
|
||||
def get_all_component_comments():
|
||||
return all_component_comments
|
||||
|
||||
def get_component_classes(component_type):
|
||||
return all_component_classes.get(component_type, None)
|
BIN
PytorchBoot/templates/__pycache__/application.cpython-39.pyc
Normal file
BIN
PytorchBoot/templates/__pycache__/application.cpython-39.pyc
Normal file
Binary file not shown.
BIN
PytorchBoot/templates/__pycache__/config.cpython-39.pyc
Normal file
BIN
PytorchBoot/templates/__pycache__/config.cpython-39.pyc
Normal file
Binary file not shown.
15
PytorchBoot/templates/application.py
Normal file
15
PytorchBoot/templates/application.py
Normal file
@@ -0,0 +1,15 @@
|
||||
template = """from PytorchBoot.application import PytorchBootApplication
|
||||
|
||||
@PytorchBootApplication
|
||||
class Application:
|
||||
@staticmethod
|
||||
def start():
|
||||
'''
|
||||
call default or your custom runners here, code will be executed
|
||||
automatically when type "pytorch-boot run" or "ptb run" in terminal
|
||||
|
||||
example:
|
||||
Trainer("path_to_your_train_config").run()
|
||||
Evaluator("path_to_your_eval_config").run()
|
||||
'''
|
||||
"""
|
58
PytorchBoot/templates/config.py
Normal file
58
PytorchBoot/templates/config.py
Normal file
@@ -0,0 +1,58 @@
|
||||
template = """
|
||||
runners:
|
||||
general:
|
||||
seed: 0
|
||||
device: cuda
|
||||
cuda_visible_devices: "0,1,2,3,4,5,6,7"
|
||||
parallel: False
|
||||
|
||||
experiment:
|
||||
name: experiment_name
|
||||
root_dir: "experiments"
|
||||
use_checkpoint: False
|
||||
epoch: -1 # -1 stands for last epoch
|
||||
max_epochs: 5000
|
||||
save_checkpoint_interval: 1
|
||||
test_first: True
|
||||
|
||||
train:
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 0.0001
|
||||
losses: # loss type : weight
|
||||
loss_type_0: 1.0
|
||||
dataset:
|
||||
name: train_set_name
|
||||
source: train_set_source_name
|
||||
ratio: 1.0
|
||||
batch_size: 1
|
||||
num_workers: 1
|
||||
|
||||
test:
|
||||
frequency: 3 # test frequency
|
||||
dataset_list:
|
||||
- name: test_set_name_0
|
||||
source: train_set_source_name
|
||||
eval_list:
|
||||
- eval_func_name_0
|
||||
- eval_func_name_1
|
||||
ratio: 1.0
|
||||
batch_size: 1
|
||||
num_workers: 1
|
||||
|
||||
pipeline: pipeline_name
|
||||
|
||||
pipelines:
|
||||
pipeline_name_0:
|
||||
- module_name_0
|
||||
- module_name_1
|
||||
|
||||
datasets:
|
||||
dataset_source_name_0:
|
||||
dataset_source_name_1:
|
||||
|
||||
modules:
|
||||
module_name_0:
|
||||
module_name_1:
|
||||
|
||||
"""
|
1
PytorchBoot/ui/client/index.html
Normal file
1
PytorchBoot/ui/client/index.html
Normal file
@@ -0,0 +1 @@
|
||||
<!DOCTYPE html><html><head><meta charset=utf-8><meta name=viewport content="width=device-width,initial-scale=1"><title>PyTorchBoot Project</title><link href=/static/css/app.5383ee564f9a1a656786665504aa6b98.css rel=stylesheet></head><body><div id=app></div><script type=text/javascript src=/static/js/manifest.2ae2e69a05c33dfc65f8.js></script><script type=text/javascript src=/static/js/vendor.9f7b4785a30f0533ee08.js></script><script type=text/javascript src=/static/js/app.230235873e25a72eeacb.js></script></body></html>
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
BIN
PytorchBoot/ui/client/static/fonts/ionicons.143146f.woff2
Normal file
BIN
PytorchBoot/ui/client/static/fonts/ionicons.143146f.woff2
Normal file
Binary file not shown.
BIN
PytorchBoot/ui/client/static/fonts/ionicons.99ac330.woff
Normal file
BIN
PytorchBoot/ui/client/static/fonts/ionicons.99ac330.woff
Normal file
Binary file not shown.
BIN
PytorchBoot/ui/client/static/fonts/ionicons.d535a25.ttf
Normal file
BIN
PytorchBoot/ui/client/static/fonts/ionicons.d535a25.ttf
Normal file
Binary file not shown.
870
PytorchBoot/ui/client/static/img/ionicons.a2c4a26.svg
Normal file
870
PytorchBoot/ui/client/static/img/ionicons.a2c4a26.svg
Normal file
File diff suppressed because one or more lines are too long
After Width: | Height: | Size: 542 KiB |
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -0,0 +1,2 @@
|
||||
!function(r){var n=window.webpackJsonp;window.webpackJsonp=function(e,u,c){for(var f,i,p,a=0,l=[];a<e.length;a++)i=e[a],o[i]&&l.push(o[i][0]),o[i]=0;for(f in u)Object.prototype.hasOwnProperty.call(u,f)&&(r[f]=u[f]);for(n&&n(e,u,c);l.length;)l.shift()();if(c)for(a=0;a<c.length;a++)p=t(t.s=c[a]);return p};var e={},o={2:0};function t(n){if(e[n])return e[n].exports;var o=e[n]={i:n,l:!1,exports:{}};return r[n].call(o.exports,o,o.exports,t),o.l=!0,o.exports}t.m=r,t.c=e,t.d=function(r,n,e){t.o(r,n)||Object.defineProperty(r,n,{configurable:!1,enumerable:!0,get:e})},t.n=function(r){var n=r&&r.__esModule?function(){return r.default}:function(){return r};return t.d(n,"a",n),n},t.o=function(r,n){return Object.prototype.hasOwnProperty.call(r,n)},t.p="/",t.oe=function(r){throw console.error(r),r}}([]);
|
||||
//# sourceMappingURL=manifest.2ae2e69a05c33dfc65f8.js.map
|
@@ -0,0 +1 @@
|
||||
{"version":3,"sources":["webpack:///webpack/bootstrap def2f39c04517bb0de2d"],"names":["parentJsonpFunction","window","chunkIds","moreModules","executeModules","moduleId","chunkId","result","i","resolves","length","installedChunks","push","Object","prototype","hasOwnProperty","call","modules","shift","__webpack_require__","s","installedModules","2","exports","module","l","m","c","d","name","getter","o","defineProperty","configurable","enumerable","get","n","__esModule","object","property","p","oe","err","console","error"],"mappings":"aACA,IAAAA,EAAAC,OAAA,aACAA,OAAA,sBAAAC,EAAAC,EAAAC,GAIA,IADA,IAAAC,EAAAC,EAAAC,EAAAC,EAAA,EAAAC,KACQD,EAAAN,EAAAQ,OAAoBF,IAC5BF,EAAAJ,EAAAM,GACAG,EAAAL,IACAG,EAAAG,KAAAD,EAAAL,GAAA,IAEAK,EAAAL,GAAA,EAEA,IAAAD,KAAAF,EACAU,OAAAC,UAAAC,eAAAC,KAAAb,EAAAE,KACAY,EAAAZ,GAAAF,EAAAE,IAIA,IADAL,KAAAE,EAAAC,EAAAC,GACAK,EAAAC,QACAD,EAAAS,OAAAT,GAEA,GAAAL,EACA,IAAAI,EAAA,EAAYA,EAAAJ,EAAAM,OAA2BF,IACvCD,EAAAY,IAAAC,EAAAhB,EAAAI,IAGA,OAAAD,GAIA,IAAAc,KAGAV,GACAW,EAAA,GAIA,SAAAH,EAAAd,GAGA,GAAAgB,EAAAhB,GACA,OAAAgB,EAAAhB,GAAAkB,QAGA,IAAAC,EAAAH,EAAAhB,IACAG,EAAAH,EACAoB,GAAA,EACAF,YAUA,OANAN,EAAAZ,GAAAW,KAAAQ,EAAAD,QAAAC,IAAAD,QAAAJ,GAGAK,EAAAC,GAAA,EAGAD,EAAAD,QAKAJ,EAAAO,EAAAT,EAGAE,EAAAQ,EAAAN,EAGAF,EAAAS,EAAA,SAAAL,EAAAM,EAAAC,GACAX,EAAAY,EAAAR,EAAAM,IACAhB,OAAAmB,eAAAT,EAAAM,GACAI,cAAA,EACAC,YAAA,EACAC,IAAAL,KAMAX,EAAAiB,EAAA,SAAAZ,GACA,IAAAM,EAAAN,KAAAa,WACA,WAA2B,OAAAb,EAAA,SAC3B,WAAiC,OAAAA,GAEjC,OADAL,EAAAS,EAAAE,EAAA,IAAAA,GACAA,GAIAX,EAAAY,EAAA,SAAAO,EAAAC,GAAsD,OAAA1B,OAAAC,UAAAC,eAAAC,KAAAsB,EAAAC,IAGtDpB,EAAAqB,EAAA,IAGArB,EAAAsB,GAAA,SAAAC,GAA8D,MAApBC,QAAAC,MAAAF,GAAoBA","file":"static/js/manifest.2ae2e69a05c33dfc65f8.js","sourcesContent":[" \t// install a JSONP callback for chunk loading\n \tvar parentJsonpFunction = window[\"webpackJsonp\"];\n \twindow[\"webpackJsonp\"] = function webpackJsonpCallback(chunkIds, moreModules, executeModules) {\n \t\t// add \"moreModules\" to the modules object,\n \t\t// then flag all \"chunkIds\" as loaded and fire callback\n \t\tvar moduleId, chunkId, i = 0, resolves = [], result;\n \t\tfor(;i < chunkIds.length; i++) {\n \t\t\tchunkId = chunkIds[i];\n \t\t\tif(installedChunks[chunkId]) {\n \t\t\t\tresolves.push(installedChunks[chunkId][0]);\n \t\t\t}\n \t\t\tinstalledChunks[chunkId] = 0;\n \t\t}\n \t\tfor(moduleId in moreModules) {\n \t\t\tif(Object.prototype.hasOwnProperty.call(moreModules, moduleId)) {\n \t\t\t\tmodules[moduleId] = moreModules[moduleId];\n \t\t\t}\n \t\t}\n \t\tif(parentJsonpFunction) parentJsonpFunction(chunkIds, moreModules, executeModules);\n \t\twhile(resolves.length) {\n \t\t\tresolves.shift()();\n \t\t}\n \t\tif(executeModules) {\n \t\t\tfor(i=0; i < executeModules.length; i++) {\n \t\t\t\tresult = __webpack_require__(__webpack_require__.s = executeModules[i]);\n \t\t\t}\n \t\t}\n \t\treturn result;\n \t};\n\n \t// The module cache\n \tvar installedModules = {};\n\n \t// objects to store loaded and loading chunks\n \tvar installedChunks = {\n \t\t2: 0\n \t};\n\n \t// The require function\n \tfunction __webpack_require__(moduleId) {\n\n \t\t// Check if module is in cache\n \t\tif(installedModules[moduleId]) {\n \t\t\treturn installedModules[moduleId].exports;\n \t\t}\n \t\t// Create a new module (and put it into the cache)\n \t\tvar module = installedModules[moduleId] = {\n \t\t\ti: moduleId,\n \t\t\tl: false,\n \t\t\texports: {}\n \t\t};\n\n \t\t// Execute the module function\n \t\tmodules[moduleId].call(module.exports, module, module.exports, __webpack_require__);\n\n \t\t// Flag the module as loaded\n \t\tmodule.l = true;\n\n \t\t// Return the exports of the module\n \t\treturn module.exports;\n \t}\n\n\n \t// expose the modules object (__webpack_modules__)\n \t__webpack_require__.m = modules;\n\n \t// expose the module cache\n \t__webpack_require__.c = installedModules;\n\n \t// define getter function for harmony exports\n \t__webpack_require__.d = function(exports, name, getter) {\n \t\tif(!__webpack_require__.o(exports, name)) {\n \t\t\tObject.defineProperty(exports, name, {\n \t\t\t\tconfigurable: false,\n \t\t\t\tenumerable: true,\n \t\t\t\tget: getter\n \t\t\t});\n \t\t}\n \t};\n\n \t// getDefaultExport function for compatibility with non-harmony modules\n \t__webpack_require__.n = function(module) {\n \t\tvar getter = module && module.__esModule ?\n \t\t\tfunction getDefault() { return module['default']; } :\n \t\t\tfunction getModuleExports() { return module; };\n \t\t__webpack_require__.d(getter, 'a', getter);\n \t\treturn getter;\n \t};\n\n \t// Object.prototype.hasOwnProperty.call\n \t__webpack_require__.o = function(object, property) { return Object.prototype.hasOwnProperty.call(object, property); };\n\n \t// __webpack_public_path__\n \t__webpack_require__.p = \"/\";\n\n \t// on error function for async loading\n \t__webpack_require__.oe = function(err) { console.error(err); throw err; };\n\n\n\n// WEBPACK FOOTER //\n// webpack/bootstrap def2f39c04517bb0de2d"],"sourceRoot":""}
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
233
PytorchBoot/ui/server/app.py
Normal file
233
PytorchBoot/ui/server/app.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import os
|
||||
import threading
|
||||
import socket
|
||||
import logging
|
||||
import psutil
|
||||
import GPUtil
|
||||
import platform
|
||||
|
||||
from flask import Flask, jsonify, request, send_from_directory
|
||||
from flask_cors import CORS
|
||||
from tensorboard import program
|
||||
from PytorchBoot.utils.project_util import ProjectUtil
|
||||
from PytorchBoot.stereotype import get_all_component_classes, get_all_component_comments
|
||||
from PytorchBoot.application import application_class
|
||||
from PytorchBoot.status import status_manager
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
from PytorchBoot.utils.timer_util import Timer
|
||||
|
||||
app = Flask(__name__, static_folder="../client")
|
||||
app.logger.setLevel("WARNING")
|
||||
logging.getLogger("werkzeug").disabled = True
|
||||
CORS(app)
|
||||
root_path = os.getcwd()
|
||||
ProjectUtil.scan_project(root_path)
|
||||
configs = ProjectUtil.scan_configs(root_path)
|
||||
running_tensorboard = {}
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def serve_index():
|
||||
return send_from_directory(app.static_folder, "index.html")
|
||||
|
||||
|
||||
@app.route("/<path:path>")
|
||||
def serve_file(path):
|
||||
return send_from_directory(app.static_folder, path)
|
||||
|
||||
|
||||
@app.route("/test", methods=["POST"])
|
||||
def hello_world():
|
||||
return jsonify(message="Hello, World!")
|
||||
|
||||
|
||||
@app.route("/project/structure", methods=["POST"])
|
||||
def project_structure():
|
||||
component_info = {}
|
||||
for st, cls_dict in get_all_component_classes().items():
|
||||
component_info[st] = {k: v.__name__ for k, v in cls_dict.items()}
|
||||
comment_info = get_all_component_comments()
|
||||
app_info = {}
|
||||
for app_name, app_cls in application_class.items():
|
||||
app_info[app_name] = app_cls.__name__
|
||||
|
||||
return jsonify(
|
||||
components=component_info,
|
||||
comments=comment_info,
|
||||
applications=app_info,
|
||||
configs=configs,
|
||||
root_path=root_path,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/project/run_app", methods=["POST"])
|
||||
def run_application():
|
||||
data = request.json
|
||||
app_name = data.get("app_name")
|
||||
app_cls = application_class.get(app_name)
|
||||
|
||||
if app_cls is None:
|
||||
Log.error(
|
||||
f"No class annotated with @PytorchBootApplication found with the name '{app_name}'.",
|
||||
True,
|
||||
)
|
||||
return jsonify(
|
||||
{
|
||||
"message": f"No application found with the name '{app_name}'",
|
||||
"status": "error",
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(app_cls, "start"):
|
||||
Log.error(
|
||||
"The class annotated with @PytorchBootApplication should have a 'start' method.",
|
||||
True,
|
||||
)
|
||||
return jsonify(
|
||||
{"message": "The class should have a 'start' method", "status": "error"}
|
||||
)
|
||||
|
||||
def run_in_background():
|
||||
Log.info(f"Application '{app_cls.__name__}' started.")
|
||||
timer = Timer("Application")
|
||||
timer.start()
|
||||
status_manager.run_app(app_name, app_cls)
|
||||
app_cls.start()
|
||||
status_manager.end_app(app_name)
|
||||
timer.stop()
|
||||
Log.info(timer.get_elasped_time_str(Timer.HOURS))
|
||||
Log.success("Application finished.")
|
||||
|
||||
threading.Thread(target=run_in_background).start()
|
||||
|
||||
return jsonify(
|
||||
{"message": f"Application '{app_name}' is running now.", "status": "success"}
|
||||
)
|
||||
|
||||
|
||||
@app.route("/project/get_status", methods=["POST"])
|
||||
def get_status():
|
||||
cpu_info = {
|
||||
"model": platform.processor(),
|
||||
"usage_percent": psutil.cpu_percent(interval=1),
|
||||
}
|
||||
virtual_memory = psutil.virtual_memory()
|
||||
memory_info = {
|
||||
"used": round(virtual_memory.used / (1024**3), 3),
|
||||
"total": round(virtual_memory.total / (1024**3), 3),
|
||||
}
|
||||
|
||||
gpus = GPUtil.getGPUs()
|
||||
gpu_info = []
|
||||
for gpu in gpus:
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": gpu.name,
|
||||
"memory_used": gpu.memoryUsed,
|
||||
"memory_total": gpu.memoryTotal,
|
||||
}
|
||||
)
|
||||
|
||||
return jsonify(
|
||||
curr_status=status_manager.get_status(),
|
||||
last_status=status_manager.get_last_status(),
|
||||
logs=status_manager.get_log(),
|
||||
progress=status_manager.get_progress(),
|
||||
running_apps=status_manager.get_running_apps(),
|
||||
cpu=cpu_info,
|
||||
memory=memory_info,
|
||||
gpus=gpu_info,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/project/set_status", methods=["POST"])
|
||||
def set_status():
|
||||
status = request.json.get("status")
|
||||
progress = request.json.get("progress")
|
||||
if status:
|
||||
status_manager.set_status(
|
||||
app_name=status["app_name"],
|
||||
runner_name=status["runner_name"],
|
||||
key=status["key"],
|
||||
value=status["value"],
|
||||
)
|
||||
if progress:
|
||||
status_manager.set_progress(
|
||||
app_name=progress["app_name"],
|
||||
runner_name=progress["runner_name"],
|
||||
key=progress["key"],
|
||||
curr_value=progress["curr_value"],
|
||||
max_value=progress["max_value"],
|
||||
)
|
||||
return jsonify({"status": "success"})
|
||||
|
||||
@app.route("/project/add_log", methods=["POST"])
|
||||
def add_log():
|
||||
log = request.json.get("log")
|
||||
Log.log(log["message"], log["log_type"])
|
||||
return jsonify({"status": "success"})
|
||||
|
||||
def find_free_port(start_port):
|
||||
"""Find a free port starting from start_port."""
|
||||
port = start_port
|
||||
while True:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
result = sock.connect_ex(("localhost", port))
|
||||
if result != 0:
|
||||
return port
|
||||
port += 1
|
||||
|
||||
|
||||
def start_tensorboard(log_dir, port):
|
||||
"""Starts TensorBoard in a separate thread."""
|
||||
tb = program.TensorBoard()
|
||||
tb.configure(argv=[None, "--logdir", log_dir, "--port", str(port)])
|
||||
tb.launch()
|
||||
|
||||
|
||||
@app.route("/tensorboard/run", methods=["POST"])
|
||||
def run_tensorboard():
|
||||
data = request.json
|
||||
log_dir = data.get("log_dir")
|
||||
if log_dir in running_tensorboard:
|
||||
return jsonify(
|
||||
{
|
||||
"message": f"TensorBoard ({running_tensorboard[log_dir]}) is already running for <{log_dir}>",
|
||||
"url": running_tensorboard[log_dir],
|
||||
"status": "warning",
|
||||
}
|
||||
)
|
||||
|
||||
if not os.path.isdir(log_dir):
|
||||
return jsonify({"message": "Log directory does not exist", "status": "error"})
|
||||
|
||||
port = find_free_port(10000)
|
||||
|
||||
try:
|
||||
tb_thread = threading.Thread(target=start_tensorboard, args=(log_dir, port))
|
||||
tb_thread.start()
|
||||
except Exception as e:
|
||||
return jsonify(
|
||||
{"message": f"Error starting TensorBoard: {str(e)}", "status": "error"}
|
||||
)
|
||||
|
||||
url = f"http://localhost:{port}"
|
||||
running_tensorboard[log_dir] = url
|
||||
return jsonify(
|
||||
{"url": url, "message": f"TensorBoard is running at {url}", "status": "success"}
|
||||
)
|
||||
|
||||
|
||||
@app.route("/tensorboard/dirs", methods=["POST"])
|
||||
def get_tensorboard_dirs():
|
||||
tensorboard_dirs = []
|
||||
for root, dirs, _ in os.walk(root_path):
|
||||
for dir_name in dirs:
|
||||
if dir_name == "tensorboard":
|
||||
tensorboard_dirs.append(os.path.join(root, dir_name))
|
||||
return jsonify({"tensorboard_dirs": tensorboard_dirs})
|
||||
|
||||
|
||||
@app.route("/tensorboard/running_tensorboards", methods=["POST"])
|
||||
def get_running_tensorboards():
|
||||
return jsonify(running_tensorboards=running_tensorboard)
|
3
PytorchBoot/utils/__init__.py
Normal file
3
PytorchBoot/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
from PytorchBoot.utils.tensorboard_util import TensorboardWriter
|
||||
from PytorchBoot.utils.timer_util import Timer
|
BIN
PytorchBoot/utils/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
PytorchBoot/utils/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
PytorchBoot/utils/__pycache__/log_util.cpython-39.pyc
Normal file
BIN
PytorchBoot/utils/__pycache__/log_util.cpython-39.pyc
Normal file
Binary file not shown.
BIN
PytorchBoot/utils/__pycache__/tensorboard_util.cpython-39.pyc
Normal file
BIN
PytorchBoot/utils/__pycache__/tensorboard_util.cpython-39.pyc
Normal file
Binary file not shown.
BIN
PytorchBoot/utils/__pycache__/timer_util.cpython-39.pyc
Normal file
BIN
PytorchBoot/utils/__pycache__/timer_util.cpython-39.pyc
Normal file
Binary file not shown.
81
PytorchBoot/utils/log_util.py
Normal file
81
PytorchBoot/utils/log_util.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import time
|
||||
import PytorchBoot.namespace as namespace
|
||||
from PytorchBoot.status import status_manager
|
||||
|
||||
class Log:
|
||||
MAX_TITLE_LENGTH:int = 7
|
||||
TYPE_COLOR_MAP = {
|
||||
namespace.LogType.INFO: "\033[94m",
|
||||
namespace.LogType.ERROR: "\033[91m",
|
||||
namespace.LogType.WARNING: "\033[93m",
|
||||
namespace.LogType.SUCCESS: "\033[92m",
|
||||
namespace.LogType.DEBUG: "\033[95m",
|
||||
namespace.LogType.TERMINATE: "\033[96m"
|
||||
}
|
||||
def get_time():
|
||||
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
def blue(message):
|
||||
# blue
|
||||
print(f"\033[94m{message}\033[0m")
|
||||
def red(message):
|
||||
# red
|
||||
print(f"\033[91m{message}\033[0m")
|
||||
def yellow(message):
|
||||
# yellow
|
||||
print(f"\033[93m{message}\033[0m")
|
||||
def green(message):
|
||||
# green
|
||||
print(f"\033[92m{message}\033[0m")
|
||||
|
||||
def log(message, log_type: str):
|
||||
time_str = Log.get_time()
|
||||
space = ""
|
||||
if len(log_type) < Log.MAX_TITLE_LENGTH:
|
||||
space = " " * (Log.MAX_TITLE_LENGTH - len(log_type))
|
||||
|
||||
print (f"\033[1m\033[4m({time_str})\033[0m \033[1m{Log.TYPE_COLOR_MAP[log_type]}[{log_type.capitalize()}]\033[0m{space} {Log.TYPE_COLOR_MAP[log_type]}{message}\033[0m")
|
||||
status_manager.add_log(time_str, log_type, message)
|
||||
|
||||
def bold(message):
|
||||
print(f"\033[1m{message}\033[0m")
|
||||
def underline(message):
|
||||
print(f"\033[4m{message}\033[0m")
|
||||
|
||||
def info(message):
|
||||
Log.log(message, namespace.LogType.INFO)
|
||||
|
||||
def error(message, terminate=False):
|
||||
Log.log(message, namespace.LogType.ERROR)
|
||||
if terminate:
|
||||
Log.terminate("Application Terminated.")
|
||||
|
||||
|
||||
def warning(message):
|
||||
Log.log(message, namespace.LogType.WARNING)
|
||||
def success(message):
|
||||
Log.log(message, namespace.LogType.SUCCESS)
|
||||
|
||||
def debug(message):
|
||||
Log.log(message, namespace.LogType.DEBUG)
|
||||
|
||||
def terminate(message):
|
||||
Log.log(message, namespace.LogType.TERMINATE)
|
||||
exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
Log.info("This is a info message")
|
||||
Log.error("This is a error message")
|
||||
Log.warning("This is a warning message")
|
||||
Log.success("This is a success message")
|
||||
Log.debug("This is a debug message")
|
||||
Log.blue("This is a blue message")
|
||||
Log.red("This is a red message")
|
||||
Log.yellow("This is a yellow message")
|
||||
Log.green("This is a green message")
|
||||
|
||||
Log.bold("This is a bold message")
|
||||
Log.underline("This is a underline message")
|
||||
Log.error("This is a terminate message", True)
|
||||
|
||||
|
50
PytorchBoot/utils/project_util.py
Normal file
50
PytorchBoot/utils/project_util.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
import sys
|
||||
import yaml
|
||||
import importlib
|
||||
|
||||
class ProjectUtil:
|
||||
@staticmethod
|
||||
def scan_project(root_path):
|
||||
sys.path.append(root_path)
|
||||
if not os.path.exists(root_path) or not os.path.isdir(root_path):
|
||||
raise ValueError(f"The provided root_path '{root_path}' is not a valid directory.")
|
||||
|
||||
parent_dir = os.path.dirname(root_path)
|
||||
sys.path.insert(0, parent_dir)
|
||||
|
||||
def import_all_modules(path, package_name):
|
||||
for root, dirs, files in os.walk(path):
|
||||
relative_path = os.path.relpath(root, root_path)
|
||||
if relative_path == '.':
|
||||
module_package = package_name
|
||||
else:
|
||||
module_package = f"{package_name}.{relative_path.replace(os.sep, '.')}"
|
||||
for file in files:
|
||||
if file.endswith(".py") and file != "__init__.py":
|
||||
module_name = file[:-3]
|
||||
full_module_name = f"{module_package}.{module_name}"
|
||||
if full_module_name not in sys.modules:
|
||||
importlib.import_module(full_module_name)
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.')]
|
||||
|
||||
package_name = os.path.basename(root_path)
|
||||
import_all_modules(root_path, package_name)
|
||||
|
||||
@staticmethod
|
||||
def scan_configs(root_path):
|
||||
configs = {}
|
||||
for root, dirs, files in os.walk(root_path):
|
||||
for file in files:
|
||||
if file.endswith(('.yaml', '.yml')):
|
||||
if file.startswith('__'):
|
||||
continue
|
||||
file_path = os.path.join(root, file)
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
try:
|
||||
content = yaml.safe_load(f)
|
||||
configs[os.path.splitext(file)[0]] = content
|
||||
except yaml.YAMLError as e:
|
||||
print(f"Error reading {file_path}: {e}")
|
||||
|
||||
return configs
|
44
PytorchBoot/utils/tensorboard_util.py
Normal file
44
PytorchBoot/utils/tensorboard_util.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
import PytorchBoot.namespace as namespace
|
||||
|
||||
class TensorboardWriter:
|
||||
@staticmethod
|
||||
def write_tensorboard(writer, panel, data_dict, step, simple_scalar = False):
|
||||
|
||||
if simple_scalar:
|
||||
TensorboardWriter.write_scalar_tensorboard(writer, panel, data_dict, step)
|
||||
|
||||
if namespace.TensorBoard.SCALAR in data_dict:
|
||||
scalar_data_dict = data_dict[namespace.TensorBoard.SCALAR]
|
||||
TensorboardWriter.write_scalar_tensorboard(writer, panel, scalar_data_dict, step)
|
||||
if namespace.TensorBoard.IMAGE in data_dict:
|
||||
image_data_dict = data_dict[namespace.TensorBoard.IMAGE]
|
||||
TensorboardWriter.write_image_tensorboard(writer, panel, image_data_dict, step)
|
||||
if namespace.TensorBoard.POINT in data_dict:
|
||||
point_data_dict = data_dict[namespace.TensorBoard.POINT]
|
||||
TensorboardWriter.write_points_tensorboard(writer, panel, point_data_dict, step)
|
||||
|
||||
@staticmethod
|
||||
def write_scalar_tensorboard(writer, panel, data_dict, step):
|
||||
for key, value in data_dict.items():
|
||||
if isinstance(value, dict):
|
||||
writer.add_scalars(f'{panel}/{key}', value, step)
|
||||
else:
|
||||
writer.add_scalar(f'{panel}/{key}', value, step)
|
||||
|
||||
@staticmethod
|
||||
def write_image_tensorboard(writer, panel, data_dict, step):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def write_points_tensorboard(writer, panel, data_dict, step):
|
||||
for key, value in data_dict.items():
|
||||
if value.shape[-1] == 3:
|
||||
colors = torch.zeros_like(value)
|
||||
vertices = torch.cat([value, colors], dim=-1)
|
||||
elif value.shape[-1] == 6:
|
||||
vertices = value
|
||||
else:
|
||||
raise ValueError(f'Unexpected value shape: {value.shape}')
|
||||
faces = None
|
||||
writer.add_mesh(f'{panel}/{key}', vertices=vertices, faces=faces, global_step=step)
|
32
PytorchBoot/utils/timer_util.py
Normal file
32
PytorchBoot/utils/timer_util.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import time
|
||||
|
||||
class Timer:
|
||||
MILLI_SECONDS = "milliseconds"
|
||||
SECONDS = "seconds"
|
||||
MINUTES = "minutes"
|
||||
HOURS = "hours"
|
||||
def __init__(self, name=None):
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
self.name = name
|
||||
|
||||
def start(self):
|
||||
self.start_time = time.time()
|
||||
|
||||
def stop(self):
|
||||
self.end_time = time.time()
|
||||
|
||||
def elapsed_time(self):
|
||||
return int(self.end_time - self.start_time)
|
||||
|
||||
def get_elasped_time_str(self, format):
|
||||
if format == Timer.SECONDS:
|
||||
return f"Elapsed time in <{self.name}>: {self.elapsed_time()} seconds"
|
||||
elif format == Timer.MINUTES:
|
||||
return f"Elapsed time in <{self.name}>: {self.elapsed_time() // 60} minutes, {self.elapsed_time() % 60} seconds"
|
||||
elif format == Timer.HOURS:
|
||||
return f"Elapsed time in <{self.name}>: {self.elapsed_time() // 3600} hours, {(self.elapsed_time() % 3600)//60} minutes, {self.elapsed_time() % 60} seconds"
|
||||
elif format == Timer.MILLI_SECONDS:
|
||||
return f"Elapsed time in <{self.name}>: {(self.end_time - self.start_time) * 1000} milliseconds"
|
||||
else:
|
||||
return f"Invalid format: {format}"
|
Reference in New Issue
Block a user