mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
[to #42322933] reduce the GPU usage of dialog trianer
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10955485
This commit is contained in:
@@ -17,7 +17,7 @@ class TestDialogModelingTrainer(unittest.TestCase):
|
||||
model_id = 'damo/nlp_space_pretrained-dialog-model'
|
||||
output_dir = './dialog_fintune_result'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer_with_model_and_args(self):
|
||||
# download data set
|
||||
data_multiwoz = MsDataset.load(
|
||||
@@ -33,13 +33,13 @@ class TestDialogModelingTrainer(unittest.TestCase):
|
||||
def cfg_modify_fn(cfg):
|
||||
config = {
|
||||
'seed': 10,
|
||||
'gpu': 4,
|
||||
'gpu': 1,
|
||||
'use_data_distributed': False,
|
||||
'valid_metric_name': '-loss',
|
||||
'num_epochs': 60,
|
||||
'save_dir': self.output_dir,
|
||||
'token_loss': True,
|
||||
'batch_size': 32,
|
||||
'batch_size': 4,
|
||||
'log_steps': 10,
|
||||
'valid_steps': 0,
|
||||
'save_checkpoint': True,
|
||||
@@ -71,3 +71,7 @@ class TestDialogModelingTrainer(unittest.TestCase):
|
||||
assert os.path.exists(checkpoint_path)
|
||||
trainer.evaluate(checkpoint_path=checkpoint_path)
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user