From 1b6723ea24f8d751760267baac9f3987dcfc2db2 Mon Sep 17 00:00:00 2001 From: mushenL <125954878+mushenL@users.noreply.github.com> Date: Fri, 1 Dec 2023 15:14:06 +0800 Subject: [PATCH] hf_tool add AutoImageProcessor and BatchFeature (#650) * add AutoImageProcessor and BatchFeature --- modelscope/__init__.py | 6 ++++-- modelscope/utils/hf_util.py | 4 ++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/modelscope/__init__.py b/modelscope/__init__.py index 11f28767..9f1d5df1 100644 --- a/modelscope/__init__.py +++ b/modelscope/__init__.py @@ -33,7 +33,8 @@ if TYPE_CHECKING: AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoTokenizer, - GenerationConfig) + GenerationConfig, AutoImageProcessor, + BatchFeature) from .utils.hub import create_model_if_not_exist, read_config from .utils.logger import get_logger from .version import __release_datetime__, __version__ @@ -81,7 +82,8 @@ else: 'BitsAndBytesConfig', 'AutoModelForCausalLM', 'AutoModelForSeq2SeqLM', 'AutoTokenizer', 'AutoModelForSequenceClassification', - 'AutoModelForTokenClassification' + 'AutoModelForTokenClassification', 'AutoImageProcessor', + 'BatchFeature' ], 'msdatasets': ['MsDataset'] } diff --git a/modelscope/utils/hf_util.py b/modelscope/utils/hf_util.py index f98dbea2..4816b95e 100644 --- a/modelscope/utils/hf_util.py +++ b/modelscope/utils/hf_util.py @@ -2,6 +2,7 @@ import os from transformers import AutoConfig as AutoConfigHF +from transformers import AutoImageProcessor as AutoImageProcessorHF from transformers import AutoModel as AutoModelHF from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF @@ -10,6 +11,7 @@ from transformers import \ from transformers import \ AutoModelForTokenClassification as AutoModelForTokenClassificationHF from transformers import AutoTokenizer as AutoTokenizerHF +from transformers import BatchFeature as BatchFeatureHF from transformers import BitsAndBytesConfig as BitsAndBytesConfigHF from transformers import GenerationConfig as GenerationConfigHF from transformers import PreTrainedModel, PreTrainedTokenizerBase @@ -136,3 +138,5 @@ GenerationConfig = get_wrapped_class( GenerationConfigHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors']) GPTQConfig = GPTQConfigHF BitsAndBytesConfig = BitsAndBytesConfigHF +AutoImageProcessor = get_wrapped_class(AutoImageProcessorHF) +BatchFeature = get_wrapped_class(BatchFeatureHF)