Support csanmt exporting and refactor some code

1. Support csanmt exporting to savedmodel format
2. Create a new base class for text-ranking preprocessors, and move some parameters of mgeo_ranking_preprocessor to init method
3. Avoid Model & Preprocessor classes coupled with pytorch
4. Regression test supports comparing only model output
5. Support zero-shot exporting to onnx and torchscript

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11522461
This commit is contained in:
yuze.zyz
2023-02-10 05:15:04 +00:00
committed by wenmeng.zwm
parent 9f1b767ecd
commit 4dca4773db
36 changed files with 802 additions and 304 deletions

View File

@@ -1,5 +1,12 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.utils.import_utils import is_tf_available, is_torch_available
from .base import Exporter
from .builder import build_exporter
from .nlp import SbertForSequenceClassificationExporter
from .tf_model_exporter import TfModelExporter
from .torch_model_exporter import TorchModelExporter
if is_tf_available():
from .nlp import CsanmtForTranslationExporter
from .tf_model_exporter import TfModelExporter
if is_torch_available():
from .nlp import SbertForSequenceClassificationExporter, SbertForZeroShotClassificationExporter
from .torch_model_exporter import TorchModelExporter

View File

@@ -6,9 +6,11 @@ from typing import Dict, Union
from modelscope.models import Model
from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.constant import ModelFile
from modelscope.utils.hub import snapshot_download
from modelscope.utils.logger import get_logger
from .builder import build_exporter
logger = get_logger(__name__)
class Exporter(ABC):
"""Exporter base class to output model to onnx, torch_script, graphdef, etc.
@@ -46,7 +48,12 @@ class Exporter(ABC):
if hasattr(cfg, 'export'):
export_cfg.update(cfg.export)
export_cfg['model'] = model
exporter = build_exporter(export_cfg, task_name, kwargs)
try:
exporter = build_exporter(export_cfg, task_name, kwargs)
except KeyError as e:
raise KeyError(
f'The exporting of model \'{model_cfg.type}\' with task: \'{task_name}\' '
f'is not supported currently.') from e
return exporter
@abstractmethod

View File

@@ -1,2 +1,11 @@
from .sbert_for_sequence_classification_exporter import \
SbertForSequenceClassificationExporter
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.utils.import_utils import is_tf_available, is_torch_available
if is_tf_available():
from .csanmt_for_translation_exporter import CsanmtForTranslationExporter
if is_torch_available():
from .sbert_for_sequence_classification_exporter import \
SbertForSequenceClassificationExporter
from .sbert_for_zero_shot_classification_exporter import \
SbertForZeroShotClassificationExporter

View File

@@ -0,0 +1,185 @@
import os
from typing import Any, Dict
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import freeze_graph
from modelscope.exporters.builder import EXPORTERS
from modelscope.exporters.tf_model_exporter import TfModelExporter
from modelscope.metainfo import Models
from modelscope.pipelines.nlp.translation_pipeline import TranslationPipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import compare_arguments_nested
logger = get_logger(__name__)
if tf.__version__ >= '2.0':
tf = tf.compat.v1
tf.disable_eager_execution()
tf.logging.set_verbosity(tf.logging.INFO)
@EXPORTERS.register_module(Tasks.translation, module_name=Models.translation)
class CsanmtForTranslationExporter(TfModelExporter):
def __init__(self, model=None):
super().__init__(model)
self.pipeline = TranslationPipeline(self.model)
def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]:
return_dict = self.pipeline.preprocess(
"Alibaba Group's mission is to let the world have no difficult business"
)
return {'input_wids': return_dict['input_ids']}
def export_saved_model(self, output_dir, rtol=None, atol=None, **kwargs):
def _generate_signature():
receiver_tensors = {
'input_wids':
tf.saved_model.utils.build_tensor_info(
self.pipeline.input_wids)
}
export_outputs = {
'output_seqs':
tf.saved_model.utils.build_tensor_info(
self.pipeline.output['output_seqs'])
}
signature_def = tf.saved_model.signature_def_utils.build_signature_def(
receiver_tensors, export_outputs,
tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
return {'translation_signature': signature_def}
with self.pipeline._session.as_default() as sess:
builder = tf.saved_model.builder.SavedModelBuilder(output_dir)
builder.add_meta_graph_and_variables(
sess, [tag_constants.SERVING],
signature_def_map=_generate_signature(),
assets_collection=ops.get_collection(
ops.GraphKeys.ASSET_FILEPATHS),
clear_devices=True)
builder.save()
dummy_inputs = self.generate_dummy_inputs()
with tf.Session(graph=tf.Graph()) as sess:
# Restore model from the saved_modle file, that is exported by TensorFlow estimator.
MetaGraphDef = tf.saved_model.loader.load(sess, ['serve'],
output_dir)
# SignatureDef protobuf
SignatureDef_map = MetaGraphDef.signature_def
SignatureDef = SignatureDef_map['translation_signature']
# TensorInfo protobuf
X_TensorInfo = SignatureDef.inputs['input_wids']
y_TensorInfo = SignatureDef.outputs['output_seqs']
X = tf.saved_model.utils.get_tensor_from_tensor_info(
X_TensorInfo, sess.graph)
y = tf.saved_model.utils.get_tensor_from_tensor_info(
y_TensorInfo, sess.graph)
outputs = sess.run(y, feed_dict={X: dummy_inputs['input_wids']})
trans_result = self.pipeline.postprocess({'output_seqs': outputs})
logger.info(trans_result)
outputs_origin = self.pipeline.forward(
{'input_ids': dummy_inputs['input_wids']})
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
if not compare_arguments_nested('Output match failed', outputs,
outputs_origin['output_seqs'], **tols):
raise RuntimeError(
'Export saved model failed because of validation error.')
return {'model': output_dir}
def export_frozen_graph_def(self,
output_dir: str,
rtol=None,
atol=None,
**kwargs):
input_saver_def = self.pipeline.model_loader.as_saver_def()
inference_graph_def = tf.get_default_graph().as_graph_def()
for node in inference_graph_def.node:
node.device = ''
frozen_dir = os.path.join(output_dir, 'frozen')
tf.gfile.MkDir(frozen_dir)
frozen_graph_path = os.path.join(frozen_dir,
'frozen_inference_graph.pb')
outputs = {
'output_trans_result':
tf.identity(
self.pipeline.output['output_seqs'],
name='NmtModel/output_trans_result')
}
for output_key in outputs:
tf.add_to_collection('inference_op', outputs[output_key])
output_node_names = ','.join([
'%s/%s' % ('NmtModel', output_key)
for output_key in outputs.keys()
])
print(output_node_names)
_ = freeze_graph.freeze_graph_with_def_protos(
input_graph_def=tf.get_default_graph().as_graph_def(),
input_saver_def=input_saver_def,
input_checkpoint=self.pipeline.model_path,
output_node_names=output_node_names,
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
output_graph=frozen_graph_path,
clear_devices=True,
initializer_nodes='')
# 5. test frozen.pb
dummy_inputs = self.generate_dummy_inputs()
with self.pipeline._session.as_default() as sess:
sess.run(tf.tables_initializer())
graph = tf.Graph()
with tf.gfile.GFile(frozen_graph_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with graph.as_default():
tf.import_graph_def(graph_def, name='')
graph.finalize()
with tf.Session(graph=graph) as trans_sess:
outputs = trans_sess.run(
'NmtModel/strided_slice_9:0',
feed_dict={'input_wids:0': dummy_inputs['input_wids']})
trans_result = self.pipeline.postprocess(
{'output_seqs': outputs})
logger.info(trans_result)
outputs_origin = self.pipeline.forward(
{'input_ids': dummy_inputs['input_wids']})
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
if not compare_arguments_nested('Output match failed', outputs,
outputs_origin['output_seqs'], **tols):
raise RuntimeError(
'Export frozen graphdef failed because of validation error.')
return {'model': frozen_graph_path}
def export_onnx(self, output_dir: str, opset=13, **kwargs):
raise NotImplementedError(
'csanmt model does not support onnx format, consider using savedmodel instead.'
)

View File

@@ -1,4 +1,3 @@
import os
from collections import OrderedDict
from typing import Any, Dict, Mapping, Tuple
@@ -7,9 +6,7 @@ from torch.utils.data.dataloader import default_collate
from modelscope.exporters.builder import EXPORTERS
from modelscope.exporters.torch_model_exporter import TorchModelExporter
from modelscope.metainfo import Models
from modelscope.preprocessors import (
Preprocessor, TextClassificationTransformersPreprocessor,
build_preprocessor)
from modelscope.preprocessors import Preprocessor
from modelscope.utils.constant import ModeKeys, Tasks
@@ -17,8 +14,6 @@ from modelscope.utils.constant import ModeKeys, Tasks
@EXPORTERS.register_module(
Tasks.text_classification, module_name=Models.structbert)
@EXPORTERS.register_module(Tasks.sentence_similarity, module_name=Models.bert)
@EXPORTERS.register_module(
Tasks.zero_shot_classification, module_name=Models.bert)
@EXPORTERS.register_module(
Tasks.sentiment_classification, module_name=Models.bert)
@EXPORTERS.register_module(Tasks.nli, module_name=Models.bert)
@@ -27,8 +22,6 @@ from modelscope.utils.constant import ModeKeys, Tasks
@EXPORTERS.register_module(
Tasks.sentiment_classification, module_name=Models.structbert)
@EXPORTERS.register_module(Tasks.nli, module_name=Models.structbert)
@EXPORTERS.register_module(
Tasks.zero_shot_classification, module_name=Models.structbert)
class SbertForSequenceClassificationExporter(TorchModelExporter):
def generate_dummy_inputs(self,

View File

@@ -0,0 +1,58 @@
from collections import OrderedDict
from typing import Any, Dict, Mapping
from modelscope.exporters.builder import EXPORTERS
from modelscope.exporters.torch_model_exporter import TorchModelExporter
from modelscope.metainfo import Models
from modelscope.preprocessors import Preprocessor
from modelscope.utils.constant import Tasks
@EXPORTERS.register_module(
Tasks.zero_shot_classification, module_name=Models.bert)
@EXPORTERS.register_module(
Tasks.zero_shot_classification, module_name=Models.structbert)
class SbertForZeroShotClassificationExporter(TorchModelExporter):
def generate_dummy_inputs(self,
candidate_labels,
hypothesis_template,
max_length=128,
pair: bool = False,
**kwargs) -> Dict[str, Any]:
"""Generate dummy inputs for model exportation to onnx or other formats by tracing.
Args:
max_length(int): The max length of sentence, default 128.
hypothesis_template(str): The template of prompt, like '这篇文章的标题是{}'
candidate_labels(List): The labels of prompt,
like ['文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事']
pair(bool, `optional`): Whether to generate sentence pairs or single sentences.
Returns:
Dummy inputs.
"""
assert hasattr(
self.model, 'model_dir'
), 'model_dir attribute is required to build the preprocessor'
preprocessor = Preprocessor.from_pretrained(
self.model.model_dir, max_length=max_length)
return preprocessor(
preprocessor.nlp_tokenizer.tokenizer.unk_token,
candidate_labels=candidate_labels,
hypothesis_template=hypothesis_template)
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
dynamic_axis = {0: 'batch', 1: 'sequence'}
return OrderedDict([
('input_ids', dynamic_axis),
('attention_mask', dynamic_axis),
('token_type_ids', dynamic_axis),
])
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict({'logits': {0: 'batch'}})

View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from abc import abstractmethod
from typing import Any, Callable, Dict, Mapping
import tensorflow as tf
@@ -7,7 +8,7 @@ import tensorflow as tf
from modelscope.outputs import ModelOutputBase
from modelscope.utils.constant import ModelFile
from modelscope.utils.logger import get_logger
from modelscope.utils.regress_test_utils import compare_arguments_nested
from modelscope.utils.test_utils import compare_arguments_nested
from .base import Exporter
logger = get_logger()
@@ -29,6 +30,14 @@ class TfModelExporter(Exporter):
self._tf2_export_onnx(model, onnx_file, opset=opset, **kwargs)
return {'model': onnx_file}
@abstractmethod
def export_saved_model(self, output_dir: str, **kwargs):
pass
@abstractmethod
def export_frozen_graph_def(self, output_dir: str, **kwargs):
pass
def _tf2_export_onnx(self,
model,
output: str,
@@ -59,56 +68,67 @@ class TfModelExporter(Exporter):
onnx.save(onnx_model, output)
if validation:
try:
import onnx
import onnxruntime as ort
except ImportError:
logger.warn(
'Cannot validate the exported onnx file, because '
'the installation of onnx or onnxruntime cannot be found')
return
self._validate_model(dummy_inputs, model, output, rtol, atol,
call_func)
def tensor_nested_numpify(tensors):
if isinstance(tensors, (list, tuple)):
return type(tensors)(
tensor_nested_numpify(t) for t in tensors)
if isinstance(tensors, Mapping):
# return dict
return {
k: tensor_nested_numpify(t)
for k, t in tensors.items()
}
if isinstance(tensors, tf.Tensor):
t = tensors.cpu()
return t.numpy()
return tensors
def _validate_model(
self,
dummy_inputs,
model,
output,
rtol: float = None,
atol: float = None,
call_func: Callable = None,
):
try:
import onnx
import onnxruntime as ort
except ImportError:
logger.warn(
'Cannot validate the exported onnx file, because '
'the installation of onnx or onnxruntime cannot be found')
return
onnx_model = onnx.load(output)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(output)
outputs_origin = call_func(
dummy_inputs) if call_func is not None else model(dummy_inputs)
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
outputs_origin = list(
tensor_nested_numpify(outputs_origin).values())
elif isinstance(outputs_origin, (tuple, list)):
outputs_origin = list(tensor_nested_numpify(outputs_origin))
outputs = ort_session.run(
None,
tensor_nested_numpify(dummy_inputs),
)
outputs = tensor_nested_numpify(outputs)
if isinstance(outputs, dict):
outputs = list(outputs.values())
elif isinstance(outputs, tuple):
outputs = list(outputs)
def tensor_nested_numpify(tensors):
if isinstance(tensors, (list, tuple)):
return type(tensors)(tensor_nested_numpify(t) for t in tensors)
if isinstance(tensors, Mapping):
# return dict
return {
k: tensor_nested_numpify(t)
for k, t in tensors.items()
}
if isinstance(tensors, tf.Tensor):
t = tensors.cpu()
return t.numpy()
return tensors
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
if not compare_arguments_nested('Onnx model output match failed',
outputs, outputs_origin, **tols):
raise RuntimeError(
'export onnx failed because of validation error.')
onnx_model = onnx.load(output)
onnx.checker.check_model(onnx_model, full_check=True)
ort_session = ort.InferenceSession(output)
outputs_origin = call_func(
dummy_inputs) if call_func is not None else model(dummy_inputs)
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
outputs_origin = list(
tensor_nested_numpify(outputs_origin).values())
elif isinstance(outputs_origin, (tuple, list)):
outputs_origin = list(tensor_nested_numpify(outputs_origin))
outputs = ort_session.run(
None,
tensor_nested_numpify(dummy_inputs),
)
outputs = tensor_nested_numpify(outputs)
if isinstance(outputs, dict):
outputs = list(outputs.values())
elif isinstance(outputs, tuple):
outputs = list(outputs)
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
if not compare_arguments_nested('Onnx model output match failed',
outputs, outputs_origin, **tols):
raise RuntimeError(
'export onnx failed because of validation error.')

View File

@@ -1,15 +1,13 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import pickle
from typing import Dict, Optional, Union
from urllib.parse import urlparse
from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.constants import (FILE_HASH, MODEL_META_FILE_NAME,
MODEL_META_MODEL_ID)
from modelscope.hub.constants import FILE_HASH
from modelscope.hub.git import GitCommandWrapper
from modelscope.hub.utils.caching import FileSystemCache, ModelFileSystemCache
from modelscope.hub.utils.caching import ModelFileSystemCache
from modelscope.hub.utils.utils import compute_hash
from modelscope.utils.logger import get_logger

View File

@@ -9,4 +9,5 @@ from .base import Head, Model
from .builder import BACKBONES, HEADS, MODELS, build_model
if is_torch_available():
from .base import TorchModel, TorchHead
from .base.base_torch_model import TorchModel
from .base.base_torch_head import TorchHead

View File

@@ -1,6 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.utils.import_utils import is_torch_available
from .base_head import * # noqa F403
from .base_model import * # noqa F403
from .base_torch_head import * # noqa F403
from .base_torch_model import * # noqa F403
if is_torch_available():
from .base_torch_model import TorchModel
from .base_torch_head import TorchHead

View File

@@ -2,13 +2,11 @@
import os
import os.path as osp
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from modelscope.hub.check_model import check_local_model_is_latest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models.builder import build_model
from modelscope.utils.checkpoint import (save_checkpoint, save_configuration,
save_pretrained)
from modelscope.utils.config import Config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
from modelscope.utils.device import verify_device
@@ -150,9 +148,7 @@ class Model(ABC):
def save_pretrained(self,
target_folder: Union[str, os.PathLike],
save_checkpoint_names: Union[str, List[str]] = None,
save_function: Callable = save_checkpoint,
config: Optional[dict] = None,
save_config_function: Callable = save_configuration,
**kwargs):
"""save the pretrained model, its configuration and other related files to a directory,
so that it can be re-loaded
@@ -164,21 +160,8 @@ class Model(ABC):
save_checkpoint_names (Union[str, List[str]]):
The checkpoint names to be saved in the target_folder
save_function (Callable, optional):
The function to use to save the state dictionary.
config (Optional[dict], optional):
The config for the configuration.json, might not be identical with model.config
save_config_function (Callble, optional):
The function to use to save the configuration.
"""
if config is None and hasattr(self, 'cfg'):
config = self.cfg
if config is not None:
save_config_function(target_folder, config)
save_pretrained(self, target_folder, save_checkpoint_names,
save_function, **kwargs)
raise NotImplementedError(
'save_pretrained method need to be implemented by the subclass.')

View File

@@ -1,14 +1,16 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from copy import deepcopy
from typing import Any, Dict
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from torch import nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
from modelscope.utils.checkpoint import (save_checkpoint, save_configuration,
save_pretrained)
from modelscope.utils.file_utils import func_receive_dict_inputs
from modelscope.utils.hub import parse_label_mapping
from modelscope.utils.logger import get_logger
from .base_model import Model
@@ -88,3 +90,39 @@ class TorchModel(Model, torch.nn.Module):
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def save_pretrained(self,
target_folder: Union[str, os.PathLike],
save_checkpoint_names: Union[str, List[str]] = None,
save_function: Callable = save_checkpoint,
config: Optional[dict] = None,
save_config_function: Callable = save_configuration,
**kwargs):
"""save the pretrained model, its configuration and other related files to a directory,
so that it can be re-loaded
Args:
target_folder (Union[str, os.PathLike]):
Directory to which to save. Will be created if it doesn't exist.
save_checkpoint_names (Union[str, List[str]]):
The checkpoint names to be saved in the target_folder
save_function (Callable, optional):
The function to use to save the state dictionary.
config (Optional[dict], optional):
The config for the configuration.json, might not be identical with model.config
save_config_function (Callble, optional):
The function to use to save the configuration.
"""
if config is None and hasattr(self, 'cfg'):
config = self.cfg
if config is not None:
save_config_function(target_folder, config)
save_pretrained(self, target_folder, save_checkpoint_names,
save_function, **kwargs)

View File

@@ -3,7 +3,7 @@
from modelscope.utils.config import ConfigDict
from modelscope.utils.constant import Tasks
from modelscope.utils.import_utils import INDEX_KEY, LazyImportModule
from modelscope.utils.registry import TYPE_NAME, Registry, build_from_cfg
from modelscope.utils.registry import Registry, build_from_cfg
MODELS = Registry('models')
BACKBONES = MODELS

View File

@@ -6,7 +6,6 @@ from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional,
Sequence, Union)
import numpy as np
import torch
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES
from datasets.utils.file_utils import is_relative_path
@@ -43,42 +42,6 @@ def format_list(para) -> List:
return para
class MsMapDataset(torch.utils.data.Dataset):
def __init__(self, dataset: Iterable, preprocessor_list, retained_columns,
columns, to_tensor):
super(MsDataset).__init__()
self.dataset = dataset
self.preprocessor_list = preprocessor_list
self.to_tensor = to_tensor
self.retained_columns = retained_columns
self.columns = columns
def __len__(self):
return len(self.dataset)
def type_converter(self, x):
if self.to_tensor:
return torch.tensor(x)
else:
return x
def __getitem__(self, index):
item_dict = self.dataset[index]
res = {
k: self.type_converter(item_dict[k])
for k in self.columns
if (not self.to_tensor) or k in self.retained_columns
}
for preprocessor in self.preprocessor_list:
res.update({
k: self.type_converter(v)
for k, v in preprocessor(item_dict).items()
if (not self.to_tensor) or k in self.retained_columns
})
return res
class MsDataset:
"""
ModelScope Dataset (aka, MsDataset) is backed by a huggingface Dataset to
@@ -303,6 +266,7 @@ class MsDataset:
columns: Union[str, List[str]] = None,
to_tensor: bool = True,
):
import torch
preprocessor_list = preprocessors if isinstance(
preprocessors, list) else [preprocessors]
@@ -332,6 +296,42 @@ class MsDataset:
continue
retained_columns.append(k)
class MsMapDataset(torch.utils.data.Dataset):
def __init__(self, dataset: Iterable, preprocessor_list,
retained_columns, columns, to_tensor):
super(MsDataset).__init__()
self.dataset = dataset
self.preprocessor_list = preprocessor_list
self.to_tensor = to_tensor
self.retained_columns = retained_columns
self.columns = columns
def __len__(self):
return len(self.dataset)
def type_converter(self, x):
import torch
if self.to_tensor:
return torch.tensor(x)
else:
return x
def __getitem__(self, index):
item_dict = self.dataset[index]
res = {
k: self.type_converter(item_dict[k])
for k in self.columns
if (not self.to_tensor) or k in self.retained_columns
}
for preprocessor in self.preprocessor_list:
res.update({
k: self.type_converter(v)
for k, v in preprocessor(item_dict).items()
if (not self.to_tensor) or k in self.retained_columns
})
return res
return MsMapDataset(self._hf_ds, preprocessor_list, retained_columns,
columns, to_tensor)

