From 7298bd2bb450f6be85346b9b7076c7e67f3892de Mon Sep 17 00:00:00 2001 From: "ada.drx" Date: Tue, 7 Feb 2023 02:55:33 +0000 Subject: [PATCH] mgeo fix finetune for rerank test case and reduce UT time * reduce UT time * fix finetune for rerank test case Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11563740 --- .../multi_modal/mgeo_ranking_trainer.py | 2 ++ tests/trainers/test_finetune_mgeo.py | 19 +++++++++---------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/modelscope/trainers/multi_modal/mgeo_ranking_trainer.py b/modelscope/trainers/multi_modal/mgeo_ranking_trainer.py index 6079a8a8..16a87bb1 100644 --- a/modelscope/trainers/multi_modal/mgeo_ranking_trainer.py +++ b/modelscope/trainers/multi_modal/mgeo_ranking_trainer.py @@ -290,6 +290,8 @@ class MGeoRankingTrainer(NlpEpochBasedTrainer): label_list.extend(label_ids) logits_list.extend(logits) qid_list.extend(qids) + if _step + 1 > self._eval_iters_per_epoch: + break logger.info('Inference time = {:.2f}s, [{:.4f} ms / sample] '.format( total_spent_time, total_spent_time * 1000 / total_samples)) diff --git a/tests/trainers/test_finetune_mgeo.py b/tests/trainers/test_finetune_mgeo.py index b492497b..02a081b9 100644 --- a/tests/trainers/test_finetune_mgeo.py +++ b/tests/trainers/test_finetune_mgeo.py @@ -50,7 +50,7 @@ class TestFinetuneMGeo(unittest.TestCase): results_files = os.listdir(self.tmp_dir) self.assertIn(f'{trainer.timestamp}.log.json', results_files) - @unittest.skipUnless(test_level() >= 4, 'skip test in current test level') + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_finetune_geotes_rerank(self): def cfg_modify_fn(cfg): @@ -82,7 +82,8 @@ class TestFinetuneMGeo(unittest.TestCase): cfg.train.dataloader.batch_size_per_gpu = 3 cfg.train.dataloader.workers_per_gpu = 16 cfg.evaluation.dataloader.workers_per_gpu = 16 - + cfg.train.train_iters_per_epoch = 10 + cfg.evaluation.val_iters_per_epoch = 10 cfg['evaluation']['metrics'] = 'mrr@1' cfg.train.max_epochs = 1 cfg.model['neg_sample'] = neg_sample @@ -139,14 +140,8 @@ class TestFinetuneMGeo(unittest.TestCase): split='validation', namespace='damo') - dataset = MsDataset.load( - 'json', - data_files={ - 'train': [train_dataset['train'] + '/train.json'], - 'test': [dev_dataset['validation'] + '/dev.json'] - }) - train_ds = dataset['train'].to_hf_dataset() - dev_ds = dataset['test'].to_hf_dataset() + train_ds = train_dataset['train'] + dev_ds = dev_dataset['validation'] model_id = 'damo/mgeo_backbone_chinese_base' self.finetune( @@ -170,6 +165,8 @@ class TestFinetuneMGeo(unittest.TestCase): cfg.evaluation.dataloader.batch_size_per_gpu = 64 cfg.train.optimizer.lr = 2e-5 cfg.train.max_epochs = 1 + cfg.train.train_iters_per_epoch = 10 + cfg.evaluation.val_iters_per_epoch = 10 cfg['dataset'] = { 'train': { @@ -236,6 +233,8 @@ class TestFinetuneMGeo(unittest.TestCase): } cfg.train.max_epochs = 1 cfg.train.dataloader.batch_size_per_gpu = 32 + cfg.train.train_iters_per_epoch = 10 + cfg.evaluation.val_iters_per_epoch = 10 cfg.train.optimizer.lr = 3e-5 cfg.train.hooks = [{ 'type': 'CheckpointHook',