[to #42322933] Merge ANS pipeline into master

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9178339

    * refactor: move aec models to audio/aec

* refactor: move aec models to audio/aec

* refactor: move aec models to audio/aec

* refactor: move aec models to audio/aec

* feat: add unittest for ANS pipeline

* Merge branch 'master' into dev/ans

* add new SoundFile to audio dependency

* Merge branch 'master' into dev/ans

* use ANS pipeline name from metainfo

* Merge branch 'master' into dev/ans

* chore: update docstring of ANS module

* Merge branch 'master' into dev/ans

* refactor: use names from metainfo

* refactor: enable ans unittest

* refactor: add more log message in unittest
This commit is contained in:
bin.xue
2022-06-28 14:41:08 +08:00
committed by huangjun.hj
parent a7c1cd0fc9
commit 04b7eba285
23 changed files with 1112 additions and 6 deletions

View File

@@ -21,6 +21,7 @@ class Models(object):
sambert_hifi_16k = 'sambert-hifi-16k'
generic_tts_frontend = 'generic-tts-frontend'
hifigan16k = 'hifigan16k'
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
kws_kwsbp = 'kws-kwsbp'
# multi-modal models
@@ -55,6 +56,7 @@ class Pipelines(object):
# audio tasks
sambert_hifigan_16k_tts = 'sambert-hifigan-16k-tts'
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k'
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
kws_kwsbp = 'kws-kwsbp'
# multi-modal tasks

View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .audio.ans.frcrn import FRCRNModel
from .audio.kws import GenericKeyWordSpotting
from .audio.tts.am import SambertNetHifi16k
from .audio.tts.vocoder import Hifigan16k

View File

View File

@@ -0,0 +1,248 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class UniDeepFsmn(nn.Module):
def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
super(UniDeepFsmn, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
if lorder is None:
return
self.lorder = lorder
self.hidden_size = hidden_size
self.linear = nn.Linear(input_dim, hidden_size)
self.project = nn.Linear(hidden_size, output_dim, bias=False)
self.conv1 = nn.Conv2d(
output_dim,
output_dim, [lorder, 1], [1, 1],
groups=output_dim,
bias=False)
def forward(self, input):
r"""
Args:
input: torch with shape: batch (b) x sequence(T) x feature (h)
Returns:
batch (b) x channel (c) x sequence(T) x feature (h)
"""
f1 = F.relu(self.linear(input))
p1 = self.project(f1)
x = torch.unsqueeze(p1, 1)
# x: batch (b) x channel (c) x sequence(T) x feature (h)
x_per = x.permute(0, 3, 2, 1)
# x_per: batch (b) x feature (h) x sequence(T) x channel (c)
y = F.pad(x_per, [0, 0, self.lorder - 1, 0])
out = x_per + self.conv1(y)
out1 = out.permute(0, 3, 2, 1)
# out1: batch (b) x channel (c) x sequence(T) x feature (h)
return input + out1.squeeze()
class ComplexUniDeepFsmn(nn.Module):
def __init__(self, nIn, nHidden=128, nOut=128):
super(ComplexUniDeepFsmn, self).__init__()
self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
self.fsmn_re_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden)
self.fsmn_im_L2 = UniDeepFsmn(nHidden, nOut, 20, nHidden)
def forward(self, x):
r"""
Args:
x: torch with shape [batch, channel, feature, sequence, 2], eg: [6, 256, 1, 106, 2]
Returns:
[batch, feature, sequence, 2], eg: [6, 99, 1024, 2]
"""
#
b, c, h, T, d = x.size()
x = torch.reshape(x, (b, c * h, T, d))
# x: [b,h,T,2], [6, 256, 106, 2]
x = torch.transpose(x, 1, 2)
# x: [b,T,h,2], [6, 106, 256, 2]
real_L1 = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1])
imaginary_L1 = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0])
# GRU output: [99, 6, 128]
real = self.fsmn_re_L2(real_L1) - self.fsmn_im_L2(imaginary_L1)
imaginary = self.fsmn_re_L2(imaginary_L1) + self.fsmn_im_L2(real_L1)
# output: [b,T,h,2], [99, 6, 1024, 2]
output = torch.stack((real, imaginary), dim=-1)
# output: [b,h,T,2], [6, 99, 1024, 2]
output = torch.transpose(output, 1, 2)
output = torch.reshape(output, (b, c, h, T, d))
return output
class ComplexUniDeepFsmn_L1(nn.Module):
def __init__(self, nIn, nHidden=128, nOut=128):
super(ComplexUniDeepFsmn_L1, self).__init__()
self.fsmn_re_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
self.fsmn_im_L1 = UniDeepFsmn(nIn, nHidden, 20, nHidden)
def forward(self, x):
r"""
Args:
x: torch with shape [batch, channel, feature, sequence, 2], eg: [6, 256, 1, 106, 2]
"""
b, c, h, T, d = x.size()
# x : [b,T,h,c,2]
x = torch.transpose(x, 1, 3)
x = torch.reshape(x, (b * T, h, c, d))
real = self.fsmn_re_L1(x[..., 0]) - self.fsmn_im_L1(x[..., 1])
imaginary = self.fsmn_re_L1(x[..., 1]) + self.fsmn_im_L1(x[..., 0])
# output: [b*T,h,c,2], [6*106, h, 256, 2]
output = torch.stack((real, imaginary), dim=-1)
output = torch.reshape(output, (b, T, h, c, d))
output = torch.transpose(output, 1, 3)
return output
class ComplexConv2d(nn.Module):
# https://github.com/litcoderr/ComplexCNN/blob/master/complexcnn/modules.py
def __init__(self,
in_channel,
out_channel,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
**kwargs):
super().__init__()
# Model components
self.conv_re = nn.Conv2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
**kwargs)
self.conv_im = nn.Conv2d(
in_channel,
out_channel,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
**kwargs)
def forward(self, x):
r"""
Args:
x: torch with shape: [batch,channel,axis1,axis2,2]
"""
real = self.conv_re(x[..., 0]) - self.conv_im(x[..., 1])
imaginary = self.conv_re(x[..., 1]) + self.conv_im(x[..., 0])
output = torch.stack((real, imaginary), dim=-1)
return output
class ComplexConvTranspose2d(nn.Module):
def __init__(self,
in_channel,
out_channel,
kernel_size,
stride=1,
padding=0,
output_padding=0,
dilation=1,
groups=1,
bias=True,
**kwargs):
super().__init__()
# Model components
self.tconv_re = nn.ConvTranspose2d(
in_channel,
out_channel,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
bias=bias,
dilation=dilation,
**kwargs)
self.tconv_im = nn.ConvTranspose2d(
in_channel,
out_channel,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
bias=bias,
dilation=dilation,
**kwargs)
def forward(self, x): # shpae of x : [batch,channel,axis1,axis2,2]
real = self.tconv_re(x[..., 0]) - self.tconv_im(x[..., 1])
imaginary = self.tconv_re(x[..., 1]) + self.tconv_im(x[..., 0])
output = torch.stack((real, imaginary), dim=-1)
return output
class ComplexBatchNorm2d(nn.Module):
def __init__(self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
**kwargs):
super().__init__()
self.bn_re = nn.BatchNorm2d(
num_features=num_features,
momentum=momentum,
affine=affine,
eps=eps,
track_running_stats=track_running_stats,
**kwargs)
self.bn_im = nn.BatchNorm2d(
num_features=num_features,
momentum=momentum,
affine=affine,
eps=eps,
track_running_stats=track_running_stats,
**kwargs)
def forward(self, x):
real = self.bn_re(x[..., 0])
imag = self.bn_im(x[..., 1])
output = torch.stack((real, imag), dim=-1)
return output

