Handle unsupported Transformers class, and add more auto classes (#1113)

* optimize unsupported transformer class and add more automodel

Co-authored-by: Yingda Chen <yingda.chen@alibaba-inc.com>
This commit is contained in:
Yingda Chen
2024-12-02 15:17:17 +08:00
committed by GitHub
parent a4582012ff
commit 3017a70262
2 changed files with 116 additions and 7 deletions

View File

@@ -38,11 +38,23 @@ if TYPE_CHECKING:
AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForVision2Seq, AutoModelForSequenceClassification,
AutoModelForTokenClassification, AutoModelForImageClassification,
AutoModelForImageToImage, AutoModelForImageSegmentation,
AutoModelForQuestionAnswering, AutoModelForMaskedLM, AutoTokenizer,
AutoModelForMaskGeneration, AutoModelForPreTraining,
AutoModelForTextEncoding, AutoImageProcessor, BatchFeature,
Qwen2VLForConditionalGeneration, T5EncoderModel)
AutoModelForImageTextToText,
AutoModelForZeroShotImageClassification,
AutoModelForKeypointDetection,
AutoModelForDocumentQuestionAnswering,
AutoModelForSemanticSegmentation,
AutoModelForUniversalSegmentation,
AutoModelForInstanceSegmentation, AutoModelForObjectDetection,
AutoModelForZeroShotObjectDetection,
AutoModelForAudioClassification, AutoModelForSpeechSeq2Seq,
AutoModelForMaskedImageModeling,
AutoModelForVisualQuestionAnswering,
AutoModelForTableQuestionAnswering, AutoModelForImageToImage,
AutoModelForImageSegmentation, AutoModelForQuestionAnswering,
AutoModelForMaskedLM, AutoTokenizer, AutoModelForMaskGeneration,
AutoModelForPreTraining, AutoModelForTextEncoding,
AutoImageProcessor, BatchFeature, Qwen2VLForConditionalGeneration,
T5EncoderModel)
else:
print(
'transformer is not installed, please install it if you want to use related modules'
@@ -103,6 +115,18 @@ else:
'AutoModelForSequenceClassification',
'AutoModelForTokenClassification',
'AutoModelForImageClassification', 'AutoModelForImageToImage',
'AutoModelForImageTextToText',
'AutoModelForZeroShotImageClassification',
'AutoModelForKeypointDetection',
'AutoModelForDocumentQuestionAnswering',
'AutoModelForSemanticSegmentation',
'AutoModelForUniversalSegmentation',
'AutoModelForInstanceSegmentation', 'AutoModelForObjectDetection',
'AutoModelForZeroShotObjectDetection',
'AutoModelForAudioClassification', 'AutoModelForSpeechSeq2Seq',
'AutoModelForMaskedImageModeling',
'AutoModelForVisualQuestionAnswering',
'AutoModelForTableQuestionAnswering',
'AutoModelForImageSegmentation', 'AutoModelForQuestionAnswering',
'AutoModelForMaskedLM', 'AutoTokenizer',
'AutoModelForMaskGeneration', 'AutoModelForPreTraining',

View File

@@ -8,24 +8,52 @@ from transformers import AutoConfig as AutoConfigHF
from transformers import AutoFeatureExtractor as AutoFeatureExtractorHF
from transformers import AutoImageProcessor as AutoImageProcessorHF
from transformers import AutoModel as AutoModelHF
from transformers import \
AutoModelForAudioClassification as AutoModelForAudioClassificationHF
from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF
from transformers import \
AutoModelForDocumentQuestionAnswering as \
AutoModelForDocumentQuestionAnsweringHF
from transformers import \
AutoModelForImageClassification as AutoModelForImageClassificationHF
from transformers import \
AutoModelForImageSegmentation as AutoModelForImageSegmentationHF
from transformers import \
AutoModelForInstanceSegmentation as AutoModelForInstanceSegmentationHF
from transformers import \
AutoModelForMaskedImageModeling as AutoModelForMaskedImageModelingHF
from transformers import AutoModelForMaskedLM as AutoModelForMaskedLMHF
from transformers import \
AutoModelForMaskGeneration as AutoModelForMaskGenerationHF
from transformers import \
AutoModelForObjectDetection as AutoModelForObjectDetectionHF
from transformers import AutoModelForPreTraining as AutoModelForPreTrainingHF
from transformers import \
AutoModelForQuestionAnswering as AutoModelForQuestionAnsweringHF
from transformers import \
AutoModelForSemanticSegmentation as AutoModelForSemanticSegmentationHF
from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF
from transformers import \
AutoModelForSequenceClassification as AutoModelForSequenceClassificationHF
from transformers import \
AutoModelForSpeechSeq2Seq as AutoModelForSpeechSeq2SeqHF
from transformers import \
AutoModelForTableQuestionAnswering as AutoModelForTableQuestionAnsweringHF
from transformers import AutoModelForTextEncoding as AutoModelForTextEncodingHF
from transformers import \
AutoModelForTokenClassification as AutoModelForTokenClassificationHF
from transformers import \
AutoModelForUniversalSegmentation as AutoModelForUniversalSegmentationHF
from transformers import AutoModelForVision2Seq as AutoModelForVision2SeqHF
from transformers import \
AutoModelForVisualQuestionAnswering as \
AutoModelForVisualQuestionAnsweringHF
from transformers import \
AutoModelForZeroShotImageClassification as \
AutoModelForZeroShotImageClassificationHF
from transformers import \
AutoModelForZeroShotObjectDetection as \
AutoModelForZeroShotObjectDetectionHF
from transformers import AutoProcessor as AutoProcessorHF
from transformers import AutoTokenizer as AutoTokenizerHF
from transformers import BatchFeature as BatchFeatureHF
@@ -34,6 +62,7 @@ from transformers import GenerationConfig as GenerationConfigHF
from transformers import (PretrainedConfig, PreTrainedModel,
PreTrainedTokenizerBase)
from transformers import T5EncoderModel as T5EncoderModelHF
from transformers import __version__ as transformers_version
from modelscope import snapshot_download
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
@@ -49,6 +78,21 @@ except ImportError:
logger = get_logger()
class UnsupportedAutoClass:
def __init__(self, name: str):
self.error_msg =\
f'{name} is not supported with your installed Transformers version {transformers_version}. ' + \
'Please update your Transformers by "pip install transformers -U".'
def from_pretrained(self, pretrained_model_name_or_path, *model_args,
**kwargs):
raise ImportError(self.error_msg)
def from_config(self, cls, config):
raise ImportError(self.error_msg)
def user_agent(invoked_by=None):
if invoked_by is None:
invoked_by = Invoke.PRETRAINED
@@ -328,14 +372,54 @@ AutoModelForImageSegmentation = get_wrapped_class(
AutoModelForImageSegmentationHF)
AutoModelForImageClassification = get_wrapped_class(
AutoModelForImageClassificationHF)
AutoModelForZeroShotImageClassification = get_wrapped_class(
AutoModelForZeroShotImageClassificationHF)
try:
from transformers import AutoModelForImageToImage as AutoModelForImageToImageHF
AutoModelForImageToImage = get_wrapped_class(AutoModelForImageToImageHF)
except ImportError:
AutoModelForImageToImage = None
AutoModelForImageToImage = UnsupportedAutoClass('AutoModelForImageToImage')
try:
from transformers import AutoModelForImageTextToText as AutoModelForImageTextToTextHF
AutoModelForImageTextToText = get_wrapped_class(
AutoModelForImageTextToTextHF)
except ImportError:
AutoModelForImageTextToText = UnsupportedAutoClass(
'AutoModelForImageTextToText')
try:
from transformers import AutoModelForKeypointDetection as AutoModelForKeypointDetectionHF
AutoModelForKeypointDetection = get_wrapped_class(
AutoModelForKeypointDetectionHF)
except ImportError:
AutoModelForKeypointDetection = UnsupportedAutoClass(
'AutoModelForKeypointDetection')
AutoModelForQuestionAnswering = get_wrapped_class(
AutoModelForQuestionAnsweringHF)
AutoModelForTableQuestionAnswering = get_wrapped_class(
AutoModelForTableQuestionAnsweringHF)
AutoModelForVisualQuestionAnswering = get_wrapped_class(
AutoModelForVisualQuestionAnsweringHF)
AutoModelForKeypointDetection = get_wrapped_class(
AutoModelForKeypointDetectionHF)
AutoModelForDocumentQuestionAnswering = get_wrapped_class(
AutoModelForDocumentQuestionAnsweringHF)
AutoModelForSemanticSegmentation = get_wrapped_class(
AutoModelForSemanticSegmentationHF)
AutoModelForUniversalSegmentation = get_wrapped_class(
AutoModelForUniversalSegmentationHF)
AutoModelForInstanceSegmentation = get_wrapped_class(
AutoModelForInstanceSegmentationHF)
AutoModelForObjectDetection = get_wrapped_class(AutoModelForObjectDetectionHF)
AutoModelForZeroShotObjectDetection = get_wrapped_class(
AutoModelForZeroShotObjectDetectionHF)
AutoModelForAudioClassification = get_wrapped_class(
AutoModelForAudioClassificationHF)
AutoModelForSpeechSeq2Seq = get_wrapped_class(AutoModelForSpeechSeq2SeqHF)
AutoModelForMaskedImageModeling = get_wrapped_class(
AutoModelForMaskedImageModelingHF)
AutoModelForMaskedLM = get_wrapped_class(AutoModelForMaskedLMHF)
AutoModelForMaskGeneration = get_wrapped_class(AutoModelForMaskGenerationHF)
AutoModelForPreTraining = get_wrapped_class(AutoModelForPreTrainingHF)
@@ -347,7 +431,8 @@ try:
Qwen2VLForConditionalGeneration = get_wrapped_class(
Qwen2VLForConditionalGenerationHF)
except ImportError:
Qwen2VLForConditionalGeneration = None
Qwen2VLForConditionalGeneration = UnsupportedAutoClass(
'Qwen2VLForConditionalGeneration')
AutoTokenizer = get_wrapped_class(
AutoTokenizerHF,