fix recursion (#1024)

* fix recursion

* lint code
This commit is contained in:
tastelikefeet
2024-10-16 08:42:46 +08:00
committed by GitHub
parent eab95f909f
commit 7ceac5a359

View File

@@ -26,7 +26,9 @@ def can_load_by_ms(model_dir: str, task_name: Optional[str],
def fix_upgrade(module_obj: Any):
from transformers import PreTrainedModel
if hasattr(module_obj, '_set_gradient_checkpointing') \
and 'value' in inspect.signature(module_obj._set_gradient_checkpointing).parameters.keys():
and 'value' in inspect.signature(
module_obj._set_gradient_checkpointing).parameters.keys() \
and 'modelscope.' in str(module_obj.__class__):
module_obj._set_gradient_checkpointing = MethodType(
PreTrainedModel._set_gradient_checkpointing, module_obj)