mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[to #42322933] Merge ANS pipeline into master
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9178339 * refactor: move aec models to audio/aec * refactor: move aec models to audio/aec * refactor: move aec models to audio/aec * refactor: move aec models to audio/aec * feat: add unittest for ANS pipeline * Merge branch 'master' into dev/ans * add new SoundFile to audio dependency * Merge branch 'master' into dev/ans * use ANS pipeline name from metainfo * Merge branch 'master' into dev/ans * chore: update docstring of ANS module * Merge branch 'master' into dev/ans * refactor: use names from metainfo * refactor: enable ans unittest * refactor: add more log message in unittest
This commit is contained in:
@@ -21,6 +21,7 @@ class Models(object):
|
||||
sambert_hifi_16k = 'sambert-hifi-16k'
|
||||
generic_tts_frontend = 'generic-tts-frontend'
|
||||
hifigan16k = 'hifigan16k'
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
kws_kwsbp = 'kws-kwsbp'
|
||||
|
||||
# multi-modal models
|
||||
@@ -55,6 +56,7 @@ class Pipelines(object):
|
||||
# audio tasks
|
||||
sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts'
|
||||
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k'
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
kws_kwsbp = 'kws-kwsbp'
|
||||
|
||||
# multi-modal tasks
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .audio.ans.frcrn import FRCRNModel
|
||||
from .audio.kws import GenericKeyWordSpotting
|
||||
from .audio.tts.am import SambertNetHifi16k
|
||||
from .audio.tts.vocoder import Hifigan16k
|
||||
|
||||
0
modelscope/models/audio/aec/network/__init__.py
Normal file
0
modelscope/models/audio/aec/network/__init__.py
Normal file
0
modelscope/models/audio/ans/__init__.py
Normal file
0
modelscope/models/audio/ans/__init__.py
Normal file
248
modelscope/models/audio/ans/complex_nn.py
Normal file
248
modelscope/models/audio/ans/complex_nn.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class UniDeepFsmn(nn.Module):
|
||||
|
||||
def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
|
||||
super(UniDeepFsmn, self).__init__()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
|
||||
if lorder is None:
|
||||
return
|
||||
|
||||
self.lorder = lorder
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.linear = nn.Linear(input_dim, hidden_size)
|
||||
|
||||
self.project = nn.Linear(hidden_size, output_dim, bias=False)
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
output_dim,
|
||||
output_dim, [lorder, 1], [1, 1],
|
||||
groups=output_dim,
|
||||
bias=False)
|
||||
|
||||
def forward(self, input):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
input: torch with shape: batch (b) x sequence(T) x feature (h)
|
||||
|
||||
Returns:
|
||||
batch (b) x channel (c) x sequence(T) x feature (h)
|
||||
"""
|
||||
f1 = F.relu(self.linear(input))
|
||||
|
||||
p1 = self.project(f1)
|
||||
|
||||
x = torch.unsqueeze(p1, 1)
|
||||
# x: batch (b) x channel (c) x sequence(T) x feature (h)
|
||||
x_per = x.permute(0, 3, 2, 1)
|
||||
# x_per: batch (b) x feature (h) x sequence(T) x channel (c)
|
||||
y = F.pad(x_per, [0, 0, self.lorder - 1, 0])
|
||||
|
||||
out = x_per + self.conv1(y)
|
||||
|
||||
out1 = out.permute(0, 3, 2, 1)
|
||||
# out1: batch (b) x channel (c) x sequence(T) x feature (h)
|
||||
return input + out1.squeeze()
|
||||
|
||||
|
||||
class ComplexUniDeepFsmn(nn.Module):
|
||||
|
||||
def __init__(self, nIn, nHidden=128, nOut=128):
|
||||
super(ComplexUniDeepFsmn, self).__init__()
|
||||
|
||||
self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
|
||||
self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
|
||||
self.fsmn_re_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden)
|
||||
self.fsmn_im_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden)
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
x: torch with shape [batch, channel, feature, sequence, 2], eg: [6, 256, 1, 106, 2]
|
||||
|
||||
Returns:
|
||||
[batch, feature, sequence, 2], eg: [6, 99, 1024, 2]
|
||||
"""
|
||||
#
|
||||
b, c, h, T, d = x.size()
|
||||
x = torch.reshape(x, (b, c * h, T, d))
|
||||
# x: [b,h,T,2], [6, 256, 106, 2]
|
||||
x = torch.transpose(x, 1, 2)
|
||||
# x: [b,T,h,2], [6, 106, 256, 2]
|
||||
|
||||
real_L1 = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1])
|
||||
imaginary_L1 = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0])
|
||||
# GRU output: [99, 6, 128]
|
||||
real = self.fsmn_re_L2(real_L1) - self.fsmn_im_L2(imaginary_L1)
|
||||
imaginary = self.fsmn_re_L2(imaginary_L1) + self.fsmn_im_L2(real_L1)
|
||||
# output: [b,T,h,2], [99, 6, 1024, 2]
|
||||
output = torch.stack((real, imaginary), dim=-1)
|
||||
|
||||
# output: [b,h,T,2], [6, 99, 1024, 2]
|
||||
output = torch.transpose(output, 1, 2)
|
||||
output = torch.reshape(output, (b, c, h, T, d))
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ComplexUniDeepFsmn_L1(nn.Module):
|
||||
|
||||
def __init__(self, nIn, nHidden=128, nOut=128):
|
||||
super(ComplexUniDeepFsmn_L1, self).__init__()
|
||||
self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
|
||||
self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
x: torch with shape [batch, channel, feature, sequence, 2], eg: [6, 256, 1, 106, 2]
|
||||
"""
|
||||
b, c, h, T, d = x.size()
|
||||
# x : [b,T,h,c,2]
|
||||
x = torch.transpose(x, 1, 3)
|
||||
x = torch.reshape(x, (b * T, h, c, d))
|
||||
|
||||
real = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1])
|
||||
imaginary = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0])
|
||||
# output: [b*T,h,c,2], [6*106, h, 256, 2]
|
||||
output = torch.stack((real, imaginary), dim=-1)
|
||||
|
||||
output = torch.reshape(output, (b, T, h, c, d))
|
||||
output = torch.transpose(output, 1, 3)
|
||||
return output
|
||||
|
||||
|
||||
class ComplexConv2d(nn.Module):
|
||||
# https://github.com/litcoderr/ComplexCNN/blob/master/complexcnn/modules.py
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
# Model components
|
||||
self.conv_re = nn.Conv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
**kwargs)
|
||||
self.conv_im = nn.Conv2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
**kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
x: torch with shape: [batch,channel,axis1,axis2,2]
|
||||
"""
|
||||
real = self.conv_re(x[..., 0]) - self.conv_im(x[..., 1])
|
||||
imaginary = self.conv_re(x[..., 1]) + self.conv_im(x[..., 0])
|
||||
output = torch.stack((real, imaginary), dim=-1)
|
||||
return output
|
||||
|
||||
|
||||
class ComplexConvTranspose2d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
output_padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
# Model components
|
||||
self.tconv_re = nn.ConvTranspose2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
dilation=dilation,
|
||||
**kwargs)
|
||||
self.tconv_im = nn.ConvTranspose2d(
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
dilation=dilation,
|
||||
**kwargs)
|
||||
|
||||
def forward(self, x): # shpae of x : [batch,channel,axis1,axis2,2]
|
||||
real = self.tconv_re(x[..., 0]) - self.tconv_im(x[..., 1])
|
||||
imaginary = self.tconv_re(x[..., 1]) + self.tconv_im(x[..., 0])
|
||||
output = torch.stack((real, imaginary), dim=-1)
|
||||
return output
|
||||
|
||||
|
||||
class ComplexBatchNorm2d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
num_features,
|
||||
eps=1e-5,
|
||||
momentum=0.1,
|
||||
affine=True,
|
||||
track_running_stats=True,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.bn_re = nn.BatchNorm2d(
|
||||
num_features=num_features,
|
||||
momentum=momentum,
|
||||
affine=affine,
|
||||
eps=eps,
|
||||
track_running_stats=track_running_stats,
|
||||
**kwargs)
|
||||
self.bn_im = nn.BatchNorm2d(
|
||||
num_features=num_features,
|
||||
momentum=momentum,
|
||||
affine=affine,
|
||||
eps=eps,
|
||||
track_running_stats=track_running_stats,
|
||||
**kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
real = self.bn_re(x[..., 0])
|
||||
imag = self.bn_im(x[..., 1])
|
||||
output = torch.stack((real, imag), dim=-1)
|
||||
return output
|
||||
112
modelscope/models/audio/ans/conv_stft.py
Normal file
112
modelscope/models/audio/ans/conv_stft.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scipy.signal import get_window
|
||||
|
||||
|
||||
def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
|
||||
if win_type == 'None' or win_type is None:
|
||||
window = np.ones(win_len)
|
||||
else:
|
||||
window = get_window(win_type, win_len, fftbins=True)**0.5
|
||||
|
||||
N = fft_len
|
||||
fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
|
||||
real_kernel = np.real(fourier_basis)
|
||||
imag_kernel = np.imag(fourier_basis)
|
||||
kernel = np.concatenate([real_kernel, imag_kernel], 1).T
|
||||
|
||||
if invers:
|
||||
kernel = np.linalg.pinv(kernel).T
|
||||
|
||||
kernel = kernel * window
|
||||
kernel = kernel[:, None, :]
|
||||
return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(
|
||||
window[None, :, None].astype(np.float32))
|
||||
|
||||
|
||||
class ConvSTFT(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
win_len,
|
||||
win_inc,
|
||||
fft_len=None,
|
||||
win_type='hamming',
|
||||
feature_type='real',
|
||||
fix=True):
|
||||
super(ConvSTFT, self).__init__()
|
||||
|
||||
if fft_len is None:
|
||||
self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
|
||||
else:
|
||||
self.fft_len = fft_len
|
||||
|
||||
kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type)
|
||||
self.weight = nn.Parameter(kernel, requires_grad=(not fix))
|
||||
self.feature_type = feature_type
|
||||
self.stride = win_inc
|
||||
self.win_len = win_len
|
||||
self.dim = self.fft_len
|
||||
|
||||
def forward(self, inputs):
|
||||
if inputs.dim() == 2:
|
||||
inputs = torch.unsqueeze(inputs, 1)
|
||||
|
||||
outputs = F.conv1d(inputs, self.weight, stride=self.stride)
|
||||
|
||||
if self.feature_type == 'complex':
|
||||
return outputs
|
||||
else:
|
||||
dim = self.dim // 2 + 1
|
||||
real = outputs[:, :dim, :]
|
||||
imag = outputs[:, dim:, :]
|
||||
mags = torch.sqrt(real**2 + imag**2)
|
||||
phase = torch.atan2(imag, real)
|
||||
return mags, phase
|
||||
|
||||
|
||||
class ConviSTFT(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
win_len,
|
||||
win_inc,
|
||||
fft_len=None,
|
||||
win_type='hamming',
|
||||
feature_type='real',
|
||||
fix=True):
|
||||
super(ConviSTFT, self).__init__()
|
||||
if fft_len is None:
|
||||
self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
|
||||
else:
|
||||
self.fft_len = fft_len
|
||||
kernel, window = init_kernels(
|
||||
win_len, win_inc, self.fft_len, win_type, invers=True)
|
||||
self.weight = nn.Parameter(kernel, requires_grad=(not fix))
|
||||
self.feature_type = feature_type
|
||||
self.win_type = win_type
|
||||
self.win_len = win_len
|
||||
self.win_inc = win_inc
|
||||
self.stride = win_inc
|
||||
self.dim = self.fft_len
|
||||
self.register_buffer('window', window)
|
||||
self.register_buffer('enframe', torch.eye(win_len)[:, None, :])
|
||||
|
||||
def forward(self, inputs, phase=None):
|
||||
"""
|
||||
Args:
|
||||
inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags)
|
||||
phase: [B, N//2+1, T] (if not none)
|
||||
"""
|
||||
|
||||
if phase is not None:
|
||||
real = inputs * torch.cos(phase)
|
||||
imag = inputs * torch.sin(phase)
|
||||
inputs = torch.cat([real, imag], 1)
|
||||
outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
|
||||
|
||||
# this is from torch-stft: https://github.com/pseeth/torch-stft
|
||||
t = self.window.repeat(1, 1, inputs.size(-1))**2
|
||||
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
|
||||
outputs = outputs / (coff + 1e-8)
|
||||
return outputs
|
||||
309
modelscope/models/audio/ans/frcrn.py
Normal file
309
modelscope/models/audio/ans/frcrn.py
Normal file
@@ -0,0 +1,309 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from ...base import Model, Tensor
|
||||
from .conv_stft import ConviSTFT, ConvSTFT
|
||||
from .unet import UNet
|
||||
|
||||
|
||||
class FTB(nn.Module):
|
||||
|
||||
def __init__(self, input_dim=257, in_channel=9, r_channel=5):
|
||||
|
||||
super(FTB, self).__init__()
|
||||
self.in_channel = in_channel
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]),
|
||||
nn.BatchNorm2d(r_channel), nn.ReLU())
|
||||
|
||||
self.conv1d = nn.Sequential(
|
||||
nn.Conv1d(
|
||||
r_channel * input_dim, in_channel, kernel_size=9, padding=4),
|
||||
nn.BatchNorm1d(in_channel), nn.ReLU())
|
||||
self.freq_fc = nn.Linear(input_dim, input_dim, bias=False)
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]),
|
||||
nn.BatchNorm2d(in_channel), nn.ReLU())
|
||||
|
||||
def forward(self, inputs):
|
||||
'''
|
||||
inputs should be [Batch, Ca, Dim, Time]
|
||||
'''
|
||||
# T-F attention
|
||||
conv1_out = self.conv1(inputs)
|
||||
B, C, D, T = conv1_out.size()
|
||||
reshape1_out = torch.reshape(conv1_out, [B, C * D, T])
|
||||
conv1d_out = self.conv1d(reshape1_out)
|
||||
conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T])
|
||||
|
||||
# now is also [B,C,D,T]
|
||||
att_out = conv1d_out * inputs
|
||||
|
||||
# tranpose to [B,C,T,D]
|
||||
att_out = torch.transpose(att_out, 2, 3)
|
||||
freqfc_out = self.freq_fc(att_out)
|
||||
att_out = torch.transpose(freqfc_out, 2, 3)
|
||||
|
||||
cat_out = torch.cat([att_out, inputs], 1)
|
||||
outputs = self.conv2(cat_out)
|
||||
return outputs
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.speech_signal_process, module_name=Models.speech_frcrn_ans_cirm_16k)
|
||||
class FRCRNModel(Model):
|
||||
r""" A decorator of FRCRN for integrating into modelscope framework """
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""initialize the frcrn model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self._model = FRCRN(*args, **kwargs)
|
||||
model_bin_file = os.path.join(model_dir,
|
||||
ModelFile.TORCH_MODEL_BIN_FILE)
|
||||
if os.path.exists(model_bin_file):
|
||||
checkpoint = torch.load(model_bin_file)
|
||||
self._model.load_state_dict(checkpoint, strict=False)
|
||||
|
||||
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
output = self._model.forward(input)
|
||||
return {
|
||||
'spec_l1': output[0],
|
||||
'wav_l1': output[1],
|
||||
'mask_l1': output[2],
|
||||
'spec_l2': output[3],
|
||||
'wav_l2': output[4],
|
||||
'mask_l2': output[5]
|
||||
}
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
self._model = self._model.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
self._model = self._model.train(False)
|
||||
return self
|
||||
|
||||
|
||||
class FRCRN(nn.Module):
|
||||
r""" Frequency Recurrent CRN """
|
||||
|
||||
def __init__(self,
|
||||
complex,
|
||||
model_complexity,
|
||||
model_depth,
|
||||
log_amp,
|
||||
padding_mode,
|
||||
win_len=400,
|
||||
win_inc=100,
|
||||
fft_len=512,
|
||||
win_type='hanning'):
|
||||
r"""
|
||||
Args:
|
||||
complex: Whether to use complex networks.
|
||||
model_complexity: define the model complexity with the number of layers
|
||||
model_depth: Only two options are available : 10, 20
|
||||
log_amp: Whether to use log amplitude to estimate signals
|
||||
padding_mode: Encoder's convolution filter. 'zeros', 'reflect'
|
||||
win_len: length of window used for defining one frame of sample points
|
||||
win_inc: length of window shifting (equivalent to hop_size)
|
||||
fft_len: number of Short Time Fourier Transform (STFT) points
|
||||
win_type: windowing type used in STFT, eg. 'hanning', 'hamming'
|
||||
"""
|
||||
super().__init__()
|
||||
self.feat_dim = fft_len // 2 + 1
|
||||
|
||||
self.win_len = win_len
|
||||
self.win_inc = win_inc
|
||||
self.fft_len = fft_len
|
||||
self.win_type = win_type
|
||||
|
||||
fix = True
|
||||
self.stft = ConvSTFT(
|
||||
self.win_len,
|
||||
self.win_inc,
|
||||
self.fft_len,
|
||||
self.win_type,
|
||||
feature_type='complex',
|
||||
fix=fix)
|
||||
self.istft = ConviSTFT(
|
||||
self.win_len,
|
||||
self.win_inc,
|
||||
self.fft_len,
|
||||
self.win_type,
|
||||
feature_type='complex',
|
||||
fix=fix)
|
||||
self.unet = UNet(
|
||||
1,
|
||||
complex=complex,
|
||||
model_complexity=model_complexity,
|
||||
model_depth=model_depth,
|
||||
padding_mode=padding_mode)
|
||||
self.unet2 = UNet(
|
||||
1,
|
||||
complex=complex,
|
||||
model_complexity=model_complexity,
|
||||
model_depth=model_depth,
|
||||
padding_mode=padding_mode)
|
||||
|
||||
def forward(self, inputs):
|
||||
out_list = []
|
||||
# [B, D*2, T]
|
||||
cmp_spec = self.stft(inputs)
|
||||
# [B, 1, D*2, T]
|
||||
cmp_spec = torch.unsqueeze(cmp_spec, 1)
|
||||
|
||||
# to [B, 2, D, T] real_part/imag_part
|
||||
cmp_spec = torch.cat([
|
||||
cmp_spec[:, :, :self.feat_dim, :],
|
||||
cmp_spec[:, :, self.feat_dim:, :],
|
||||
], 1)
|
||||
|
||||
# [B, 2, D, T]
|
||||
cmp_spec = torch.unsqueeze(cmp_spec, 4)
|
||||
# [B, 1, D, T, 2]
|
||||
cmp_spec = torch.transpose(cmp_spec, 1, 4)
|
||||
unet1_out = self.unet(cmp_spec)
|
||||
cmp_mask1 = torch.tanh(unet1_out)
|
||||
unet2_out = self.unet2(unet1_out)
|
||||
cmp_mask2 = torch.tanh(unet2_out)
|
||||
est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask1)
|
||||
out_list.append(est_spec)
|
||||
out_list.append(est_wav)
|
||||
out_list.append(est_mask)
|
||||
cmp_mask2 = cmp_mask2 + cmp_mask1
|
||||
est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask2)
|
||||
out_list.append(est_spec)
|
||||
out_list.append(est_wav)
|
||||
out_list.append(est_mask)
|
||||
return out_list
|
||||
|
||||
def apply_mask(self, cmp_spec, cmp_mask):
|
||||
est_spec = torch.cat([
|
||||
cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 0]
|
||||
- cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 1],
|
||||
cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 1]
|
||||
+ cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 0]
|
||||
], 1)
|
||||
est_spec = torch.cat([est_spec[:, 0, :, :], est_spec[:, 1, :, :]], 1)
|
||||
cmp_mask = torch.squeeze(cmp_mask, 1)
|
||||
cmp_mask = torch.cat([cmp_mask[:, :, :, 0], cmp_mask[:, :, :, 1]], 1)
|
||||
|
||||
est_wav = self.istft(est_spec)
|
||||
est_wav = torch.squeeze(est_wav, 1)
|
||||
return est_spec, est_wav, cmp_mask
|
||||
|
||||
def get_params(self, weight_decay=0.0):
|
||||
# add L2 penalty
|
||||
weights, biases = [], []
|
||||
for name, param in self.named_parameters():
|
||||
if 'bias' in name:
|
||||
biases += [param]
|
||||
else:
|
||||
weights += [param]
|
||||
params = [{
|
||||
'params': weights,
|
||||
'weight_decay': weight_decay,
|
||||
}, {
|
||||
'params': biases,
|
||||
'weight_decay': 0.0,
|
||||
}]
|
||||
return params
|
||||
|
||||
def loss(self, noisy, labels, out_list, mode='Mix'):
|
||||
if mode == 'SiSNR':
|
||||
count = 0
|
||||
while count < len(out_list):
|
||||
est_spec = out_list[count]
|
||||
count = count + 1
|
||||
est_wav = out_list[count]
|
||||
count = count + 1
|
||||
est_mask = out_list[count]
|
||||
count = count + 1
|
||||
if count != 3:
|
||||
loss = self.loss_1layer(noisy, est_spec, est_wav, labels,
|
||||
est_mask, mode)
|
||||
return loss
|
||||
|
||||
elif mode == 'Mix':
|
||||
count = 0
|
||||
while count < len(out_list):
|
||||
est_spec = out_list[count]
|
||||
count = count + 1
|
||||
est_wav = out_list[count]
|
||||
count = count + 1
|
||||
est_mask = out_list[count]
|
||||
count = count + 1
|
||||
if count != 3:
|
||||
amp_loss, phase_loss, SiSNR_loss = self.loss_1layer(
|
||||
noisy, est_spec, est_wav, labels, est_mask, mode)
|
||||
loss = amp_loss + phase_loss + SiSNR_loss
|
||||
return loss, amp_loss, phase_loss
|
||||
|
||||
def loss_1layer(self, noisy, est, est_wav, labels, cmp_mask, mode='Mix'):
|
||||
r""" Compute the loss by mode
|
||||
mode == 'Mix'
|
||||
est: [B, F*2, T]
|
||||
labels: [B, F*2,T]
|
||||
mode == 'SiSNR'
|
||||
est: [B, T]
|
||||
labels: [B, T]
|
||||
"""
|
||||
if mode == 'SiSNR':
|
||||
if labels.dim() == 3:
|
||||
labels = torch.squeeze(labels, 1)
|
||||
if est_wav.dim() == 3:
|
||||
est_wav = torch.squeeze(est_wav, 1)
|
||||
return -si_snr(est_wav, labels)
|
||||
elif mode == 'Mix':
|
||||
|
||||
if labels.dim() == 3:
|
||||
labels = torch.squeeze(labels, 1)
|
||||
if est_wav.dim() == 3:
|
||||
est_wav = torch.squeeze(est_wav, 1)
|
||||
SiSNR_loss = -si_snr(est_wav, labels)
|
||||
|
||||
b, d, t = est.size()
|
||||
S = self.stft(labels)
|
||||
Sr = S[:, :self.feat_dim, :]
|
||||
Si = S[:, self.feat_dim:, :]
|
||||
Y = self.stft(noisy)
|
||||
Yr = Y[:, :self.feat_dim, :]
|
||||
Yi = Y[:, self.feat_dim:, :]
|
||||
Y_pow = Yr**2 + Yi**2
|
||||
gth_mask = torch.cat([(Sr * Yr + Si * Yi) / (Y_pow + 1e-8),
|
||||
(Si * Yr - Sr * Yi) / (Y_pow + 1e-8)], 1)
|
||||
gth_mask[gth_mask > 2] = 1
|
||||
gth_mask[gth_mask < -2] = -1
|
||||
amp_loss = F.mse_loss(gth_mask[:, :self.feat_dim, :],
|
||||
cmp_mask[:, :self.feat_dim, :]) * d
|
||||
phase_loss = F.mse_loss(gth_mask[:, self.feat_dim:, :],
|
||||
cmp_mask[:, self.feat_dim:, :]) * d
|
||||
return amp_loss, phase_loss, SiSNR_loss
|
||||
|
||||
|
||||
def l2_norm(s1, s2):
|
||||
norm = torch.sum(s1 * s2, -1, keepdim=True)
|
||||
return norm
|
||||
|
||||
|
||||
def si_snr(s1, s2, eps=1e-8):
|
||||
s1_s2_norm = l2_norm(s1, s2)
|
||||
s2_s2_norm = l2_norm(s2, s2)
|
||||
s_target = s1_s2_norm / (s2_s2_norm + eps) * s2
|
||||
e_nosie = s1 - s_target
|
||||
target_norm = l2_norm(s_target, s_target)
|
||||
noise_norm = l2_norm(e_nosie, e_nosie)
|
||||
snr = 10 * torch.log10((target_norm) / (noise_norm + eps) + eps)
|
||||
return torch.mean(snr)
|
||||
26
modelscope/models/audio/ans/se_module_complex.py
Normal file
26
modelscope/models/audio/ans/se_module_complex.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class SELayer(nn.Module):
|
||||
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(SELayer, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc_r = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel), nn.Sigmoid())
|
||||
self.fc_i = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel), nn.Sigmoid())
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _, _ = x.size()
|
||||
x_r = self.avg_pool(x[:, :, :, :, 0]).view(b, c)
|
||||
x_i = self.avg_pool(x[:, :, :, :, 1]).view(b, c)
|
||||
y_r = self.fc_r(x_r).view(b, c, 1, 1, 1) - self.fc_i(x_i).view(
|
||||
b, c, 1, 1, 1)
|
||||
y_i = self.fc_r(x_i).view(b, c, 1, 1, 1) + self.fc_i(x_r).view(
|
||||
b, c, 1, 1, 1)
|
||||
y = torch.cat([y_r, y_i], 4)
|
||||
return x * y
|
||||
269
modelscope/models/audio/ans/unet.py
Normal file
269
modelscope/models/audio/ans/unet.py
Normal file
@@ -0,0 +1,269 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from . import complex_nn
|
||||
from .se_module_complex import SELayer
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=None,
|
||||
complex=False,
|
||||
padding_mode='zeros'):
|
||||
super().__init__()
|
||||
if padding is None:
|
||||
padding = [(i - 1) // 2 for i in kernel_size] # 'SAME' padding
|
||||
|
||||
if complex:
|
||||
conv = complex_nn.ComplexConv2d
|
||||
bn = complex_nn.ComplexBatchNorm2d
|
||||
else:
|
||||
conv = nn.Conv2d
|
||||
bn = nn.BatchNorm2d
|
||||
|
||||
self.conv = conv(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
padding_mode=padding_mode)
|
||||
self.bn = bn(out_channels)
|
||||
self.relu = nn.LeakyReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=(0, 0),
|
||||
complex=False):
|
||||
super().__init__()
|
||||
if complex:
|
||||
tconv = complex_nn.ComplexConvTranspose2d
|
||||
bn = complex_nn.ComplexBatchNorm2d
|
||||
else:
|
||||
tconv = nn.ConvTranspose2d
|
||||
bn = nn.BatchNorm2d
|
||||
|
||||
self.transconv = tconv(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding)
|
||||
self.bn = bn(out_channels)
|
||||
self.relu = nn.LeakyReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.transconv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_channels=1,
|
||||
complex=False,
|
||||
model_complexity=45,
|
||||
model_depth=20,
|
||||
padding_mode='zeros'):
|
||||
super().__init__()
|
||||
|
||||
if complex:
|
||||
model_complexity = int(model_complexity // 1.414)
|
||||
|
||||
self.set_size(
|
||||
model_complexity=model_complexity,
|
||||
input_channels=input_channels,
|
||||
model_depth=model_depth)
|
||||
self.encoders = []
|
||||
self.model_length = model_depth // 2
|
||||
self.fsmn = complex_nn.ComplexUniDeepFsmn(128, 128, 128)
|
||||
self.se_layers_enc = []
|
||||
self.fsmn_enc = []
|
||||
for i in range(self.model_length):
|
||||
fsmn_enc = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128)
|
||||
self.add_module('fsmn_enc{}'.format(i), fsmn_enc)
|
||||
self.fsmn_enc.append(fsmn_enc)
|
||||
module = Encoder(
|
||||
self.enc_channels[i],
|
||||
self.enc_channels[i + 1],
|
||||
kernel_size=self.enc_kernel_sizes[i],
|
||||
stride=self.enc_strides[i],
|
||||
padding=self.enc_paddings[i],
|
||||
complex=complex,
|
||||
padding_mode=padding_mode)
|
||||
self.add_module('encoder{}'.format(i), module)
|
||||
self.encoders.append(module)
|
||||
se_layer_enc = SELayer(self.enc_channels[i + 1], 8)
|
||||
self.add_module('se_layer_enc{}'.format(i), se_layer_enc)
|
||||
self.se_layers_enc.append(se_layer_enc)
|
||||
self.decoders = []
|
||||
self.fsmn_dec = []
|
||||
self.se_layers_dec = []
|
||||
for i in range(self.model_length):
|
||||
fsmn_dec = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128)
|
||||
self.add_module('fsmn_dec{}'.format(i), fsmn_dec)
|
||||
self.fsmn_dec.append(fsmn_dec)
|
||||
module = Decoder(
|
||||
self.dec_channels[i] * 2,
|
||||
self.dec_channels[i + 1],
|
||||
kernel_size=self.dec_kernel_sizes[i],
|
||||
stride=self.dec_strides[i],
|
||||
padding=self.dec_paddings[i],
|
||||
complex=complex)
|
||||
self.add_module('decoder{}'.format(i), module)
|
||||
self.decoders.append(module)
|
||||
if i < self.model_length - 1:
|
||||
se_layer_dec = SELayer(self.dec_channels[i + 1], 8)
|
||||
self.add_module('se_layer_dec{}'.format(i), se_layer_dec)
|
||||
self.se_layers_dec.append(se_layer_dec)
|
||||
if complex:
|
||||
conv = complex_nn.ComplexConv2d
|
||||
else:
|
||||
conv = nn.Conv2d
|
||||
|
||||
linear = conv(self.dec_channels[-1], 1, 1)
|
||||
|
||||
self.add_module('linear', linear)
|
||||
self.complex = complex
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
self.decoders = nn.ModuleList(self.decoders)
|
||||
self.encoders = nn.ModuleList(self.encoders)
|
||||
self.se_layers_enc = nn.ModuleList(self.se_layers_enc)
|
||||
self.se_layers_dec = nn.ModuleList(self.se_layers_dec)
|
||||
self.fsmn_enc = nn.ModuleList(self.fsmn_enc)
|
||||
self.fsmn_dec = nn.ModuleList(self.fsmn_dec)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = inputs
|
||||
# go down
|
||||
xs = []
|
||||
xs_se = []
|
||||
xs_se.append(x)
|
||||
for i, encoder in enumerate(self.encoders):
|
||||
xs.append(x)
|
||||
if i > 0:
|
||||
x = self.fsmn_enc[i](x)
|
||||
x = encoder(x)
|
||||
xs_se.append(self.se_layers_enc[i](x))
|
||||
# xs : x0=input x1 ... x9
|
||||
x = self.fsmn(x)
|
||||
|
||||
p = x
|
||||
for i, decoder in enumerate(self.decoders):
|
||||
p = decoder(p)
|
||||
if i < self.model_length - 1:
|
||||
p = self.fsmn_dec[i](p)
|
||||
if i == self.model_length - 1:
|
||||
break
|
||||
if i < self.model_length - 2:
|
||||
p = self.se_layers_dec[i](p)
|
||||
p = torch.cat([p, xs_se[self.model_length - 1 - i]], dim=1)
|
||||
|
||||
# cmp_spec: [12, 1, 513, 64, 2]
|
||||
cmp_spec = self.linear(p)
|
||||
return cmp_spec
|
||||
|
||||
def set_size(self, model_complexity, model_depth=20, input_channels=1):
|
||||
|
||||
if model_depth == 14:
|
||||
self.enc_channels = [
|
||||
input_channels, 128, 128, 128, 128, 128, 128, 128
|
||||
]
|
||||
self.enc_kernel_sizes = [(5, 2), (5, 2), (5, 2), (5, 2), (5, 2),
|
||||
(5, 2), (2, 2)]
|
||||
self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1),
|
||||
(2, 1)]
|
||||
self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1),
|
||||
(0, 1), (0, 1)]
|
||||
self.dec_channels = [64, 128, 128, 128, 128, 128, 128, 1]
|
||||
self.dec_kernel_sizes = [(2, 2), (5, 2), (5, 2), (5, 2), (6, 2),
|
||||
(5, 2), (5, 2)]
|
||||
self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1),
|
||||
(2, 1)]
|
||||
self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1),
|
||||
(0, 1), (0, 1)]
|
||||
|
||||
elif model_depth == 10:
|
||||
self.enc_channels = [
|
||||
input_channels,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128,
|
||||
256,
|
||||
]
|
||||
self.enc_kernel_sizes = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
|
||||
self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
|
||||
self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
|
||||
self.dec_channels = [128, 128, 64, 32, 16, 1]
|
||||
self.dec_kernel_sizes = [(3, 3), (3, 3), (3, 3), (4, 3), (3, 3)]
|
||||
self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
|
||||
self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
|
||||
|
||||
elif model_depth == 20:
|
||||
self.enc_channels = [
|
||||
input_channels, model_complexity, model_complexity,
|
||||
model_complexity * 2, model_complexity * 2,
|
||||
model_complexity * 2, model_complexity * 2,
|
||||
model_complexity * 2, model_complexity * 2,
|
||||
model_complexity * 2, 128
|
||||
]
|
||||
|
||||
self.enc_kernel_sizes = [(7, 1), (1, 7), (6, 4), (7, 5), (5, 3),
|
||||
(5, 3), (5, 3), (5, 3), (5, 3), (5, 3)]
|
||||
|
||||
self.enc_strides = [(1, 1), (1, 1), (2, 2), (2, 1), (2, 2), (2, 1),
|
||||
(2, 2), (2, 1), (2, 2), (2, 1)]
|
||||
|
||||
self.enc_paddings = [
|
||||
(3, 0),
|
||||
(0, 3),
|
||||
None, # (0, 2),
|
||||
None,
|
||||
None, # (3,1),
|
||||
None, # (3,1),
|
||||
None, # (1,2),
|
||||
None,
|
||||
None,
|
||||
None
|
||||
]
|
||||
|
||||
self.dec_channels = [
|
||||
0, model_complexity * 2, model_complexity * 2,
|
||||
model_complexity * 2, model_complexity * 2,
|
||||
model_complexity * 2, model_complexity * 2,
|
||||
model_complexity * 2, model_complexity * 2,
|
||||
model_complexity * 2, model_complexity * 2,
|
||||
model_complexity * 2
|
||||
]
|
||||
|
||||
self.dec_kernel_sizes = [(4, 3), (4, 2), (4, 3), (4, 2), (4, 3),
|
||||
(4, 2), (6, 3), (7, 4), (1, 7), (7, 1)]
|
||||
|
||||
self.dec_strides = [(2, 1), (2, 2), (2, 1), (2, 2), (2, 1), (2, 2),
|
||||
(2, 1), (2, 2), (1, 1), (1, 1)]
|
||||
|
||||
self.dec_paddings = [(1, 1), (1, 0), (1, 1), (1, 0), (1, 1),
|
||||
(1, 0), (2, 1), (2, 1), (0, 3), (3, 0)]
|
||||
else:
|
||||
raise ValueError('Unknown model depth : {}'.format(model_depth))
|
||||
@@ -1,4 +1,5 @@
|
||||
from .audio import LinearAECPipeline
|
||||
from .audio.ans_pipeline import ANSPipeline
|
||||
from .base import Pipeline
|
||||
from .builder import pipeline
|
||||
from .cv import * # noqa F403
|
||||
|
||||
117
modelscope/pipelines/audio/ans_pipeline.py
Normal file
117
modelscope/pipelines/audio/ans_pipeline.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import os.path
|
||||
from typing import Any, Dict
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.utils.constant import Tasks
|
||||
from ..base import Input, Pipeline
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
def audio_norm(x):
|
||||
rms = (x**2).mean()**0.5
|
||||
scalar = 10**(-25 / 20) / rms
|
||||
x = x * scalar
|
||||
pow_x = x**2
|
||||
avg_pow_x = pow_x.mean()
|
||||
rmsx = pow_x[pow_x > avg_pow_x].mean()**0.5
|
||||
scalarx = 10**(-25 / 20) / rmsx
|
||||
x = x * scalarx
|
||||
return x
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.speech_signal_process,
|
||||
module_name=Pipelines.speech_frcrn_ans_cirm_16k)
|
||||
class ANSPipeline(Pipeline):
|
||||
r"""ANS (Acoustic Noise Suppression) Inference Pipeline .
|
||||
|
||||
When invoke the class with pipeline.__call__(), it accept only one parameter:
|
||||
inputs(str): the path of wav file
|
||||
"""
|
||||
SAMPLE_RATE = 16000
|
||||
|
||||
def __init__(self, model):
|
||||
r"""
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
self.device = torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.model = self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
def preprocess(self, inputs: Input) -> Dict[str, Any]:
|
||||
assert isinstance(inputs, str) and os.path.exists(inputs) and os.path.isfile(inputs), \
|
||||
f'Input file do not exists: {inputs}'
|
||||
data1, fs = sf.read(inputs)
|
||||
data1 = audio_norm(data1)
|
||||
if fs != self.SAMPLE_RATE:
|
||||
data1 = librosa.resample(data1, fs, self.SAMPLE_RATE)
|
||||
if len(data1.shape) > 1:
|
||||
data1 = data1[:, 0]
|
||||
data = data1.astype(np.float32)
|
||||
inputs = np.reshape(data, [1, data.shape[0]])
|
||||
return {'ndarray': inputs, 'nsamples': data.shape[0]}
|
||||
|
||||
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
ndarray = inputs['ndarray']
|
||||
nsamples = inputs['nsamples']
|
||||
decode_do_segement = False
|
||||
window = 16000
|
||||
stride = int(window * 0.75)
|
||||
print('inputs:{}'.format(ndarray.shape))
|
||||
b, t = ndarray.shape # size()
|
||||
if t > window * 120:
|
||||
decode_do_segement = True
|
||||
|
||||
if t < window:
|
||||
ndarray = np.concatenate(
|
||||
[ndarray, np.zeros((ndarray.shape[0], window - t))], 1)
|
||||
elif t < window + stride:
|
||||
padding = window + stride - t
|
||||
print('padding: {}'.format(padding))
|
||||
ndarray = np.concatenate(
|
||||
[ndarray, np.zeros((ndarray.shape[0], padding))], 1)
|
||||
else:
|
||||
if (t - window) % stride != 0:
|
||||
padding = t - (t - window) // stride * stride
|
||||
print('padding: {}'.format(padding))
|
||||
ndarray = np.concatenate(
|
||||
[ndarray, np.zeros((ndarray.shape[0], padding))], 1)
|
||||
print('inputs after padding:{}'.format(ndarray.shape))
|
||||
with torch.no_grad():
|
||||
ndarray = torch.from_numpy(np.float32(ndarray)).to(self.device)
|
||||
b, t = ndarray.shape
|
||||
if decode_do_segement:
|
||||
outputs = np.zeros(t)
|
||||
give_up_length = (window - stride) // 2
|
||||
current_idx = 0
|
||||
while current_idx + window <= t:
|
||||
print('current_idx: {}'.format(current_idx))
|
||||
tmp_input = ndarray[:, current_idx:current_idx + window]
|
||||
tmp_output = self.model(
|
||||
tmp_input, )['wav_l2'][0].cpu().numpy()
|
||||
end_index = current_idx + window - give_up_length
|
||||
if current_idx == 0:
|
||||
outputs[current_idx:
|
||||
end_index] = tmp_output[:-give_up_length]
|
||||
else:
|
||||
outputs[current_idx
|
||||
+ give_up_length:end_index] = tmp_output[
|
||||
give_up_length:-give_up_length]
|
||||
current_idx += stride
|
||||
else:
|
||||
outputs = self.model(ndarray)['wav_l2'][0].cpu().numpy()
|
||||
return {'output_pcm': outputs[:nsamples]}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
if 'output_path' in kwargs.keys():
|
||||
sf.write(kwargs['output_path'], inputs['output_pcm'],
|
||||
self.SAMPLE_RATE)
|
||||
return inputs
|
||||
@@ -16,6 +16,7 @@ protobuf>3,<=3.20
|
||||
ptflops
|
||||
PyWavelets>=1.0.0
|
||||
scikit-learn
|
||||
SoundFile>0.10
|
||||
sox
|
||||
tensorboard
|
||||
tensorflow==1.15.*
|
||||
|
||||
@@ -17,6 +17,9 @@ AEC_LIB_URL = 'http://isv-data.oss-cn-hangzhou.aliyuncs.com/ics%2FMaaS%2FAEC%2Fl
|
||||
'?Expires=1664085465&OSSAccessKeyId=LTAIxjQyZNde90zh&Signature=Y7gelmGEsQAJRK4yyHSYMrdWizk%3D'
|
||||
AEC_LIB_FILE = 'libmitaec_pyio.so'
|
||||
|
||||
NOISE_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ANS/sample_audio/speech_with_noise.wav'
|
||||
NOISE_SPEECH_FILE = 'speech_with_noise.wav'
|
||||
|
||||
|
||||
def download(remote_path, local_path):
|
||||
local_dir = os.path.dirname(local_path)
|
||||
@@ -30,23 +33,40 @@ def download(remote_path, local_path):
|
||||
class SpeechSignalProcessTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/speech_dfsmn_aec_psm_16k'
|
||||
pass
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_aec(self):
|
||||
# A temporary hack to provide c++ lib. Download it first.
|
||||
download(AEC_LIB_URL, AEC_LIB_FILE)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run(self):
|
||||
# Download audio files
|
||||
download(NEAREND_MIC_URL, NEAREND_MIC_FILE)
|
||||
download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE)
|
||||
model_id = 'damo/speech_dfsmn_aec_psm_16k'
|
||||
input = {
|
||||
'nearend_mic': NEAREND_MIC_FILE,
|
||||
'farend_speech': FAREND_SPEECH_FILE
|
||||
}
|
||||
aec = pipeline(
|
||||
Tasks.speech_signal_process,
|
||||
model=self.model_id,
|
||||
model=model_id,
|
||||
pipeline_name=Pipelines.speech_dfsmn_aec_psm_16k)
|
||||
aec(input, output_path='output.wav')
|
||||
output_path = os.path.abspath('output.wav')
|
||||
aec(input, output_path=output_path)
|
||||
print(f'Processed audio saved to {output_path}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_ans(self):
|
||||
# Download audio files
|
||||
download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE)
|
||||
model_id = 'damo/speech_frcrn_ans_cirm_16k'
|
||||
ans = pipeline(
|
||||
Tasks.speech_signal_process,
|
||||
model=model_id,
|
||||
pipeline_name=Pipelines.speech_frcrn_ans_cirm_16k)
|
||||
output_path = os.path.abspath('output.wav')
|
||||
ans(NOISE_SPEECH_FILE, output_path=output_path)
|
||||
print(f'Processed audio saved to {output_path}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user