From 898e3a42eb7714656dea0b0898468ccae6a97e26 Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Wed, 31 May 2023 21:29:36 +0800 Subject: [PATCH] Fix bug for DeepspeedHook.register_processor Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12686921 --- modelscope/trainers/hooks/distributed/deepspeed_hook.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modelscope/trainers/hooks/distributed/deepspeed_hook.py b/modelscope/trainers/hooks/distributed/deepspeed_hook.py index 7dddc5d9..d0a6eb9b 100644 --- a/modelscope/trainers/hooks/distributed/deepspeed_hook.py +++ b/modelscope/trainers/hooks/distributed/deepspeed_hook.py @@ -17,7 +17,7 @@ from modelscope.trainers.hooks.priority import Priority from modelscope.utils.checkpoint import save_checkpoint from modelscope.utils.logger import get_logger from ..checkpoint.checkpoint_processor import CheckpointProcessor -from ..lr_scheduler_hook import LrSchedulerProcessor +from ..lr_scheduler_hook import LrSchedulerHook, LrSchedulerProcessor from ..optimizer.base import OptimizerHook, OptimizerProcessor @@ -158,6 +158,10 @@ class DeepspeedHook(Hook): if len(optimizer_hook) > 0 and not isinstance( optimizer_hook[0].processor, DeepspeedProcessor): optimizer_hook[0].set_processor(processor) + lr_schedular_hook = trainer.get_hook(LrSchedulerHook) + if len(lr_schedular_hook) > 0 and not isinstance( + lr_schedular_hook[0].processor, DeepspeedProcessor): + lr_schedular_hook[0].set_processor(processor) ckpt_hook = trainer.get_hook(CheckpointHook) if len(ckpt_hook) > 0 and not isinstance(ckpt_hook[0].processor, DeepspeedProcessor):