hf_tool add AutoImageProcessor and BatchFeature (#650)

* add AutoImageProcessor and BatchFeature
This commit is contained in:
mushenL
2023-12-01 15:14:06 +08:00
committed by GitHub
parent 7b80a87b5a
commit 1b6723ea24
2 changed files with 8 additions and 2 deletions

View File

@@ -33,7 +33,8 @@ if TYPE_CHECKING:
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification, AutoTokenizer,
GenerationConfig)
GenerationConfig, AutoImageProcessor,
BatchFeature)
from .utils.hub import create_model_if_not_exist, read_config
from .utils.logger import get_logger
from .version import __release_datetime__, __version__
@@ -81,7 +82,8 @@ else:
'BitsAndBytesConfig', 'AutoModelForCausalLM',
'AutoModelForSeq2SeqLM', 'AutoTokenizer',
'AutoModelForSequenceClassification',
'AutoModelForTokenClassification'
'AutoModelForTokenClassification', 'AutoImageProcessor',
'BatchFeature'
],
'msdatasets': ['MsDataset']
}

View File

@@ -2,6 +2,7 @@
import os
from transformers import AutoConfig as AutoConfigHF
from transformers import AutoImageProcessor as AutoImageProcessorHF
from transformers import AutoModel as AutoModelHF
from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF
from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF
@@ -10,6 +11,7 @@ from transformers import \
from transformers import \
AutoModelForTokenClassification as AutoModelForTokenClassificationHF
from transformers import AutoTokenizer as AutoTokenizerHF
from transformers import BatchFeature as BatchFeatureHF
from transformers import BitsAndBytesConfig as BitsAndBytesConfigHF
from transformers import GenerationConfig as GenerationConfigHF
from transformers import PreTrainedModel, PreTrainedTokenizerBase
@@ -136,3 +138,5 @@ GenerationConfig = get_wrapped_class(
GenerationConfigHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors'])
GPTQConfig = GPTQConfigHF
BitsAndBytesConfig = BitsAndBytesConfigHF
AutoImageProcessor = get_wrapped_class(AutoImageProcessorHF)
BatchFeature = get_wrapped_class(BatchFeatureHF)