diff --git a/modelscope/trainers/hooks/distributed/deepspeed_hook.py b/modelscope/trainers/hooks/distributed/deepspeed_hook.py index 7dddc5d9..d0a6eb9b 100644 --- a/modelscope/trainers/hooks/distributed/deepspeed_hook.py +++ b/modelscope/trainers/hooks/distributed/deepspeed_hook.py @@ -17,7 +17,7 @@ from modelscope.trainers.hooks.priority import Priority from modelscope.utils.checkpoint import save_checkpoint from modelscope.utils.logger import get_logger from ..checkpoint.checkpoint_processor import CheckpointProcessor -from ..lr_scheduler_hook import LrSchedulerProcessor +from ..lr_scheduler_hook import LrSchedulerHook, LrSchedulerProcessor from ..optimizer.base import OptimizerHook, OptimizerProcessor @@ -158,6 +158,10 @@ class DeepspeedHook(Hook): if len(optimizer_hook) > 0 and not isinstance( optimizer_hook[0].processor, DeepspeedProcessor): optimizer_hook[0].set_processor(processor) + lr_schedular_hook = trainer.get_hook(LrSchedulerHook) + if len(lr_schedular_hook) > 0 and not isinstance( + lr_schedular_hook[0].processor, DeepspeedProcessor): + lr_schedular_hook[0].set_processor(processor) ckpt_hook = trainer.get_hook(CheckpointHook) if len(ckpt_hook) > 0 and not isinstance(ckpt_hook[0].processor, DeepspeedProcessor):