mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
update rdrop
This commit is contained in:
@@ -131,7 +131,7 @@ class OFATrainer(EpochBasedTrainer):
|
||||
model.train()
|
||||
# model_outputs = model.forward(inputs)
|
||||
loss, sample_size, logging_output = self.criterion(model, inputs)
|
||||
train_outputs = {'loss': loss / 100}
|
||||
train_outputs = {'loss': loss}
|
||||
# add model output info to log
|
||||
if 'log_vars' not in train_outputs:
|
||||
default_keys_pattern = ['loss']
|
||||
|
||||
@@ -144,7 +144,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
|
||||
sample_size = (
|
||||
sample['target'].size(0) if self.sentence_avg else ntokens)
|
||||
logging_output = {
|
||||
'loss': loss.data / 100,
|
||||
'loss': loss.data,
|
||||
'nll_loss': nll_loss.data,
|
||||
'ntokens': sample['ntokens'],
|
||||
'nsentences': sample['nsentences'],
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import glob
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
@@ -98,8 +96,9 @@ class TestOfaTrainer(unittest.TestCase):
|
||||
trainer = build_trainer(name=Trainers.ofa, default_args=args)
|
||||
trainer.train()
|
||||
|
||||
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE,
|
||||
os.listdir(os.path.join(WORKSPACE, 'output')))
|
||||
self.assertIn(
|
||||
ModelFile.TORCH_MODEL_BIN_FILE,
|
||||
os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR)))
|
||||
shutil.rmtree(WORKSPACE)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user