mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
[to #42322933] fix bug: checkpoint hook and bestckpthook exists at the same time
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10227608
This commit is contained in:
@@ -1,4 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.utils.config import Config
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
'train': {
|
||||
'hooks': [{
|
||||
@@ -12,3 +15,19 @@ DEFAULT_CONFIG = {
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def merge_cfg(cfg: Config):
|
||||
"""Merge the default config into the input cfg.
|
||||
|
||||
This function will pop the default CheckpointHook when the BestCkptSaverHook exists in the input cfg.
|
||||
|
||||
@param cfg: The input cfg to be merged into.
|
||||
"""
|
||||
cfg.merge_from_dict(DEFAULT_CONFIG, force=False)
|
||||
# pop duplicate hook
|
||||
|
||||
if any(['BestCkptSaverHook' == hook['type'] for hook in cfg.train.hooks]):
|
||||
cfg.train.hooks = list(
|
||||
filter(lambda hook: hook['type'] != 'CheckpointHook',
|
||||
cfg.train.hooks))
|
||||
|
||||
@@ -41,7 +41,7 @@ from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
|
||||
init_dist, set_random_seed)
|
||||
from .base import BaseTrainer
|
||||
from .builder import TRAINERS
|
||||
from .default_config import DEFAULT_CONFIG
|
||||
from .default_config import merge_cfg
|
||||
from .hooks.hook import Hook
|
||||
from .parallel.builder import build_parallel
|
||||
from .parallel.utils import is_parallel
|
||||
@@ -114,7 +114,7 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
super().__init__(cfg_file, arg_parse_fn)
|
||||
|
||||
# add default config
|
||||
self.cfg.merge_from_dict(self._get_default_config(), force=False)
|
||||
merge_cfg(self.cfg)
|
||||
self.cfg = self.rebuild_config(self.cfg)
|
||||
|
||||
if 'cfg_options' in kwargs:
|
||||
@@ -951,9 +951,6 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
stage_hook_infos.append(info)
|
||||
return '\n'.join(stage_hook_infos)
|
||||
|
||||
def _get_default_config(self):
|
||||
return DEFAULT_CONFIG
|
||||
|
||||
|
||||
def worker_init_fn(worker_id, num_workers, rank, seed):
|
||||
# The seed of each worker equals to
|
||||
|
||||
@@ -204,9 +204,6 @@ class BestCkptSaverHookTest(unittest.TestCase):
|
||||
trainer = build_trainer(trainer_name, kwargs)
|
||||
trainer.train()
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{LogKeys.EPOCH}_1.pth', results_files)
|
||||
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
|
||||
self.assertIn(f'{LogKeys.EPOCH}_3.pth', results_files)
|
||||
self.assertIn(f'best_{LogKeys.EPOCH}1_{MetricKeys.ACCURACY}0.1.pth',
|
||||
results_files)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user