From c8dcdd93daaa6db7eaffff93535bcf5fd877a0b2 Mon Sep 17 00:00:00 2001 From: "wenmeng.zwm" Date: Thu, 8 Dec 2022 10:22:47 +0800 Subject: [PATCH] broadcase metric values across all workers for distribution Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10980488 --- modelscope/trainers/trainer.py | 9 ++++++--- tests/trainers/test_trainer_gpu.py | 1 + 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index df2dc25f..aa4818d9 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -36,9 +36,9 @@ from modelscope.utils.device import create_device from modelscope.utils.file_utils import func_receive_dict_inputs from modelscope.utils.logger import get_logger from modelscope.utils.registry import build_from_cfg -from modelscope.utils.torch_utils import (get_dist_info, get_local_rank, - init_dist, is_dist, is_master, - set_random_seed) +from modelscope.utils.torch_utils import (broadcast, get_dist_info, + get_local_rank, init_dist, is_dist, + is_master, set_random_seed) from .base import BaseTrainer from .builder import TRAINERS from .default_config import merge_cfg @@ -982,6 +982,9 @@ class EpochBasedTrainer(BaseTrainer): for metric_cls in metric_classes: metric_values.update(metric_cls.evaluate()) + _, world_size = get_dist_info() + if world_size > 1: + metric_values = broadcast(metric_values, 0) return metric_values def visualization(self, results, dataset, **kwargs): diff --git a/tests/trainers/test_trainer_gpu.py b/tests/trainers/test_trainer_gpu.py index c003f3c9..ca0a15f0 100644 --- a/tests/trainers/test_trainer_gpu.py +++ b/tests/trainers/test_trainer_gpu.py @@ -67,6 +67,7 @@ def train_func(work_dir, **kwargs): json_cfg = { 'task': Tasks.image_classification, + 'model': {}, 'train': { 'work_dir': work_dir, 'dataloader': {