mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
support subfolder
This commit is contained in:
@@ -37,7 +37,8 @@ if TYPE_CHECKING:
|
||||
AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification, AutoModelForImageSegmentation,
|
||||
AutoTokenizer, GenerationConfig, AutoImageProcessor, BatchFeature)
|
||||
AutoTokenizer, GenerationConfig, AutoImageProcessor, BatchFeature,
|
||||
T5EncoderModel)
|
||||
else:
|
||||
print(
|
||||
'transformer is not installed, please install it if you want to use related modules'
|
||||
@@ -96,7 +97,7 @@ else:
|
||||
'AutoModelForSeq2SeqLM', 'AutoTokenizer',
|
||||
'AutoModelForSequenceClassification',
|
||||
'AutoModelForTokenClassification', 'AutoModelForImageSegmentation',
|
||||
'AutoImageProcessor', 'BatchFeature'
|
||||
'AutoImageProcessor', 'BatchFeature', 'T5EncoderModel'
|
||||
]
|
||||
|
||||
import sys
|
||||
|
||||
@@ -23,6 +23,7 @@ from transformers import BitsAndBytesConfig as BitsAndBytesConfigHF
|
||||
from transformers import GenerationConfig as GenerationConfigHF
|
||||
from transformers import (PretrainedConfig, PreTrainedModel,
|
||||
PreTrainedTokenizerBase)
|
||||
from transformers import T5EncoderModel as T5EncoderModelHF
|
||||
|
||||
from modelscope import snapshot_download
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
|
||||
@@ -248,7 +249,10 @@ def patch_hub():
|
||||
_patch_pretrained_class()
|
||||
|
||||
|
||||
def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
|
||||
def get_wrapped_class(module_class,
|
||||
ignore_file_pattern=[],
|
||||
file_filter=None,
|
||||
**kwargs):
|
||||
"""Get a custom wrapper class for auto classes to download the models from the ModelScope hub
|
||||
Args:
|
||||
module_class: The actual module class
|
||||
@@ -258,6 +262,7 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
|
||||
The wrapper
|
||||
"""
|
||||
default_ignore_file_pattern = ignore_file_pattern
|
||||
default_file_filter = file_filter
|
||||
|
||||
class ClassWrapper(module_class):
|
||||
|
||||
@@ -266,13 +271,26 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
|
||||
**kwargs):
|
||||
ignore_file_pattern = kwargs.pop('ignore_file_pattern',
|
||||
default_ignore_file_pattern)
|
||||
subfolder = kwargs.pop('subfolder', default_file_filter)
|
||||
|
||||
if subfolder:
|
||||
file_filter = f'{subfolder}/*'
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', DEFAULT_MODEL_REVISION)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
user_agent=user_agent())
|
||||
if file_filter is None:
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
user_agent=user_agent())
|
||||
else:
|
||||
model_dir = os.path.join(
|
||||
snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=file_filter,
|
||||
user_agent=user_agent()), subfolder)
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
|
||||
@@ -297,6 +315,7 @@ AutoModelForTokenClassification = get_wrapped_class(
|
||||
AutoModelForTokenClassificationHF)
|
||||
AutoModelForImageSegmentation = get_wrapped_class(
|
||||
AutoModelForImageSegmentationHF)
|
||||
T5EncoderModel = get_wrapped_class(T5EncoderModelHF)
|
||||
|
||||
AutoTokenizer = get_wrapped_class(
|
||||
AutoTokenizerHF,
|
||||
|
||||
Reference in New Issue
Block a user