Merge branch 'master' into release/1.15

This commit is contained in:
mulin.lyh
2024-05-28 14:55:46 +08:00
5 changed files with 227 additions and 13 deletions

View File

@@ -404,7 +404,7 @@ class HubApi:
(owner_or_group, page_number, page_size), (owner_or_group, page_number, page_size),
cookies=cookies, cookies=cookies,
headers=self.builder_headers(self.headers)) headers=self.builder_headers(self.headers))
handle_http_response(r, logger, cookies, 'list_model') handle_http_response(r, logger, cookies, owner_or_group)
if r.status_code == HTTPStatus.OK: if r.status_code == HTTPStatus.OK:
if is_ok(r.json()): if is_ok(r.json()):
data = r.json()[API_RESPONSE_FIELD_DATA] data = r.json()[API_RESPONSE_FIELD_DATA]

View File

@@ -87,16 +87,34 @@ def handle_http_post_error(response, url, request_body):
def handle_http_response(response: requests.Response, logger, cookies, def handle_http_response(response: requests.Response, logger, cookies,
model_id): model_id):
http_error_msg = ''
if isinstance(response.reason, bytes):
try: try:
response.raise_for_status() reason = response.reason.decode('utf-8')
except HTTPError as error: except UnicodeDecodeError:
if cookies is None: # code in [403] and reason = response.reason.decode('iso-8859-1')
logger.error( else:
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \ reason = response.reason
private. Please login first.') request_id = get_request_id(response)
message = _decode_response_error(response) if 404 == response.status_code:
raise HTTPError('Response details: %s, Request id: %s' % http_error_msg = 'The request model: %s does not exist!' % (model_id)
(message, get_request_id(response))) from error elif 403 == response.status_code:
if cookies is None:
http_error_msg = 'Authentication token does not exist, '
'failed to access model {model_id} which may not exist or may be '
'private. Please login first.'
else:
http_error_msg = 'The authentication token is invalid, failed to access model {model_id}.'
elif 400 <= response.status_code < 500:
http_error_msg = u'%s Client Error: %s, Request id: %s for url: %s' % (
response.status_code, reason, request_id, response.url)
elif 500 <= response.status_code < 600:
http_error_msg = u'%s Server Error: %s, Request id: %s, for url: %s' % (
response.status_code, reason, request_id, response.url)
if http_error_msg: # there is error.
logger.error(http_error_msg)
raise HTTPError(http_error_msg, response=response)
def raise_on_error(rsp): def raise_on_error(rsp):
@@ -160,7 +178,12 @@ def raise_for_http_status(rsp):
else: else:
reason = rsp.reason reason = rsp.reason
request_id = get_request_id(rsp) request_id = get_request_id(rsp)
if 400 <= rsp.status_code < 500: if 404 == rsp.status_code:
http_error_msg = 'The request resource(model or dataset) does not exist!,'
'url: %s, reason: %s' % (rsp.url, reason)
elif 403 == rsp.status_code:
http_error_msg = 'Authentication token does not exist or invalid.'
elif 400 <= rsp.status_code < 500:
http_error_msg = u'%s Client Error: %s, Request id: %s for url: %s' % ( http_error_msg = u'%s Client Error: %s, Request id: %s for url: %s' % (
rsp.status_code, reason, request_id, rsp.url) rsp.status_code, reason, request_id, rsp.url)

View File

@@ -43,7 +43,8 @@ def snapshot_download(
model_id (str): A user or an organization name and a repo name separated by a `/`. model_id (str): A user or an organization name and a repo name separated by a `/`.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. NOTE: currently only branch and tag name is supported commit hash. NOTE: currently only branch and tag name is supported
cache_dir (str, Path, optional): Path to the folder where cached files are stored. cache_dir (str, Path, optional): Path to the folder where cached files are stored, model will
be save as cache_dir/model_id/THE_MODEL_FILES.
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string. user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
local cached file if it exists. local cached file if it exists.

View File

@@ -0,0 +1,153 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
"""
This TDNN implementation is adapted from https://github.com/wenet-e2e/wespeaker.
TDNN replaces i-vectors for text-independent speaker verification with embeddings
extracted from a feedforward deep neural network. The specific structure can be
referred to in https://www.danielpovey.com/files/2017_interspeech_embeddings.pdf.
"""
import math
import os
from typing import Any, Dict, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.compliance.kaldi as Kaldi
import modelscope.models.audio.sv.pooling_layers as pooling_layers
from modelscope.metainfo import Models
from modelscope.models import MODELS, TorchModel
from modelscope.utils.constant import Tasks
from modelscope.utils.device import create_device
class TdnnLayer(nn.Module):
def __init__(self, in_dim, out_dim, context_size, dilation=1, padding=0):
"""Define the TDNN layer, essentially 1-D convolution
Args:
in_dim (int): input dimension
out_dim (int): output channels
context_size (int): context size, essentially the filter size
dilation (int, optional): Defaults to 1.
padding (int, optional): Defaults to 0.
"""
super(TdnnLayer, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.context_size = context_size
self.dilation = dilation
self.padding = padding
self.conv_1d = nn.Conv1d(
self.in_dim,
self.out_dim,
self.context_size,
dilation=self.dilation,
padding=self.padding)
# Set Affine=false to be compatible with the original kaldi version
self.bn = nn.BatchNorm1d(out_dim, affine=False)
def forward(self, x):
out = self.conv_1d(x)
out = F.relu(out)
out = self.bn(out)
return out
class XVEC(nn.Module):
def __init__(self,
feat_dim=40,
hid_dim=512,
stats_dim=1500,
embed_dim=512,
pooling_func='TSTP'):
"""
Implementation of Kaldi style xvec, as described in
X-VECTORS: ROBUST DNN EMBEDDINGS FOR SPEAKER RECOGNITION
"""
super(XVEC, self).__init__()
self.feat_dim = feat_dim
self.stats_dim = stats_dim
self.embed_dim = embed_dim
self.frame_1 = TdnnLayer(feat_dim, hid_dim, context_size=5, dilation=1)
self.frame_2 = TdnnLayer(hid_dim, hid_dim, context_size=3, dilation=2)
self.frame_3 = TdnnLayer(hid_dim, hid_dim, context_size=3, dilation=3)
self.frame_4 = TdnnLayer(hid_dim, hid_dim, context_size=1, dilation=1)
self.frame_5 = TdnnLayer(
hid_dim, stats_dim, context_size=1, dilation=1)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2
self.pool = getattr(pooling_layers, pooling_func)(
in_dim=self.stats_dim)
self.seg_1 = nn.Linear(self.stats_dim * self.n_stats, embed_dim)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
out = self.frame_1(x)
out = self.frame_2(out)
out = self.frame_3(out)
out = self.frame_4(out)
out = self.frame_5(out)
stats = self.pool(out)
embed_a = self.seg_1(stats)
return embed_a
@MODELS.register_module(Tasks.speaker_verification, module_name=Models.tdnn_sv)
class SpeakerVerificationTDNN(TorchModel):
def __init__(self, model_dir, model_config: Dict[str, Any], *args,
**kwargs):
super().__init__(model_dir, model_config, *args, **kwargs)
self.model_config = model_config
self.other_config = kwargs
self.feature_dim = 80
self.embed_dim = 512
self.device = create_device(self.other_config['device'])
print(self.device)
self.embedding_model = XVEC(
feat_dim=self.feature_dim, embed_dim=self.embed_dim)
pretrained_model_name = kwargs['pretrained_model']
self.__load_check_point(pretrained_model_name)
self.embedding_model.to(self.device)
self.embedding_model.eval()
def forward(self, audio):
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
assert len(
audio.shape
) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
# audio shape: [N, T]
feature = self.__extract_feature(audio)
embedding = self.embedding_model(feature.to(self.device))
return embedding.detach().cpu()
def __extract_feature(self, audio):
features = []
for au in audio:
feature = Kaldi.fbank(
au.unsqueeze(0), num_mel_bins=self.feature_dim)
feature = feature - feature.mean(dim=0, keepdim=True)
features.append(feature.unsqueeze(0))
features = torch.cat(features)
return features
def __load_check_point(self, pretrained_model_name):
self.embedding_model.load_state_dict(
torch.load(
os.path.join(self.model_dir, pretrained_model_name),
map_location=torch.device('cpu')),
strict=True)

View File

@@ -1,8 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
import os import os
import shutil
import tempfile import tempfile
import unittest import unittest
import uuid import uuid
from pathlib import Path
from shutil import rmtree from shutil import rmtree
import requests import requests
@@ -13,6 +15,7 @@ from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import ModelFile from modelscope.utils.constant import ModelFile
from modelscope.utils.file_utils import get_model_cache_dir
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1, from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
TEST_MODEL_CHINESE_NAME, TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG) TEST_MODEL_ORG)
@@ -148,6 +151,40 @@ class HubOperationTest(unittest.TestCase):
data = self.api.list_models(TEST_MODEL_ORG) data = self.api.list_models(TEST_MODEL_ORG)
assert len(data['Models']) >= 1 assert len(data['Models']) >= 1
def test_snapshot_download_location(self):
self.prepare_case()
snapshot_download_path = snapshot_download(
model_id=self.model_id, revision=self.revision)
assert os.path.exists(snapshot_download_path)
assert '/hub/' in snapshot_download_path
print(snapshot_download_path)
shutil.rmtree(snapshot_download_path)
# download with cache_dir
cache_dir = '/tmp/snapshot_download_cache_test'
snapshot_download_path = snapshot_download(
self.model_id, revision=self.revision, cache_dir=cache_dir)
expect_path = os.path.join(cache_dir, self.model_id)
assert snapshot_download_path == expect_path
assert os.path.exists(
os.path.join(snapshot_download_path, ModelFile.README))
shutil.rmtree(cache_dir)
# download with local_dir
local_dir = '/tmp/snapshot_download_local_dir'
snapshot_download_path = snapshot_download(
self.model_id, revision=self.revision, local_dir=local_dir)
assert snapshot_download_path == local_dir
assert os.path.exists(os.path.join(local_dir, ModelFile.README))
shutil.rmtree(local_dir)
# download with local_dir and cache dir, with local first.
local_dir = '/tmp/snapshot_download_local_dir'
snapshot_download_path = snapshot_download(
self.model_id,
revision=self.revision,
cache_dir=cache_dir,
local_dir=local_dir)
assert snapshot_download_path == local_dir
assert os.path.exists(os.path.join(local_dir, ModelFile.README))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()