mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
Support csanmt exporting and refactor some code
1. Support csanmt exporting to savedmodel format 2. Create a new base class for text-ranking preprocessors, and move some parameters of mgeo_ranking_preprocessor to init method 3. Avoid Model & Preprocessor classes coupled with pytorch 4. Regression test supports comparing only model output 5. Support zero-shot exporting to onnx and torchscript Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11522461
This commit is contained in:
@@ -1,5 +1,12 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.utils.import_utils import is_tf_available, is_torch_available
|
||||
from .base import Exporter
|
||||
from .builder import build_exporter
|
||||
from .nlp import SbertForSequenceClassificationExporter
|
||||
from .tf_model_exporter import TfModelExporter
|
||||
from .torch_model_exporter import TorchModelExporter
|
||||
|
||||
if is_tf_available():
|
||||
from .nlp import CsanmtForTranslationExporter
|
||||
from .tf_model_exporter import TfModelExporter
|
||||
if is_torch_available():
|
||||
from .nlp import SbertForSequenceClassificationExporter, SbertForZeroShotClassificationExporter
|
||||
from .torch_model_exporter import TorchModelExporter
|
||||
|
||||
@@ -6,9 +6,11 @@ from typing import Dict, Union
|
||||
from modelscope.models import Model
|
||||
from modelscope.utils.config import Config, ConfigDict
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.hub import snapshot_download
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .builder import build_exporter
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Exporter(ABC):
|
||||
"""Exporter base class to output model to onnx, torch_script, graphdef, etc.
|
||||
@@ -46,7 +48,12 @@ class Exporter(ABC):
|
||||
if hasattr(cfg, 'export'):
|
||||
export_cfg.update(cfg.export)
|
||||
export_cfg['model'] = model
|
||||
exporter = build_exporter(export_cfg, task_name, kwargs)
|
||||
try:
|
||||
exporter = build_exporter(export_cfg, task_name, kwargs)
|
||||
except KeyError as e:
|
||||
raise KeyError(
|
||||
f'The exporting of model \'{model_cfg.type}\' with task: \'{task_name}\' '
|
||||
f'is not supported currently.') from e
|
||||
return exporter
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -1,2 +1,11 @@
|
||||
from .sbert_for_sequence_classification_exporter import \
|
||||
SbertForSequenceClassificationExporter
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.utils.import_utils import is_tf_available, is_torch_available
|
||||
|
||||
if is_tf_available():
|
||||
from .csanmt_for_translation_exporter import CsanmtForTranslationExporter
|
||||
if is_torch_available():
|
||||
from .sbert_for_sequence_classification_exporter import \
|
||||
SbertForSequenceClassificationExporter
|
||||
from .sbert_for_zero_shot_classification_exporter import \
|
||||
SbertForZeroShotClassificationExporter
|
||||
|
||||
185
modelscope/exporters/nlp/csanmt_for_translation_exporter.py
Normal file
185
modelscope/exporters/nlp/csanmt_for_translation_exporter.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.tools import freeze_graph
|
||||
|
||||
from modelscope.exporters.builder import EXPORTERS
|
||||
from modelscope.exporters.tf_model_exporter import TfModelExporter
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.pipelines.nlp.translation_pipeline import TranslationPipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.test_utils import compare_arguments_nested
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
if tf.__version__ >= '2.0':
|
||||
tf = tf.compat.v1
|
||||
tf.disable_eager_execution()
|
||||
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
|
||||
@EXPORTERS.register_module(Tasks.translation, module_name=Models.translation)
|
||||
class CsanmtForTranslationExporter(TfModelExporter):
|
||||
|
||||
def __init__(self, model=None):
|
||||
super().__init__(model)
|
||||
self.pipeline = TranslationPipeline(self.model)
|
||||
|
||||
def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]:
|
||||
return_dict = self.pipeline.preprocess(
|
||||
"Alibaba Group's mission is to let the world have no difficult business"
|
||||
)
|
||||
return {'input_wids': return_dict['input_ids']}
|
||||
|
||||
def export_saved_model(self, output_dir, rtol=None, atol=None, **kwargs):
|
||||
|
||||
def _generate_signature():
|
||||
receiver_tensors = {
|
||||
'input_wids':
|
||||
tf.saved_model.utils.build_tensor_info(
|
||||
self.pipeline.input_wids)
|
||||
}
|
||||
export_outputs = {
|
||||
'output_seqs':
|
||||
tf.saved_model.utils.build_tensor_info(
|
||||
self.pipeline.output['output_seqs'])
|
||||
}
|
||||
|
||||
signature_def = tf.saved_model.signature_def_utils.build_signature_def(
|
||||
receiver_tensors, export_outputs,
|
||||
tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
|
||||
|
||||
return {'translation_signature': signature_def}
|
||||
|
||||
with self.pipeline._session.as_default() as sess:
|
||||
builder = tf.saved_model.builder.SavedModelBuilder(output_dir)
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, [tag_constants.SERVING],
|
||||
signature_def_map=_generate_signature(),
|
||||
assets_collection=ops.get_collection(
|
||||
ops.GraphKeys.ASSET_FILEPATHS),
|
||||
clear_devices=True)
|
||||
builder.save()
|
||||
|
||||
dummy_inputs = self.generate_dummy_inputs()
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
# Restore model from the saved_modle file, that is exported by TensorFlow estimator.
|
||||
MetaGraphDef = tf.saved_model.loader.load(sess, ['serve'],
|
||||
output_dir)
|
||||
|
||||
# SignatureDef protobuf
|
||||
SignatureDef_map = MetaGraphDef.signature_def
|
||||
SignatureDef = SignatureDef_map['translation_signature']
|
||||
# TensorInfo protobuf
|
||||
X_TensorInfo = SignatureDef.inputs['input_wids']
|
||||
y_TensorInfo = SignatureDef.outputs['output_seqs']
|
||||
X = tf.saved_model.utils.get_tensor_from_tensor_info(
|
||||
X_TensorInfo, sess.graph)
|
||||
y = tf.saved_model.utils.get_tensor_from_tensor_info(
|
||||
y_TensorInfo, sess.graph)
|
||||
outputs = sess.run(y, feed_dict={X: dummy_inputs['input_wids']})
|
||||
trans_result = self.pipeline.postprocess({'output_seqs': outputs})
|
||||
logger.info(trans_result)
|
||||
|
||||
outputs_origin = self.pipeline.forward(
|
||||
{'input_ids': dummy_inputs['input_wids']})
|
||||
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
if not compare_arguments_nested('Output match failed', outputs,
|
||||
outputs_origin['output_seqs'], **tols):
|
||||
raise RuntimeError(
|
||||
'Export saved model failed because of validation error.')
|
||||
|
||||
return {'model': output_dir}
|
||||
|
||||
def export_frozen_graph_def(self,
|
||||
output_dir: str,
|
||||
rtol=None,
|
||||
atol=None,
|
||||
**kwargs):
|
||||
input_saver_def = self.pipeline.model_loader.as_saver_def()
|
||||
inference_graph_def = tf.get_default_graph().as_graph_def()
|
||||
for node in inference_graph_def.node:
|
||||
node.device = ''
|
||||
|
||||
frozen_dir = os.path.join(output_dir, 'frozen')
|
||||
tf.gfile.MkDir(frozen_dir)
|
||||
frozen_graph_path = os.path.join(frozen_dir,
|
||||
'frozen_inference_graph.pb')
|
||||
|
||||
outputs = {
|
||||
'output_trans_result':
|
||||
tf.identity(
|
||||
self.pipeline.output['output_seqs'],
|
||||
name='NmtModel/output_trans_result')
|
||||
}
|
||||
|
||||
for output_key in outputs:
|
||||
tf.add_to_collection('inference_op', outputs[output_key])
|
||||
|
||||
output_node_names = ','.join([
|
||||
'%s/%s' % ('NmtModel', output_key)
|
||||
for output_key in outputs.keys()
|
||||
])
|
||||
print(output_node_names)
|
||||
_ = freeze_graph.freeze_graph_with_def_protos(
|
||||
input_graph_def=tf.get_default_graph().as_graph_def(),
|
||||
input_saver_def=input_saver_def,
|
||||
input_checkpoint=self.pipeline.model_path,
|
||||
output_node_names=output_node_names,
|
||||
restore_op_name='save/restore_all',
|
||||
filename_tensor_name='save/Const:0',
|
||||
output_graph=frozen_graph_path,
|
||||
clear_devices=True,
|
||||
initializer_nodes='')
|
||||
|
||||
# 5. test frozen.pb
|
||||
dummy_inputs = self.generate_dummy_inputs()
|
||||
with self.pipeline._session.as_default() as sess:
|
||||
sess.run(tf.tables_initializer())
|
||||
|
||||
graph = tf.Graph()
|
||||
with tf.gfile.GFile(frozen_graph_path, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
|
||||
with graph.as_default():
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
graph.finalize()
|
||||
|
||||
with tf.Session(graph=graph) as trans_sess:
|
||||
outputs = trans_sess.run(
|
||||
'NmtModel/strided_slice_9:0',
|
||||
feed_dict={'input_wids:0': dummy_inputs['input_wids']})
|
||||
trans_result = self.pipeline.postprocess(
|
||||
{'output_seqs': outputs})
|
||||
logger.info(trans_result)
|
||||
|
||||
outputs_origin = self.pipeline.forward(
|
||||
{'input_ids': dummy_inputs['input_wids']})
|
||||
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
if not compare_arguments_nested('Output match failed', outputs,
|
||||
outputs_origin['output_seqs'], **tols):
|
||||
raise RuntimeError(
|
||||
'Export frozen graphdef failed because of validation error.')
|
||||
|
||||
return {'model': frozen_graph_path}
|
||||
|
||||
def export_onnx(self, output_dir: str, opset=13, **kwargs):
|
||||
raise NotImplementedError(
|
||||
'csanmt model does not support onnx format, consider using savedmodel instead.'
|
||||
)
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Mapping, Tuple
|
||||
|
||||
@@ -7,9 +6,7 @@ from torch.utils.data.dataloader import default_collate
|
||||
from modelscope.exporters.builder import EXPORTERS
|
||||
from modelscope.exporters.torch_model_exporter import TorchModelExporter
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.preprocessors import (
|
||||
Preprocessor, TextClassificationTransformersPreprocessor,
|
||||
build_preprocessor)
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.utils.constant import ModeKeys, Tasks
|
||||
|
||||
|
||||
@@ -17,8 +14,6 @@ from modelscope.utils.constant import ModeKeys, Tasks
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.text_classification, module_name=Models.structbert)
|
||||
@EXPORTERS.register_module(Tasks.sentence_similarity, module_name=Models.bert)
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.zero_shot_classification, module_name=Models.bert)
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.sentiment_classification, module_name=Models.bert)
|
||||
@EXPORTERS.register_module(Tasks.nli, module_name=Models.bert)
|
||||
@@ -27,8 +22,6 @@ from modelscope.utils.constant import ModeKeys, Tasks
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.sentiment_classification, module_name=Models.structbert)
|
||||
@EXPORTERS.register_module(Tasks.nli, module_name=Models.structbert)
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.zero_shot_classification, module_name=Models.structbert)
|
||||
class SbertForSequenceClassificationExporter(TorchModelExporter):
|
||||
|
||||
def generate_dummy_inputs(self,
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Mapping
|
||||
|
||||
from modelscope.exporters.builder import EXPORTERS
|
||||
from modelscope.exporters.torch_model_exporter import TorchModelExporter
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.zero_shot_classification, module_name=Models.bert)
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.zero_shot_classification, module_name=Models.structbert)
|
||||
class SbertForZeroShotClassificationExporter(TorchModelExporter):
|
||||
|
||||
def generate_dummy_inputs(self,
|
||||
candidate_labels,
|
||||
hypothesis_template,
|
||||
max_length=128,
|
||||
pair: bool = False,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
"""Generate dummy inputs for model exportation to onnx or other formats by tracing.
|
||||
|
||||
Args:
|
||||
|
||||
max_length(int): The max length of sentence, default 128.
|
||||
hypothesis_template(str): The template of prompt, like '这篇文章的标题是{}'
|
||||
candidate_labels(List): The labels of prompt,
|
||||
like ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']
|
||||
pair(bool, `optional`): Whether to generate sentence pairs or single sentences.
|
||||
|
||||
Returns:
|
||||
Dummy inputs.
|
||||
"""
|
||||
|
||||
assert hasattr(
|
||||
self.model, 'model_dir'
|
||||
), 'model_dir attribute is required to build the preprocessor'
|
||||
preprocessor = Preprocessor.from_pretrained(
|
||||
self.model.model_dir, max_length=max_length)
|
||||
return preprocessor(
|
||||
preprocessor.nlp_tokenizer.tokenizer.unk_token,
|
||||
candidate_labels=candidate_labels,
|
||||
hypothesis_template=hypothesis_template)
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
dynamic_axis = {0: 'batch', 1: 'sequence'}
|
||||
return OrderedDict([
|
||||
('input_ids', dynamic_axis),
|
||||
('attention_mask', dynamic_axis),
|
||||
('token_type_ids', dynamic_axis),
|
||||
])
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
return OrderedDict({'logits': {0: 'batch'}})
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Callable, Dict, Mapping
|
||||
|
||||
import tensorflow as tf
|
||||
@@ -7,7 +8,7 @@ import tensorflow as tf
|
||||
from modelscope.outputs import ModelOutputBase
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.regress_test_utils import compare_arguments_nested
|
||||
from modelscope.utils.test_utils import compare_arguments_nested
|
||||
from .base import Exporter
|
||||
|
||||
logger = get_logger()
|
||||
@@ -29,6 +30,14 @@ class TfModelExporter(Exporter):
|
||||
self._tf2_export_onnx(model, onnx_file, opset=opset, **kwargs)
|
||||
return {'model': onnx_file}
|
||||
|
||||
@abstractmethod
|
||||
def export_saved_model(self, output_dir: str, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def export_frozen_graph_def(self, output_dir: str, **kwargs):
|
||||
pass
|
||||
|
||||
def _tf2_export_onnx(self,
|
||||
model,
|
||||
output: str,
|
||||
@@ -59,56 +68,67 @@ class TfModelExporter(Exporter):
|
||||
onnx.save(onnx_model, output)
|
||||
|
||||
if validation:
|
||||
try:
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
logger.warn(
|
||||
'Cannot validate the exported onnx file, because '
|
||||
'the installation of onnx or onnxruntime cannot be found')
|
||||
return
|
||||
self._validate_model(dummy_inputs, model, output, rtol, atol,
|
||||
call_func)
|
||||
|
||||
def tensor_nested_numpify(tensors):
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(
|
||||
tensor_nested_numpify(t) for t in tensors)
|
||||
if isinstance(tensors, Mapping):
|
||||
# return dict
|
||||
return {
|
||||
k: tensor_nested_numpify(t)
|
||||
for k, t in tensors.items()
|
||||
}
|
||||
if isinstance(tensors, tf.Tensor):
|
||||
t = tensors.cpu()
|
||||
return t.numpy()
|
||||
return tensors
|
||||
def _validate_model(
|
||||
self,
|
||||
dummy_inputs,
|
||||
model,
|
||||
output,
|
||||
rtol: float = None,
|
||||
atol: float = None,
|
||||
call_func: Callable = None,
|
||||
):
|
||||
try:
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
logger.warn(
|
||||
'Cannot validate the exported onnx file, because '
|
||||
'the installation of onnx or onnxruntime cannot be found')
|
||||
return
|
||||
|
||||
onnx_model = onnx.load(output)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
ort_session = ort.InferenceSession(output)
|
||||
outputs_origin = call_func(
|
||||
dummy_inputs) if call_func is not None else model(dummy_inputs)
|
||||
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
|
||||
outputs_origin = list(
|
||||
tensor_nested_numpify(outputs_origin).values())
|
||||
elif isinstance(outputs_origin, (tuple, list)):
|
||||
outputs_origin = list(tensor_nested_numpify(outputs_origin))
|
||||
outputs = ort_session.run(
|
||||
None,
|
||||
tensor_nested_numpify(dummy_inputs),
|
||||
)
|
||||
outputs = tensor_nested_numpify(outputs)
|
||||
if isinstance(outputs, dict):
|
||||
outputs = list(outputs.values())
|
||||
elif isinstance(outputs, tuple):
|
||||
outputs = list(outputs)
|
||||
def tensor_nested_numpify(tensors):
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(tensor_nested_numpify(t) for t in tensors)
|
||||
if isinstance(tensors, Mapping):
|
||||
# return dict
|
||||
return {
|
||||
k: tensor_nested_numpify(t)
|
||||
for k, t in tensors.items()
|
||||
}
|
||||
if isinstance(tensors, tf.Tensor):
|
||||
t = tensors.cpu()
|
||||
return t.numpy()
|
||||
return tensors
|
||||
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
if not compare_arguments_nested('Onnx model output match failed',
|
||||
outputs, outputs_origin, **tols):
|
||||
raise RuntimeError(
|
||||
'export onnx failed because of validation error.')
|
||||
onnx_model = onnx.load(output)
|
||||
onnx.checker.check_model(onnx_model, full_check=True)
|
||||
ort_session = ort.InferenceSession(output)
|
||||
outputs_origin = call_func(
|
||||
dummy_inputs) if call_func is not None else model(dummy_inputs)
|
||||
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
|
||||
outputs_origin = list(
|
||||
tensor_nested_numpify(outputs_origin).values())
|
||||
elif isinstance(outputs_origin, (tuple, list)):
|
||||
outputs_origin = list(tensor_nested_numpify(outputs_origin))
|
||||
outputs = ort_session.run(
|
||||
None,
|
||||
tensor_nested_numpify(dummy_inputs),
|
||||
)
|
||||
outputs = tensor_nested_numpify(outputs)
|
||||
if isinstance(outputs, dict):
|
||||
outputs = list(outputs.values())
|
||||
elif isinstance(outputs, tuple):
|
||||
outputs = list(outputs)
|
||||
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
if not compare_arguments_nested('Onnx model output match failed',
|
||||
outputs, outputs_origin, **tols):
|
||||
raise RuntimeError(
|
||||
'export onnx failed because of validation error.')
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
from typing import Dict, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.constants import (FILE_HASH, MODEL_META_FILE_NAME,
|
||||
MODEL_META_MODEL_ID)
|
||||
from modelscope.hub.constants import FILE_HASH
|
||||
from modelscope.hub.git import GitCommandWrapper
|
||||
from modelscope.hub.utils.caching import FileSystemCache, ModelFileSystemCache
|
||||
from modelscope.hub.utils.caching import ModelFileSystemCache
|
||||
from modelscope.hub.utils.utils import compute_hash
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
|
||||
@@ -9,4 +9,5 @@ from .base import Head, Model
|
||||
from .builder import BACKBONES, HEADS, MODELS, build_model
|
||||
|
||||
if is_torch_available():
|
||||
from .base import TorchModel, TorchHead
|
||||
from .base.base_torch_model import TorchModel
|
||||
from .base.base_torch_head import TorchHead
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.utils.import_utils import is_torch_available
|
||||
from .base_head import * # noqa F403
|
||||
from .base_model import * # noqa F403
|
||||
from .base_torch_head import * # noqa F403
|
||||
from .base_torch_model import * # noqa F403
|
||||
|
||||
if is_torch_available():
|
||||
from .base_torch_model import TorchModel
|
||||
from .base_torch_head import TorchHead
|
||||
|
||||
@@ -2,13 +2,11 @@
|
||||
import os
|
||||
import os.path as osp
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from modelscope.hub.check_model import check_local_model_is_latest
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models.builder import build_model
|
||||
from modelscope.utils.checkpoint import (save_checkpoint, save_configuration,
|
||||
save_pretrained)
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
|
||||
from modelscope.utils.device import verify_device
|
||||
@@ -150,9 +148,7 @@ class Model(ABC):
|
||||
def save_pretrained(self,
|
||||
target_folder: Union[str, os.PathLike],
|
||||
save_checkpoint_names: Union[str, List[str]] = None,
|
||||
save_function: Callable = save_checkpoint,
|
||||
config: Optional[dict] = None,
|
||||
save_config_function: Callable = save_configuration,
|
||||
**kwargs):
|
||||
"""save the pretrained model, its configuration and other related files to a directory,
|
||||
so that it can be re-loaded
|
||||
@@ -164,21 +160,8 @@ class Model(ABC):
|
||||
save_checkpoint_names (Union[str, List[str]]):
|
||||
The checkpoint names to be saved in the target_folder
|
||||
|
||||
save_function (Callable, optional):
|
||||
The function to use to save the state dictionary.
|
||||
|
||||
config (Optional[dict], optional):
|
||||
The config for the configuration.json, might not be identical with model.config
|
||||
|
||||
save_config_function (Callble, optional):
|
||||
The function to use to save the configuration.
|
||||
|
||||
"""
|
||||
if config is None and hasattr(self, 'cfg'):
|
||||
config = self.cfg
|
||||
|
||||
if config is not None:
|
||||
save_config_function(target_folder, config)
|
||||
|
||||
save_pretrained(self, target_folder, save_checkpoint_names,
|
||||
save_function, **kwargs)
|
||||
raise NotImplementedError(
|
||||
'save_pretrained method need to be implemented by the subclass.')
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
|
||||
from modelscope.utils.checkpoint import (save_checkpoint, save_configuration,
|
||||
save_pretrained)
|
||||
from modelscope.utils.file_utils import func_receive_dict_inputs
|
||||
from modelscope.utils.hub import parse_label_mapping
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .base_model import Model
|
||||
|
||||
@@ -88,3 +90,39 @@ class TorchModel(Model, torch.nn.Module):
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def save_pretrained(self,
|
||||
target_folder: Union[str, os.PathLike],
|
||||
save_checkpoint_names: Union[str, List[str]] = None,
|
||||
save_function: Callable = save_checkpoint,
|
||||
config: Optional[dict] = None,
|
||||
save_config_function: Callable = save_configuration,
|
||||
**kwargs):
|
||||
"""save the pretrained model, its configuration and other related files to a directory,
|
||||
so that it can be re-loaded
|
||||
|
||||
Args:
|
||||
target_folder (Union[str, os.PathLike]):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
|
||||
save_checkpoint_names (Union[str, List[str]]):
|
||||
The checkpoint names to be saved in the target_folder
|
||||
|
||||
save_function (Callable, optional):
|
||||
The function to use to save the state dictionary.
|
||||
|
||||
config (Optional[dict], optional):
|
||||
The config for the configuration.json, might not be identical with model.config
|
||||
|
||||
save_config_function (Callble, optional):
|
||||
The function to use to save the configuration.
|
||||
|
||||
"""
|
||||
if config is None and hasattr(self, 'cfg'):
|
||||
config = self.cfg
|
||||
|
||||
if config is not None:
|
||||
save_config_function(target_folder, config)
|
||||
|
||||
save_pretrained(self, target_folder, save_checkpoint_names,
|
||||
save_function, **kwargs)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from modelscope.utils.config import ConfigDict
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule
|
||||
from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg
|
||||
from modelscope.utils.registry import Registry, build_from_cfg
|
||||
|
||||
MODELS = Registry('models')
|
||||
BACKBONES = MODELS
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional,
|
||||
Sequence, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
||||
from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES
|
||||
from datasets.utils.file_utils import is_relative_path
|
||||
@@ -43,42 +42,6 @@ def format_list(para) -> List:
|
||||
return para
|
||||
|
||||
|
||||
class MsMapDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, dataset: Iterable, preprocessor_list, retained_columns,
|
||||
columns, to_tensor):
|
||||
super(MsDataset).__init__()
|
||||
self.dataset = dataset
|
||||
self.preprocessor_list = preprocessor_list
|
||||
self.to_tensor = to_tensor
|
||||
self.retained_columns = retained_columns
|
||||
self.columns = columns
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def type_converter(self, x):
|
||||
if self.to_tensor:
|
||||
return torch.tensor(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
def __getitem__(self, index):
|
||||
item_dict = self.dataset[index]
|
||||
res = {
|
||||
k: self.type_converter(item_dict[k])
|
||||
for k in self.columns
|
||||
if (not self.to_tensor) or k in self.retained_columns
|
||||
}
|
||||
for preprocessor in self.preprocessor_list:
|
||||
res.update({
|
||||
k: self.type_converter(v)
|
||||
for k, v in preprocessor(item_dict).items()
|
||||
if (not self.to_tensor) or k in self.retained_columns
|
||||
})
|
||||
return res
|
||||
|
||||
|
||||
class MsDataset:
|
||||
"""
|
||||
ModelScope Dataset (aka, MsDataset) is backed by a huggingface Dataset to
|
||||
@@ -303,6 +266,7 @@ class MsDataset:
|
||||
columns: Union[str, List[str]] = None,
|
||||
to_tensor: bool = True,
|
||||
):
|
||||
import torch
|
||||
preprocessor_list = preprocessors if isinstance(
|
||||
preprocessors, list) else [preprocessors]
|
||||
|
||||
@@ -332,6 +296,42 @@ class MsDataset:
|
||||
continue
|
||||
retained_columns.append(k)
|
||||
|
||||
class MsMapDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, dataset: Iterable, preprocessor_list,
|
||||
retained_columns, columns, to_tensor):
|
||||
super(MsDataset).__init__()
|
||||
self.dataset = dataset
|
||||
self.preprocessor_list = preprocessor_list
|
||||
self.to_tensor = to_tensor
|
||||
self.retained_columns = retained_columns
|
||||
self.columns = columns
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def type_converter(self, x):
|
||||
import torch
|
||||
if self.to_tensor:
|
||||
return torch.tensor(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
def __getitem__(self, index):
|
||||
item_dict = self.dataset[index]
|
||||
res = {
|
||||
k: self.type_converter(item_dict[k])
|
||||
for k in self.columns
|
||||
if (not self.to_tensor) or k in self.retained_columns
|
||||
}
|
||||
for preprocessor in self.preprocessor_list:
|
||||
res.update({
|
||||
k: self.type_converter(v)
|
||||
for k, v in preprocessor(item_dict).items()
|
||||
if (not self.to_tensor) or k in self.retained_columns
|
||||
})
|
||||
return res
|
||||
|
||||
return MsMapDataset(self._hf_ds, preprocessor_list, retained_columns,
|
||||
columns, to_tensor)
|
||||
|
||||
|
||||
@@ -22,7 +22,6 @@ from modelscope.utils.device import (create_device, device_placement,
|
||||
from modelscope.utils.hub import read_config, snapshot_download
|
||||
from modelscope.utils.import_utils import is_tf_available, is_torch_available
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.torch_utils import _find_free_port, _is_free_port
|
||||
from .util import is_model, is_official_hub_path
|
||||
|
||||
if is_torch_available():
|
||||
@@ -426,6 +425,7 @@ class DistributedPipeline(Pipeline):
|
||||
'master_ip']
|
||||
master_port = '29500' if 'master_port' not in kwargs else kwargs[
|
||||
'master_port']
|
||||
from modelscope.utils.torch_utils import _find_free_port, _is_free_port
|
||||
if not _is_free_port(int(master_port)):
|
||||
master_port = str(_find_free_port())
|
||||
self.model_pool.map(
|
||||
|
||||
@@ -44,7 +44,7 @@ class TranslationPipeline(Pipeline):
|
||||
model = self.model.model_dir
|
||||
tf.reset_default_graph()
|
||||
|
||||
model_path = osp.join(
|
||||
self.model_path = osp.join(
|
||||
osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER), 'ckpt-0')
|
||||
|
||||
self.cfg = Config.from_file(osp.join(model, ModelFile.CONFIGURATION))
|
||||
@@ -88,10 +88,10 @@ class TranslationPipeline(Pipeline):
|
||||
self.output.update(output)
|
||||
|
||||
with self._session.as_default() as sess:
|
||||
logger.info(f'loading model from {model_path}')
|
||||
logger.info(f'loading model from {self.model_path}')
|
||||
# load model
|
||||
model_loader = tf.train.Saver(tf.global_variables())
|
||||
model_loader.restore(sess, model_path)
|
||||
self.model_loader = tf.train.Saver(tf.global_variables())
|
||||
self.model_loader.restore(sess, self.model_path)
|
||||
|
||||
def preprocess(self, input: str) -> Dict[str, Any]:
|
||||
input = input.split('<SENT_SPLIT>')
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
import os.path as osp
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.utils.config import Config
|
||||
@@ -86,6 +84,7 @@ def is_model(path: Union[str, List]):
|
||||
|
||||
|
||||
def batch_process(model, data):
|
||||
import torch
|
||||
if model.__class__.__name__ == 'OfaForAllTasks':
|
||||
# collate batch data due to the nested data structure
|
||||
assert isinstance(data, list)
|
||||
|
||||
@@ -5,7 +5,6 @@ from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Union
|
||||
|
||||
from modelscope.metainfo import Models, Preprocessors
|
||||
from modelscope.utils.checkpoint import save_configuration
|
||||
from modelscope.utils.config import Config, ConfigDict
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke,
|
||||
ModeKeys, Tasks)
|
||||
@@ -312,7 +311,7 @@ class Preprocessor(ABC):
|
||||
def save_pretrained(self,
|
||||
target_folder: Union[str, os.PathLike],
|
||||
config: Optional[dict] = None,
|
||||
save_config_function: Callable = save_configuration):
|
||||
save_config_function: Callable = None):
|
||||
"""Save the preprocessor, its configuration and other related files to a directory,
|
||||
so that it can be re-loaded
|
||||
|
||||
@@ -341,4 +340,7 @@ class Preprocessor(ABC):
|
||||
'preprocessor']['val']:
|
||||
config['preprocessor']['val']['mode'] = 'inference'
|
||||
|
||||
if save_config_function is None:
|
||||
from modelscope.utils.checkpoint import save_configuration
|
||||
save_config_function = save_configuration
|
||||
save_config_function(target_folder, config)
|
||||
|
||||
@@ -10,6 +10,7 @@ from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.preprocessors.builder import PREPROCESSORS
|
||||
from modelscope.utils.constant import Fields, ModeKeys
|
||||
from modelscope.utils.type_assert import type_assert
|
||||
from .text_ranking_preprocessor import TextRankingPreprocessorBase
|
||||
|
||||
|
||||
class GisUtt:
|
||||
@@ -113,7 +114,7 @@ class GisUtt:
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.nlp, module_name=Preprocessors.mgeo_ranking)
|
||||
class MGeoRankingTransformersPreprocessor(Preprocessor):
|
||||
class MGeoRankingTransformersPreprocessor(TextRankingPreprocessorBase):
|
||||
|
||||
def __init__(self,
|
||||
model_dir: str,
|
||||
@@ -125,39 +126,39 @@ class MGeoRankingTransformersPreprocessor(Preprocessor):
|
||||
label='labels',
|
||||
qid='qid',
|
||||
max_length=None,
|
||||
padding='longest',
|
||||
truncation=True,
|
||||
use_fast=True,
|
||||
**kwargs):
|
||||
"""The tokenizer preprocessor class for the text ranking preprocessor.
|
||||
|
||||
Args:
|
||||
model_dir(str, `optional`): The model dir used to parse the label mapping, can be None.
|
||||
first_sequence(str, `optional`): The key of the first sequence.
|
||||
second_sequence(str, `optional`): The key of the second sequence.
|
||||
label(str, `optional`): The keys of the label columns, default `labels`.
|
||||
qid(str, `optional`): The qid info.
|
||||
mode: The mode for the preprocessor.
|
||||
max_length: The max sequence length which the model supported,
|
||||
will be passed into tokenizer as the 'max_length' param.
|
||||
padding: The padding method
|
||||
truncation: The truncation method
|
||||
"""
|
||||
super().__init__(mode)
|
||||
super().__init__(
|
||||
mode=mode,
|
||||
first_sequence=first_sequence,
|
||||
second_sequence=second_sequence,
|
||||
label=label,
|
||||
qid=qid)
|
||||
self.model_dir = model_dir
|
||||
self.first_sequence = first_sequence
|
||||
self.second_sequence = second_sequence
|
||||
self.first_sequence_gis = first_sequence_gis
|
||||
self.second_sequence_gis = second_sequence_gis
|
||||
|
||||
self.label = label
|
||||
self.qid = qid
|
||||
self.sequence_length = max_length if max_length is not None else kwargs.get(
|
||||
'sequence_length', 128)
|
||||
kwargs.pop('sequence_length', None)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
|
||||
self.tokenize_kwargs = kwargs
|
||||
self.tokenize_kwargs['padding'] = padding
|
||||
self.tokenize_kwargs['truncation'] = truncation
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_dir, use_fast=use_fast)
|
||||
|
||||
@type_assert(object, dict)
|
||||
def __call__(self,
|
||||
data: Dict,
|
||||
padding='longest',
|
||||
truncation=True,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
def __call__(self, data: Dict, **kwargs) -> Dict[str, Any]:
|
||||
sentence1 = data.get(self.first_sequence)
|
||||
sentence2 = data.get(self.second_sequence)
|
||||
labels = data.get(self.label)
|
||||
@@ -176,12 +177,9 @@ class MGeoRankingTransformersPreprocessor(Preprocessor):
|
||||
'max_length', kwargs.pop('sequence_length', self.sequence_length))
|
||||
if 'return_tensors' not in kwargs:
|
||||
kwargs['return_tensors'] = 'pt'
|
||||
feature = self.tokenizer(
|
||||
sentence1,
|
||||
sentence2,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs)
|
||||
|
||||
self.tokenize_kwargs.update(kwargs)
|
||||
feature = self.tokenizer(sentence1, sentence2, **self.tokenize_kwargs)
|
||||
if labels is not None:
|
||||
feature['labels'] = labels
|
||||
if qid is not None:
|
||||
|
||||
@@ -11,9 +11,33 @@ from modelscope.utils.constant import Fields, ModeKeys
|
||||
from modelscope.utils.type_assert import type_assert
|
||||
|
||||
|
||||
class TextRankingPreprocessorBase(Preprocessor):
|
||||
|
||||
def __init__(self,
|
||||
mode: str = ModeKeys.INFERENCE,
|
||||
first_sequence='source_sentence',
|
||||
second_sequence='sentences_to_compare',
|
||||
label='labels',
|
||||
qid='qid'):
|
||||
"""The tokenizer preprocessor class for the text ranking preprocessor.
|
||||
|
||||
Args:
|
||||
first_sequence(str, `optional`): The key of the first sequence.
|
||||
second_sequence(str, `optional`): The key of the second sequence.
|
||||
label(str, `optional`): The keys of the label columns, default `labels`.
|
||||
qid(str, `optional`): The qid info.
|
||||
mode: The mode for the preprocessor.
|
||||
"""
|
||||
super().__init__(mode)
|
||||
self.first_sequence = first_sequence
|
||||
self.second_sequence = second_sequence
|
||||
self.label = label
|
||||
self.qid = qid
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.nlp, module_name=Preprocessors.text_ranking)
|
||||
class TextRankingTransformersPreprocessor(Preprocessor):
|
||||
class TextRankingTransformersPreprocessor(TextRankingPreprocessorBase):
|
||||
|
||||
def __init__(self,
|
||||
model_dir: str,
|
||||
@@ -23,36 +47,33 @@ class TextRankingTransformersPreprocessor(Preprocessor):
|
||||
label='labels',
|
||||
qid='qid',
|
||||
max_length=None,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
**kwargs):
|
||||
"""The tokenizer preprocessor class for the text ranking preprocessor.
|
||||
|
||||
Args:
|
||||
model_dir(str, `optional`): The model dir used to parse the label mapping, can be None.
|
||||
first_sequence(str, `optional`): The key of the first sequence.
|
||||
second_sequence(str, `optional`): The key of the second sequence.
|
||||
label(str, `optional`): The keys of the label columns, default `labels`.
|
||||
qid(str, `optional`): The qid info.
|
||||
mode: The mode for the preprocessor.
|
||||
max_length: The max sequence length which the model supported,
|
||||
will be passed into tokenizer as the 'max_length' param.
|
||||
"""
|
||||
super().__init__(mode)
|
||||
super().__init__(
|
||||
mode=mode,
|
||||
first_sequence=first_sequence,
|
||||
second_sequence=second_sequence,
|
||||
label=label,
|
||||
qid=qid)
|
||||
self.model_dir = model_dir
|
||||
self.first_sequence = first_sequence
|
||||
self.second_sequence = second_sequence
|
||||
self.label = label
|
||||
self.qid = qid
|
||||
self.sequence_length = max_length if max_length is not None else kwargs.get(
|
||||
'sequence_length', 128)
|
||||
kwargs.pop('sequence_length', None)
|
||||
self.tokenize_kwargs = kwargs
|
||||
self.tokenize_kwargs['padding'] = padding
|
||||
self.tokenize_kwargs['truncation'] = truncation
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
|
||||
|
||||
@type_assert(object, dict)
|
||||
def __call__(self,
|
||||
data: Dict,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
def __call__(self, data: Dict, **kwargs) -> Dict[str, Any]:
|
||||
sentence1 = data.get(self.first_sequence)
|
||||
sentence2 = data.get(self.second_sequence)
|
||||
labels = data.get(self.label)
|
||||
@@ -67,12 +88,9 @@ class TextRankingTransformersPreprocessor(Preprocessor):
|
||||
'max_length', kwargs.pop('sequence_length', self.sequence_length))
|
||||
if 'return_tensors' not in kwargs:
|
||||
kwargs['return_tensors'] = 'pt'
|
||||
feature = self.tokenizer(
|
||||
sentence1,
|
||||
sentence2,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
**kwargs)
|
||||
|
||||
self.tokenize_kwargs.update(kwargs)
|
||||
feature = self.tokenizer(sentence1, sentence2, **self.tokenize_kwargs)
|
||||
if labels is not None:
|
||||
feature['labels'] = labels
|
||||
if qid is not None:
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import os
|
||||
import os.path as osp
|
||||
import platform
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import io
|
||||
|
||||
import cv2
|
||||
import json
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
@@ -265,6 +264,7 @@ def postprocess(req, resp):
|
||||
new_resp.get(output_key)
|
||||
if file_type == 'png' or file_type == 'jpg':
|
||||
content = new_resp.get(output_key)
|
||||
import cv2
|
||||
_, img_encode = cv2.imencode('.' + file_type, content)
|
||||
img_bytes = img_encode.tobytes()
|
||||
return type(img_bytes)
|
||||
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
from modelscope.utils.constant import Devices, Frameworks
|
||||
from modelscope.utils.import_utils import is_tf_available, is_torch_available
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
import torch.optim
|
||||
from torch import nn
|
||||
|
||||
from modelscope.utils.service_utils import NumpyEncoder
|
||||
from .test_utils import compare_arguments_nested
|
||||
|
||||
|
||||
class RegressTool:
|
||||
@@ -71,6 +71,7 @@ class RegressTool:
|
||||
module: nn.Module,
|
||||
file_name: str,
|
||||
compare_fn=None,
|
||||
compare_model_output=True,
|
||||
**kwargs):
|
||||
"""Monitor a pytorch module in a single forward.
|
||||
|
||||
@@ -78,6 +79,7 @@ class RegressTool:
|
||||
module: A torch module
|
||||
file_name: The file_name to store or load file
|
||||
compare_fn: A custom fn used to compare the results manually.
|
||||
compare_model_output: Only compare the input module's output, skip all other tensors
|
||||
|
||||
>>> def compare_fn(v1, v2, key, type):
|
||||
>>> return None
|
||||
@@ -120,17 +122,46 @@ class RegressTool:
|
||||
with open(baseline, 'rb') as f:
|
||||
base = pickle.load(f)
|
||||
|
||||
class SafeNumpyEncoder(NumpyEncoder):
|
||||
class SafeNumpyEncoder(json.JSONEncoder):
|
||||
|
||||
def parse_default(self, obj):
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
|
||||
if isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
def default(self, obj):
|
||||
try:
|
||||
return super().default(obj)
|
||||
return self.default(obj)
|
||||
except Exception:
|
||||
print(
|
||||
f'Type {obj.__class__} cannot be serialized and printed'
|
||||
)
|
||||
return None
|
||||
|
||||
if compare_model_output:
|
||||
print(
|
||||
'Ignore inner modules, only the output of the model will be verified.'
|
||||
)
|
||||
base = {
|
||||
key: value
|
||||
for key, value in base.items() if key == file_name
|
||||
}
|
||||
for key, value in base.items():
|
||||
value['input'] = {'args': None, 'kwargs': None}
|
||||
io_json = {
|
||||
key: value
|
||||
for key, value in io_json.items() if key == file_name
|
||||
}
|
||||
for key, value in io_json.items():
|
||||
value['input'] = {'args': None, 'kwargs': None}
|
||||
|
||||
print(f'baseline: {json.dumps(base, cls=SafeNumpyEncoder)}')
|
||||
print(f'latest : {json.dumps(io_json, cls=SafeNumpyEncoder)}')
|
||||
if not compare_io_and_print(base, io_json, compare_fn, **kwargs):
|
||||
@@ -326,10 +357,75 @@ class MsRegressTool(RegressTool):
|
||||
|
||||
def lazy_stop_callback():
|
||||
|
||||
from modelscope.trainers.hooks.hook import Hook, Priority
|
||||
class EarlyStopHook:
|
||||
PRIORITY = 90
|
||||
|
||||
class EarlyStopHook(Hook):
|
||||
PRIORITY = Priority.VERY_LOW
|
||||
def before_run(self, trainer):
|
||||
pass
|
||||
|
||||
def after_run(self, trainer):
|
||||
pass
|
||||
|
||||
def before_epoch(self, trainer):
|
||||
pass
|
||||
|
||||
def after_epoch(self, trainer):
|
||||
pass
|
||||
|
||||
def before_iter(self, trainer):
|
||||
pass
|
||||
|
||||
def before_train_epoch(self, trainer):
|
||||
self.before_epoch(trainer)
|
||||
|
||||
def before_val_epoch(self, trainer):
|
||||
self.before_epoch(trainer)
|
||||
|
||||
def after_train_epoch(self, trainer):
|
||||
self.after_epoch(trainer)
|
||||
|
||||
def after_val_epoch(self, trainer):
|
||||
self.after_epoch(trainer)
|
||||
|
||||
def before_train_iter(self, trainer):
|
||||
self.before_iter(trainer)
|
||||
|
||||
def before_val_iter(self, trainer):
|
||||
self.before_iter(trainer)
|
||||
|
||||
def after_train_iter(self, trainer):
|
||||
self.after_iter(trainer)
|
||||
|
||||
def after_val_iter(self, trainer):
|
||||
self.after_iter(trainer)
|
||||
|
||||
def every_n_epochs(self, trainer, n):
|
||||
return (trainer.epoch + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def every_n_inner_iters(self, runner, n):
|
||||
return (runner.inner_iter
|
||||
+ 1) % n == 0 if n > 0 else False
|
||||
|
||||
def every_n_iters(self, trainer, n):
|
||||
return (trainer.iter + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def end_of_epoch(self, trainer):
|
||||
return trainer.inner_iter + 1 == trainer.iters_per_epoch
|
||||
|
||||
def is_last_epoch(self, trainer):
|
||||
return trainer.epoch + 1 == trainer.max_epochs
|
||||
|
||||
def is_last_iter(self, trainer):
|
||||
return trainer.iter + 1 == trainer.max_iters
|
||||
|
||||
def get_triggered_stages(self):
|
||||
return []
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
pass
|
||||
|
||||
def after_iter(self, trainer):
|
||||
raise MsRegressTool.EarlyStopError('Test finished.')
|
||||
@@ -526,92 +622,6 @@ def intercept_module(module: nn.Module,
|
||||
intercept_module(module, io_json, full_name, restore)
|
||||
|
||||
|
||||
def compare_arguments_nested(print_content,
|
||||
arg1,
|
||||
arg2,
|
||||
rtol=1.e-3,
|
||||
atol=1.e-8,
|
||||
ignore_unknown_type=True):
|
||||
type1 = type(arg1)
|
||||
type2 = type(arg2)
|
||||
if type1.__name__ != type2.__name__:
|
||||
if print_content is not None:
|
||||
print(
|
||||
f'{print_content}, type not equal:{type1.__name__} and {type2.__name__}'
|
||||
)
|
||||
return False
|
||||
|
||||
if arg1 is None:
|
||||
return True
|
||||
elif isinstance(arg1, (int, str, bool, np.bool, np.integer, np.str)):
|
||||
if arg1 != arg2:
|
||||
if print_content is not None:
|
||||
print(f'{print_content}, arg1:{arg1}, arg2:{arg2}')
|
||||
return False
|
||||
return True
|
||||
elif isinstance(arg1, (float, np.floating)):
|
||||
if not np.isclose(arg1, arg2, rtol=rtol, atol=atol, equal_nan=True):
|
||||
if print_content is not None:
|
||||
print(f'{print_content}, arg1:{arg1}, arg2:{arg2}')
|
||||
return False
|
||||
return True
|
||||
elif isinstance(arg1, (tuple, list)):
|
||||
if len(arg1) != len(arg2):
|
||||
if print_content is not None:
|
||||
print(
|
||||
f'{print_content}, length is not equal:{len(arg1)}, {len(arg2)}'
|
||||
)
|
||||
return False
|
||||
if not all([
|
||||
compare_arguments_nested(
|
||||
None, sub_arg1, sub_arg2, rtol=rtol, atol=atol)
|
||||
for sub_arg1, sub_arg2 in zip(arg1, arg2)
|
||||
]):
|
||||
if print_content is not None:
|
||||
print(f'{print_content}')
|
||||
return False
|
||||
return True
|
||||
elif isinstance(arg1, Mapping):
|
||||
keys1 = arg1.keys()
|
||||
keys2 = arg2.keys()
|
||||
if len(keys1) != len(keys2):
|
||||
if print_content is not None:
|
||||
print(
|
||||
f'{print_content}, key length is not equal:{len(keys1)}, {len(keys2)}'
|
||||
)
|
||||
return False
|
||||
if len(set(keys1) - set(keys2)) > 0:
|
||||
if print_content is not None:
|
||||
print(f'{print_content}, key diff:{set(keys1) - set(keys2)}')
|
||||
return False
|
||||
if not all([
|
||||
compare_arguments_nested(
|
||||
None, arg1[key], arg2[key], rtol=rtol, atol=atol)
|
||||
for key in keys1
|
||||
]):
|
||||
if print_content is not None:
|
||||
print(f'{print_content}')
|
||||
return False
|
||||
return True
|
||||
elif isinstance(arg1, np.ndarray):
|
||||
arg1 = np.where(np.equal(arg1, None), np.NaN,
|
||||
arg1).astype(dtype=np.float)
|
||||
arg2 = np.where(np.equal(arg2, None), np.NaN,
|
||||
arg2).astype(dtype=np.float)
|
||||
if not all(
|
||||
np.isclose(arg1, arg2, rtol=rtol, atol=atol,
|
||||
equal_nan=True).flatten()):
|
||||
if print_content is not None:
|
||||
print(f'{print_content}')
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
if ignore_unknown_type:
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f'type not supported: {type1}')
|
||||
|
||||
|
||||
def compare_io_and_print(baseline_json, io_json, compare_fn=None, **kwargs):
|
||||
if compare_fn is None:
|
||||
|
||||
|
||||
@@ -5,12 +5,9 @@ from io import BytesIO
|
||||
import json
|
||||
import numpy as np
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.outputs import TASK_OUTPUTS, OutputKeys
|
||||
from modelscope.pipeline_inputs import TASK_INPUTS, InputType
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks, TasksIODescriptions
|
||||
|
||||
|
||||
# service data decoder func decodes data from network and convert it to pipeline's input
|
||||
@@ -91,12 +88,14 @@ def decode_base64_to_binary(encoding):
|
||||
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
from PIL import Image
|
||||
content = encoding.split(';')[1]
|
||||
image_encoded = content.split(',')[1]
|
||||
return Image.open(BytesIO(base64.b64decode(image_encoded)))
|
||||
|
||||
|
||||
def encode_array_to_img_base64(image_array):
|
||||
from PIL import Image
|
||||
with BytesIO() as output_bytes:
|
||||
pil_image = Image.fromarray(image_array.astype(np.uint8))
|
||||
pil_image.save(output_bytes, 'PNG')
|
||||
|
||||
@@ -12,13 +12,12 @@ import tarfile
|
||||
import tempfile
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from modelscope.utils.import_utils import is_tf_available, is_torch_available
|
||||
from modelscope.utils.torch_utils import _find_free_port
|
||||
|
||||
TEST_LEVEL = 2
|
||||
TEST_LEVEL_STR = 'TEST_LEVEL'
|
||||
@@ -49,7 +48,7 @@ def set_test_level(level: int):
|
||||
TEST_LEVEL = level
|
||||
|
||||
|
||||
class DummyTorchDataset(Dataset):
|
||||
class DummyTorchDataset:
|
||||
|
||||
def __init__(self, feat, label, num) -> None:
|
||||
self.feat = feat
|
||||
@@ -57,6 +56,7 @@ class DummyTorchDataset(Dataset):
|
||||
self.num = num
|
||||
|
||||
def __getitem__(self, index):
|
||||
import torch
|
||||
return {
|
||||
'feat': torch.Tensor(self.feat),
|
||||
'labels': torch.Tensor(self.label)
|
||||
@@ -119,6 +119,92 @@ def get_case_model_info():
|
||||
return model_cases
|
||||
|
||||
|
||||
def compare_arguments_nested(print_content,
|
||||
arg1,
|
||||
arg2,
|
||||
rtol=1.e-3,
|
||||
atol=1.e-8,
|
||||
ignore_unknown_type=True):
|
||||
type1 = type(arg1)
|
||||
type2 = type(arg2)
|
||||
if type1.__name__ != type2.__name__:
|
||||
if print_content is not None:
|
||||
print(
|
||||
f'{print_content}, type not equal:{type1.__name__} and {type2.__name__}'
|
||||
)
|
||||
return False
|
||||
|
||||
if arg1 is None:
|
||||
return True
|
||||
elif isinstance(arg1, (int, str, bool, np.bool, np.integer, np.str)):
|
||||
if arg1 != arg2:
|
||||
if print_content is not None:
|
||||
print(f'{print_content}, arg1:{arg1}, arg2:{arg2}')
|
||||
return False
|
||||
return True
|
||||
elif isinstance(arg1, (float, np.floating)):
|
||||
if not np.isclose(arg1, arg2, rtol=rtol, atol=atol, equal_nan=True):
|
||||
if print_content is not None:
|
||||
print(f'{print_content}, arg1:{arg1}, arg2:{arg2}')
|
||||
return False
|
||||
return True
|
||||
elif isinstance(arg1, (tuple, list)):
|
||||
if len(arg1) != len(arg2):
|
||||
if print_content is not None:
|
||||
print(
|
||||
f'{print_content}, length is not equal:{len(arg1)}, {len(arg2)}'
|
||||
)
|
||||
return False
|
||||
if not all([
|
||||
compare_arguments_nested(
|
||||
None, sub_arg1, sub_arg2, rtol=rtol, atol=atol)
|
||||
for sub_arg1, sub_arg2 in zip(arg1, arg2)
|
||||
]):
|
||||
if print_content is not None:
|
||||
print(f'{print_content}')
|
||||
return False
|
||||
return True
|
||||
elif isinstance(arg1, Mapping):
|
||||
keys1 = arg1.keys()
|
||||
keys2 = arg2.keys()
|
||||
if len(keys1) != len(keys2):
|
||||
if print_content is not None:
|
||||
print(
|
||||
f'{print_content}, key length is not equal:{len(keys1)}, {len(keys2)}'
|
||||
)
|
||||
return False
|
||||
if len(set(keys1) - set(keys2)) > 0:
|
||||
if print_content is not None:
|
||||
print(f'{print_content}, key diff:{set(keys1) - set(keys2)}')
|
||||
return False
|
||||
if not all([
|
||||
compare_arguments_nested(
|
||||
None, arg1[key], arg2[key], rtol=rtol, atol=atol)
|
||||
for key in keys1
|
||||
]):
|
||||
if print_content is not None:
|
||||
print(f'{print_content}')
|
||||
return False
|
||||
return True
|
||||
elif isinstance(arg1, np.ndarray):
|
||||
arg1 = np.where(np.equal(arg1, None), np.NaN,
|
||||
arg1).astype(dtype=np.float)
|
||||
arg2 = np.where(np.equal(arg2, None), np.NaN,
|
||||
arg2).astype(dtype=np.float)
|
||||
if not all(
|
||||
np.isclose(arg1, arg2, rtol=rtol, atol=atol,
|
||||
equal_nan=True).flatten()):
|
||||
if print_content is not None:
|
||||
print(f'{print_content}')
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
if ignore_unknown_type:
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f'type not supported: {type1}')
|
||||
|
||||
|
||||
_DIST_SCRIPT_TEMPLATE = """
|
||||
import ast
|
||||
import argparse
|
||||
@@ -263,6 +349,7 @@ class DistributedTestCase(unittest.TestCase):
|
||||
save_all_ranks=False,
|
||||
*args,
|
||||
**kwargs):
|
||||
from .torch_utils import _find_free_port
|
||||
ip = socket.gethostbyname(socket.gethostname())
|
||||
dist_start_cmd = '%s -m torch.distributed.launch --nproc_per_node=%d --master_addr=\'%s\' --master_port=%s' % (
|
||||
sys.executable, num_gpus, ip, _find_free_port())
|
||||
|
||||
40
tests/export/test_export_csanmt_model.py
Normal file
40
tests/export/test_export_csanmt_model.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
|
||||
from tensorflow.keras.preprocessing import image
|
||||
|
||||
from modelscope.exporters import TfModelExporter
|
||||
from modelscope.models import Model
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import compare_arguments_nested, test_level
|
||||
|
||||
|
||||
class TestExportTfModel(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2,
|
||||
'test with numpy version == 1.18.1')
|
||||
def test_export_csanmt(self):
|
||||
model = Model.from_pretrained('damo/nlp_csanmt_translation_en2zh_base')
|
||||
print(
|
||||
TfModelExporter.from_model(model).export_saved_model(
|
||||
output_dir=self.tmp_dir))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
47
tests/export/test_export_sbert_zero_shot_classification.py
Normal file
47
tests/export/test_export_sbert_zero_shot_classification.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
|
||||
from modelscope.exporters import Exporter, TorchModelExporter
|
||||
from modelscope.models import Model
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TestExportSbertZeroShotClassification(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
self.model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base'
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_export_sbert_sequence_classification(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
print(
|
||||
Exporter.from_model(model).export_onnx(
|
||||
candidate_labels=[
|
||||
'文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'
|
||||
],
|
||||
hypothesis_template='这篇文章的标题是{}',
|
||||
output_dir=self.tmp_dir))
|
||||
print(
|
||||
Exporter.from_model(model).export_torch_script(
|
||||
candidate_labels=[
|
||||
'文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'
|
||||
],
|
||||
hypothesis_template='这篇文章的标题是{}',
|
||||
output_dir=self.tmp_dir))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -10,7 +10,6 @@ from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
|
||||
from tensorflow.keras.preprocessing import image
|
||||
|
||||
from modelscope.exporters import TfModelExporter
|
||||
from modelscope.utils.regress_test_utils import compare_arguments_nested
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import LogKeys, ModelFile
|
||||
from modelscope.utils.test_utils import create_dummy_test_dataset
|
||||
@@ -20,7 +20,7 @@ dummy_dataset = create_dummy_test_dataset(
|
||||
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20)
|
||||
|
||||
|
||||
class DummyModel(nn.Module, Model):
|
||||
class DummyModel(TorchModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@@ -11,7 +11,7 @@ from torch import nn
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.metrics.builder import METRICS, MetricKeys
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import LogKeys, ModelFile
|
||||
from modelscope.utils.registry import default_group
|
||||
@@ -42,7 +42,7 @@ dummy_dataset = create_dummy_test_dataset(
|
||||
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20)
|
||||
|
||||
|
||||
class DummyModel(nn.Module, Model):
|
||||
class DummyModel(TorchModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@@ -11,7 +11,7 @@ from torch import nn
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.metrics.builder import METRICS, MetricKeys
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.registry import default_group
|
||||
@@ -35,7 +35,7 @@ dummy_dataset = create_dummy_test_dataset(
|
||||
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20)
|
||||
|
||||
|
||||
class DummyModel(nn.Module, Model):
|
||||
class DummyModel(TorchModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@@ -13,7 +13,7 @@ from torch.optim.lr_scheduler import MultiStepLR
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.metrics.builder import METRICS, MetricKeys
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages
|
||||
from modelscope.utils.registry import default_group
|
||||
@@ -41,7 +41,7 @@ def create_dummy_metric():
|
||||
return {MetricKeys.ACCURACY: self._fake_acc_by_epoch[_global_iter]}
|
||||
|
||||
|
||||
class DummyModel(nn.Module, Model):
|
||||
class DummyModel(TorchModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@@ -16,7 +16,7 @@ from torch.utils.data import IterableDataset
|
||||
|
||||
from modelscope.metainfo import Metrics, Trainers
|
||||
from modelscope.metrics.builder import MetricKeys
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.trainers.base import DummyTrainer
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
@@ -41,7 +41,7 @@ dummy_dataset_big = create_dummy_test_dataset(
|
||||
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40)
|
||||
|
||||
|
||||
class DummyModel(nn.Module, Model):
|
||||
class DummyModel(TorchModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@@ -15,7 +15,7 @@ from torch.utils.data import IterableDataset
|
||||
|
||||
from modelscope.metainfo import Metrics, Trainers
|
||||
from modelscope.metrics.builder import MetricKeys
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.base import Model, TorchModel
|
||||
from modelscope.trainers import EpochBasedTrainer, build_trainer
|
||||
from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile, Tasks
|
||||
from modelscope.utils.test_utils import (DistributedTestCase,
|
||||
@@ -38,7 +38,7 @@ dummy_dataset_big = create_dummy_test_dataset(
|
||||
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40)
|
||||
|
||||
|
||||
class DummyModel(nn.Module, Model):
|
||||
class DummyModel(TorchModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user