From 854c1e6cbf037b89e04aa324bc7109facfdd21c1 Mon Sep 17 00:00:00 2001 From: "bin.xue" Date: Fri, 13 Jan 2023 09:19:31 +0000 Subject: [PATCH] [to #42322933] bugfix: separation.evaluate() failed Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11426908 --- .../trainers/audio/separation_trainer.py | 13 ++++----- .../trainers/audio/test_separation_trainer.py | 29 +++++++++++++++---- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/modelscope/trainers/audio/separation_trainer.py b/modelscope/trainers/audio/separation_trainer.py index c89f479a..c425325c 100644 --- a/modelscope/trainers/audio/separation_trainer.py +++ b/modelscope/trainers/audio/separation_trainer.py @@ -128,6 +128,7 @@ class SeparationTrainer(BaseTrainer): if self.device.type == 'cuda': run_opts['device'] = f'{self.device.type}:{self.device.index}' self.epoch_counter = sb.utils.epoch_loop.EpochCounter(self._max_epochs) + self.hparams['epoch_counter'] = self.epoch_counter self.hparams['checkpointer'].add_recoverables( {'counter': self.epoch_counter}) modules = self.model.as_dict() @@ -162,6 +163,10 @@ class SeparationTrainer(BaseTrainer): def evaluate(self, checkpoint_path: str, *args, **kwargs) -> Dict[str, float]: + if checkpoint_path: + self.hparams.checkpointer.checkpoints_dir = checkpoint_path + else: + self.model.load_check_point(device=self.device) value = self.separator.evaluate( self.eval_dataset, test_loader_kwargs=self.hparams['dataloader_opts'], @@ -334,7 +339,6 @@ class Separation(sb.Brain): # Perform end-of-iteration things, like annealing, logging, etc. if stage == sb.Stage.VALID: - # Learning rate annealing if isinstance(self.hparams.lr_scheduler, schedulers.ReduceLROnPlateau): @@ -357,13 +361,6 @@ class Separation(sb.Brain): meta={'si-snr': stage_stats['si-snr']}, min_keys=['si-snr'], ) - elif stage == sb.Stage.TEST: - self.hparams.train_logger.log_stats( - stats_meta={ - 'Epoch loaded': self.hparams.epoch_counter.current - }, - test_stats=stage_stats, - ) def add_speed_perturb(self, targets, targ_lens): """Adds speed perturbation and random_shift to the input signals""" diff --git a/tests/trainers/audio/test_separation_trainer.py b/tests/trainers/audio/test_separation_trainer.py index 1023e805..4fdbab18 100644 --- a/tests/trainers/audio/test_separation_trainer.py +++ b/tests/trainers/audio/test_separation_trainer.py @@ -35,12 +35,16 @@ class TestSeparationTrainer(unittest.TestCase): self.dataset = MsDataset.load( 'csv', data_files={ 'test': [csv_path] - }).to_torch_dataset(preprocessors=[ - AudioBrainPreprocessor( - takes='mix_wav:FILE', provides='mix_sig'), - AudioBrainPreprocessor(takes='s1_wav:FILE', provides='s1_sig'), - AudioBrainPreprocessor(takes='s2_wav:FILE', provides='s2_sig') - ]) + }).to_torch_dataset( + preprocessors=[ + AudioBrainPreprocessor( + takes='mix_wav:FILE', provides='mix_sig'), + AudioBrainPreprocessor( + takes='s1_wav:FILE', provides='s1_sig'), + AudioBrainPreprocessor( + takes='s2_wav:FILE', provides='s2_sig') + ], + to_tensor=False) def tearDown(self): shutil.rmtree(self.tmp_dir, ignore_errors=True) @@ -69,6 +73,19 @@ class TestSeparationTrainer(unittest.TestCase): self.assertEqual( len(checkpoint_dirs), 2, f'Cannot find checkpoint in {save_dir}!') + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_eval(self): + kwargs = dict( + model=self.model_id, + train_dataset=None, + eval_dataset=self.dataset, + max_epochs=2, + work_dir=self.tmp_dir) + trainer = build_trainer( + Trainers.speech_separation, default_args=kwargs) + result = trainer.evaluate(None) + self.assertTrue('si-snr' in result) + if __name__ == '__main__': unittest.main()