mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
[to #42322933] bugfix: separation.evaluate() failed
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11426908
This commit is contained in:
@@ -128,6 +128,7 @@ class SeparationTrainer(BaseTrainer):
|
||||
if self.device.type == 'cuda':
|
||||
run_opts['device'] = f'{self.device.type}:{self.device.index}'
|
||||
self.epoch_counter = sb.utils.epoch_loop.EpochCounter(self._max_epochs)
|
||||
self.hparams['epoch_counter'] = self.epoch_counter
|
||||
self.hparams['checkpointer'].add_recoverables(
|
||||
{'counter': self.epoch_counter})
|
||||
modules = self.model.as_dict()
|
||||
@@ -162,6 +163,10 @@ class SeparationTrainer(BaseTrainer):
|
||||
|
||||
def evaluate(self, checkpoint_path: str, *args,
|
||||
**kwargs) -> Dict[str, float]:
|
||||
if checkpoint_path:
|
||||
self.hparams.checkpointer.checkpoints_dir = checkpoint_path
|
||||
else:
|
||||
self.model.load_check_point(device=self.device)
|
||||
value = self.separator.evaluate(
|
||||
self.eval_dataset,
|
||||
test_loader_kwargs=self.hparams['dataloader_opts'],
|
||||
@@ -334,7 +339,6 @@ class Separation(sb.Brain):
|
||||
|
||||
# Perform end-of-iteration things, like annealing, logging, etc.
|
||||
if stage == sb.Stage.VALID:
|
||||
|
||||
# Learning rate annealing
|
||||
if isinstance(self.hparams.lr_scheduler,
|
||||
schedulers.ReduceLROnPlateau):
|
||||
@@ -357,13 +361,6 @@ class Separation(sb.Brain):
|
||||
meta={'si-snr': stage_stats['si-snr']},
|
||||
min_keys=['si-snr'],
|
||||
)
|
||||
elif stage == sb.Stage.TEST:
|
||||
self.hparams.train_logger.log_stats(
|
||||
stats_meta={
|
||||
'Epoch loaded': self.hparams.epoch_counter.current
|
||||
},
|
||||
test_stats=stage_stats,
|
||||
)
|
||||
|
||||
def add_speed_perturb(self, targets, targ_lens):
|
||||
"""Adds speed perturbation and random_shift to the input signals"""
|
||||
|
||||
@@ -35,12 +35,16 @@ class TestSeparationTrainer(unittest.TestCase):
|
||||
self.dataset = MsDataset.load(
|
||||
'csv', data_files={
|
||||
'test': [csv_path]
|
||||
}).to_torch_dataset(preprocessors=[
|
||||
AudioBrainPreprocessor(
|
||||
takes='mix_wav:FILE', provides='mix_sig'),
|
||||
AudioBrainPreprocessor(takes='s1_wav:FILE', provides='s1_sig'),
|
||||
AudioBrainPreprocessor(takes='s2_wav:FILE', provides='s2_sig')
|
||||
])
|
||||
}).to_torch_dataset(
|
||||
preprocessors=[
|
||||
AudioBrainPreprocessor(
|
||||
takes='mix_wav:FILE', provides='mix_sig'),
|
||||
AudioBrainPreprocessor(
|
||||
takes='s1_wav:FILE', provides='s1_sig'),
|
||||
AudioBrainPreprocessor(
|
||||
takes='s2_wav:FILE', provides='s2_sig')
|
||||
],
|
||||
to_tensor=False)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
@@ -69,6 +73,19 @@ class TestSeparationTrainer(unittest.TestCase):
|
||||
self.assertEqual(
|
||||
len(checkpoint_dirs), 2, f'Cannot find checkpoint in {save_dir}!')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_eval(self):
|
||||
kwargs = dict(
|
||||
model=self.model_id,
|
||||
train_dataset=None,
|
||||
eval_dataset=self.dataset,
|
||||
max_epochs=2,
|
||||
work_dir=self.tmp_dir)
|
||||
trainer = build_trainer(
|
||||
Trainers.speech_separation, default_args=kwargs)
|
||||
result = trainer.evaluate(None)
|
||||
self.assertTrue('si-snr' in result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user