From 0fdf37312fd9b2ac35c3cd92dd16c919c75bb5ff Mon Sep 17 00:00:00 2001 From: "bin.xue" Date: Tue, 3 Jan 2023 13:18:44 +0800 Subject: [PATCH] [to #42322933] feat:add speech separation pipeline Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11255740 --- data/test/audios/mix_speech.wav | 3 + modelscope/metainfo.py | 2 + .../models/audio/separation/__init__.py | 0 .../models/audio/separation/layer_norm.py | 68 +++ .../models/audio/separation/mossformer.py | 472 ++++++++++++++++++ .../audio/separation/mossformer_block.py | 265 ++++++++++ .../separation/mossformer_conv_module.py | 272 ++++++++++ modelscope/outputs/outputs.py | 2 + modelscope/pipeline_inputs.py | 2 + .../pipelines/audio/separation_pipeline.py | 68 +++ modelscope/utils/constant.py | 1 + requirements/audio.txt | 2 + tests/pipelines/test_speech_separation.py | 39 ++ 13 files changed, 1196 insertions(+) create mode 100644 data/test/audios/mix_speech.wav create mode 100644 modelscope/models/audio/separation/__init__.py create mode 100644 modelscope/models/audio/separation/layer_norm.py create mode 100644 modelscope/models/audio/separation/mossformer.py create mode 100644 modelscope/models/audio/separation/mossformer_block.py create mode 100644 modelscope/models/audio/separation/mossformer_conv_module.py create mode 100644 modelscope/pipelines/audio/separation_pipeline.py create mode 100644 tests/pipelines/test_speech_separation.py diff --git a/data/test/audios/mix_speech.wav b/data/test/audios/mix_speech.wav new file mode 100644 index 00000000..b200e668 --- /dev/null +++ b/data/test/audios/mix_speech.wav @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:34c2f1867f7882614b7087f2fd2acb722d0f520a2ec50b2d116d5b3f0c05f84b +size 141134 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 3b14870f..3cbd788a 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -113,6 +113,7 @@ class Models(object): speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield' + speech_mossformer_separation_temporal_8k = 'speech_mossformer_separation_temporal_8k' kws_kwsbp = 'kws-kwsbp' generic_asr = 'generic-asr' wenet_asr = 'wenet-asr' @@ -317,6 +318,7 @@ class Pipelines(object): speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k' speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k' speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield' + speech_separation = 'speech-separation' kws_kwsbp = 'kws-kwsbp' asr_inference = 'asr-inference' asr_wenet_inference = 'asr-wenet-inference' diff --git a/modelscope/models/audio/separation/__init__.py b/modelscope/models/audio/separation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/separation/layer_norm.py b/modelscope/models/audio/separation/layer_norm.py new file mode 100644 index 00000000..a4145cbc --- /dev/null +++ b/modelscope/models/audio/separation/layer_norm.py @@ -0,0 +1,68 @@ +# Copyright 2018 Northwestern Polytechnical University (author: Ke Wang) +# made publicly available under the MIT License +# at https://github.com/wangkenpu/Conv-TasNet-PyTorch/blob/64188ffa48971218fdd68b66906970f215d7eca2/model/layer_norm.py + +from __future__ import absolute_import, division, print_function + +import torch +import torch.nn as nn + + +class CLayerNorm(nn.LayerNorm): + """Channel-wise layer normalization.""" + + def __init__(self, *args, **kwargs): + super(CLayerNorm, self).__init__(*args, **kwargs) + + def forward(self, sample): + """Forward function. + + Args: + sample: [batch_size, channels, length] + """ + if sample.dim() != 3: + raise RuntimeError('{} only accept 3-D tensor as input'.format( + self.__name__)) + # [N, C, T] -> [N, T, C] + sample = torch.transpose(sample, 1, 2) + # LayerNorm + sample = super().forward(sample) + # [N, T, C] -> [N, C, T] + sample = torch.transpose(sample, 1, 2) + return sample + + +class GLayerNorm(nn.Module): + """Global Layer Normalization for TasNet.""" + + def __init__(self, channels, eps=1e-5): + super(GLayerNorm, self).__init__() + self.eps = eps + self.norm_dim = channels + self.gamma = nn.Parameter(torch.Tensor(channels)) + self.beta = nn.Parameter(torch.Tensor(channels)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.gamma) + nn.init.zeros_(self.beta) + + def forward(self, sample): + """Forward function. + + Args: + sample: [batch_size, channels, length] + """ + if sample.dim() != 3: + raise RuntimeError('{} only accept 3-D tensor as input'.format( + self.__name__)) + # [N, C, T] -> [N, T, C] + sample = torch.transpose(sample, 1, 2) + # Mean and variance [N, 1, 1] + mean = torch.mean(sample, (1, 2), keepdim=True) + var = torch.mean((sample - mean)**2, (1, 2), keepdim=True) + sample = (sample + - mean) / torch.sqrt(var + self.eps) * self.gamma + self.beta + # [N, T, C] -> [N, C, T] + sample = torch.transpose(sample, 1, 2) + return sample diff --git a/modelscope/models/audio/separation/mossformer.py b/modelscope/models/audio/separation/mossformer.py new file mode 100644 index 00000000..3ab42f74 --- /dev/null +++ b/modelscope/models/audio/separation/mossformer.py @@ -0,0 +1,472 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import copy +import os +from typing import Any, Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models import MODELS, TorchModel +from modelscope.models.audio.separation.mossformer_block import ( + MossFormerModule, ScaledSinuEmbedding) +from modelscope.models.audio.separation.mossformer_conv_module import ( + CumulativeLayerNorm, GlobalLayerNorm) +from modelscope.models.base import Tensor +from modelscope.utils.constant import Tasks + +EPS = 1e-8 + + +@MODELS.register_module( + Tasks.speech_separation, + module_name=Models.speech_mossformer_separation_temporal_8k) +class MossFormer(TorchModel): + """Library to support MossFormer speech separation. + + Args: + model_dir (str): the model path. + """ + + def __init__(self, model_dir: str, *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + self.encoder = Encoder( + kernel_size=kwargs['kernel_size'], + out_channels=kwargs['out_channels']) + self.decoder = Decoder( + in_channels=kwargs['in_channels'], + out_channels=1, + kernel_size=kwargs['kernel_size'], + stride=kwargs['stride'], + bias=kwargs['bias']) + self.mask_net = MossFormerMaskNet( + kwargs['in_channels'], + kwargs['out_channels'], + MossFormerM(kwargs['num_blocks'], kwargs['d_model'], + kwargs['attn_dropout'], kwargs['group_size'], + kwargs['query_key_dim'], kwargs['expansion_factor'], + kwargs['causal']), + norm=kwargs['norm'], + num_spks=kwargs['num_spks']) + self.num_spks = kwargs['num_spks'] + + def forward(self, inputs: Tensor) -> Dict[str, Any]: + # Separation + mix_w = self.encoder(inputs) + est_mask = self.mask_net(mix_w) + mix_w = torch.stack([mix_w] * self.num_spks) + sep_h = mix_w * est_mask + # Decoding + est_source = torch.cat( + [ + self.decoder(sep_h[i]).unsqueeze(-1) + for i in range(self.num_spks) + ], + dim=-1, + ) + # T changed after conv1d in encoder, fix it here + t_origin = inputs.size(1) + t_est = est_source.size(1) + if t_origin > t_est: + est_source = F.pad(est_source, (0, 0, 0, t_origin - t_est)) + else: + est_source = est_source[:, :t_origin, :] + return est_source + + def load_check_point(self, load_path=None, device=None): + if not load_path: + load_path = self.model_dir + if not device: + device = torch.device('cpu') + self.encoder.load_state_dict( + torch.load( + os.path.join(load_path, 'encoder.bin'), map_location=device), + strict=True) + self.decoder.load_state_dict( + torch.load( + os.path.join(load_path, 'decoder.bin'), map_location=device), + strict=True) + self.mask_net.load_state_dict( + torch.load( + os.path.join(load_path, 'masknet.bin'), map_location=device), + strict=True) + + +def select_norm(norm, dim, shape): + """Just a wrapper to select the normalization type. + """ + + if norm == 'gln': + return GlobalLayerNorm(dim, shape, elementwise_affine=True) + if norm == 'cln': + return CumulativeLayerNorm(dim, elementwise_affine=True) + if norm == 'ln': + return nn.GroupNorm(1, dim, eps=1e-8) + else: + return nn.BatchNorm1d(dim) + + +class Encoder(nn.Module): + """Convolutional Encoder Layer. + + Args: + kernel_size: Length of filters. + in_channels: Number of input channels. + out_channels: Number of output channels. + + Example: + ------- + >>> x = torch.randn(2, 1000) + >>> encoder = Encoder(kernel_size=4, out_channels=64) + >>> h = encoder(x) + >>> h.shape + torch.Size([2, 64, 499]) + """ + + def __init__(self, + kernel_size: int = 2, + out_channels: int = 64, + in_channels: int = 1): + super(Encoder, self).__init__() + self.conv1d = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=kernel_size // 2, + groups=1, + bias=False, + ) + self.in_channels = in_channels + + def forward(self, x: torch.Tensor): + """Return the encoded output. + + Args: + x: Input tensor with dimensionality [B, L]. + + Returns: + Encoded tensor with dimensionality [B, N, T_out]. + where B = Batchsize + L = Number of timepoints + N = Number of filters + T_out = Number of timepoints at the output of the encoder + """ + # B x L -> B x 1 x L + if self.in_channels == 1: + x = torch.unsqueeze(x, dim=1) + # B x 1 x L -> B x N x T_out + x = self.conv1d(x) + x = F.relu(x) + + return x + + +class Decoder(nn.ConvTranspose1d): + """A decoder layer that consists of ConvTranspose1d. + + Args: + kernel_size: Length of filters. + in_channels: Number of input channels. + out_channels: Number of output channels. + + Example + --------- + >>> x = torch.randn(2, 100, 1000) + >>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1) + >>> h = decoder(x) + >>> h.shape + torch.Size([2, 1003]) + """ + + def __init__(self, *args, **kwargs): + super(Decoder, self).__init__(*args, **kwargs) + + def forward(self, x): + """Return the decoded output. + + Args: + x: Input tensor with dimensionality [B, N, L]. + where, B = Batchsize, + N = number of filters + L = time points + """ + + if x.dim() not in [2, 3]: + raise RuntimeError('{} accept 3/4D tensor as input'.format( + self.__name__)) + x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1)) + + if torch.squeeze(x).dim() == 1: + x = torch.squeeze(x, dim=1) + else: + x = torch.squeeze(x) + return x + + +class IdentityBlock: + """This block is used when we want to have identity transformation within the Dual_path block. + + Example + ------- + >>> x = torch.randn(10, 100) + >>> IB = IdentityBlock() + >>> xhat = IB(x) + """ + + def _init__(self, **kwargs): + pass + + def __call__(self, x): + return x + + +class MossFormerM(nn.Module): + """This class implements the transformer encoder. + + Args: + num_blocks : int + Number of mossformer blocks to include. + d_model : int + The dimension of the input embedding. + attn_dropout : float + Dropout for the self-attention (Optional). + group_size: int + the chunk size + query_key_dim: int + the attention vector dimension + expansion_factor: int + the expansion factor for the linear projection in conv module + causal: bool + true for causal / false for non causal + + Example + ------- + >>> import torch + >>> x = torch.rand((8, 60, 512)) #B, S, N + >>> net = MossFormerM(num_blocks=8, d_model=512) + >>> output, _ = net(x) + >>> output.shape + torch.Size([8, 60, 512]) + """ + + def __init__(self, + num_blocks, + d_model=None, + attn_dropout=0.1, + group_size=256, + query_key_dim=128, + expansion_factor=4., + causal=False): + super().__init__() + + self.mossformerM = MossFormerModule( + dim=d_model, + depth=num_blocks, + group_size=group_size, + query_key_dim=query_key_dim, + expansion_factor=expansion_factor, + causal=causal, + attn_dropout=attn_dropout) + import speechbrain as sb + self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6) + + def forward(self, src: torch.Tensor): + """ + Args: + src: Tensor shape [B, S, N], + where, B = Batchsize, + S = time points + N = number of filters + The sequence to the encoder layer (required). + """ + output = self.mossformerM(src) + output = self.norm(output) + + return output + + +class ComputeAttention(nn.Module): + """Computation block for dual-path processing. + + Args: + att_mdl : torch.nn.module + Model to process within the chunks. + out_channels : int + Dimensionality of attention model. + norm : str + Normalization type. + skip_connection : bool + Skip connection around the attention module. + + Example + --------- + >>> att_block = MossFormerM(num_blocks=8, d_model=512) + >>> comp_att = ComputeAttention(att_block, 512) + >>> x = torch.randn(10, 64, 512) + >>> x = comp_att(x) + >>> x.shape + torch.Size([10, 64, 512]) + """ + + def __init__( + self, + att_mdl, + out_channels, + norm='ln', + skip_connection=True, + ): + super(ComputeAttention, self).__init__() + + self.att_mdl = att_mdl + self.skip_connection = skip_connection + + # Norm + self.norm = norm + if norm is not None: + self.att_norm = select_norm(norm, out_channels, 3) + + def forward(self, x: torch.Tensor): + """Returns the output tensor. + + Args: + x: Input tensor of dimension [B, S, N]. + + Returns: + out: Output tensor of dimension [B, S, N]. + where, B = Batchsize, + N = number of filters + S = time points + """ + # [B, S, N] + att_out = x.permute(0, 2, 1).contiguous() + + att_out = self.att_mdl(att_out) + + # [B, N, S] + att_out = att_out.permute(0, 2, 1).contiguous() + if self.norm is not None: + att_out = self.att_norm(att_out) + + # [B, N, S] + if self.skip_connection: + att_out = att_out + x + + out = att_out + return out + + +class MossFormerMaskNet(nn.Module): + """The dual path model which is the basis for dualpathrnn, sepformer, dptnet. + + Args: + in_channels : int + Number of channels at the output of the encoder. + out_channels : int + Number of channels that would be inputted to the intra and inter blocks. + att_model : torch.nn.module + Attention model to process the input sequence. + norm : str + Normalization type. + num_spks : int + Number of sources (speakers). + skip_connection : bool + Skip connection around attention module. + use_global_pos_enc : bool + Global positional encodings. + + Example + --------- + >>> mossformer_block = MossFormerM(num_blocks=8, d_model=512) + >>> mossformer_masknet = MossFormerMaskNet(64, 64, att_model, num_spks=2) + >>> x = torch.randn(10, 64, 2000) + >>> x = mossformer_masknet(x) + >>> x.shape + torch.Size([2, 10, 64, 2000]) + """ + + def __init__( + self, + in_channels, + out_channels, + att_model, + norm='ln', + num_spks=2, + skip_connection=True, + use_global_pos_enc=True, + ): + super(MossFormerMaskNet, self).__init__() + self.num_spks = num_spks + self.norm = select_norm(norm, in_channels, 3) + self.conv1d_encoder = nn.Conv1d( + in_channels, out_channels, 1, bias=False) + self.use_global_pos_enc = use_global_pos_enc + + if self.use_global_pos_enc: + self.pos_enc = ScaledSinuEmbedding(out_channels) + + self.mdl = copy.deepcopy( + ComputeAttention( + att_model, + out_channels, + norm, + skip_connection=skip_connection, + )) + + self.conv1d_out = nn.Conv1d( + out_channels, out_channels * num_spks, kernel_size=1) + self.conv1_decoder = nn.Conv1d( + out_channels, in_channels, 1, bias=False) + self.prelu = nn.PReLU() + self.activation = nn.ReLU() + # gated output layer + self.output = nn.Sequential( + nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()) + self.output_gate = nn.Sequential( + nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()) + + def forward(self, x: torch.Tensor): + """Returns the output tensor. + + Args: + x: Input tensor of dimension [B, N, S]. + + Returns: + out: Output tensor of dimension [spks, B, N, S] + where, spks = Number of speakers + B = Batchsize, + N = number of filters + S = the number of time frames + """ + + # before each line we indicate the shape after executing the line + # [B, N, L] + x = self.norm(x) + # [B, N, L] + x = self.conv1d_encoder(x) + if self.use_global_pos_enc: + base = x + x = x.transpose(1, -1) + emb = self.pos_enc(x) + emb = emb.transpose(0, -1) + x = base + emb + # [B, N, S] + x = self.mdl(x) + x = self.prelu(x) + # [B, N*spks, S] + x = self.conv1d_out(x) + b, _, s = x.shape + # [B*spks, N, S] + x = x.view(b * self.num_spks, -1, s) + # [B*spks, N, S] + x = self.output(x) * self.output_gate(x) + # [B*spks, N, S] + x = self.conv1_decoder(x) + # [B, spks, N, S] + _, n, L = x.shape + x = x.view(b, self.num_spks, n, L) + x = self.activation(x) + # [spks, B, N, S] + x = x.transpose(0, 1) + return x diff --git a/modelscope/models/audio/separation/mossformer_block.py b/modelscope/models/audio/separation/mossformer_block.py new file mode 100644 index 00000000..1db8d010 --- /dev/null +++ b/modelscope/models/audio/separation/mossformer_block.py @@ -0,0 +1,265 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.nn.functional as F +from torch import einsum, nn + +from modelscope.models.audio.separation.mossformer_conv_module import \ + MossFormerConvModule + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +def padding_to_multiple_of(n, mult): + remainder = n % mult + if remainder == 0: + return 0 + return mult - remainder + + +class ScaleNorm(nn.Module): + + def __init__(self, dim, eps=1e-5): + super().__init__() + self.scale = dim**-0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(1)) + + def forward(self, x): + norm = torch.norm(x, dim=-1, keepdim=True) * self.scale + return x / norm.clamp(min=self.eps) * self.g + + +class ScaledSinuEmbedding(nn.Module): + + def __init__(self, dim): + super().__init__() + self.scale = nn.Parameter(torch.ones(1, )) + inv_freq = 1. / (10000**(torch.arange(0, dim, 2).float() / dim)) + self.register_buffer('inv_freq', inv_freq) + + def forward(self, x): + n, device = x.shape[1], x.device + t = torch.arange(n, device=device).type_as(self.inv_freq) + sinu = einsum('i , j -> i j', t, self.inv_freq) + emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1) + return emb * self.scale + + +class OffsetScale(nn.Module): + + def __init__(self, dim, heads=1): + super().__init__() + self.gamma = nn.Parameter(torch.ones(heads, dim)) + self.beta = nn.Parameter(torch.zeros(heads, dim)) + nn.init.normal_(self.gamma, std=0.02) + + def forward(self, x): + out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta + return out.unbind(dim=-2) + + +class FFConvM(nn.Module): + + def __init__(self, dim_in, dim_out, norm_klass=nn.LayerNorm, dropout=0.1): + super().__init__() + self.mdl = nn.Sequential( + norm_klass(dim_in), nn.Linear(dim_in, dim_out), nn.SiLU(), + MossFormerConvModule(dim_out), nn.Dropout(dropout)) + + def forward(self, x): + output = self.mdl(x) + return output + + +class MossFormerBlock(nn.Module): + + def __init__(self, + dim, + group_size=256, + query_key_dim=128, + expansion_factor=1., + causal=False, + dropout=0.1, + rotary_pos_emb=None, + norm_klass=nn.LayerNorm, + shift_tokens=True): + super().__init__() + hidden_dim = int(dim * expansion_factor) + self.group_size = group_size + self.causal = causal + self.shift_tokens = shift_tokens + # positional embeddings + self.rotary_pos_emb = rotary_pos_emb + # norm + self.dropout = nn.Dropout(dropout) + # projections + self.to_hidden = FFConvM( + dim_in=dim, + dim_out=hidden_dim, + norm_klass=norm_klass, + dropout=dropout, + ) + self.to_qk = FFConvM( + dim_in=dim, + dim_out=query_key_dim, + norm_klass=norm_klass, + dropout=dropout, + ) + self.qk_offset_scale = OffsetScale(query_key_dim, heads=4) + self.to_out = FFConvM( + dim_in=dim * 2, + dim_out=dim, + norm_klass=norm_klass, + dropout=dropout, + ) + self.gateActivate = nn.Sigmoid() + + def forward(self, x): + # prenorm + normed_x = x + # do token shift - a great, costless trick from an independent AI researcher in Shenzhen + if self.shift_tokens: + x_shift, x_pass = normed_x.chunk(2, dim=-1) + x_shift = F.pad(x_shift, (0, 0, 1, -1), value=0.) + normed_x = torch.cat((x_shift, x_pass), dim=-1) + + # initial projections + v, u = self.to_hidden(normed_x).chunk(2, dim=-1) + qk = self.to_qk(normed_x) + # offset and scale + quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk) + att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, + u) + + # projection out and residual + out = (att_u * v) * self.gateActivate(att_v * u) + x = x + self.to_out(out) + return x + + def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask=None): + b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size + + from einops import rearrange + if exists(mask): + lin_mask = rearrange(mask, '... -> ... 1') + lin_k = lin_k.masked_fill(~lin_mask, 0.) + + # rotate queries and keys + if exists(self.rotary_pos_emb): + quad_q, lin_q, quad_k, lin_k = map( + self.rotary_pos_emb.rotate_queries_or_keys, + (quad_q, lin_q, quad_k, lin_k)) + + # padding for groups + padding = padding_to_multiple_of(n, g) + if padding > 0: + quad_q, quad_k, lin_q, lin_k, v, u = map( + lambda t: F.pad(t, (0, 0, 0, padding), value=0.), + (quad_q, quad_k, lin_q, lin_k, v, u)) + mask = default(mask, + torch.ones((b, n), device=device, dtype=torch.bool)) + mask = F.pad(mask, (0, padding), value=False) + + # group along sequence + quad_q, quad_k, lin_q, lin_k, v, u = map( + lambda t: rearrange(t, 'b (g n) d -> b g n d', n=self.group_size), + (quad_q, quad_k, lin_q, lin_k, v, u)) + + if exists(mask): + mask = rearrange(mask, 'b (g j) -> b g 1 j', j=g) + + # calculate quadratic attention output + sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g + attn = F.relu(sim)**2 + attn = self.dropout(attn) + if exists(mask): + attn = attn.masked_fill(~mask, 0.) + if self.causal: + causal_mask = torch.ones((g, g), dtype=torch.bool, + device=device).triu(1) + attn = attn.masked_fill(causal_mask, 0.) + + quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v) + quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u) + + # calculate linear attention output + if self.causal: + lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g + # exclusive cumulative sum along group dimension + lin_kv = lin_kv.cumsum(dim=1) + lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value=0.) + lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q) + + lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g + # exclusive cumulative sum along group dimension + lin_ku = lin_ku.cumsum(dim=1) + lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value=0.) + lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q) + else: + lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n + lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv) + + lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n + lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku) + + # fold back groups into full sequence, and excise out padding + quad_attn_out_v, lin_attn_out_v = map( + lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], + (quad_out_v, lin_out_v)) + quad_attn_out_u, lin_attn_out_u = map( + lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], + (quad_out_u, lin_out_u)) + + # gate + return quad_attn_out_v + lin_attn_out_v, quad_attn_out_u + lin_attn_out_u + + +class MossFormerModule(nn.Module): + + def __init__(self, + dim, + depth, + group_size=256, + query_key_dim=128, + expansion_factor=4., + causal=False, + attn_dropout=0.1, + norm_type='scalenorm', + shift_tokens=True): + super().__init__() + assert norm_type in ( + 'scalenorm', + 'layernorm'), 'norm_type must be one of scalenorm or layernorm' + + if norm_type == 'scalenorm': + norm_klass = ScaleNorm + elif norm_type == 'layernorm': + norm_klass = nn.LayerNorm + + from rotary_embedding_torch import RotaryEmbedding + rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim)) + # max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J + self.layers = nn.ModuleList([ + MossFormerBlock( + dim=dim, + group_size=group_size, + query_key_dim=query_key_dim, + expansion_factor=expansion_factor, + causal=causal, + dropout=attn_dropout, + rotary_pos_emb=rotary_pos_emb, + norm_klass=norm_klass, + shift_tokens=shift_tokens) for _ in range(depth) + ]) + + def forward(self, x): + for mossformer_layer in self.layers: + x = mossformer_layer(x) + return x diff --git a/modelscope/models/audio/separation/mossformer_conv_module.py b/modelscope/models/audio/separation/mossformer_conv_module.py new file mode 100644 index 00000000..283269b3 --- /dev/null +++ b/modelscope/models/audio/separation/mossformer_conv_module.py @@ -0,0 +1,272 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import torch +import torch.nn as nn +import torch.nn.init as init +from torch import Tensor + +EPS = 1e-8 + + +class GlobalLayerNorm(nn.Module): + """Calculate Global Layer Normalization. + + Args: + dim : (int or list or torch.Size) + Input shape from an expected input of size. + eps : float + A value added to the denominator for numerical stability. + elementwise_affine : bool + A boolean value that when set to True, + this module has learnable per-element affine parameters + initialized to ones (for weights) and zeros (for biases). + + Example + ------- + >>> x = torch.randn(5, 10, 20) + >>> GLN = GlobalLayerNorm(10, 3) + >>> x_norm = GLN(x) + """ + + def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True): + super(GlobalLayerNorm, self).__init__() + self.dim = dim + self.eps = eps + self.elementwise_affine = elementwise_affine + + if self.elementwise_affine: + if shape == 3: + self.weight = nn.Parameter(torch.ones(self.dim, 1)) + self.bias = nn.Parameter(torch.zeros(self.dim, 1)) + if shape == 4: + self.weight = nn.Parameter(torch.ones(self.dim, 1, 1)) + self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + + def forward(self, x): + """Returns the normalized tensor. + + Args: + x: Tensor of size [N, C, K, S] or [N, C, L]. + """ + # N x 1 x 1 + # cln: mean,var N x 1 x K x S + # gln: mean,var N x 1 x 1 + if x.dim() == 3: + mean = torch.mean(x, (1, 2), keepdim=True) + var = torch.mean((x - mean)**2, (1, 2), keepdim=True) + if self.elementwise_affine: + x = (self.weight * (x - mean) / torch.sqrt(var + self.eps) + + self.bias) # yapf: disable + else: + x = (x - mean) / torch.sqrt(var + self.eps) + + if x.dim() == 4: + mean = torch.mean(x, (1, 2, 3), keepdim=True) + var = torch.mean((x - mean)**2, (1, 2, 3), keepdim=True) + if self.elementwise_affine: + x = (self.weight * (x - mean) / torch.sqrt(var + self.eps) + + self.bias) # yapf: disable + else: + x = (x - mean) / torch.sqrt(var + self.eps) + return x + + +class CumulativeLayerNorm(nn.LayerNorm): + """Calculate Cumulative Layer Normalization. + + Args: + dim: Dimension that you want to normalize. + elementwise_affine: Learnable per-element affine parameters. + + Example + ------- + >>> x = torch.randn(5, 10, 20) + >>> CLN = CumulativeLayerNorm(10) + >>> x_norm = CLN(x) + """ + + def __init__(self, dim, elementwise_affine=True): + super(CumulativeLayerNorm, self).__init__( + dim, elementwise_affine=elementwise_affine, eps=1e-8) + + def forward(self, x): + """Returns the normalized tensor. + + Args: + x: Tensor size [N, C, K, S] or [N, C, L] + """ + # N x K x S x C + if x.dim() == 4: + x = x.permute(0, 2, 3, 1).contiguous() + # N x K x S x C == only channel norm + x = super().forward(x) + # N x C x K x S + x = x.permute(0, 3, 1, 2).contiguous() + if x.dim() == 3: + x = torch.transpose(x, 1, 2) + # N x L x C == only channel norm + x = super().forward(x) + # N x C x L + x = torch.transpose(x, 1, 2) + return x + + +def select_norm(norm, dim, shape): + """Just a wrapper to select the normalization type. + """ + + if norm == 'gln': + return GlobalLayerNorm(dim, shape, elementwise_affine=True) + if norm == 'cln': + return CumulativeLayerNorm(dim, elementwise_affine=True) + if norm == 'ln': + return nn.GroupNorm(1, dim, eps=1e-8) + else: + return nn.BatchNorm1d(dim) + + +class Swish(nn.Module): + """ + Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied + to a variety of challenging domains such as Image classification and Machine translation. + """ + + def __init__(self): + super(Swish, self).__init__() + + def forward(self, inputs: Tensor) -> Tensor: + return inputs * inputs.sigmoid() + + +class GLU(nn.Module): + """ + The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing + in the paper “Language Modeling with Gated Convolutional Networks” + """ + + def __init__(self, dim: int) -> None: + super(GLU, self).__init__() + self.dim = dim + + def forward(self, inputs: Tensor) -> Tensor: + outputs, gate = inputs.chunk(2, dim=self.dim) + return outputs * gate.sigmoid() + + +class Transpose(nn.Module): + """ Wrapper class of torch.transpose() for Sequential module. """ + + def __init__(self, shape: tuple): + super(Transpose, self).__init__() + self.shape = shape + + def forward(self, x: Tensor) -> Tensor: + return x.transpose(*self.shape) + + +class Linear(nn.Module): + """ + Wrapper class of torch.nn.Linear + Weight initialize by xavier initialization and bias initialize to zeros. + """ + + def __init__(self, + in_features: int, + out_features: int, + bias: bool = True) -> None: + super(Linear, self).__init__() + self.linear = nn.Linear(in_features, out_features, bias=bias) + init.xavier_uniform_(self.linear.weight) + if bias: + init.zeros_(self.linear.bias) + + def forward(self, x: Tensor) -> Tensor: + return self.linear(x) + + +class DepthwiseConv1d(nn.Module): + """ + When groups == in_channels and out_channels == K * in_channels, where K is a positive integer, + this operation is termed in literature as depthwise convolution. + Args: + in_channels (int): Number of channels in the input + out_channels (int): Number of channels produced by the convolution + kernel_size (int or tuple): Size of the convolving kernel + stride (int, optional): Stride of the convolution. Default: 1 + padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 + bias (bool, optional): If True, adds a learnable bias to the output. Default: True + Inputs: inputs + - **inputs** (batch, in_channels, time): Tensor containing input vector + Returns: outputs + - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: int = 0, + bias: bool = False, + ) -> None: + super(DepthwiseConv1d, self).__init__() + assert out_channels % in_channels == 0, 'out_channels should be constant multiple of in_channels' + self.conv = nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + groups=in_channels, + stride=stride, + padding=padding, + bias=bias, + ) + + def forward(self, inputs: Tensor) -> Tensor: + return self.conv(inputs) + + +class MossFormerConvModule(nn.Module): + """Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU). + + This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution + to aid training deep models. + + Args: + in_channels (int): Number of channels in the input + kernel_size (int or tuple, optional): Size of the convolving kernel Default: 17 + dropout_p (float, optional): probability of dropout + """ + + def __init__(self, + in_channels: int, + kernel_size: int = 17, + expansion_factor: int = 2) -> None: + super(MossFormerConvModule, self).__init__() + assert ( + kernel_size - 1 + ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" + assert expansion_factor == 2, 'Currently, Only Supports expansion_factor 2' + + self.sequential = nn.Sequential( + Transpose(shape=(1, 2)), + DepthwiseConv1d( + in_channels, + in_channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2), + ) + + def forward(self, inputs: Tensor) -> Tensor: + """ + Args: + inputs (batch, time, dim): Tensor contains input sequences + + Returns: + outputs (batch, time, dim): Tensor produces by conformer convolution module. + """ + return inputs + self.sequential(inputs).transpose(1, 2) diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py index bbd05740..784dbf71 100644 --- a/modelscope/outputs/outputs.py +++ b/modelscope/outputs/outputs.py @@ -27,6 +27,7 @@ class OutputKeys(object): OUTPUT_IMG = 'output_img' OUTPUT_VIDEO = 'output_video' OUTPUT_PCM = 'output_pcm' + OUTPUT_PCM_LIST = 'output_pcm_list' OUTPUT_WAV = 'output_wav' IMG_EMBEDDING = 'img_embedding' SPO_LIST = 'spo_list' @@ -696,6 +697,7 @@ TASK_OUTPUTS = { Tasks.speech_signal_process: [OutputKeys.OUTPUT_PCM], Tasks.acoustic_echo_cancellation: [OutputKeys.OUTPUT_PCM], Tasks.acoustic_noise_suppression: [OutputKeys.OUTPUT_PCM], + Tasks.speech_separation: [OutputKeys.OUTPUT_PCM_LIST], # text_to_speech result for a single sample # { diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py index ffb1ac7d..2732be98 100644 --- a/modelscope/pipeline_inputs.py +++ b/modelscope/pipeline_inputs.py @@ -214,6 +214,8 @@ TASK_INPUTS = { 'nearend_mic': InputType.AUDIO, 'farend_speech': InputType.AUDIO }, + Tasks.speech_separation: + InputType.AUDIO, Tasks.acoustic_noise_suppression: InputType.AUDIO, Tasks.text_to_speech: diff --git a/modelscope/pipelines/audio/separation_pipeline.py b/modelscope/pipelines/audio/separation_pipeline.py new file mode 100644 index 00000000..8942cdb3 --- /dev/null +++ b/modelscope/pipelines/audio/separation_pipeline.py @@ -0,0 +1,68 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import io +from typing import Any, Dict + +import numpy +import soundfile as sf +import torch + +from modelscope.fileio import File +from modelscope.metainfo import Pipelines +from modelscope.models.base import Input +from modelscope.outputs import OutputKeys +from modelscope.pipelines import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.speech_separation, module_name=Pipelines.speech_separation) +class SeparationPipeline(Pipeline): + + def __init__(self, model, **kwargs): + """create a speech separation pipeline for prediction + + Args: + model: model id on modelscope hub. + """ + logger.info('loading model...') + super().__init__(model=model, **kwargs) + self.model.load_check_point(device=self.device) + self.model.eval() + + def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]: + if isinstance(inputs, str): + file_bytes = File.read(inputs) + data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32') + if fs != 8000: + raise ValueError( + 'modelscope error: The audio sample rate should be 8000') + elif isinstance(inputs, bytes): + data = torch.from_numpy( + numpy.frombuffer(inputs, dtype=numpy.float32)) + return dict(data=data) + + def postprocess(self, inputs: Dict[str, Any], + **post_params) -> Dict[str, Any]: + return inputs + + def forward( + self, inputs: Dict[str, Any], **forward_params + ) -> Dict[str, Any]: # mix, targets, stage, noise=None): + """Forward computations from the mixture to the separated signals.""" + logger.info('Start forward...') + # Unpack lists and put tensors in the right device + mix = inputs['data'].to(self.device) + mix = torch.unsqueeze(mix, dim=1).transpose(0, 1) + est_source = self.model(mix) + result = [] + for ns in range(self.model.num_spks): + signal = est_source[0, :, ns] + signal = signal / signal.abs().max() * 0.5 + result.append(signal.unsqueeze(0).cpu()) + logger.info('Finish forward.') + return {OutputKeys.OUTPUT_PCM_LIST: result} diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index e01f1e06..1c615d3e 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -158,6 +158,7 @@ class AudioTasks(object): auto_speech_recognition = 'auto-speech-recognition' text_to_speech = 'text-to-speech' speech_signal_process = 'speech-signal-process' + speech_separation = 'speech-separation' acoustic_echo_cancellation = 'acoustic-echo-cancellation' acoustic_noise_suppression = 'acoustic-noise-suppression' keyword_spotting = 'keyword-spotting' diff --git a/requirements/audio.txt b/requirements/audio.txt index 2dd63417..0e7bf4e3 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -29,9 +29,11 @@ pygments>=2.12.0 pysptk>=0.1.15,<0.2.0 pytorch_wavelets PyWavelets>=1.0.0 +rotary_embedding_torch>=0.1.5 scikit-learn SoundFile>0.10 sox +speechbrain>=0.5 torchaudio tqdm traitlets>=5.3.0 diff --git a/tests/pipelines/test_speech_separation.py b/tests/pipelines/test_speech_separation.py new file mode 100644 index 00000000..219eb9fd --- /dev/null +++ b/tests/pipelines/test_speech_separation.py @@ -0,0 +1,39 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path +import unittest + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + +MIX_SPEECH_FILE = 'data/test/audios/mix_speech.wav' + + +class SpeechSeparationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + pass + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_normal(self): + import torchaudio + model_id = 'damo/speech_mossformer_separation_temporal_8k' + separation = pipeline(Tasks.speech_separation, model=model_id) + result = separation(os.path.join(os.getcwd(), MIX_SPEECH_FILE)) + self.assertTrue(OutputKeys.OUTPUT_PCM_LIST in result) + self.assertEqual(len(result[OutputKeys.OUTPUT_PCM_LIST]), 2) + for i, signal in enumerate(result[OutputKeys.OUTPUT_PCM_LIST]): + save_file = f'output_spk{i}.wav' + # Estimated source + torchaudio.save(save_file, signal, 8000) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main()