add awqconfig (#761)

This commit is contained in:
Jintao
2024-02-21 14:54:48 +08:00
committed by GitHub
parent b037e9caf0
commit e168717f36
2 changed files with 5 additions and 2 deletions

View File

@@ -29,7 +29,7 @@ if TYPE_CHECKING:
from .trainers import (EpochBasedTrainer, Hook, Priority, TrainingArgs,
build_dataset_from_file)
from .utils.constant import Tasks
from .utils.hf_util import AutoConfig, GPTQConfig, BitsAndBytesConfig
from .utils.hf_util import AutoConfig, GPTQConfig, AwqConfig, BitsAndBytesConfig
from .utils.hf_util import (AutoModel, AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
@@ -80,7 +80,7 @@ else:
'utils.constant': ['Tasks'],
'utils.hf_util': [
'AutoConfig', 'GenerationConfig', 'AutoModel', 'GPTQConfig',
'BitsAndBytesConfig', 'AutoModelForCausalLM',
'AwqConfig', 'BitsAndBytesConfig', 'AutoModelForCausalLM',
'AutoModelForSeq2SeqLM', 'AutoTokenizer',
'AutoModelForSequenceClassification',
'AutoModelForTokenClassification', 'AutoImageProcessor',

View File

@@ -21,8 +21,10 @@ from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
try:
from transformers import GPTQConfig as GPTQConfigHF
from transformers import AwqConfig as AwqConfigHF
except ImportError:
GPTQConfigHF = None
AwqConfigHF = None
def user_agent(invoked_by=None):
@@ -135,6 +137,7 @@ AutoConfig = get_wrapped_class(
GenerationConfig = get_wrapped_class(
GenerationConfigHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors'])
GPTQConfig = GPTQConfigHF
AwqConfig = AwqConfigHF
BitsAndBytesConfig = BitsAndBytesConfigHF
AutoImageProcessor = get_wrapped_class(AutoImageProcessorHF)
BatchFeature = get_wrapped_class(BatchFeatureHF)