mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 12:09:22 +01:00
fix dist training
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10185634 * fix dist training
This commit is contained in:
@@ -37,8 +37,8 @@ from modelscope.utils.device import create_device, verify_device
|
||||
from modelscope.utils.file_utils import func_receive_dict_inputs
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.registry import build_from_cfg
|
||||
from modelscope.utils.torch_utils import (get_dist_info, init_dist,
|
||||
set_random_seed)
|
||||
from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
|
||||
init_dist, set_random_seed)
|
||||
from .base import BaseTrainer
|
||||
from .builder import TRAINERS
|
||||
from .default_config import DEFAULT_CONFIG
|
||||
@@ -155,8 +155,17 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
if self.eval_preprocessor is not None:
|
||||
self.eval_preprocessor.mode = ModeKeys.EVAL
|
||||
|
||||
if kwargs.get('launcher', None) is not None:
|
||||
init_dist(kwargs['launcher'])
|
||||
|
||||
_, world_size = get_dist_info()
|
||||
self._dist = world_size > 1
|
||||
|
||||
device_name = kwargs.get('device', 'gpu')
|
||||
verify_device(device_name)
|
||||
if self._dist:
|
||||
local_rank = get_local_rank()
|
||||
device_name = f'cuda:{local_rank}'
|
||||
|
||||
self.device = create_device(device_name)
|
||||
|
||||
self.train_dataset = self.to_task_dataset(
|
||||
@@ -219,11 +228,6 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
|
||||
self.use_fp16 = kwargs.get('use_fp16', False)
|
||||
|
||||
if kwargs.get('launcher', None) is not None:
|
||||
init_dist(kwargs['launcher'])
|
||||
|
||||
self._dist = get_dist_info()[1] > 1
|
||||
|
||||
# model placement
|
||||
if self.device.type == 'cuda':
|
||||
self.model.to(self.device)
|
||||
@@ -531,8 +535,14 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
model.train()
|
||||
self._mode = ModeKeys.TRAIN
|
||||
# call model forward but not __call__ to skip postprocess
|
||||
if isinstance(inputs,
|
||||
Mapping) and not func_receive_dict_inputs(model.forward):
|
||||
|
||||
if is_parallel(model):
|
||||
receive_dict_inputs = func_receive_dict_inputs(
|
||||
model.module.forward)
|
||||
else:
|
||||
receive_dict_inputs = func_receive_dict_inputs(model.forward)
|
||||
|
||||
if isinstance(inputs, Mapping) and not receive_dict_inputs:
|
||||
train_outputs = model.forward(**inputs)
|
||||
else:
|
||||
train_outputs = model.forward(inputs)
|
||||
|
||||
@@ -11,6 +11,7 @@ import torch
|
||||
from torch import distributed as dist
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.trainers.parallel.utils import is_parallel
|
||||
from modelscope.utils.data_utils import to_device
|
||||
from modelscope.utils.file_utils import func_receive_dict_inputs
|
||||
from modelscope.utils.torch_utils import (broadcast, get_dist_info, is_master,
|
||||
@@ -134,7 +135,10 @@ def multi_gpu_test(model,
|
||||
data_len = data_loader_iters_per_gpu * world_size
|
||||
desc = 'Total test iterations with multi gpus'
|
||||
|
||||
time.sleep(2) # This line can prevent deadlock problem in some cases.
|
||||
if is_parallel(model):
|
||||
receive_dict_inputs = func_receive_dict_inputs(model.module.forward)
|
||||
else:
|
||||
receive_dict_inputs = func_receive_dict_inputs(model.forward)
|
||||
|
||||
count = 0
|
||||
with tqdm(total=data_len, desc=desc) as pbar:
|
||||
@@ -142,8 +146,7 @@ def multi_gpu_test(model,
|
||||
data = to_device(data, device)
|
||||
data_list.append(data)
|
||||
with torch.no_grad():
|
||||
if isinstance(data, Mapping) and not func_receive_dict_inputs(
|
||||
model.forward):
|
||||
if isinstance(data, Mapping) and not receive_dict_inputs:
|
||||
result = model.forward(**data)
|
||||
else:
|
||||
result = model.forward(data)
|
||||
|
||||
@@ -115,6 +115,10 @@ def get_dist_info() -> Tuple[int, int]:
|
||||
return rank, world_size
|
||||
|
||||
|
||||
def get_local_rank():
|
||||
return int(os.environ.get('LOCAL_RANK', 0))
|
||||
|
||||
|
||||
def is_master():
|
||||
rank, _ = get_dist_info()
|
||||
return rank == 0
|
||||
|
||||
@@ -53,7 +53,18 @@ class DummyModel(nn.Module, Model):
|
||||
return dict(logits=x, loss=loss)
|
||||
|
||||
|
||||
def train_func(work_dir, dist=False, iterable_dataset=False, **kwargs):
|
||||
class DummyModelForwardInputs(DummyModel):
|
||||
|
||||
def forward(self, inputs):
|
||||
feat, labels = inputs['feat'], inputs['labels']
|
||||
return super().forward(feat, labels)
|
||||
|
||||
|
||||
def train_func(work_dir,
|
||||
dist=False,
|
||||
iterable_dataset=False,
|
||||
forward_inputs=False,
|
||||
**kwargs):
|
||||
json_cfg = {
|
||||
'task': Tasks.image_classification,
|
||||
'train': {
|
||||
@@ -81,7 +92,10 @@ def train_func(work_dir, dist=False, iterable_dataset=False, **kwargs):
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(json_cfg, f)
|
||||
|
||||
model = DummyModel()
|
||||
if forward_inputs:
|
||||
model = DummyModelForwardInputs()
|
||||
else:
|
||||
model = DummyModel()
|
||||
optimmizer = SGD(model.parameters(), lr=0.01)
|
||||
lr_scheduler = StepLR(optimmizer, 2)
|
||||
trainer_name = Trainers.default
|
||||
@@ -273,6 +287,22 @@ class TrainerTestMultiGpus(DistributedTestCase):
|
||||
for i in [1, 3, 5]:
|
||||
self.assertIn(MetricKeys.ACCURACY, lines[i])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_multi_gpus_forward_inputs(self):
|
||||
self.start(
|
||||
train_func,
|
||||
num_gpus=2,
|
||||
work_dir=self.tmp_dir,
|
||||
dist=True,
|
||||
forward_inputs=True)
|
||||
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
json_files = glob.glob(os.path.join(self.tmp_dir, '*.log.json'))
|
||||
self.assertEqual(len(json_files), 1)
|
||||
self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
|
||||
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
|
||||
self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
|
||||
|
||||
# TODO: support iters_per_epoch for dist mode
|
||||
@unittest.skipIf(True, 'need to adapt to DistributedSampler')
|
||||
def test_multi_gpus_with_iters_per_epoch(self):
|
||||
|
||||
Reference in New Issue
Block a user