mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 20:19:22 +01:00
mgeo fix finetune for rerank test case and reduce UT time
* reduce UT time * fix finetune for rerank test case Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11563740
This commit is contained in:
@@ -290,6 +290,8 @@ class MGeoRankingTrainer(NlpEpochBasedTrainer):
|
||||
label_list.extend(label_ids)
|
||||
logits_list.extend(logits)
|
||||
qid_list.extend(qids)
|
||||
if _step + 1 > self._eval_iters_per_epoch:
|
||||
break
|
||||
|
||||
logger.info('Inference time = {:.2f}s, [{:.4f} ms / sample] '.format(
|
||||
total_spent_time, total_spent_time * 1000 / total_samples))
|
||||
|
||||
@@ -50,7 +50,7 @@ class TestFinetuneMGeo(unittest.TestCase):
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 4, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_finetune_geotes_rerank(self):
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
@@ -82,7 +82,8 @@ class TestFinetuneMGeo(unittest.TestCase):
|
||||
cfg.train.dataloader.batch_size_per_gpu = 3
|
||||
cfg.train.dataloader.workers_per_gpu = 16
|
||||
cfg.evaluation.dataloader.workers_per_gpu = 16
|
||||
|
||||
cfg.train.train_iters_per_epoch = 10
|
||||
cfg.evaluation.val_iters_per_epoch = 10
|
||||
cfg['evaluation']['metrics'] = 'mrr@1'
|
||||
cfg.train.max_epochs = 1
|
||||
cfg.model['neg_sample'] = neg_sample
|
||||
@@ -139,14 +140,8 @@ class TestFinetuneMGeo(unittest.TestCase):
|
||||
split='validation',
|
||||
namespace='damo')
|
||||
|
||||
dataset = MsDataset.load(
|
||||
'json',
|
||||
data_files={
|
||||
'train': [train_dataset['train'] + '/train.json'],
|
||||
'test': [dev_dataset['validation'] + '/dev.json']
|
||||
})
|
||||
train_ds = dataset['train'].to_hf_dataset()
|
||||
dev_ds = dataset['test'].to_hf_dataset()
|
||||
train_ds = train_dataset['train']
|
||||
dev_ds = dev_dataset['validation']
|
||||
|
||||
model_id = 'damo/mgeo_backbone_chinese_base'
|
||||
self.finetune(
|
||||
@@ -170,6 +165,8 @@ class TestFinetuneMGeo(unittest.TestCase):
|
||||
cfg.evaluation.dataloader.batch_size_per_gpu = 64
|
||||
cfg.train.optimizer.lr = 2e-5
|
||||
cfg.train.max_epochs = 1
|
||||
cfg.train.train_iters_per_epoch = 10
|
||||
cfg.evaluation.val_iters_per_epoch = 10
|
||||
|
||||
cfg['dataset'] = {
|
||||
'train': {
|
||||
@@ -236,6 +233,8 @@ class TestFinetuneMGeo(unittest.TestCase):
|
||||
}
|
||||
cfg.train.max_epochs = 1
|
||||
cfg.train.dataloader.batch_size_per_gpu = 32
|
||||
cfg.train.train_iters_per_epoch = 10
|
||||
cfg.evaluation.val_iters_per_epoch = 10
|
||||
cfg.train.optimizer.lr = 3e-5
|
||||
cfg.train.hooks = [{
|
||||
'type': 'CheckpointHook',
|
||||
|
||||
Reference in New Issue
Block a user