Merge branch 'master' into release/1.15

This commit is contained in:
mulin.lyh
2024-05-28 14:55:46 +08:00
5 changed files with 227 additions and 13 deletions

View File

@@ -404,7 +404,7 @@ class HubApi:
(owner_or_group, page_number, page_size),
cookies=cookies,
headers=self.builder_headers(self.headers))
handle_http_response(r, logger, cookies, 'list_model')
handle_http_response(r, logger, cookies, owner_or_group)
if r.status_code == HTTPStatus.OK:
if is_ok(r.json()):
data = r.json()[API_RESPONSE_FIELD_DATA]

View File

@@ -87,16 +87,34 @@ def handle_http_post_error(response, url, request_body):
def handle_http_response(response: requests.Response, logger, cookies,
model_id):
try:
response.raise_for_status()
except HTTPError as error:
if cookies is None: # code in [403] and
logger.error(
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \
private. Please login first.')
message = _decode_response_error(response)
raise HTTPError('Response details: %s, Request id: %s' %
(message, get_request_id(response))) from error
http_error_msg = ''
if isinstance(response.reason, bytes):
try:
reason = response.reason.decode('utf-8')
except UnicodeDecodeError:
reason = response.reason.decode('iso-8859-1')
else:
reason = response.reason
request_id = get_request_id(response)
if 404 == response.status_code:
http_error_msg = 'The request model: %s does not exist!' % (model_id)
elif 403 == response.status_code:
if cookies is None:
http_error_msg = 'Authentication token does not exist, '
'failed to access model {model_id} which may not exist or may be '
'private. Please login first.'
else:
http_error_msg = 'The authentication token is invalid, failed to access model {model_id}.'
elif 400 <= response.status_code < 500:
http_error_msg = u'%s Client Error: %s, Request id: %s for url: %s' % (
response.status_code, reason, request_id, response.url)
elif 500 <= response.status_code < 600:
http_error_msg = u'%s Server Error: %s, Request id: %s, for url: %s' % (
response.status_code, reason, request_id, response.url)
if http_error_msg: # there is error.
logger.error(http_error_msg)
raise HTTPError(http_error_msg, response=response)
def raise_on_error(rsp):
@@ -160,7 +178,12 @@ def raise_for_http_status(rsp):
else:
reason = rsp.reason
request_id = get_request_id(rsp)
if 400 <= rsp.status_code < 500:
if 404 == rsp.status_code:
http_error_msg = 'The request resource(model or dataset) does not exist!,'
'url: %s, reason: %s' % (rsp.url, reason)
elif 403 == rsp.status_code:
http_error_msg = 'Authentication token does not exist or invalid.'
elif 400 <= rsp.status_code < 500:
http_error_msg = u'%s Client Error: %s, Request id: %s for url: %s' % (
rsp.status_code, reason, request_id, rsp.url)

View File

@@ -43,7 +43,8 @@ def snapshot_download(
model_id (str): A user or an organization name and a repo name separated by a `/`.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. NOTE: currently only branch and tag name is supported
cache_dir (str, Path, optional): Path to the folder where cached files are stored.
cache_dir (str, Path, optional): Path to the folder where cached files are stored, model will
be save as cache_dir/model_id/THE_MODEL_FILES.
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
local cached file if it exists.

View File

@@ -0,0 +1,153 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
"""
This TDNN implementation is adapted from https://github.com/wenet-e2e/wespeaker.
TDNN replaces i-vectors for text-independent speaker verification with embeddings
extracted from a feedforward deep neural network. The specific structure can be
referred to in https://www.danielpovey.com/files/2017_interspeech_embeddings.pdf.
"""
import math
import os
from typing import Any, Dict, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.compliance.kaldi as Kaldi
import modelscope.models.audio.sv.pooling_layers as pooling_layers
from modelscope.metainfo import Models
from modelscope.models import MODELS, TorchModel
from modelscope.utils.constant import Tasks
from modelscope.utils.device import create_device
class TdnnLayer(nn.Module):
def __init__(self, in_dim, out_dim, context_size, dilation=1, padding=0):
"""Define the TDNN layer, essentially 1-D convolution
Args:
in_dim (int): input dimension
out_dim (int): output channels
context_size (int): context size, essentially the filter size
dilation (int, optional): Defaults to 1.
padding (int, optional): Defaults to 0.
"""
super(TdnnLayer, self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.context_size = context_size
self.dilation = dilation
self.padding = padding
self.conv_1d = nn.Conv1d(
self.in_dim,
self.out_dim,
self.context_size,
dilation=self.dilation,
padding=self.padding)
# Set Affine=false to be compatible with the original kaldi version
self.bn = nn.BatchNorm1d(out_dim, affine=False)
def forward(self, x):
out = self.conv_1d(x)
out = F.relu(out)
out = self.bn(out)
return out
class XVEC(nn.Module):
def __init__(self,
feat_dim=40,
hid_dim=512,
stats_dim=1500,
embed_dim=512,
pooling_func='TSTP'):
"""
Implementation of Kaldi style xvec, as described in
X-VECTORS: ROBUST DNN EMBEDDINGS FOR SPEAKER RECOGNITION
"""
super(XVEC, self).__init__()
self.feat_dim = feat_dim
self.stats_dim = stats_dim
self.embed_dim = embed_dim
self.frame_1 = TdnnLayer(feat_dim, hid_dim, context_size=5, dilation=1)
self.frame_2 = TdnnLayer(hid_dim, hid_dim, context_size=3, dilation=2)
self.frame_3 = TdnnLayer(hid_dim, hid_dim, context_size=3, dilation=3)
self.frame_4 = TdnnLayer(hid_dim, hid_dim, context_size=1, dilation=1)
self.frame_5 = TdnnLayer(
hid_dim, stats_dim, context_size=1, dilation=1)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2
self.pool = getattr(pooling_layers, pooling_func)(
in_dim=self.stats_dim)
self.seg_1 = nn.Linear(self.stats_dim * self.n_stats, embed_dim)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
out = self.frame_1(x)
out = self.frame_2(out)
out = self.frame_3(out)
out = self.frame_4(out)
out = self.frame_5(out)
stats = self.pool(out)
embed_a = self.seg_1(stats)
return embed_a
@MODELS.register_module(Tasks.speaker_verification, module_name=Models.tdnn_sv)
class SpeakerVerificationTDNN(TorchModel):
def __init__(self, model_dir, model_config: Dict[str, Any], *args,
**kwargs):
super().__init__(model_dir, model_config, *args, **kwargs)
self.model_config = model_config
self.other_config = kwargs
self.feature_dim = 80
self.embed_dim = 512
self.device = create_device(self.other_config['device'])
print(self.device)
self.embedding_model = XVEC(
feat_dim=self.feature_dim, embed_dim=self.embed_dim)
pretrained_model_name = kwargs['pretrained_model']
self.__load_check_point(pretrained_model_name)
self.embedding_model.to(self.device)
self.embedding_model.eval()
def forward(self, audio):
if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio)
if len(audio.shape) == 1:
audio = audio.unsqueeze(0)
assert len(
audio.shape
) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
# audio shape: [N, T]
feature = self.__extract_feature(audio)
embedding = self.embedding_model(feature.to(self.device))
return embedding.detach().cpu()
def __extract_feature(self, audio):
features = []
for au in audio:
feature = Kaldi.fbank(
au.unsqueeze(0), num_mel_bins=self.feature_dim)
feature = feature - feature.mean(dim=0, keepdim=True)
features.append(feature.unsqueeze(0))
features = torch.cat(features)
return features
def __load_check_point(self, pretrained_model_name):
self.embedding_model.load_state_dict(
torch.load(
os.path.join(self.model_dir, pretrained_model_name),
map_location=torch.device('cpu')),
strict=True)

View File

@@ -1,8 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
import uuid
from pathlib import Path
from shutil import rmtree
import requests
@@ -13,6 +15,7 @@ from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import ModelFile
from modelscope.utils.file_utils import get_model_cache_dir
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG)
@@ -148,6 +151,40 @@ class HubOperationTest(unittest.TestCase):
data = self.api.list_models(TEST_MODEL_ORG)
assert len(data['Models']) >= 1
def test_snapshot_download_location(self):
self.prepare_case()
snapshot_download_path = snapshot_download(
model_id=self.model_id, revision=self.revision)
assert os.path.exists(snapshot_download_path)
assert '/hub/' in snapshot_download_path
print(snapshot_download_path)
shutil.rmtree(snapshot_download_path)
# download with cache_dir
cache_dir = '/tmp/snapshot_download_cache_test'
snapshot_download_path = snapshot_download(
self.model_id, revision=self.revision, cache_dir=cache_dir)
expect_path = os.path.join(cache_dir, self.model_id)
assert snapshot_download_path == expect_path
assert os.path.exists(
os.path.join(snapshot_download_path, ModelFile.README))
shutil.rmtree(cache_dir)
# download with local_dir
local_dir = '/tmp/snapshot_download_local_dir'
snapshot_download_path = snapshot_download(
self.model_id, revision=self.revision, local_dir=local_dir)
assert snapshot_download_path == local_dir
assert os.path.exists(os.path.join(local_dir, ModelFile.README))
shutil.rmtree(local_dir)
# download with local_dir and cache dir, with local first.
local_dir = '/tmp/snapshot_download_local_dir'
snapshot_download_path = snapshot_download(
self.model_id,
revision=self.revision,
cache_dir=cache_dir,
local_dir=local_dir)
assert snapshot_download_path == local_dir
assert os.path.exists(os.path.join(local_dir, ModelFile.README))
if __name__ == '__main__':
unittest.main()