mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
[to #42322933] Add early stop hook
This commit is contained in:
@@ -522,6 +522,7 @@ class Hooks(object):
|
||||
ClipClampLogitScaleHook = 'ClipClampLogitScaleHook'
|
||||
|
||||
# train
|
||||
EarlyStopHook = 'EarlyStopHook'
|
||||
DeepspeedHook = 'DeepspeedHook'
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
if TYPE_CHECKING:
|
||||
from .builder import HOOKS, build_hook
|
||||
from .checkpoint_hook import BestCkptSaverHook, CheckpointHook
|
||||
from .early_stop_hook import EarlyStopHook
|
||||
from .compression import SparsityHook
|
||||
from .evaluation_hook import EvaluationHook
|
||||
from .hook import Hook
|
||||
|
||||
109
modelscope/trainers/hooks/early_stop_hook.py
Normal file
109
modelscope/trainers/hooks/early_stop_hook.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modelscope.metainfo import Hooks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .builder import HOOKS
|
||||
from .hook import Hook
|
||||
from .priority import Priority
|
||||
|
||||
|
||||
@HOOKS.register_module(module_name=Hooks.EarlyStopHook)
|
||||
class EarlyStopHook(Hook):
|
||||
"""Early stop when a specific metric stops improving.
|
||||
|
||||
Args:
|
||||
metric_key (str): Metric key to be monitored.
|
||||
rule (str): Comparison rule for best score. Support "max" and "min".
|
||||
If rule is "max", the training will stop when `metric_key` has stopped increaing.
|
||||
If rule is "min", the training will stop when `metric_key` has stopped decreasing.
|
||||
patience (int): Trainer will stop if the monitored metric did not improve for the last `patience` times.
|
||||
min_delta (float): Minimum change in the monitored metric to quailfy as an improvement.
|
||||
check_finite (bool): If true, stops training when the metric becomes NaN or infinite.
|
||||
by_epoch (int): Saving checkpoints by epoch or by iteration.
|
||||
interval (int): The frequency to trigger early stop check. If `by_epoch=True`,
|
||||
it means the number of epochs, else means the number of iterations.
|
||||
"""
|
||||
|
||||
PRIORITY = Priority.VERY_LOW
|
||||
rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y}
|
||||
|
||||
def __init__(self,
|
||||
metric_key: str,
|
||||
rule: str = 'max',
|
||||
patience: int = 3,
|
||||
min_delta: float = 0.0,
|
||||
check_finite: bool = True,
|
||||
by_epoch: bool = True,
|
||||
interval: int = 1):
|
||||
self.metric_key = metric_key
|
||||
self.rule = rule
|
||||
self.patience = patience
|
||||
self.min_delta = min_delta
|
||||
self.check_finite = check_finite
|
||||
self.by_epoch = by_epoch
|
||||
self.interval = interval
|
||||
|
||||
self.wait_count = 0
|
||||
self.best_score = float('inf') if rule == 'min' else -float('inf')
|
||||
|
||||
def before_run(self, trainer):
|
||||
if not hasattr(trainer, 'logger'):
|
||||
self.logger = get_logger(__name__)
|
||||
else:
|
||||
self.logger = trainer.logger
|
||||
|
||||
def _should_stop(self, trainer):
|
||||
metric_values = trainer.metric_values
|
||||
|
||||
if metric_values is None:
|
||||
return False
|
||||
|
||||
if self.metric_key not in metric_values:
|
||||
raise ValueError(
|
||||
f'Metric not found: {self.metric_key} not in {metric_values}')
|
||||
|
||||
should_stop = False
|
||||
current_score = metric_values[self.metric_key]
|
||||
if self.check_finite and not np.isfinite(current_score):
|
||||
should_stop = True
|
||||
self.logger.warn(
|
||||
f'Metric {self.metric_key} = {current_score} is not finite. '
|
||||
f'Previous best metric: {self.best_score:.4f}.')
|
||||
elif self.rule_map[self.rule](current_score - self.min_delta,
|
||||
self.best_score):
|
||||
self.best_score = current_score
|
||||
self.wait_count = 0
|
||||
else:
|
||||
self.wait_count += 1
|
||||
if self.wait_count >= self.patience:
|
||||
should_stop = True
|
||||
self.logger.info(
|
||||
f'Metric {self.metric_key} did not improve in the last {self.wait_count} epochs or iterations. '
|
||||
f'Best score: {self.best_score:.4f}.')
|
||||
return should_stop
|
||||
|
||||
def _stop_training(self, trainer):
|
||||
self.logger.info('Early Stopping!')
|
||||
trainer._stop_training = True
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
if not self.by_epoch:
|
||||
return
|
||||
|
||||
if not self.every_n_epochs(trainer, self.interval):
|
||||
return
|
||||
|
||||
if self._should_stop(trainer):
|
||||
self._stop_training(trainer)
|
||||
|
||||
def after_train_iter(self, trainer):
|
||||
if self.by_epoch:
|
||||
return
|
||||
|
||||
if not self.every_n_iters(trainer, self.interval):
|
||||
return
|
||||
|
||||
if self._should_stop(trainer):
|
||||
self._stop_training(trainer)
|
||||
@@ -112,6 +112,7 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
self._epoch = 0
|
||||
self._iter = 0
|
||||
self._inner_iter = 0
|
||||
self._stop_training = False
|
||||
|
||||
if isinstance(model, str):
|
||||
self.model_dir = self.get_or_download_model_dir(
|
||||
@@ -910,6 +911,8 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
# Value changed after the hooks are invoked, do not move them above the invoke_hook code.
|
||||
self._inner_iter = 0
|
||||
self._epoch += 1
|
||||
if self._stop_training:
|
||||
break
|
||||
|
||||
self.invoke_hook(TrainerStages.after_run)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user