[to #42322933] feat:add speech separation pipeline

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11255740
This commit is contained in:
bin.xue
2023-01-03 13:18:44 +08:00
committed by yingda.chen
parent 01c498cd14
commit 0fdf37312f
13 changed files with 1196 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:34c2f1867f7882614b7087f2fd2acb722d0f520a2ec50b2d116d5b3f0c05f84b
size 141134

View File

@@ -113,6 +113,7 @@ class Models(object):
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield'
speech_mossformer_separation_temporal_8k = 'speech_mossformer_separation_temporal_8k'
kws_kwsbp = 'kws-kwsbp'
generic_asr = 'generic-asr'
wenet_asr = 'wenet-asr'
@@ -317,6 +318,7 @@ class Pipelines(object):
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k'
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
speech_separation = 'speech-separation'
kws_kwsbp = 'kws-kwsbp'
asr_inference = 'asr-inference'
asr_wenet_inference = 'asr-wenet-inference'

View File

@@ -0,0 +1,68 @@
# Copyright 2018 Northwestern Polytechnical University (author: Ke Wang)
# made publicly available under the MIT License
# at https://github.com/wangkenpu/Conv-TasNet-PyTorch/blob/64188ffa48971218fdd68b66906970f215d7eca2/model/layer_norm.py
from __future__ import absolute_import, division, print_function
import torch
import torch.nn as nn
class CLayerNorm(nn.LayerNorm):
"""Channel-wise layer normalization."""
def __init__(self, *args, **kwargs):
super(CLayerNorm, self).__init__(*args, **kwargs)
def forward(self, sample):
"""Forward function.
Args:
sample: [batch_size, channels, length]
"""
if sample.dim() != 3:
raise RuntimeError('{} only accept 3-D tensor as input'.format(
self.__name__))
# [N, C, T] -> [N, T, C]
sample = torch.transpose(sample, 1, 2)
# LayerNorm
sample = super().forward(sample)
# [N, T, C] -> [N, C, T]
sample = torch.transpose(sample, 1, 2)
return sample
class GLayerNorm(nn.Module):
"""Global Layer Normalization for TasNet."""
def __init__(self, channels, eps=1e-5):
super(GLayerNorm, self).__init__()
self.eps = eps
self.norm_dim = channels
self.gamma = nn.Parameter(torch.Tensor(channels))
self.beta = nn.Parameter(torch.Tensor(channels))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.gamma)
nn.init.zeros_(self.beta)
def forward(self, sample):
"""Forward function.
Args:
sample: [batch_size, channels, length]
"""
if sample.dim() != 3:
raise RuntimeError('{} only accept 3-D tensor as input'.format(
self.__name__))
# [N, C, T] -> [N, T, C]
sample = torch.transpose(sample, 1, 2)
# Mean and variance [N, 1, 1]
mean = torch.mean(sample, (1, 2), keepdim=True)
var = torch.mean((sample - mean)**2, (1, 2), keepdim=True)
sample = (sample
- mean) / torch.sqrt(var + self.eps) * self.gamma + self.beta
# [N, T, C] -> [N, C, T]
sample = torch.transpose(sample, 1, 2)
return sample

View File

