From cc8b78eac8ae5c4a6288c04fdb9fc370527273e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=A1=8C=E5=97=94?= Date: Tue, 25 Oct 2022 11:55:37 +0800 Subject: [PATCH] update rdrop --- modelscope/trainers/multi_modal/ofa/ofa_trainer.py | 2 +- modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py | 2 +- tests/trainers/test_ofa_trainer.py | 7 +++---- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py index 34919fb2..c36a886e 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py @@ -131,7 +131,7 @@ class OFATrainer(EpochBasedTrainer): model.train() # model_outputs = model.forward(inputs) loss, sample_size, logging_output = self.criterion(model, inputs) - train_outputs = {'loss': loss / 100} + train_outputs = {'loss': loss} # add model output info to log if 'log_vars' not in train_outputs: default_keys_pattern = ['loss'] diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py index 3ba5c91f..3c38884c 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py @@ -144,7 +144,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): sample_size = ( sample['target'].size(0) if self.sentence_avg else ntokens) logging_output = { - 'loss': loss.data / 100, + 'loss': loss.data, 'nll_loss': nll_loss.data, 'ntokens': sample['ntokens'], 'nsentences': sample['nsentences'], diff --git a/tests/trainers/test_ofa_trainer.py b/tests/trainers/test_ofa_trainer.py index 5c252e0a..21ddce21 100644 --- a/tests/trainers/test_ofa_trainer.py +++ b/tests/trainers/test_ofa_trainer.py @@ -1,7 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -import glob import os -import os.path as osp import shutil import unittest @@ -98,8 +96,9 @@ class TestOfaTrainer(unittest.TestCase): trainer = build_trainer(name=Trainers.ofa, default_args=args) trainer.train() - self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, - os.listdir(os.path.join(WORKSPACE, 'output'))) + self.assertIn( + ModelFile.TORCH_MODEL_BIN_FILE, + os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR))) shutil.rmtree(WORKSPACE)