mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
@@ -26,7 +26,9 @@ def can_load_by_ms(model_dir: str, task_name: Optional[str],
|
||||
def fix_upgrade(module_obj: Any):
|
||||
from transformers import PreTrainedModel
|
||||
if hasattr(module_obj, '_set_gradient_checkpointing') \
|
||||
and 'value' in inspect.signature(module_obj._set_gradient_checkpointing).parameters.keys():
|
||||
and 'value' in inspect.signature(
|
||||
module_obj._set_gradient_checkpointing).parameters.keys() \
|
||||
and 'modelscope.' in str(module_obj.__class__):
|
||||
module_obj._set_gradient_checkpointing = MethodType(
|
||||
PreTrainedModel._set_gradient_checkpointing, module_obj)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user