@@ -0,0 +1,472 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import os
from typing import Any, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.metainfo import Models
from modelscope.models import MODELS, TorchModel
from modelscope.models.audio.separation.mossformer_block import (
MossFormerModule, ScaledSinuEmbedding)
from modelscope.models.audio.separation.mossformer_conv_module import (
CumulativeLayerNorm, GlobalLayerNorm)
from modelscope.models.base import Tensor
from modelscope.utils.constant import Tasks
EPS = 1e-8
@MODELS.register_module(
Tasks.speech_separation,
module_name=Models.speech_mossformer_separation_temporal_8k)
class MossFormer(TorchModel):
"""Library to support MossFormer speech separation.
Args:
model_dir (str): the model path.
"""
def __init__(self, model_dir: str, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
self.encoder = Encoder(
kernel_size=kwargs['kernel_size'],
out_channels=kwargs['out_channels'])
self.decoder = Decoder(
in_channels=kwargs['in_channels'],
out_channels=1,
kernel_size=kwargs['kernel_size'],
stride=kwargs['stride'],
bias=kwargs['bias'])
self.mask_net = MossFormerMaskNet(
kwargs['in_channels'],
kwargs['out_channels'],
MossFormerM(kwargs['num_blocks'], kwargs['d_model'],
kwargs['attn_dropout'], kwargs['group_size'],
kwargs['query_key_dim'], kwargs['expansion_factor'],
kwargs['causal']),
norm=kwargs['norm'],
num_spks=kwargs['num_spks'])
self.num_spks = kwargs['num_spks']
def forward(self, inputs: Tensor) -> Dict[str, Any]:
# Separation
mix_w = self.encoder(inputs)
est_mask = self.mask_net(mix_w)
mix_w = torch.stack([mix_w] * self.num_spks)
sep_h = mix_w * est_mask
# Decoding
est_source = torch.cat(
[
self.decoder(sep_h[i]).unsqueeze(-1)
for i in range(self.num_spks)
],
dim=-1,
)
# T changed after conv1d in encoder, fix it here
t_origin = inputs.size(1)
t_est = est_source.size(1)
if t_origin > t_est:
est_source = F.pad(est_source, (0, 0, 0, t_origin - t_est))
else:
est_source = est_source[:, :t_origin, :]
return est_source
def load_check_point(self, load_path=None, device=None):
if not load_path:
load_path = self.model_dir
if not device:
device = torch.device('cpu')
self.encoder.load_state_dict(
torch.load(
os.path.join(load_path, 'encoder.bin'), map_location=device),
strict=True)
self.decoder.load_state_dict(
torch.load(
os.path.join(load_path, 'decoder.bin'), map_location=device),
strict=True)
self.mask_net.load_state_dict(
torch.load(
os.path.join(load_path, 'masknet.bin'), map_location=device),
strict=True)
def select_norm(norm, dim, shape):
"""Just a wrapper to select the normalization type.
"""
if norm == 'gln':
return GlobalLayerNorm(dim, shape, elementwise_affine=True)
if norm == 'cln':
return CumulativeLayerNorm(dim, elementwise_affine=True)
if norm == 'ln':
return nn.GroupNorm(1, dim, eps=1e-8)
else:
return nn.BatchNorm1d(dim)
class Encoder(nn.Module):
"""Convolutional Encoder Layer.
Args:
kernel_size: Length of filters.
in_channels: Number of input channels.
out_channels: Number of output channels.
Example:
-------
>>> x = torch.randn(2, 1000)
>>> encoder = Encoder(kernel_size=4, out_channels=64)
>>> h = encoder(x)
>>> h.shape
torch.Size([2, 64, 499])
"""
def __init__(self,
kernel_size: int = 2,
out_channels: int = 64,
in_channels: int = 1):
super(Encoder, self).__init__()
self.conv1d = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=kernel_size // 2,
groups=1,
bias=False,
)
self.in_channels = in_channels
def forward(self, x: torch.Tensor):
"""Return the encoded output.
Args:
x: Input tensor with dimensionality [B, L].
Returns:
Encoded tensor with dimensionality [B, N, T_out].
where B = Batchsize
L = Number of timepoints
N = Number of filters
T_out = Number of timepoints at the output of the encoder
"""
# B x L -> B x 1 x L
if self.in_channels == 1:
x = torch.unsqueeze(x, dim=1)
# B x 1 x L -> B x N x T_out
x = self.conv1d(x)
x = F.relu(x)
return x
class Decoder(nn.ConvTranspose1d):
"""A decoder layer that consists of ConvTranspose1d.
Args:
kernel_size: Length of filters.
in_channels: Number of input channels.
out_channels: Number of output channels.
Example
---------
>>> x = torch.randn(2, 100, 1000)
>>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
>>> h = decoder(x)
>>> h.shape
torch.Size([2, 1003])
"""
def __init__(self, *args, **kwargs):
super(Decoder, self).__init__(*args, **kwargs)
def forward(self, x):
"""Return the decoded output.
Args:
x: Input tensor with dimensionality [B, N, L].
where, B = Batchsize,
N = number of filters
L = time points
"""
if x.dim() not in [2, 3]:
raise RuntimeError('{} accept 3/4D tensor as input'.format(
self.__name__))
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
if torch.squeeze(x).dim() == 1:
x = torch.squeeze(x, dim=1)
else:
x = torch.squeeze(x)
return x
class IdentityBlock:
"""This block is used when we want to have identity transformation within the Dual_path block.
Example
-------
>>> x = torch.randn(10, 100)
>>> IB = IdentityBlock()
>>> xhat = IB(x)
"""
def _init__(self, **kwargs):
pass
def __call__(self, x):
return x
class MossFormerM(nn.Module):
"""This class implements the transformer encoder.
Args:
num_blocks : int
Number of mossformer blocks to include.
d_model : int
The dimension of the input embedding.
attn_dropout : float
Dropout for the self-attention (Optional).
group_size: int
the chunk size
query_key_dim: int
the attention vector dimension
expansion_factor: int
the expansion factor for the linear projection in conv module
causal: bool
true for causal / false for non causal
Example
-------
>>> import torch
>>> x = torch.rand((8, 60, 512)) #B, S, N
>>> net = MossFormerM(num_blocks=8, d_model=512)
>>> output, _ = net(x)
>>> output.shape
torch.Size([8, 60, 512])
"""
def __init__(self,
num_blocks,
d_model=None,
attn_dropout=0.1,
group_size=256,
query_key_dim=128,
expansion_factor=4.,
causal=False):
super().__init__()
self.mossformerM = MossFormerModule(
dim=d_model,
depth=num_blocks,
group_size=group_size,
query_key_dim=query_key_dim,
expansion_factor=expansion_factor,
causal=causal,
attn_dropout=attn_dropout)
import speechbrain as sb
self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
def forward(self, src: torch.Tensor):
"""
Args:
src: Tensor shape [B, S, N],
where, B = Batchsize,
S = time points
N = number of filters
The sequence to the encoder layer (required).
"""
output = self.mossformerM(src)
output = self.norm(output)
return output
class ComputeAttention(nn.Module):
"""Computation block for dual-path processing.
Args:
att_mdl : torch.nn.module
Model to process within the chunks.
out_channels : int
Dimensionality of attention model.
norm : str
Normalization type.
skip_connection : bool
Skip connection around the attention module.
Example
---------
>>> att_block = MossFormerM(num_blocks=8, d_model=512)
>>> comp_att = ComputeAttention(att_block, 512)
>>> x = torch.randn(10, 64, 512)
>>> x = comp_att(x)
>>> x.shape
torch.Size([10, 64, 512])
"""
def __init__(
self,
att_mdl,
out_channels,
norm='ln',
skip_connection=True,
):
super(ComputeAttention, self).__init__()
self.att_mdl = att_mdl
self.skip_connection = skip_connection
# Norm
self.norm = norm
if norm is not None:
self.att_norm = select_norm(norm, out_channels, 3)
def forward(self, x: torch.Tensor):
"""Returns the output tensor.
Args:
x: Input tensor of dimension [B, S, N].
Returns:
out: Output tensor of dimension [B, S, N].
where, B = Batchsize,
N = number of filters
S = time points
"""
# [B, S, N]
att_out = x.permute(0, 2, 1).contiguous()
att_out = self.att_mdl(att_out)
# [B, N, S]
att_out = att_out.permute(0, 2, 1).contiguous()
if self.norm is not None:
att_out = self.att_norm(att_out)
# [B, N, S]
if self.skip_connection:
att_out = att_out + x
out = att_out
return out
class MossFormerMaskNet(nn.Module):
"""The dual path model which is the basis for dualpathrnn, sepformer, dptnet.
Args:
in_channels : int
Number of channels at the output of the encoder.
out_channels : int
Number of channels that would be inputted to the intra and inter blocks.
att_model : torch.nn.module
Attention model to process the input sequence.
norm : str
Normalization type.
num_spks : int
Number of sources (speakers).
skip_connection : bool
Skip connection around attention module.
use_global_pos_enc : bool
Global positional encodings.
Example
---------
>>> mossformer_block = MossFormerM(num_blocks=8, d_model=512)
>>> mossformer_masknet = MossFormerMaskNet(64, 64, att_model, num_spks=2)
>>> x = torch.randn(10, 64, 2000)
>>> x = mossformer_masknet(x)
>>> x.shape
torch.Size([2, 10, 64, 2000])
"""
def __init__(
self,
in_channels,
out_channels,
att_model,
norm='ln',
num_spks=2,
skip_connection=True,
use_global_pos_enc=True,
):
super(MossFormerMaskNet, self).__init__()
self.num_spks = num_spks
self.norm = select_norm(norm, in_channels, 3)
self.conv1d_encoder = nn.Conv1d(
in_channels, out_channels, 1, bias=False)
self.use_global_pos_enc = use_global_pos_enc
if self.use_global_pos_enc:
self.pos_enc = ScaledSinuEmbedding(out_channels)
self.mdl = copy.deepcopy(
ComputeAttention(
att_model,
out_channels,
norm,
skip_connection=skip_connection,
))
self.conv1d_out = nn.Conv1d(
out_channels, out_channels * num_spks, kernel_size=1)
self.conv1_decoder = nn.Conv1d(
out_channels, in_channels, 1, bias=False)
self.prelu = nn.PReLU()
self.activation = nn.ReLU()
# gated output layer
self.output = nn.Sequential(
nn.Conv1d(out_channels, out_channels, 1), nn.Tanh())
self.output_gate = nn.Sequential(
nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid())
def forward(self, x: torch.Tensor):
"""Returns the output tensor.
Args:
x: Input tensor of dimension [B, N, S].
Returns:
out: Output tensor of dimension [spks, B, N, S]
where, spks = Number of speakers
B = Batchsize,
N = number of filters
S = the number of time frames
"""
# before each line we indicate the shape after executing the line
# [B, N, L]
x = self.norm(x)
# [B, N, L]
x = self.conv1d_encoder(x)
if self.use_global_pos_enc:
base = x
x = x.transpose(1, -1)
emb = self.pos_enc(x)
emb = emb.transpose(0, -1)
x = base + emb
# [B, N, S]
x = self.mdl(x)
x = self.prelu(x)
# [B, N*spks, S]
x = self.conv1d_out(x)
b, _, s = x.shape
# [B*spks, N, S]
x = x.view(b * self.num_spks, -1, s)
# [B*spks, N, S]
x = self.output(x) * self.output_gate(x)
# [B*spks, N, S]
x = self.conv1_decoder(x)
# [B, spks, N, S]
_, n, L = x.shape
x = x.view(b, self.num_spks, n, L)
x = self.activation(x)
# [spks, B, N, S]
x = x.transpose(0, 1)
return x

