mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
fix
This commit is contained in:
@@ -436,6 +436,25 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
**ignore_file_pattern_kwargs))
|
||||
|
||||
all_available_modules.append(var)
|
||||
|
||||
def get_class_from_dynamic_module(class_reference, *args, **kwargs):
|
||||
from transformers.dynamic_module_utils import origin_get_class_from_dynamic_module
|
||||
if '--' in class_reference:
|
||||
repo_id, class_reference = class_reference.split('--')
|
||||
if not os.path.exists(repo_id):
|
||||
from modelscope import snapshot_download
|
||||
repo_id = snapshot_download(repo_id)
|
||||
class_reference = repo_id + '--' + class_reference
|
||||
return origin_get_class_from_dynamic_module(class_reference, *args,
|
||||
**kwargs)
|
||||
|
||||
from transformers import dynamic_module_utils
|
||||
if not hasattr(dynamic_module_utils,
|
||||
'origin_get_class_from_dynamic_module'):
|
||||
dynamic_module_utils.origin_get_class_from_dynamic_module = dynamic_module_utils.get_class_from_dynamic_module
|
||||
dynamic_module_utils.get_class_from_dynamic_module = get_class_from_dynamic_module
|
||||
from transformers.models.auto import configuration_auto
|
||||
configuration_auto.get_class_from_dynamic_module = get_class_from_dynamic_module
|
||||
return all_available_modules
|
||||
|
||||
|
||||
@@ -469,6 +488,13 @@ def _unpatch_pretrained_class(all_imported_modules):
|
||||
except: # noqa
|
||||
pass
|
||||
|
||||
from transformers import dynamic_module_utils
|
||||
if hasattr(dynamic_module_utils, 'origin_get_class_from_dynamic_module'):
|
||||
dynamic_module_utils.get_class_from_dynamic_module = dynamic_module_utils.origin_get_class_from_dynamic_module
|
||||
from transformers.models.auto import configuration_auto
|
||||
configuration_auto.get_class_from_dynamic_module = dynamic_module_utils.origin_get_class_from_dynamic_module
|
||||
delattr(dynamic_module_utils, 'origin_get_class_from_dynamic_module')
|
||||
|
||||
|
||||
def _patch_hub():
|
||||
import huggingface_hub
|
||||
|
||||
Reference in New Issue
Block a user