support subfolder

This commit is contained in:
Yingda Chen
2024-11-22 17:47:58 +08:00
parent d453e3a240
commit 437ea688c5
2 changed files with 28 additions and 8 deletions

View File

@@ -37,7 +37,8 @@ if TYPE_CHECKING:
AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification, AutoModelForImageSegmentation,
AutoTokenizer, GenerationConfig, AutoImageProcessor, BatchFeature)
AutoTokenizer, GenerationConfig, AutoImageProcessor, BatchFeature,
T5EncoderModel)
else:
print(
'transformer is not installed, please install it if you want to use related modules'
@@ -96,7 +97,7 @@ else:
'AutoModelForSeq2SeqLM', 'AutoTokenizer',
'AutoModelForSequenceClassification',
'AutoModelForTokenClassification', 'AutoModelForImageSegmentation',
'AutoImageProcessor', 'BatchFeature'
'AutoImageProcessor', 'BatchFeature', 'T5EncoderModel'
]
import sys

View File

@@ -23,6 +23,7 @@ from transformers import BitsAndBytesConfig as BitsAndBytesConfigHF
from transformers import GenerationConfig as GenerationConfigHF
from transformers import (PretrainedConfig, PreTrainedModel,
PreTrainedTokenizerBase)
from transformers import T5EncoderModel as T5EncoderModelHF
from modelscope import snapshot_download
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
@@ -248,7 +249,10 @@ def patch_hub():
_patch_pretrained_class()
def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
def get_wrapped_class(module_class,
ignore_file_pattern=[],
file_filter=None,
**kwargs):
"""Get a custom wrapper class for auto classes to download the models from the ModelScope hub
Args:
module_class: The actual module class
@@ -258,6 +262,7 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
The wrapper
"""
default_ignore_file_pattern = ignore_file_pattern
default_file_filter = file_filter
class ClassWrapper(module_class):
@@ -266,13 +271,26 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
**kwargs):
ignore_file_pattern = kwargs.pop('ignore_file_pattern',
default_ignore_file_pattern)
subfolder = kwargs.pop('subfolder', default_file_filter)
if subfolder:
file_filter = f'{subfolder}/*'
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', DEFAULT_MODEL_REVISION)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
user_agent=user_agent())
if file_filter is None:
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
user_agent=user_agent())
else:
model_dir = os.path.join(
snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=file_filter,
user_agent=user_agent()), subfolder)
else:
model_dir = pretrained_model_name_or_path
@@ -297,6 +315,7 @@ AutoModelForTokenClassification = get_wrapped_class(
AutoModelForTokenClassificationHF)
AutoModelForImageSegmentation = get_wrapped_class(
AutoModelForImageSegmentationHF)
T5EncoderModel = get_wrapped_class(T5EncoderModelHF)
AutoTokenizer = get_wrapped_class(
AutoTokenizerHF,