Init ptb
This commit is contained in:
2
PytorchBoot/factory/__init__.py
Normal file
2
PytorchBoot/factory/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from PytorchBoot.factory.component_factory import ComponentFactory
|
||||
from PytorchBoot.factory.optimizer_factory import OptimizerFactory
|
BIN
PytorchBoot/factory/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
PytorchBoot/factory/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
PytorchBoot/factory/__pycache__/component_factory.cpython-39.pyc
Normal file
BIN
PytorchBoot/factory/__pycache__/component_factory.cpython-39.pyc
Normal file
Binary file not shown.
BIN
PytorchBoot/factory/__pycache__/optimizer_factory.cpython-39.pyc
Normal file
BIN
PytorchBoot/factory/__pycache__/optimizer_factory.cpython-39.pyc
Normal file
Binary file not shown.
27
PytorchBoot/factory/component_factory.py
Normal file
27
PytorchBoot/factory/component_factory.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from PytorchBoot.component import Component
|
||||
from PytorchBoot.stereotype import *
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
from PytorchBoot.config import ConfigManager
|
||||
|
||||
class ComponentFactory:
|
||||
@staticmethod
|
||||
def create(component_type: str, name: str) -> Component:
|
||||
component_classes = get_component_classes(component_type=component_type)
|
||||
if component_classes is None:
|
||||
Log.error(f"Unsupported component type: {component_type}", True)
|
||||
|
||||
if component_type == namespace.Stereotype.DATASET:
|
||||
config = ConfigManager.get(component_type, name)
|
||||
cls = dataset_classes[config["source"]]
|
||||
dataset_obj = cls(config)
|
||||
dataset_obj.NAME = name
|
||||
dataset_obj.TYPE = component_type
|
||||
return dataset_obj
|
||||
|
||||
if name not in component_classes:
|
||||
Log.error(f"Unsupported component name: {name}", True)
|
||||
|
||||
cls = component_classes[name]
|
||||
config = ConfigManager.get(component_type, name)
|
||||
return cls(config)
|
||||
|
67
PytorchBoot/factory/optimizer_factory.py
Normal file
67
PytorchBoot/factory/optimizer_factory.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import torch.optim as optim
|
||||
|
||||
class OptimizerFactory:
|
||||
@staticmethod
|
||||
def create(config: dict, params) -> optim.Optimizer:
|
||||
optim_type = config["type"]
|
||||
lr = config.get("lr", 1e-3)
|
||||
|
||||
if optim_type == "SGD":
|
||||
return optim.SGD(
|
||||
params,
|
||||
lr=lr,
|
||||
momentum=config.get("momentum", 0.9),
|
||||
weight_decay=config.get("weight_decay", 1e-4),
|
||||
)
|
||||
elif optim_type == "Adam":
|
||||
return optim.Adam(
|
||||
params,
|
||||
lr=lr,
|
||||
betas=config.get("betas", (0.9, 0.999)),
|
||||
eps=config.get("eps", 1e-8),
|
||||
)
|
||||
elif optim_type == "AdamW":
|
||||
return optim.AdamW(
|
||||
params,
|
||||
lr=lr,
|
||||
betas=config.get("betas", (0.9, 0.999)),
|
||||
eps=config.get("eps", 1e-8),
|
||||
weight_decay=config.get("weight_decay", 1e-2),
|
||||
)
|
||||
elif optim_type == "RMSprop":
|
||||
return optim.RMSprop(
|
||||
params,
|
||||
lr=lr,
|
||||
alpha=config.get("alpha", 0.99),
|
||||
eps=config.get("eps", 1e-8),
|
||||
weight_decay=config.get("weight_decay", 1e-4),
|
||||
momentum=config.get("momentum", 0.9),
|
||||
)
|
||||
elif optim_type == "Adagrad":
|
||||
return optim.Adagrad(
|
||||
params,
|
||||
lr=lr,
|
||||
lr_decay=config.get("lr_decay", 0),
|
||||
weight_decay=config.get("weight_decay", 0),
|
||||
)
|
||||
elif optim_type == "Adamax":
|
||||
return optim.Adamax(
|
||||
params,
|
||||
lr=lr,
|
||||
betas=config.get("betas", (0.9, 0.999)),
|
||||
eps=config.get("eps", 1e-8),
|
||||
weight_decay=config.get("weight_decay", 0),
|
||||
)
|
||||
elif optim_type == "LBFGS":
|
||||
return optim.LBFGS(
|
||||
params,
|
||||
lr=lr,
|
||||
max_iter=config.get("max_iter", 20),
|
||||
max_eval=config.get("max_eval", None),
|
||||
tolerance_grad=config.get("tolerance_grad", 1e-7),
|
||||
tolerance_change=config.get("tolerance_change", 1e-9),
|
||||
history_size=config.get("history_size", 100),
|
||||
line_search_fn=config.get("line_search_fn", None),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unknown optimizer: {}".format(optim_type))
|
Reference in New Issue
Block a user