View File

@@ -0,0 +1,112 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.signal import get_window
def init_kernels(win_len, win_inc, fft_len, win_type=None, invers=False):
if win_type == 'None' or win_type is None:
window = np.ones(win_len)
else:
window = get_window(win_type, win_len, fftbins=True)**0.5
N = fft_len
fourier_basis = np.fft.rfft(np.eye(N))[:win_len]
real_kernel = np.real(fourier_basis)
imag_kernel = np.imag(fourier_basis)
kernel = np.concatenate([real_kernel, imag_kernel], 1).T
if invers:
kernel = np.linalg.pinv(kernel).T
kernel = kernel * window
kernel = kernel[:, None, :]
return torch.from_numpy(kernel.astype(np.float32)), torch.from_numpy(
window[None, :, None].astype(np.float32))
class ConvSTFT(nn.Module):
def __init__(self,
win_len,
win_inc,
fft_len=None,
win_type='hamming',
feature_type='real',
fix=True):
super(ConvSTFT, self).__init__()
if fft_len is None:
self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
else:
self.fft_len = fft_len
kernel, _ = init_kernels(win_len, win_inc, self.fft_len, win_type)
self.weight = nn.Parameter(kernel, requires_grad=(not fix))
self.feature_type = feature_type
self.stride = win_inc
self.win_len = win_len
self.dim = self.fft_len
def forward(self, inputs):
if inputs.dim() == 2:
inputs = torch.unsqueeze(inputs, 1)
outputs = F.conv1d(inputs, self.weight, stride=self.stride)
if self.feature_type == 'complex':
return outputs
else:
dim = self.dim // 2 + 1
real = outputs[:, :dim, :]
imag = outputs[:, dim:, :]
mags = torch.sqrt(real**2 + imag**2)
phase = torch.atan2(imag, real)
return mags, phase
class ConviSTFT(nn.Module):
def __init__(self,
win_len,
win_inc,
fft_len=None,
win_type='hamming',
feature_type='real',
fix=True):
super(ConviSTFT, self).__init__()
if fft_len is None:
self.fft_len = np.int(2**np.ceil(np.log2(win_len)))
else:
self.fft_len = fft_len
kernel, window = init_kernels(
win_len, win_inc, self.fft_len, win_type, invers=True)
self.weight = nn.Parameter(kernel, requires_grad=(not fix))
self.feature_type = feature_type
self.win_type = win_type
self.win_len = win_len
self.win_inc = win_inc
self.stride = win_inc
self.dim = self.fft_len
self.register_buffer('window', window)
self.register_buffer('enframe', torch.eye(win_len)[:, None, :])
def forward(self, inputs, phase=None):
"""
Args:
inputs : [B, N+2, T] (complex spec) or [B, N//2+1, T] (mags)
phase: [B, N//2+1, T] (if not none)
"""
if phase is not None:
real = inputs * torch.cos(phase)
imag = inputs * torch.sin(phase)
inputs = torch.cat([real, imag], 1)
outputs = F.conv_transpose1d(inputs, self.weight, stride=self.stride)
# this is from torch-stft: https://github.com/pseeth/torch-stft
t = self.window.repeat(1, 1, inputs.size(-1))**2
coff = F.conv_transpose1d(t, self.enframe, stride=self.stride)
outputs = outputs / (coff + 1e-8)
return outputs

