mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
Merge branch 'master' into release/1.15
This commit is contained in:
@@ -404,7 +404,7 @@ class HubApi:
|
||||
(owner_or_group, page_number, page_size),
|
||||
cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
handle_http_response(r, logger, cookies, 'list_model')
|
||||
handle_http_response(r, logger, cookies, owner_or_group)
|
||||
if r.status_code == HTTPStatus.OK:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
|
||||
@@ -87,16 +87,34 @@ def handle_http_post_error(response, url, request_body):
|
||||
|
||||
def handle_http_response(response: requests.Response, logger, cookies,
|
||||
model_id):
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except HTTPError as error:
|
||||
if cookies is None: # code in [403] and
|
||||
logger.error(
|
||||
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \
|
||||
private. Please login first.')
|
||||
message = _decode_response_error(response)
|
||||
raise HTTPError('Response details: %s, Request id: %s' %
|
||||
(message, get_request_id(response))) from error
|
||||
http_error_msg = ''
|
||||
if isinstance(response.reason, bytes):
|
||||
try:
|
||||
reason = response.reason.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
reason = response.reason.decode('iso-8859-1')
|
||||
else:
|
||||
reason = response.reason
|
||||
request_id = get_request_id(response)
|
||||
if 404 == response.status_code:
|
||||
http_error_msg = 'The request model: %s does not exist!' % (model_id)
|
||||
elif 403 == response.status_code:
|
||||
if cookies is None:
|
||||
http_error_msg = 'Authentication token does not exist, '
|
||||
'failed to access model {model_id} which may not exist or may be '
|
||||
'private. Please login first.'
|
||||
else:
|
||||
http_error_msg = 'The authentication token is invalid, failed to access model {model_id}.'
|
||||
elif 400 <= response.status_code < 500:
|
||||
http_error_msg = u'%s Client Error: %s, Request id: %s for url: %s' % (
|
||||
response.status_code, reason, request_id, response.url)
|
||||
|
||||
elif 500 <= response.status_code < 600:
|
||||
http_error_msg = u'%s Server Error: %s, Request id: %s, for url: %s' % (
|
||||
response.status_code, reason, request_id, response.url)
|
||||
if http_error_msg: # there is error.
|
||||
logger.error(http_error_msg)
|
||||
raise HTTPError(http_error_msg, response=response)
|
||||
|
||||
|
||||
def raise_on_error(rsp):
|
||||
@@ -160,7 +178,12 @@ def raise_for_http_status(rsp):
|
||||
else:
|
||||
reason = rsp.reason
|
||||
request_id = get_request_id(rsp)
|
||||
if 400 <= rsp.status_code < 500:
|
||||
if 404 == rsp.status_code:
|
||||
http_error_msg = 'The request resource(model or dataset) does not exist!,'
|
||||
'url: %s, reason: %s' % (rsp.url, reason)
|
||||
elif 403 == rsp.status_code:
|
||||
http_error_msg = 'Authentication token does not exist or invalid.'
|
||||
elif 400 <= rsp.status_code < 500:
|
||||
http_error_msg = u'%s Client Error: %s, Request id: %s for url: %s' % (
|
||||
rsp.status_code, reason, request_id, rsp.url)
|
||||
|
||||
|
||||
@@ -43,7 +43,8 @@ def snapshot_download(
|
||||
model_id (str): A user or an organization name and a repo name separated by a `/`.
|
||||
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
|
||||
commit hash. NOTE: currently only branch and tag name is supported
|
||||
cache_dir (str, Path, optional): Path to the folder where cached files are stored.
|
||||
cache_dir (str, Path, optional): Path to the folder where cached files are stored, model will
|
||||
be save as cache_dir/model_id/THE_MODEL_FILES.
|
||||
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
|
||||
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
|
||||
local cached file if it exists.
|
||||
|
||||
153
modelscope/models/audio/sv/xvector.py
Normal file
153
modelscope/models/audio/sv/xvector.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
"""
|
||||
This TDNN implementation is adapted from https://github.com/wenet-e2e/wespeaker.
|
||||
TDNN replaces i-vectors for text-independent speaker verification with embeddings
|
||||
extracted from a feedforward deep neural network. The specific structure can be
|
||||
referred to in https://www.danielpovey.com/files/2017_interspeech_embeddings.pdf.
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio.compliance.kaldi as Kaldi
|
||||
|
||||
import modelscope.models.audio.sv.pooling_layers as pooling_layers
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import MODELS, TorchModel
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.device import create_device
|
||||
|
||||
|
||||
class TdnnLayer(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, out_dim, context_size, dilation=1, padding=0):
|
||||
"""Define the TDNN layer, essentially 1-D convolution
|
||||
|
||||
Args:
|
||||
in_dim (int): input dimension
|
||||
out_dim (int): output channels
|
||||
context_size (int): context size, essentially the filter size
|
||||
dilation (int, optional): Defaults to 1.
|
||||
padding (int, optional): Defaults to 0.
|
||||
"""
|
||||
super(TdnnLayer, self).__init__()
|
||||
self.in_dim = in_dim
|
||||
self.out_dim = out_dim
|
||||
self.context_size = context_size
|
||||
self.dilation = dilation
|
||||
self.padding = padding
|
||||
self.conv_1d = nn.Conv1d(
|
||||
self.in_dim,
|
||||
self.out_dim,
|
||||
self.context_size,
|
||||
dilation=self.dilation,
|
||||
padding=self.padding)
|
||||
|
||||
# Set Affine=false to be compatible with the original kaldi version
|
||||
self.bn = nn.BatchNorm1d(out_dim, affine=False)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_1d(x)
|
||||
out = F.relu(out)
|
||||
out = self.bn(out)
|
||||
return out
|
||||
|
||||
|
||||
class XVEC(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
feat_dim=40,
|
||||
hid_dim=512,
|
||||
stats_dim=1500,
|
||||
embed_dim=512,
|
||||
pooling_func='TSTP'):
|
||||
"""
|
||||
Implementation of Kaldi style xvec, as described in
|
||||
X-VECTORS: ROBUST DNN EMBEDDINGS FOR SPEAKER RECOGNITION
|
||||
"""
|
||||
super(XVEC, self).__init__()
|
||||
self.feat_dim = feat_dim
|
||||
self.stats_dim = stats_dim
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.frame_1 = TdnnLayer(feat_dim, hid_dim, context_size=5, dilation=1)
|
||||
self.frame_2 = TdnnLayer(hid_dim, hid_dim, context_size=3, dilation=2)
|
||||
self.frame_3 = TdnnLayer(hid_dim, hid_dim, context_size=3, dilation=3)
|
||||
self.frame_4 = TdnnLayer(hid_dim, hid_dim, context_size=1, dilation=1)
|
||||
self.frame_5 = TdnnLayer(
|
||||
hid_dim, stats_dim, context_size=1, dilation=1)
|
||||
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == 'TSDP' else 2
|
||||
self.pool = getattr(pooling_layers, pooling_func)(
|
||||
in_dim=self.stats_dim)
|
||||
self.seg_1 = nn.Linear(self.stats_dim * self.n_stats, embed_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T)
|
||||
|
||||
out = self.frame_1(x)
|
||||
out = self.frame_2(out)
|
||||
out = self.frame_3(out)
|
||||
out = self.frame_4(out)
|
||||
out = self.frame_5(out)
|
||||
|
||||
stats = self.pool(out)
|
||||
embed_a = self.seg_1(stats)
|
||||
return embed_a
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.speaker_verification, module_name=Models.tdnn_sv)
|
||||
class SpeakerVerificationTDNN(TorchModel):
|
||||
|
||||
def __init__(self, model_dir, model_config: Dict[str, Any], *args,
|
||||
**kwargs):
|
||||
super().__init__(model_dir, model_config, *args, **kwargs)
|
||||
self.model_config = model_config
|
||||
self.other_config = kwargs
|
||||
|
||||
self.feature_dim = 80
|
||||
self.embed_dim = 512
|
||||
self.device = create_device(self.other_config['device'])
|
||||
print(self.device)
|
||||
|
||||
self.embedding_model = XVEC(
|
||||
feat_dim=self.feature_dim, embed_dim=self.embed_dim)
|
||||
pretrained_model_name = kwargs['pretrained_model']
|
||||
self.__load_check_point(pretrained_model_name)
|
||||
|
||||
self.embedding_model.to(self.device)
|
||||
self.embedding_model.eval()
|
||||
|
||||
def forward(self, audio):
|
||||
if isinstance(audio, np.ndarray):
|
||||
audio = torch.from_numpy(audio)
|
||||
if len(audio.shape) == 1:
|
||||
audio = audio.unsqueeze(0)
|
||||
assert len(
|
||||
audio.shape
|
||||
) == 2, 'modelscope error: the shape of input audio to model needs to be [N, T]'
|
||||
# audio shape: [N, T]
|
||||
feature = self.__extract_feature(audio)
|
||||
embedding = self.embedding_model(feature.to(self.device))
|
||||
|
||||
return embedding.detach().cpu()
|
||||
|
||||
def __extract_feature(self, audio):
|
||||
features = []
|
||||
for au in audio:
|
||||
feature = Kaldi.fbank(
|
||||
au.unsqueeze(0), num_mel_bins=self.feature_dim)
|
||||
feature = feature - feature.mean(dim=0, keepdim=True)
|
||||
features.append(feature.unsqueeze(0))
|
||||
features = torch.cat(features)
|
||||
return features
|
||||
|
||||
def __load_check_point(self, pretrained_model_name):
|
||||
self.embedding_model.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(self.model_dir, pretrained_model_name),
|
||||
map_location=torch.device('cpu')),
|
||||
strict=True)
|
||||
@@ -1,8 +1,10 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
|
||||
import requests
|
||||
@@ -13,6 +15,7 @@ from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.file_utils import get_model_cache_dir
|
||||
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
|
||||
TEST_MODEL_CHINESE_NAME,
|
||||
TEST_MODEL_ORG)
|
||||
@@ -148,6 +151,40 @@ class HubOperationTest(unittest.TestCase):
|
||||
data = self.api.list_models(TEST_MODEL_ORG)
|
||||
assert len(data['Models']) >= 1
|
||||
|
||||
def test_snapshot_download_location(self):
|
||||
self.prepare_case()
|
||||
snapshot_download_path = snapshot_download(
|
||||
model_id=self.model_id, revision=self.revision)
|
||||
assert os.path.exists(snapshot_download_path)
|
||||
assert '/hub/' in snapshot_download_path
|
||||
print(snapshot_download_path)
|
||||
shutil.rmtree(snapshot_download_path)
|
||||
# download with cache_dir
|
||||
cache_dir = '/tmp/snapshot_download_cache_test'
|
||||
snapshot_download_path = snapshot_download(
|
||||
self.model_id, revision=self.revision, cache_dir=cache_dir)
|
||||
expect_path = os.path.join(cache_dir, self.model_id)
|
||||
assert snapshot_download_path == expect_path
|
||||
assert os.path.exists(
|
||||
os.path.join(snapshot_download_path, ModelFile.README))
|
||||
shutil.rmtree(cache_dir)
|
||||
# download with local_dir
|
||||
local_dir = '/tmp/snapshot_download_local_dir'
|
||||
snapshot_download_path = snapshot_download(
|
||||
self.model_id, revision=self.revision, local_dir=local_dir)
|
||||
assert snapshot_download_path == local_dir
|
||||
assert os.path.exists(os.path.join(local_dir, ModelFile.README))
|
||||
shutil.rmtree(local_dir)
|
||||
# download with local_dir and cache dir, with local first.
|
||||
local_dir = '/tmp/snapshot_download_local_dir'
|
||||
snapshot_download_path = snapshot_download(
|
||||
self.model_id,
|
||||
revision=self.revision,
|
||||
cache_dir=cache_dir,
|
||||
local_dir=local_dir)
|
||||
assert snapshot_download_path == local_dir
|
||||
assert os.path.exists(os.path.join(local_dir, ModelFile.README))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user