From 437ea688c5dbe0429e04e08f6bd209284a75501b Mon Sep 17 00:00:00 2001 From: Yingda Chen Date: Fri, 22 Nov 2024 17:47:58 +0800 Subject: [PATCH] support subfolder --- modelscope/__init__.py | 5 +++-- modelscope/utils/hf_util.py | 31 +++++++++++++++++++++++++------ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/modelscope/__init__.py b/modelscope/__init__.py index 14269075..d60a8c79 100644 --- a/modelscope/__init__.py +++ b/modelscope/__init__.py @@ -37,7 +37,8 @@ if TYPE_CHECKING: AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForImageSegmentation, - AutoTokenizer, GenerationConfig, AutoImageProcessor, BatchFeature) + AutoTokenizer, GenerationConfig, AutoImageProcessor, BatchFeature, + T5EncoderModel) else: print( 'transformer is not installed, please install it if you want to use related modules' @@ -96,7 +97,7 @@ else: 'AutoModelForSeq2SeqLM', 'AutoTokenizer', 'AutoModelForSequenceClassification', 'AutoModelForTokenClassification', 'AutoModelForImageSegmentation', - 'AutoImageProcessor', 'BatchFeature' + 'AutoImageProcessor', 'BatchFeature', 'T5EncoderModel' ] import sys diff --git a/modelscope/utils/hf_util.py b/modelscope/utils/hf_util.py index 471ce519..9d517724 100644 --- a/modelscope/utils/hf_util.py +++ b/modelscope/utils/hf_util.py @@ -23,6 +23,7 @@ from transformers import BitsAndBytesConfig as BitsAndBytesConfigHF from transformers import GenerationConfig as GenerationConfigHF from transformers import (PretrainedConfig, PreTrainedModel, PreTrainedTokenizerBase) +from transformers import T5EncoderModel as T5EncoderModelHF from modelscope import snapshot_download from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke @@ -248,7 +249,10 @@ def patch_hub(): _patch_pretrained_class() -def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs): +def get_wrapped_class(module_class, + ignore_file_pattern=[], + file_filter=None, + **kwargs): """Get a custom wrapper class for auto classes to download the models from the ModelScope hub Args: module_class: The actual module class @@ -258,6 +262,7 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs): The wrapper """ default_ignore_file_pattern = ignore_file_pattern + default_file_filter = file_filter class ClassWrapper(module_class): @@ -266,13 +271,26 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs): **kwargs): ignore_file_pattern = kwargs.pop('ignore_file_pattern', default_ignore_file_pattern) + subfolder = kwargs.pop('subfolder', default_file_filter) + + if subfolder: + file_filter = f'{subfolder}/*' if not os.path.exists(pretrained_model_name_or_path): revision = kwargs.pop('revision', DEFAULT_MODEL_REVISION) - model_dir = snapshot_download( - pretrained_model_name_or_path, - revision=revision, - ignore_file_pattern=ignore_file_pattern, - user_agent=user_agent()) + if file_filter is None: + model_dir = snapshot_download( + pretrained_model_name_or_path, + revision=revision, + ignore_file_pattern=ignore_file_pattern, + user_agent=user_agent()) + else: + model_dir = os.path.join( + snapshot_download( + pretrained_model_name_or_path, + revision=revision, + ignore_file_pattern=ignore_file_pattern, + allow_file_pattern=file_filter, + user_agent=user_agent()), subfolder) else: model_dir = pretrained_model_name_or_path @@ -297,6 +315,7 @@ AutoModelForTokenClassification = get_wrapped_class( AutoModelForTokenClassificationHF) AutoModelForImageSegmentation = get_wrapped_class( AutoModelForImageSegmentationHF) +T5EncoderModel = get_wrapped_class(T5EncoderModelHF) AutoTokenizer = get_wrapped_class( AutoTokenizerHF,