mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
[to #42322933] feat:add speech separation pipeline
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11255740
This commit is contained in:
3
data/test/audios/mix_speech.wav
Normal file
3
data/test/audios/mix_speech.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:34c2f1867f7882614b7087f2fd2acb722d0f520a2ec50b2d116d5b3f0c05f84b
|
||||
size 141134
|
||||
@@ -113,6 +113,7 @@ class Models(object):
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
|
||||
speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield'
|
||||
speech_mossformer_separation_temporal_8k = 'speech_mossformer_separation_temporal_8k'
|
||||
kws_kwsbp = 'kws-kwsbp'
|
||||
generic_asr = 'generic-asr'
|
||||
wenet_asr = 'wenet-asr'
|
||||
@@ -317,6 +318,7 @@ class Pipelines(object):
|
||||
speech_dfsmn_aec_psm_16k = 'speech-dfsmn-aec-psm-16k'
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
|
||||
speech_separation = 'speech-separation'
|
||||
kws_kwsbp = 'kws-kwsbp'
|
||||
asr_inference = 'asr-inference'
|
||||
asr_wenet_inference = 'asr-wenet-inference'
|
||||
|
||||
0
modelscope/models/audio/separation/__init__.py
Normal file
0
modelscope/models/audio/separation/__init__.py
Normal file
68
modelscope/models/audio/separation/layer_norm.py
Normal file
68
modelscope/models/audio/separation/layer_norm.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright 2018 Northwestern Polytechnical University (author: Ke Wang)
|
||||
# made publicly available under the MIT License
|
||||
# at https://github.com/wangkenpu/Conv-TasNet-PyTorch/blob/64188ffa48971218fdd68b66906970f215d7eca2/model/layer_norm.py
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class CLayerNorm(nn.LayerNorm):
|
||||
"""Channel-wise layer normalization."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CLayerNorm, self).__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, sample):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
sample: [batch_size, channels, length]
|
||||
"""
|
||||
if sample.dim() != 3:
|
||||
raise RuntimeError('{} only accept 3-D tensor as input'.format(
|
||||
self.__name__))
|
||||
# [N, C, T] -> [N, T, C]
|
||||
sample = torch.transpose(sample, 1, 2)
|
||||
# LayerNorm
|
||||
sample = super().forward(sample)
|
||||
# [N, T, C] -> [N, C, T]
|
||||
sample = torch.transpose(sample, 1, 2)
|
||||
return sample
|
||||
|
||||
|
||||
class GLayerNorm(nn.Module):
|
||||
"""Global Layer Normalization for TasNet."""
|
||||
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super(GLayerNorm, self).__init__()
|
||||
self.eps = eps
|
||||
self.norm_dim = channels
|
||||
self.gamma = nn.Parameter(torch.Tensor(channels))
|
||||
self.beta = nn.Parameter(torch.Tensor(channels))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.ones_(self.gamma)
|
||||
nn.init.zeros_(self.beta)
|
||||
|
||||
def forward(self, sample):
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
sample: [batch_size, channels, length]
|
||||
"""
|
||||
if sample.dim() != 3:
|
||||
raise RuntimeError('{} only accept 3-D tensor as input'.format(
|
||||
self.__name__))
|
||||
# [N, C, T] -> [N, T, C]
|
||||
sample = torch.transpose(sample, 1, 2)
|
||||
# Mean and variance [N, 1, 1]
|
||||
mean = torch.mean(sample, (1, 2), keepdim=True)
|
||||
var = torch.mean((sample - mean)**2, (1, 2), keepdim=True)
|
||||
sample = (sample
|
||||
- mean) / torch.sqrt(var + self.eps) * self.gamma + self.beta
|
||||
# [N, T, C] -> [N, C, T]
|
||||
sample = torch.transpose(sample, 1, 2)
|
||||
return sample
|
||||
472
modelscope/models/audio/separation/mossformer.py
Normal file
472
modelscope/models/audio/separation/mossformer.py
Normal file
@@ -0,0 +1,472 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import copy
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import MODELS, TorchModel
|
||||
from modelscope.models.audio.separation.mossformer_block import (
|
||||
MossFormerModule, ScaledSinuEmbedding)
|
||||
from modelscope.models.audio.separation.mossformer_conv_module import (
|
||||
CumulativeLayerNorm, GlobalLayerNorm)
|
||||
from modelscope.models.base import Tensor
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
EPS = 1e-8
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.speech_separation,
|
||||
module_name=Models.speech_mossformer_separation_temporal_8k)
|
||||
class MossFormer(TorchModel):
|
||||
"""Library to support MossFormer speech separation.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.encoder = Encoder(
|
||||
kernel_size=kwargs['kernel_size'],
|
||||
out_channels=kwargs['out_channels'])
|
||||
self.decoder = Decoder(
|
||||
in_channels=kwargs['in_channels'],
|
||||
out_channels=1,
|
||||
kernel_size=kwargs['kernel_size'],
|
||||
stride=kwargs['stride'],
|
||||
bias=kwargs['bias'])
|
||||
self.mask_net = MossFormerMaskNet(
|
||||
kwargs['in_channels'],
|
||||
kwargs['out_channels'],
|
||||
MossFormerM(kwargs['num_blocks'], kwargs['d_model'],
|
||||
kwargs['attn_dropout'], kwargs['group_size'],
|
||||
kwargs['query_key_dim'], kwargs['expansion_factor'],
|
||||
kwargs['causal']),
|
||||
norm=kwargs['norm'],
|
||||
num_spks=kwargs['num_spks'])
|
||||
self.num_spks = kwargs['num_spks']
|
||||
|
||||
def forward(self, inputs: Tensor) -> Dict[str, Any]:
|
||||
# Separation
|
||||
mix_w = self.encoder(inputs)
|
||||
est_mask = self.mask_net(mix_w)
|
||||
mix_w = torch.stack([mix_w] * self.num_spks)
|
||||
sep_h = mix_w * est_mask
|
||||
# Decoding
|
||||
est_source = torch.cat(
|
||||
[
|
||||
self.decoder(sep_h[i]).unsqueeze(-1)
|
||||
for i in range(self.num_spks)
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
# T changed after conv1d in encoder, fix it here
|
||||
t_origin = inputs.size(1)
|
||||
t_est = est_source.size(1)
|
||||
if t_origin > t_est:
|
||||
est_source = F.pad(est_source, (0, 0, 0, t_origin - t_est))
|
||||
else:
|
||||
est_source = est_source[:, :t_origin, :]
|
||||
return est_source
|
||||
|
||||
def load_check_point(self, load_path=None, device=None):
|
||||
if not load_path:
|
||||
load_path = self.model_dir
|
||||
if not device:
|
||||
device = torch.device('cpu')
|
||||
self.encoder.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(load_path, 'encoder.bin'), map_location=device),
|
||||
strict=True)
|
||||
self.decoder.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(load_path, 'decoder.bin'), map_location=device),
|
||||
strict=True)
|
||||
self.mask_net.load_state_dict(
|
||||
torch.load(
|
||||
os.path.join(load_path, 'masknet.bin'), map_location=device),
|
||||
strict=True)
|
||||
|
||||
|
||||
def select_norm(norm, dim, shape):
|
||||
"""Just a wrapper to select the normalization type.
|
||||
"""
|
||||
|
||||
if norm == 'gln':
|
||||
return GlobalLayerNorm(dim, shape, elementwise_affine=True)
|
||||
if norm == 'cln':
|
||||
return CumulativeLayerNorm(dim, elementwise_affine=True)
|
||||
if norm == 'ln':
|
||||
return nn.GroupNorm(1, dim, eps=1e-8)
|
||||
else:
|
||||
return nn.BatchNorm1d(dim)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Convolutional Encoder Layer.
|
||||
|
||||
Args:
|
||||
kernel_size: Length of filters.
|
||||
in_channels: Number of input channels.
|
||||
out_channels: Number of output channels.
|
||||
|
||||
Example:
|
||||
-------
|
||||
>>> x = torch.randn(2, 1000)
|
||||
>>> encoder = Encoder(kernel_size=4, out_channels=64)
|
||||
>>> h = encoder(x)
|
||||
>>> h.shape
|
||||
torch.Size([2, 64, 499])
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
kernel_size: int = 2,
|
||||
out_channels: int = 64,
|
||||
in_channels: int = 1):
|
||||
super(Encoder, self).__init__()
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size // 2,
|
||||
groups=1,
|
||||
bias=False,
|
||||
)
|
||||
self.in_channels = in_channels
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Return the encoded output.
|
||||
|
||||
Args:
|
||||
x: Input tensor with dimensionality [B, L].
|
||||
|
||||
Returns:
|
||||
Encoded tensor with dimensionality [B, N, T_out].
|
||||
where B = Batchsize
|
||||
L = Number of timepoints
|
||||
N = Number of filters
|
||||
T_out = Number of timepoints at the output of the encoder
|
||||
"""
|
||||
# B x L -> B x 1 x L
|
||||
if self.in_channels == 1:
|
||||
x = torch.unsqueeze(x, dim=1)
|
||||
# B x 1 x L -> B x N x T_out
|
||||
x = self.conv1d(x)
|
||||
x = F.relu(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.ConvTranspose1d):
|
||||
"""A decoder layer that consists of ConvTranspose1d.
|
||||
|
||||
Args:
|
||||
kernel_size: Length of filters.
|
||||
in_channels: Number of input channels.
|
||||
out_channels: Number of output channels.
|
||||
|
||||
Example
|
||||
---------
|
||||
>>> x = torch.randn(2, 100, 1000)
|
||||
>>> decoder = Decoder(kernel_size=4, in_channels=100, out_channels=1)
|
||||
>>> h = decoder(x)
|
||||
>>> h.shape
|
||||
torch.Size([2, 1003])
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Decoder, self).__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
"""Return the decoded output.
|
||||
|
||||
Args:
|
||||
x: Input tensor with dimensionality [B, N, L].
|
||||
where, B = Batchsize,
|
||||
N = number of filters
|
||||
L = time points
|
||||
"""
|
||||
|
||||
if x.dim() not in [2, 3]:
|
||||
raise RuntimeError('{} accept 3/4D tensor as input'.format(
|
||||
self.__name__))
|
||||
x = super().forward(x if x.dim() == 3 else torch.unsqueeze(x, 1))
|
||||
|
||||
if torch.squeeze(x).dim() == 1:
|
||||
x = torch.squeeze(x, dim=1)
|
||||
else:
|
||||
x = torch.squeeze(x)
|
||||
return x
|
||||
|
||||
|
||||
class IdentityBlock:
|
||||
"""This block is used when we want to have identity transformation within the Dual_path block.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> x = torch.randn(10, 100)
|
||||
>>> IB = IdentityBlock()
|
||||
>>> xhat = IB(x)
|
||||
"""
|
||||
|
||||
def _init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class MossFormerM(nn.Module):
|
||||
"""This class implements the transformer encoder.
|
||||
|
||||
Args:
|
||||
num_blocks : int
|
||||
Number of mossformer blocks to include.
|
||||
d_model : int
|
||||
The dimension of the input embedding.
|
||||
attn_dropout : float
|
||||
Dropout for the self-attention (Optional).
|
||||
group_size: int
|
||||
the chunk size
|
||||
query_key_dim: int
|
||||
the attention vector dimension
|
||||
expansion_factor: int
|
||||
the expansion factor for the linear projection in conv module
|
||||
causal: bool
|
||||
true for causal / false for non causal
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import torch
|
||||
>>> x = torch.rand((8, 60, 512)) #B, S, N
|
||||
>>> net = MossFormerM(num_blocks=8, d_model=512)
|
||||
>>> output, _ = net(x)
|
||||
>>> output.shape
|
||||
torch.Size([8, 60, 512])
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_blocks,
|
||||
d_model=None,
|
||||
attn_dropout=0.1,
|
||||
group_size=256,
|
||||
query_key_dim=128,
|
||||
expansion_factor=4.,
|
||||
causal=False):
|
||||
super().__init__()
|
||||
|
||||
self.mossformerM = MossFormerModule(
|
||||
dim=d_model,
|
||||
depth=num_blocks,
|
||||
group_size=group_size,
|
||||
query_key_dim=query_key_dim,
|
||||
expansion_factor=expansion_factor,
|
||||
causal=causal,
|
||||
attn_dropout=attn_dropout)
|
||||
import speechbrain as sb
|
||||
self.norm = sb.nnet.normalization.LayerNorm(d_model, eps=1e-6)
|
||||
|
||||
def forward(self, src: torch.Tensor):
|
||||
"""
|
||||
Args:
|
||||
src: Tensor shape [B, S, N],
|
||||
where, B = Batchsize,
|
||||
S = time points
|
||||
N = number of filters
|
||||
The sequence to the encoder layer (required).
|
||||
"""
|
||||
output = self.mossformerM(src)
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ComputeAttention(nn.Module):
|
||||
"""Computation block for dual-path processing.
|
||||
|
||||
Args:
|
||||
att_mdl : torch.nn.module
|
||||
Model to process within the chunks.
|
||||
out_channels : int
|
||||
Dimensionality of attention model.
|
||||
norm : str
|
||||
Normalization type.
|
||||
skip_connection : bool
|
||||
Skip connection around the attention module.
|
||||
|
||||
Example
|
||||
---------
|
||||
>>> att_block = MossFormerM(num_blocks=8, d_model=512)
|
||||
>>> comp_att = ComputeAttention(att_block, 512)
|
||||
>>> x = torch.randn(10, 64, 512)
|
||||
>>> x = comp_att(x)
|
||||
>>> x.shape
|
||||
torch.Size([10, 64, 512])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
att_mdl,
|
||||
out_channels,
|
||||
norm='ln',
|
||||
skip_connection=True,
|
||||
):
|
||||
super(ComputeAttention, self).__init__()
|
||||
|
||||
self.att_mdl = att_mdl
|
||||
self.skip_connection = skip_connection
|
||||
|
||||
# Norm
|
||||
self.norm = norm
|
||||
if norm is not None:
|
||||
self.att_norm = select_norm(norm, out_channels, 3)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Returns the output tensor.
|
||||
|
||||
Args:
|
||||
x: Input tensor of dimension [B, S, N].
|
||||
|
||||
Returns:
|
||||
out: Output tensor of dimension [B, S, N].
|
||||
where, B = Batchsize,
|
||||
N = number of filters
|
||||
S = time points
|
||||
"""
|
||||
# [B, S, N]
|
||||
att_out = x.permute(0, 2, 1).contiguous()
|
||||
|
||||
att_out = self.att_mdl(att_out)
|
||||
|
||||
# [B, N, S]
|
||||
att_out = att_out.permute(0, 2, 1).contiguous()
|
||||
if self.norm is not None:
|
||||
att_out = self.att_norm(att_out)
|
||||
|
||||
# [B, N, S]
|
||||
if self.skip_connection:
|
||||
att_out = att_out + x
|
||||
|
||||
out = att_out
|
||||
return out
|
||||
|
||||
|
||||
class MossFormerMaskNet(nn.Module):
|
||||
"""The dual path model which is the basis for dualpathrnn, sepformer, dptnet.
|
||||
|
||||
Args:
|
||||
in_channels : int
|
||||
Number of channels at the output of the encoder.
|
||||
out_channels : int
|
||||
Number of channels that would be inputted to the intra and inter blocks.
|
||||
att_model : torch.nn.module
|
||||
Attention model to process the input sequence.
|
||||
norm : str
|
||||
Normalization type.
|
||||
num_spks : int
|
||||
Number of sources (speakers).
|
||||
skip_connection : bool
|
||||
Skip connection around attention module.
|
||||
use_global_pos_enc : bool
|
||||
Global positional encodings.
|
||||
|
||||
Example
|
||||
---------
|
||||
>>> mossformer_block = MossFormerM(num_blocks=8, d_model=512)
|
||||
>>> mossformer_masknet = MossFormerMaskNet(64, 64, att_model, num_spks=2)
|
||||
>>> x = torch.randn(10, 64, 2000)
|
||||
>>> x = mossformer_masknet(x)
|
||||
>>> x.shape
|
||||
torch.Size([2, 10, 64, 2000])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
att_model,
|
||||
norm='ln',
|
||||
num_spks=2,
|
||||
skip_connection=True,
|
||||
use_global_pos_enc=True,
|
||||
):
|
||||
super(MossFormerMaskNet, self).__init__()
|
||||
self.num_spks = num_spks
|
||||
self.norm = select_norm(norm, in_channels, 3)
|
||||
self.conv1d_encoder = nn.Conv1d(
|
||||
in_channels, out_channels, 1, bias=False)
|
||||
self.use_global_pos_enc = use_global_pos_enc
|
||||
|
||||
if self.use_global_pos_enc:
|
||||
self.pos_enc = ScaledSinuEmbedding(out_channels)
|
||||
|
||||
self.mdl = copy.deepcopy(
|
||||
ComputeAttention(
|
||||
att_model,
|
||||
out_channels,
|
||||
norm,
|
||||
skip_connection=skip_connection,
|
||||
))
|
||||
|
||||
self.conv1d_out = nn.Conv1d(
|
||||
out_channels, out_channels * num_spks, kernel_size=1)
|
||||
self.conv1_decoder = nn.Conv1d(
|
||||
out_channels, in_channels, 1, bias=False)
|
||||
self.prelu = nn.PReLU()
|
||||
self.activation = nn.ReLU()
|
||||
# gated output layer
|
||||
self.output = nn.Sequential(
|
||||
nn.Conv1d(out_channels, out_channels, 1), nn.Tanh())
|
||||
self.output_gate = nn.Sequential(
|
||||
nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid())
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Returns the output tensor.
|
||||
|
||||
Args:
|
||||
x: Input tensor of dimension [B, N, S].
|
||||
|
||||
Returns:
|
||||
out: Output tensor of dimension [spks, B, N, S]
|
||||
where, spks = Number of speakers
|
||||
B = Batchsize,
|
||||
N = number of filters
|
||||
S = the number of time frames
|
||||
"""
|
||||
|
||||
# before each line we indicate the shape after executing the line
|
||||
# [B, N, L]
|
||||
x = self.norm(x)
|
||||
# [B, N, L]
|
||||
x = self.conv1d_encoder(x)
|
||||
if self.use_global_pos_enc:
|
||||
base = x
|
||||
x = x.transpose(1, -1)
|
||||
emb = self.pos_enc(x)
|
||||
emb = emb.transpose(0, -1)
|
||||
x = base + emb
|
||||
# [B, N, S]
|
||||
x = self.mdl(x)
|
||||
x = self.prelu(x)
|
||||
# [B, N*spks, S]
|
||||
x = self.conv1d_out(x)
|
||||
b, _, s = x.shape
|
||||
# [B*spks, N, S]
|
||||
x = x.view(b * self.num_spks, -1, s)
|
||||
# [B*spks, N, S]
|
||||
x = self.output(x) * self.output_gate(x)
|
||||
# [B*spks, N, S]
|
||||
x = self.conv1_decoder(x)
|
||||
# [B, spks, N, S]
|
||||
_, n, L = x.shape
|
||||
x = x.view(b, self.num_spks, n, L)
|
||||
x = self.activation(x)
|
||||
# [spks, B, N, S]
|
||||
x = x.transpose(0, 1)
|
||||
return x
|
||||
265
modelscope/models/audio/separation/mossformer_block.py
Normal file
265
modelscope/models/audio/separation/mossformer_block.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import einsum, nn
|
||||
|
||||
from modelscope.models.audio.separation.mossformer_conv_module import \
|
||||
MossFormerConvModule
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def padding_to_multiple_of(n, mult):
|
||||
remainder = n % mult
|
||||
if remainder == 0:
|
||||
return 0
|
||||
return mult - remainder
|
||||
|
||||
|
||||
class ScaleNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.scale = dim**-0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1))
|
||||
|
||||
def forward(self, x):
|
||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||
return x / norm.clamp(min=self.eps) * self.g
|
||||
|
||||
|
||||
class ScaledSinuEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.ones(1, ))
|
||||
inv_freq = 1. / (10000**(torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
|
||||
def forward(self, x):
|
||||
n, device = x.shape[1], x.device
|
||||
t = torch.arange(n, device=device).type_as(self.inv_freq)
|
||||
sinu = einsum('i , j -> i j', t, self.inv_freq)
|
||||
emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1)
|
||||
return emb * self.scale
|
||||
|
||||
|
||||
class OffsetScale(nn.Module):
|
||||
|
||||
def __init__(self, dim, heads=1):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.ones(heads, dim))
|
||||
self.beta = nn.Parameter(torch.zeros(heads, dim))
|
||||
nn.init.normal_(self.gamma, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
|
||||
return out.unbind(dim=-2)
|
||||
|
||||
|
||||
class FFConvM(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out, norm_klass=nn.LayerNorm, dropout=0.1):
|
||||
super().__init__()
|
||||
self.mdl = nn.Sequential(
|
||||
norm_klass(dim_in), nn.Linear(dim_in, dim_out), nn.SiLU(),
|
||||
MossFormerConvModule(dim_out), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x):
|
||||
output = self.mdl(x)
|
||||
return output
|
||||
|
||||
|
||||
class MossFormerBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
group_size=256,
|
||||
query_key_dim=128,
|
||||
expansion_factor=1.,
|
||||
causal=False,
|
||||
dropout=0.1,
|
||||
rotary_pos_emb=None,
|
||||
norm_klass=nn.LayerNorm,
|
||||
shift_tokens=True):
|
||||
super().__init__()
|
||||
hidden_dim = int(dim * expansion_factor)
|
||||
self.group_size = group_size
|
||||
self.causal = causal
|
||||
self.shift_tokens = shift_tokens
|
||||
# positional embeddings
|
||||
self.rotary_pos_emb = rotary_pos_emb
|
||||
# norm
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
# projections
|
||||
self.to_hidden = FFConvM(
|
||||
dim_in=dim,
|
||||
dim_out=hidden_dim,
|
||||
norm_klass=norm_klass,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.to_qk = FFConvM(
|
||||
dim_in=dim,
|
||||
dim_out=query_key_dim,
|
||||
norm_klass=norm_klass,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.qk_offset_scale = OffsetScale(query_key_dim, heads=4)
|
||||
self.to_out = FFConvM(
|
||||
dim_in=dim * 2,
|
||||
dim_out=dim,
|
||||
norm_klass=norm_klass,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.gateActivate = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
# prenorm
|
||||
normed_x = x
|
||||
# do token shift - a great, costless trick from an independent AI researcher in Shenzhen
|
||||
if self.shift_tokens:
|
||||
x_shift, x_pass = normed_x.chunk(2, dim=-1)
|
||||
x_shift = F.pad(x_shift, (0, 0, 1, -1), value=0.)
|
||||
normed_x = torch.cat((x_shift, x_pass), dim=-1)
|
||||
|
||||
# initial projections
|
||||
v, u = self.to_hidden(normed_x).chunk(2, dim=-1)
|
||||
qk = self.to_qk(normed_x)
|
||||
# offset and scale
|
||||
quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
|
||||
att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v,
|
||||
u)
|
||||
|
||||
# projection out and residual
|
||||
out = (att_u * v) * self.gateActivate(att_v * u)
|
||||
x = x + self.to_out(out)
|
||||
return x
|
||||
|
||||
def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask=None):
|
||||
b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size
|
||||
|
||||
from einops import rearrange
|
||||
if exists(mask):
|
||||
lin_mask = rearrange(mask, '... -> ... 1')
|
||||
lin_k = lin_k.masked_fill(~lin_mask, 0.)
|
||||
|
||||
# rotate queries and keys
|
||||
if exists(self.rotary_pos_emb):
|
||||
quad_q, lin_q, quad_k, lin_k = map(
|
||||
self.rotary_pos_emb.rotate_queries_or_keys,
|
||||
(quad_q, lin_q, quad_k, lin_k))
|
||||
|
||||
# padding for groups
|
||||
padding = padding_to_multiple_of(n, g)
|
||||
if padding > 0:
|
||||
quad_q, quad_k, lin_q, lin_k, v, u = map(
|
||||
lambda t: F.pad(t, (0, 0, 0, padding), value=0.),
|
||||
(quad_q, quad_k, lin_q, lin_k, v, u))
|
||||
mask = default(mask,
|
||||
torch.ones((b, n), device=device, dtype=torch.bool))
|
||||
mask = F.pad(mask, (0, padding), value=False)
|
||||
|
||||
# group along sequence
|
||||
quad_q, quad_k, lin_q, lin_k, v, u = map(
|
||||
lambda t: rearrange(t, 'b (g n) d -> b g n d', n=self.group_size),
|
||||
(quad_q, quad_k, lin_q, lin_k, v, u))
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b (g j) -> b g 1 j', j=g)
|
||||
|
||||
# calculate quadratic attention output
|
||||
sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g
|
||||
attn = F.relu(sim)**2
|
||||
attn = self.dropout(attn)
|
||||
if exists(mask):
|
||||
attn = attn.masked_fill(~mask, 0.)
|
||||
if self.causal:
|
||||
causal_mask = torch.ones((g, g), dtype=torch.bool,
|
||||
device=device).triu(1)
|
||||
attn = attn.masked_fill(causal_mask, 0.)
|
||||
|
||||
quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
|
||||
quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)
|
||||
|
||||
# calculate linear attention output
|
||||
if self.causal:
|
||||
lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
|
||||
# exclusive cumulative sum along group dimension
|
||||
lin_kv = lin_kv.cumsum(dim=1)
|
||||
lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value=0.)
|
||||
lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)
|
||||
|
||||
lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
|
||||
# exclusive cumulative sum along group dimension
|
||||
lin_ku = lin_ku.cumsum(dim=1)
|
||||
lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value=0.)
|
||||
lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
|
||||
else:
|
||||
lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
|
||||
lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)
|
||||
|
||||
lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
|
||||
lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)
|
||||
|
||||
# fold back groups into full sequence, and excise out padding
|
||||
quad_attn_out_v, lin_attn_out_v = map(
|
||||
lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n],
|
||||
(quad_out_v, lin_out_v))
|
||||
quad_attn_out_u, lin_attn_out_u = map(
|
||||
lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n],
|
||||
(quad_out_u, lin_out_u))
|
||||
|
||||
# gate
|
||||
return quad_attn_out_v + lin_attn_out_v, quad_attn_out_u + lin_attn_out_u
|
||||
|
||||
|
||||
class MossFormerModule(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
depth,
|
||||
group_size=256,
|
||||
query_key_dim=128,
|
||||
expansion_factor=4.,
|
||||
causal=False,
|
||||
attn_dropout=0.1,
|
||||
norm_type='scalenorm',
|
||||
shift_tokens=True):
|
||||
super().__init__()
|
||||
assert norm_type in (
|
||||
'scalenorm',
|
||||
'layernorm'), 'norm_type must be one of scalenorm or layernorm'
|
||||
|
||||
if norm_type == 'scalenorm':
|
||||
norm_klass = ScaleNorm
|
||||
elif norm_type == 'layernorm':
|
||||
norm_klass = nn.LayerNorm
|
||||
|
||||
from rotary_embedding_torch import RotaryEmbedding
|
||||
rotary_pos_emb = RotaryEmbedding(dim=min(32, query_key_dim))
|
||||
# max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
|
||||
self.layers = nn.ModuleList([
|
||||
MossFormerBlock(
|
||||
dim=dim,
|
||||
group_size=group_size,
|
||||
query_key_dim=query_key_dim,
|
||||
expansion_factor=expansion_factor,
|
||||
causal=causal,
|
||||
dropout=attn_dropout,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
norm_klass=norm_klass,
|
||||
shift_tokens=shift_tokens) for _ in range(depth)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
for mossformer_layer in self.layers:
|
||||
x = mossformer_layer(x)
|
||||
return x
|
||||
272
modelscope/models/audio/separation/mossformer_conv_module.py
Normal file
272
modelscope/models/audio/separation/mossformer_conv_module.py
Normal file
@@ -0,0 +1,272 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
from torch import Tensor
|
||||
|
||||
EPS = 1e-8
|
||||
|
||||
|
||||
class GlobalLayerNorm(nn.Module):
|
||||
"""Calculate Global Layer Normalization.
|
||||
|
||||
Args:
|
||||
dim : (int or list or torch.Size)
|
||||
Input shape from an expected input of size.
|
||||
eps : float
|
||||
A value added to the denominator for numerical stability.
|
||||
elementwise_affine : bool
|
||||
A boolean value that when set to True,
|
||||
this module has learnable per-element affine parameters
|
||||
initialized to ones (for weights) and zeros (for biases).
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> x = torch.randn(5, 10, 20)
|
||||
>>> GLN = GlobalLayerNorm(10, 3)
|
||||
>>> x_norm = GLN(x)
|
||||
"""
|
||||
|
||||
def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
|
||||
super(GlobalLayerNorm, self).__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
|
||||
if self.elementwise_affine:
|
||||
if shape == 3:
|
||||
self.weight = nn.Parameter(torch.ones(self.dim, 1))
|
||||
self.bias = nn.Parameter(torch.zeros(self.dim, 1))
|
||||
if shape == 4:
|
||||
self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
|
||||
self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def forward(self, x):
|
||||
"""Returns the normalized tensor.
|
||||
|
||||
Args:
|
||||
x: Tensor of size [N, C, K, S] or [N, C, L].
|
||||
"""
|
||||
# N x 1 x 1
|
||||
# cln: mean,var N x 1 x K x S
|
||||
# gln: mean,var N x 1 x 1
|
||||
if x.dim() == 3:
|
||||
mean = torch.mean(x, (1, 2), keepdim=True)
|
||||
var = torch.mean((x - mean)**2, (1, 2), keepdim=True)
|
||||
if self.elementwise_affine:
|
||||
x = (self.weight * (x - mean) / torch.sqrt(var + self.eps)
|
||||
+ self.bias) # yapf: disable
|
||||
else:
|
||||
x = (x - mean) / torch.sqrt(var + self.eps)
|
||||
|
||||
if x.dim() == 4:
|
||||
mean = torch.mean(x, (1, 2, 3), keepdim=True)
|
||||
var = torch.mean((x - mean)**2, (1, 2, 3), keepdim=True)
|
||||
if self.elementwise_affine:
|
||||
x = (self.weight * (x - mean) / torch.sqrt(var + self.eps)
|
||||
+ self.bias) # yapf: disable
|
||||
else:
|
||||
x = (x - mean) / torch.sqrt(var + self.eps)
|
||||
return x
|
||||
|
||||
|
||||
class CumulativeLayerNorm(nn.LayerNorm):
|
||||
"""Calculate Cumulative Layer Normalization.
|
||||
|
||||
Args:
|
||||
dim: Dimension that you want to normalize.
|
||||
elementwise_affine: Learnable per-element affine parameters.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> x = torch.randn(5, 10, 20)
|
||||
>>> CLN = CumulativeLayerNorm(10)
|
||||
>>> x_norm = CLN(x)
|
||||
"""
|
||||
|
||||
def __init__(self, dim, elementwise_affine=True):
|
||||
super(CumulativeLayerNorm, self).__init__(
|
||||
dim, elementwise_affine=elementwise_affine, eps=1e-8)
|
||||
|
||||
def forward(self, x):
|
||||
"""Returns the normalized tensor.
|
||||
|
||||
Args:
|
||||
x: Tensor size [N, C, K, S] or [N, C, L]
|
||||
"""
|
||||
# N x K x S x C
|
||||
if x.dim() == 4:
|
||||
x = x.permute(0, 2, 3, 1).contiguous()
|
||||
# N x K x S x C == only channel norm
|
||||
x = super().forward(x)
|
||||
# N x C x K x S
|
||||
x = x.permute(0, 3, 1, 2).contiguous()
|
||||
if x.dim() == 3:
|
||||
x = torch.transpose(x, 1, 2)
|
||||
# N x L x C == only channel norm
|
||||
x = super().forward(x)
|
||||
# N x C x L
|
||||
x = torch.transpose(x, 1, 2)
|
||||
return x
|
||||
|
||||
|
||||
def select_norm(norm, dim, shape):
|
||||
"""Just a wrapper to select the normalization type.
|
||||
"""
|
||||
|
||||
if norm == 'gln':
|
||||
return GlobalLayerNorm(dim, shape, elementwise_affine=True)
|
||||
if norm == 'cln':
|
||||
return CumulativeLayerNorm(dim, elementwise_affine=True)
|
||||
if norm == 'ln':
|
||||
return nn.GroupNorm(1, dim, eps=1e-8)
|
||||
else:
|
||||
return nn.BatchNorm1d(dim)
|
||||
|
||||
|
||||
class Swish(nn.Module):
|
||||
"""
|
||||
Swish is a smooth, non-monotonic function that consistently matches or outperforms ReLU on deep networks applied
|
||||
to a variety of challenging domains such as Image classification and Machine translation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(Swish, self).__init__()
|
||||
|
||||
def forward(self, inputs: Tensor) -> Tensor:
|
||||
return inputs * inputs.sigmoid()
|
||||
|
||||
|
||||
class GLU(nn.Module):
|
||||
"""
|
||||
The gating mechanism is called Gated Linear Units (GLU), which was first introduced for natural language processing
|
||||
in the paper “Language Modeling with Gated Convolutional Networks”
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int) -> None:
|
||||
super(GLU, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, inputs: Tensor) -> Tensor:
|
||||
outputs, gate = inputs.chunk(2, dim=self.dim)
|
||||
return outputs * gate.sigmoid()
|
||||
|
||||
|
||||
class Transpose(nn.Module):
|
||||
""" Wrapper class of torch.transpose() for Sequential module. """
|
||||
|
||||
def __init__(self, shape: tuple):
|
||||
super(Transpose, self).__init__()
|
||||
self.shape = shape
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.transpose(*self.shape)
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
"""
|
||||
Wrapper class of torch.nn.Linear
|
||||
Weight initialize by xavier initialization and bias initialize to zeros.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True) -> None:
|
||||
super(Linear, self).__init__()
|
||||
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
||||
init.xavier_uniform_(self.linear.weight)
|
||||
if bias:
|
||||
init.zeros_(self.linear.bias)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class DepthwiseConv1d(nn.Module):
|
||||
"""
|
||||
When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
|
||||
this operation is termed in literature as depthwise convolution.
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input
|
||||
out_channels (int): Number of channels produced by the convolution
|
||||
kernel_size (int or tuple): Size of the convolving kernel
|
||||
stride (int, optional): Stride of the convolution. Default: 1
|
||||
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
|
||||
bias (bool, optional): If True, adds a learnable bias to the output. Default: True
|
||||
Inputs: inputs
|
||||
- **inputs** (batch, in_channels, time): Tensor containing input vector
|
||||
Returns: outputs
|
||||
- **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
bias: bool = False,
|
||||
) -> None:
|
||||
super(DepthwiseConv1d, self).__init__()
|
||||
assert out_channels % in_channels == 0, 'out_channels should be constant multiple of in_channels'
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
groups=in_channels,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, inputs: Tensor) -> Tensor:
|
||||
return self.conv(inputs)
|
||||
|
||||
|
||||
class MossFormerConvModule(nn.Module):
|
||||
"""Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
|
||||
|
||||
This is followed by a single 1-D depthwise convolution layer. Batchnorm is deployed just after the convolution
|
||||
to aid training deep models.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels in the input
|
||||
kernel_size (int or tuple, optional): Size of the convolving kernel Default: 17
|
||||
dropout_p (float, optional): probability of dropout
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
kernel_size: int = 17,
|
||||
expansion_factor: int = 2) -> None:
|
||||
super(MossFormerConvModule, self).__init__()
|
||||
assert (
|
||||
kernel_size - 1
|
||||
) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
|
||||
assert expansion_factor == 2, 'Currently, Only Supports expansion_factor 2'
|
||||
|
||||
self.sequential = nn.Sequential(
|
||||
Transpose(shape=(1, 2)),
|
||||
DepthwiseConv1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2),
|
||||
)
|
||||
|
||||
def forward(self, inputs: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
inputs (batch, time, dim): Tensor contains input sequences
|
||||
|
||||
Returns:
|
||||
outputs (batch, time, dim): Tensor produces by conformer convolution module.
|
||||
"""
|
||||
return inputs + self.sequential(inputs).transpose(1, 2)
|
||||
@@ -27,6 +27,7 @@ class OutputKeys(object):
|
||||
OUTPUT_IMG = 'output_img'
|
||||
OUTPUT_VIDEO = 'output_video'
|
||||
OUTPUT_PCM = 'output_pcm'
|
||||
OUTPUT_PCM_LIST = 'output_pcm_list'
|
||||
OUTPUT_WAV = 'output_wav'
|
||||
IMG_EMBEDDING = 'img_embedding'
|
||||
SPO_LIST = 'spo_list'
|
||||
@@ -696,6 +697,7 @@ TASK_OUTPUTS = {
|
||||
Tasks.speech_signal_process: [OutputKeys.OUTPUT_PCM],
|
||||
Tasks.acoustic_echo_cancellation: [OutputKeys.OUTPUT_PCM],
|
||||
Tasks.acoustic_noise_suppression: [OutputKeys.OUTPUT_PCM],
|
||||
Tasks.speech_separation: [OutputKeys.OUTPUT_PCM_LIST],
|
||||
|
||||
# text_to_speech result for a single sample
|
||||
# {
|
||||
|
||||
@@ -214,6 +214,8 @@ TASK_INPUTS = {
|
||||
'nearend_mic': InputType.AUDIO,
|
||||
'farend_speech': InputType.AUDIO
|
||||
},
|
||||
Tasks.speech_separation:
|
||||
InputType.AUDIO,
|
||||
Tasks.acoustic_noise_suppression:
|
||||
InputType.AUDIO,
|
||||
Tasks.text_to_speech:
|
||||
|
||||
68
modelscope/pipelines/audio/separation_pipeline.py
Normal file
68
modelscope/pipelines/audio/separation_pipeline.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import io
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy
|
||||
import soundfile as sf
|
||||
import torch
|
||||
|
||||
from modelscope.fileio import File
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.base import Input
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.speech_separation, module_name=Pipelines.speech_separation)
|
||||
class SeparationPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model, **kwargs):
|
||||
"""create a speech separation pipeline for prediction
|
||||
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
logger.info('loading model...')
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model.load_check_point(device=self.device)
|
||||
self.model.eval()
|
||||
|
||||
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
|
||||
if isinstance(inputs, str):
|
||||
file_bytes = File.read(inputs)
|
||||
data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
|
||||
if fs != 8000:
|
||||
raise ValueError(
|
||||
'modelscope error: The audio sample rate should be 8000')
|
||||
elif isinstance(inputs, bytes):
|
||||
data = torch.from_numpy(
|
||||
numpy.frombuffer(inputs, dtype=numpy.float32))
|
||||
return dict(data=data)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any],
|
||||
**post_params) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
def forward(
|
||||
self, inputs: Dict[str, Any], **forward_params
|
||||
) -> Dict[str, Any]: # mix, targets, stage, noise=None):
|
||||
"""Forward computations from the mixture to the separated signals."""
|
||||
logger.info('Start forward...')
|
||||
# Unpack lists and put tensors in the right device
|
||||
mix = inputs['data'].to(self.device)
|
||||
mix = torch.unsqueeze(mix, dim=1).transpose(0, 1)
|
||||
est_source = self.model(mix)
|
||||
result = []
|
||||
for ns in range(self.model.num_spks):
|
||||
signal = est_source[0, :, ns]
|
||||
signal = signal / signal.abs().max() * 0.5
|
||||
result.append(signal.unsqueeze(0).cpu())
|
||||
logger.info('Finish forward.')
|
||||
return {OutputKeys.OUTPUT_PCM_LIST: result}
|
||||
@@ -158,6 +158,7 @@ class AudioTasks(object):
|
||||
auto_speech_recognition = 'auto-speech-recognition'
|
||||
text_to_speech = 'text-to-speech'
|
||||
speech_signal_process = 'speech-signal-process'
|
||||
speech_separation = 'speech-separation'
|
||||
acoustic_echo_cancellation = 'acoustic-echo-cancellation'
|
||||
acoustic_noise_suppression = 'acoustic-noise-suppression'
|
||||
keyword_spotting = 'keyword-spotting'
|
||||
|
||||
@@ -29,9 +29,11 @@ pygments>=2.12.0
|
||||
pysptk>=0.1.15,<0.2.0
|
||||
pytorch_wavelets
|
||||
PyWavelets>=1.0.0
|
||||
rotary_embedding_torch>=0.1.5
|
||||
scikit-learn
|
||||
SoundFile>0.10
|
||||
sox
|
||||
speechbrain>=0.5
|
||||
torchaudio
|
||||
tqdm
|
||||
traitlets>=5.3.0
|
||||
|
||||
39
tests/pipelines/test_speech_separation.py
Normal file
39
tests/pipelines/test_speech_separation.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os.path
|
||||
import unittest
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
MIX_SPEECH_FILE = 'data/test/audios/mix_speech.wav'
|
||||
|
||||
|
||||
class SpeechSeparationTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
pass
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_normal(self):
|
||||
import torchaudio
|
||||
model_id = 'damo/speech_mossformer_separation_temporal_8k'
|
||||
separation = pipeline(Tasks.speech_separation, model=model_id)
|
||||
result = separation(os.path.join(os.getcwd(), MIX_SPEECH_FILE))
|
||||
self.assertTrue(OutputKeys.OUTPUT_PCM_LIST in result)
|
||||
self.assertEqual(len(result[OutputKeys.OUTPUT_PCM_LIST]), 2)
|
||||
for i, signal in enumerate(result[OutputKeys.OUTPUT_PCM_LIST]):
|
||||
save_file = f'output_spk{i}.wav'
|
||||
# Estimated source
|
||||
torchaudio.save(save_file, signal, 8000)
|
||||
|
||||
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
|
||||
def test_demo_compatibility(self):
|
||||
self.compatibility_check()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user