diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py index 71494768..e27c23fd 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py @@ -20,6 +20,7 @@ from modelscope.preprocessors.ofa.utils.collate import collate_fn from modelscope.trainers import EpochBasedTrainer from modelscope.trainers.builder import TRAINERS from modelscope.trainers.optimizer.builder import build_optimizer +from modelscope.trainers.parallel.utils import is_parallel from modelscope.utils.config import Config from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, ModeKeys) @@ -137,6 +138,7 @@ class OFATrainer(EpochBasedTrainer): return cfg def train_step(self, model, inputs): + model = model.module if self._dist or is_parallel(model) else model model.train() loss, sample_size, logging_output = self.criterion(model, inputs) train_outputs = {'loss': loss}