View File

@@ -0,0 +1,309 @@
import os
from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.metainfo import Models
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from ...base import Model, Tensor
from .conv_stft import ConviSTFT, ConvSTFT
from .unet import UNet
class FTB(nn.Module):
def __init__(self, input_dim=257, in_channel=9, r_channel=5):
super(FTB, self).__init__()
self.in_channel = in_channel
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, r_channel, kernel_size=[1, 1]),
nn.BatchNorm2d(r_channel), nn.ReLU())
self.conv1d = nn.Sequential(
nn.Conv1d(
r_channel * input_dim, in_channel, kernel_size=9, padding=4),
nn.BatchNorm1d(in_channel), nn.ReLU())
self.freq_fc = nn.Linear(input_dim, input_dim, bias=False)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channel * 2, in_channel, kernel_size=[1, 1]),
nn.BatchNorm2d(in_channel), nn.ReLU())
def forward(self, inputs):
'''
inputs should be [Batch, Ca, Dim, Time]
'''
# T-F attention
conv1_out = self.conv1(inputs)
B, C, D, T = conv1_out.size()
reshape1_out = torch.reshape(conv1_out, [B, C * D, T])
conv1d_out = self.conv1d(reshape1_out)
conv1d_out = torch.reshape(conv1d_out, [B, self.in_channel, 1, T])
# now is also [B,C,D,T]
att_out = conv1d_out * inputs
# tranpose to [B,C,T,D]
att_out = torch.transpose(att_out, 2, 3)
freqfc_out = self.freq_fc(att_out)
att_out = torch.transpose(freqfc_out, 2, 3)
cat_out = torch.cat([att_out, inputs], 1)
outputs = self.conv2(cat_out)
return outputs
@MODELS.register_module(
Tasks.speech_signal_process, module_name=Models.speech_frcrn_ans_cirm_16k)
class FRCRNModel(Model):
r""" A decorator of FRCRN for integrating into modelscope framework """
def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the frcrn model from the `model_dir` path.
Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)
self._model = FRCRN(*args, **kwargs)
model_bin_file = os.path.join(model_dir,
ModelFile.TORCH_MODEL_BIN_FILE)
if os.path.exists(model_bin_file):
checkpoint = torch.load(model_bin_file)
self._model.load_state_dict(checkpoint, strict=False)
def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]:
output = self._model.forward(input)
return {
'spec_l1': output[0],
'wav_l1': output[1],
'mask_l1': output[2],
'spec_l2': output[3],
'wav_l2': output[4],
'mask_l2': output[5]
}
def to(self, *args, **kwargs):
self._model = self._model.to(*args, **kwargs)
return self
def eval(self):
self._model = self._model.train(False)
return self
class FRCRN(nn.Module):
r""" Frequency Recurrent CRN """
def __init__(self,
complex,
model_complexity,
model_depth,
log_amp,
padding_mode,
win_len=400,
win_inc=100,
fft_len=512,
win_type='hanning'):
r"""
Args:
complex: Whether to use complex networks.
model_complexity: define the model complexity with the number of layers
model_depth: Only two options are available : 10, 20
log_amp: Whether to use log amplitude to estimate signals
padding_mode: Encoder's convolution filter. 'zeros', 'reflect'
win_len: length of window used for defining one frame of sample points
win_inc: length of window shifting (equivalent to hop_size)
fft_len: number of Short Time Fourier Transform (STFT) points
win_type: windowing type used in STFT, eg. 'hanning', 'hamming'
"""
super().__init__()
self.feat_dim = fft_len // 2 + 1
self.win_len = win_len
self.win_inc = win_inc
self.fft_len = fft_len
self.win_type = win_type
fix = True
self.stft = ConvSTFT(
self.win_len,
self.win_inc,
self.fft_len,
self.win_type,
feature_type='complex',
fix=fix)
self.istft = ConviSTFT(
self.win_len,
self.win_inc,
self.fft_len,
self.win_type,
feature_type='complex',
fix=fix)
self.unet = UNet(
1,
complex=complex,
model_complexity=model_complexity,
model_depth=model_depth,
padding_mode=padding_mode)
self.unet2 = UNet(
1,
complex=complex,
model_complexity=model_complexity,
model_depth=model_depth,
padding_mode=padding_mode)
def forward(self, inputs):
out_list = []
# [B, D*2, T]
cmp_spec = self.stft(inputs)
# [B, 1, D*2, T]
cmp_spec = torch.unsqueeze(cmp_spec, 1)
# to [B, 2, D, T] real_part/imag_part
cmp_spec = torch.cat([
cmp_spec[:, :, :self.feat_dim, :],
cmp_spec[:, :, self.feat_dim:, :],
], 1)
# [B, 2, D, T]
cmp_spec = torch.unsqueeze(cmp_spec, 4)
# [B, 1, D, T, 2]
cmp_spec = torch.transpose(cmp_spec, 1, 4)
unet1_out = self.unet(cmp_spec)
cmp_mask1 = torch.tanh(unet1_out)
unet2_out = self.unet2(unet1_out)
cmp_mask2 = torch.tanh(unet2_out)
est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask1)
out_list.append(est_spec)
out_list.append(est_wav)
out_list.append(est_mask)
cmp_mask2 = cmp_mask2 + cmp_mask1
est_spec, est_wav, est_mask = self.apply_mask(cmp_spec, cmp_mask2)
out_list.append(est_spec)
out_list.append(est_wav)
out_list.append(est_mask)
return out_list
def apply_mask(self, cmp_spec, cmp_mask):
est_spec = torch.cat([
cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 0]
- cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 1],
cmp_spec[:, :, :, :, 0] * cmp_mask[:, :, :, :, 1]
+ cmp_spec[:, :, :, :, 1] * cmp_mask[:, :, :, :, 0]
], 1)
est_spec = torch.cat([est_spec[:, 0, :, :], est_spec[:, 1, :, :]], 1)
cmp_mask = torch.squeeze(cmp_mask, 1)
cmp_mask = torch.cat([cmp_mask[:, :, :, 0], cmp_mask[:, :, :, 1]], 1)
est_wav = self.istft(est_spec)
est_wav = torch.squeeze(est_wav, 1)
return est_spec, est_wav, cmp_mask
def get_params(self, weight_decay=0.0):
# add L2 penalty
weights, biases = [], []
for name, param in self.named_parameters():
if 'bias' in name:
biases += [param]
else:
weights += [param]
params = [{
'params': weights,
'weight_decay': weight_decay,
}, {
'params': biases,
'weight_decay': 0.0,
}]
return params
def loss(self, noisy, labels, out_list, mode='Mix'):
if mode == 'SiSNR':
count = 0
while count < len(out_list):
est_spec = out_list[count]
count = count + 1
est_wav = out_list[count]
count = count + 1
est_mask = out_list[count]
count = count + 1
if count != 3:
loss = self.loss_1layer(noisy, est_spec, est_wav, labels,
est_mask, mode)
return loss
elif mode == 'Mix':
count = 0
while count < len(out_list):
est_spec = out_list[count]
count = count + 1
est_wav = out_list[count]
count = count + 1
est_mask = out_list[count]
count = count + 1
if count != 3:
amp_loss, phase_loss, SiSNR_loss = self.loss_1layer(
noisy, est_spec, est_wav, labels, est_mask, mode)
loss = amp_loss + phase_loss + SiSNR_loss
return loss, amp_loss, phase_loss
def loss_1layer(self, noisy, est, est_wav, labels, cmp_mask, mode='Mix'):
r""" Compute the loss by mode
mode == 'Mix'
est: [B, F*2, T]
labels: [B, F*2,T]
mode == 'SiSNR'
est: [B, T]
labels: [B, T]
"""
if mode == 'SiSNR':
if labels.dim() == 3:
labels = torch.squeeze(labels, 1)
if est_wav.dim() == 3:
est_wav = torch.squeeze(est_wav, 1)
return -si_snr(est_wav, labels)
elif mode == 'Mix':
if labels.dim() == 3:
labels = torch.squeeze(labels, 1)
if est_wav.dim() == 3:
est_wav = torch.squeeze(est_wav, 1)
SiSNR_loss = -si_snr(est_wav, labels)
b, d, t = est.size()
S = self.stft(labels)
Sr = S[:, :self.feat_dim, :]
Si = S[:, self.feat_dim:, :]
Y = self.stft(noisy)
Yr = Y[:, :self.feat_dim, :]
Yi = Y[:, self.feat_dim:, :]
Y_pow = Yr**2 + Yi**2
gth_mask = torch.cat([(Sr * Yr + Si * Yi) / (Y_pow + 1e-8),
(Si * Yr - Sr * Yi) / (Y_pow + 1e-8)], 1)
gth_mask[gth_mask > 2] = 1
gth_mask[gth_mask < -2] = -1
amp_loss = F.mse_loss(gth_mask[:, :self.feat_dim, :],
cmp_mask[:, :self.feat_dim, :]) * d
phase_loss = F.mse_loss(gth_mask[:, self.feat_dim:, :],
cmp_mask[:, self.feat_dim:, :]) * d
return amp_loss, phase_loss, SiSNR_loss
def l2_norm(s1, s2):
norm = torch.sum(s1 * s2, -1, keepdim=True)
return norm
def si_snr(s1, s2, eps=1e-8):
s1_s2_norm = l2_norm(s1, s2)
s2_s2_norm = l2_norm(s2, s2)
s_target = s1_s2_norm / (s2_s2_norm + eps) * s2
e_nosie = s1 - s_target
target_norm = l2_norm(s_target, s_target)
noise_norm = l2_norm(e_nosie, e_nosie)
snr = 10 * torch.log10((target_norm) / (noise_norm + eps) + eps)
return torch.mean(snr)