View File

@@ -0,0 +1,265 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn.functional as F
from torch import einsum, nn
from modelscope.models.audio.separation.mossformer_conv_module import \
MossFormerConvModule
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def padding_to_multiple_of(n, mult):
remainder = n % mult
if remainder == 0:
return 0
return mult - remainder
class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
def forward(self, x):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class ScaledSinuEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = nn.Parameter(torch.ones(1, ))
inv_freq = 1. / (10000**(torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x):
n, device = x.shape[1], x.device
t = torch.arange(n, device=device).type_as(self.inv_freq)
sinu = einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1)
return emb * self.scale
class OffsetScale(nn.Module):
def __init__(self, dim, heads=1):
super().__init__()
self.gamma = nn.Parameter(torch.ones(heads, dim))
self.beta = nn.Parameter(torch.zeros(heads, dim))
nn.init.normal_(self.gamma, std=0.02)
def forward(self, x):
out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
return out.unbind(dim=-2)
class FFConvM(nn.Module):
def __init__(self, dim_in, dim_out, norm_klass=nn.LayerNorm, dropout=0.1):
super().__init__()
self.mdl = nn.Sequential(
norm_klass(dim_in), nn.Linear(dim_in, dim_out), nn.SiLU(),
MossFormerConvModule(dim_out), nn.Dropout(dropout))
def forward(self, x):
output = self.mdl(x)
return output
class MossFormerBlock(nn.Module):
def __init__(self,
dim,
group_size=256,
query_key_dim=128,
expansion_factor=1.,
causal=False,
dropout=0.1,
rotary_pos_emb=None,
norm_klass=nn.LayerNorm,
shift_tokens=True):
super().__init__()
hidden_dim = int(dim * expansion_factor)
self.group_size = group_size
self.causal = causal
self.shift_tokens = shift_tokens
# positional embeddings
self.rotary_pos_emb = rotary_pos_emb
# norm
self.dropout = nn.Dropout(dropout)
# projections
self.to_hidden = FFConvM(
dim_in=dim,
dim_out=hidden_dim,
norm_klass=norm_klass,
dropout=dropout,
)
self.to_qk = FFConvM(
dim_in=dim,
dim_out=query_key_dim,
norm_klass=norm_klass,
dropout=dropout,
)
self.qk_offset_scale = OffsetScale(query_key_dim, heads=4)
self.to_out = FFConvM(
dim_in=dim * 2,
dim_out=dim,
norm_klass=norm_klass,
dropout=dropout,
)
self.gateActivate = nn.Sigmoid()
def forward(self, x):
# prenorm
normed_x = x
# do token shift - a great, costless trick from an independent AI researcher in Shenzhen
if self.shift_tokens:
x_shift, x_pass = normed_x.chunk(2, dim=-1)
x_shift = F.pad(x_shift, (0, 0, 1, -1), value=0.)
normed_x = torch.cat((x_shift, x_pass), dim=-1)
# initial projections
v, u = self.to_hidden(normed_x).chunk(2, dim=-1)
qk = self.to_qk(normed_x)
# offset and scale
quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v,
u)
# projection out and residual
out = (att_u * v) * self.gateActivate(att_v * u)
x = x + self.to_out(out)
return x
def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask=None):
b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
from einops import rearrange
if exists(mask):
lin_mask = rearrange(mask, '... -> ... 1')
lin_k = lin_k.masked_fill(~lin_mask, 0.)
# rotate queries and keys
if exists(self.rotary_pos_emb):
quad_q, lin_q, quad_k, lin_k = map(
self.rotary_pos_emb.rotate_queries_or_keys,
(quad_q, lin_q, quad_k, lin_k))
# padding for groups
padding = padding_to_multiple_of(n, g)
if padding > 0:
quad_q, quad_k, lin_q, lin_k, v, u = map(
lambda t: F.pad(t, (0, 0, 0, padding), value=0.),
(quad_q, quad_k, lin_q, lin_k, v, u))
mask = default(mask,
torch.ones((b, n), device=device, dtype=torch.bool))
mask = F.pad(mask, (0, padding), value=False)
# group along sequence
quad_q, quad_k, lin_q, lin_k, v, u = map(
lambda t: rearrange(t, 'b (g n) d -> b g n d', n=self.group_size),
(quad_q, quad_k, lin_q, lin_k, v, u))
if exists(mask):
mask = rearrange(mask, 'b (g j) -> b g 1 j', j=g)
# calculate quadratic attention output
sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
attn = F.relu(sim)**2
attn = self.dropout(attn)
if exists(mask):
attn = attn.masked_fill(~mask, 0.)
if self.causal:
causal_mask = torch.ones((g, g), dtype=torch.bool,
device=device).triu(1)
attn = attn.masked_fill(causal_mask, 0.)
quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
# calculate linear attention output
if self.causal:
lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
# exclusive cumulative sum along group dimension
lin_kv = lin_kv.cumsum(dim=1)
lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value=0.)
lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
# exclusive cumulative sum along group dimension
lin_ku = lin_ku.cumsum(dim=1)
lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value=0.)
lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
else:
lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
# fold back groups into full sequence, and excise out padding
quad_attn_out_v, lin_attn_out_v = map(
lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n],
(quad_out_v, lin_out_v))
quad_attn_out_u, lin_attn_out_u = map(
lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n],
(quad_out_u, lin_out_u))
# gate
return quad_attn_out_v + lin_attn_out_v, quad_attn_out_u + lin_attn_out_u
class MossFormerModule(nn.Module):
def __init__(self,
dim,
depth,
group_size=256,
query_key_dim=128,
expansion_factor=4.,
causal=False,
attn_dropout=0.1,
norm_type='scalenorm',
shift_tokens=True):
super().__init__()
assert norm_type in (
'scalenorm',
'layernorm'), 'norm_type must be one of scalenorm or layernorm'
if norm_type == 'scalenorm':
norm_klass = ScaleNorm
elif norm_type == 'layernorm':
norm_klass = nn.LayerNorm
from rotary_embedding_torch import RotaryEmbedding
rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim))
# max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
self.layers = nn.ModuleList([
MossFormerBlock(
dim=dim,
group_size=group_size,
query_key_dim=query_key_dim,
expansion_factor=expansion_factor,
causal=causal,
dropout=attn_dropout,
rotary_pos_emb=rotary_pos_emb,
norm_klass=norm_klass,
shift_tokens=shift_tokens) for _ in range(depth)
])
def forward(self, x):
for mossformer_layer in self.layers:
x = mossformer_layer(x)
return x

