fix error report (#868)

Co-authored-by: mulin.lyh <mulin.lyh@taobao.com>
This commit is contained in:
liuyhwangyh
2024-05-28 14:38:19 +08:00
committed by GitHub
parent f93a184d88
commit 17da5e2264
6 changed files with 74 additions and 316 deletions

View File

@@ -404,7 +404,7 @@ class HubApi:
(owner_or_group, page_number, page_size),
cookies=cookies,
headers=self.builder_headers(self.headers))
handle_http_response(r, logger, cookies, 'list_model')
handle_http_response(r, logger, cookies, owner_or_group)
if r.status_code == HTTPStatus.OK:
if is_ok(r.json()):
data = r.json()[API_RESPONSE_FIELD_DATA]

View File

@@ -87,16 +87,34 @@ def handle_http_post_error(response, url, request_body):
def handle_http_response(response: requests.Response, logger, cookies,
model_id):
http_error_msg = ''
if isinstance(response.reason, bytes):
try:
response.raise_for_status()
except HTTPError as error:
if cookies is None: # code in [403] and
logger.error(
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \
private. Please login first.')
message = _decode_response_error(response)
raise HTTPError('Response details: %s, Request id: %s' %
(message, get_request_id(response))) from error
reason = response.reason.decode('utf-8')
except UnicodeDecodeError:
reason = response.reason.decode('iso-8859-1')
else:
reason = response.reason
request_id = get_request_id(response)
if 404 == response.status_code:
http_error_msg = 'The request model: %s does not exist!' % (model_id)
elif 403 == response.status_code:
if cookies is None:
http_error_msg = 'Authentication token does not exist, '
'failed to access model {model_id} which may not exist or may be '
'private. Please login first.'
else:
http_error_msg = 'The authentication token is invalid, failed to access model {model_id}.'
elif 400 <= response.status_code < 500:
http_error_msg = u'%s Client Error: %s, Request id: %s for url: %s' % (
response.status_code, reason, request_id, response.url)
elif 500 <= response.status_code < 600:
http_error_msg = u'%s Server Error: %s, Request id: %s, for url: %s' % (
response.status_code, reason, request_id, response.url)
if http_error_msg: # there is error.
logger.error(http_error_msg)
raise HTTPError(http_error_msg, response=response)
def raise_on_error(rsp):
@@ -160,7 +178,12 @@ def raise_for_http_status(rsp):
else:
reason = rsp.reason
request_id = get_request_id(rsp)
if 400 <= rsp.status_code < 500:
if 404 == rsp.status_code:
http_error_msg = 'The request resource(model or dataset) does not exist!,'
'url: %s, reason: %s' % (rsp.url, reason)
elif 403 == rsp.status_code:
http_error_msg = 'Authentication token does not exist or invalid.'
elif 400 <= rsp.status_code < 500:
http_error_msg = u'%s Client Error: %s, Request id: %s for url: %s' % (
rsp.status_code, reason, request_id, rsp.url)

View File

@@ -43,7 +43,8 @@ def snapshot_download(
model_id (str): A user or an organization name and a repo name separated by a `/`.
revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a
commit hash. NOTE: currently only branch and tag name is supported
cache_dir (str, Path, optional): Path to the folder where cached files are stored.
cache_dir (str, Path, optional): Path to the folder where cached files are stored, model will
be save as cache_dir/model_id/THE_MODEL_FILES.
user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string.
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
local cached file if it exists.

View File

@@ -1,303 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class Conv1d_O(nn.Module):
def __init__(
self,
out_channels,
kernel_size,
input_shape=None,
in_channels=None,
stride=1,
dilation=1,
padding='same',
groups=1,
bias=True,
padding_mode='reflect',
skip_transpose=False,
):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.padding_mode = padding_mode
self.unsqueeze = False
self.skip_transpose = skip_transpose
if input_shape is None and in_channels is None:
raise ValueError('Must provide one of input_shape or in_channels')
if in_channels is None:
in_channels = self._check_input_shape(input_shape)
self.conv = nn.Conv1d(
in_channels,
out_channels,
self.kernel_size,
stride=self.stride,
dilation=self.dilation,
padding=0,
groups=groups,
bias=bias,
)
def forward(self, x):
"""Returns the output of the convolution.
Arguments
---------
x : torch.Tensor (batch, time, channel)
input to convolve. 2d or 4d tensors are expected.
"""
if not self.skip_transpose:
x = x.transpose(1, -1)
if self.unsqueeze:
x = x.unsqueeze(1)
if self.padding == 'same':
x = self._manage_padding(x, self.kernel_size, self.dilation,
self.stride)
elif self.padding == 'causal':
num_pad = (self.kernel_size - 1) * self.dilation
x = F.pad(x, (num_pad, 0))
elif self.padding == 'valid':
pass
else:
raise ValueError(
"Padding must be 'same', 'valid' or 'causal'. Got "
+ self.padding)
wx = self.conv(x)
if self.unsqueeze:
wx = wx.squeeze(1)
if not self.skip_transpose:
wx = wx.transpose(1, -1)
return wx
def _manage_padding(
self,
x,
kernel_size: int,
dilation: int,
stride: int,
):
# Detecting input shape
L_in = x.shape[-1]
# Time padding
padding = get_padding_elem(L_in, stride, kernel_size, dilation)
# Applying padding
x = F.pad(x, padding, mode=self.padding_mode)
return x
def _check_input_shape(self, shape):
"""Checks the input shape and returns the number of input channels.
"""
if len(shape) == 2:
self.unsqueeze = True
in_channels = 1
elif self.skip_transpose:
in_channels = shape[1]
elif len(shape) == 3:
in_channels = shape[2]
else:
raise ValueError('conv1d expects 2d, 3d inputs. Got '
+ str(len(shape)))
# Kernel size must be odd
if self.kernel_size % 2 == 0:
raise ValueError(
'The field kernel size must be an odd number. Got %s.' %
(self.kernel_size))
return in_channels
# Skip transpose as much as possible for efficiency
class Conv1d(Conv1d_O):
def __init__(self, *args, **kwargs):
super().__init__(skip_transpose=True, *args, **kwargs)
def get_padding_elem(L_in: int, stride: int, kernel_size: int, dilation: int):
"""This function computes the number of elements to add for zero-padding.
Arguments
---------
L_in : int
stride: int
kernel_size : int
dilation : int
"""
if stride > 1:
n_steps = math.ceil(((L_in - kernel_size * dilation) / stride) + 1)
L_out = stride * (n_steps - 1) + kernel_size * dilation
padding = [kernel_size // 2, kernel_size // 2]
else:
L_out = (L_in - dilation * (kernel_size - 1) - 1) // stride + 1
padding = [(L_in - L_out) // 2, (L_in - L_out) // 2]
return padding
class BatchNorm1d_O(nn.Module):
def __init__(
self,
input_shape=None,
input_size=None,
eps=1e-05,
momentum=0.1,
affine=True,
track_running_stats=True,
combine_batch_time=False,
skip_transpose=False,
):
super().__init__()
self.combine_batch_time = combine_batch_time
self.skip_transpose = skip_transpose
if input_size is None and skip_transpose:
input_size = input_shape[1]
elif input_size is None:
input_size = input_shape[-1]
self.norm = nn.BatchNorm1d(
input_size,
eps=eps,
momentum=momentum,
affine=affine,
track_running_stats=track_running_stats,
)
def forward(self, x):
"""Returns the normalized input tensor.
Arguments
---------
x : torch.Tensor (batch, time, [channels])
input to normalize. 2d or 3d tensors are expected in input
4d tensors can be used when combine_dims=True.
"""
shape_or = x.shape
if self.combine_batch_time:
if x.ndim == 3:
x = x.reshape(shape_or[0] * shape_or[1], shape_or[2])
else:
x = x.reshape(shape_or[0] * shape_or[1], shape_or[3],
shape_or[2])
elif not self.skip_transpose:
x = x.transpose(-1, 1)
x_n = self.norm(x)
if self.combine_batch_time:
x_n = x_n.reshape(shape_or)
elif not self.skip_transpose:
x_n = x_n.transpose(1, -1)
return x_n
class BatchNorm1d(BatchNorm1d_O):
def __init__(self, *args, **kwargs):
super().__init__(skip_transpose=True, *args, **kwargs)
class Xvector(torch.nn.Module):
"""This model extracts X-vectors for speaker recognition and diarization.
Arguments
---------
device : str
Device used e.g. "cpu" or "cuda".
activation : torch class
A class for constructing the activation layers.
tdnn_blocks : int
Number of time-delay neural (TDNN) layers.
tdnn_channels : list of ints
Output channels for TDNN layer.
tdnn_kernel_sizes : list of ints
List of kernel sizes for each TDNN layer.
tdnn_dilations : list of ints
List of dilations for kernels in each TDNN layer.
lin_neurons : int
Number of neurons in linear layers.
Example
-------
>>> compute_xvect = Xvector('cpu')
>>> input_feats = torch.rand([5, 10, 40])
>>> outputs = compute_xvect(input_feats)
>>> outputs.shape
torch.Size([5, 1, 512])
"""
def __init__(
self,
device='cpu',
activation=torch.nn.LeakyReLU,
tdnn_blocks=5,
tdnn_channels=[512, 512, 512, 512, 1500],
tdnn_kernel_sizes=[5, 3, 3, 1, 1],
tdnn_dilations=[1, 2, 3, 1, 1],
lin_neurons=512,
in_channels=80,
):
super().__init__()
self.blocks = nn.ModuleList()
# TDNN layers
for block_index in range(tdnn_blocks):
out_channels = tdnn_channels[block_index]
self.blocks.extend([
Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=tdnn_kernel_sizes[block_index],
dilation=tdnn_dilations[block_index],
),
activation(),
BatchNorm1d(input_size=out_channels),
])
in_channels = tdnn_channels[block_index]
def forward(self, x, lens=None):
"""Returns the x-vectors.
Arguments
---------
x : torch.Tensor
"""
x = x.transpose(1, 2)
for layer in self.blocks:
try:
x = layer(x, lengths=lens)
except TypeError:
x = layer(x)
x = x.transpose(1, 2)
return x

View File

@@ -1,8 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
import uuid
from pathlib import Path
from shutil import rmtree
import requests
@@ -13,6 +15,7 @@ from modelscope.hub.file_download import model_file_download
from modelscope.hub.repository import Repository
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.utils.constant import ModelFile
from modelscope.utils.file_utils import get_model_cache_dir
from modelscope.utils.test_utils import (TEST_ACCESS_TOKEN1,
TEST_MODEL_CHINESE_NAME,
TEST_MODEL_ORG)
@@ -148,6 +151,40 @@ class HubOperationTest(unittest.TestCase):
data = self.api.list_models(TEST_MODEL_ORG)
assert len(data['Models']) >= 1
def test_snapshot_download_location(self):
self.prepare_case()
snapshot_download_path = snapshot_download(
model_id=self.model_id, revision=self.revision)
assert os.path.exists(snapshot_download_path)
assert '/hub/' in snapshot_download_path
print(snapshot_download_path)
shutil.rmtree(snapshot_download_path)
# download with cache_dir
cache_dir = '/tmp/snapshot_download_cache_test'
snapshot_download_path = snapshot_download(
self.model_id, revision=self.revision, cache_dir=cache_dir)
expect_path = os.path.join(cache_dir, self.model_id)
assert snapshot_download_path == expect_path
assert os.path.exists(
os.path.join(snapshot_download_path, ModelFile.README))
shutil.rmtree(cache_dir)
# download with local_dir
local_dir = '/tmp/snapshot_download_local_dir'
snapshot_download_path = snapshot_download(
self.model_id, revision=self.revision, local_dir=local_dir)
assert snapshot_download_path == local_dir
assert os.path.exists(os.path.join(local_dir, ModelFile.README))
shutil.rmtree(local_dir)
# download with local_dir and cache dir, with local first.
local_dir = '/tmp/snapshot_download_local_dir'
snapshot_download_path = snapshot_download(
self.model_id,
revision=self.revision,
cache_dir=cache_dir,
local_dir=local_dir)
assert snapshot_download_path == local_dir
assert os.path.exists(os.path.join(local_dir, ModelFile.README))
if __name__ == '__main__':
unittest.main()