add more hf alias

This commit is contained in:
Yingda Chen
2024-11-25 12:14:47 +08:00
parent e3f63fd1ea
commit 27bf8fab1e
2 changed files with 52 additions and 8 deletions

View File

@@ -36,9 +36,12 @@ if TYPE_CHECKING:
from .utils.hf_util import (
AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification, AutoModelForImageSegmentation,
AutoTokenizer, GenerationConfig, AutoImageProcessor, BatchFeature,
T5EncoderModel)
AutoModelForTokenClassification, AutoModelForImageClassification,
AutoModelForImageTextToText, AutoModelForImageToImage,
AutoModelForImageSegmentation, AutoModelForQuestionAnswering,
AutoModelForMaskedLM, AutoTokenizer, AutoModelForMaskGeneration,
AutoModelForPreTraining, AutoModelForTextEncoding,
GenerationConfig, AutoImageProcessor, BatchFeature, T5EncoderModel)
else:
print(
'transformer is not installed, please install it if you want to use related modules'
@@ -96,6 +99,11 @@ else:
'AwqConfig', 'BitsAndBytesConfig', 'AutoModelForCausalLM',
'AutoModelForSeq2SeqLM', 'AutoTokenizer',
'AutoModelForSequenceClassification',
'AutoModelForTokenClassification',
'AutoModelForImageClassification', 'AutoModelForImageTextToText',
'AutoModelForImageToImage', 'AutoModelForQuestionAnswering',
'AutoModelForMaskedLM', 'AutoModelForMaskGeneration',
'AutoModelForPreTraining', 'AutoModelForTextEncoding',
'AutoModelForTokenClassification', 'AutoModelForImageSegmentation',
'AutoImageProcessor', 'BatchFeature', 'T5EncoderModel'
]

View File

@@ -9,11 +9,23 @@ from transformers import AutoFeatureExtractor as AutoFeatureExtractorHF
from transformers import AutoImageProcessor as AutoImageProcessorHF
from transformers import AutoModel as AutoModelHF
from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF
from transformers import \
AutoModelForImageClassification as AutoModelForImageClassificationHF
from transformers import \
AutoModelForImageSegmentation as AutoModelForImageSegmentationHF
from transformers import \
AutoModelForImageTextToText as AutoModelForImageTextToTextHF
from transformers import AutoModelForImageToImage as AutoModelForImageToImageHF
from transformers import AutoModelForMaskedLM as AutoModelForMaskedLMHF
from transformers import \
AutoModelForMaskGeneration as AutoModelForMaskGenerationHF
from transformers import AutoModelForPreTraining as AutoModelForPreTrainingHF
from transformers import \
AutoModelForQuestionAnswering as AutoModelForQuestionAnsweringHF
from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF
from transformers import \
AutoModelForSequenceClassification as AutoModelForSequenceClassificationHF
from transformers import AutoModelForTextEncoding as AutoModelForTextEncodingHF
from transformers import \
AutoModelForTokenClassification as AutoModelForTokenClassificationHF
from transformers import AutoProcessor as AutoProcessorHF
@@ -315,25 +327,49 @@ AutoModelForTokenClassification = get_wrapped_class(
AutoModelForTokenClassificationHF)
AutoModelForImageSegmentation = get_wrapped_class(
AutoModelForImageSegmentationHF)
AutoModelForImageClassification = get_wrapped_class(
AutoModelForImageClassificationHF)
AutoModelForImageTextToText = get_wrapped_class(AutoModelForImageTextToTextHF)
AutoModelForImageToImage = get_wrapped_class(AutoModelForImageToImageHF)
AutoModelForQuestionAnswering = get_wrapped_class(
AutoModelForQuestionAnsweringHF)
AutoModelForMaskedLM = get_wrapped_class(AutoModelForMaskedLMHF)
AutoModelForMaskGeneration = get_wrapped_class(AutoModelForMaskGenerationHF)
AutoModelForPreTraining = get_wrapped_class(AutoModelForPreTrainingHF)
AutoModelForTextEncoding = get_wrapped_class(AutoModelForTextEncodingHF)
T5EncoderModel = get_wrapped_class(T5EncoderModelHF)
AutoTokenizer = get_wrapped_class(
AutoTokenizerHF,
ignore_file_pattern=[
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
])
AutoProcessor = get_wrapped_class(
AutoProcessorHF,
ignore_file_pattern=[
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
])
AutoConfig = get_wrapped_class(
AutoConfigHF,
ignore_file_pattern=[
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
])
GenerationConfig = get_wrapped_class(
GenerationConfigHF,
ignore_file_pattern=[
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
])
BitsAndBytesConfig = get_wrapped_class(
BitsAndBytesConfigHF,
ignore_file_pattern=[
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
])
AutoImageProcessor = AutoImageProcessorHF(
BitsAndBytesConfigHF,
ignore_file_pattern=[
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
])
GPTQConfig = GPTQConfigHF
AwqConfig = AwqConfigHF
BitsAndBytesConfig = BitsAndBytesConfigHF
AutoImageProcessor = get_wrapped_class(AutoImageProcessorHF)
BatchFeature = get_wrapped_class(BatchFeatureHF)