mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
Merge branch 'master-github' into release/1.10
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
from .utils.automodel_utils import fix_transformers_upgrade
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .exporters import Exporter, TfModelExporter, TorchModelExporter
|
||||
@@ -33,7 +34,8 @@ if TYPE_CHECKING:
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification, AutoTokenizer,
|
||||
GenerationConfig)
|
||||
GenerationConfig, AutoImageProcessor,
|
||||
BatchFeature)
|
||||
from .utils.hub import create_model_if_not_exist, read_config
|
||||
from .utils.logger import get_logger
|
||||
from .version import __release_datetime__, __version__
|
||||
@@ -81,7 +83,8 @@ else:
|
||||
'BitsAndBytesConfig', 'AutoModelForCausalLM',
|
||||
'AutoModelForSeq2SeqLM', 'AutoTokenizer',
|
||||
'AutoModelForSequenceClassification',
|
||||
'AutoModelForTokenClassification'
|
||||
'AutoModelForTokenClassification', 'AutoImageProcessor',
|
||||
'BatchFeature'
|
||||
],
|
||||
'msdatasets': ['MsDataset']
|
||||
}
|
||||
@@ -95,3 +98,5 @@ else:
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
fix_transformers_upgrade()
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Tasks
|
||||
from modelscope.models.builder import build_backbone, build_model
|
||||
from modelscope.utils.automodel_utils import (can_load_by_ms, fix_upgrade,
|
||||
from modelscope.utils.automodel_utils import (can_load_by_ms,
|
||||
try_to_load_hf_model)
|
||||
from modelscope.utils.config import Config, ConfigDict
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
|
||||
@@ -192,7 +192,6 @@ class Model(ABC):
|
||||
model_cfg.pop('model_dir', None)
|
||||
model.name = model_name_or_path
|
||||
model.model_dir = local_model_dir
|
||||
fix_upgrade(model)
|
||||
return model
|
||||
|
||||
def save_pretrained(self,
|
||||
|
||||
@@ -28,6 +28,8 @@ class DocumentGroundedDialogRetrievalModel(TorchModel):
|
||||
map_location='cpu')
|
||||
compatible_position_ids(state_dict,
|
||||
'ctx_encoder.encoder.embeddings.position_ids')
|
||||
compatible_position_ids(state_dict,
|
||||
'qry_encoder.encoder.embeddings.position_ids')
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
def forward(self, input: Dict[str, Tensor], gck_segment=32):
|
||||
|
||||
@@ -25,15 +25,27 @@ def can_load_by_ms(model_dir: str, task_name: Optional[str],
|
||||
|
||||
|
||||
def fix_upgrade(module_obj: Any):
|
||||
from transformers import PreTrainedModel
|
||||
if hasattr(module_obj, '_set_gradient_checkpointing') \
|
||||
and 'value' in inspect.signature(module_obj._set_gradient_checkpointing).parameters.keys():
|
||||
module_obj._set_gradient_checkpointing = MethodType(
|
||||
PreTrainedModel._set_gradient_checkpointing, module_obj)
|
||||
|
||||
|
||||
def post_init(self, *args, **kwargs):
|
||||
fix_upgrade(self)
|
||||
self.post_init_origin(*args, **kwargs)
|
||||
|
||||
|
||||
def fix_transformers_upgrade():
|
||||
if is_transformers_available():
|
||||
# from 4.35.0, transformers changes its arguments of _set_gradient_checkpointing
|
||||
import transformers
|
||||
from transformers import PreTrainedModel
|
||||
if version.parse(transformers.__version__) >= version.parse('4.35.0'):
|
||||
if isinstance(module_obj, PreTrainedModel) and hasattr(module_obj, '_set_gradient_checkpointing') \
|
||||
and 'value' in inspect.signature(module_obj._set_gradient_checkpointing).parameters.keys():
|
||||
module_obj._set_gradient_checkpointing = MethodType(
|
||||
PreTrainedModel._set_gradient_checkpointing, module_obj)
|
||||
if version.parse(transformers.__version__) >= version.parse('4.35.0') \
|
||||
and not hasattr(PreTrainedModel, 'post_init_origin'):
|
||||
PreTrainedModel.post_init_origin = PreTrainedModel.post_init
|
||||
PreTrainedModel.post_init = post_init
|
||||
|
||||
|
||||
def _can_load_by_hf_automodel(automodel_class: type, config) -> bool:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import os
|
||||
|
||||
from transformers import AutoConfig as AutoConfigHF
|
||||
from transformers import AutoImageProcessor as AutoImageProcessorHF
|
||||
from transformers import AutoModel as AutoModelHF
|
||||
from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF
|
||||
from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF
|
||||
@@ -10,12 +11,12 @@ from transformers import \
|
||||
from transformers import \
|
||||
AutoModelForTokenClassification as AutoModelForTokenClassificationHF
|
||||
from transformers import AutoTokenizer as AutoTokenizerHF
|
||||
from transformers import BatchFeature as BatchFeatureHF
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfigHF
|
||||
from transformers import GenerationConfig as GenerationConfigHF
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
|
||||
from modelscope import snapshot_download
|
||||
from modelscope.utils.automodel_utils import fix_upgrade
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
|
||||
|
||||
try:
|
||||
@@ -112,7 +113,6 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
|
||||
|
||||
if module_class.__name__.startswith('AutoModel'):
|
||||
module_obj.model_dir = model_dir
|
||||
fix_upgrade(module_obj)
|
||||
return module_obj
|
||||
|
||||
ClassWrapper.__name__ = module_class.__name__
|
||||
@@ -136,3 +136,5 @@ GenerationConfig = get_wrapped_class(
|
||||
GenerationConfigHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors'])
|
||||
GPTQConfig = GPTQConfigHF
|
||||
BitsAndBytesConfig = BitsAndBytesConfigHF
|
||||
AutoImageProcessor = get_wrapped_class(AutoImageProcessorHF)
|
||||
BatchFeature = get_wrapped_class(BatchFeatureHF)
|
||||
|
||||
@@ -115,6 +115,9 @@ def get_case_model_info():
|
||||
elements = line.split(':')
|
||||
test_file = elements[0]
|
||||
model_pos = line.find('damo')
|
||||
if model_pos == -1 or (model_pos - 1) > len(line):
|
||||
print('Processing line: %s failed' % line)
|
||||
continue
|
||||
left_quote = line[model_pos - 1]
|
||||
rquote_idx = line.rfind(left_quote)
|
||||
model_name = line[model_pos:rquote_idx]
|
||||
|
||||
Reference in New Issue
Block a user