[to #42322933] Add early stop hook

This commit is contained in:
pangda
2022-12-06 10:54:47 +08:00
parent a3a942352e
commit 5fd3e7bb43
4 changed files with 114 additions and 0 deletions

View File

@@ -522,6 +522,7 @@ class Hooks(object):
ClipClampLogitScaleHook = 'ClipClampLogitScaleHook'
# train
EarlyStopHook = 'EarlyStopHook'
DeepspeedHook = 'DeepspeedHook'

View File

@@ -6,6 +6,7 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .builder import HOOKS, build_hook
from .checkpoint_hook import BestCkptSaverHook, CheckpointHook
from .early_stop_hook import EarlyStopHook
from .compression import SparsityHook
from .evaluation_hook import EvaluationHook
from .hook import Hook

View File

@@ -0,0 +1,109 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
from modelscope.metainfo import Hooks
from modelscope.utils.logger import get_logger
from .builder import HOOKS
from .hook import Hook
from .priority import Priority
@HOOKS.register_module(module_name=Hooks.EarlyStopHook)
class EarlyStopHook(Hook):
"""Early stop when a specific metric stops improving.
Args:
metric_key (str): Metric key to be monitored.
rule (str): Comparison rule for best score. Support "max" and "min".
If rule is "max", the training will stop when `metric_key` has stopped increaing.
If rule is "min", the training will stop when `metric_key` has stopped decreasing.
patience (int): Trainer will stop if the monitored metric did not improve for the last `patience` times.
min_delta (float): Minimum change in the monitored metric to quailfy as an improvement.
check_finite (bool): If true, stops training when the metric becomes NaN or infinite.
by_epoch (int): Saving checkpoints by epoch or by iteration.
interval (int): The frequency to trigger early stop check. If `by_epoch=True`,
it means the number of epochs, else means the number of iterations.
"""
PRIORITY = Priority.VERY_LOW
rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y}
def __init__(self,
metric_key: str,
rule: str = 'max',
patience: int = 3,
min_delta: float = 0.0,
check_finite: bool = True,
by_epoch: bool = True,
interval: int = 1):
self.metric_key = metric_key
self.rule = rule
self.patience = patience
self.min_delta = min_delta
self.check_finite = check_finite
self.by_epoch = by_epoch
self.interval = interval
self.wait_count = 0
self.best_score = float('inf') if rule == 'min' else -float('inf')
def before_run(self, trainer):
if not hasattr(trainer, 'logger'):
self.logger = get_logger(__name__)
else:
self.logger = trainer.logger
def _should_stop(self, trainer):
metric_values = trainer.metric_values
if metric_values is None:
return False
if self.metric_key not in metric_values:
raise ValueError(
f'Metric not found: {self.metric_key} not in {metric_values}')
should_stop = False
current_score = metric_values[self.metric_key]
if self.check_finite and not np.isfinite(current_score):
should_stop = True
self.logger.warn(
f'Metric {self.metric_key} = {current_score} is not finite. '
f'Previous best metric: {self.best_score:.4f}.')
elif self.rule_map[self.rule](current_score - self.min_delta,
self.best_score):
self.best_score = current_score
self.wait_count = 0
else:
self.wait_count += 1
if self.wait_count >= self.patience:
should_stop = True
self.logger.info(
f'Metric {self.metric_key} did not improve in the last {self.wait_count} epochs or iterations. '
f'Best score: {self.best_score:.4f}.')
return should_stop
def _stop_training(self, trainer):
self.logger.info('Early Stopping!')
trainer._stop_training = True
def after_train_epoch(self, trainer):
if not self.by_epoch:
return
if not self.every_n_epochs(trainer, self.interval):
return
if self._should_stop(trainer):
self._stop_training(trainer)
def after_train_iter(self, trainer):
if self.by_epoch:
return
if not self.every_n_iters(trainer, self.interval):
return
if self._should_stop(trainer):
self._stop_training(trainer)

View File

@@ -112,6 +112,7 @@ class EpochBasedTrainer(BaseTrainer):
self._epoch = 0
self._iter = 0
self._inner_iter = 0
self._stop_training = False
if isinstance(model, str):
self.model_dir = self.get_or_download_model_dir(
@@ -910,6 +911,8 @@ class EpochBasedTrainer(BaseTrainer):
# Value changed after the hooks are invoked, do not move them above the invoke_hook code.
self._inner_iter = 0
self._epoch += 1
if self._stop_training:
break
self.invoke_hook(TrainerStages.after_run)