diff --git a/modelscope/__init__.py b/modelscope/__init__.py index b7712b3b..996fd4c9 100644 --- a/modelscope/__init__.py +++ b/modelscope/__init__.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: AutoModel, AutoProcessor, AutoFeatureExtractor, GenerationConfig, AutoConfig, GPTQConfig, AwqConfig, BitsAndBytesConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, - AutoModelForSequenceClassification, + AutoModelForVision2Seq, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoModelForImageClassification, AutoModelForImageToImage, AutoModelForImageSegmentation, AutoModelForQuestionAnswering, AutoModelForMaskedLM, AutoTokenizer, @@ -99,7 +99,8 @@ else: 'AutoModel', 'AutoProcessor', 'AutoFeatureExtractor', 'GenerationConfig', 'AutoConfig', 'GPTQConfig', 'AwqConfig', 'BitsAndBytesConfig', 'AutoModelForCausalLM', - 'AutoModelForSeq2SeqLM', 'AutoModelForSequenceClassification', + 'AutoModelForSeq2SeqLM', 'AutoModelForVision2Seq', + 'AutoModelForSequenceClassification', 'AutoModelForTokenClassification', 'AutoModelForImageClassification', 'AutoModelForImageToImage', 'AutoModelForImageSegmentation', 'AutoModelForQuestionAnswering', diff --git a/modelscope/utils/hf_util.py b/modelscope/utils/hf_util.py index f6613f98..d838347a 100644 --- a/modelscope/utils/hf_util.py +++ b/modelscope/utils/hf_util.py @@ -26,6 +26,7 @@ from transformers import \ from transformers import AutoModelForTextEncoding as AutoModelForTextEncodingHF from transformers import \ AutoModelForTokenClassification as AutoModelForTokenClassificationHF +from transformers import AutoModelForVision2Seq as AutoModelForVision2SeqHF from transformers import AutoProcessor as AutoProcessorHF from transformers import AutoTokenizer as AutoTokenizerHF from transformers import BatchFeature as BatchFeatureHF @@ -321,6 +322,7 @@ def get_wrapped_class(module_class, AutoModel = get_wrapped_class(AutoModelHF) AutoModelForCausalLM = get_wrapped_class(AutoModelForCausalLMHF) AutoModelForSeq2SeqLM = get_wrapped_class(AutoModelForSeq2SeqLMHF) +AutoModelForVision2Seq = get_wrapped_class(AutoModelForVision2SeqHF) AutoModelForSequenceClassification = get_wrapped_class( AutoModelForSequenceClassificationHF) AutoModelForTokenClassification = get_wrapped_class(