mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
merge master& add multi-gpu for ofa
新增对ofa多GPU训练的支持
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10838906
This commit is contained in:
@@ -20,6 +20,7 @@ from modelscope.preprocessors.ofa.utils.collate import collate_fn
|
||||
from modelscope.trainers import EpochBasedTrainer
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.trainers.optimizer.builder import build_optimizer
|
||||
from modelscope.trainers.parallel.utils import is_parallel
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigKeys,
|
||||
ModeKeys)
|
||||
@@ -137,6 +138,7 @@ class OFATrainer(EpochBasedTrainer):
|
||||
return cfg
|
||||
|
||||
def train_step(self, model, inputs):
|
||||
model = model.module if self._dist or is_parallel(model) else model
|
||||
model.train()
|
||||
loss, sample_size, logging_output = self.criterion(model, inputs)
|
||||
train_outputs = {'loss': loss}
|
||||
|
||||
Reference in New Issue
Block a user