View File

@@ -0,0 +1,26 @@
import torch
from torch import nn
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc_r = nn.Sequential(
nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel), nn.Sigmoid())
self.fc_i = nn.Sequential(
nn.Linear(channel, channel // reduction), nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel), nn.Sigmoid())
def forward(self, x):
b, c, _, _, _ = x.size()
x_r = self.avg_pool(x[:, :, :, :, 0]).view(b, c)
x_i = self.avg_pool(x[:, :, :, :, 1]).view(b, c)
y_r = self.fc_r(x_r).view(b, c, 1, 1, 1) - self.fc_i(x_i).view(
b, c, 1, 1, 1)
y_i = self.fc_r(x_i).view(b, c, 1, 1, 1) + self.fc_i(x_r).view(
b, c, 1, 1, 1)
y = torch.cat([y_r, y_i], 4)
return x * y

View File

@@ -0,0 +1,269 @@
import torch
import torch.nn as nn
from . import complex_nn
from .se_module_complex import SELayer
class Encoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding=None,
complex=False,
padding_mode='zeros'):
super().__init__()
if padding is None:
padding = [(i - 1) // 2 for i in kernel_size] # 'SAME' padding
if complex:
conv = complex_nn.ComplexConv2d
bn = complex_nn.ComplexBatchNorm2d
else:
conv = nn.Conv2d
bn = nn.BatchNorm2d
self.conv = conv(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
padding_mode=padding_mode)
self.bn = bn(out_channels)
self.relu = nn.LeakyReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Decoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding=(0, 0),
complex=False):
super().__init__()
if complex:
tconv = complex_nn.ComplexConvTranspose2d
bn = complex_nn.ComplexBatchNorm2d
else:
tconv = nn.ConvTranspose2d
bn = nn.BatchNorm2d
self.transconv = tconv(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding)
self.bn = bn(out_channels)
self.relu = nn.LeakyReLU(inplace=True)
def forward(self, x):
x = self.transconv(x)
x = self.bn(x)
x = self.relu(x)
return x
class UNet(nn.Module):
def __init__(self,
input_channels=1,
complex=False,
model_complexity=45,
model_depth=20,
padding_mode='zeros'):
super().__init__()
if complex:
model_complexity = int(model_complexity // 1.414)
self.set_size(
model_complexity=model_complexity,
input_channels=input_channels,
model_depth=model_depth)
self.encoders = []
self.model_length = model_depth // 2
self.fsmn = complex_nn.ComplexUniDeepFsmn(128, 128, 128)
self.se_layers_enc = []
self.fsmn_enc = []
for i in range(self.model_length):
fsmn_enc = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128)
self.add_module('fsmn_enc{}'.format(i), fsmn_enc)
self.fsmn_enc.append(fsmn_enc)
module = Encoder(
self.enc_channels[i],
self.enc_channels[i + 1],
kernel_size=self.enc_kernel_sizes[i],
stride=self.enc_strides[i],
padding=self.enc_paddings[i],
complex=complex,
padding_mode=padding_mode)
self.add_module('encoder{}'.format(i), module)
self.encoders.append(module)
se_layer_enc = SELayer(self.enc_channels[i + 1], 8)
self.add_module('se_layer_enc{}'.format(i), se_layer_enc)
self.se_layers_enc.append(se_layer_enc)
self.decoders = []
self.fsmn_dec = []
self.se_layers_dec = []
for i in range(self.model_length):
fsmn_dec = complex_nn.ComplexUniDeepFsmn_L1(128, 128, 128)
self.add_module('fsmn_dec{}'.format(i), fsmn_dec)
self.fsmn_dec.append(fsmn_dec)
module = Decoder(
self.dec_channels[i] * 2,
self.dec_channels[i + 1],
kernel_size=self.dec_kernel_sizes[i],
stride=self.dec_strides[i],
padding=self.dec_paddings[i],
complex=complex)
self.add_module('decoder{}'.format(i), module)
self.decoders.append(module)
if i < self.model_length - 1:
se_layer_dec = SELayer(self.dec_channels[i + 1], 8)
self.add_module('se_layer_dec{}'.format(i), se_layer_dec)
self.se_layers_dec.append(se_layer_dec)
if complex:
conv = complex_nn.ComplexConv2d
else:
conv = nn.Conv2d
linear = conv(self.dec_channels[-1], 1, 1)
self.add_module('linear', linear)
self.complex = complex
self.padding_mode = padding_mode
self.decoders = nn.ModuleList(self.decoders)
self.encoders = nn.ModuleList(self.encoders)
self.se_layers_enc = nn.ModuleList(self.se_layers_enc)
self.se_layers_dec = nn.ModuleList(self.se_layers_dec)
self.fsmn_enc = nn.ModuleList(self.fsmn_enc)
self.fsmn_dec = nn.ModuleList(self.fsmn_dec)
def forward(self, inputs):
x = inputs
# go down
xs = []
xs_se = []
xs_se.append(x)
for i, encoder in enumerate(self.encoders):
xs.append(x)
if i > 0:
x = self.fsmn_enc[i](x)
x = encoder(x)
xs_se.append(self.se_layers_enc[i](x))
# xs : x0=input x1 ... x9
x = self.fsmn(x)
p = x
for i, decoder in enumerate(self.decoders):
p = decoder(p)
if i < self.model_length - 1:
p = self.fsmn_dec[i](p)
if i == self.model_length - 1:
break
if i < self.model_length - 2:
p = self.se_layers_dec[i](p)
p = torch.cat([p, xs_se[self.model_length - 1 - i]], dim=1)
# cmp_spec: [12, 1, 513, 64, 2]
cmp_spec = self.linear(p)
return cmp_spec
def set_size(self, model_complexity, model_depth=20, input_channels=1):
if model_depth == 14:
self.enc_channels = [
input_channels, 128, 128, 128, 128, 128, 128, 128
]
self.enc_kernel_sizes = [(5, 2), (5, 2), (5, 2), (5, 2), (5, 2),
(5, 2), (2, 2)]
self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1),
(2, 1)]
self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1),
(0, 1), (0, 1)]
self.dec_channels = [64, 128, 128, 128, 128, 128, 128, 1]
self.dec_kernel_sizes = [(2, 2), (5, 2), (5, 2), (5, 2), (6, 2),
(5, 2), (5, 2)]
self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1),
(2, 1)]
self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1),
(0, 1), (0, 1)]
elif model_depth == 10:
self.enc_channels = [
input_channels,
16,
32,
64,
128,
256,
]
self.enc_kernel_sizes = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
self.dec_channels = [128, 128, 64, 32, 16, 1]
self.dec_kernel_sizes = [(3, 3), (3, 3), (3, 3), (4, 3), (3, 3)]
self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
elif model_depth == 20:
self.enc_channels = [
input_channels, model_complexity, model_complexity,
model_complexity * 2, model_complexity * 2,
model_complexity * 2, model_complexity * 2,
model_complexity * 2, model_complexity * 2,
model_complexity * 2, 128
]
self.enc_kernel_sizes = [(7, 1), (1, 7), (6, 4), (7, 5), (5, 3),
(5, 3), (5, 3), (5, 3), (5, 3), (5, 3)]
self.enc_strides = [(1, 1), (1, 1), (2, 2), (2, 1), (2, 2), (2, 1),
(2, 2), (2, 1), (2, 2), (2, 1)]
self.enc_paddings = [
(3, 0),
(0, 3),
None, # (0, 2),
None,
None, # (3,1),
None, # (3,1),
None, # (1,2),
None,
None,
None
]
self.dec_channels = [
0, model_complexity * 2, model_complexity * 2,
model_complexity * 2, model_complexity * 2,
model_complexity * 2, model_complexity * 2,
model_complexity * 2, model_complexity * 2,
model_complexity * 2, model_complexity * 2,
model_complexity * 2
]
self.dec_kernel_sizes = [(4, 3), (4, 2), (4, 3), (4, 2), (4, 3),
(4, 2), (6, 3), (7, 4), (1, 7), (7, 1)]
self.dec_strides = [(2, 1), (2, 2), (2, 1), (2, 2), (2, 1), (2, 2),
(2, 1), (2, 2), (1, 1), (1, 1)]
self.dec_paddings = [(1, 1), (1, 0), (1, 1), (1, 0), (1, 1),
(1, 0), (2, 1), (2, 1), (0, 3), (3, 0)]
else:
raise ValueError('Unknown model depth : {}'.format(model_depth))

