mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-25 04:30:48 +01:00
Merge remote-tracking branch 'origin' into release/1.0
This commit is contained in:
@@ -7,9 +7,9 @@ from typing import Any, Dict, Mapping
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.onnx import export as onnx_export
|
||||
from torch.onnx.utils import _decide_input_format
|
||||
|
||||
from modelscope.models import TorchModel
|
||||
from modelscope.outputs import ModelOutputBase
|
||||
from modelscope.pipelines.base import collate_fn
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
@@ -102,6 +102,53 @@ class TorchModelExporter(Exporter):
|
||||
"""
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _decide_input_format(model, args):
|
||||
import inspect
|
||||
|
||||
def _signature(model) -> inspect.Signature:
|
||||
should_be_callable = getattr(model, 'forward', model)
|
||||
if callable(should_be_callable):
|
||||
return inspect.signature(should_be_callable)
|
||||
raise ValueError('model has no forward method and is not callable')
|
||||
|
||||
try:
|
||||
sig = _signature(model)
|
||||
except ValueError as e:
|
||||
logger.warn('%s, skipping _decide_input_format' % e)
|
||||
return args
|
||||
try:
|
||||
ordered_list_keys = list(sig.parameters.keys())
|
||||
if ordered_list_keys[0] == 'self':
|
||||
ordered_list_keys = ordered_list_keys[1:]
|
||||
args_dict: Dict = {}
|
||||
if isinstance(args, list):
|
||||
args_list = args
|
||||
elif isinstance(args, tuple):
|
||||
args_list = list(args)
|
||||
else:
|
||||
args_list = [args]
|
||||
if isinstance(args_list[-1], dict):
|
||||
args_dict = args_list[-1]
|
||||
args_list = args_list[:-1]
|
||||
n_nonkeyword = len(args_list)
|
||||
for optional_arg in ordered_list_keys[n_nonkeyword:]:
|
||||
if optional_arg in args_dict:
|
||||
args_list.append(args_dict[optional_arg])
|
||||
# Check if this arg has a default value
|
||||
else:
|
||||
param = sig.parameters[optional_arg]
|
||||
if param.default != param.empty:
|
||||
args_list.append(param.default)
|
||||
args = args_list if isinstance(args, list) else tuple(args_list)
|
||||
# Cases of models with no input args
|
||||
except IndexError:
|
||||
logger.warn('No input args, skipping _decide_input_format')
|
||||
except Exception as e:
|
||||
logger.warn('Skipping _decide_input_format\n {}'.format(e.args[0]))
|
||||
|
||||
return args
|
||||
|
||||
def _torch_export_onnx(self,
|
||||
model: nn.Module,
|
||||
output: str,
|
||||
@@ -179,16 +226,21 @@ class TorchModelExporter(Exporter):
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
outputs_origin = model.forward(
|
||||
*_decide_input_format(model, dummy_inputs))
|
||||
if isinstance(outputs_origin, Mapping):
|
||||
outputs_origin = numpify_tensor_nested(
|
||||
list(outputs_origin.values()))
|
||||
*self._decide_input_format(model, dummy_inputs))
|
||||
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
|
||||
outputs_origin = list(
|
||||
numpify_tensor_nested(outputs_origin).values())
|
||||
elif isinstance(outputs_origin, (tuple, list)):
|
||||
outputs_origin = numpify_tensor_nested(outputs_origin)
|
||||
outputs_origin = list(numpify_tensor_nested(outputs_origin))
|
||||
outputs = ort_session.run(
|
||||
onnx_outputs,
|
||||
numpify_tensor_nested(dummy_inputs),
|
||||
)
|
||||
outputs = numpify_tensor_nested(outputs)
|
||||
if isinstance(outputs, dict):
|
||||
outputs = list(outputs.values())
|
||||
elif isinstance(outputs, tuple):
|
||||
outputs = list(outputs)
|
||||
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
@@ -232,12 +284,26 @@ class TorchModelExporter(Exporter):
|
||||
'Model property dummy_inputs must be set.')
|
||||
dummy_inputs = collate_fn(dummy_inputs, device)
|
||||
if isinstance(dummy_inputs, Mapping):
|
||||
dummy_inputs = tuple(dummy_inputs.values())
|
||||
dummy_inputs = self._decide_input_format(model, dummy_inputs)
|
||||
dummy_inputs_filter = []
|
||||
for _input in dummy_inputs:
|
||||
if _input is not None:
|
||||
dummy_inputs_filter.append(_input)
|
||||
else:
|
||||
break
|
||||
|
||||
if len(dummy_inputs) != len(dummy_inputs_filter):
|
||||
logger.warn(
|
||||
f'Dummy inputs is not continuous in the forward method, '
|
||||
f'origin length: {len(dummy_inputs)}, '
|
||||
f'the length after filtering: {len(dummy_inputs_filter)}')
|
||||
dummy_inputs = dummy_inputs_filter
|
||||
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
with replace_call():
|
||||
traced_model = torch.jit.trace(
|
||||
model, dummy_inputs, strict=strict)
|
||||
model, tuple(dummy_inputs), strict=strict)
|
||||
torch.jit.save(traced_model, output)
|
||||
|
||||
if validation:
|
||||
@@ -249,6 +315,10 @@ class TorchModelExporter(Exporter):
|
||||
outputs = numpify_tensor_nested(outputs)
|
||||
outputs_origin = model.forward(*dummy_inputs)
|
||||
outputs_origin = numpify_tensor_nested(outputs_origin)
|
||||
if isinstance(outputs, dict):
|
||||
outputs = list(outputs.values())
|
||||
if isinstance(outputs_origin, dict):
|
||||
outputs_origin = list(outputs_origin.values())
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
|
||||
@@ -161,5 +161,12 @@ class Model(ABC):
|
||||
assert config is not None, 'Cannot save the model because the model config is empty.'
|
||||
if isinstance(config, Config):
|
||||
config = config.to_dict()
|
||||
if 'preprocessor' in config and config['preprocessor'] is not None:
|
||||
if 'mode' in config['preprocessor']:
|
||||
config['preprocessor']['mode'] = 'inference'
|
||||
elif 'val' in config['preprocessor'] and 'mode' in config[
|
||||
'preprocessor']['val']:
|
||||
config['preprocessor']['val']['mode'] = 'inference'
|
||||
|
||||
save_pretrained(self, target_folder, save_checkpoint_names,
|
||||
save_function, config, **kwargs)
|
||||
|
||||
@@ -36,6 +36,7 @@ class BertForTextRanking(BertForSequenceClassification):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
*args,
|
||||
**kwargs) -> AttentionTextClassificationModelOutput:
|
||||
outputs = self.base_model.forward(
|
||||
input_ids=input_ids,
|
||||
|
||||
@@ -109,6 +109,7 @@ class SbertForSequenceClassification(SbertPreTrainedModel):
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
r"""
|
||||
Args:
|
||||
|
||||
@@ -672,7 +672,7 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
self.model, cfg=cfg, default_args=default_args)
|
||||
except KeyError as e:
|
||||
self.logger.error(
|
||||
f'Build optimizer error, the optimizer {cfg} is native torch optimizer, '
|
||||
f'Build optimizer error, the optimizer {cfg} is a torch native component, '
|
||||
f'please check if your torch with version: {torch.__version__} matches the config.'
|
||||
)
|
||||
raise e
|
||||
@@ -682,7 +682,7 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
return build_lr_scheduler(cfg=cfg, default_args=default_args)
|
||||
except KeyError as e:
|
||||
self.logger.error(
|
||||
f'Build lr_scheduler error, the lr_scheduler {cfg} is native torch lr_scheduler, '
|
||||
f'Build lr_scheduler error, the lr_scheduler {cfg} is a torch native component, '
|
||||
f'please check if your torch with version: {torch.__version__} matches the config.'
|
||||
)
|
||||
raise e
|
||||
|
||||
@@ -23,7 +23,7 @@ class TestExportSbertSequenceClassification(unittest.TestCase):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skip
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_export_sbert_sequence_classification(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
print(
|
||||
|
||||
@@ -38,7 +38,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@unittest.skip
|
||||
def test_trainer_cfg_class(self):
|
||||
dataset = MsDataset.load('clue', subset_name='tnews')
|
||||
train_dataset = dataset['train']
|
||||
|
||||
@@ -72,7 +72,7 @@ class TestTrainerWithNlp(unittest.TestCase):
|
||||
output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)
|
||||
pipeline_sentence_similarity(output_dir)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
|
||||
@unittest.skip
|
||||
def test_trainer_with_backbone_head(self):
|
||||
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base'
|
||||
kwargs = dict(
|
||||
|
||||
Reference in New Issue
Block a user