mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2026-05-18 05:04:58 +02:00
detection and extraction
This commit is contained in:
25
sound_extraction/model/LASSNet.py
Normal file
25
sound_extraction/model/LASSNet.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .text_encoder import Text_Encoder
|
||||
from .resunet_film import UNetRes_FiLM
|
||||
|
||||
class LASSNet(nn.Module):
|
||||
def __init__(self, device='cuda'):
|
||||
super(LASSNet, self).__init__()
|
||||
self.text_embedder = Text_Encoder(device)
|
||||
self.UNet = UNetRes_FiLM(channels=1, cond_embedding_dim=256)
|
||||
|
||||
def forward(self, x, caption):
|
||||
# x: (Batch, 1, T, 128))
|
||||
input_ids, attns_mask = self.text_embedder.tokenize(caption)
|
||||
|
||||
cond_vec = self.text_embedder(input_ids, attns_mask)[0]
|
||||
dec_cond_vec = cond_vec
|
||||
|
||||
mask = self.UNet(x, cond_vec, dec_cond_vec)
|
||||
mask = torch.sigmoid(mask)
|
||||
return mask
|
||||
|
||||
def get_tokenizer(self):
|
||||
return self.text_embedder.tokenizer
|
||||
27
sound_extraction/model/film.py
Normal file
27
sound_extraction/model/film.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Film(nn.Module):
|
||||
def __init__(self, channels, cond_embedding_dim):
|
||||
super(Film, self).__init__()
|
||||
self.linear = nn.Sequential(
|
||||
nn.Linear(cond_embedding_dim, channels * 2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channels * 2, channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
def forward(self, data, cond_vec):
|
||||
"""
|
||||
:param data: [batchsize, channels, samples] or [batchsize, channels, T, F] or [batchsize, channels, F, T]
|
||||
:param cond_vec: [batchsize, cond_embedding_dim]
|
||||
:return:
|
||||
"""
|
||||
bias = self.linear(cond_vec) # [batchsize, channels]
|
||||
if len(list(data.size())) == 3:
|
||||
data = data + bias[..., None]
|
||||
elif len(list(data.size())) == 4:
|
||||
data = data + bias[..., None, None]
|
||||
else:
|
||||
print("Warning: The size of input tensor,", data.size(), "is not correct. Film is not working.")
|
||||
return data
|
||||
483
sound_extraction/model/modules.py
Normal file
483
sound_extraction/model/modules.py
Normal file
@@ -0,0 +1,483 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
from .film import Film
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
|
||||
super(ConvBlock, self).__init__()
|
||||
|
||||
self.activation = activation
|
||||
padding = (kernel_size[0] // 2, kernel_size[1] // 2)
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=(1, 1),
|
||||
dilation=(1, 1),
|
||||
padding=padding,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=(1, 1),
|
||||
dilation=(1, 1),
|
||||
padding=padding,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
init_layer(self.conv1)
|
||||
init_layer(self.conv2)
|
||||
init_bn(self.bn1)
|
||||
init_bn(self.bn2)
|
||||
|
||||
def forward(self, x):
|
||||
x = act(self.bn1(self.conv1(x)), self.activation)
|
||||
x = act(self.bn2(self.conv2(x)), self.activation)
|
||||
return x
|
||||
|
||||
|
||||
class EncoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, downsample, activation, momentum):
|
||||
super(EncoderBlock, self).__init__()
|
||||
|
||||
self.conv_block = ConvBlock(
|
||||
in_channels, out_channels, kernel_size, activation, momentum
|
||||
)
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x):
|
||||
encoder = self.conv_block(x)
|
||||
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
||||
return encoder_pool, encoder
|
||||
|
||||
|
||||
class DecoderBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, upsample, activation, momentum):
|
||||
super(DecoderBlock, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = upsample
|
||||
self.activation = activation
|
||||
|
||||
self.conv1 = torch.nn.ConvTranspose2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=self.stride,
|
||||
stride=self.stride,
|
||||
padding=(0, 0),
|
||||
bias=False,
|
||||
dilation=(1, 1),
|
||||
)
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
||||
|
||||
self.conv_block2 = ConvBlock(
|
||||
out_channels * 2, out_channels, kernel_size, activation, momentum
|
||||
)
|
||||
|
||||
def init_weights(self):
|
||||
init_layer(self.conv1)
|
||||
init_bn(self.bn)
|
||||
|
||||
def prune(self, x):
|
||||
"""Prune the shape of x after transpose convolution."""
|
||||
padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2)
|
||||
x = x[
|
||||
:,
|
||||
:,
|
||||
padding[0] : padding[0] - self.stride[0],
|
||||
padding[1] : padding[1] - self.stride[1]]
|
||||
return x
|
||||
|
||||
def forward(self, input_tensor, concat_tensor):
|
||||
x = act(self.bn1(self.conv1(input_tensor)), self.activation)
|
||||
# from IPython import embed; embed(using=False); os._exit(0)
|
||||
# x = self.prune(x)
|
||||
x = torch.cat((x, concat_tensor), dim=1)
|
||||
x = self.conv_block2(x)
|
||||
return x
|
||||
|
||||
|
||||
class EncoderBlockRes1B(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, downsample, activation, momentum):
|
||||
super(EncoderBlockRes1B, self).__init__()
|
||||
size = (3,3)
|
||||
|
||||
self.conv_block1 = ConvBlockRes(in_channels, out_channels, size, activation, momentum)
|
||||
self.conv_block2 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
||||
self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
||||
self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x):
|
||||
encoder = self.conv_block1(x)
|
||||
encoder = self.conv_block2(encoder)
|
||||
encoder = self.conv_block3(encoder)
|
||||
encoder = self.conv_block4(encoder)
|
||||
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
||||
return encoder_pool, encoder
|
||||
|
||||
class DecoderBlockRes1B(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, activation, momentum):
|
||||
super(DecoderBlockRes1B, self).__init__()
|
||||
size = (3,3)
|
||||
self.activation = activation
|
||||
|
||||
self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
|
||||
out_channels=out_channels, kernel_size=size, stride=stride,
|
||||
padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(in_channels)
|
||||
self.conv_block2 = ConvBlockRes(out_channels * 2, out_channels, size, activation, momentum)
|
||||
self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
||||
self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
||||
self.conv_block5 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
||||
|
||||
def init_weights(self):
|
||||
init_layer(self.conv1)
|
||||
|
||||
def prune(self, x, both=False):
|
||||
"""Prune the shape of x after transpose convolution.
|
||||
"""
|
||||
if(both): x = x[:, :, 0 : - 1, 0:-1]
|
||||
else: x = x[:, :, 0: - 1, :]
|
||||
return x
|
||||
|
||||
def forward(self, input_tensor, concat_tensor,both=False):
|
||||
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
||||
x = self.prune(x,both=both)
|
||||
x = torch.cat((x, concat_tensor), dim=1)
|
||||
x = self.conv_block2(x)
|
||||
x = self.conv_block3(x)
|
||||
x = self.conv_block4(x)
|
||||
x = self.conv_block5(x)
|
||||
return x
|
||||
|
||||
|
||||
class EncoderBlockRes2BCond(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, downsample, activation, momentum, cond_embedding_dim):
|
||||
super(EncoderBlockRes2BCond, self).__init__()
|
||||
size = (3, 3)
|
||||
|
||||
self.conv_block1 = ConvBlockResCond(in_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
||||
self.conv_block2 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x, cond_vec):
|
||||
encoder = self.conv_block1(x, cond_vec)
|
||||
encoder = self.conv_block2(encoder, cond_vec)
|
||||
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
||||
return encoder_pool, encoder
|
||||
|
||||
class DecoderBlockRes2BCond(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, activation, momentum, cond_embedding_dim):
|
||||
super(DecoderBlockRes2BCond, self).__init__()
|
||||
size = (3, 3)
|
||||
self.activation = activation
|
||||
|
||||
self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
|
||||
out_channels=out_channels, kernel_size=size, stride=stride,
|
||||
padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(in_channels)
|
||||
self.conv_block2 = ConvBlockResCond(out_channels * 2, out_channels, size, activation, momentum, cond_embedding_dim)
|
||||
self.conv_block3 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
||||
|
||||
def init_weights(self):
|
||||
init_layer(self.conv1)
|
||||
|
||||
def prune(self, x, both=False):
|
||||
"""Prune the shape of x after transpose convolution.
|
||||
"""
|
||||
if(both): x = x[:, :, 0 : - 1, 0:-1]
|
||||
else: x = x[:, :, 0: - 1, :]
|
||||
return x
|
||||
|
||||
def forward(self, input_tensor, concat_tensor, cond_vec, both=False):
|
||||
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
||||
x = self.prune(x, both=both)
|
||||
x = torch.cat((x, concat_tensor), dim=1)
|
||||
x = self.conv_block2(x, cond_vec)
|
||||
x = self.conv_block3(x, cond_vec)
|
||||
return x
|
||||
|
||||
class EncoderBlockRes4BCond(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, downsample, activation, momentum, cond_embedding_dim):
|
||||
super(EncoderBlockRes4B, self).__init__()
|
||||
size = (3,3)
|
||||
|
||||
self.conv_block1 = ConvBlockResCond(in_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
||||
self.conv_block2 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
||||
self.conv_block3 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
||||
self.conv_block4 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x, cond_vec):
|
||||
encoder = self.conv_block1(x, cond_vec)
|
||||
encoder = self.conv_block2(encoder, cond_vec)
|
||||
encoder = self.conv_block3(encoder, cond_vec)
|
||||
encoder = self.conv_block4(encoder, cond_vec)
|
||||
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
||||
return encoder_pool, encoder
|
||||
|
||||
class DecoderBlockRes4BCond(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, activation, momentum, cond_embedding_dim):
|
||||
super(DecoderBlockRes4B, self).__init__()
|
||||
size = (3, 3)
|
||||
self.activation = activation
|
||||
|
||||
self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
|
||||
out_channels=out_channels, kernel_size=size, stride=stride,
|
||||
padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(in_channels)
|
||||
self.conv_block2 = ConvBlockResCond(out_channels * 2, out_channels, size, activation, momentum, cond_embedding_dim)
|
||||
self.conv_block3 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
||||
self.conv_block4 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
||||
self.conv_block5 = ConvBlockResCond(out_channels, out_channels, size, activation, momentum, cond_embedding_dim)
|
||||
|
||||
def init_weights(self):
|
||||
init_layer(self.conv1)
|
||||
|
||||
def prune(self, x, both=False):
|
||||
"""Prune the shape of x after transpose convolution.
|
||||
"""
|
||||
if(both): x = x[:, :, 0 : - 1, 0:-1]
|
||||
else: x = x[:, :, 0: - 1, :]
|
||||
return x
|
||||
|
||||
def forward(self, input_tensor, concat_tensor, cond_vec, both=False):
|
||||
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
||||
x = self.prune(x,both=both)
|
||||
x = torch.cat((x, concat_tensor), dim=1)
|
||||
x = self.conv_block2(x, cond_vec)
|
||||
x = self.conv_block3(x, cond_vec)
|
||||
x = self.conv_block4(x, cond_vec)
|
||||
x = self.conv_block5(x, cond_vec)
|
||||
return x
|
||||
|
||||
class EncoderBlockRes4B(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, downsample, activation, momentum):
|
||||
super(EncoderBlockRes4B, self).__init__()
|
||||
size = (3, 3)
|
||||
|
||||
self.conv_block1 = ConvBlockRes(in_channels, out_channels, size, activation, momentum)
|
||||
self.conv_block2 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
||||
self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
||||
self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self, x):
|
||||
encoder = self.conv_block1(x)
|
||||
encoder = self.conv_block2(encoder)
|
||||
encoder = self.conv_block3(encoder)
|
||||
encoder = self.conv_block4(encoder)
|
||||
encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
||||
return encoder_pool, encoder
|
||||
|
||||
class DecoderBlockRes4B(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, activation, momentum):
|
||||
super(DecoderBlockRes4B, self).__init__()
|
||||
size = (3,3)
|
||||
self.activation = activation
|
||||
|
||||
self.conv1 = torch.nn.ConvTranspose2d(in_channels=in_channels,
|
||||
out_channels=out_channels, kernel_size=size, stride=stride,
|
||||
padding=(0, 0), output_padding=(0, 0), bias=False, dilation=1)
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(in_channels)
|
||||
self.conv_block2 = ConvBlockRes(out_channels * 2, out_channels, size, activation, momentum)
|
||||
self.conv_block3 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
||||
self.conv_block4 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
||||
self.conv_block5 = ConvBlockRes(out_channels, out_channels, size, activation, momentum)
|
||||
|
||||
def init_weights(self):
|
||||
init_layer(self.conv1)
|
||||
|
||||
def prune(self, x, both=False):
|
||||
"""Prune the shape of x after transpose convolution.
|
||||
"""
|
||||
if(both): x = x[:, :, 0 : - 1, 0:-1]
|
||||
else: x = x[:, :, 0: - 1, :]
|
||||
return x
|
||||
|
||||
def forward(self, input_tensor, concat_tensor,both=False):
|
||||
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
||||
x = self.prune(x,both=both)
|
||||
x = torch.cat((x, concat_tensor), dim=1)
|
||||
x = self.conv_block2(x)
|
||||
x = self.conv_block3(x)
|
||||
x = self.conv_block4(x)
|
||||
x = self.conv_block5(x)
|
||||
return x
|
||||
|
||||
class ConvBlockResCond(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum, cond_embedding_dim):
|
||||
r"""Residual block.
|
||||
"""
|
||||
super(ConvBlockResCond, self).__init__()
|
||||
|
||||
self.activation = activation
|
||||
padding = [kernel_size[0] // 2, kernel_size[1] // 2]
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(in_channels)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||
|
||||
self.conv1 = nn.Conv2d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size, stride=(1, 1),
|
||||
dilation=(1, 1), padding=padding, bias=False)
|
||||
self.film1 = Film(channels=out_channels, cond_embedding_dim=cond_embedding_dim)
|
||||
self.conv2 = nn.Conv2d(in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size, stride=(1, 1),
|
||||
dilation=(1, 1), padding=padding, bias=False)
|
||||
self.film2 = Film(channels=out_channels, cond_embedding_dim=cond_embedding_dim)
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = nn.Conv2d(in_channels=in_channels,
|
||||
out_channels=out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
||||
self.film_res = Film(channels=out_channels, cond_embedding_dim=cond_embedding_dim)
|
||||
self.is_shortcut = True
|
||||
else:
|
||||
self.is_shortcut = False
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
init_bn(self.bn1)
|
||||
init_bn(self.bn2)
|
||||
init_layer(self.conv1)
|
||||
init_layer(self.conv2)
|
||||
|
||||
if self.is_shortcut:
|
||||
init_layer(self.shortcut)
|
||||
|
||||
def forward(self, x, cond_vec):
|
||||
origin = x
|
||||
x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
|
||||
x = self.film1(x, cond_vec)
|
||||
x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
|
||||
x = self.film2(x, cond_vec)
|
||||
if self.is_shortcut:
|
||||
residual = self.shortcut(origin)
|
||||
residual = self.film_res(residual, cond_vec)
|
||||
return residual + x
|
||||
else:
|
||||
return origin + x
|
||||
|
||||
class ConvBlockRes(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, activation, momentum):
|
||||
r"""Residual block.
|
||||
"""
|
||||
super(ConvBlockRes, self).__init__()
|
||||
|
||||
self.activation = activation
|
||||
padding = [kernel_size[0] // 2, kernel_size[1] // 2]
|
||||
|
||||
self.bn1 = nn.BatchNorm2d(in_channels)
|
||||
self.bn2 = nn.BatchNorm2d(out_channels)
|
||||
|
||||
self.conv1 = nn.Conv2d(in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size, stride=(1, 1),
|
||||
dilation=(1, 1), padding=padding, bias=False)
|
||||
|
||||
self.conv2 = nn.Conv2d(in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size, stride=(1, 1),
|
||||
dilation=(1, 1), padding=padding, bias=False)
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.shortcut = nn.Conv2d(in_channels=in_channels,
|
||||
out_channels=out_channels, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0))
|
||||
self.is_shortcut = True
|
||||
else:
|
||||
self.is_shortcut = False
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
init_bn(self.bn1)
|
||||
init_bn(self.bn2)
|
||||
init_layer(self.conv1)
|
||||
init_layer(self.conv2)
|
||||
|
||||
if self.is_shortcut:
|
||||
init_layer(self.shortcut)
|
||||
|
||||
def forward(self, x):
|
||||
origin = x
|
||||
x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
|
||||
x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
|
||||
|
||||
if self.is_shortcut:
|
||||
return self.shortcut(origin) + x
|
||||
else:
|
||||
return origin + x
|
||||
|
||||
def init_layer(layer):
|
||||
"""Initialize a Linear or Convolutional layer. """
|
||||
nn.init.xavier_uniform_(layer.weight)
|
||||
|
||||
if hasattr(layer, 'bias'):
|
||||
if layer.bias is not None:
|
||||
layer.bias.data.fill_(0.)
|
||||
|
||||
def init_bn(bn):
|
||||
"""Initialize a Batchnorm layer. """
|
||||
bn.bias.data.fill_(0.)
|
||||
bn.weight.data.fill_(1.)
|
||||
|
||||
def init_gru(rnn):
|
||||
"""Initialize a GRU layer. """
|
||||
|
||||
def _concat_init(tensor, init_funcs):
|
||||
(length, fan_out) = tensor.shape
|
||||
fan_in = length // len(init_funcs)
|
||||
|
||||
for (i, init_func) in enumerate(init_funcs):
|
||||
init_func(tensor[i * fan_in: (i + 1) * fan_in, :])
|
||||
|
||||
def _inner_uniform(tensor):
|
||||
fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in')
|
||||
nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
|
||||
|
||||
for i in range(rnn.num_layers):
|
||||
_concat_init(
|
||||
getattr(rnn, 'weight_ih_l{}'.format(i)),
|
||||
[_inner_uniform, _inner_uniform, _inner_uniform]
|
||||
)
|
||||
torch.nn.init.constant_(getattr(rnn, 'bias_ih_l{}'.format(i)), 0)
|
||||
|
||||
_concat_init(
|
||||
getattr(rnn, 'weight_hh_l{}'.format(i)),
|
||||
[_inner_uniform, _inner_uniform, nn.init.orthogonal_]
|
||||
)
|
||||
torch.nn.init.constant_(getattr(rnn, 'bias_hh_l{}'.format(i)), 0)
|
||||
|
||||
|
||||
def act(x, activation):
|
||||
if activation == 'relu':
|
||||
return F.relu_(x)
|
||||
|
||||
elif activation == 'leaky_relu':
|
||||
return F.leaky_relu_(x, negative_slope=0.2)
|
||||
|
||||
elif activation == 'swish':
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
else:
|
||||
raise Exception('Incorrect activation!')
|
||||
110
sound_extraction/model/resunet_film.py
Normal file
110
sound_extraction/model/resunet_film.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from .modules import *
|
||||
import numpy as np
|
||||
|
||||
class UNetRes_FiLM(nn.Module):
|
||||
def __init__(self, channels, cond_embedding_dim, nsrc=1):
|
||||
super(UNetRes_FiLM, self).__init__()
|
||||
activation = 'relu'
|
||||
momentum = 0.01
|
||||
|
||||
self.nsrc = nsrc
|
||||
self.channels = channels
|
||||
self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blocks}
|
||||
|
||||
self.encoder_block1 = EncoderBlockRes2BCond(in_channels=channels * nsrc, out_channels=32,
|
||||
downsample=(2, 2), activation=activation, momentum=momentum,
|
||||
cond_embedding_dim=cond_embedding_dim)
|
||||
self.encoder_block2 = EncoderBlockRes2BCond(in_channels=32, out_channels=64,
|
||||
downsample=(2, 2), activation=activation, momentum=momentum,
|
||||
cond_embedding_dim=cond_embedding_dim)
|
||||
self.encoder_block3 = EncoderBlockRes2BCond(in_channels=64, out_channels=128,
|
||||
downsample=(2, 2), activation=activation, momentum=momentum,
|
||||
cond_embedding_dim=cond_embedding_dim)
|
||||
self.encoder_block4 = EncoderBlockRes2BCond(in_channels=128, out_channels=256,
|
||||
downsample=(2, 2), activation=activation, momentum=momentum,
|
||||
cond_embedding_dim=cond_embedding_dim)
|
||||
self.encoder_block5 = EncoderBlockRes2BCond(in_channels=256, out_channels=384,
|
||||
downsample=(2, 2), activation=activation, momentum=momentum,
|
||||
cond_embedding_dim=cond_embedding_dim)
|
||||
self.encoder_block6 = EncoderBlockRes2BCond(in_channels=384, out_channels=384,
|
||||
downsample=(2, 2), activation=activation, momentum=momentum,
|
||||
cond_embedding_dim=cond_embedding_dim)
|
||||
self.conv_block7 = ConvBlockResCond(in_channels=384, out_channels=384,
|
||||
kernel_size=(3, 3), activation=activation, momentum=momentum,
|
||||
cond_embedding_dim=cond_embedding_dim)
|
||||
self.decoder_block1 = DecoderBlockRes2BCond(in_channels=384, out_channels=384,
|
||||
stride=(2, 2), activation=activation, momentum=momentum,
|
||||
cond_embedding_dim=cond_embedding_dim)
|
||||
self.decoder_block2 = DecoderBlockRes2BCond(in_channels=384, out_channels=384,
|
||||
stride=(2, 2), activation=activation, momentum=momentum,
|
||||
cond_embedding_dim=cond_embedding_dim)
|
||||
self.decoder_block3 = DecoderBlockRes2BCond(in_channels=384, out_channels=256,
|
||||
stride=(2, 2), activation=activation, momentum=momentum,
|
||||
cond_embedding_dim=cond_embedding_dim)
|
||||
self.decoder_block4 = DecoderBlockRes2BCond(in_channels=256, out_channels=128,
|
||||
stride=(2, 2), activation=activation, momentum=momentum,
|
||||
cond_embedding_dim=cond_embedding_dim)
|
||||
self.decoder_block5 = DecoderBlockRes2BCond(in_channels=128, out_channels=64,
|
||||
stride=(2, 2), activation=activation, momentum=momentum,
|
||||
cond_embedding_dim=cond_embedding_dim)
|
||||
self.decoder_block6 = DecoderBlockRes2BCond(in_channels=64, out_channels=32,
|
||||
stride=(2, 2), activation=activation, momentum=momentum,
|
||||
cond_embedding_dim=cond_embedding_dim)
|
||||
|
||||
self.after_conv_block1 = ConvBlockResCond(in_channels=32, out_channels=32,
|
||||
kernel_size=(3, 3), activation=activation, momentum=momentum,
|
||||
cond_embedding_dim=cond_embedding_dim)
|
||||
|
||||
self.after_conv2 = nn.Conv2d(in_channels=32, out_channels=1,
|
||||
kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
init_layer(self.after_conv2)
|
||||
|
||||
def forward(self, sp, cond_vec, dec_cond_vec):
|
||||
"""
|
||||
Args:
|
||||
input: sp: (batch_size, channels_num, segment_samples)
|
||||
Outputs:
|
||||
output_dict: {
|
||||
'wav': (batch_size, channels_num, segment_samples),
|
||||
'sp': (batch_size, channels_num, time_steps, freq_bins)}
|
||||
"""
|
||||
|
||||
x = sp
|
||||
# Pad spectrogram to be evenly divided by downsample ratio.
|
||||
origin_len = x.shape[2] # time_steps
|
||||
pad_len = int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio - origin_len
|
||||
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
||||
x = x[..., 0: x.shape[-1] - 2] # (bs, channels, T, F)
|
||||
|
||||
# UNet
|
||||
(x1_pool, x1) = self.encoder_block1(x, cond_vec) # x1_pool: (bs, 32, T / 2, F / 2)
|
||||
(x2_pool, x2) = self.encoder_block2(x1_pool, cond_vec) # x2_pool: (bs, 64, T / 4, F / 4)
|
||||
(x3_pool, x3) = self.encoder_block3(x2_pool, cond_vec) # x3_pool: (bs, 128, T / 8, F / 8)
|
||||
(x4_pool, x4) = self.encoder_block4(x3_pool, dec_cond_vec) # x4_pool: (bs, 256, T / 16, F / 16)
|
||||
(x5_pool, x5) = self.encoder_block5(x4_pool, dec_cond_vec) # x5_pool: (bs, 512, T / 32, F / 32)
|
||||
(x6_pool, x6) = self.encoder_block6(x5_pool, dec_cond_vec) # x6_pool: (bs, 1024, T / 64, F / 64)
|
||||
x_center = self.conv_block7(x6_pool, dec_cond_vec) # (bs, 2048, T / 64, F / 64)
|
||||
x7 = self.decoder_block1(x_center, x6, dec_cond_vec) # (bs, 1024, T / 32, F / 32)
|
||||
x8 = self.decoder_block2(x7, x5, dec_cond_vec) # (bs, 512, T / 16, F / 16)
|
||||
x9 = self.decoder_block3(x8, x4, cond_vec) # (bs, 256, T / 8, F / 8)
|
||||
x10 = self.decoder_block4(x9, x3, cond_vec) # (bs, 128, T / 4, F / 4)
|
||||
x11 = self.decoder_block5(x10, x2, cond_vec) # (bs, 64, T / 2, F / 2)
|
||||
x12 = self.decoder_block6(x11, x1, cond_vec) # (bs, 32, T, F)
|
||||
x = self.after_conv_block1(x12, cond_vec) # (bs, 32, T, F)
|
||||
x = self.after_conv2(x) # (bs, channels, T, F)
|
||||
|
||||
# Recover shape
|
||||
x = F.pad(x, pad=(0, 2))
|
||||
x = x[:, :, 0: origin_len, :]
|
||||
return x
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = UNetRes_FiLM(channels=1, cond_embedding_dim=16)
|
||||
cond_vec = torch.randn((1, 16))
|
||||
dec_vec = cond_vec
|
||||
print(model(torch.randn((1, 1, 1001, 513)), cond_vec, dec_vec).size())
|
||||
45
sound_extraction/model/text_encoder.py
Normal file
45
sound_extraction/model/text_encoder.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import *
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
# pretrained model name: (model class, model tokenizer, output dimension, token style)
|
||||
MODELS = {
|
||||
'prajjwal1/bert-mini': (BertModel, BertTokenizer),
|
||||
}
|
||||
|
||||
class Text_Encoder(nn.Module):
|
||||
def __init__(self, device):
|
||||
super(Text_Encoder, self).__init__()
|
||||
self.base_model = 'prajjwal1/bert-mini'
|
||||
self.dropout = 0.1
|
||||
|
||||
self.tokenizer = MODELS[self.base_model][1].from_pretrained(self.base_model)
|
||||
|
||||
self.bert_layer = MODELS[self.base_model][0].from_pretrained(self.base_model,
|
||||
add_pooling_layer=False,
|
||||
hidden_dropout_prob=self.dropout,
|
||||
attention_probs_dropout_prob=self.dropout,
|
||||
output_hidden_states=True)
|
||||
|
||||
self.linear_layer = nn.Sequential(nn.Linear(256, 256), nn.ReLU(inplace=True))
|
||||
|
||||
self.device = device
|
||||
|
||||
def tokenize(self, caption):
|
||||
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
tokenized = self.tokenizer(caption, add_special_tokens=False, padding=True, return_tensors='pt')
|
||||
input_ids = tokenized['input_ids']
|
||||
attns_mask = tokenized['attention_mask']
|
||||
|
||||
input_ids = input_ids.to(self.device)
|
||||
attns_mask = attns_mask.to(self.device)
|
||||
return input_ids, attns_mask
|
||||
|
||||
def forward(self, input_ids, attns_mask):
|
||||
# input_ids, attns_mask = self.tokenize(caption)
|
||||
output = self.bert_layer(input_ids=input_ids, attention_mask=attns_mask)[0]
|
||||
cls_embed = output[:, 0, :]
|
||||
text_embed = self.linear_layer(cls_embed)
|
||||
|
||||
return text_embed, output # text_embed: (batch, hidden_size)
|
||||
98
sound_extraction/utils/create_mixtures.py
Normal file
98
sound_extraction/utils/create_mixtures.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
def add_noise_and_scale(front, noise, snr_l=0, snr_h=0, scale_lower=1.0, scale_upper=1.0):
|
||||
"""
|
||||
:param front: front-head audio, like vocal [samples,channel], will be normlized so any scale will be fine
|
||||
:param noise: noise, [samples,channel], any scale
|
||||
:param snr_l: Optional
|
||||
:param snr_h: Optional
|
||||
:param scale_lower: Optional
|
||||
:param scale_upper: Optional
|
||||
:return: scaled front and noise (noisy = front + noise), all_mel_e2e outputs are noramlized within [-1 , 1]
|
||||
"""
|
||||
snr = None
|
||||
noise, front = normalize_energy_torch(noise), normalize_energy_torch(front) # set noise and vocal to equal range [-1,1]
|
||||
# print("normalize:",torch.max(noise),torch.max(front))
|
||||
if snr_l is not None and snr_h is not None:
|
||||
front, noise, snr = _random_noise(front, noise, snr_l=snr_l, snr_h=snr_h) # remix them with a specific snr
|
||||
|
||||
noisy, noise, front = unify_energy_torch(noise + front, noise, front) # normalize noisy, noise and vocal energy into [-1,1]
|
||||
|
||||
# print("unify:", torch.max(noise), torch.max(front), torch.max(noisy))
|
||||
scale = _random_scale(scale_lower, scale_upper) # random scale these three signal
|
||||
|
||||
# print("Scale",scale)
|
||||
noisy, noise, front = noisy * scale, noise * scale, front * scale # apply scale
|
||||
# print("after scale", torch.max(noisy), torch.max(noise), torch.max(front), snr, scale)
|
||||
|
||||
front, noise = _to_numpy(front), _to_numpy(noise) # [num_samples]
|
||||
mixed_wav = front + noise
|
||||
|
||||
return front, noise, mixed_wav, snr, scale
|
||||
|
||||
def _random_scale(lower=0.3, upper=0.9):
|
||||
return float(uniform_torch(lower, upper))
|
||||
|
||||
def _random_noise(clean, noise, snr_l=None, snr_h=None):
|
||||
snr = uniform_torch(snr_l,snr_h)
|
||||
clean_weight = 10 ** (float(snr) / 20)
|
||||
return clean, noise/clean_weight, snr
|
||||
|
||||
def _to_numpy(wav):
|
||||
return np.transpose(wav, (1, 0))[0].numpy() # [num_samples]
|
||||
|
||||
def normalize_energy(audio, alpha = 1):
|
||||
'''
|
||||
:param audio: 1d waveform, [batchsize, *],
|
||||
:param alpha: the value of output range from: [-alpha,alpha]
|
||||
:return: 1d waveform which value range from: [-alpha,alpha]
|
||||
'''
|
||||
val_max = activelev(audio)
|
||||
return (audio / val_max) * alpha
|
||||
|
||||
def normalize_energy_torch(audio, alpha = 1):
|
||||
'''
|
||||
If the signal is almost empty(determined by threshold), if will only be divided by 2**15
|
||||
:param audio: 1d waveform, 2**15
|
||||
:param alpha: the value of output range from: [-alpha,alpha]
|
||||
:return: 1d waveform which value range from: [-alpha,alpha]
|
||||
'''
|
||||
val_max = activelev_torch([audio])
|
||||
return (audio / val_max) * alpha
|
||||
|
||||
def unify_energy(*args):
|
||||
max_amp = activelev(args)
|
||||
mix_scale = 1.0/max_amp
|
||||
return [x * mix_scale for x in args]
|
||||
|
||||
def unify_energy_torch(*args):
|
||||
max_amp = activelev_torch(args)
|
||||
mix_scale = 1.0/max_amp
|
||||
return [x * mix_scale for x in args]
|
||||
|
||||
def activelev(*args):
|
||||
'''
|
||||
need to update like matlab
|
||||
'''
|
||||
return np.max(np.abs([*args]))
|
||||
|
||||
def activelev_torch(*args):
|
||||
'''
|
||||
need to update like matlab
|
||||
'''
|
||||
res = []
|
||||
args = args[0]
|
||||
for each in args:
|
||||
res.append(torch.max(torch.abs(each)))
|
||||
return max(res)
|
||||
|
||||
def uniform_torch(lower, upper):
|
||||
if(abs(lower-upper)<1e-5):
|
||||
return upper
|
||||
return (upper-lower)*torch.rand(1)+lower
|
||||
|
||||
if __name__ == "__main__":
|
||||
wav1 = torch.randn(1, 32000)
|
||||
wav2 = torch.randn(1, 32000)
|
||||
target, noise, snr, scale = add_noise_and_scale(wav1, wav2)
|
||||
159
sound_extraction/utils/stft.py
Normal file
159
sound_extraction/utils/stft.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
from scipy.signal import get_window
|
||||
import librosa.util as librosa_util
|
||||
from librosa.util import pad_center, tiny
|
||||
# from audio_processing import window_sumsquare
|
||||
|
||||
def window_sumsquare(window, n_frames, hop_length=512, win_length=1024,
|
||||
n_fft=1024, dtype=np.float32, norm=None):
|
||||
"""
|
||||
# from librosa 0.6
|
||||
Compute the sum-square envelope of a window function at a given hop length.
|
||||
This is used to estimate modulation effects induced by windowing
|
||||
observations in short-time fourier transforms.
|
||||
Parameters
|
||||
----------
|
||||
window : string, tuple, number, callable, or list-like
|
||||
Window specification, as in `get_window`
|
||||
n_frames : int > 0
|
||||
The number of analysis frames
|
||||
hop_length : int > 0
|
||||
The number of samples to advance between frames
|
||||
win_length : [optional]
|
||||
The length of the window function. By default, this matches `n_fft`.
|
||||
n_fft : int > 0
|
||||
The length of each analysis frame.
|
||||
dtype : np.dtype
|
||||
The data type of the output
|
||||
Returns
|
||||
-------
|
||||
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
||||
The sum-squared envelope of the window function
|
||||
"""
|
||||
if win_length is None:
|
||||
win_length = n_fft
|
||||
|
||||
n = n_fft + hop_length * (n_frames - 1)
|
||||
x = np.zeros(n, dtype=dtype)
|
||||
|
||||
# Compute the squared window at the desired length
|
||||
win_sq = get_window(window, win_length, fftbins=True)
|
||||
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
|
||||
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
||||
|
||||
# Fill the envelope
|
||||
for i in range(n_frames):
|
||||
sample = i * hop_length
|
||||
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
||||
return x
|
||||
|
||||
class STFT(torch.nn.Module):
|
||||
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
||||
def __init__(self, filter_length=1024, hop_length=512, win_length=1024,
|
||||
window='hann'):
|
||||
super(STFT, self).__init__()
|
||||
self.filter_length = filter_length
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.window = window
|
||||
self.forward_transform = None
|
||||
scale = self.filter_length / self.hop_length
|
||||
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
||||
|
||||
cutoff = int((self.filter_length / 2 + 1))
|
||||
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
|
||||
np.imag(fourier_basis[:cutoff, :])])
|
||||
|
||||
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
||||
inverse_basis = torch.FloatTensor(
|
||||
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
|
||||
|
||||
if window is not None:
|
||||
assert(filter_length >= win_length)
|
||||
# get window and zero center pad it to filter_length
|
||||
fft_window = get_window(window, win_length, fftbins=True)
|
||||
fft_window = pad_center(fft_window, filter_length)
|
||||
fft_window = torch.from_numpy(fft_window).float()
|
||||
|
||||
# window the bases
|
||||
forward_basis *= fft_window
|
||||
inverse_basis *= fft_window
|
||||
|
||||
self.register_buffer('forward_basis', forward_basis.float())
|
||||
self.register_buffer('inverse_basis', inverse_basis.float())
|
||||
|
||||
def transform(self, input_data):
|
||||
num_batches = input_data.size(0)
|
||||
num_samples = input_data.size(1)
|
||||
|
||||
self.num_samples = num_samples
|
||||
|
||||
# similar to librosa, reflect-pad the input
|
||||
input_data = input_data.view(num_batches, 1, num_samples)
|
||||
input_data = F.pad(
|
||||
input_data.unsqueeze(1),
|
||||
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
||||
mode='reflect')
|
||||
input_data = input_data.squeeze(1)
|
||||
|
||||
forward_transform = F.conv1d(
|
||||
input_data,
|
||||
Variable(self.forward_basis, requires_grad=False),
|
||||
stride=self.hop_length,
|
||||
padding=0)
|
||||
|
||||
cutoff = int((self.filter_length / 2) + 1)
|
||||
real_part = forward_transform[:, :cutoff, :]
|
||||
imag_part = forward_transform[:, cutoff:, :]
|
||||
|
||||
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
||||
phase = torch.autograd.Variable(
|
||||
torch.atan2(imag_part.data, real_part.data))
|
||||
|
||||
return magnitude, phase # [batch_size, F(513), T(1251)]
|
||||
|
||||
def inverse(self, magnitude, phase):
|
||||
recombine_magnitude_phase = torch.cat(
|
||||
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
|
||||
|
||||
inverse_transform = F.conv_transpose1d(
|
||||
recombine_magnitude_phase,
|
||||
Variable(self.inverse_basis, requires_grad=False),
|
||||
stride=self.hop_length,
|
||||
padding=0)
|
||||
|
||||
if self.window is not None:
|
||||
window_sum = window_sumsquare(
|
||||
self.window, magnitude.size(-1), hop_length=self.hop_length,
|
||||
win_length=self.win_length, n_fft=self.filter_length,
|
||||
dtype=np.float32)
|
||||
# remove modulation effects
|
||||
approx_nonzero_indices = torch.from_numpy(
|
||||
np.where(window_sum > tiny(window_sum))[0])
|
||||
window_sum = torch.autograd.Variable(
|
||||
torch.from_numpy(window_sum), requires_grad=False)
|
||||
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
|
||||
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
||||
|
||||
# scale by hop ratio
|
||||
inverse_transform *= float(self.filter_length) / self.hop_length
|
||||
|
||||
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
|
||||
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
|
||||
|
||||
return inverse_transform #[batch_size, 1, sample_num]
|
||||
|
||||
def forward(self, input_data):
|
||||
self.magnitude, self.phase = self.transform(input_data)
|
||||
reconstruction = self.inverse(self.magnitude, self.phase)
|
||||
return reconstruction
|
||||
|
||||
if __name__ == '__main__':
|
||||
a = torch.randn(4, 320000)
|
||||
stft = STFT()
|
||||
mag, phase = stft.transform(a)
|
||||
# rec_a = stft.inverse(mag, phase)
|
||||
print(mag.shape)
|
||||
23
sound_extraction/utils/wav_io.py
Normal file
23
sound_extraction/utils/wav_io.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import librosa
|
||||
import librosa.filters
|
||||
import math
|
||||
import numpy as np
|
||||
import scipy.io.wavfile
|
||||
|
||||
def load_wav(path):
|
||||
max_length = 32000 * 10
|
||||
wav = librosa.core.load(path, sr=32000)[0]
|
||||
if len(wav) > max_length:
|
||||
audio = wav[0:max_length]
|
||||
|
||||
# pad audio to max length, 10s for AudioCaps
|
||||
if len(wav) < max_length:
|
||||
# audio = torch.nn.functional.pad(audio, (0, self.max_length - audio.size(1)), 'constant')
|
||||
wav = np.pad(wav, (0, max_length - len(wav)), 'constant')
|
||||
wav = wav[...,None]
|
||||
return wav
|
||||
|
||||
|
||||
def save_wav(wav, path):
|
||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||
scipy.io.wavfile.write(path, 32000, wav.astype(np.int16))
|
||||
Reference in New Issue
Block a user