add transformers compatability for Vision2seq (#1107)

* add vision2seq
---------

Co-authored-by: Yingda Chen <yingda.chen@alibaba-inc.com>
This commit is contained in:
Yingda Chen
2024-11-29 10:35:50 +08:00
committed by GitHub
parent 3e13cc899b
commit 95cad91c21
3 changed files with 6 additions and 3 deletions

View File

@@ -36,7 +36,7 @@ if TYPE_CHECKING:
AutoModel, AutoProcessor, AutoFeatureExtractor, GenerationConfig,
AutoConfig, GPTQConfig, AwqConfig, BitsAndBytesConfig,
AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForVision2Seq, AutoModelForSequenceClassification,
AutoModelForTokenClassification, AutoModelForImageClassification,
AutoModelForImageToImage, AutoModelForImageSegmentation,
AutoModelForQuestionAnswering, AutoModelForMaskedLM, AutoTokenizer,
@@ -99,7 +99,8 @@ else:
'AutoModel', 'AutoProcessor', 'AutoFeatureExtractor',
'GenerationConfig', 'AutoConfig', 'GPTQConfig', 'AwqConfig',
'BitsAndBytesConfig', 'AutoModelForCausalLM',
'AutoModelForSeq2SeqLM', 'AutoModelForSequenceClassification',
'AutoModelForSeq2SeqLM', 'AutoModelForVision2Seq',
'AutoModelForSequenceClassification',
'AutoModelForTokenClassification',
'AutoModelForImageClassification', 'AutoModelForImageToImage',
'AutoModelForImageSegmentation', 'AutoModelForQuestionAnswering',

View File

@@ -571,7 +571,7 @@ class HubApi:
revision_detail = self.get_branch_tag_detail(all_tags_detail, revision)
if revision_detail is None:
revision_detail = self.get_branch_tag_detail(all_branches_detail, revision)
logger.info('Development mode use revision: %s' % revision)
logger.debug('Development mode use revision: %s' % revision)
else:
if revision is not None and revision in all_branches:
revision_detail = self.get_branch_tag_detail(all_branches_detail, revision)

View File

@@ -26,6 +26,7 @@ from transformers import \
from transformers import AutoModelForTextEncoding as AutoModelForTextEncodingHF
from transformers import \
AutoModelForTokenClassification as AutoModelForTokenClassificationHF
from transformers import AutoModelForVision2Seq as AutoModelForVision2SeqHF
from transformers import AutoProcessor as AutoProcessorHF
from transformers import AutoTokenizer as AutoTokenizerHF
from transformers import BatchFeature as BatchFeatureHF
@@ -321,6 +322,7 @@ def get_wrapped_class(module_class,
AutoModel = get_wrapped_class(AutoModelHF)
AutoModelForCausalLM = get_wrapped_class(AutoModelForCausalLMHF)
AutoModelForSeq2SeqLM = get_wrapped_class(AutoModelForSeq2SeqLMHF)
AutoModelForVision2Seq = get_wrapped_class(AutoModelForVision2SeqHF)
AutoModelForSequenceClassification = get_wrapped_class(
AutoModelForSequenceClassificationHF)
AutoModelForTokenClassification = get_wrapped_class(