update basic framework

This commit is contained in:
hofee
2024-08-21 17:11:56 +08:00
parent 73dcd592df
commit f977fd4b8e
29 changed files with 1393 additions and 719 deletions

12
losses/gf_loss.py Normal file
View File

@@ -0,0 +1,12 @@
import torch
import PytorchBoot.stereotype as stereotype
@stereotype.loss_function("gf_loss")
def compute_loss(output, data):
estimated_score = output['estimated_score']
target_score = output['target_score']
std = output['std']
bs = estimated_score.shape[0]
loss_weighting = std ** 2
loss = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score) ** 2).view(bs, -1), dim=-1))
return loss

View File

@@ -1,12 +0,0 @@
class LossFunctionFactory:
@staticmethod
def create(function_name):
raise ValueError("Unknown loss function {}".format(function_name))
''' ------------ Debug ------------ '''
if __name__ == "__main__":
from configs.config import ConfigManager
ConfigManager.load_config_with('../configs/local_train_config.yaml')
ConfigManager.print_config()