Merge branch 'master-github' into release/1.10

This commit is contained in:
mulin.lyh
2023-12-02 21:16:24 +08:00
6 changed files with 34 additions and 11 deletions

View File

@@ -2,6 +2,7 @@
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
from .utils.automodel_utils import fix_transformers_upgrade
if TYPE_CHECKING:
from .exporters import Exporter, TfModelExporter, TorchModelExporter
@@ -33,7 +34,8 @@ if TYPE_CHECKING:
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification, AutoTokenizer,
GenerationConfig)
GenerationConfig, AutoImageProcessor,
BatchFeature)
from .utils.hub import create_model_if_not_exist, read_config
from .utils.logger import get_logger
from .version import __release_datetime__, __version__
@@ -81,7 +83,8 @@ else:
'BitsAndBytesConfig', 'AutoModelForCausalLM',
'AutoModelForSeq2SeqLM', 'AutoTokenizer',
'AutoModelForSequenceClassification',
'AutoModelForTokenClassification'
'AutoModelForTokenClassification', 'AutoImageProcessor',
'BatchFeature'
],
'msdatasets': ['MsDataset']
}
@@ -95,3 +98,5 @@ else:
module_spec=__spec__,
extra_objects={},
)
fix_transformers_upgrade()

View File

@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Union
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Tasks
from modelscope.models.builder import build_backbone, build_model
from modelscope.utils.automodel_utils import (can_load_by_ms, fix_upgrade,
from modelscope.utils.automodel_utils import (can_load_by_ms,
try_to_load_hf_model)
from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
@@ -192,7 +192,6 @@ class Model(ABC):
model_cfg.pop('model_dir', None)
model.name = model_name_or_path
model.model_dir = local_model_dir
fix_upgrade(model)
return model
def save_pretrained(self,

View File

@@ -28,6 +28,8 @@ class DocumentGroundedDialogRetrievalModel(TorchModel):
map_location='cpu')
compatible_position_ids(state_dict,
'ctx_encoder.encoder.embeddings.position_ids')
compatible_position_ids(state_dict,
'qry_encoder.encoder.embeddings.position_ids')
self.model.load_state_dict(state_dict)
def forward(self, input: Dict[str, Tensor], gck_segment=32):

View File

@@ -25,15 +25,27 @@ def can_load_by_ms(model_dir: str, task_name: Optional[str],
def fix_upgrade(module_obj: Any):
from transformers import PreTrainedModel
if hasattr(module_obj, '_set_gradient_checkpointing') \
and 'value' in inspect.signature(module_obj._set_gradient_checkpointing).parameters.keys():
module_obj._set_gradient_checkpointing = MethodType(
PreTrainedModel._set_gradient_checkpointing, module_obj)
def post_init(self, *args, **kwargs):
fix_upgrade(self)
self.post_init_origin(*args, **kwargs)
def fix_transformers_upgrade():
if is_transformers_available():
# from 4.35.0, transformers changes its arguments of _set_gradient_checkpointing
import transformers
from transformers import PreTrainedModel
if version.parse(transformers.__version__) >= version.parse('4.35.0'):
if isinstance(module_obj, PreTrainedModel) and hasattr(module_obj, '_set_gradient_checkpointing') \
and 'value' in inspect.signature(module_obj._set_gradient_checkpointing).parameters.keys():
module_obj._set_gradient_checkpointing = MethodType(
PreTrainedModel._set_gradient_checkpointing, module_obj)
if version.parse(transformers.__version__) >= version.parse('4.35.0') \
and not hasattr(PreTrainedModel, 'post_init_origin'):
PreTrainedModel.post_init_origin = PreTrainedModel.post_init
PreTrainedModel.post_init = post_init
def _can_load_by_hf_automodel(automodel_class: type, config) -> bool:

View File

@@ -2,6 +2,7 @@
import os
from transformers import AutoConfig as AutoConfigHF
from transformers import AutoImageProcessor as AutoImageProcessorHF
from transformers import AutoModel as AutoModelHF
from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF
from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF
@@ -10,12 +11,12 @@ from transformers import \
from transformers import \
AutoModelForTokenClassification as AutoModelForTokenClassificationHF
from transformers import AutoTokenizer as AutoTokenizerHF
from transformers import BatchFeature as BatchFeatureHF
from transformers import BitsAndBytesConfig as BitsAndBytesConfigHF
from transformers import GenerationConfig as GenerationConfigHF
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from modelscope import snapshot_download
from modelscope.utils.automodel_utils import fix_upgrade
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
try:
@@ -112,7 +113,6 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
if module_class.__name__.startswith('AutoModel'):
module_obj.model_dir = model_dir
fix_upgrade(module_obj)
return module_obj
ClassWrapper.__name__ = module_class.__name__
@@ -136,3 +136,5 @@ GenerationConfig = get_wrapped_class(
GenerationConfigHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors'])
GPTQConfig = GPTQConfigHF
BitsAndBytesConfig = BitsAndBytesConfigHF
AutoImageProcessor = get_wrapped_class(AutoImageProcessorHF)
BatchFeature = get_wrapped_class(BatchFeatureHF)

View File

@@ -115,6 +115,9 @@ def get_case_model_info():
elements = line.split(':')
test_file = elements[0]
model_pos = line.find('damo')
if model_pos == -1 or (model_pos - 1) > len(line):
print('Processing line: %s failed' % line)
continue
left_quote = line[model_pos - 1]
rquote_idx = line.rfind(left_quote)
model_name = line[model_pos:rquote_idx]