[to #42322933] bugfix: separation.evaluate() failed

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11426908
This commit is contained in:
bin.xue
2023-01-13 09:19:31 +00:00
committed by wenmeng.zwm
parent 241c675f60
commit 854c1e6cbf
2 changed files with 28 additions and 14 deletions

View File

@@ -128,6 +128,7 @@ class SeparationTrainer(BaseTrainer):
if self.device.type == 'cuda':
run_opts['device'] = f'{self.device.type}:{self.device.index}'
self.epoch_counter = sb.utils.epoch_loop.EpochCounter(self._max_epochs)
self.hparams['epoch_counter'] = self.epoch_counter
self.hparams['checkpointer'].add_recoverables(
{'counter': self.epoch_counter})
modules = self.model.as_dict()
@@ -162,6 +163,10 @@ class SeparationTrainer(BaseTrainer):
def evaluate(self, checkpoint_path: str, *args,
**kwargs) -> Dict[str, float]:
if checkpoint_path:
self.hparams.checkpointer.checkpoints_dir = checkpoint_path
else:
self.model.load_check_point(device=self.device)
value = self.separator.evaluate(
self.eval_dataset,
test_loader_kwargs=self.hparams['dataloader_opts'],
@@ -334,7 +339,6 @@ class Separation(sb.Brain):
# Perform end-of-iteration things, like annealing, logging, etc.
if stage == sb.Stage.VALID:
# Learning rate annealing
if isinstance(self.hparams.lr_scheduler,
schedulers.ReduceLROnPlateau):
@@ -357,13 +361,6 @@ class Separation(sb.Brain):
meta={'si-snr': stage_stats['si-snr']},
min_keys=['si-snr'],
)
elif stage == sb.Stage.TEST:
self.hparams.train_logger.log_stats(
stats_meta={
'Epoch loaded': self.hparams.epoch_counter.current
},
test_stats=stage_stats,
)
def add_speed_perturb(self, targets, targ_lens):
"""Adds speed perturbation and random_shift to the input signals"""

View File

@@ -35,12 +35,16 @@ class TestSeparationTrainer(unittest.TestCase):
self.dataset = MsDataset.load(
'csv', data_files={
'test': [csv_path]
}).to_torch_dataset(preprocessors=[
AudioBrainPreprocessor(
takes='mix_wav:FILE', provides='mix_sig'),
AudioBrainPreprocessor(takes='s1_wav:FILE', provides='s1_sig'),
AudioBrainPreprocessor(takes='s2_wav:FILE', provides='s2_sig')
])
}).to_torch_dataset(
preprocessors=[
AudioBrainPreprocessor(
takes='mix_wav:FILE', provides='mix_sig'),
AudioBrainPreprocessor(
takes='s1_wav:FILE', provides='s1_sig'),
AudioBrainPreprocessor(
takes='s2_wav:FILE', provides='s2_sig')
],
to_tensor=False)
def tearDown(self):
shutil.rmtree(self.tmp_dir, ignore_errors=True)
@@ -69,6 +73,19 @@ class TestSeparationTrainer(unittest.TestCase):
self.assertEqual(
len(checkpoint_dirs), 2, f'Cannot find checkpoint in {save_dir}!')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_eval(self):
kwargs = dict(
model=self.model_id,
train_dataset=None,
eval_dataset=self.dataset,
max_epochs=2,
work_dir=self.tmp_dir)
trainer = build_trainer(
Trainers.speech_separation, default_args=kwargs)
result = trainer.evaluate(None)
self.assertTrue('si-snr' in result)
if __name__ == '__main__':
unittest.main()