mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
Add AutoModelForImageSegmentation and T5EncoderModel; Support from subfolder option (#1096)
Co-authored-by: Yingda Chen <yingda.chen@alibaba-inc.com>
This commit is contained in:
@@ -33,12 +33,12 @@ if TYPE_CHECKING:
|
||||
from .utils.constant import Tasks
|
||||
if is_transformers_available():
|
||||
from .utils.hf_util import AutoConfig, GPTQConfig, AwqConfig, BitsAndBytesConfig
|
||||
from .utils.hf_util import (AutoModel, AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification,
|
||||
AutoTokenizer, GenerationConfig,
|
||||
AutoImageProcessor, BatchFeature)
|
||||
from .utils.hf_util import (
|
||||
AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification, AutoModelForImageSegmentation,
|
||||
AutoTokenizer, GenerationConfig, AutoImageProcessor, BatchFeature,
|
||||
T5EncoderModel)
|
||||
else:
|
||||
print(
|
||||
'transformer is not installed, please install it if you want to use related modules'
|
||||
@@ -96,8 +96,8 @@ else:
|
||||
'AwqConfig', 'BitsAndBytesConfig', 'AutoModelForCausalLM',
|
||||
'AutoModelForSeq2SeqLM', 'AutoTokenizer',
|
||||
'AutoModelForSequenceClassification',
|
||||
'AutoModelForTokenClassification', 'AutoImageProcessor',
|
||||
'BatchFeature'
|
||||
'AutoModelForTokenClassification', 'AutoModelForImageSegmentation',
|
||||
'AutoImageProcessor', 'BatchFeature', 'T5EncoderModel'
|
||||
]
|
||||
|
||||
import sys
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import importlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
from typing import Dict, Literal, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers import AutoConfig as AutoConfigHF
|
||||
from transformers import AutoFeatureExtractor as AutoFeatureExtractorHF
|
||||
from transformers import AutoImageProcessor as AutoImageProcessorHF
|
||||
from transformers import AutoModel as AutoModelHF
|
||||
from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF
|
||||
from transformers import \
|
||||
AutoModelForImageSegmentation as AutoModelForImageSegmentationHF
|
||||
from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF
|
||||
from transformers import \
|
||||
AutoModelForSequenceClassification as AutoModelForSequenceClassificationHF
|
||||
@@ -22,6 +23,7 @@ from transformers import BitsAndBytesConfig as BitsAndBytesConfigHF
|
||||
from transformers import GenerationConfig as GenerationConfigHF
|
||||
from transformers import (PretrainedConfig, PreTrainedModel,
|
||||
PreTrainedTokenizerBase)
|
||||
from transformers import T5EncoderModel as T5EncoderModelHF
|
||||
|
||||
from modelscope import snapshot_download
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
|
||||
@@ -247,7 +249,10 @@ def patch_hub():
|
||||
_patch_pretrained_class()
|
||||
|
||||
|
||||
def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
|
||||
def get_wrapped_class(module_class,
|
||||
ignore_file_pattern=[],
|
||||
file_filter=None,
|
||||
**kwargs):
|
||||
"""Get a custom wrapper class for auto classes to download the models from the ModelScope hub
|
||||
Args:
|
||||
module_class: The actual module class
|
||||
@@ -257,6 +262,7 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
|
||||
The wrapper
|
||||
"""
|
||||
default_ignore_file_pattern = ignore_file_pattern
|
||||
default_file_filter = file_filter
|
||||
|
||||
class ClassWrapper(module_class):
|
||||
|
||||
@@ -265,13 +271,26 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
|
||||
**kwargs):
|
||||
ignore_file_pattern = kwargs.pop('ignore_file_pattern',
|
||||
default_ignore_file_pattern)
|
||||
subfolder = kwargs.pop('subfolder', default_file_filter)
|
||||
|
||||
if subfolder:
|
||||
file_filter = f'{subfolder}/*'
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', DEFAULT_MODEL_REVISION)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
user_agent=user_agent())
|
||||
if file_filter is None:
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
user_agent=user_agent())
|
||||
else:
|
||||
model_dir = os.path.join(
|
||||
snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
allow_file_pattern=file_filter,
|
||||
user_agent=user_agent()), subfolder)
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
|
||||
@@ -294,6 +313,9 @@ AutoModelForSequenceClassification = get_wrapped_class(
|
||||
AutoModelForSequenceClassificationHF)
|
||||
AutoModelForTokenClassification = get_wrapped_class(
|
||||
AutoModelForTokenClassificationHF)
|
||||
AutoModelForImageSegmentation = get_wrapped_class(
|
||||
AutoModelForImageSegmentationHF)
|
||||
T5EncoderModel = get_wrapped_class(T5EncoderModelHF)
|
||||
|
||||
AutoTokenizer = get_wrapped_class(
|
||||
AutoTokenizerHF,
|
||||
|
||||
Reference in New Issue
Block a user