diff --git a/modelscope/__init__.py b/modelscope/__init__.py index 11f28767..9f1d5df1 100644 --- a/modelscope/__init__.py +++ b/modelscope/__init__.py @@ -33,7 +33,8 @@ if TYPE_CHECKING: AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoTokenizer, - GenerationConfig) + GenerationConfig, AutoImageProcessor, + BatchFeature) from .utils.hub import create_model_if_not_exist, read_config from .utils.logger import get_logger from .version import __release_datetime__, __version__ @@ -81,7 +82,8 @@ else: 'BitsAndBytesConfig', 'AutoModelForCausalLM', 'AutoModelForSeq2SeqLM', 'AutoTokenizer', 'AutoModelForSequenceClassification', - 'AutoModelForTokenClassification' + 'AutoModelForTokenClassification', 'AutoImageProcessor', + 'BatchFeature' ], 'msdatasets': ['MsDataset'] } diff --git a/modelscope/utils/hf_util.py b/modelscope/utils/hf_util.py index f98dbea2..4816b95e 100644 --- a/modelscope/utils/hf_util.py +++ b/modelscope/utils/hf_util.py @@ -2,6 +2,7 @@ import os from transformers import AutoConfig as AutoConfigHF +from transformers import AutoImageProcessor as AutoImageProcessorHF from transformers import AutoModel as AutoModelHF from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF @@ -10,6 +11,7 @@ from transformers import \ from transformers import \ AutoModelForTokenClassification as AutoModelForTokenClassificationHF from transformers import AutoTokenizer as AutoTokenizerHF +from transformers import BatchFeature as BatchFeatureHF from transformers import BitsAndBytesConfig as BitsAndBytesConfigHF from transformers import GenerationConfig as GenerationConfigHF from transformers import PreTrainedModel, PreTrainedTokenizerBase @@ -136,3 +138,5 @@ GenerationConfig = get_wrapped_class( GenerationConfigHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors']) GPTQConfig = GPTQConfigHF BitsAndBytesConfig = BitsAndBytesConfigHF +AutoImageProcessor = get_wrapped_class(AutoImageProcessorHF) +BatchFeature = get_wrapped_class(BatchFeatureHF)