mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[to #42322933] add speech separation finetune
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11379892
This commit is contained in:
3
data/test/audios/s1_speech.wav
Normal file
3
data/test/audios/s1_speech.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:437b1064a0e38219a9043e25e4761c9f1161c0431636dcea159b44524e0f34eb
|
||||
size 141134
|
||||
3
data/test/audios/s2_speech.wav
Normal file
3
data/test/audios/s2_speech.wav
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b1eb51be6751b35aa521866ef0cd1caa64e39451cd7f4b22dee5c1cb7e3e43d5
|
||||
size 141134
|
||||
@@ -429,6 +429,7 @@ class Trainers(object):
|
||||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
|
||||
speech_kws_fsmn_char_ctc_nearfield = 'speech_kws_fsmn_char_ctc_nearfield'
|
||||
speech_kantts_trainer = 'speech-kantts-trainer'
|
||||
speech_separation = 'speech-separation'
|
||||
|
||||
|
||||
class Preprocessors(object):
|
||||
|
||||
@@ -93,6 +93,10 @@ class MossFormer(TorchModel):
|
||||
os.path.join(load_path, 'masknet.bin'), map_location=device),
|
||||
strict=True)
|
||||
|
||||
def as_dict(self):
|
||||
return dict(
|
||||
encoder=self.encoder, decoder=self.decoder, masknet=self.mask_net)
|
||||
|
||||
|
||||
def select_norm(norm, dim, shape):
|
||||
"""Just a wrapper to select the normalization type.
|
||||
|
||||
@@ -63,6 +63,9 @@ class SeparationPipeline(Pipeline):
|
||||
for ns in range(self.model.num_spks):
|
||||
signal = est_source[0, :, ns]
|
||||
signal = signal / signal.abs().max() * 0.5
|
||||
result.append(signal.unsqueeze(0).cpu())
|
||||
signal = signal.unsqueeze(0).cpu()
|
||||
# convert tensor to pcm
|
||||
output = (signal.numpy() * 32768).astype(numpy.int16).tobytes()
|
||||
result.append(output)
|
||||
logger.info('Finish forward.')
|
||||
return {OutputKeys.OUTPUT_PCM_LIST: result}
|
||||
|
||||
@@ -8,7 +8,7 @@ if TYPE_CHECKING:
|
||||
from .builder import PREPROCESSORS, build_preprocessor
|
||||
from .common import Compose, ToTensor, Filter
|
||||
from .asr import WavToScp
|
||||
from .audio import LinearAECAndFbank
|
||||
from .audio import LinearAECAndFbank, AudioBrainPreprocessor
|
||||
from .image import (LoadImage, load_image,
|
||||
ImageColorEnhanceFinetunePreprocessor,
|
||||
ImageInstanceSegmentationPreprocessor,
|
||||
@@ -46,7 +46,7 @@ else:
|
||||
'base': ['Preprocessor'],
|
||||
'builder': ['PREPROCESSORS', 'build_preprocessor'],
|
||||
'common': ['Compose', 'ToTensor', 'Filter'],
|
||||
'audio': ['LinearAECAndFbank'],
|
||||
'audio': ['LinearAECAndFbank', 'AudioBrainPreprocessor'],
|
||||
'asr': ['WavToScp'],
|
||||
'video': ['ReadVideoData', 'MovieSceneSegmentationPreprocessor'],
|
||||
'image': [
|
||||
|
||||
@@ -11,7 +11,34 @@ import torch
|
||||
from modelscope.fileio import File
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.preprocessors.builder import PREPROCESSORS
|
||||
from modelscope.utils.constant import Fields
|
||||
from modelscope.utils.constant import Fields, ModeKeys
|
||||
|
||||
|
||||
class AudioBrainPreprocessor(Preprocessor):
|
||||
"""A preprocessor takes audio file path and reads it into tensor
|
||||
|
||||
Args:
|
||||
takes: the audio file field name
|
||||
provides: the tensor field name
|
||||
mode: process mode, default 'inference'
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
takes: str,
|
||||
provides: str,
|
||||
mode=ModeKeys.INFERENCE,
|
||||
*args,
|
||||
**kwargs):
|
||||
super(AudioBrainPreprocessor, self).__init__(mode, *args, **kwargs)
|
||||
self.takes = takes
|
||||
self.provides = provides
|
||||
import speechbrain as sb
|
||||
self.read_audio = sb.dataio.dataio.read_audio
|
||||
|
||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result = self.read_audio(data[self.takes])
|
||||
data[self.provides] = result
|
||||
return data
|
||||
|
||||
|
||||
def load_kaldi_feature_transform(filename):
|
||||
|
||||
564
modelscope/trainers/audio/separation_trainer.py
Normal file
564
modelscope/trainers/audio/separation_trainer.py
Normal file
@@ -0,0 +1,564 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import csv
|
||||
import os
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import speechbrain as sb
|
||||
import speechbrain.nnet.schedulers as schedulers
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models import Model, TorchModel
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers.base import BaseTrainer
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, ModelFile
|
||||
from modelscope.utils.device import create_device
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.torch_utils import (get_dist_info, get_local_rank,
|
||||
init_dist)
|
||||
|
||||
EVAL_KEY = 'si-snr'
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@TRAINERS.register_module(module_name=Trainers.speech_separation)
|
||||
class SeparationTrainer(BaseTrainer):
|
||||
"""A trainer is used for speech separation.
|
||||
|
||||
Args:
|
||||
model: id or local path of the model
|
||||
work_dir: local path to store all training outputs
|
||||
cfg_file: config file of the model
|
||||
train_dataset: dataset for training
|
||||
eval_dataset: dataset for evaluation
|
||||
model_revision: the git version of model on modelhub
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
work_dir: str,
|
||||
cfg_file: Optional[str] = None,
|
||||
train_dataset: Optional[Union[MsDataset, Dataset]] = None,
|
||||
eval_dataset: Optional[Union[MsDataset, Dataset]] = None,
|
||||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
**kwargs):
|
||||
|
||||
if isinstance(model, str):
|
||||
self.model_dir = self.get_or_download_model_dir(
|
||||
model, model_revision)
|
||||
if cfg_file is None:
|
||||
cfg_file = os.path.join(self.model_dir,
|
||||
ModelFile.CONFIGURATION)
|
||||
else:
|
||||
assert cfg_file is not None, 'Config file should not be None if model is not from pretrained!'
|
||||
self.model_dir = os.path.dirname(cfg_file)
|
||||
|
||||
BaseTrainer.__init__(self, cfg_file)
|
||||
|
||||
self.model = self.build_model()
|
||||
self.work_dir = work_dir
|
||||
if kwargs.get('launcher', None) is not None:
|
||||
init_dist(kwargs['launcher'])
|
||||
_, world_size = get_dist_info()
|
||||
self._dist = world_size > 1
|
||||
|
||||
device_name = kwargs.get('device', 'gpu')
|
||||
if self._dist:
|
||||
local_rank = get_local_rank()
|
||||
device_name = f'cuda:{local_rank}'
|
||||
self.device = create_device(device_name)
|
||||
|
||||
if 'max_epochs' not in kwargs:
|
||||
assert hasattr(
|
||||
self.cfg.train, 'max_epochs'
|
||||
), 'max_epochs is missing from the configuration file'
|
||||
self._max_epochs = self.cfg.train.max_epochs
|
||||
else:
|
||||
self._max_epochs = kwargs['max_epochs']
|
||||
self.train_dataset = train_dataset
|
||||
self.eval_dataset = eval_dataset
|
||||
|
||||
hparams_file = os.path.join(self.model_dir, 'hparams.yaml')
|
||||
overrides = {
|
||||
'output_folder':
|
||||
self.work_dir,
|
||||
'seed':
|
||||
self.cfg.train.seed,
|
||||
'lr':
|
||||
self.cfg.train.optimizer.lr,
|
||||
'weight_decay':
|
||||
self.cfg.train.optimizer.weight_decay,
|
||||
'clip_grad_norm':
|
||||
self.cfg.train.optimizer.clip_grad_norm,
|
||||
'factor':
|
||||
self.cfg.train.lr_scheduler.factor,
|
||||
'patience':
|
||||
self.cfg.train.lr_scheduler.patience,
|
||||
'dont_halve_until_epoch':
|
||||
self.cfg.train.lr_scheduler.dont_halve_until_epoch,
|
||||
}
|
||||
# load hyper params
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
with open(hparams_file) as fin:
|
||||
self.hparams = load_hyperpyyaml(fin, overrides=overrides)
|
||||
# Create experiment directory
|
||||
sb.create_experiment_directory(
|
||||
experiment_directory=self.work_dir,
|
||||
hyperparams_to_save=hparams_file,
|
||||
overrides=overrides,
|
||||
)
|
||||
|
||||
run_opts = {
|
||||
'debug': False,
|
||||
'device': 'cpu',
|
||||
'data_parallel_backend': False,
|
||||
'distributed_launch': False,
|
||||
'distributed_backend': 'nccl',
|
||||
'find_unused_parameters': False
|
||||
}
|
||||
if self.device.type == 'cuda':
|
||||
run_opts['device'] = f'{self.device.type}:{self.device.index}'
|
||||
self.epoch_counter = sb.utils.epoch_loop.EpochCounter(self._max_epochs)
|
||||
self.hparams['checkpointer'].add_recoverables(
|
||||
{'counter': self.epoch_counter})
|
||||
modules = self.model.as_dict()
|
||||
self.hparams['checkpointer'].add_recoverables(modules)
|
||||
# Brain class initialization
|
||||
self.separator = Separation(
|
||||
modules=modules,
|
||||
opt_class=self.hparams['optimizer'],
|
||||
hparams=self.hparams,
|
||||
run_opts=run_opts,
|
||||
checkpointer=self.hparams['checkpointer'],
|
||||
)
|
||||
|
||||
def build_model(self) -> torch.nn.Module:
|
||||
""" Instantiate a pytorch model and return.
|
||||
"""
|
||||
model = Model.from_pretrained(
|
||||
self.model_dir, cfg_dict=self.cfg, training=True)
|
||||
if isinstance(model, TorchModel) and hasattr(model, 'model'):
|
||||
return model.model
|
||||
elif isinstance(model, torch.nn.Module):
|
||||
return model
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
self.separator.fit(
|
||||
self.epoch_counter,
|
||||
self.train_dataset,
|
||||
self.eval_dataset,
|
||||
train_loader_kwargs=self.hparams['dataloader_opts'],
|
||||
valid_loader_kwargs=self.hparams['dataloader_opts'],
|
||||
)
|
||||
|
||||
def evaluate(self, checkpoint_path: str, *args,
|
||||
**kwargs) -> Dict[str, float]:
|
||||
value = self.separator.evaluate(
|
||||
self.eval_dataset,
|
||||
test_loader_kwargs=self.hparams['dataloader_opts'],
|
||||
min_key=EVAL_KEY)
|
||||
return {EVAL_KEY: value}
|
||||
|
||||
|
||||
class Separation(sb.Brain):
|
||||
"""A subclass of speechbrain.Brain implements training steps."""
|
||||
|
||||
def compute_forward(self, mix, targets, stage, noise=None):
|
||||
"""Forward computations from the mixture to the separated signals."""
|
||||
|
||||
# Unpack lists and put tensors in the right device
|
||||
mix, mix_lens = mix
|
||||
mix, mix_lens = mix.to(self.device), mix_lens.to(self.device)
|
||||
|
||||
# Convert targets to tensor
|
||||
targets = torch.cat(
|
||||
[
|
||||
targets[i][0].unsqueeze(-1)
|
||||
for i in range(self.hparams.num_spks)
|
||||
],
|
||||
dim=-1,
|
||||
).to(self.device)
|
||||
|
||||
# Add speech distortions
|
||||
if stage == sb.Stage.TRAIN:
|
||||
with torch.no_grad():
|
||||
if self.hparams.use_speedperturb or self.hparams.use_rand_shift:
|
||||
mix, targets = self.add_speed_perturb(targets, mix_lens)
|
||||
|
||||
mix = targets.sum(-1)
|
||||
|
||||
if self.hparams.use_wavedrop:
|
||||
mix = self.hparams.wavedrop(mix, mix_lens)
|
||||
|
||||
if self.hparams.limit_training_signal_len:
|
||||
mix, targets = self.cut_signals(mix, targets)
|
||||
|
||||
# Separation
|
||||
mix_w = self.modules['encoder'](mix)
|
||||
est_mask = self.modules['masknet'](mix_w)
|
||||
mix_w = torch.stack([mix_w] * self.hparams.num_spks)
|
||||
sep_h = mix_w * est_mask
|
||||
|
||||
# Decoding
|
||||
est_source = torch.cat(
|
||||
[
|
||||
self.modules['decoder'](sep_h[i]).unsqueeze(-1)
|
||||
for i in range(self.hparams.num_spks)
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
# T changed after conv1d in encoder, fix it here
|
||||
T_origin = mix.size(1)
|
||||
T_est = est_source.size(1)
|
||||
if T_origin > T_est:
|
||||
est_source = F.pad(est_source, (0, 0, 0, T_origin - T_est))
|
||||
else:
|
||||
est_source = est_source[:, :T_origin, :]
|
||||
|
||||
return est_source, targets
|
||||
|
||||
def compute_objectives(self, predictions, targets):
|
||||
"""Computes the sinr loss"""
|
||||
return self.hparams.loss(targets, predictions)
|
||||
|
||||
# yapf: disable
|
||||
def fit_batch(self, batch):
|
||||
"""Trains one batch"""
|
||||
# Unpacking batch list
|
||||
mixture = batch.mix_sig
|
||||
targets = [batch.s1_sig, batch.s2_sig]
|
||||
|
||||
if self.hparams.num_spks == 3:
|
||||
targets.append(batch.s3_sig)
|
||||
|
||||
if self.auto_mix_prec:
|
||||
with autocast():
|
||||
predictions, targets = self.compute_forward(
|
||||
mixture, targets, sb.Stage.TRAIN)
|
||||
loss = self.compute_objectives(predictions, targets)
|
||||
# hard threshold the easy dataitems
|
||||
if self.hparams.threshold_byloss:
|
||||
th = self.hparams.threshold
|
||||
loss_to_keep = loss[loss > th]
|
||||
if loss_to_keep.nelement() > 0:
|
||||
loss = loss_to_keep.mean()
|
||||
else:
|
||||
print('loss has zero elements!!')
|
||||
else:
|
||||
loss = loss.mean()
|
||||
|
||||
# the fix for computational problems
|
||||
if loss < self.hparams.loss_upper_lim and loss.nelement() > 0:
|
||||
self.scaler.scale(loss).backward()
|
||||
if self.hparams.clip_grad_norm >= 0:
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.modules.parameters(),
|
||||
self.hparams.clip_grad_norm,
|
||||
)
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
self.nonfinite_count += 1
|
||||
logger.info(
|
||||
'infinite loss or empty loss! it happened {} times so far - skipping this batch'
|
||||
.format(self.nonfinite_count))
|
||||
loss.data = torch.tensor(0).to(self.device)
|
||||
else:
|
||||
predictions, targets = self.compute_forward(
|
||||
mixture, targets, sb.Stage.TRAIN)
|
||||
loss = self.compute_objectives(predictions, targets)
|
||||
if self.hparams.threshold_byloss:
|
||||
th = self.hparams.threshold
|
||||
loss_to_keep = loss[loss > th]
|
||||
if loss_to_keep.nelement() > 0:
|
||||
loss = loss_to_keep.mean()
|
||||
else:
|
||||
loss = loss.mean()
|
||||
# the fix for computational problems
|
||||
if loss < self.hparams.loss_upper_lim and loss.nelement() > 0:
|
||||
loss.backward()
|
||||
if self.hparams.clip_grad_norm >= 0:
|
||||
torch.nn.utils.clip_grad_norm_(self.modules.parameters(),
|
||||
self.hparams.clip_grad_norm)
|
||||
self.optimizer.step()
|
||||
else:
|
||||
self.nonfinite_count += 1
|
||||
logger.info(
|
||||
'infinite loss or empty loss! it happened {} times so far - skipping this batch'
|
||||
.format(self.nonfinite_count))
|
||||
loss.data = torch.tensor(0).to(self.device)
|
||||
self.optimizer.zero_grad()
|
||||
return loss.detach().cpu()
|
||||
# yapf: enable
|
||||
|
||||
def evaluate_batch(self, batch, stage):
|
||||
"""Computations needed for validation/test batches"""
|
||||
snt_id = batch.id
|
||||
mixture = batch.mix_sig
|
||||
targets = [batch.s1_sig, batch.s2_sig]
|
||||
if self.hparams.num_spks == 3:
|
||||
targets.append(batch.s3_sig)
|
||||
|
||||
with torch.no_grad():
|
||||
predictions, targets = self.compute_forward(
|
||||
mixture, targets, stage)
|
||||
loss = self.compute_objectives(predictions, targets)
|
||||
|
||||
# Manage audio file saving
|
||||
if stage == sb.Stage.TEST and self.hparams.save_audio:
|
||||
if hasattr(self.hparams, 'n_audio_to_save'):
|
||||
if self.hparams.n_audio_to_save > 0:
|
||||
self.save_audio(snt_id[0], mixture, targets, predictions)
|
||||
self.hparams.n_audio_to_save += -1
|
||||
else:
|
||||
self.save_audio(snt_id[0], mixture, targets, predictions)
|
||||
|
||||
return loss.mean().detach()
|
||||
|
||||
def on_stage_end(self, stage, stage_loss, epoch):
|
||||
"""Gets called at the end of a epoch."""
|
||||
# Compute/store important stats
|
||||
stage_stats = {'si-snr': stage_loss}
|
||||
if stage == sb.Stage.TRAIN:
|
||||
self.train_stats = stage_stats
|
||||
|
||||
# Perform end-of-iteration things, like annealing, logging, etc.
|
||||
if stage == sb.Stage.VALID:
|
||||
|
||||
# Learning rate annealing
|
||||
if isinstance(self.hparams.lr_scheduler,
|
||||
schedulers.ReduceLROnPlateau):
|
||||
current_lr, next_lr = self.hparams.lr_scheduler(
|
||||
[self.optimizer], epoch, stage_loss)
|
||||
schedulers.update_learning_rate(self.optimizer, next_lr)
|
||||
else:
|
||||
# if we do not use the reducelronplateau, we do not change the lr
|
||||
current_lr = self.hparams.optimizer.optim.param_groups[0]['lr']
|
||||
|
||||
self.hparams.train_logger.log_stats(
|
||||
stats_meta={
|
||||
'epoch': epoch,
|
||||
'lr': current_lr
|
||||
},
|
||||
train_stats=self.train_stats,
|
||||
valid_stats=stage_stats,
|
||||
)
|
||||
self.checkpointer.save_and_keep_only(
|
||||
meta={'si-snr': stage_stats['si-snr']},
|
||||
min_keys=['si-snr'],
|
||||
)
|
||||
elif stage == sb.Stage.TEST:
|
||||
self.hparams.train_logger.log_stats(
|
||||
stats_meta={
|
||||
'Epoch loaded': self.hparams.epoch_counter.current
|
||||
},
|
||||
test_stats=stage_stats,
|
||||
)
|
||||
|
||||
def add_speed_perturb(self, targets, targ_lens):
|
||||
"""Adds speed perturbation and random_shift to the input signals"""
|
||||
|
||||
min_len = -1
|
||||
recombine = False
|
||||
|
||||
if self.hparams.use_speedperturb:
|
||||
# Performing speed change (independently on each source)
|
||||
new_targets = []
|
||||
recombine = True
|
||||
|
||||
for i in range(targets.shape[-1]):
|
||||
new_target = self.hparams.speedperturb(targets[:, :, i],
|
||||
targ_lens)
|
||||
new_targets.append(new_target)
|
||||
if i == 0:
|
||||
min_len = new_target.shape[-1]
|
||||
else:
|
||||
if new_target.shape[-1] < min_len:
|
||||
min_len = new_target.shape[-1]
|
||||
|
||||
if self.hparams.use_rand_shift:
|
||||
# Performing random_shift (independently on each source)
|
||||
recombine = True
|
||||
for i in range(targets.shape[-1]):
|
||||
rand_shift = torch.randint(self.hparams.min_shift,
|
||||
self.hparams.max_shift, (1, ))
|
||||
new_targets[i] = new_targets[i].to(self.device)
|
||||
new_targets[i] = torch.roll(
|
||||
new_targets[i], shifts=(rand_shift[0], ), dims=1)
|
||||
|
||||
# Re-combination
|
||||
if recombine:
|
||||
if self.hparams.use_speedperturb:
|
||||
targets = torch.zeros(
|
||||
targets.shape[0],
|
||||
min_len,
|
||||
targets.shape[-1],
|
||||
device=targets.device,
|
||||
dtype=torch.float,
|
||||
)
|
||||
for i, new_target in enumerate(new_targets):
|
||||
targets[:, :, i] = new_targets[i][:, 0:min_len]
|
||||
|
||||
mix = targets.sum(-1)
|
||||
return mix, targets
|
||||
|
||||
def cut_signals(self, mixture, targets):
|
||||
"""This function selects a random segment of a given length within the mixture.
|
||||
The corresponding targets are selected accordingly"""
|
||||
randstart = torch.randint(
|
||||
0,
|
||||
1 + max(0, mixture.shape[1] - self.hparams.training_signal_len),
|
||||
(1, ),
|
||||
).item()
|
||||
targets = targets[:, randstart:randstart
|
||||
+ self.hparams.training_signal_len, :]
|
||||
mixture = mixture[:, randstart:randstart
|
||||
+ self.hparams.training_signal_len]
|
||||
return mixture, targets
|
||||
|
||||
def reset_layer_recursively(self, layer):
|
||||
"""Reinitializes the parameters of the neural networks"""
|
||||
if hasattr(layer, 'reset_parameters'):
|
||||
layer.reset_parameters()
|
||||
for child_layer in layer.modules():
|
||||
if layer != child_layer:
|
||||
self.reset_layer_recursively(child_layer)
|
||||
|
||||
def save_results(self, test_data):
|
||||
"""This script computes the SDR and SI-SNR metrics and saves
|
||||
them into a csv file"""
|
||||
|
||||
# This package is required for SDR computation
|
||||
from mir_eval.separation import bss_eval_sources
|
||||
|
||||
# Create folders where to store audio
|
||||
save_file = os.path.join(self.hparams.output_folder,
|
||||
'test_results.csv')
|
||||
|
||||
# Variable init
|
||||
all_sdrs = []
|
||||
all_sdrs_i = []
|
||||
all_sisnrs = []
|
||||
all_sisnrs_i = []
|
||||
csv_columns = ['snt_id', 'sdr', 'sdr_i', 'si-snr', 'si-snr_i']
|
||||
|
||||
test_loader = sb.dataio.dataloader.make_dataloader(
|
||||
test_data, **self.hparams.dataloader_opts)
|
||||
|
||||
with open(save_file, 'w') as results_csv:
|
||||
writer = csv.DictWriter(results_csv, fieldnames=csv_columns)
|
||||
writer.writeheader()
|
||||
|
||||
# Loop over all test sentence
|
||||
with tqdm(test_loader, dynamic_ncols=True) as t:
|
||||
for i, batch in enumerate(t):
|
||||
|
||||
# Apply Separation
|
||||
mixture, mix_len = batch.mix_sig
|
||||
snt_id = batch.id
|
||||
targets = [batch.s1_sig, batch.s2_sig]
|
||||
if self.hparams.num_spks == 3:
|
||||
targets.append(batch.s3_sig)
|
||||
|
||||
with torch.no_grad():
|
||||
predictions, targets = self.compute_forward(
|
||||
batch.mix_sig, targets, sb.Stage.TEST)
|
||||
|
||||
# Compute SI-SNR
|
||||
sisnr = self.compute_objectives(predictions, targets)
|
||||
|
||||
# Compute SI-SNR improvement
|
||||
mixture_signal = torch.stack(
|
||||
[mixture] * self.hparams.num_spks, dim=-1)
|
||||
mixture_signal = mixture_signal.to(targets.device)
|
||||
sisnr_baseline = self.compute_objectives(
|
||||
mixture_signal, targets)
|
||||
sisnr_i = sisnr.mean() - sisnr_baseline.mean()
|
||||
|
||||
# Compute SDR
|
||||
sdr, _, _, _ = bss_eval_sources(
|
||||
targets[0].t().cpu().numpy(),
|
||||
predictions[0].t().detach().cpu().numpy(),
|
||||
)
|
||||
|
||||
sdr_baseline, _, _, _ = bss_eval_sources(
|
||||
targets[0].t().cpu().numpy(),
|
||||
mixture_signal[0].t().detach().cpu().numpy(),
|
||||
)
|
||||
|
||||
sdr_i = sdr.mean() - sdr_baseline.mean()
|
||||
|
||||
# Saving on a csv file
|
||||
row = {
|
||||
'snt_id': snt_id[0],
|
||||
'sdr': sdr.mean(),
|
||||
'sdr_i': sdr_i,
|
||||
'si-snr': -sisnr.item(),
|
||||
'si-snr_i': -sisnr_i.item(),
|
||||
}
|
||||
writer.writerow(row)
|
||||
|
||||
# Metric Accumulation
|
||||
all_sdrs.append(sdr.mean())
|
||||
all_sdrs_i.append(sdr_i.mean())
|
||||
all_sisnrs.append(-sisnr.item())
|
||||
all_sisnrs_i.append(-sisnr_i.item())
|
||||
|
||||
row = {
|
||||
'snt_id': 'avg',
|
||||
'sdr': np.array(all_sdrs).mean(),
|
||||
'sdr_i': np.array(all_sdrs_i).mean(),
|
||||
'si-snr': np.array(all_sisnrs).mean(),
|
||||
'si-snr_i': np.array(all_sisnrs_i).mean(),
|
||||
}
|
||||
writer.writerow(row)
|
||||
|
||||
logger.info('Mean SISNR is {}'.format(np.array(all_sisnrs).mean()))
|
||||
logger.info('Mean SISNRi is {}'.format(np.array(all_sisnrs_i).mean()))
|
||||
logger.info('Mean SDR is {}'.format(np.array(all_sdrs).mean()))
|
||||
logger.info('Mean SDRi is {}'.format(np.array(all_sdrs_i).mean()))
|
||||
|
||||
def save_audio(self, snt_id, mixture, targets, predictions):
|
||||
'saves the test audio (mixture, targets, and estimated sources) on disk'
|
||||
|
||||
# Create outout folder
|
||||
save_path = os.path.join(self.hparams.save_folder, 'audio_results')
|
||||
if not os.path.exists(save_path):
|
||||
os.mkdir(save_path)
|
||||
|
||||
for ns in range(self.hparams.num_spks):
|
||||
|
||||
# Estimated source
|
||||
signal = predictions[0, :, ns]
|
||||
signal = signal / signal.abs().max() * 0.5
|
||||
save_file = os.path.join(
|
||||
save_path, 'item{}_source{}hat.wav'.format(snt_id, ns + 1))
|
||||
torchaudio.save(save_file,
|
||||
signal.unsqueeze(0).cpu(),
|
||||
self.hparams.sample_rate)
|
||||
|
||||
# Original source
|
||||
signal = targets[0, :, ns]
|
||||
signal = signal / signal.abs().max() * 0.5
|
||||
save_file = os.path.join(
|
||||
save_path, 'item{}_source{}.wav'.format(snt_id, ns + 1))
|
||||
torchaudio.save(save_file,
|
||||
signal.unsqueeze(0).cpu(),
|
||||
self.hparams.sample_rate)
|
||||
|
||||
# Mixture
|
||||
signal = mixture[0][0, :]
|
||||
signal = signal / signal.abs().max() * 0.5
|
||||
save_file = os.path.join(save_path, 'item{}_mix.wav'.format(snt_id))
|
||||
torchaudio.save(save_file,
|
||||
signal.unsqueeze(0).cpu(), self.hparams.sample_rate)
|
||||
@@ -13,6 +13,7 @@ librosa
|
||||
lxml
|
||||
matplotlib
|
||||
MinDAEC
|
||||
mir_eval>=0.7
|
||||
msgpack>=1.0.4
|
||||
nara_wpe
|
||||
nltk
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
import os.path
|
||||
import unittest
|
||||
|
||||
import numpy
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
@@ -19,7 +21,7 @@ class SpeechSeparationTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_normal(self):
|
||||
import torchaudio
|
||||
import soundfile as sf
|
||||
model_id = 'damo/speech_mossformer_separation_temporal_8k'
|
||||
separation = pipeline(Tasks.speech_separation, model=model_id)
|
||||
result = separation(os.path.join(os.getcwd(), MIX_SPEECH_FILE))
|
||||
@@ -27,8 +29,8 @@ class SpeechSeparationTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
self.assertEqual(len(result[OutputKeys.OUTPUT_PCM_LIST]), 2)
|
||||
for i, signal in enumerate(result[OutputKeys.OUTPUT_PCM_LIST]):
|
||||
save_file = f'output_spk{i}.wav'
|
||||
# Estimated source
|
||||
torchaudio.save(save_file, signal, 8000)
|
||||
sf.write(save_file, numpy.frombuffer(signal, dtype=numpy.int16),
|
||||
8000)
|
||||
|
||||
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
|
||||
def test_demo_compatibility(self):
|
||||
|
||||
74
tests/trainers/audio/test_separation_trainer.py
Normal file
74
tests/trainers/audio/test_separation_trainer.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.preprocessors.audio import AudioBrainPreprocessor
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
MIX_SPEECH_FILE = 'data/test/audios/mix_speech.wav'
|
||||
S1_SPEECH_FILE = 'data/test/audios/s1_speech.wav'
|
||||
S2_SPEECH_FILE = 'data/test/audios/s2_speech.wav'
|
||||
|
||||
|
||||
class TestSeparationTrainer(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
self.model_id = 'damo/speech_mossformer_separation_temporal_8k'
|
||||
|
||||
csv_path = os.path.join(self.tmp_dir, 'test.csv')
|
||||
mix_path = os.path.join(os.getcwd(), MIX_SPEECH_FILE)
|
||||
s1_path = os.path.join(os.getcwd(), S1_SPEECH_FILE)
|
||||
s2_path = os.path.join(os.getcwd(), S2_SPEECH_FILE)
|
||||
with open(csv_path, 'w') as w:
|
||||
w.write(f'id,mix_wav:FILE,s1_wav:FILE,s2_wav:FILE\n'
|
||||
f'0,{mix_path},{s1_path},{s2_path}\n')
|
||||
self.dataset = MsDataset.load(
|
||||
'csv', data_files={
|
||||
'test': [csv_path]
|
||||
}).to_torch_dataset(preprocessors=[
|
||||
AudioBrainPreprocessor(
|
||||
takes='mix_wav:FILE', provides='mix_sig'),
|
||||
AudioBrainPreprocessor(takes='s1_wav:FILE', provides='s1_sig'),
|
||||
AudioBrainPreprocessor(takes='s2_wav:FILE', provides='s2_sig')
|
||||
])
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer(self):
|
||||
kwargs = dict(
|
||||
model=self.model_id,
|
||||
train_dataset=self.dataset,
|
||||
eval_dataset=self.dataset,
|
||||
max_epochs=2,
|
||||
work_dir=self.tmp_dir)
|
||||
trainer = build_trainer(
|
||||
Trainers.speech_separation, default_args=kwargs)
|
||||
# model placement
|
||||
trainer.model.load_check_point(device=trainer.device)
|
||||
trainer.train()
|
||||
|
||||
logging_path = os.path.join(self.tmp_dir, 'train_log.txt')
|
||||
self.assertTrue(
|
||||
os.path.exists(logging_path),
|
||||
f'Cannot find logging file {logging_path}')
|
||||
save_dir = os.path.join(self.tmp_dir, 'save')
|
||||
checkpoint_dirs = os.listdir(save_dir)
|
||||
self.assertEqual(
|
||||
len(checkpoint_dirs), 2, f'Cannot find checkpoint in {save_dir}!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user