detection and extraction

This commit is contained in:
yangdongchao
2023-04-06 00:11:23 +08:00
parent 7ee017cf0d
commit 322ed8cbb2
37 changed files with 11554 additions and 3 deletions

View File

@@ -0,0 +1,25 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .text_encoder import Text_Encoder
from .resunet_film import UNetRes_FiLM
class LASSNet(nn.Module):
def __init__(self, device='cuda'):
super(LASSNet, self).__init__()
self.text_embedder = Text_Encoder(device)
self.UNet = UNetRes_FiLM(channels=1, cond_embedding_dim=256)
def forward(self, x, caption):
# x: (Batch, 1, T, 128))
input_ids, attns_mask = self.text_embedder.tokenize(caption)
cond_vec = self.text_embedder(input_ids, attns_mask)[0]
dec_cond_vec = cond_vec
mask = self.UNet(x, cond_vec, dec_cond_vec)
mask = torch.sigmoid(mask)
return mask
def get_tokenizer(self):
return self.text_embedder.tokenizer

View File

@@ -0,0 +1,27 @@
import torch
import torch.nn as nn
class Film(nn.Module):
def __init__(self, channels, cond_embedding_dim):
super(Film, self).__init__()
self.linear = nn.Sequential(
nn.Linear(cond_embedding_dim, channels * 2),
nn.ReLU(inplace=True),
nn.Linear(channels * 2, channels),
nn.ReLU(inplace=True)
)
def forward(self, data, cond_vec):
"""
:param data: [batchsize, channels, samples] or [batchsize, channels, T, F] or [batchsize, channels, F, T]
:param cond_vec: [batchsize, cond_embedding_dim]
:return:
"""
bias = self.linear(cond_vec) # [batchsize, channels]
if len(list(data.size())) == 3:
data = data + bias[..., None]
elif len(list(data.size())) == 4:
data = data + bias[..., None, None]
else:
print("Warning: The size of input tensor,", data.size(), "is not correct. Film is not working.")
return data

View File