View File

@@ -0,0 +1,272 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
import torch.nn.init as init
from torch import Tensor
EPS = 1e-8
class GlobalLayerNorm(nn.Module):
"""Calculate Global Layer Normalization.
Args:
dim : (int or list or torch.Size)
Input shape from an expected input of size.
eps : float
A value added to the denominator for numerical stability.
elementwise_affine : bool
A boolean value that when set to True,
this module has learnable per-element affine parameters
initialized to ones (for weights) and zeros (for biases).
Example
-------
>>> x = torch.randn(5, 10, 20)
>>> GLN = GlobalLayerNorm(10, 3)
>>> x_norm = GLN(x)
"""
def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
super(GlobalLayerNorm, self).__init__()
self.dim = dim
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
if shape == 3:
self.weight = nn.Parameter(torch.ones(self.dim, 1))
self.bias = nn.Parameter(torch.zeros(self.dim, 1))
if shape == 4:
self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
def forward(self, x):
"""Returns the normalized tensor.
Args:
x: Tensor of size [N, C, K, S] or [N, C, L].
"""
# N x 1 x 1
# cln: mean,var N x 1 x K x S
# gln: mean,var N x 1 x 1
if x.dim() == 3:
mean = torch.mean(x, (1, 2), keepdim=True)
var = torch.mean((x - mean)**2, (1, 2), keepdim=True)
if self.elementwise_affine:
x = (self.weight * (x - mean) / torch.sqrt(var + self.eps)
+ self.bias) # yapf: disable
else:
x = (x - mean) / torch.sqrt(var + self.eps)
if x.dim() == 4:
mean = torch.mean(x, (1, 2, 3), keepdim=True)
var = torch.mean((x - mean)**2, (1, 2, 3), keepdim=True)
if self.elementwise_affine:
x = (self.weight * (x - mean) / torch.sqrt(var + self.eps)
+ self.bias) # yapf: disable
else:
x = (x - mean) / torch.sqrt(var + self.eps)
return x
class CumulativeLayerNorm(nn.LayerNorm):
"""Calculate Cumulative Layer Normalization.
Args:
dim: Dimension that you want to normalize.
elementwise_affine: Learnable per-element affine parameters.
Example
-------
>>> x = torch.randn(5, 10, 20)
>>> CLN = CumulativeLayerNorm(10)
>>> x_norm = CLN(x)
"""
def __init__(self, dim, elementwise_affine=True):
super(CumulativeLayerNorm, self).__init__(
dim, elementwise_affine=elementwise_affine, eps=1e-8)
def forward(self, x):
"""Returns the normalized tensor.
Args:
x: Tensor size [N, C, K, S] or [N, C, L]
"""
# N x K x S x C
if x.dim() == 4:
x = x.permute(0, 2, 3, 1).contiguous()
# N x K x S x C == only channel norm
x = super().forward(x)
# N x C x K x S
x = x.permute(0, 3, 1, 2).contiguous()
if x.dim() == 3:
x = torch.transpose(x, 1, 2)
# N x L x C == only channel norm
x = super().forward(x)
# N x C x L
x = torch.transpose(x, 1, 2)
return x
def select_norm(norm, dim, shape):
"""Just a wrapper to select the normalization type.
"""
if norm == 'gln':
return GlobalLayerNorm(dim, shape, elementwise_affine=True)
if norm == 'cln':
return CumulativeLayerNorm(dim, elementwise_affine=True)
if norm == 'ln':
return nn.GroupNorm(1, dim, eps=1e-8)
else:
return nn.BatchNorm1d(dim)
class Swish(nn.Module):
"""
Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied
to a variety of challenging domains such as Image classification and Machine translation.
"""
def __init__(self):
super(Swish, self).__init__()
def forward(self, inputs: Tensor) -> Tensor:
return inputs * inputs.sigmoid()
class GLU(nn.Module):
"""
The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing
in the paper “Language Modeling with Gated Convolutional Networks”
"""
def __init__(self, dim: int) -> None:
super(GLU, self).__init__()
self.dim = dim
def forward(self, inputs: Tensor) -> Tensor:
outputs, gate = inputs.chunk(2, dim=self.dim)
return outputs * gate.sigmoid()
class Transpose(nn.Module):
""" Wrapper class of torch.transpose() for Sequential module. """
def __init__(self, shape: tuple):
super(Transpose, self).__init__()
self.shape = shape
def forward(self, x: Tensor) -> Tensor:
return x.transpose(*self.shape)
class Linear(nn.Module):
"""
Wrapper class of torch.nn.Linear
Weight initialize by xavier initialization and bias initialize to zeros.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True) -> None:
super(Linear, self).__init__()
self.linear = nn.Linear(in_features, out_features, bias=bias)
init.xavier_uniform_(self.linear.weight)
if bias:
init.zeros_(self.linear.bias)
def forward(self, x: Tensor) -> Tensor:
return self.linear(x)
class DepthwiseConv1d(nn.Module):
"""
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
this operation is termed in literature as depthwise convolution.
Args:
in_channels (int): Number of channels in the input
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
Inputs: inputs
- **inputs** (batch, in_channels, time): Tensor containing input vector
Returns: outputs
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
bias: bool = False,
) -> None:
super(DepthwiseConv1d, self).__init__()
assert out_channels % in_channels == 0, 'out_channels should be constant multiple of in_channels'
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
groups=in_channels,
stride=stride,
padding=padding,
bias=bias,
)
def forward(self, inputs: Tensor) -> Tensor:
return self.conv(inputs)
class MossFormerConvModule(nn.Module):
"""Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
to aid training deep models.
Args:
in_channels (int): Number of channels in the input
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 17
dropout_p (float, optional): probability of dropout
"""
def __init__(self,
in_channels: int,
kernel_size: int = 17,
expansion_factor: int = 2) -> None:
super(MossFormerConvModule, self).__init__()
assert (
kernel_size - 1
) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
assert expansion_factor == 2, 'Currently, Only Supports expansion_factor 2'
self.sequential = nn.Sequential(
Transpose(shape=(1, 2)),
DepthwiseConv1d(
in_channels,
in_channels,
kernel_size,
stride=1,
padding=(kernel_size - 1) // 2),
)
def forward(self, inputs: Tensor) -> Tensor:
"""
Args:
inputs (batch, time, dim): Tensor contains input sequences
Returns:
outputs (batch, time, dim): Tensor produces by conformer convolution module.
"""
return inputs + self.sequential(inputs).transpose(1, 2)

View File

@@ -27,6 +27,7 @@ class OutputKeys(object):
OUTPUT_IMG = 'output_img'
OUTPUT_VIDEO = 'output_video'
OUTPUT_PCM = 'output_pcm'
OUTPUT_PCM_LIST = 'output_pcm_list'
OUTPUT_WAV = 'output_wav'
IMG_EMBEDDING = 'img_embedding'
SPO_LIST = 'spo_list'
@@ -696,6 +697,7 @@ TASK_OUTPUTS = {
Tasks.speech_signal_process: [OutputKeys.OUTPUT_PCM],
Tasks.acoustic_echo_cancellation: [OutputKeys.OUTPUT_PCM],
Tasks.acoustic_noise_suppression: [OutputKeys.OUTPUT_PCM],
Tasks.speech_separation: [OutputKeys.OUTPUT_PCM_LIST],
# text_to_speech result for a single sample
# {

View File

@@ -214,6 +214,8 @@ TASK_INPUTS = {
'nearend_mic': InputType.AUDIO,
'farend_speech': InputType.AUDIO
},
Tasks.speech_separation:
InputType.AUDIO,
Tasks.acoustic_noise_suppression:
InputType.AUDIO,
Tasks.text_to_speech:

View File

@@ -0,0 +1,68 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import io
from typing import Any, Dict
import numpy
import soundfile as sf
import torch
from modelscope.fileio import File
from modelscope.metainfo import Pipelines
from modelscope.models.base import Input
from modelscope.outputs import OutputKeys
from modelscope.pipelines import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.speech_separation, module_name=Pipelines.speech_separation)
class SeparationPipeline(Pipeline):
def __init__(self, model, **kwargs):
"""create a speech separation pipeline for prediction
Args:
model: model id on modelscope hub.
"""
logger.info('loading model...')
super().__init__(model=model, **kwargs)
self.model.load_check_point(device=self.device)
self.model.eval()
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
if isinstance(inputs, str):
file_bytes = File.read(inputs)
data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
if fs != 8000:
raise ValueError(
'modelscope error: The audio sample rate should be 8000')
elif isinstance(inputs, bytes):
data = torch.from_numpy(
numpy.frombuffer(inputs, dtype=numpy.float32))
return dict(data=data)
def postprocess(self, inputs: Dict[str, Any],
**post_params) -> Dict[str, Any]:
return inputs
def forward(
self, inputs: Dict[str, Any], **forward_params
) -> Dict[str, Any]: # mix, targets, stage, noise=None):
"""Forward computations from the mixture to the separated signals."""
logger.info('Start forward...')
# Unpack lists and put tensors in the right device
mix = inputs['data'].to(self.device)
mix = torch.unsqueeze(mix, dim=1).transpose(0, 1)
est_source = self.model(mix)
result = []
for ns in range(self.model.num_spks):
signal = est_source[0, :, ns]
signal = signal / signal.abs().max() * 0.5
result.append(signal.unsqueeze(0).cpu())
logger.info('Finish forward.')
return {OutputKeys.OUTPUT_PCM_LIST: result}

View File

@@ -158,6 +158,7 @@ class AudioTasks(object):
auto_speech_recognition = 'auto-speech-recognition'
text_to_speech = 'text-to-speech'
speech_signal_process = 'speech-signal-process'
speech_separation = 'speech-separation'
acoustic_echo_cancellation = 'acoustic-echo-cancellation'
acoustic_noise_suppression = 'acoustic-noise-suppression'
keyword_spotting = 'keyword-spotting'

View File

@@ -29,9 +29,11 @@ pygments>=2.12.0
pysptk>=0.1.15,<0.2.0
pytorch_wavelets
PyWavelets>=1.0.0
rotary_embedding_torch>=0.1.5
scikit-learn
SoundFile>0.10
sox
speechbrain>=0.5
torchaudio
tqdm
traitlets>=5.3.0

View File

@@ -0,0 +1,39 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path
import unittest
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
MIX_SPEECH_FILE = 'data/test/audios/mix_speech.wav'
class SpeechSeparationTest(unittest.TestCase, DemoCompatibilityCheck):
def setUp(self) -> None:
pass
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_normal(self):
import torchaudio
model_id = 'damo/speech_mossformer_separation_temporal_8k'
separation = pipeline(Tasks.speech_separation, model=model_id)
result = separation(os.path.join(os.getcwd(), MIX_SPEECH_FILE))
self.assertTrue(OutputKeys.OUTPUT_PCM_LIST in result)
self.assertEqual(len(result[OutputKeys.OUTPUT_PCM_LIST]), 2)
for i, signal in enumerate(result[OutputKeys.OUTPUT_PCM_LIST]):
save_file = f'output_spk{i}.wav'
# Estimated source
torchaudio.save(save_file, signal, 8000)
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):
self.compatibility_check()
if __name__ == '__main__':
unittest.main()