Files
modelscope/modelscope/models/base.py
jiaqi.sjq b1490bfd7f [to #9061073] feat: merge tts to master
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9061073
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9061073

    * [to #41669377] docs and tools refinement and release 

1. add build_doc linter script
2. add sphinx-docs support
3. add development doc and api doc
4. change version to 0.1.0 for the first internal release version

Link: https://code.aone.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8775307

* [to #41669377] add pipeline tutorial and fix bugs 

1. add pipleine tutorial
2. fix bugs when using pipeline with certain model and preprocessor

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/8814301

* refine doc

* refine doc

* merge remote release/0.1 and fix conflict

* Merge branch 'release/0.1' into 'nls/tts'

Release/0.1



See merge request !1700968

* [Add] add tts preprocessor without requirements. finish requirements build later

* [Add] add requirements and frd submodule

* [Fix] remove models submodule

* [Add] add am module

* [Update] update am and vocoder

* [Update] remove submodule

* [Update] add models

* [Fix] fix init error

* [Fix] fix bugs with tts pipeline

* merge master

* [Update] merge from master

* remove frd subdmoule and using wheel from oss

* change scripts

* [Fix] fix bugs in am and vocoder

* [Merge] merge from master

* Merge branch 'master' into nls/tts

* [Fix] fix bugs

* [Fix] fix pep8

* Merge branch 'master' into nls/tts

* [Update] remove hparams and import configuration from kwargs

* Merge branch 'master' into nls/tts

* upgrade tf113 to tf115

* Merge branch 'nls/tts' of gitlab.alibaba-inc.com:Ali-MaaS/MaaS-lib into nls/tts

* add multiple versions of ttsfrd

* merge master

* [Fix] fix cr comments

* Merge branch 'master' into nls/tts

* [Fix] fix cr comments 0617

* Merge branch 'master' into nls/tts

* [Fix] remove comment out codes

* [Merge] merge from master

* [Fix] fix crash for incompatible tf and pytorch version, and frd using zip file resource

* Merge branch 'master' into nls/tts

* [Add] add cuda support
2022-06-20 17:23:11 +08:00

68 lines
2.3 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
from abc import ABC, abstractmethod
from typing import Dict, Union
from maas_hub.snapshot_download import snapshot_download
from modelscope.models.builder import build_model
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile
from modelscope.utils.hub import get_model_cache_dir
Tensor = Union['torch.Tensor', 'tf.Tensor']
class Model(ABC):
def __init__(self, model_dir, *args, **kwargs):
self.model_dir = model_dir
def __call__(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
return self.postprocess(self.forward(input))
@abstractmethod
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
pass
def postprocess(self, input: Dict[str, Tensor],
**kwargs) -> Dict[str, Tensor]:
""" Model specific postprocess and convert model output to
standard model outputs.
Args:
inputs: input data
Return:
dict of results: a dict containing outputs of model, each
output should have the standard output name.
"""
return input
@classmethod
def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs):
""" Instantiate a model from local directory or remote model repo
"""
if osp.exists(model_name_or_path):
local_model_dir = model_name_or_path
else:
cache_path = get_model_cache_dir(model_name_or_path)
local_model_dir = cache_path if osp.exists(
cache_path) else snapshot_download(model_name_or_path)
# else:
# raise ValueError(
# 'Remote model repo {model_name_or_path} does not exists')
cfg = Config.from_file(
osp.join(local_model_dir, ModelFile.CONFIGURATION))
task_name = cfg.task
model_cfg = cfg.model
# TODO @wenmeng.zwm may should manually initialize model after model building
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
model_cfg.type = model_cfg.model_type
model_cfg.model_dir = local_model_dir
for k, v in kwargs.items():
model_cfg.k = v
return build_model(model_cfg, task_name)