@@ -0,0 +1,483 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from .film import Film
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
super(ConvBlock, self).__init__()
self.activation = activation
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
self.conv1 = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=(1, 1),
dilation=(1, 1),
padding=padding,
bias=False,
)
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
self.conv2 = nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=(1, 1),
dilation=(1, 1),
padding=padding,
bias=False,
)
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
self.init_weights()
def init_weights(self):
init_layer(self.conv1)
init_layer(self.conv2)
init_bn(self.bn1)
init_bn(self.bn2)
def forward(self, x):
x = act(self.bn1(self.conv1(x)), self.activation)
x = act(self.bn2(self.conv2(x)), self.activation)
return x
class EncoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, downsample, activation, momentum):
super(EncoderBlock, self).__init__()
self.conv_block = ConvBlock(
in_channels, out_channels, kernel_size, activation, momentum
)
self.downsample = downsample
def forward(self, x):
encoder = self.conv_block(x)
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
return encoder_pool, encoder
class DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, upsample, activation, momentum):
super(DecoderBlock, self).__init__()
self.kernel_size = kernel_size
self.stride = upsample
self.activation = activation
self.conv1 = torch.nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=self.stride,
stride=self.stride,
padding=(0, 0),
bias=False,
dilation=(1, 1),
)
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
self.conv_block2 = ConvBlock(
out_channels * 2, out_channels, kernel_size, activation, momentum
)
def init_weights(self):
init_layer(self.conv1)
init_bn(self.bn)
def prune(self, x):
"""Prune the shape of x after transpose convolution."""
padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
x = x[
:,
:,
padding[0] : padding[0] - self.stride[0],
padding[1] : padding[1] - self.stride[1]]
return x
def forward(self, input_tensor, concat_tensor):
x = act(self.bn1(self.conv1(input_tensor)), self.activation)
# from IPython import embed; embed(using=False); os._exit(0)
# x = self.prune(x)
x = torch.cat((x, concat_tensor), dim=1)
x = self.conv_block2(x)
return x
class EncoderBlockRes1B(nn.Module):
def __init__(self, in_channels, out_channels, downsample, activation, momentum):
super(EncoderBlockRes1B, self).__init__()
size = (3,3)
self.conv_block1 = ConvBlockRes(in_channels, out_channels, size, activation, momentum)
self.conv_block2 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
self.downsample = downsample
def forward(self, x):
encoder = self.conv_block1(x)
encoder = self.conv_block2(encoder)
encoder = self.conv_block3(encoder)
encoder = self.conv_block4(encoder)
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
return encoder_pool, encoder
class DecoderBlockRes1B(nn.Module):
def __init__(self, in_channels, out_channels, stride, activation, momentum):
super(DecoderBlockRes1B, self).__init__()
size = (3,3)
self.activation = activation
self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
out_channels=out_channels, kernel_size=size, stride=stride,
padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv_block2 = ConvBlockRes(out_channels * 2, out_channels, size, activation, momentum)
self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
self.conv_block5 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
def init_weights(self):
init_layer(self.conv1)
def prune(self, x, both=False):
"""Prune the shape of x after transpose convolution.
"""
if(both): x = x[:, :, 0 : - 1, 0:-1]
else: x = x[:, :, 0: - 1, :]
return x
def forward(self, input_tensor, concat_tensor,both=False):
x = self.conv1(F.relu_(self.bn1(input_tensor)))
x = self.prune(x,both=both)
x = torch.cat((x, concat_tensor), dim=1)
x = self.conv_block2(x)
x = self.conv_block3(x)
x = self.conv_block4(x)
x = self.conv_block5(x)
return x
class EncoderBlockRes2BCond(nn.Module):
def __init__(self, in_channels, out_channels, downsample, activation, momentum, cond_embedding_dim):
super(EncoderBlockRes2BCond, self).__init__()
size = (3, 3)
self.conv_block1 = ConvBlockResCond(in_channels, out_channels, size, activation, momentum, cond_embedding_dim)
self.conv_block2 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
self.downsample = downsample
def forward(self, x, cond_vec):
encoder = self.conv_block1(x, cond_vec)
encoder = self.conv_block2(encoder, cond_vec)
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
return encoder_pool, encoder
class DecoderBlockRes2BCond(nn.Module):
def __init__(self, in_channels, out_channels, stride, activation, momentum, cond_embedding_dim):
super(DecoderBlockRes2BCond, self).__init__()
size = (3, 3)
self.activation = activation
self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
out_channels=out_channels, kernel_size=size, stride=stride,
padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv_block2 = ConvBlockResCond(out_channels * 2, out_channels, size, activation, momentum, cond_embedding_dim)
self.conv_block3 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
def init_weights(self):
init_layer(self.conv1)
def prune(self, x, both=False):
"""Prune the shape of x after transpose convolution.
"""
if(both): x = x[:, :, 0 : - 1, 0:-1]
else: x = x[:, :, 0: - 1, :]
return x
def forward(self, input_tensor, concat_tensor, cond_vec, both=False):
x = self.conv1(F.relu_(self.bn1(input_tensor)))
x = self.prune(x, both=both)
x = torch.cat((x, concat_tensor), dim=1)
x = self.conv_block2(x, cond_vec)
x = self.conv_block3(x, cond_vec)
return x
class EncoderBlockRes4BCond(nn.Module):
def __init__(self, in_channels, out_channels, downsample, activation, momentum, cond_embedding_dim):
super(EncoderBlockRes4B, self).__init__()
size = (3,3)
self.conv_block1 = ConvBlockResCond(in_channels, out_channels, size, activation, momentum, cond_embedding_dim)
self.conv_block2 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
self.conv_block3 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
self.conv_block4 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
self.downsample = downsample
def forward(self, x, cond_vec):
encoder = self.conv_block1(x, cond_vec)
encoder = self.conv_block2(encoder, cond_vec)
encoder = self.conv_block3(encoder, cond_vec)
encoder = self.conv_block4(encoder, cond_vec)
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
return encoder_pool, encoder
class DecoderBlockRes4BCond(nn.Module):
def __init__(self, in_channels, out_channels, stride, activation, momentum, cond_embedding_dim):
super(DecoderBlockRes4B, self).__init__()
size = (3, 3)
self.activation = activation
self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
out_channels=out_channels, kernel_size=size, stride=stride,
padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv_block2 = ConvBlockResCond(out_channels * 2, out_channels, size, activation, momentum, cond_embedding_dim)
self.conv_block3 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
self.conv_block4 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
self.conv_block5 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
def init_weights(self):
init_layer(self.conv1)
def prune(self, x, both=False):
"""Prune the shape of x after transpose convolution.
"""
if(both): x = x[:, :, 0 : - 1, 0:-1]
else: x = x[:, :, 0: - 1, :]
return x
def forward(self, input_tensor, concat_tensor, cond_vec, both=False):
x = self.conv1(F.relu_(self.bn1(input_tensor)))
x = self.prune(x,both=both)
x = torch.cat((x, concat_tensor), dim=1)
x = self.conv_block2(x, cond_vec)
x = self.conv_block3(x, cond_vec)
x = self.conv_block4(x, cond_vec)
x = self.conv_block5(x, cond_vec)
return x
class EncoderBlockRes4B(nn.Module):
def __init__(self, in_channels, out_channels, downsample, activation, momentum):
super(EncoderBlockRes4B, self).__init__()
size = (3, 3)
self.conv_block1 = ConvBlockRes(in_channels, out_channels, size, activation, momentum)
self.conv_block2 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
self.downsample = downsample
def forward(self, x):
encoder = self.conv_block1(x)
encoder = self.conv_block2(encoder)
encoder = self.conv_block3(encoder)
encoder = self.conv_block4(encoder)
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
return encoder_pool, encoder
class DecoderBlockRes4B(nn.Module):
def __init__(self, in_channels, out_channels, stride, activation, momentum):
super(DecoderBlockRes4B, self).__init__()
size = (3,3)
self.activation = activation
self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
out_channels=out_channels, kernel_size=size, stride=stride,
padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv_block2 = ConvBlockRes(out_channels * 2, out_channels, size, activation, momentum)
self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
self.conv_block5 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
def init_weights(self):
init_layer(self.conv1)
def prune(self, x, both=False):
"""Prune the shape of x after transpose convolution.
"""
if(both): x = x[:, :, 0 : - 1, 0:-1]
else: x = x[:, :, 0: - 1, :]
return x
def forward(self, input_tensor, concat_tensor,both=False):
x = self.conv1(F.relu_(self.bn1(input_tensor)))
x = self.prune(x,both=both)
x = torch.cat((x, concat_tensor), dim=1)
x = self.conv_block2(x)
x = self.conv_block3(x)
x = self.conv_block4(x)
x = self.conv_block5(x)
return x
class ConvBlockResCond(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum, cond_embedding_dim):
r"""Residual block.
"""
super(ConvBlockResCond, self).__init__()
self.activation = activation
padding = [kernel_size[0] // 2, kernel_size[1] // 2]
self.bn1 = nn.BatchNorm2d(in_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv1 = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size, stride=(1, 1),
dilation=(1, 1), padding=padding, bias=False)
self.film1 = Film(channels=out_channels, cond_embedding_dim=cond_embedding_dim)
self.conv2 = nn.Conv2d(in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size, stride=(1, 1),
dilation=(1, 1), padding=padding, bias=False)
self.film2 = Film(channels=out_channels, cond_embedding_dim=cond_embedding_dim)
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
self.film_res = Film(channels=out_channels, cond_embedding_dim=cond_embedding_dim)
self.is_shortcut = True
else:
self.is_shortcut = False
self.init_weights()
def init_weights(self):
init_bn(self.bn1)
init_bn(self.bn2)
init_layer(self.conv1)
init_layer(self.conv2)
if self.is_shortcut:
init_layer(self.shortcut)
def forward(self, x, cond_vec):
origin = x
x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
x = self.film1(x, cond_vec)
x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
x = self.film2(x, cond_vec)
if self.is_shortcut:
residual = self.shortcut(origin)
residual = self.film_res(residual, cond_vec)
return residual + x
else:
return origin + x
class ConvBlockRes(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
r"""Residual block.
"""
super(ConvBlockRes, self).__init__()
self.activation = activation
padding = [kernel_size[0] // 2, kernel_size[1] // 2]
self.bn1 = nn.BatchNorm2d(in_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv1 = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size, stride=(1, 1),
dilation=(1, 1), padding=padding, bias=False)
self.conv2 = nn.Conv2d(in_channels=out_channels,
out_channels=out_channels,
kernel_size=kernel_size, stride=(1, 1),
dilation=(1, 1), padding=padding, bias=False)
if in_channels != out_channels:
self.shortcut = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
self.is_shortcut = True
else:
self.is_shortcut = False
self.init_weights()
def init_weights(self):
init_bn(self.bn1)
init_bn(self.bn2)
init_layer(self.conv1)
init_layer(self.conv2)
if self.is_shortcut:
init_layer(self.shortcut)
def forward(self, x):
origin = x
x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
if self.is_shortcut:
return self.shortcut(origin) + x
else:
return origin + x
def init_layer(layer):
"""Initialize a Linear or Convolutional layer. """
nn.init.xavier_uniform_(layer.weight)
if hasattr(layer, 'bias'):
if layer.bias is not None:
layer.bias.data.fill_(0.)
def init_bn(bn):
"""Initialize a Batchnorm layer. """
bn.bias.data.fill_(0.)
bn.weight.data.fill_(1.)
def init_gru(rnn):
"""Initialize a GRU layer. """
def _concat_init(tensor, init_funcs):
(length, fan_out) = tensor.shape
fan_in = length // len(init_funcs)
for (i, init_func) in enumerate(init_funcs):
init_func(tensor[i * fan_in: (i + 1) * fan_in, :])
def _inner_uniform(tensor):
fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in')
nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
for i in range(rnn.num_layers):
_concat_init(
getattr(rnn, 'weight_ih_l{}'.format(i)),
[_inner_uniform, _inner_uniform, _inner_uniform]
)
torch.nn.init.constant_(getattr(rnn, 'bias_ih_l{}'.format(i)), 0)
_concat_init(
getattr(rnn, 'weight_hh_l{}'.format(i)),
[_inner_uniform, _inner_uniform, nn.init.orthogonal_]
)
torch.nn.init.constant_(getattr(rnn, 'bias_hh_l{}'.format(i)), 0)
def act(x, activation):
if activation == 'relu':
return F.relu_(x)
elif activation == 'leaky_relu':
return F.leaky_relu_(x, negative_slope=0.2)
elif activation == 'swish':
return x * torch.sigmoid(x)
else:
raise Exception('Incorrect activation!')

View File

@@ -0,0 +1,110 @@
from .modules import *
import numpy as np
class UNetRes_FiLM(nn.Module):
def __init__(self, channels, cond_embedding_dim, nsrc=1):
super(UNetRes_FiLM, self).__init__()
activation = 'relu'
momentum = 0.01
self.nsrc = nsrc
self.channels = channels
self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blocks}
self.encoder_block1 = EncoderBlockRes2BCond(in_channels=channels * nsrc, out_channels=32,
downsample=(2, 2), activation=activation, momentum=momentum,
cond_embedding_dim=cond_embedding_dim)
self.encoder_block2 = EncoderBlockRes2BCond(in_channels=32, out_channels=64,
downsample=(2, 2), activation=activation, momentum=momentum,
cond_embedding_dim=cond_embedding_dim)
self.encoder_block3 = EncoderBlockRes2BCond(in_channels=64, out_channels=128,
downsample=(2, 2), activation=activation, momentum=momentum,
cond_embedding_dim=cond_embedding_dim)
self.encoder_block4 = EncoderBlockRes2BCond(in_channels=128, out_channels=256,
downsample=(2, 2), activation=activation, momentum=momentum,
cond_embedding_dim=cond_embedding_dim)
self.encoder_block5 = EncoderBlockRes2BCond(in_channels=256, out_channels=384,
downsample=(2, 2), activation=activation, momentum=momentum,
cond_embedding_dim=cond_embedding_dim)
self.encoder_block6 = EncoderBlockRes2BCond(in_channels=384, out_channels=384,
downsample=(2, 2), activation=activation, momentum=momentum,
cond_embedding_dim=cond_embedding_dim)
self.conv_block7 = ConvBlockResCond(in_channels=384, out_channels=384,
kernel_size=(3, 3), activation=activation, momentum=momentum,
cond_embedding_dim=cond_embedding_dim)
self.decoder_block1 = DecoderBlockRes2BCond(in_channels=384, out_channels=384,
stride=(2, 2), activation=activation, momentum=momentum,
cond_embedding_dim=cond_embedding_dim)
self.decoder_block2 = DecoderBlockRes2BCond(in_channels=384, out_channels=384,
stride=(2, 2), activation=activation, momentum=momentum,
cond_embedding_dim=cond_embedding_dim)
self.decoder_block3 = DecoderBlockRes2BCond(in_channels=384, out_channels=256,
stride=(2, 2), activation=activation, momentum=momentum,
cond_embedding_dim=cond_embedding_dim)
self.decoder_block4 = DecoderBlockRes2BCond(in_channels=256, out_channels=128,
stride=(2, 2), activation=activation, momentum=momentum,
cond_embedding_dim=cond_embedding_dim)
self.decoder_block5 = DecoderBlockRes2BCond(in_channels=128, out_channels=64,
stride=(2, 2), activation=activation, momentum=momentum,
cond_embedding_dim=cond_embedding_dim)
self.decoder_block6 = DecoderBlockRes2BCond(in_channels=64, out_channels=32,
stride=(2, 2), activation=activation, momentum=momentum,
cond_embedding_dim=cond_embedding_dim)
self.after_conv_block1 = ConvBlockResCond(in_channels=32, out_channels=32,
kernel_size=(3, 3), activation=activation, momentum=momentum,
cond_embedding_dim=cond_embedding_dim)
self.after_conv2 = nn.Conv2d(in_channels=32, out_channels=1,
kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True)
self.init_weights()
def init_weights(self):
init_layer(self.after_conv2)
def forward(self, sp, cond_vec, dec_cond_vec):
"""
Args:
input: sp: (batch_size, channels_num, segment_samples)
Outputs:
output_dict: {
'wav': (batch_size, channels_num, segment_samples),
'sp': (batch_size, channels_num, time_steps, freq_bins)}
"""
x = sp
# Pad spectrogram to be evenly divided by downsample ratio.
origin_len = x.shape[2] # time_steps
pad_len = int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio - origin_len
x = F.pad(x, pad=(0, 0, 0, pad_len))
x = x[..., 0: x.shape[-1] - 2] # (bs, channels, T, F)
# UNet
(x1_pool, x1) = self.encoder_block1(x, cond_vec) # x1_pool: (bs, 32, T / 2, F / 2)
(x2_pool, x2) = self.encoder_block2(x1_pool, cond_vec) # x2_pool: (bs, 64, T / 4, F / 4)
(x3_pool, x3) = self.encoder_block3(x2_pool, cond_vec) # x3_pool: (bs, 128, T / 8, F / 8)
(x4_pool, x4) = self.encoder_block4(x3_pool, dec_cond_vec) # x4_pool: (bs, 256, T / 16, F / 16)
(x5_pool, x5) = self.encoder_block5(x4_pool, dec_cond_vec) # x5_pool: (bs, 512, T / 32, F / 32)
(x6_pool, x6) = self.encoder_block6(x5_pool, dec_cond_vec) # x6_pool: (bs, 1024, T / 64, F / 64)
x_center = self.conv_block7(x6_pool, dec_cond_vec) # (bs, 2048, T / 64, F / 64)
x7 = self.decoder_block1(x_center, x6, dec_cond_vec) # (bs, 1024, T / 32, F / 32)
x8 = self.decoder_block2(x7, x5, dec_cond_vec) # (bs, 512, T / 16, F / 16)
x9 = self.decoder_block3(x8, x4, cond_vec) # (bs, 256, T / 8, F / 8)
x10 = self.decoder_block4(x9, x3, cond_vec) # (bs, 128, T / 4, F / 4)
x11 = self.decoder_block5(x10, x2, cond_vec) # (bs, 64, T / 2, F / 2)
x12 = self.decoder_block6(x11, x1, cond_vec) # (bs, 32, T, F)
x = self.after_conv_block1(x12, cond_vec) # (bs, 32, T, F)
x = self.after_conv2(x) # (bs, channels, T, F)
# Recover shape
x = F.pad(x, pad=(0, 2))
x = x[:, :, 0: origin_len, :]
return x
if __name__ == "__main__":
model = UNetRes_FiLM(channels=1, cond_embedding_dim=16)
cond_vec = torch.randn((1, 16))
dec_vec = cond_vec
print(model(torch.randn((1, 1, 1001, 513)), cond_vec, dec_vec).size())

View File

@@ -0,0 +1,45 @@
import torch
import torch.nn as nn
from transformers import *
import warnings
warnings.filterwarnings('ignore')
# pretrained model name: (model class, model tokenizer, output dimension, token style)
MODELS = {
'prajjwal1/bert-mini': (BertModel, BertTokenizer),
}
class Text_Encoder(nn.Module):
def __init__(self, device):
super(Text_Encoder, self).__init__()
self.base_model = 'prajjwal1/bert-mini'
self.dropout = 0.1
self.tokenizer = MODELS[self.base_model][1].from_pretrained(self.base_model)
self.bert_layer = MODELS[self.base_model][0].from_pretrained(self.base_model,
add_pooling_layer=False,
hidden_dropout_prob=self.dropout,
attention_probs_dropout_prob=self.dropout,
output_hidden_states=True)
self.linear_layer = nn.Sequential(nn.Linear(256, 256), nn.ReLU(inplace=True))
self.device = device
def tokenize(self, caption):
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenized = self.tokenizer(caption, add_special_tokens=False, padding=True, return_tensors='pt')
input_ids = tokenized['input_ids']
attns_mask = tokenized['attention_mask']
input_ids = input_ids.to(self.device)
attns_mask = attns_mask.to(self.device)
return input_ids, attns_mask
def forward(self, input_ids, attns_mask):
# input_ids, attns_mask = self.tokenize(caption)
output = self.bert_layer(input_ids=input_ids, attention_mask=attns_mask)[0]
cls_embed = output[:, 0, :]
text_embed = self.linear_layer(cls_embed)
return text_embed, output # text_embed: (batch, hidden_size)

View File

@@ -0,0 +1,98 @@
import torch
import numpy as np
def add_noise_and_scale(front, noise, snr_l=0, snr_h=0, scale_lower=1.0, scale_upper=1.0):
"""
:param front: front-head audio, like vocal [samples,channel], will be normlized so any scale will be fine
:param noise: noise, [samples,channel], any scale
:param snr_l: Optional
:param snr_h: Optional
:param scale_lower: Optional
:param scale_upper: Optional
:return: scaled front and noise (noisy = front + noise), all_mel_e2e outputs are noramlized within [-1 , 1]
"""
snr = None
noise, front = normalize_energy_torch(noise), normalize_energy_torch(front) # set noise and vocal to equal range [-1,1]
# print("normalize:",torch.max(noise),torch.max(front))
if snr_l is not None and snr_h is not None:
front, noise, snr = _random_noise(front, noise, snr_l=snr_l, snr_h=snr_h) # remix them with a specific snr
noisy, noise, front = unify_energy_torch(noise + front, noise, front) # normalize noisy, noise and vocal energy into [-1,1]
# print("unify:", torch.max(noise), torch.max(front), torch.max(noisy))
scale = _random_scale(scale_lower, scale_upper) # random scale these three signal
# print("Scale",scale)
noisy, noise, front = noisy * scale, noise * scale, front * scale # apply scale
# print("after scale", torch.max(noisy), torch.max(noise), torch.max(front), snr, scale)
front, noise = _to_numpy(front), _to_numpy(noise) # [num_samples]
mixed_wav = front + noise
return front, noise, mixed_wav, snr, scale
def _random_scale(lower=0.3, upper=0.9):
return float(uniform_torch(lower, upper))
def _random_noise(clean, noise, snr_l=None, snr_h=None):
snr = uniform_torch(snr_l,snr_h)
clean_weight = 10 ** (float(snr) / 20)
return clean, noise/clean_weight, snr
def _to_numpy(wav):
return np.transpose(wav, (1, 0))[0].numpy() # [num_samples]
def normalize_energy(audio, alpha = 1):
'''
:param audio: 1d waveform, [batchsize, *],
:param alpha: the value of output range from: [-alpha,alpha]
:return: 1d waveform which value range from: [-alpha,alpha]
'''
val_max = activelev(audio)
return (audio / val_max) * alpha
def normalize_energy_torch(audio, alpha = 1):
'''
If the signal is almost empty(determined by threshold), if will only be divided by 2**15
:param audio: 1d waveform, 2**15
:param alpha: the value of output range from: [-alpha,alpha]
:return: 1d waveform which value range from: [-alpha,alpha]
'''
val_max = activelev_torch([audio])
return (audio / val_max) * alpha
def unify_energy(*args):
max_amp = activelev(args)
mix_scale = 1.0/max_amp
return [x * mix_scale for x in args]
def unify_energy_torch(*args):
max_amp = activelev_torch(args)
mix_scale = 1.0/max_amp
return [x * mix_scale for x in args]
def activelev(*args):
'''
need to update like matlab
'''
return np.max(np.abs([*args]))
def activelev_torch(*args):
'''
need to update like matlab
'''
res = []
args = args[0]
for each in args:
res.append(torch.max(torch.abs(each)))
return max(res)
def uniform_torch(lower, upper):
if(abs(lower-upper)<1e-5):
return upper
return (upper-lower)*torch.rand(1)+lower
if __name__ == "__main__":
wav1 = torch.randn(1, 32000)
wav2 = torch.randn(1, 32000)
target, noise, snr, scale = add_noise_and_scale(wav1, wav2)

View File

@@ -0,0 +1,159 @@
import torch
import numpy as np
import torch.nn.functional as F
from torch.autograd import Variable
from scipy.signal import get_window
import librosa.util as librosa_util
from librosa.util import pad_center, tiny
# from audio_processing import window_sumsquare
def window_sumsquare(window, n_frames, hop_length=512, win_length=1024,
n_fft=1024, dtype=np.float32, norm=None):
"""
# from librosa 0.6
Compute the sum-square envelope of a window function at a given hop length.
This is used to estimate modulation effects induced by windowing
observations in short-time fourier transforms.
Parameters
----------
window : string, tuple, number, callable, or list-like
Window specification, as in `get_window`
n_frames : int > 0
The number of analysis frames
hop_length : int > 0
The number of samples to advance between frames
win_length : [optional]
The length of the window function. By default, this matches `n_fft`.
n_fft : int > 0
The length of each analysis frame.
dtype : np.dtype
The data type of the output
Returns
-------
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
The sum-squared envelope of the window function
"""
if win_length is None:
win_length = n_fft
n = n_fft + hop_length * (n_frames - 1)
x = np.zeros(n, dtype=dtype)
# Compute the squared window at the desired length
win_sq = get_window(window, win_length, fftbins=True)
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
win_sq = librosa_util.pad_center(win_sq, n_fft)
# Fill the envelope
for i in range(n_frames):
sample = i * hop_length
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
return x
class STFT(torch.nn.Module):
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
def __init__(self, filter_length=1024, hop_length=512, win_length=1024,
window='hann'):
super(STFT, self).__init__()
self.filter_length = filter_length
self.hop_length = hop_length
self.win_length = win_length
self.window = window
self.forward_transform = None
scale = self.filter_length / self.hop_length
fourier_basis = np.fft.fft(np.eye(self.filter_length))
cutoff = int((self.filter_length / 2 + 1))
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
np.imag(fourier_basis[:cutoff, :])])
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
inverse_basis = torch.FloatTensor(
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
if window is not None:
assert(filter_length >= win_length)
# get window and zero center pad it to filter_length
fft_window = get_window(window, win_length, fftbins=True)
fft_window = pad_center(fft_window, filter_length)
fft_window = torch.from_numpy(fft_window).float()
# window the bases
forward_basis *= fft_window
inverse_basis *= fft_window
self.register_buffer('forward_basis', forward_basis.float())
self.register_buffer('inverse_basis', inverse_basis.float())
def transform(self, input_data):
num_batches = input_data.size(0)
num_samples = input_data.size(1)
self.num_samples = num_samples
# similar to librosa, reflect-pad the input
input_data = input_data.view(num_batches, 1, num_samples)
input_data = F.pad(
input_data.unsqueeze(1),
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
mode='reflect')
input_data = input_data.squeeze(1)
forward_transform = F.conv1d(
input_data,
Variable(self.forward_basis, requires_grad=False),
stride=self.hop_length,
padding=0)
cutoff = int((self.filter_length / 2) + 1)
real_part = forward_transform[:, :cutoff, :]
imag_part = forward_transform[:, cutoff:, :]
magnitude = torch.sqrt(real_part**2 + imag_part**2)
phase = torch.autograd.Variable(
torch.atan2(imag_part.data, real_part.data))
return magnitude, phase # [batch_size, F(513), T(1251)]
def inverse(self, magnitude, phase):
recombine_magnitude_phase = torch.cat(
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
inverse_transform = F.conv_transpose1d(
recombine_magnitude_phase,
Variable(self.inverse_basis, requires_grad=False),
stride=self.hop_length,
padding=0)
if self.window is not None:
window_sum = window_sumsquare(
self.window, magnitude.size(-1), hop_length=self.hop_length,
win_length=self.win_length, n_fft=self.filter_length,
dtype=np.float32)
# remove modulation effects
approx_nonzero_indices = torch.from_numpy(
np.where(window_sum > tiny(window_sum))[0])
window_sum = torch.autograd.Variable(
torch.from_numpy(window_sum), requires_grad=False)
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
# scale by hop ratio
inverse_transform *= float(self.filter_length) / self.hop_length
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
return inverse_transform #[batch_size, 1, sample_num]
def forward(self, input_data):
self.magnitude, self.phase = self.transform(input_data)
reconstruction = self.inverse(self.magnitude, self.phase)
return reconstruction
if __name__ == '__main__':
a = torch.randn(4, 320000)
stft = STFT()
mag, phase = stft.transform(a)
# rec_a = stft.inverse(mag, phase)
print(mag.shape)

View File

@@ -0,0 +1,23 @@
import librosa
import librosa.filters
import math
import numpy as np
import scipy.io.wavfile
def load_wav(path):
max_length = 32000 * 10
wav = librosa.core.load(path, sr=32000)[0]
if len(wav) > max_length:
audio = wav[0:max_length]
# pad audio to max length, 10s for AudioCaps
if len(wav) < max_length:
# audio = torch.nn.functional.pad(audio, (0, self.max_length - audio.size(1)), 'constant')
wav = np.pad(wav, (0, max_length - len(wav)), 'constant')
wav = wav[...,None]
return wav
def save_wav(wav, path):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
scipy.io.wavfile.write(path, 32000, wav.astype(np.int16))