mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
Merge remote-tracking branch 'origin' into ofa/finetune
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
repos:
|
||||
- repo: https://gitlab.com/pycqa/flake8.git
|
||||
rev: 3.8.3
|
||||
rev: 4.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
exclude: thirdparty/|examples/
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
repos:
|
||||
- repo: /home/admin/pre-commit/flake8
|
||||
rev: 3.8.3
|
||||
rev: 4.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
exclude: thirdparty/|examples/
|
||||
|
||||
3
data/test/audios/asr_example_8K.wav
Normal file
3
data/test/audios/asr_example_8K.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e999c247bfebb03d556a31722f0ce7145cac20a67fac9da813ad336e1f549f9f
|
||||
size 38954
|
||||
3
data/test/audios/asr_example_cn_dialect.wav
Normal file
3
data/test/audios/asr_example_cn_dialect.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:32eb8d4d537941bf0edea69cd6723e8ba489fa3df64e13e29f96e4fae0b856f4
|
||||
size 93676
|
||||
3
data/test/audios/asr_example_cn_en.wav
Normal file
3
data/test/audios/asr_example_cn_en.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f57aee13ade70be6b2c6e4f5e5c7404bdb03057b63828baefbaadcf23855a4cb
|
||||
size 472012
|
||||
3
data/test/audios/asr_example_en.wav
Normal file
3
data/test/audios/asr_example_en.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fee8e0460ca707f108782be0d93c555bf34fb6b1cb297e5fceed70192cc65f9b
|
||||
size 71244
|
||||
3
data/test/audios/asr_example_es.wav
Normal file
3
data/test/audios/asr_example_es.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:450e31f9df8c5b48c617900625f01cb64c484f079a9843179fe9feaa7d163e61
|
||||
size 181964
|
||||
3
data/test/audios/asr_example_id.wav
Normal file
3
data/test/audios/asr_example_id.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:255494c41bc1dfb0c954d827ec6ce775900e4f7a55fb0a7881bdf9d66a03b425
|
||||
size 112078
|
||||
3
data/test/audios/asr_example_ja.wav
Normal file
3
data/test/audios/asr_example_ja.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:22a55277908bbc3ef60a0cf56b230eb507b9e837574e8f493e93644b1d21c281
|
||||
size 200556
|
||||
3
data/test/audios/asr_example_ko.wav
Normal file
3
data/test/audios/asr_example_ko.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ee92191836c76412463d8b282a7ab4e1aa57386ba699ec011a3e2c4d64f32f4b
|
||||
size 162636
|
||||
3
data/test/audios/asr_example_ru.wav
Normal file
3
data/test/audios/asr_example_ru.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:77d1537fc584c1505d8aa10ec8c86af57ab661199e4f28fd7ffee3c22d1e4e61
|
||||
size 160204
|
||||
3
data/test/regression/sbert-base-tnews.bin
Normal file
3
data/test/regression/sbert-base-tnews.bin
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2bce1341f4b55d536771dad6e2b280458579f46c3216474ceb8a926022ab53d0
|
||||
size 151572
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:44e3925c15d86d8596baeb6bd1d153d86f57b7489798b2cf988a1248e110fd62
|
||||
size 62231
|
||||
oid sha256:6af5024a26337a440c7ea2935fce84af558dd982ee97a2f027bb922cc874292b
|
||||
size 61741
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1ff17a0272752de4c88d4254b2e881f97f8ef022f03609d03ee1de0ae964368a
|
||||
size 62235
|
||||
oid sha256:bbce084781342ca7274c2e4d02ed5c5de43ba213a3b76328d5994404d6544c41
|
||||
size 61745
|
||||
|
||||
@@ -23,12 +23,14 @@ class SbertForSequenceClassificationExporter(TorchModelExporter):
|
||||
|
||||
def generate_dummy_inputs(self,
|
||||
shape: Tuple = None,
|
||||
pair: bool = False,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
"""Generate dummy inputs for model exportation to onnx or other formats by tracing.
|
||||
|
||||
@param shape: A tuple of input shape which should have at most two dimensions.
|
||||
shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor.
|
||||
shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor.
|
||||
@param pair: Generate sentence pairs or single sentences for dummy inputs.
|
||||
@return: Dummy inputs.
|
||||
"""
|
||||
|
||||
@@ -55,7 +57,7 @@ class SbertForSequenceClassificationExporter(TorchModelExporter):
|
||||
**sequence_length
|
||||
})
|
||||
preprocessor: Preprocessor = build_preprocessor(cfg, field_name)
|
||||
if preprocessor.pair:
|
||||
if pair:
|
||||
first_sequence = preprocessor.tokenizer.unk_token
|
||||
second_sequence = preprocessor.tokenizer.unk_token
|
||||
else:
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
# yapf: disable
|
||||
import datetime
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
from http.cookiejar import CookieJar
|
||||
@@ -16,17 +19,25 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA,
|
||||
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN,
|
||||
API_RESPONSE_FIELD_MESSAGE,
|
||||
API_RESPONSE_FIELD_USERNAME,
|
||||
DEFAULT_CREDENTIALS_PATH)
|
||||
DEFAULT_CREDENTIALS_PATH, Licenses,
|
||||
ModelVisibility)
|
||||
from modelscope.hub.errors import (InvalidParameter, NotExistError,
|
||||
NotLoginException, RequestError,
|
||||
datahub_raise_on_error,
|
||||
handle_http_post_error,
|
||||
handle_http_response, is_ok, raise_on_error)
|
||||
from modelscope.hub.git import GitCommandWrapper
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.hub.utils.utils import (get_endpoint,
|
||||
model_id_to_group_owner_name)
|
||||
from modelscope.utils.config_ds import DOWNLOADED_DATASETS_PATH
|
||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
DEFAULT_MODEL_REVISION,
|
||||
DatasetFormations, DatasetMetaFormats,
|
||||
DownloadMode)
|
||||
DownloadMode, ModelFile)
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .errors import (InvalidParameter, NotExistError, RequestError,
|
||||
datahub_raise_on_error, handle_http_post_error,
|
||||
handle_http_response, is_ok, raise_on_error)
|
||||
from .utils.utils import get_endpoint, model_id_to_group_owner_name
|
||||
|
||||
# yapf: enable
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -169,11 +180,106 @@ class HubApi:
|
||||
else:
|
||||
r.raise_for_status()
|
||||
|
||||
def list_model(self,
|
||||
owner_or_group: str,
|
||||
page_number=1,
|
||||
page_size=10) -> dict:
|
||||
"""List model in owner or group.
|
||||
def push_model(self,
|
||||
model_id: str,
|
||||
model_dir: str,
|
||||
visibility: int = ModelVisibility.PUBLIC,
|
||||
license: str = Licenses.APACHE_V2,
|
||||
chinese_name: Optional[str] = None,
|
||||
commit_message: Optional[str] = 'upload model',
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION):
|
||||
"""
|
||||
Upload model from a given directory to given repository. A valid model directory
|
||||
must contain a configuration.json file.
|
||||
|
||||
This function upload the files in given directory to given repository. If the
|
||||
given repository is not exists in remote, it will automatically create it with
|
||||
given visibility, license and chinese_name parameters. If the revision is also
|
||||
not exists in remote repository, it will create a new branch for it.
|
||||
|
||||
This function must be called before calling HubApi's login with a valid token
|
||||
which can be obtained from ModelScope's website.
|
||||
|
||||
Args:
|
||||
model_id (`str`):
|
||||
The model id to be uploaded, caller must have write permission for it.
|
||||
model_dir(`str`):
|
||||
The Absolute Path of the finetune result.
|
||||
visibility(`int`, defaults to `0`):
|
||||
Visibility of the new created model(1-private, 5-public). If the model is
|
||||
not exists in ModelScope, this function will create a new model with this
|
||||
visibility and this parameter is required. You can ignore this parameter
|
||||
if you make sure the model's existence.
|
||||
license(`str`, defaults to `None`):
|
||||
License of the new created model(see License). If the model is not exists
|
||||
in ModelScope, this function will create a new model with this license
|
||||
and this parameter is required. You can ignore this parameter if you
|
||||
make sure the model's existence.
|
||||
chinese_name(`str`, *optional*, defaults to `None`):
|
||||
chinese name of the new created model.
|
||||
commit_message(`str`, *optional*, defaults to `None`):
|
||||
commit message of the push request.
|
||||
revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION):
|
||||
which branch to push. If the branch is not exists, It will create a new
|
||||
branch and push to it.
|
||||
"""
|
||||
if model_id is None:
|
||||
raise InvalidParameter('model_id cannot be empty!')
|
||||
if model_dir is None:
|
||||
raise InvalidParameter('model_dir cannot be empty!')
|
||||
if not os.path.exists(model_dir) or os.path.isfile(model_dir):
|
||||
raise InvalidParameter('model_dir must be a valid directory.')
|
||||
cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
|
||||
if not os.path.exists(cfg_file):
|
||||
raise ValueError(f'{model_dir} must contain a configuration.json.')
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise NotLoginException('Must login before upload!')
|
||||
files_to_save = os.listdir(model_dir)
|
||||
try:
|
||||
self.get_model(model_id=model_id)
|
||||
except Exception:
|
||||
if visibility is None or license is None:
|
||||
raise InvalidParameter(
|
||||
'visibility and license cannot be empty if want to create new repo'
|
||||
)
|
||||
logger.info('Create new model %s' % model_id)
|
||||
self.create_model(
|
||||
model_id=model_id,
|
||||
visibility=visibility,
|
||||
license=license,
|
||||
chinese_name=chinese_name)
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
git_wrapper = GitCommandWrapper()
|
||||
try:
|
||||
repo = Repository(model_dir=tmp_dir, clone_from=model_id)
|
||||
branches = git_wrapper.get_remote_branches(tmp_dir)
|
||||
if revision not in branches:
|
||||
logger.info('Create new branch %s' % revision)
|
||||
git_wrapper.new_branch(tmp_dir, revision)
|
||||
git_wrapper.checkout(tmp_dir, revision)
|
||||
for f in files_to_save:
|
||||
if f[0] != '.':
|
||||
src = os.path.join(model_dir, f)
|
||||
if os.path.isdir(src):
|
||||
shutil.copytree(src, os.path.join(tmp_dir, f))
|
||||
else:
|
||||
shutil.copy(src, tmp_dir)
|
||||
if not commit_message:
|
||||
date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
|
||||
commit_message = '[automsg] push model %s to hub at %s' % (
|
||||
model_id, date)
|
||||
repo.push(commit_message=commit_message, branch=revision)
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
|
||||
def list_models(self,
|
||||
owner_or_group: str,
|
||||
page_number=1,
|
||||
page_size=10) -> dict:
|
||||
"""List models in owner or group.
|
||||
|
||||
Args:
|
||||
owner_or_group(`str`): owner or group.
|
||||
@@ -390,11 +496,13 @@ class HubApi:
|
||||
return resp['Data']
|
||||
|
||||
def list_oss_dataset_objects(self, dataset_name, namespace, max_limit,
|
||||
is_recursive, is_filter_dir, revision,
|
||||
cookies):
|
||||
is_recursive, is_filter_dir, revision):
|
||||
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \
|
||||
f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}'
|
||||
cookies = requests.utils.dict_from_cookiejar(cookies)
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies:
|
||||
cookies = requests.utils.dict_from_cookiejar(cookies)
|
||||
|
||||
resp = requests.get(url=url, cookies=cookies)
|
||||
resp = resp.json()
|
||||
|
||||
@@ -11,13 +11,12 @@ from typing import Dict, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
from filelock import FileLock
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope import __version__
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .api import HubApi, ModelScopeConfig
|
||||
from .constants import FILE_HASH
|
||||
from .errors import FileDownloadError, NotExistError
|
||||
from .utils.caching import ModelFileSystemCache
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from typing import List
|
||||
from xmlrpc.client import Boolean
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .api import ModelScopeConfig
|
||||
from .errors import GitError
|
||||
|
||||
logger = get_logger()
|
||||
@@ -132,6 +129,7 @@ class GitCommandWrapper(metaclass=Singleton):
|
||||
return response
|
||||
|
||||
def add_user_info(self, repo_base_dir, repo_name):
|
||||
from modelscope.hub.api import ModelScopeConfig
|
||||
user_name, user_email = ModelScopeConfig.get_user_info()
|
||||
if user_name and user_email:
|
||||
# config user.name and user.email if exist
|
||||
@@ -184,8 +182,11 @@ class GitCommandWrapper(metaclass=Singleton):
|
||||
info = [
|
||||
line.strip()
|
||||
for line in rsp.stdout.decode('utf8').strip().split(os.linesep)
|
||||
][1:]
|
||||
return ['/'.join(line.split('/')[1:]) for line in info]
|
||||
]
|
||||
if len(info) == 1:
|
||||
return ['/'.join(info[0].split('/')[1:])]
|
||||
else:
|
||||
return ['/'.join(line.split('/')[1:]) for line in info[1:]]
|
||||
|
||||
def pull(self, repo_dir: str):
|
||||
cmds = ['-C', repo_dir, 'pull']
|
||||
|
||||
@@ -7,7 +7,6 @@ from modelscope.hub.errors import GitError, InvalidParameter, NotLoginException
|
||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
DEFAULT_MODEL_REVISION)
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .api import ModelScopeConfig
|
||||
from .git import GitCommandWrapper
|
||||
from .utils.utils import get_endpoint
|
||||
|
||||
@@ -47,6 +46,7 @@ class Repository:
|
||||
err_msg = 'a non-default value of revision cannot be empty.'
|
||||
raise InvalidParameter(err_msg)
|
||||
|
||||
from modelscope.hub.api import ModelScopeConfig
|
||||
if auth_token:
|
||||
self.auth_token = auth_token
|
||||
else:
|
||||
@@ -166,7 +166,7 @@ class DatasetRepository:
|
||||
err_msg = 'a non-default value of revision cannot be empty.'
|
||||
raise InvalidParameter(err_msg)
|
||||
self.revision = revision
|
||||
|
||||
from modelscope.hub.api import ModelScopeConfig
|
||||
if auth_token:
|
||||
self.auth_token = auth_token
|
||||
else:
|
||||
|
||||
@@ -5,9 +5,9 @@ import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .api import HubApi, ModelScopeConfig
|
||||
from .constants import FILE_HASH
|
||||
from .errors import NotExistError
|
||||
from .file_download import (get_file_download_url, http_get_file,
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import datetime
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from typing import Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
from modelscope import __version__
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.errors import InvalidParameter, NotLoginException
|
||||
from modelscope.hub.git import GitCommandWrapper
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def upload_folder(model_id: str,
|
||||
model_dir: str,
|
||||
visibility: int = 0,
|
||||
license: str = None,
|
||||
chinese_name: Optional[str] = None,
|
||||
commit_message: Optional[str] = None,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION):
|
||||
"""
|
||||
Upload model from a given directory to given repository. A valid model directory
|
||||
must contain a configuration.json file.
|
||||
|
||||
This function upload the files in given directory to given repository. If the
|
||||
given repository is not exists in remote, it will automatically create it with
|
||||
given visibility, license and chinese_name parameters. If the revision is also
|
||||
not exists in remote repository, it will create a new branch for it.
|
||||
|
||||
This function must be called before calling HubApi's login with a valid token
|
||||
which can be obtained from ModelScope's website.
|
||||
|
||||
Args:
|
||||
model_id (`str`):
|
||||
The model id to be uploaded, caller must have write permission for it.
|
||||
model_dir(`str`):
|
||||
The Absolute Path of the finetune result.
|
||||
visibility(`int`, defaults to `0`):
|
||||
Visibility of the new created model(1-private, 5-public). If the model is
|
||||
not exists in ModelScope, this function will create a new model with this
|
||||
visibility and this parameter is required. You can ignore this parameter
|
||||
if you make sure the model's existence.
|
||||
license(`str`, defaults to `None`):
|
||||
License of the new created model(see License). If the model is not exists
|
||||
in ModelScope, this function will create a new model with this license
|
||||
and this parameter is required. You can ignore this parameter if you
|
||||
make sure the model's existence.
|
||||
chinese_name(`str`, *optional*, defaults to `None`):
|
||||
chinese name of the new created model.
|
||||
commit_message(`str`, *optional*, defaults to `None`):
|
||||
commit message of the push request.
|
||||
revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION):
|
||||
which branch to push. If the branch is not exists, It will create a new
|
||||
branch and push to it.
|
||||
"""
|
||||
if model_id is None:
|
||||
raise InvalidParameter('model_id cannot be empty!')
|
||||
if model_dir is None:
|
||||
raise InvalidParameter('model_dir cannot be empty!')
|
||||
if not os.path.exists(model_dir) or os.path.isfile(model_dir):
|
||||
raise InvalidParameter('model_dir must be a valid directory.')
|
||||
cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
|
||||
if not os.path.exists(cfg_file):
|
||||
raise ValueError(f'{model_dir} must contain a configuration.json.')
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise NotLoginException('Must login before upload!')
|
||||
files_to_save = os.listdir(model_dir)
|
||||
api = HubApi()
|
||||
try:
|
||||
api.get_model(model_id=model_id)
|
||||
except Exception:
|
||||
if visibility is None or license is None:
|
||||
raise InvalidParameter(
|
||||
'visibility and license cannot be empty if want to create new repo'
|
||||
)
|
||||
logger.info('Create new model %s' % model_id)
|
||||
api.create_model(
|
||||
model_id=model_id,
|
||||
visibility=visibility,
|
||||
license=license,
|
||||
chinese_name=chinese_name)
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
git_wrapper = GitCommandWrapper()
|
||||
try:
|
||||
repo = Repository(model_dir=tmp_dir, clone_from=model_id)
|
||||
branches = git_wrapper.get_remote_branches(tmp_dir)
|
||||
if revision not in branches:
|
||||
logger.info('Create new branch %s' % revision)
|
||||
git_wrapper.new_branch(tmp_dir, revision)
|
||||
git_wrapper.checkout(tmp_dir, revision)
|
||||
for f in files_to_save:
|
||||
if f[0] != '.':
|
||||
src = os.path.join(model_dir, f)
|
||||
if os.path.isdir(src):
|
||||
shutil.copytree(src, os.path.join(tmp_dir, f))
|
||||
else:
|
||||
shutil.copy(src, tmp_dir)
|
||||
if not commit_message:
|
||||
date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
|
||||
commit_message = '[automsg] push model %s to hub at %s' % (
|
||||
model_id, date)
|
||||
repo.push(commit_message=commit_message, branch=revision)
|
||||
except Exception:
|
||||
raise
|
||||
finally:
|
||||
shutil.rmtree(tmp_dir, ignore_errors=True)
|
||||
@@ -9,7 +9,9 @@ class Models(object):
|
||||
|
||||
Model name should only contain model info but not task info.
|
||||
"""
|
||||
# tinynas models
|
||||
tinynas_detection = 'tinynas-detection'
|
||||
tinynas_damoyolo = 'tinynas-damoyolo'
|
||||
|
||||
# vision models
|
||||
detection = 'detection'
|
||||
@@ -454,9 +456,9 @@ class Datasets(object):
|
||||
""" Names for different datasets.
|
||||
"""
|
||||
ClsDataset = 'ClsDataset'
|
||||
Face2dKeypointsDataset = 'Face2dKeypointsDataset'
|
||||
Face2dKeypointsDataset = 'FaceKeypointDataset'
|
||||
HandCocoWholeBodyDataset = 'HandCocoWholeBodyDataset'
|
||||
HumanWholeBodyKeypointDataset = 'HumanWholeBodyKeypointDataset'
|
||||
HumanWholeBodyKeypointDataset = 'WholeBodyCocoTopDownDataset'
|
||||
SegDataset = 'SegDataset'
|
||||
DetDataset = 'DetDataset'
|
||||
DetImagesMixDataset = 'DetImagesMixDataset'
|
||||
|
||||
@@ -32,6 +32,7 @@ task_default_metrics = {
|
||||
Tasks.sentiment_classification: [Metrics.seq_cls_metric],
|
||||
Tasks.token_classification: [Metrics.token_cls_metric],
|
||||
Tasks.text_generation: [Metrics.text_gen_metric],
|
||||
Tasks.text_classification: [Metrics.seq_cls_metric],
|
||||
Tasks.image_denoising: [Metrics.image_denoise_metric],
|
||||
Tasks.image_color_enhancement: [Metrics.image_color_enhance_metric],
|
||||
Tasks.image_portrait_enhancement:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import os
|
||||
import pickle as pkl
|
||||
from threading import Lock
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
@@ -27,6 +28,7 @@ class Voice:
|
||||
self.__am_config = AttrDict(**am_config)
|
||||
self.__voc_config = AttrDict(**voc_config)
|
||||
self.__model_loaded = False
|
||||
self.__lock = Lock()
|
||||
if 'am' not in self.__am_config:
|
||||
raise TtsModelConfigurationException(
|
||||
'modelscope error: am configuration invalid')
|
||||
@@ -71,34 +73,35 @@ class Voice:
|
||||
self.__generator.remove_weight_norm()
|
||||
|
||||
def __am_forward(self, symbol_seq):
|
||||
with torch.no_grad():
|
||||
inputs_feat_lst = self.__ling_unit.encode_symbol_sequence(
|
||||
symbol_seq)
|
||||
inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to(
|
||||
self.__device)
|
||||
inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to(
|
||||
self.__device)
|
||||
inputs_syllable = torch.from_numpy(inputs_feat_lst[2]).long().to(
|
||||
self.__device)
|
||||
inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to(
|
||||
self.__device)
|
||||
inputs_ling = torch.stack(
|
||||
[inputs_sy, inputs_tone, inputs_syllable, inputs_ws],
|
||||
dim=-1).unsqueeze(0)
|
||||
inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to(
|
||||
self.__device).unsqueeze(0)
|
||||
inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to(
|
||||
self.__device).unsqueeze(0)
|
||||
inputs_len = torch.zeros(1).to(self.__device).long(
|
||||
) + inputs_emo.size(1) - 1 # minus 1 for "~"
|
||||
res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1],
|
||||
inputs_spk[:, :-1], inputs_len)
|
||||
postnet_outputs = res['postnet_outputs']
|
||||
LR_length_rounded = res['LR_length_rounded']
|
||||
valid_length = int(LR_length_rounded[0].item())
|
||||
postnet_outputs = postnet_outputs[
|
||||
0, :valid_length, :].cpu().numpy()
|
||||
return postnet_outputs
|
||||
with self.__lock:
|
||||
with torch.no_grad():
|
||||
inputs_feat_lst = self.__ling_unit.encode_symbol_sequence(
|
||||
symbol_seq)
|
||||
inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to(
|
||||
self.__device)
|
||||
inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to(
|
||||
self.__device)
|
||||
inputs_syllable = torch.from_numpy(
|
||||
inputs_feat_lst[2]).long().to(self.__device)
|
||||
inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to(
|
||||
self.__device)
|
||||
inputs_ling = torch.stack(
|
||||
[inputs_sy, inputs_tone, inputs_syllable, inputs_ws],
|
||||
dim=-1).unsqueeze(0)
|
||||
inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to(
|
||||
self.__device).unsqueeze(0)
|
||||
inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to(
|
||||
self.__device).unsqueeze(0)
|
||||
inputs_len = torch.zeros(1).to(self.__device).long(
|
||||
) + inputs_emo.size(1) - 1 # minus 1 for "~"
|
||||
res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1],
|
||||
inputs_spk[:, :-1], inputs_len)
|
||||
postnet_outputs = res['postnet_outputs']
|
||||
LR_length_rounded = res['LR_length_rounded']
|
||||
valid_length = int(LR_length_rounded[0].item())
|
||||
postnet_outputs = postnet_outputs[
|
||||
0, :valid_length, :].cpu().numpy()
|
||||
return postnet_outputs
|
||||
|
||||
def __vocoder_forward(self, melspec):
|
||||
dim0 = list(melspec.shape)[-1]
|
||||
@@ -118,14 +121,15 @@ class Voice:
|
||||
return audio
|
||||
|
||||
def forward(self, symbol_seq):
|
||||
if not self.__model_loaded:
|
||||
torch.manual_seed(self.__am_config.seed)
|
||||
if torch.cuda.is_available():
|
||||
with self.__lock:
|
||||
if not self.__model_loaded:
|
||||
torch.manual_seed(self.__am_config.seed)
|
||||
self.__device = torch.device('cuda')
|
||||
else:
|
||||
self.__device = torch.device('cpu')
|
||||
self.__load_am()
|
||||
self.__load_vocoder()
|
||||
self.__model_loaded = True
|
||||
if torch.cuda.is_available():
|
||||
torch.manual_seed(self.__am_config.seed)
|
||||
self.__device = torch.device('cuda')
|
||||
else:
|
||||
self.__device = torch.device('cpu')
|
||||
self.__load_am()
|
||||
self.__load_vocoder()
|
||||
self.__model_loaded = True
|
||||
return self.__vocoder_forward(self.__am_forward(symbol_seq))
|
||||
|
||||
@@ -93,7 +93,7 @@ class TextDrivenSeg(TorchModel):
|
||||
"""
|
||||
with torch.no_grad():
|
||||
if self.device_id == -1:
|
||||
output = self.model(image)
|
||||
output = self.model(image, [text])
|
||||
else:
|
||||
device = torch.device('cuda', self.device_id)
|
||||
output = self.model(image.to(device), [text])
|
||||
|
||||
@@ -7,10 +7,12 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .tinynas_detector import Tinynas_detector
|
||||
from .tinynas_damoyolo import DamoYolo
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'tinynas_detector': ['TinynasDetector'],
|
||||
'tinynas_damoyolo': ['DamoYolo'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modelscope.utils.file_utils import read_file
|
||||
from ..core.base_ops import Focus, SPPBottleneck, get_activation
|
||||
from ..core.repvgg_block import RepVggBlock
|
||||
|
||||
@@ -49,12 +50,16 @@ class ResConvK1KX(nn.Module):
|
||||
kernel_size,
|
||||
stride,
|
||||
force_resproj=False,
|
||||
act='silu'):
|
||||
act='silu',
|
||||
reparam=False):
|
||||
super(ResConvK1KX, self).__init__()
|
||||
self.stride = stride
|
||||
self.conv1 = ConvKXBN(in_c, btn_c, 1, 1)
|
||||
self.conv2 = RepVggBlock(
|
||||
btn_c, out_c, kernel_size, stride, act='identity')
|
||||
if not reparam:
|
||||
self.conv2 = ConvKXBN(btn_c, out_c, 3, stride)
|
||||
else:
|
||||
self.conv2 = RepVggBlock(
|
||||
btn_c, out_c, kernel_size, stride, act='identity')
|
||||
|
||||
if act is None:
|
||||
self.activation_function = torch.relu
|
||||
@@ -97,7 +102,8 @@ class SuperResConvK1KX(nn.Module):
|
||||
stride,
|
||||
num_blocks,
|
||||
with_spp=False,
|
||||
act='silu'):
|
||||
act='silu',
|
||||
reparam=False):
|
||||
super(SuperResConvK1KX, self).__init__()
|
||||
if act is None:
|
||||
self.act = torch.relu
|
||||
@@ -124,7 +130,8 @@ class SuperResConvK1KX(nn.Module):
|
||||
this_kernel_size,
|
||||
this_stride,
|
||||
force_resproj,
|
||||
act=act)
|
||||
act=act,
|
||||
reparam=reparam)
|
||||
self.block_list.append(the_block)
|
||||
if block_id == 0 and with_spp:
|
||||
self.block_list.append(
|
||||
@@ -248,7 +255,8 @@ class TinyNAS(nn.Module):
|
||||
with_spp=False,
|
||||
use_focus=False,
|
||||
need_conv1=True,
|
||||
act='silu'):
|
||||
act='silu',
|
||||
reparam=False):
|
||||
super(TinyNAS, self).__init__()
|
||||
assert len(out_indices) == len(out_channels)
|
||||
self.out_indices = out_indices
|
||||
@@ -281,7 +289,8 @@ class TinyNAS(nn.Module):
|
||||
block_info['s'],
|
||||
block_info['L'],
|
||||
spp,
|
||||
act=act)
|
||||
act=act,
|
||||
reparam=reparam)
|
||||
self.block_list.append(the_block)
|
||||
elif the_block_class == 'SuperResConvKXKX':
|
||||
spp = with_spp if idx == len(structure_info) - 1 else False
|
||||
@@ -325,8 +334,8 @@ class TinyNAS(nn.Module):
|
||||
def load_tinynas_net(backbone_cfg):
|
||||
# load masternet model to path
|
||||
import ast
|
||||
|
||||
struct_str = ''.join([x.strip() for x in backbone_cfg.net_structure_str])
|
||||
net_structure_str = read_file(backbone_cfg.structure_file)
|
||||
struct_str = ''.join([x.strip() for x in net_structure_str])
|
||||
struct_info = ast.literal_eval(struct_str)
|
||||
for layer in struct_info:
|
||||
if 'nbitsA' in layer:
|
||||
@@ -342,6 +351,6 @@ def load_tinynas_net(backbone_cfg):
|
||||
use_focus=backbone_cfg.use_focus,
|
||||
act=backbone_cfg.act,
|
||||
need_conv1=backbone_cfg.need_conv1,
|
||||
)
|
||||
reparam=backbone_cfg.reparam)
|
||||
|
||||
return model
|
||||
|
||||
@@ -30,7 +30,7 @@ class SingleStageDetector(TorchModel):
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
|
||||
config_path = osp.join(model_dir, 'airdet_s.py')
|
||||
config_path = osp.join(model_dir, self.config_name)
|
||||
config = parse_config(config_path)
|
||||
self.cfg = config
|
||||
model_path = osp.join(model_dir, config.model.name)
|
||||
@@ -41,6 +41,9 @@ class SingleStageDetector(TorchModel):
|
||||
self.conf_thre = config.model.head.nms_conf_thre
|
||||
self.nms_thre = config.model.head.nms_iou_thre
|
||||
|
||||
if self.cfg.model.backbone.name == 'TinyNAS':
|
||||
self.cfg.model.backbone.structure_file = osp.join(
|
||||
model_dir, self.cfg.model.backbone.structure_file)
|
||||
self.backbone = build_backbone(self.cfg.model.backbone)
|
||||
self.neck = build_neck(self.cfg.model.neck)
|
||||
self.head = build_head(self.cfg.model.head)
|
||||
|
||||
@@ -124,11 +124,13 @@ class GFocalHead_Tiny(nn.Module):
|
||||
simOTA_iou_weight=3.0,
|
||||
octbase=8,
|
||||
simlqe=False,
|
||||
use_lqe=True,
|
||||
**kwargs):
|
||||
self.simlqe = simlqe
|
||||
self.num_classes = num_classes
|
||||
self.in_channels = in_channels
|
||||
self.strides = strides
|
||||
self.use_lqe = use_lqe
|
||||
self.feat_channels = feat_channels if isinstance(feat_channels, list) \
|
||||
else [feat_channels] * len(self.strides)
|
||||
|
||||
@@ -181,15 +183,20 @@ class GFocalHead_Tiny(nn.Module):
|
||||
groups=self.conv_groups,
|
||||
norm=self.norm,
|
||||
act=self.act))
|
||||
if not self.simlqe:
|
||||
conf_vector = [nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)]
|
||||
if self.use_lqe:
|
||||
if not self.simlqe:
|
||||
conf_vector = [
|
||||
nn.Conv2d(4 * self.total_dim, self.reg_channels, 1)
|
||||
]
|
||||
else:
|
||||
conf_vector = [
|
||||
nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1)
|
||||
]
|
||||
conf_vector += [self.relu]
|
||||
conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()]
|
||||
reg_conf = nn.Sequential(*conf_vector)
|
||||
else:
|
||||
conf_vector = [
|
||||
nn.Conv2d(4 * (self.reg_max + 1), self.reg_channels, 1)
|
||||
]
|
||||
conf_vector += [self.relu]
|
||||
conf_vector += [nn.Conv2d(self.reg_channels, 1, 1), nn.Sigmoid()]
|
||||
reg_conf = nn.Sequential(*conf_vector)
|
||||
reg_conf = None
|
||||
|
||||
return cls_convs, reg_convs, reg_conf
|
||||
|
||||
@@ -290,21 +297,27 @@ class GFocalHead_Tiny(nn.Module):
|
||||
N, C, H, W = bbox_pred.size()
|
||||
prob = F.softmax(
|
||||
bbox_pred.reshape(N, 4, self.reg_max + 1, H, W), dim=2)
|
||||
if not self.simlqe:
|
||||
prob_topk, _ = prob.topk(self.reg_topk, dim=2)
|
||||
if self.use_lqe:
|
||||
if not self.simlqe:
|
||||
prob_topk, _ = prob.topk(self.reg_topk, dim=2)
|
||||
|
||||
if self.add_mean:
|
||||
stat = torch.cat(
|
||||
[prob_topk, prob_topk.mean(dim=2, keepdim=True)], dim=2)
|
||||
if self.add_mean:
|
||||
stat = torch.cat(
|
||||
[prob_topk,
|
||||
prob_topk.mean(dim=2, keepdim=True)],
|
||||
dim=2)
|
||||
else:
|
||||
stat = prob_topk
|
||||
|
||||
quality_score = reg_conf(
|
||||
stat.reshape(N, 4 * self.total_dim, H, W))
|
||||
else:
|
||||
stat = prob_topk
|
||||
quality_score = reg_conf(
|
||||
bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W))
|
||||
|
||||
quality_score = reg_conf(stat.reshape(N, 4 * self.total_dim, H, W))
|
||||
cls_score = gfl_cls(cls_feat).sigmoid() * quality_score
|
||||
else:
|
||||
quality_score = reg_conf(
|
||||
bbox_pred.reshape(N, 4 * (self.reg_max + 1), H, W))
|
||||
|
||||
cls_score = gfl_cls(cls_feat).sigmoid() * quality_score
|
||||
cls_score = gfl_cls(cls_feat).sigmoid()
|
||||
|
||||
flatten_cls_score = cls_score.flatten(start_dim=2).transpose(1, 2)
|
||||
flatten_bbox_pred = bbox_pred.flatten(start_dim=2).transpose(1, 2)
|
||||
|
||||
@@ -14,7 +14,6 @@ class GiraffeNeckV2(nn.Module):
|
||||
self,
|
||||
depth=1.0,
|
||||
width=1.0,
|
||||
in_features=[2, 3, 4],
|
||||
in_channels=[256, 512, 1024],
|
||||
out_channels=[256, 512, 1024],
|
||||
depthwise=False,
|
||||
@@ -24,7 +23,6 @@ class GiraffeNeckV2(nn.Module):
|
||||
block_name='BasicBlock',
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.in_channels = in_channels
|
||||
Conv = DWConv if depthwise else BaseConv
|
||||
|
||||
@@ -169,8 +167,7 @@ class GiraffeNeckV2(nn.Module):
|
||||
"""
|
||||
|
||||
# backbone
|
||||
features = [out_features[f] for f in self.in_features]
|
||||
[x2, x1, x0] = features
|
||||
[x2, x1, x0] = out_features
|
||||
|
||||
# node x3
|
||||
x13 = self.bu_conv13(x1)
|
||||
|
||||
15
modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py
Normal file
15
modelscope/models/cv/tinynas_detection/tinynas_damoyolo.py
Normal file
@@ -0,0 +1,15 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import Tasks
|
||||
from .detector import SingleStageDetector
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.image_object_detection, module_name=Models.tinynas_damoyolo)
|
||||
class DamoYolo(SingleStageDetector):
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
self.config_name = 'damoyolo_s.py'
|
||||
super(DamoYolo, self).__init__(model_dir, *args, **kwargs)
|
||||
@@ -12,5 +12,5 @@ from .detector import SingleStageDetector
|
||||
class TinynasDetector(SingleStageDetector):
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
|
||||
self.config_name = 'airdet_s.py'
|
||||
super(TinynasDetector, self).__init__(model_dir, *args, **kwargs)
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
"""PyTorch BERT model. """
|
||||
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
@@ -41,7 +40,6 @@ from transformers.modeling_utils import (PreTrainedModel,
|
||||
find_pruneable_heads_and_indices,
|
||||
prune_linear_layer)
|
||||
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .configuration_bert import BertConfig
|
||||
|
||||
@@ -50,81 +48,6 @@ logger = get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = 'BertConfig'
|
||||
|
||||
|
||||
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
"""Load tf checkpoints in a pytorch model."""
|
||||
try:
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
logger.error(
|
||||
'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see '
|
||||
'https://www.tensorflow.org/install/ for installation instructions.'
|
||||
)
|
||||
raise
|
||||
tf_path = os.path.abspath(tf_checkpoint_path)
|
||||
logger.info(f'Converting TensorFlow checkpoint from {tf_path}')
|
||||
# Load weights from TF model
|
||||
init_vars = tf.train.list_variables(tf_path)
|
||||
names = []
|
||||
arrays = []
|
||||
for name, shape in init_vars:
|
||||
logger.info(f'Loading TF weight {name} with shape {shape}')
|
||||
array = tf.train.load_variable(tf_path, name)
|
||||
names.append(name)
|
||||
arrays.append(array)
|
||||
|
||||
for name, array in zip(names, arrays):
|
||||
name = name.split('/')
|
||||
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
||||
# which are not required for using pretrained model
|
||||
if any(n in [
|
||||
'adam_v', 'adam_m', 'AdamWeightDecayOptimizer',
|
||||
'AdamWeightDecayOptimizer_1', 'global_step'
|
||||
] for n in name):
|
||||
logger.info(f"Skipping {'/'.join(name)}")
|
||||
continue
|
||||
pointer = model
|
||||
for m_name in name:
|
||||
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
|
||||
scope_names = re.split(r'_(\d+)', m_name)
|
||||
else:
|
||||
scope_names = [m_name]
|
||||
if scope_names[0] == 'kernel' or scope_names[0] == 'gamma':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta':
|
||||
pointer = getattr(pointer, 'bias')
|
||||
elif scope_names[0] == 'output_weights':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif scope_names[0] == 'squad':
|
||||
pointer = getattr(pointer, 'classifier')
|
||||
else:
|
||||
try:
|
||||
pointer = getattr(pointer, scope_names[0])
|
||||
except AttributeError:
|
||||
logger.info(f"Skipping {'/'.join(name)}")
|
||||
continue
|
||||
if len(scope_names) >= 2:
|
||||
num = int(scope_names[1])
|
||||
pointer = pointer[num]
|
||||
if m_name[-11:] == '_embeddings':
|
||||
pointer = getattr(pointer, 'weight')
|
||||
elif m_name == 'kernel':
|
||||
array = np.transpose(array)
|
||||
try:
|
||||
if pointer.shape != array.shape:
|
||||
raise ValueError(
|
||||
f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched'
|
||||
)
|
||||
except AssertionError as e:
|
||||
e.args += (pointer.shape, array.shape)
|
||||
raise
|
||||
logger.info(f'Initialize PyTorch weight {name}')
|
||||
pointer.data = torch.from_numpy(array)
|
||||
return model
|
||||
|
||||
|
||||
class BertEmbeddings(nn.Module):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
@@ -750,7 +673,6 @@ class BertPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
|
||||
config_class = BertConfig
|
||||
load_tf_weights = load_tf_weights_in_bert
|
||||
base_model_prefix = 'bert'
|
||||
supports_gradient_checkpointing = True
|
||||
_keys_to_ignore_on_load_missing = [r'position_ids']
|
||||
|
||||
@@ -26,11 +26,16 @@ class EasyCVBaseDataset(object):
|
||||
if self.split_config is not None:
|
||||
self._update_data_source(kwargs['data_source'])
|
||||
|
||||
def _update_data_root(self, input_dict, data_root):
|
||||
for k, v in input_dict.items():
|
||||
if isinstance(v, str) and self.DATA_ROOT_PATTERN in v:
|
||||
input_dict.update(
|
||||
{k: v.replace(self.DATA_ROOT_PATTERN, data_root)})
|
||||
elif isinstance(v, dict):
|
||||
self._update_data_root(v, data_root)
|
||||
|
||||
def _update_data_source(self, data_source):
|
||||
data_root = next(iter(self.split_config.values()))
|
||||
data_root = data_root.rstrip(osp.sep)
|
||||
|
||||
for k, v in data_source.items():
|
||||
if isinstance(v, str) and self.DATA_ROOT_PATTERN in v:
|
||||
data_source.update(
|
||||
{k: v.replace(self.DATA_ROOT_PATTERN, data_root)})
|
||||
self._update_data_root(data_source, data_root)
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any, Mapping, Optional, Sequence, Union
|
||||
from datasets.builder import DatasetBuilder
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.utils.constant import DEFAULT_DATASET_REVISION, DownloadParams
|
||||
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .dataset_builder import MsCsvDatasetBuilder, TaskSpecificDatasetBuilder
|
||||
|
||||
@@ -95,15 +95,13 @@ def list_dataset_objects(hub_api: HubApi, max_limit: int, is_recursive: bool,
|
||||
res (list): List of objects, i.e., ['train/images/001.png', 'train/images/002.png', 'val/images/001.png', ...]
|
||||
"""
|
||||
res = []
|
||||
cookies = hub_api.check_cookies_upload_data(use_cookies=True)
|
||||
objects = hub_api.list_oss_dataset_objects(
|
||||
dataset_name=dataset_name,
|
||||
namespace=namespace,
|
||||
max_limit=max_limit,
|
||||
is_recursive=is_recursive,
|
||||
is_filter_dir=True,
|
||||
revision=version,
|
||||
cookies=cookies)
|
||||
revision=version)
|
||||
|
||||
for item in objects:
|
||||
object_key = item.get('Key')
|
||||
@@ -174,7 +172,7 @@ def get_dataset_files(subset_split_into: dict,
|
||||
modelscope_api = HubApi()
|
||||
objects = list_dataset_objects(
|
||||
hub_api=modelscope_api,
|
||||
max_limit=DownloadParams.MAX_LIST_OBJECTS_NUM.value,
|
||||
max_limit=-1,
|
||||
is_recursive=True,
|
||||
dataset_name=dataset_name,
|
||||
namespace=namespace,
|
||||
|
||||
@@ -47,22 +47,28 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
|
||||
if isinstance(audio_in, str):
|
||||
# load pcm data from url if audio_in is url str
|
||||
self.audio_in = load_bytes_from_url(audio_in)
|
||||
self.audio_in, checking_audio_fs = load_bytes_from_url(audio_in)
|
||||
elif isinstance(audio_in, bytes):
|
||||
# load pcm data from wav data if audio_in is wave format
|
||||
self.audio_in = extract_pcm_from_wav(audio_in)
|
||||
self.audio_in, checking_audio_fs = extract_pcm_from_wav(audio_in)
|
||||
else:
|
||||
self.audio_in = audio_in
|
||||
|
||||
# set the sample_rate of audio_in if checking_audio_fs is valid
|
||||
if checking_audio_fs is not None:
|
||||
self.audio_fs = checking_audio_fs
|
||||
|
||||
if recog_type is None or audio_format is None:
|
||||
self.recog_type, self.audio_format, self.audio_in = asr_utils.type_checking(
|
||||
audio_in=self.audio_in,
|
||||
recog_type=recog_type,
|
||||
audio_format=audio_format)
|
||||
|
||||
if hasattr(asr_utils, 'sample_rate_checking') and audio_fs is None:
|
||||
self.audio_fs = asr_utils.sample_rate_checking(
|
||||
if hasattr(asr_utils, 'sample_rate_checking'):
|
||||
checking_audio_fs = asr_utils.sample_rate_checking(
|
||||
self.audio_in, self.audio_format)
|
||||
if checking_audio_fs is not None:
|
||||
self.audio_fs = checking_audio_fs
|
||||
|
||||
if self.preprocessor is None:
|
||||
self.preprocessor = WavToScp()
|
||||
@@ -80,7 +86,7 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
|
||||
logger.info(f"Decoding with {inputs['audio_format']} files ...")
|
||||
|
||||
data_cmd: Sequence[Tuple[str, str]]
|
||||
data_cmd: Sequence[Tuple[str, str, str]]
|
||||
if inputs['audio_format'] == 'wav' or inputs['audio_format'] == 'pcm':
|
||||
data_cmd = ['speech', 'sound']
|
||||
elif inputs['audio_format'] == 'kaldi_ark':
|
||||
@@ -88,6 +94,9 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
|
||||
elif inputs['audio_format'] == 'tfrecord':
|
||||
data_cmd = ['speech', 'tfrecord']
|
||||
|
||||
if inputs.__contains__('mvn_file'):
|
||||
data_cmd.append(inputs['mvn_file'])
|
||||
|
||||
# generate asr inference command
|
||||
cmd = {
|
||||
'model_type': inputs['model_type'],
|
||||
|
||||
@@ -51,10 +51,10 @@ class KeyWordSpottingKwsbpPipeline(Pipeline):
|
||||
|
||||
if isinstance(audio_in, str):
|
||||
# load pcm data from url if audio_in is url str
|
||||
audio_in = load_bytes_from_url(audio_in)
|
||||
audio_in, audio_fs = load_bytes_from_url(audio_in)
|
||||
elif isinstance(audio_in, bytes):
|
||||
# load pcm data from wav data if audio_in is wave format
|
||||
audio_in = extract_pcm_from_wav(audio_in)
|
||||
audio_in, audio_fs = extract_pcm_from_wav(audio_in)
|
||||
|
||||
output = self.preprocessor.forward(self.model.forward(), audio_in)
|
||||
output = self.forward(output)
|
||||
|
||||
@@ -12,6 +12,8 @@ from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import \
|
||||
show_image_object_detection_auto_result
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
@@ -52,10 +54,18 @@ class TinynasDetectionPipeline(Pipeline):
|
||||
|
||||
bboxes, scores, labels = self.model.postprocess(inputs['data'])
|
||||
if bboxes is None:
|
||||
return None
|
||||
outputs = {
|
||||
OutputKeys.SCORES: scores,
|
||||
OutputKeys.LABELS: labels,
|
||||
OutputKeys.BOXES: bboxes
|
||||
}
|
||||
outputs = {
|
||||
OutputKeys.SCORES: [],
|
||||
OutputKeys.LABELS: [],
|
||||
OutputKeys.BOXES: []
|
||||
}
|
||||
else:
|
||||
outputs = {
|
||||
OutputKeys.SCORES: scores,
|
||||
OutputKeys.LABELS: labels,
|
||||
OutputKeys.BOXES: bboxes
|
||||
}
|
||||
return outputs
|
||||
|
||||
def show_result(self, img_path, result, save_path=None):
|
||||
show_image_object_detection_auto_result(img_path, result, save_path)
|
||||
|
||||
@@ -133,6 +133,12 @@ class WavToScp(Preprocessor):
|
||||
else:
|
||||
inputs['asr_model_config'] = asr_model_config
|
||||
|
||||
if inputs['model_config'].__contains__('mvn_file'):
|
||||
mvn_file = os.path.join(inputs['model_workspace'],
|
||||
inputs['model_config']['mvn_file'])
|
||||
assert os.path.exists(mvn_file), 'mvn_file does not exist'
|
||||
inputs['mvn_file'] = mvn_file
|
||||
|
||||
elif inputs['model_type'] == Frameworks.tf:
|
||||
assert inputs['model_config'].__contains__(
|
||||
'vocab_file'), 'vocab_file does not exist'
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import os.path as osp
|
||||
import re
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
@@ -217,7 +217,7 @@ class NLPTokenizerPreprocessorBase(Preprocessor):
|
||||
return isinstance(label, str) or isinstance(label, int)
|
||||
|
||||
if labels is not None:
|
||||
if isinstance(labels, Iterable) and all([label_can_be_mapped(label) for label in labels]) \
|
||||
if isinstance(labels, (tuple, list)) and all([label_can_be_mapped(label) for label in labels]) \
|
||||
and self.label2id is not None:
|
||||
output[OutputKeys.LABELS] = [
|
||||
self.label2id[str(label)] for label in labels
|
||||
@@ -314,8 +314,7 @@ class SequenceClassificationPreprocessor(NLPTokenizerPreprocessorBase):
|
||||
|
||||
def __init__(self, model_dir: str, mode=ModeKeys.INFERENCE, **kwargs):
|
||||
kwargs['truncation'] = kwargs.get('truncation', True)
|
||||
kwargs['padding'] = kwargs.get(
|
||||
'padding', False if mode == ModeKeys.INFERENCE else 'max_length')
|
||||
kwargs['padding'] = kwargs.get('padding', 'max_length')
|
||||
kwargs['max_length'] = kwargs.pop('sequence_length', 128)
|
||||
super().__init__(model_dir, mode=mode, **kwargs)
|
||||
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from os.path import exists
|
||||
from tempfile import TemporaryDirectory
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -9,6 +14,7 @@ import torchvision.transforms._transforms_video as transforms
|
||||
from decord import VideoReader
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from modelscope.hub.file_download import http_get_file
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.utils.constant import Fields, ModeKeys
|
||||
from modelscope.utils.type_assert import type_assert
|
||||
@@ -30,7 +36,22 @@ def ReadVideoData(cfg,
|
||||
Returns:
|
||||
data (Tensor): the normalized video clips for model inputs
|
||||
"""
|
||||
data = _decode_video(cfg, video_path, num_temporal_views_override)
|
||||
url_parsed = urlparse(video_path)
|
||||
if url_parsed.scheme in ('file', '') and exists(
|
||||
url_parsed.path): # Possibly a local file
|
||||
data = _decode_video(cfg, video_path, num_temporal_views_override)
|
||||
else:
|
||||
with TemporaryDirectory() as temporary_cache_dir:
|
||||
random_str = uuid.uuid4().hex
|
||||
http_get_file(
|
||||
url=video_path,
|
||||
local_dir=temporary_cache_dir,
|
||||
file_name=random_str,
|
||||
cookies=None)
|
||||
temp_file_path = os.path.join(temporary_cache_dir, random_str)
|
||||
data = _decode_video(cfg, temp_file_path,
|
||||
num_temporal_views_override)
|
||||
|
||||
if num_spatial_crops_override is not None:
|
||||
num_spatial_crops = num_spatial_crops_override
|
||||
transform = kinetics400_tranform(cfg, num_spatial_crops_override)
|
||||
|
||||
@@ -47,7 +47,7 @@ class LrSchedulerHook(Hook):
|
||||
return lr
|
||||
|
||||
def before_train_iter(self, trainer):
|
||||
if not self.by_epoch:
|
||||
if not self.by_epoch and trainer.iter > 0:
|
||||
if self.warmup_lr_scheduler is not None:
|
||||
self.warmup_lr_scheduler.step()
|
||||
else:
|
||||
|
||||
@@ -656,7 +656,7 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
# TODO: support MsDataset load for cv
|
||||
if hasattr(data_cfg, 'name'):
|
||||
dataset = MsDataset.load(
|
||||
dataset_name=data_cfg.name,
|
||||
dataset_name=data_cfg.pop('name'),
|
||||
**data_cfg,
|
||||
)
|
||||
cfg = ConfigDict(type=self.cfg.model.type, mode=mode)
|
||||
|
||||
@@ -57,6 +57,7 @@ def update_conf(origin_config_file, new_config_file, conf_item: [str, str]):
|
||||
|
||||
def extract_pcm_from_wav(wav: bytes) -> bytes:
|
||||
data = wav
|
||||
sample_rate = None
|
||||
if len(data) > 44:
|
||||
frame_len = 44
|
||||
file_len = len(data)
|
||||
@@ -70,29 +71,33 @@ def extract_pcm_from_wav(wav: bytes) -> bytes:
|
||||
'Subchunk1ID'] == 'fmt ':
|
||||
header_fields['SubChunk1Size'] = struct.unpack(
|
||||
'<I', data[16:20])[0]
|
||||
header_fields['SampleRate'] = struct.unpack('<I',
|
||||
data[24:28])[0]
|
||||
sample_rate = header_fields['SampleRate']
|
||||
|
||||
if header_fields['SubChunk1Size'] == 16:
|
||||
frame_len = 44
|
||||
elif header_fields['SubChunk1Size'] == 18:
|
||||
frame_len = 46
|
||||
else:
|
||||
return data
|
||||
return data, sample_rate
|
||||
|
||||
data = wav[frame_len:file_len]
|
||||
except Exception:
|
||||
# no treatment
|
||||
pass
|
||||
|
||||
return data
|
||||
return data, sample_rate
|
||||
|
||||
|
||||
def load_bytes_from_url(url: str) -> Union[bytes, str]:
|
||||
sample_rate = None
|
||||
result = urlparse(url)
|
||||
if result.scheme is not None and len(result.scheme) > 0:
|
||||
storage = HTTPStorage()
|
||||
data = storage.read(url)
|
||||
data = extract_pcm_from_wav(data)
|
||||
data, sample_rate = extract_pcm_from_wav(data)
|
||||
else:
|
||||
data = url
|
||||
|
||||
return data
|
||||
return data, sample_rate
|
||||
|
||||
@@ -231,13 +231,6 @@ class DownloadMode(enum.Enum):
|
||||
FORCE_REDOWNLOAD = 'force_redownload'
|
||||
|
||||
|
||||
class DownloadParams(enum.Enum):
|
||||
"""
|
||||
Parameters for downloading dataset.
|
||||
"""
|
||||
MAX_LIST_OBJECTS_NUM = 50000
|
||||
|
||||
|
||||
class DatasetFormations(enum.Enum):
|
||||
""" How a dataset is organized and interpreted
|
||||
"""
|
||||
|
||||
@@ -61,8 +61,8 @@ def device_placement(framework, device_name='gpu:0'):
|
||||
if framework == Frameworks.tf:
|
||||
import tensorflow as tf
|
||||
if device_type == Devices.gpu and not tf.test.is_gpu_available():
|
||||
logger.warning(
|
||||
'tensorflow cuda is not available, using cpu instead.')
|
||||
logger.debug(
|
||||
'tensorflow: cuda is not available, using cpu instead.')
|
||||
device_type = Devices.cpu
|
||||
if device_type == Devices.cpu:
|
||||
with tf.device('/CPU:0'):
|
||||
@@ -78,7 +78,8 @@ def device_placement(framework, device_name='gpu:0'):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(f'cuda:{device_id}')
|
||||
else:
|
||||
logger.warning('cuda is not available, using cpu instead.')
|
||||
logger.debug(
|
||||
'pytorch: cuda is not available, using cpu instead.')
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
@@ -96,9 +97,7 @@ def create_device(device_name):
|
||||
if device_type == Devices.gpu:
|
||||
use_cuda = True
|
||||
if not torch.cuda.is_available():
|
||||
logger.warning(
|
||||
'cuda is not available, create gpu device failed, using cpu instead.'
|
||||
)
|
||||
logger.info('cuda is not available, using cpu instead.')
|
||||
use_cuda = False
|
||||
|
||||
if use_cuda:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -35,3 +36,10 @@ def get_default_cache_dir():
|
||||
"""
|
||||
default_cache_dir = Path.home().joinpath('.cache', 'modelscope')
|
||||
return default_cache_dir
|
||||
|
||||
|
||||
def read_file(path):
|
||||
|
||||
with open(path, 'r') as f:
|
||||
text = f.read()
|
||||
return text
|
||||
|
||||
@@ -176,7 +176,7 @@ def build_from_cfg(cfg,
|
||||
raise TypeError('default_args must be a dict or None, '
|
||||
f'but got {type(default_args)}')
|
||||
|
||||
# dynamic load installation reqruiements for this module
|
||||
# dynamic load installation requirements for this module
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
sig = (registry.name.upper(), group_key, cfg['type'])
|
||||
LazyImportModule.import_module(sig)
|
||||
@@ -193,8 +193,11 @@ def build_from_cfg(cfg,
|
||||
if isinstance(obj_type, str):
|
||||
obj_cls = registry.get(obj_type, group_key=group_key)
|
||||
if obj_cls is None:
|
||||
raise KeyError(f'{obj_type} is not in the {registry.name}'
|
||||
f' registry group {group_key}')
|
||||
raise KeyError(
|
||||
f'{obj_type} is not in the {registry.name}'
|
||||
f' registry group {group_key}. Please make'
|
||||
f' sure the correct version of 1qqQModelScope library is used.'
|
||||
)
|
||||
obj_cls.group_key = group_key
|
||||
elif inspect.isclass(obj_type) or inspect.isfunction(obj_type):
|
||||
obj_cls = obj_type
|
||||
|
||||
@@ -65,7 +65,8 @@ class RegressTool:
|
||||
def monitor_module_single_forward(self,
|
||||
module: nn.Module,
|
||||
file_name: str,
|
||||
compare_fn=None):
|
||||
compare_fn=None,
|
||||
**kwargs):
|
||||
"""Monitor a pytorch module in a single forward.
|
||||
|
||||
@param module: A torch module
|
||||
@@ -107,7 +108,7 @@ class RegressTool:
|
||||
baseline = os.path.join(tempfile.gettempdir(), name)
|
||||
self.load(baseline, name)
|
||||
with open(baseline, 'rb') as f:
|
||||
baseline_json = pickle.load(f)
|
||||
base = pickle.load(f)
|
||||
|
||||
class NumpyEncoder(json.JSONEncoder):
|
||||
"""Special json encoder for numpy types
|
||||
@@ -122,9 +123,9 @@ class RegressTool:
|
||||
return obj.tolist()
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
print(f'baseline: {json.dumps(baseline_json, cls=NumpyEncoder)}')
|
||||
print(f'baseline: {json.dumps(base, cls=NumpyEncoder)}')
|
||||
print(f'latest : {json.dumps(io_json, cls=NumpyEncoder)}')
|
||||
if not compare_io_and_print(baseline_json, io_json, compare_fn):
|
||||
if not compare_io_and_print(base, io_json, compare_fn, **kwargs):
|
||||
raise ValueError('Result not match!')
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -136,7 +137,8 @@ class RegressTool:
|
||||
ignore_keys=None,
|
||||
compare_random=True,
|
||||
reset_dropout=True,
|
||||
lazy_stop_callback=None):
|
||||
lazy_stop_callback=None,
|
||||
**kwargs):
|
||||
"""Monitor a pytorch module's backward data and cfg data within a step of the optimizer.
|
||||
|
||||
This is usually useful when you try to change some dangerous code
|
||||
@@ -265,14 +267,15 @@ class RegressTool:
|
||||
baseline_json = pickle.load(f)
|
||||
|
||||
if level == 'strict' and not compare_io_and_print(
|
||||
baseline_json['forward'], io_json, compare_fn):
|
||||
baseline_json['forward'], io_json, compare_fn, **kwargs):
|
||||
raise RuntimeError('Forward not match!')
|
||||
if not compare_backward_and_print(
|
||||
baseline_json['backward'],
|
||||
bw_json,
|
||||
compare_fn=compare_fn,
|
||||
ignore_keys=ignore_keys,
|
||||
level=level):
|
||||
level=level,
|
||||
**kwargs):
|
||||
raise RuntimeError('Backward not match!')
|
||||
cfg_opt1 = {
|
||||
'optimizer': baseline_json['optimizer'],
|
||||
@@ -286,7 +289,8 @@ class RegressTool:
|
||||
'cfg': summary['cfg'],
|
||||
'state': None if not compare_random else summary['state']
|
||||
}
|
||||
if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn):
|
||||
if not compare_cfg_and_optimizers(cfg_opt1, cfg_opt2, compare_fn,
|
||||
**kwargs):
|
||||
raise RuntimeError('Cfg or optimizers not match!')
|
||||
|
||||
|
||||
@@ -303,7 +307,8 @@ class MsRegressTool(RegressTool):
|
||||
compare_fn=None,
|
||||
ignore_keys=None,
|
||||
compare_random=True,
|
||||
lazy_stop_callback=None):
|
||||
lazy_stop_callback=None,
|
||||
**kwargs):
|
||||
|
||||
if lazy_stop_callback is None:
|
||||
|
||||
@@ -319,7 +324,7 @@ class MsRegressTool(RegressTool):
|
||||
|
||||
trainer.register_hook(EarlyStopHook())
|
||||
|
||||
def _train_loop(trainer, *args, **kwargs):
|
||||
def _train_loop(trainer, *args_train, **kwargs_train):
|
||||
with self.monitor_module_train(
|
||||
trainer,
|
||||
file_name,
|
||||
@@ -327,9 +332,11 @@ class MsRegressTool(RegressTool):
|
||||
compare_fn=compare_fn,
|
||||
ignore_keys=ignore_keys,
|
||||
compare_random=compare_random,
|
||||
lazy_stop_callback=lazy_stop_callback):
|
||||
lazy_stop_callback=lazy_stop_callback,
|
||||
**kwargs):
|
||||
try:
|
||||
return trainer.train_loop_origin(*args, **kwargs)
|
||||
return trainer.train_loop_origin(*args_train,
|
||||
**kwargs_train)
|
||||
except MsRegressTool.EarlyStopError:
|
||||
pass
|
||||
|
||||
@@ -530,7 +537,8 @@ def compare_arguments_nested(print_content,
|
||||
)
|
||||
return False
|
||||
if not all([
|
||||
compare_arguments_nested(None, sub_arg1, sub_arg2)
|
||||
compare_arguments_nested(
|
||||
None, sub_arg1, sub_arg2, rtol=rtol, atol=atol)
|
||||
for sub_arg1, sub_arg2 in zip(arg1, arg2)
|
||||
]):
|
||||
if print_content is not None:
|
||||
@@ -551,7 +559,8 @@ def compare_arguments_nested(print_content,
|
||||
print(f'{print_content}, key diff:{set(keys1) - set(keys2)}')
|
||||
return False
|
||||
if not all([
|
||||
compare_arguments_nested(None, arg1[key], arg2[key])
|
||||
compare_arguments_nested(
|
||||
None, arg1[key], arg2[key], rtol=rtol, atol=atol)
|
||||
for key in keys1
|
||||
]):
|
||||
if print_content is not None:
|
||||
@@ -574,7 +583,7 @@ def compare_arguments_nested(print_content,
|
||||
raise ValueError(f'type not supported: {type1}')
|
||||
|
||||
|
||||
def compare_io_and_print(baseline_json, io_json, compare_fn=None):
|
||||
def compare_io_and_print(baseline_json, io_json, compare_fn=None, **kwargs):
|
||||
if compare_fn is None:
|
||||
|
||||
def compare_fn(*args, **kwargs):
|
||||
@@ -602,10 +611,10 @@ def compare_io_and_print(baseline_json, io_json, compare_fn=None):
|
||||
else:
|
||||
match = compare_arguments_nested(
|
||||
f'unmatched module {key} input args', v1input['args'],
|
||||
v2input['args']) and match
|
||||
v2input['args'], **kwargs) and match
|
||||
match = compare_arguments_nested(
|
||||
f'unmatched module {key} input kwargs', v1input['kwargs'],
|
||||
v2input['kwargs']) and match
|
||||
v2input['kwargs'], **kwargs) and match
|
||||
v1output = numpify_tensor_nested(v1['output'])
|
||||
v2output = numpify_tensor_nested(v2['output'])
|
||||
res = compare_fn(v1output, v2output, key, 'output')
|
||||
@@ -615,8 +624,11 @@ def compare_io_and_print(baseline_json, io_json, compare_fn=None):
|
||||
)
|
||||
match = match and res
|
||||
else:
|
||||
match = compare_arguments_nested(f'unmatched module {key} outputs',
|
||||
v1output, v2output) and match
|
||||
match = compare_arguments_nested(
|
||||
f'unmatched module {key} outputs',
|
||||
arg1=v1output,
|
||||
arg2=v2output,
|
||||
**kwargs) and match
|
||||
return match
|
||||
|
||||
|
||||
@@ -624,7 +636,8 @@ def compare_backward_and_print(baseline_json,
|
||||
bw_json,
|
||||
level,
|
||||
ignore_keys=None,
|
||||
compare_fn=None):
|
||||
compare_fn=None,
|
||||
**kwargs):
|
||||
if compare_fn is None:
|
||||
|
||||
def compare_fn(*args, **kwargs):
|
||||
@@ -653,18 +666,26 @@ def compare_backward_and_print(baseline_json,
|
||||
data2, grad2, data_after2 = bw_json[key]['data'], bw_json[key][
|
||||
'grad'], bw_json[key]['data_after']
|
||||
match = compare_arguments_nested(
|
||||
f'unmatched module {key} tensor data', data1, data2) and match
|
||||
f'unmatched module {key} tensor data',
|
||||
arg1=data1,
|
||||
arg2=data2,
|
||||
**kwargs) and match
|
||||
if level == 'strict':
|
||||
match = compare_arguments_nested(
|
||||
f'unmatched module {key} grad data', grad1,
|
||||
grad2) and match
|
||||
f'unmatched module {key} grad data',
|
||||
arg1=grad1,
|
||||
arg2=grad2,
|
||||
**kwargs) and match
|
||||
match = compare_arguments_nested(
|
||||
f'unmatched module {key} data after step', data_after1,
|
||||
data_after2) and match
|
||||
data_after2, **kwargs) and match
|
||||
return match
|
||||
|
||||
|
||||
def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None):
|
||||
def compare_cfg_and_optimizers(baseline_json,
|
||||
cfg_json,
|
||||
compare_fn=None,
|
||||
**kwargs):
|
||||
if compare_fn is None:
|
||||
|
||||
def compare_fn(*args, **kwargs):
|
||||
@@ -686,12 +707,12 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None):
|
||||
print(
|
||||
f"Optimizer type not equal:{optimizer1['type']} and {optimizer2['type']}"
|
||||
)
|
||||
match = compare_arguments_nested('unmatched optimizer defaults',
|
||||
optimizer1['defaults'],
|
||||
optimizer2['defaults']) and match
|
||||
match = compare_arguments_nested('unmatched optimizer state_dict',
|
||||
optimizer1['state_dict'],
|
||||
optimizer2['state_dict']) and match
|
||||
match = compare_arguments_nested(
|
||||
'unmatched optimizer defaults', optimizer1['defaults'],
|
||||
optimizer2['defaults'], **kwargs) and match
|
||||
match = compare_arguments_nested(
|
||||
'unmatched optimizer state_dict', optimizer1['state_dict'],
|
||||
optimizer2['state_dict'], **kwargs) and match
|
||||
|
||||
res = compare_fn(lr_scheduler1, lr_scheduler2, None, 'lr_scheduler')
|
||||
if res is not None:
|
||||
@@ -703,16 +724,17 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None):
|
||||
print(
|
||||
f"Optimizer type not equal:{lr_scheduler1['type']} and {lr_scheduler2['type']}"
|
||||
)
|
||||
match = compare_arguments_nested('unmatched lr_scheduler state_dict',
|
||||
lr_scheduler1['state_dict'],
|
||||
lr_scheduler2['state_dict']) and match
|
||||
match = compare_arguments_nested(
|
||||
'unmatched lr_scheduler state_dict', lr_scheduler1['state_dict'],
|
||||
lr_scheduler2['state_dict'], **kwargs) and match
|
||||
|
||||
res = compare_fn(cfg1, cfg2, None, 'cfg')
|
||||
if res is not None:
|
||||
print(f'cfg compared with user compare_fn with result:{res}\n')
|
||||
match = match and res
|
||||
else:
|
||||
match = compare_arguments_nested('unmatched cfg', cfg1, cfg2) and match
|
||||
match = compare_arguments_nested(
|
||||
'unmatched cfg', arg1=cfg1, arg2=cfg2, **kwargs) and match
|
||||
|
||||
res = compare_fn(state1, state2, None, 'state')
|
||||
if res is not None:
|
||||
@@ -721,6 +743,6 @@ def compare_cfg_and_optimizers(baseline_json, cfg_json, compare_fn=None):
|
||||
match = match and res
|
||||
else:
|
||||
match = compare_arguments_nested('unmatched random state', state1,
|
||||
state2) and match
|
||||
state2, **kwargs) and match
|
||||
|
||||
return match
|
||||
|
||||
@@ -19,7 +19,7 @@ moviepy>=1.0.3
|
||||
networkx>=2.5
|
||||
numba
|
||||
onnxruntime>=1.10
|
||||
pai-easycv>=0.6.3.7
|
||||
pai-easycv>=0.6.3.9
|
||||
pandas
|
||||
psutil
|
||||
regex
|
||||
|
||||
@@ -127,7 +127,7 @@ class HubOperationTest(unittest.TestCase):
|
||||
return None
|
||||
|
||||
def test_list_model(self):
|
||||
data = self.api.list_model(TEST_MODEL_ORG)
|
||||
data = self.api.list_models(TEST_MODEL_ORG)
|
||||
assert len(data['Models']) >= 1
|
||||
|
||||
|
||||
|
||||
@@ -7,12 +7,12 @@ import uuid
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.constants import Licenses, ModelVisibility
|
||||
from modelscope.hub.errors import HTTPError, NotLoginException
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.hub.upload import upload_folder
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.test_utils import test_level
|
||||
from .test_utils import TEST_ACCESS_TOKEN1, delete_credential
|
||||
from .test_utils import TEST_ACCESS_TOKEN1, TEST_MODEL_ORG, delete_credential
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -22,7 +22,7 @@ class HubUploadTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
logger.info('SetUp')
|
||||
self.api = HubApi()
|
||||
self.user = os.environ.get('TEST_MODEL_ORG', 'citest')
|
||||
self.user = TEST_MODEL_ORG
|
||||
logger.info(self.user)
|
||||
self.create_model_name = '%s/%s_%s' % (self.user, 'test_model_upload',
|
||||
uuid.uuid4().hex)
|
||||
@@ -39,7 +39,10 @@ class HubUploadTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
logger.info('TearDown')
|
||||
shutil.rmtree(self.model_dir, ignore_errors=True)
|
||||
self.api.delete_model(model_id=self.create_model_name)
|
||||
try:
|
||||
self.api.delete_model(model_id=self.create_model_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_upload_exits_repo_master(self):
|
||||
logger.info('basic test for upload!')
|
||||
@@ -50,14 +53,14 @@ class HubUploadTest(unittest.TestCase):
|
||||
license=Licenses.APACHE_V2)
|
||||
os.system("echo '111'>%s"
|
||||
% os.path.join(self.finetune_path, 'add1.py'))
|
||||
upload_folder(
|
||||
self.api.push_model(
|
||||
model_id=self.create_model_name, model_dir=self.finetune_path)
|
||||
Repository(model_dir=self.repo_path, clone_from=self.create_model_name)
|
||||
assert os.path.exists(os.path.join(self.repo_path, 'add1.py'))
|
||||
shutil.rmtree(self.repo_path, ignore_errors=True)
|
||||
os.system("echo '222'>%s"
|
||||
% os.path.join(self.finetune_path, 'add2.py'))
|
||||
upload_folder(
|
||||
self.api.push_model(
|
||||
model_id=self.create_model_name,
|
||||
model_dir=self.finetune_path,
|
||||
revision='new_revision/version1')
|
||||
@@ -69,7 +72,7 @@ class HubUploadTest(unittest.TestCase):
|
||||
shutil.rmtree(self.repo_path, ignore_errors=True)
|
||||
os.system("echo '333'>%s"
|
||||
% os.path.join(self.finetune_path, 'add3.py'))
|
||||
upload_folder(
|
||||
self.api.push_model(
|
||||
model_id=self.create_model_name,
|
||||
model_dir=self.finetune_path,
|
||||
revision='new_revision/version2',
|
||||
@@ -84,7 +87,7 @@ class HubUploadTest(unittest.TestCase):
|
||||
add4_path = os.path.join(self.finetune_path, 'temp')
|
||||
os.mkdir(add4_path)
|
||||
os.system("echo '444'>%s" % os.path.join(add4_path, 'add4.py'))
|
||||
upload_folder(
|
||||
self.api.push_model(
|
||||
model_id=self.create_model_name,
|
||||
model_dir=self.finetune_path,
|
||||
revision='new_revision/version1')
|
||||
@@ -101,7 +104,7 @@ class HubUploadTest(unittest.TestCase):
|
||||
self.api.login(TEST_ACCESS_TOKEN1)
|
||||
os.system("echo '111'>%s"
|
||||
% os.path.join(self.finetune_path, 'add1.py'))
|
||||
upload_folder(
|
||||
self.api.push_model(
|
||||
model_id=self.create_model_name,
|
||||
model_dir=self.finetune_path,
|
||||
revision='new_model_new_revision',
|
||||
@@ -119,48 +122,23 @@ class HubUploadTest(unittest.TestCase):
|
||||
logger.info('test upload without login!')
|
||||
self.api.login(TEST_ACCESS_TOKEN1)
|
||||
delete_credential()
|
||||
try:
|
||||
upload_folder(
|
||||
with self.assertRaises(NotLoginException):
|
||||
self.api.push_model(
|
||||
model_id=self.create_model_name,
|
||||
model_dir=self.finetune_path,
|
||||
visibility=ModelVisibility.PUBLIC,
|
||||
license=Licenses.APACHE_V2)
|
||||
except Exception as e:
|
||||
logger.info(e)
|
||||
self.api.login(TEST_ACCESS_TOKEN1)
|
||||
upload_folder(
|
||||
model_id=self.create_model_name,
|
||||
model_dir=self.finetune_path,
|
||||
visibility=ModelVisibility.PUBLIC,
|
||||
license=Licenses.APACHE_V2)
|
||||
Repository(
|
||||
model_dir=self.repo_path, clone_from=self.create_model_name)
|
||||
assert os.path.exists(
|
||||
os.path.join(self.repo_path, 'configuration.json'))
|
||||
shutil.rmtree(self.repo_path, ignore_errors=True)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_upload_invalid_repo(self):
|
||||
logger.info('test upload to invalid repo!')
|
||||
self.api.login(TEST_ACCESS_TOKEN1)
|
||||
try:
|
||||
upload_folder(
|
||||
with self.assertRaises(HTTPError):
|
||||
self.api.push_model(
|
||||
model_id='%s/%s' % ('speech_tts', 'invalid_model_test'),
|
||||
model_dir=self.finetune_path,
|
||||
visibility=ModelVisibility.PUBLIC,
|
||||
license=Licenses.APACHE_V2)
|
||||
except Exception as e:
|
||||
logger.info(e)
|
||||
upload_folder(
|
||||
model_id=self.create_model_name,
|
||||
model_dir=self.finetune_path,
|
||||
visibility=ModelVisibility.PUBLIC,
|
||||
license=Licenses.APACHE_V2)
|
||||
Repository(
|
||||
model_dir=self.repo_path, clone_from=self.create_model_name)
|
||||
assert os.path.exists(
|
||||
os.path.join(self.repo_path, 'configuration.json'))
|
||||
shutil.rmtree(self.repo_path, ignore_errors=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -52,7 +52,8 @@ class MsDatasetTest(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_ms_csv_basic(self):
|
||||
ms_ds_train = MsDataset.load(
|
||||
'afqmc_small', namespace='userxiaoming', split='train')
|
||||
'clue', subset_name='afqmc',
|
||||
split='train').to_hf_dataset().select(range(5))
|
||||
print(next(iter(ms_ds_train)))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
|
||||
@@ -45,6 +45,10 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
|
||||
'checking_item': OutputKeys.TEXT,
|
||||
'example': 'wav_example'
|
||||
},
|
||||
'test_run_with_url_pytorch': {
|
||||
'checking_item': OutputKeys.TEXT,
|
||||
'example': 'wav_example'
|
||||
},
|
||||
'test_run_with_url_tf': {
|
||||
'checking_item': OutputKeys.TEXT,
|
||||
'example': 'wav_example'
|
||||
@@ -74,6 +78,170 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
|
||||
}
|
||||
}
|
||||
|
||||
all_models_info = [
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1',
|
||||
'wav_path': 'data/test/audios/asr_example.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id': 'speech_paraformer_asr_nat-aishell1-pytorch',
|
||||
'wav_path': 'data/test/audios/asr_example.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1',
|
||||
'wav_path': 'data/test/audios/asr_example.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_paraformer_asr_nat-zh-cn-8k-common-vocab8358-tensorflow1',
|
||||
'wav_path': 'data/test/audios/asr_example_8K.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_8K.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-zh-cn-8k-common-vocab8358-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_8K.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR-large_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_cn_en.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-cn-en-moe-16k-vocab8358-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_cn_en.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_cn_dialect.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-cn-dialect-16k-vocab8358-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_cn_dialect.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_paraformer_asr_nat-zh-cn-16k-common-vocab3444-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_paraformer_asr_nat-zh-cn-8k-common-vocab3444-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_8K.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_en.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-en-16k-common-vocab1080-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_en.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_ru.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-ru-16k-common-vocab1664-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_ru.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_es.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-es-16k-common-vocab3445-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_es.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_ko.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-ko-16k-common-vocab6400-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_ko.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_ja.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-ja-16k-common-vocab93-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_ja.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-online',
|
||||
'wav_path': 'data/test/audios/asr_example_id.wav'
|
||||
},
|
||||
{
|
||||
'model_group': 'damo',
|
||||
'model_id':
|
||||
'speech_UniASR_asr_2pass-id-16k-common-vocab1067-tensorflow1-offline',
|
||||
'wav_path': 'data/test/audios/asr_example_id.wav'
|
||||
},
|
||||
]
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.am_pytorch_model_id = 'damo/speech_paraformer_asr_nat-aishell1-pytorch'
|
||||
self.am_tf_model_id = 'damo/speech_paraformer_asr_nat-zh-cn-16k-common-vocab8358-tensorflow1'
|
||||
@@ -90,7 +258,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
|
||||
def run_pipeline(self,
|
||||
model_id: str,
|
||||
audio_in: Union[str, bytes],
|
||||
sr: int = 16000) -> Dict[str, Any]:
|
||||
sr: int = None) -> Dict[str, Any]:
|
||||
inference_16k_pipline = pipeline(
|
||||
task=Tasks.auto_speech_recognition, model=model_id)
|
||||
|
||||
@@ -136,46 +304,7 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
|
||||
return audio, fs
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_wav_pytorch(self):
|
||||
"""run with single waveform file
|
||||
"""
|
||||
|
||||
logger.info('Run ASR test with waveform file (pytorch)...')
|
||||
|
||||
wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_pytorch_model_id, audio_in=wav_file_path)
|
||||
self.check_result('test_run_with_wav_pytorch', rec_result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_pcm_pytorch(self):
|
||||
"""run with wav data
|
||||
"""
|
||||
|
||||
logger.info('Run ASR test with wav data (pytorch)...')
|
||||
|
||||
audio, sr = self.wav2bytes(os.path.join(os.getcwd(), WAV_FILE))
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_pytorch_model_id, audio_in=audio, sr=sr)
|
||||
self.check_result('test_run_with_pcm_pytorch', rec_result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_wav_tf(self):
|
||||
"""run with single waveform file
|
||||
"""
|
||||
|
||||
logger.info('Run ASR test with waveform file (tensorflow)...')
|
||||
|
||||
wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_tf_model_id, audio_in=wav_file_path)
|
||||
self.check_result('test_run_with_wav_tf', rec_result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_pcm_tf(self):
|
||||
def test_run_with_pcm(self):
|
||||
"""run with wav data
|
||||
"""
|
||||
|
||||
@@ -187,8 +316,33 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
|
||||
model_id=self.am_tf_model_id, audio_in=audio, sr=sr)
|
||||
self.check_result('test_run_with_pcm_tf', rec_result)
|
||||
|
||||
logger.info('Run ASR test with wav data (pytorch)...')
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_pytorch_model_id, audio_in=audio, sr=sr)
|
||||
self.check_result('test_run_with_pcm_pytorch', rec_result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_url_tf(self):
|
||||
def test_run_with_wav(self):
|
||||
"""run with single waveform file
|
||||
"""
|
||||
|
||||
logger.info('Run ASR test with waveform file (tensorflow)...')
|
||||
|
||||
wav_file_path = os.path.join(os.getcwd(), WAV_FILE)
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_tf_model_id, audio_in=wav_file_path)
|
||||
self.check_result('test_run_with_wav_tf', rec_result)
|
||||
|
||||
logger.info('Run ASR test with waveform file (pytorch)...')
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_pytorch_model_id, audio_in=wav_file_path)
|
||||
self.check_result('test_run_with_wav_pytorch', rec_result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_url(self):
|
||||
"""run with single url file
|
||||
"""
|
||||
|
||||
@@ -198,6 +352,12 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
|
||||
model_id=self.am_tf_model_id, audio_in=URL_FILE)
|
||||
self.check_result('test_run_with_url_tf', rec_result)
|
||||
|
||||
logger.info('Run ASR test with url file (pytorch)...')
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_pytorch_model_id, audio_in=URL_FILE)
|
||||
self.check_result('test_run_with_url_pytorch', rec_result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_wav_dataset_pytorch(self):
|
||||
"""run with datasets, and audio format is waveform
|
||||
@@ -217,7 +377,6 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
|
||||
data.text # hypothesis text
|
||||
"""
|
||||
|
||||
logger.info('Run ASR test with waveform dataset (pytorch)...')
|
||||
logger.info('Downloading waveform testsets file ...')
|
||||
|
||||
dataset_path = download_and_untar(
|
||||
@@ -225,40 +384,38 @@ class AutomaticSpeechRecognitionTest(unittest.TestCase,
|
||||
LITTLE_TESTSETS_URL, self.workspace)
|
||||
dataset_path = os.path.join(dataset_path, 'wav', 'test')
|
||||
|
||||
logger.info('Run ASR test with waveform dataset (tensorflow)...')
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_tf_model_id, audio_in=dataset_path)
|
||||
self.check_result('test_run_with_wav_dataset_tf', rec_result)
|
||||
|
||||
logger.info('Run ASR test with waveform dataset (pytorch)...')
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_pytorch_model_id, audio_in=dataset_path)
|
||||
self.check_result('test_run_with_wav_dataset_pytorch', rec_result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_wav_dataset_tf(self):
|
||||
"""run with datasets, and audio format is waveform
|
||||
datasets directory:
|
||||
<dataset_path>
|
||||
wav
|
||||
test # testsets
|
||||
xx.wav
|
||||
...
|
||||
dev # devsets
|
||||
yy.wav
|
||||
...
|
||||
train # trainsets
|
||||
zz.wav
|
||||
...
|
||||
transcript
|
||||
data.text # hypothesis text
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_all_models(self):
|
||||
"""run with all models
|
||||
"""
|
||||
|
||||
logger.info('Run ASR test with waveform dataset (tensorflow)...')
|
||||
logger.info('Downloading waveform testsets file ...')
|
||||
logger.info('Run ASR test with all models')
|
||||
|
||||
dataset_path = download_and_untar(
|
||||
os.path.join(self.workspace, LITTLE_TESTSETS_FILE),
|
||||
LITTLE_TESTSETS_URL, self.workspace)
|
||||
dataset_path = os.path.join(dataset_path, 'wav', 'test')
|
||||
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=self.am_tf_model_id, audio_in=dataset_path)
|
||||
self.check_result('test_run_with_wav_dataset_tf', rec_result)
|
||||
for item in self.all_models_info:
|
||||
model_id = item['model_group'] + '/' + item['model_id']
|
||||
wav_path = item['wav_path']
|
||||
rec_result = self.run_pipeline(
|
||||
model_id=model_id, audio_in=wav_path)
|
||||
if rec_result.__contains__(OutputKeys.TEXT):
|
||||
logger.info(ColorCodes.MAGENTA + str(item['model_id']) + ' '
|
||||
+ ColorCodes.YELLOW
|
||||
+ str(rec_result[OutputKeys.TEXT])
|
||||
+ ColorCodes.END)
|
||||
else:
|
||||
logger.info(ColorCodes.MAGENTA + str(rec_result)
|
||||
+ ColorCodes.END)
|
||||
|
||||
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
|
||||
def test_demo_compatibility(self):
|
||||
|
||||
@@ -26,6 +26,20 @@ class TranslationTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
pipeline_ins = pipeline(self.task, model=model_id)
|
||||
print(pipeline_ins(input=inputs))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_name_for_en2fr(self):
|
||||
model_id = 'damo/nlp_csanmt_translation_en2fr'
|
||||
inputs = 'When I was in my 20s, I saw my very first psychotherapy client.'
|
||||
pipeline_ins = pipeline(self.task, model=model_id)
|
||||
print(pipeline_ins(input=inputs))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_name_for_fr2en(self):
|
||||
model_id = 'damo/nlp_csanmt_translation_fr2en'
|
||||
inputs = "Quand j'avais la vingtaine, j'ai vu mes tout premiers clients comme psychothérapeute."
|
||||
pipeline_ins = pipeline(self.task, model=model_id)
|
||||
print(pipeline_ins(input=inputs))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_default_model(self):
|
||||
inputs = '声明补充说,沃伦的同事都深感震惊,并且希望他能够投案自首。'
|
||||
|
||||
@@ -4,22 +4,45 @@ import unittest
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TinynasObjectDetectionTest(unittest.TestCase):
|
||||
class TinynasObjectDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.image_object_detection
|
||||
self.model_id = 'damo/cv_tinynas_object-detection_damoyolo'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run(self):
|
||||
def test_run_airdet(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.image_object_detection, model='damo/cv_tinynas_detection')
|
||||
result = tinynas_object_detection(
|
||||
'data/test/images/image_detection.jpg')
|
||||
print(result)
|
||||
|
||||
@unittest.skip('will be enabled after damoyolo officially released')
|
||||
def test_run_damoyolo(self):
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.image_object_detection,
|
||||
model='damo/cv_tinynas_object-detection_damoyolo')
|
||||
result = tinynas_object_detection(
|
||||
'data/test/images/image_detection.jpg')
|
||||
print(result)
|
||||
|
||||
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
|
||||
def test_demo_compatibility(self):
|
||||
self.test_demo()
|
||||
self.compatibility_check()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_image_object_detection_auto_pipeline(self):
|
||||
test_image = 'data/test/images/image_detection.jpg'
|
||||
tinynas_object_detection = pipeline(
|
||||
Tasks.image_object_detection, model='damo/cv_tinynas_detection')
|
||||
result = tinynas_object_detection(test_image)
|
||||
tinynas_object_detection.show_result(test_image, result,
|
||||
'demo_ret.jpg')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.constant import DownloadMode, LogKeys, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), 'cuda unittest')
|
||||
class EasyCVTrainerTestFace2DKeypoints(unittest.TestCase):
|
||||
model_id = 'damo/cv_mobilenet_face-2d-keypoints_alignment'
|
||||
|
||||
def setUp(self):
|
||||
self.logger = get_logger()
|
||||
self.logger.info(('Testing %s.%s' %
|
||||
(type(self).__name__, self._testMethodName)))
|
||||
|
||||
def _train(self, tmp_dir):
|
||||
cfg_options = {'train.max_epochs': 2}
|
||||
|
||||
trainer_name = Trainers.easycv
|
||||
|
||||
train_dataset = MsDataset.load(
|
||||
dataset_name='face_2d_keypoints_dataset',
|
||||
namespace='modelscope',
|
||||
split='train',
|
||||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
|
||||
eval_dataset = MsDataset.load(
|
||||
dataset_name='face_2d_keypoints_dataset',
|
||||
namespace='modelscope',
|
||||
split='train',
|
||||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)
|
||||
|
||||
kwargs = dict(
|
||||
model=self.model_id,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
work_dir=tmp_dir,
|
||||
cfg_options=cfg_options)
|
||||
|
||||
trainer = build_trainer(trainer_name, kwargs)
|
||||
trainer.train()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer_single_gpu(self):
|
||||
temp_file_dir = tempfile.TemporaryDirectory()
|
||||
tmp_dir = temp_file_dir.name
|
||||
if not os.path.exists(tmp_dir):
|
||||
os.makedirs(tmp_dir)
|
||||
|
||||
self._train(tmp_dir)
|
||||
|
||||
results_files = os.listdir(tmp_dir)
|
||||
json_files = glob.glob(os.path.join(tmp_dir, '*.log.json'))
|
||||
self.assertEqual(len(json_files), 1)
|
||||
self.assertIn(f'{LogKeys.EPOCH}_2.pth', results_files)
|
||||
|
||||
temp_file_dir.cleanup()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -16,7 +16,8 @@ from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \
|
||||
calculate_fisher
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.data_utils import to_device
|
||||
from modelscope.utils.regress_test_utils import MsRegressTool
|
||||
from modelscope.utils.regress_test_utils import (MsRegressTool,
|
||||
compare_arguments_nested)
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
@@ -41,6 +42,33 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
|
||||
def test_trainer_repeatable(self):
|
||||
import torch # noqa
|
||||
|
||||
def compare_fn(value1, value2, key, type):
|
||||
# Ignore the differences between optimizers of two torch versions
|
||||
if type != 'optimizer':
|
||||
return None
|
||||
|
||||
match = (value1['type'] == value2['type'])
|
||||
shared_defaults = set(value1['defaults'].keys()).intersection(
|
||||
set(value2['defaults'].keys()))
|
||||
match = all([
|
||||
compare_arguments_nested(f'Optimizer defaults {key} not match',
|
||||
value1['defaults'][key],
|
||||
value2['defaults'][key])
|
||||
for key in shared_defaults
|
||||
]) and match
|
||||
match = (len(value1['state_dict']['param_groups']) == len(
|
||||
value2['state_dict']['param_groups'])) and match
|
||||
for group1, group2 in zip(value1['state_dict']['param_groups'],
|
||||
value2['state_dict']['param_groups']):
|
||||
shared_keys = set(group1.keys()).intersection(
|
||||
set(group2.keys()))
|
||||
match = all([
|
||||
compare_arguments_nested(
|
||||
f'Optimizer param_groups {key} not match', group1[key],
|
||||
group2[key]) for key in shared_keys
|
||||
]) and match
|
||||
return match
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
cfg.task = 'nli'
|
||||
cfg['preprocessor'] = {'type': 'nli-tokenizer'}
|
||||
@@ -98,7 +126,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
|
||||
name=Trainers.nlp_base_trainer, default_args=kwargs)
|
||||
|
||||
with self.regress_tool.monitor_ms_train(
|
||||
trainer, 'sbert-base-tnews', level='strict'):
|
||||
trainer, 'sbert-base-tnews', level='strict',
|
||||
compare_fn=compare_fn):
|
||||
trainer.train()
|
||||
|
||||
def finetune(self,
|
||||
|
||||
@@ -51,7 +51,7 @@ class ImageDenoiseTrainerTest(unittest.TestCase):
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_trainer(self):
|
||||
kwargs = dict(
|
||||
model=self.model_id,
|
||||
@@ -65,7 +65,7 @@ class ImageDenoiseTrainerTest(unittest.TestCase):
|
||||
for i in range(2):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_trainer_with_model_and_args(self):
|
||||
model = NAFNetForImageDenoise.from_pretrained(self.cache_path)
|
||||
kwargs = dict(
|
||||
|
||||
@@ -29,7 +29,8 @@ class TestTrainerWithNlp(unittest.TestCase):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
self.dataset = MsDataset.load(
|
||||
'afqmc_small', namespace='userxiaoming', split='train')
|
||||
'clue', subset_name='afqmc',
|
||||
split='train').to_hf_dataset().select(range(2))
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
@@ -73,7 +74,7 @@ class TestTrainerWithNlp(unittest.TestCase):
|
||||
output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR)
|
||||
pipeline_sentence_similarity(output_dir)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
|
||||
def test_trainer_with_backbone_head(self):
|
||||
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base'
|
||||
kwargs = dict(
|
||||
@@ -99,6 +100,8 @@ class TestTrainerWithNlp(unittest.TestCase):
|
||||
model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base'
|
||||
cfg = read_config(model_id, revision='beta')
|
||||
cfg.train.max_epochs = 20
|
||||
cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1}
|
||||
cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1}
|
||||
cfg.train.work_dir = self.tmp_dir
|
||||
cfg_file = os.path.join(self.tmp_dir, 'config.json')
|
||||
cfg.dump(cfg_file)
|
||||
@@ -120,22 +123,24 @@ class TestTrainerWithNlp(unittest.TestCase):
|
||||
checkpoint_path=os.path.join(self.tmp_dir, 'epoch_10.pth'))
|
||||
self.assertTrue(Metrics.accuracy in eval_results)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_trainer_with_configured_datasets(self):
|
||||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base'
|
||||
cfg: Config = read_config(model_id)
|
||||
cfg.train.max_epochs = 20
|
||||
cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1}
|
||||
cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1}
|
||||
cfg.train.work_dir = self.tmp_dir
|
||||
cfg.dataset = {
|
||||
'train': {
|
||||
'name': 'afqmc_small',
|
||||
'name': 'clue',
|
||||
'subset_name': 'afqmc',
|
||||
'split': 'train',
|
||||
'namespace': 'userxiaoming'
|
||||
},
|
||||
'val': {
|
||||
'name': 'afqmc_small',
|
||||
'name': 'clue',
|
||||
'subset_name': 'afqmc',
|
||||
'split': 'train',
|
||||
'namespace': 'userxiaoming'
|
||||
},
|
||||
}
|
||||
cfg_file = os.path.join(self.tmp_dir, 'config.json')
|
||||
@@ -159,6 +164,11 @@ class TestTrainerWithNlp(unittest.TestCase):
|
||||
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-base'
|
||||
cfg: Config = read_config(model_id)
|
||||
cfg.train.max_epochs = 3
|
||||
cfg.preprocessor.first_sequence = 'sentence1'
|
||||
cfg.preprocessor.second_sequence = 'sentence2'
|
||||
cfg.preprocessor.label = 'label'
|
||||
cfg.preprocessor.train['label2id'] = {'0': 0, '1': 1}
|
||||
cfg.preprocessor.val['label2id'] = {'0': 0, '1': 1}
|
||||
cfg.train.work_dir = self.tmp_dir
|
||||
cfg_file = os.path.join(self.tmp_dir, 'config.json')
|
||||
cfg.dump(cfg_file)
|
||||
|
||||
19
tests/utils/test_compatibility.py
Normal file
19
tests/utils/test_compatibility.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
class CompatibilityTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
def test_xtcocotools(self):
|
||||
from xtcocotools.coco import COCO
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user