Add AutoModelForImageSegmentation and T5EncoderModel; Support from subfolder option (#1096)

Co-authored-by: Yingda Chen <yingda.chen@alibaba-inc.com>
This commit is contained in:
Yingda Chen
2024-11-22 19:43:47 +08:00
committed by GitHub
parent d687c9c514
commit ddc5fab311
2 changed files with 38 additions and 16 deletions

View File

@@ -33,12 +33,12 @@ if TYPE_CHECKING:
from .utils.constant import Tasks
if is_transformers_available():
from .utils.hf_util import AutoConfig, GPTQConfig, AwqConfig, BitsAndBytesConfig
from .utils.hf_util import (AutoModel, AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoTokenizer, GenerationConfig,
AutoImageProcessor, BatchFeature)
from .utils.hf_util import (
AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification, AutoModelForImageSegmentation,
AutoTokenizer, GenerationConfig, AutoImageProcessor, BatchFeature,
T5EncoderModel)
else:
print(
'transformer is not installed, please install it if you want to use related modules'
@@ -96,8 +96,8 @@ else:
'AwqConfig', 'BitsAndBytesConfig', 'AutoModelForCausalLM',
'AutoModelForSeq2SeqLM', 'AutoTokenizer',
'AutoModelForSequenceClassification',
'AutoModelForTokenClassification', 'AutoImageProcessor',
'BatchFeature'
'AutoModelForTokenClassification', 'AutoModelForImageSegmentation',
'AutoImageProcessor', 'BatchFeature', 'T5EncoderModel'
]
import sys

View File

@@ -1,15 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import importlib
import os
from pathlib import Path
from types import MethodType
from typing import Dict, Literal, Optional, Union
from typing import Optional, Union
from transformers import AutoConfig as AutoConfigHF
from transformers import AutoFeatureExtractor as AutoFeatureExtractorHF
from transformers import AutoImageProcessor as AutoImageProcessorHF
from transformers import AutoModel as AutoModelHF
from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF
from transformers import \
AutoModelForImageSegmentation as AutoModelForImageSegmentationHF
from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF
from transformers import \
AutoModelForSequenceClassification as AutoModelForSequenceClassificationHF
@@ -22,6 +23,7 @@ from transformers import BitsAndBytesConfig as BitsAndBytesConfigHF
from transformers import GenerationConfig as GenerationConfigHF
from transformers import (PretrainedConfig, PreTrainedModel,
PreTrainedTokenizerBase)
from transformers import T5EncoderModel as T5EncoderModelHF
from modelscope import snapshot_download
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
@@ -247,7 +249,10 @@ def patch_hub():
_patch_pretrained_class()
def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
def get_wrapped_class(module_class,
ignore_file_pattern=[],
file_filter=None,
**kwargs):
"""Get a custom wrapper class for auto classes to download the models from the ModelScope hub
Args:
module_class: The actual module class
@@ -257,6 +262,7 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
The wrapper
"""
default_ignore_file_pattern = ignore_file_pattern
default_file_filter = file_filter
class ClassWrapper(module_class):
@@ -265,13 +271,26 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
**kwargs):
ignore_file_pattern = kwargs.pop('ignore_file_pattern',
default_ignore_file_pattern)
subfolder = kwargs.pop('subfolder', default_file_filter)
if subfolder:
file_filter = f'{subfolder}/*'
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', DEFAULT_MODEL_REVISION)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
user_agent=user_agent())
if file_filter is None:
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
user_agent=user_agent())
else:
model_dir = os.path.join(
snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=file_filter,
user_agent=user_agent()), subfolder)
else:
model_dir = pretrained_model_name_or_path
@@ -294,6 +313,9 @@ AutoModelForSequenceClassification = get_wrapped_class(
AutoModelForSequenceClassificationHF)
AutoModelForTokenClassification = get_wrapped_class(
AutoModelForTokenClassificationHF)
AutoModelForImageSegmentation = get_wrapped_class(
AutoModelForImageSegmentationHF)
T5EncoderModel = get_wrapped_class(T5EncoderModelHF)
AutoTokenizer = get_wrapped_class(
AutoTokenizerHF,