View File

@@ -22,7 +22,6 @@ from modelscope.utils.device import (create_device, device_placement,
from modelscope.utils.hub import read_config, snapshot_download
from modelscope.utils.import_utils import is_tf_available, is_torch_available
from modelscope.utils.logger import get_logger
from modelscope.utils.torch_utils import _find_free_port, _is_free_port
from .util import is_model, is_official_hub_path
if is_torch_available():
@@ -426,6 +425,7 @@ class DistributedPipeline(Pipeline):
'master_ip']
master_port = '29500' if 'master_port' not in kwargs else kwargs[
'master_port']
from modelscope.utils.torch_utils import _find_free_port, _is_free_port
if not _is_free_port(int(master_port)):
master_port = str(_find_free_port())
self.model_pool.map(

View File

@@ -44,7 +44,7 @@ class TranslationPipeline(Pipeline):
model = self.model.model_dir
tf.reset_default_graph()
model_path = osp.join(
self.model_path = osp.join(
osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER), 'ckpt-0')
self.cfg = Config.from_file(osp.join(model, ModelFile.CONFIGURATION))
@@ -88,10 +88,10 @@ class TranslationPipeline(Pipeline):
self.output.update(output)
with self._session.as_default() as sess:
logger.info(f'loading model from {model_path}')
logger.info(f'loading model from {self.model_path}')
# load model
model_loader = tf.train.Saver(tf.global_variables())
model_loader.restore(sess, model_path)
self.model_loader = tf.train.Saver(tf.global_variables())
self.model_loader.restore(sess, self.model_path)
def preprocess(self, input: str) -> Dict[str, Any]:
input = input.split('<SENT_SPLIT>')

View File

@@ -2,8 +2,6 @@
import os.path as osp
from typing import List, Optional, Union
import torch
from modelscope.hub.api import HubApi
from modelscope.hub.file_download import model_file_download
from modelscope.utils.config import Config
@@ -86,6 +84,7 @@ def is_model(path: Union[str, List]):
def batch_process(model, data):
import torch
if model.__class__.__name__ == 'OfaForAllTasks':
# collate batch data due to the nested data structure
assert isinstance(data, list)

View File

@@ -5,7 +5,6 @@ from copy import deepcopy
from typing import Any, Callable, Dict, Optional, Sequence, Union
from modelscope.metainfo import Models, Preprocessors
from modelscope.utils.checkpoint import save_configuration
from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke,
ModeKeys, Tasks)
@@ -312,7 +311,7 @@ class Preprocessor(ABC):
def save_pretrained(self,
target_folder: Union[str, os.PathLike],
config: Optional[dict] = None,
save_config_function: Callable = save_configuration):
save_config_function: Callable = None):
"""Save the preprocessor, its configuration and other related files to a directory,
so that it can be re-loaded
@@ -341,4 +340,7 @@ class Preprocessor(ABC):
'preprocessor']['val']:
config['preprocessor']['val']['mode'] = 'inference'
if save_config_function is None:
from modelscope.utils.checkpoint import save_configuration
save_config_function = save_configuration
save_config_function(target_folder, config)

View File

@@ -10,6 +10,7 @@ from modelscope.preprocessors import Preprocessor
from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.utils.constant import Fields, ModeKeys
from modelscope.utils.type_assert import type_assert
from .text_ranking_preprocessor import TextRankingPreprocessorBase
class GisUtt:
@@ -113,7 +114,7 @@ class GisUtt:
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.mgeo_ranking)
class MGeoRankingTransformersPreprocessor(Preprocessor):
class MGeoRankingTransformersPreprocessor(TextRankingPreprocessorBase):
def __init__(self,
model_dir: str,
@@ -125,39 +126,39 @@ class MGeoRankingTransformersPreprocessor(Preprocessor):
label='labels',
qid='qid',
max_length=None,
padding='longest',
truncation=True,
use_fast=True,
**kwargs):
"""The tokenizer preprocessor class for the text ranking preprocessor.
Args:
model_dir(str, `optional`): The model dir used to parse the label mapping, can be None.
first_sequence(str, `optional`): The key of the first sequence.
second_sequence(str, `optional`): The key of the second sequence.
label(str, `optional`): The keys of the label columns, default `labels`.
qid(str, `optional`): The qid info.
mode: The mode for the preprocessor.
max_length: The max sequence length which the model supported,
will be passed into tokenizer as the 'max_length' param.
padding: The padding method
truncation: The truncation method
"""
super().__init__(mode)
super().__init__(
mode=mode,
first_sequence=first_sequence,
second_sequence=second_sequence,
label=label,
qid=qid)
self.model_dir = model_dir
self.first_sequence = first_sequence
self.second_sequence = second_sequence
self.first_sequence_gis = first_sequence_gis
self.second_sequence_gis = second_sequence_gis
self.label = label
self.qid = qid
self.sequence_length = max_length if max_length is not None else kwargs.get(
'sequence_length', 128)
kwargs.pop('sequence_length', None)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
self.tokenize_kwargs = kwargs
self.tokenize_kwargs['padding'] = padding
self.tokenize_kwargs['truncation'] = truncation
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_dir, use_fast=use_fast)
@type_assert(object, dict)
def __call__(self,
data: Dict,
padding='longest',
truncation=True,
**kwargs) -> Dict[str, Any]:
def __call__(self, data: Dict, **kwargs) -> Dict[str, Any]:
sentence1 = data.get(self.first_sequence)
sentence2 = data.get(self.second_sequence)
labels = data.get(self.label)
@@ -176,12 +177,9 @@ class MGeoRankingTransformersPreprocessor(Preprocessor):
'max_length', kwargs.pop('sequence_length', self.sequence_length))
if 'return_tensors' not in kwargs:
kwargs['return_tensors'] = 'pt'
feature = self.tokenizer(
sentence1,
sentence2,
padding=padding,
truncation=truncation,
**kwargs)
self.tokenize_kwargs.update(kwargs)
feature = self.tokenizer(sentence1, sentence2, **self.tokenize_kwargs)
if labels is not None:
feature['labels'] = labels
if qid is not None:

View File

@@ -11,9 +11,33 @@ from modelscope.utils.constant import Fields, ModeKeys
from modelscope.utils.type_assert import type_assert
class TextRankingPreprocessorBase(Preprocessor):
def __init__(self,
mode: str = ModeKeys.INFERENCE,
first_sequence='source_sentence',
second_sequence='sentences_to_compare',
label='labels',
qid='qid'):
"""The tokenizer preprocessor class for the text ranking preprocessor.
Args:
first_sequence(str, `optional`): The key of the first sequence.
second_sequence(str, `optional`): The key of the second sequence.
label(str, `optional`): The keys of the label columns, default `labels`.
qid(str, `optional`): The qid info.
mode: The mode for the preprocessor.
"""
super().__init__(mode)
self.first_sequence = first_sequence
self.second_sequence = second_sequence
self.label = label
self.qid = qid
@PREPROCESSORS.register_module(
Fields.nlp, module_name=Preprocessors.text_ranking)
class TextRankingTransformersPreprocessor(Preprocessor):
class TextRankingTransformersPreprocessor(TextRankingPreprocessorBase):
def __init__(self,
model_dir: str,
@@ -23,36 +47,33 @@ class TextRankingTransformersPreprocessor(Preprocessor):
label='labels',
qid='qid',
max_length=None,
padding='max_length',
truncation=True,
**kwargs):
"""The tokenizer preprocessor class for the text ranking preprocessor.
Args:
model_dir(str, `optional`): The model dir used to parse the label mapping, can be None.
first_sequence(str, `optional`): The key of the first sequence.
second_sequence(str, `optional`): The key of the second sequence.
label(str, `optional`): The keys of the label columns, default `labels`.
qid(str, `optional`): The qid info.
mode: The mode for the preprocessor.
max_length: The max sequence length which the model supported,
will be passed into tokenizer as the 'max_length' param.
"""
super().__init__(mode)
super().__init__(
mode=mode,
first_sequence=first_sequence,
second_sequence=second_sequence,
label=label,
qid=qid)
self.model_dir = model_dir
self.first_sequence = first_sequence
self.second_sequence = second_sequence
self.label = label
self.qid = qid
self.sequence_length = max_length if max_length is not None else kwargs.get(
'sequence_length', 128)
kwargs.pop('sequence_length', None)
self.tokenize_kwargs = kwargs
self.tokenize_kwargs['padding'] = padding
self.tokenize_kwargs['truncation'] = truncation
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
@type_assert(object, dict)
def __call__(self,
data: Dict,
padding='max_length',
truncation=True,
**kwargs) -> Dict[str, Any]:
def __call__(self, data: Dict, **kwargs) -> Dict[str, Any]:
sentence1 = data.get(self.first_sequence)
sentence2 = data.get(self.second_sequence)
labels = data.get(self.label)
@@ -67,12 +88,9 @@ class TextRankingTransformersPreprocessor(Preprocessor):
'max_length', kwargs.pop('sequence_length', self.sequence_length))
if 'return_tensors' not in kwargs:
kwargs['return_tensors'] = 'pt'
feature = self.tokenizer(
sentence1,
sentence2,
padding=padding,
truncation=truncation,
**kwargs)
self.tokenize_kwargs.update(kwargs)
feature = self.tokenizer(sentence1, sentence2, **self.tokenize_kwargs)
if labels is not None:
feature['labels'] = labels
if qid is not None:

View File

@@ -3,7 +3,6 @@
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
import copy
import dataclasses
import os
import os.path as osp
import platform

View File

@@ -2,7 +2,6 @@
import io
import cv2
import json
from modelscope.outputs import OutputKeys
@@ -265,6 +264,7 @@ def postprocess(req, resp):
new_resp.get(output_key)
if file_type == 'png' or file_type == 'jpg':
content = new_resp.get(output_key)
import cv2
_, img_encode = cv2.imencode('.' + file_type, content)
img_bytes = img_encode.tobytes()
return type(img_bytes)

View File

@@ -3,7 +3,6 @@ import os
from contextlib import contextmanager
from modelscope.utils.constant import Devices, Frameworks
from modelscope.utils.import_utils import is_tf_available, is_torch_available
from modelscope.utils.logger import get_logger
logger = get_logger()

View File

@@ -1,7 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
import os
from pathlib import Path

View File

@@ -20,7 +20,7 @@ import torch
import torch.optim
from torch import nn
from modelscope.utils.service_utils import NumpyEncoder
from .test_utils import compare_arguments_nested
class RegressTool:
@@ -71,6 +71,7 @@ class RegressTool:
module: nn.Module,
file_name: str,
compare_fn=None,
compare_model_output=True,
**kwargs):
"""Monitor a pytorch module in a single forward.
@@ -78,6 +79,7 @@ class RegressTool:
module: A torch module
file_name: The file_name to store or load file
compare_fn: A custom fn used to compare the results manually.
compare_model_output: Only compare the input module's output, skip all other tensors
>>> def compare_fn(v1, v2, key, type):
>>> return None
@@ -120,17 +122,46 @@ class RegressTool:
with open(baseline, 'rb') as f:
base = pickle.load(f)
class SafeNumpyEncoder(NumpyEncoder):
class SafeNumpyEncoder(json.JSONEncoder):
def parse_default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.integer):
return int(obj)
return json.JSONEncoder.default(self, obj)
def default(self, obj):
try:
return super().default(obj)
return self.default(obj)
except Exception:
print(
f'Type {obj.__class__} cannot be serialized and printed'
)
return None
if compare_model_output:
print(
'Ignore inner modules, only the output of the model will be verified.'
)
base = {
key: value
for key, value in base.items() if key == file_name
}
for key, value in base.items():
value['input'] = {'args': None, 'kwargs': None}
io_json = {
key: value
for key, value in io_json.items() if key == file_name
}
for key, value in io_json.items():
value['input'] = {'args': None, 'kwargs': None}
print(f'baseline: {json.dumps(base, cls=SafeNumpyEncoder)}')
print(f'latest : {json.dumps(io_json, cls=SafeNumpyEncoder)}')
if not compare_io_and_print(base, io_json, compare_fn, **kwargs):
@@ -326,10 +357,75 @@ class MsRegressTool(RegressTool):
def lazy_stop_callback():
from modelscope.trainers.hooks.hook import Hook, Priority
class EarlyStopHook:
PRIORITY = 90
class EarlyStopHook(Hook):
PRIORITY = Priority.VERY_LOW
def before_run(self, trainer):
pass
def after_run(self, trainer):
pass
def before_epoch(self, trainer):
pass
def after_epoch(self, trainer):
pass
def before_iter(self, trainer):
pass
def before_train_epoch(self, trainer):
self.before_epoch(trainer)
def before_val_epoch(self, trainer):
self.before_epoch(trainer)
def after_train_epoch(self, trainer):
self.after_epoch(trainer)
def after_val_epoch(self, trainer):
self.after_epoch(trainer)
def before_train_iter(self, trainer):
self.before_iter(trainer)
def before_val_iter(self, trainer):
self.before_iter(trainer)
def after_train_iter(self, trainer):
self.after_iter(trainer)
def after_val_iter(self, trainer):
self.after_iter(trainer)
def every_n_epochs(self, trainer, n):
return (trainer.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, runner, n):
return (runner.inner_iter
+ 1) % n == 0 if n > 0 else False
def every_n_iters(self, trainer, n):
return (trainer.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, trainer):
return trainer.inner_iter + 1 == trainer.iters_per_epoch
def is_last_epoch(self, trainer):
return trainer.epoch + 1 == trainer.max_epochs
def is_last_iter(self, trainer):
return trainer.iter + 1 == trainer.max_iters
def get_triggered_stages(self):
return []
def state_dict(self):
return {}
def load_state_dict(self, state_dict):
pass
def after_iter(self, trainer):
raise MsRegressTool.EarlyStopError('Test finished.')
@@ -526,92 +622,6 @@ def intercept_module(module: nn.Module,
intercept_module(module, io_json, full_name, restore)
def compare_arguments_nested(print_content,
arg1,
arg2,
rtol=1.e-3,
atol=1.e-8,
ignore_unknown_type=True):
type1 = type(arg1)
type2 = type(arg2)
if type1.__name__ != type2.__name__:
if print_content is not None:
print(
f'{print_content}, type not equal:{type1.__name__} and {type2.__name__}'
)
return False
if arg1 is None:
return True
elif isinstance(arg1, (int, str, bool, np.bool, np.integer, np.str)):
if arg1 != arg2:
if print_content is not None:
print(f'{print_content}, arg1:{arg1}, arg2:{arg2}')
return False
return True
elif isinstance(arg1, (float, np.floating)):
if not np.isclose(arg1, arg2, rtol=rtol, atol=atol, equal_nan=True):
if print_content is not None:
print(f'{print_content}, arg1:{arg1}, arg2:{arg2}')
return False
return True
elif isinstance(arg1, (tuple, list)):
if len(arg1) != len(arg2):
if print_content is not None:
print(
f'{print_content}, length is not equal:{len(arg1)}, {len(arg2)}'
)
return False
if not all([
compare_arguments_nested(
None, sub_arg1, sub_arg2, rtol=rtol, atol=atol)
for sub_arg1, sub_arg2 in zip(arg1, arg2)
]):
if print_content is not None:
print(f'{print_content}')
return False
return True
elif isinstance(arg1, Mapping):
keys1 = arg1.keys()
keys2 = arg2.keys()
if len(keys1) != len(keys2):
if print_content is not None:
print(
f'{print_content}, key length is not equal:{len(keys1)}, {len(keys2)}'
)
return False
if len(set(keys1) - set(keys2)) > 0:
if print_content is not None:
print(f'{print_content}, key diff:{set(keys1) - set(keys2)}')
return False
if not all([
compare_arguments_nested(
None, arg1[key], arg2[key], rtol=rtol, atol=atol)
for key in keys1
]):
if print_content is not None:
print(f'{print_content}')
return False
return True
elif isinstance(arg1, np.ndarray):
arg1 = np.where(np.equal(arg1, None), np.NaN,
arg1).astype(dtype=np.float)
arg2 = np.where(np.equal(arg2, None), np.NaN,
arg2).astype(dtype=np.float)
if not all(
np.isclose(arg1, arg2, rtol=rtol, atol=atol,
equal_nan=True).flatten()):
if print_content is not None:
print(f'{print_content}')
return False
return True
else:
if ignore_unknown_type:
return True
else:
raise ValueError(f'type not supported: {type1}')
def compare_io_and_print(baseline_json, io_json, compare_fn=None, **kwargs):
if compare_fn is None:

View File

@@ -5,12 +5,9 @@ from io import BytesIO
import json
import numpy as np
import requests
from PIL import Image
from modelscope.outputs import TASK_OUTPUTS, OutputKeys
from modelscope.pipeline_inputs import TASK_INPUTS, InputType
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks, TasksIODescriptions
# service data decoder func decodes data from network and convert it to pipeline's input
@@ -91,12 +88,14 @@ def decode_base64_to_binary(encoding):
def decode_base64_to_image(encoding):
from PIL import Image
content = encoding.split(';')[1]
image_encoded = content.split(',')[1]
return Image.open(BytesIO(base64.b64decode(image_encoded)))
def encode_array_to_img_base64(image_array):
from PIL import Image
with BytesIO() as output_bytes:
pil_image = Image.fromarray(image_array.astype(np.uint8))
pil_image.save(output_bytes, 'PNG')

View File

@@ -12,13 +12,12 @@ import tarfile
import tempfile
import unittest
from collections import OrderedDict
from collections.abc import Mapping
import numpy as np
import requests
import torch
from torch.utils.data import Dataset
from modelscope.utils.import_utils import is_tf_available, is_torch_available
from modelscope.utils.torch_utils import _find_free_port
TEST_LEVEL = 2
TEST_LEVEL_STR = 'TEST_LEVEL'
@@ -49,7 +48,7 @@ def set_test_level(level: int):
TEST_LEVEL = level
class DummyTorchDataset(Dataset):
class DummyTorchDataset:
def __init__(self, feat, label, num) -> None:
self.feat = feat
@@ -57,6 +56,7 @@ class DummyTorchDataset(Dataset):
self.num = num
def __getitem__(self, index):
import torch
return {
'feat': torch.Tensor(self.feat),
'labels': torch.Tensor(self.label)
@@ -119,6 +119,92 @@ def get_case_model_info():
return model_cases
def compare_arguments_nested(print_content,
arg1,
arg2,
rtol=1.e-3,
atol=1.e-8,
ignore_unknown_type=True):
type1 = type(arg1)
type2 = type(arg2)
if type1.__name__ != type2.__name__:
if print_content is not None:
print(
f'{print_content}, type not equal:{type1.__name__} and {type2.__name__}'
)
return False
if arg1 is None:
return True
elif isinstance(arg1, (int, str, bool, np.bool, np.integer, np.str)):
if arg1 != arg2:
if print_content is not None:
print(f'{print_content}, arg1:{arg1}, arg2:{arg2}')
return False
return True
elif isinstance(arg1, (float, np.floating)):
if not np.isclose(arg1, arg2, rtol=rtol, atol=atol, equal_nan=True):
if print_content is not None:
print(f'{print_content}, arg1:{arg1}, arg2:{arg2}')
return False
return True
elif isinstance(arg1, (tuple, list)):
if len(arg1) != len(arg2):
if print_content is not None:
print(
f'{print_content}, length is not equal:{len(arg1)}, {len(arg2)}'
)
return False
if not all([
compare_arguments_nested(
None, sub_arg1, sub_arg2, rtol=rtol, atol=atol)
for sub_arg1, sub_arg2 in zip(arg1, arg2)
]):
if print_content is not None:
print(f'{print_content}')
return False
return True
elif isinstance(arg1, Mapping):
keys1 = arg1.keys()
keys2 = arg2.keys()
if len(keys1) != len(keys2):
if print_content is not None:
print(
f'{print_content}, key length is not equal:{len(keys1)}, {len(keys2)}'
)
return False
if len(set(keys1) - set(keys2)) > 0:
if print_content is not None:
print(f'{print_content}, key diff:{set(keys1) - set(keys2)}')
return False
if not all([
compare_arguments_nested(
None, arg1[key], arg2[key], rtol=rtol, atol=atol)
for key in keys1
]):
if print_content is not None:
print(f'{print_content}')
return False
return True
elif isinstance(arg1, np.ndarray):
arg1 = np.where(np.equal(arg1, None), np.NaN,
arg1).astype(dtype=np.float)
arg2 = np.where(np.equal(arg2, None), np.NaN,
arg2).astype(dtype=np.float)
if not all(
np.isclose(arg1, arg2, rtol=rtol, atol=atol,
equal_nan=True).flatten()):
if print_content is not None:
print(f'{print_content}')
return False
return True
else:
if ignore_unknown_type:
return True
else:
raise ValueError(f'type not supported: {type1}')
_DIST_SCRIPT_TEMPLATE = """
import ast
import argparse
@@ -263,6 +349,7 @@ class DistributedTestCase(unittest.TestCase):
save_all_ranks=False,
*args,
**kwargs):
from .torch_utils import _find_free_port
ip = socket.gethostbyname(socket.gethostname())
dist_start_cmd = '%s -m torch.distributed.launch --nproc_per_node=%d --master_addr=\'%s\' --master_port=%s' % (
sys.executable, num_gpus, ip, _find_free_port())

View File

@@ -0,0 +1,40 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.preprocessing import image
from modelscope.exporters import TfModelExporter
from modelscope.models import Model
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import compare_arguments_nested, test_level
class TestExportTfModel(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()
@unittest.skipUnless(test_level() >= 2,
'test with numpy version == 1.18.1')
def test_export_csanmt(self):
model = Model.from_pretrained('damo/nlp_csanmt_translation_en2zh_base')
print(
TfModelExporter.from_model(model).export_saved_model(
output_dir=self.tmp_dir))
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,47 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
from collections import OrderedDict
from modelscope.exporters import Exporter, TorchModelExporter
from modelscope.models import Model
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class TestExportSbertZeroShotClassification(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
self.model_id = 'damo/nlp_structbert_zero-shot-classification_chinese-base'
def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_export_sbert_sequence_classification(self):
model = Model.from_pretrained(self.model_id)
print(
Exporter.from_model(model).export_onnx(
candidate_labels=[
'文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'
],
hypothesis_template='这篇文章的标题是{}',
output_dir=self.tmp_dir))
print(
Exporter.from_model(model).export_torch_script(
candidate_labels=[
'文化', '体育', '娱乐', '财经', '家居', '汽车', '教育', '科技', '军事'
],
hypothesis_template='这篇文章的标题是{}',
output_dir=self.tmp_dir))
if __name__ == '__main__':
unittest.main()

View File

@@ -10,7 +10,6 @@ from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.preprocessing import image
from modelscope.exporters import TfModelExporter
from modelscope.utils.regress_test_utils import compare_arguments_nested
from modelscope.utils.test_utils import test_level

View File

@@ -11,7 +11,7 @@ import torch
from torch import nn
from modelscope.metainfo import Trainers
from modelscope.models.base import Model
from modelscope.models.base import TorchModel
from modelscope.trainers import build_trainer
from modelscope.utils.constant import LogKeys, ModelFile
from modelscope.utils.test_utils import create_dummy_test_dataset
@@ -20,7 +20,7 @@ dummy_dataset = create_dummy_test_dataset(
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20)
class DummyModel(nn.Module, Model):
class DummyModel(TorchModel):
def __init__(self):
super().__init__()

View File

@@ -11,7 +11,7 @@ from torch import nn
from modelscope.metainfo import Trainers
from modelscope.metrics.builder import METRICS, MetricKeys
from modelscope.models.base import Model
from modelscope.models.base import TorchModel
from modelscope.trainers import build_trainer
from modelscope.utils.constant import LogKeys, ModelFile
from modelscope.utils.registry import default_group
@@ -42,7 +42,7 @@ dummy_dataset = create_dummy_test_dataset(
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20)
class DummyModel(nn.Module, Model):
class DummyModel(TorchModel):
def __init__(self):
super().__init__()

View File

@@ -11,7 +11,7 @@ from torch import nn
from modelscope.metainfo import Trainers
from modelscope.metrics.builder import METRICS, MetricKeys
from modelscope.models.base import Model
from modelscope.models.base import TorchModel
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile
from modelscope.utils.registry import default_group
@@ -35,7 +35,7 @@ dummy_dataset = create_dummy_test_dataset(
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 20)
class DummyModel(nn.Module, Model):
class DummyModel(TorchModel):
def __init__(self):
super().__init__()

View File

@@ -13,7 +13,7 @@ from torch.optim.lr_scheduler import MultiStepLR
from modelscope.metainfo import Trainers
from modelscope.metrics.builder import METRICS, MetricKeys
from modelscope.models.base import Model
from modelscope.models.base import TorchModel
from modelscope.trainers import build_trainer
from modelscope.utils.constant import LogKeys, ModelFile, TrainerStages
from modelscope.utils.registry import default_group
@@ -41,7 +41,7 @@ def create_dummy_metric():
return {MetricKeys.ACCURACY: self._fake_acc_by_epoch[_global_iter]}
class DummyModel(nn.Module, Model):
class DummyModel(TorchModel):
def __init__(self):
super().__init__()

View File

@@ -16,7 +16,7 @@ from torch.utils.data import IterableDataset
from modelscope.metainfo import Metrics, Trainers
from modelscope.metrics.builder import MetricKeys
from modelscope.models.base import Model
from modelscope.models.base import TorchModel
from modelscope.trainers import build_trainer
from modelscope.trainers.base import DummyTrainer
from modelscope.trainers.builder import TRAINERS
@@ -41,7 +41,7 @@ dummy_dataset_big = create_dummy_test_dataset(
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40)
class DummyModel(nn.Module, Model):
class DummyModel(TorchModel):
def __init__(self):
super().__init__()

View File

@@ -15,7 +15,7 @@ from torch.utils.data import IterableDataset
from modelscope.metainfo import Metrics, Trainers
from modelscope.metrics.builder import MetricKeys
from modelscope.models.base import Model
from modelscope.models.base import Model, TorchModel
from modelscope.trainers import EpochBasedTrainer, build_trainer
from modelscope.utils.constant import LogKeys, ModeKeys, ModelFile, Tasks
from modelscope.utils.test_utils import (DistributedTestCase,
@@ -38,7 +38,7 @@ dummy_dataset_big = create_dummy_test_dataset(
np.random.random(size=(5, )), np.random.randint(0, 4, (1, )), 40)
class DummyModel(nn.Module, Model):
class DummyModel(TorchModel):
def __init__(self):
super().__init__()