broadcase metric values across all workers for distribution

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10980488
This commit is contained in:
wenmeng.zwm
2022-12-08 10:22:47 +08:00
parent 8284d2d366
commit c8dcdd93da
2 changed files with 7 additions and 3 deletions

View File

@@ -36,9 +36,9 @@ from modelscope.utils.device import create_device
from modelscope.utils.file_utils import func_receive_dict_inputs
from modelscope.utils.logger import get_logger
from modelscope.utils.registry import build_from_cfg
from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
init_dist, is_dist, is_master,
set_random_seed)
from modelscope.utils.torch_utils import (broadcast, get_dist_info,
get_local_rank, init_dist, is_dist,
is_master, set_random_seed)
from .base import BaseTrainer
from .builder import TRAINERS
from .default_config import merge_cfg
@@ -982,6 +982,9 @@ class EpochBasedTrainer(BaseTrainer):
for metric_cls in metric_classes:
metric_values.update(metric_cls.evaluate())
_, world_size = get_dist_info()
if world_size > 1:
metric_values = broadcast(metric_values, 0)
return metric_values
def visualization(self, results, dataset, **kwargs):

View File

@@ -67,6 +67,7 @@ def train_func(work_dir,
**kwargs):
json_cfg = {
'task': Tasks.image_classification,
'model': {},
'train': {
'work_dir': work_dir,
'dataloader': {