View File

@@ -1,4 +1,5 @@
from .audio import LinearAECPipeline
from .audio.ans_pipeline import ANSPipeline
from .base import Pipeline
from .builder import pipeline
from .cv import * # noqa F403

View File

@@ -0,0 +1,117 @@
import os.path
from typing import Any, Dict
import librosa
import numpy as np
import soundfile as sf
import torch
from modelscope.metainfo import Pipelines
from modelscope.utils.constant import Tasks
from ..base import Input, Pipeline
from ..builder import PIPELINES
def audio_norm(x):
rms = (x**2).mean()**0.5
scalar = 10**(-25 / 20) / rms
x = x * scalar
pow_x = x**2
avg_pow_x = pow_x.mean()
rmsx = pow_x[pow_x > avg_pow_x].mean()**0.5
scalarx = 10**(-25 / 20) / rmsx
x = x * scalarx
return x
@PIPELINES.register_module(
Tasks.speech_signal_process,
module_name=Pipelines.speech_frcrn_ans_cirm_16k)
class ANSPipeline(Pipeline):
r"""ANS (Acoustic Noise Suppression) Inference Pipeline .
When invoke the class with pipeline.__call__(), it accept only one parameter:
inputs(str): the path of wav file
"""
SAMPLE_RATE = 16000
def __init__(self, model):
r"""
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model)
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.model = self.model.to(self.device)
self.model.eval()
def preprocess(self, inputs: Input) -> Dict[str, Any]:
assert isinstance(inputs, str) and os.path.exists(inputs) and os.path.isfile(inputs), \
f'Input file do not exists: {inputs}'
data1, fs = sf.read(inputs)
data1 = audio_norm(data1)
if fs != self.SAMPLE_RATE:
data1 = librosa.resample(data1, fs, self.SAMPLE_RATE)
if len(data1.shape) > 1:
data1 = data1[:, 0]
data = data1.astype(np.float32)
inputs = np.reshape(data, [1, data.shape[0]])
return {'ndarray': inputs, 'nsamples': data.shape[0]}
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
ndarray = inputs['ndarray']
nsamples = inputs['nsamples']
decode_do_segement = False
window = 16000
stride = int(window * 0.75)
print('inputs:{}'.format(ndarray.shape))
b, t = ndarray.shape # size()
if t > window * 120:
decode_do_segement = True
if t < window:
ndarray = np.concatenate(
[ndarray, np.zeros((ndarray.shape[0], window - t))], 1)
elif t < window + stride:
padding = window + stride - t
print('padding: {}'.format(padding))
ndarray = np.concatenate(
[ndarray, np.zeros((ndarray.shape[0], padding))], 1)
else:
if (t - window) % stride != 0:
padding = t - (t - window) // stride * stride
print('padding: {}'.format(padding))
ndarray = np.concatenate(
[ndarray, np.zeros((ndarray.shape[0], padding))], 1)
print('inputs after padding:{}'.format(ndarray.shape))
with torch.no_grad():
ndarray = torch.from_numpy(np.float32(ndarray)).to(self.device)
b, t = ndarray.shape
if decode_do_segement:
outputs = np.zeros(t)
give_up_length = (window - stride) // 2
current_idx = 0
while current_idx + window <= t:
print('current_idx: {}'.format(current_idx))
tmp_input = ndarray[:, current_idx:current_idx + window]
tmp_output = self.model(
tmp_input, )['wav_l2'][0].cpu().numpy()
end_index = current_idx + window - give_up_length
if current_idx == 0:
outputs[current_idx:
end_index] = tmp_output[:-give_up_length]
else:
outputs[current_idx
+ give_up_length:end_index] = tmp_output[
give_up_length:-give_up_length]
current_idx += stride
else:
outputs = self.model(ndarray)['wav_l2'][0].cpu().numpy()
return {'output_pcm': outputs[:nsamples]}
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
if 'output_path' in kwargs.keys():
sf.write(kwargs['output_path'], inputs['output_pcm'],
self.SAMPLE_RATE)
return inputs

View File

@@ -16,6 +16,7 @@ protobuf>3,<=3.20
ptflops
PyWavelets>=1.0.0
scikit-learn
SoundFile>0.10
sox
tensorboard
tensorflow==1.15.*

View File

@@ -17,6 +17,9 @@ AEC_LIB_URL = 'http://isv-data.oss-cn-hangzhou.aliyuncs.com/ics%2FMaaS%2FAEC%2Fl
'?Expires=1664085465&OSSAccessKeyId=LTAIxjQyZNde90zh&Signature=Y7gelmGEsQAJRK4yyHSYMrdWizk%3D'
AEC_LIB_FILE = 'libmitaec_pyio.so'
NOISE_SPEECH_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ANS/sample_audio/speech_with_noise.wav'
NOISE_SPEECH_FILE = 'speech_with_noise.wav'
def download(remote_path, local_path):
local_dir = os.path.dirname(local_path)
@@ -30,23 +33,40 @@ def download(remote_path, local_path):
class SpeechSignalProcessTest(unittest.TestCase):
def setUp(self) -> None:
self.model_id = 'damo/speech_dfsmn_aec_psm_16k'
pass
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_aec(self):
# A temporary hack to provide c++ lib. Download it first.
download(AEC_LIB_URL, AEC_LIB_FILE)
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run(self):
# Download audio files
download(NEAREND_MIC_URL, NEAREND_MIC_FILE)
download(FAREND_SPEECH_URL, FAREND_SPEECH_FILE)
model_id = 'damo/speech_dfsmn_aec_psm_16k'
input = {
'nearend_mic': NEAREND_MIC_FILE,
'farend_speech': FAREND_SPEECH_FILE
}
aec = pipeline(
Tasks.speech_signal_process,
model=self.model_id,
model=model_id,
pipeline_name=Pipelines.speech_dfsmn_aec_psm_16k)
aec(input, output_path='output.wav')
output_path = os.path.abspath('output.wav')
aec(input, output_path=output_path)
print(f'Processed audio saved to {output_path}')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_ans(self):
# Download audio files
download(NOISE_SPEECH_URL, NOISE_SPEECH_FILE)
model_id = 'damo/speech_frcrn_ans_cirm_16k'
ans = pipeline(
Tasks.speech_signal_process,
model=model_id,
pipeline_name=Pipelines.speech_frcrn_ans_cirm_16k)
output_path = os.path.abspath('output.wav')
ans(NOISE_SPEECH_FILE, output_path=output_path)
print(f'Processed audio saved to {output_path}')
if __name__ == '__main__':