add vision2seq

This commit is contained in:
Yingda Chen
2024-11-29 09:39:36 +08:00
parent 3e13cc899b
commit b0ecab8f5a
2 changed files with 5 additions and 2 deletions

View File

@@ -36,7 +36,7 @@ if TYPE_CHECKING:
AutoModel, AutoProcessor, AutoFeatureExtractor, GenerationConfig,
AutoConfig, GPTQConfig, AwqConfig, BitsAndBytesConfig,
AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForVision2Seq, AutoModelForSequenceClassification,
AutoModelForTokenClassification, AutoModelForImageClassification,
AutoModelForImageToImage, AutoModelForImageSegmentation,
AutoModelForQuestionAnswering, AutoModelForMaskedLM, AutoTokenizer,
@@ -99,7 +99,8 @@ else:
'AutoModel', 'AutoProcessor', 'AutoFeatureExtractor',
'GenerationConfig', 'AutoConfig', 'GPTQConfig', 'AwqConfig',
'BitsAndBytesConfig', 'AutoModelForCausalLM',
'AutoModelForSeq2SeqLM', 'AutoModelForSequenceClassification',
'AutoModelForSeq2SeqLM', 'AutoModelForVision2Seq',
'AutoModelForSequenceClassification',
'AutoModelForTokenClassification',
'AutoModelForImageClassification', 'AutoModelForImageToImage',
'AutoModelForImageSegmentation', 'AutoModelForQuestionAnswering',

View File

@@ -26,6 +26,7 @@ from transformers import \
from transformers import AutoModelForTextEncoding as AutoModelForTextEncodingHF
from transformers import \
AutoModelForTokenClassification as AutoModelForTokenClassificationHF
from transformers import AutoModelForVision2Seq as AutoModelForVision2SeqHF
from transformers import AutoProcessor as AutoProcessorHF
from transformers import AutoTokenizer as AutoTokenizerHF
from transformers import BatchFeature as BatchFeatureHF
@@ -321,6 +322,7 @@ def get_wrapped_class(module_class,
AutoModel = get_wrapped_class(AutoModelHF)
AutoModelForCausalLM = get_wrapped_class(AutoModelForCausalLMHF)
AutoModelForSeq2SeqLM = get_wrapped_class(AutoModelForSeq2SeqLMHF)
AutoModelForVision2Seq = get_wrapped_class(AutoModelForVision2SeqHF)
AutoModelForSequenceClassification = get_wrapped_class(
AutoModelForSequenceClassificationHF)
AutoModelForTokenClassification = get_wrapped_class(