update rdrop

This commit is contained in:
行嗔
2022-10-25 11:55:37 +08:00
parent 0c64d3fca5
commit cc8b78eac8
3 changed files with 5 additions and 6 deletions

View File

@@ -131,7 +131,7 @@ class OFATrainer(EpochBasedTrainer):
model.train()
# model_outputs = model.forward(inputs)
loss, sample_size, logging_output = self.criterion(model, inputs)
train_outputs = {'loss': loss / 100}
train_outputs = {'loss': loss}
# add model output info to log
if 'log_vars' not in train_outputs:
default_keys_pattern = ['loss']

View File

@@ -144,7 +144,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
sample_size = (
sample['target'].size(0) if self.sentence_avg else ntokens)
logging_output = {
'loss': loss.data / 100,
'loss': loss.data,
'nll_loss': nll_loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['nsentences'],

View File

@@ -1,7 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import glob
import os
import os.path as osp
import shutil
import unittest
@@ -98,8 +96,9 @@ class TestOfaTrainer(unittest.TestCase):
trainer = build_trainer(name=Trainers.ofa, default_args=args)
trainer.train()
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE,
os.listdir(os.path.join(WORKSPACE, 'output')))
self.assertIn(
ModelFile.TORCH_MODEL_BIN_FILE,
os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR)))
shutil.rmtree(WORKSPACE)