mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9061073 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9061073 * [to #41669377] docs and tools refinement and release 1. add build_doc linter script 2. add sphinx-docs support 3. add development doc and api doc 4. change version to 0.1.0 for the first internal release version Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8775307 * [to #41669377] add pipeline tutorial and fix bugs 1. add pipleine tutorial 2. fix bugs when using pipeline with certain model and preprocessor Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8814301 * refine doc * refine doc * merge remote release/0.1 and fix conflict * Merge branch 'release/0.1' into 'nls/tts' Release/0.1 See merge request !1700968 * [Add] add tts preprocessor without requirements. finish requirements build later * [Add] add requirements and frd submodule * [Fix] remove models submodule * [Add] add am module * [Update] update am and vocoder * [Update] remove submodule * [Update] add models * [Fix] fix init error * [Fix] fix bugs with tts pipeline * merge master * [Update] merge from master * remove frd subdmoule and using wheel from oss * change scripts * [Fix] fix bugs in am and vocoder * [Merge] merge from master * Merge branch 'master' into nls/tts * [Fix] fix bugs * [Fix] fix pep8 * Merge branch 'master' into nls/tts * [Update] remove hparams and import configuration from kwargs * Merge branch 'master' into nls/tts * upgrade tf113 to tf115 * Merge branch 'nls/tts' of gitlab.alibaba-inc.com:Ali-MaaS/MaaS-lib into nls/tts * add multiple versions of ttsfrd * merge master * [Fix] fix cr comments * Merge branch 'master' into nls/tts * [Fix] fix cr comments 0617 * Merge branch 'master' into nls/tts * [Fix] remove comment out codes * [Merge] merge from master * [Fix] fix crash for incompatible tf and pytorch version, and frd using zip file resource * Merge branch 'master' into nls/tts * [Add] add cuda support
68 lines
2.3 KiB
Python
68 lines
2.3 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import os.path as osp
|
|
from abc import ABC, abstractmethod
|
|
from typing import Dict, Union
|
|
|
|
from maas_hub.snapshot_download import snapshot_download
|
|
|
|
from modelscope.models.builder import build_model
|
|
from modelscope.utils.config import Config
|
|
from modelscope.utils.constant import ModelFile
|
|
from modelscope.utils.hub import get_model_cache_dir
|
|
|
|
Tensor = Union['torch.Tensor', 'tf.Tensor']
|
|
|
|
|
|
class Model(ABC):
|
|
|
|
def __init__(self, model_dir, *args, **kwargs):
|
|
self.model_dir = model_dir
|
|
|
|
def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
return self.postprocess(self.forward(input))
|
|
|
|
@abstractmethod
|
|
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
|
pass
|
|
|
|
def postprocess(self, input: Dict[str, Tensor],
|
|
**kwargs) -> Dict[str, Tensor]:
|
|
""" Model specific postprocess and convert model output to
|
|
standard model outputs.
|
|
|
|
Args:
|
|
inputs: input data
|
|
|
|
Return:
|
|
dict of results: a dict containing outputs of model, each
|
|
output should have the standard output name.
|
|
"""
|
|
return input
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs):
|
|
""" Instantiate a model from local directory or remote model repo
|
|
"""
|
|
if osp.exists(model_name_or_path):
|
|
local_model_dir = model_name_or_path
|
|
else:
|
|
cache_path = get_model_cache_dir(model_name_or_path)
|
|
local_model_dir = cache_path if osp.exists(
|
|
cache_path) else snapshot_download(model_name_or_path)
|
|
# else:
|
|
# raise ValueError(
|
|
# 'Remote model repo {model_name_or_path} does not exists')
|
|
|
|
cfg = Config.from_file(
|
|
osp.join(local_model_dir, ModelFile.CONFIGURATION))
|
|
task_name = cfg.task
|
|
model_cfg = cfg.model
|
|
# TODO @wenmeng.zwm may should manually initialize model after model building
|
|
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
|
|
model_cfg.type = model_cfg.model_type
|
|
model_cfg.model_dir = local_model_dir
|
|
for k, v in kwargs.items():
|
|
model_cfg.k = v
|
|
return build_model(model_cfg, task_name)
|