From 31689a0139f74da7e712c7f89eb7700f3e7611f4 Mon Sep 17 00:00:00 2001 From: "liugao.lg" Date: Wed, 23 Nov 2022 19:08:39 +0800 Subject: [PATCH] merge master& add multi-gpu for ofa MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增对ofa多GPU训练的支持 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10838906 --- modelscope/trainers/multi_modal/ofa/ofa_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py index 71494768..e27c23fd 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py @@ -20,6 +20,7 @@ from modelscope.preprocessors.ofa.utils.collate import collate_fn from modelscope.trainers import EpochBasedTrainer from modelscope.trainers.builder import TRAINERS from modelscope.trainers.optimizer.builder import build_optimizer +from modelscope.trainers.parallel.utils import is_parallel from modelscope.utils.config import Config from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys, ModeKeys) @@ -137,6 +138,7 @@ class OFATrainer(EpochBasedTrainer): return cfg def train_step(self, model, inputs): + model = model.module if self._dist or is_parallel(model) else model model.train() loss, sample_size, logging_output = self.criterion(model, inputs) train_outputs = {'loss': loss}