Files
modelscope/tests/trainers/hooks/test_checkpoint_hook.py
wenmeng.zwm 231f400133 [to #43112534] finetune support and first case
co-contributed with 夕陌&雨泓

 * add torch epoch based trainer and dis utils
 * add hooks including optimizer, lrscheduler, logging, checkpoint, evaluation, time profiling
 * add torch mdoel base and test
 * add optimizer and lrscheduler module
 * add sbert for text classification example
 * add task_dataset for dataset-level processor

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9338412
2022-07-14 16:25:55 +08:00

109 lines
2.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
from abc import ABCMeta
import json
import torch
from torch import nn
from torch.utils.data import Dataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile
class DummyDataset(Dataset, metaclass=ABCMeta):
def __len__(self):
return 20
def __getitem__(self, idx):
return dict(feat=torch.rand((5, )), label=torch.randint(0, 4, (1, )))
class DummyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 4)
self.bn = nn.BatchNorm1d(4)
def forward(self, feat, labels):
x = self.linear(feat)
x = self.bn(x)
loss = torch.sum(x)
return dict(logits=x, loss=loss)
class CheckpointHookTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
def tearDown(self):
super().tearDown()
shutil.rmtree(self.tmp_dir)
def test_checkpoint_hook(self):
json_cfg = {
'task': 'image_classification',
'train': {
'work_dir': self.tmp_dir,
'dataloader': {
'batch_size_per_gpu': 2,
'workers_per_gpu': 1
},
'optimizer': {
'type': 'SGD',
'lr': 0.01,
'options': {
'grad_clip': {
'max_norm': 2.0
}
}
},
'lr_scheduler': {
'type': 'StepLR',
'step_size': 2,
'options': {
'warmup': {
'type': 'LinearWarmup',
'warmup_iters': 2
}
}
},
'hooks': [{
'type': 'CheckpointHook',
'interval': 1
}]
}
}
config_path = os.path.join(self.tmp_dir, ModelFile.CONFIGURATION)
with open(config_path, 'w') as f:
json.dump(json_cfg, f)
trainer_name = 'EpochBasedTrainer'
kwargs = dict(
cfg_file=config_path,
model=DummyModel(),
data_collator=None,
train_dataset=DummyDataset(),
max_epochs=2)
trainer = build_trainer(trainer_name, kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn('epoch_1.pth', results_files)
self.assertIn('epoch_2.pth', results_files)
if __name__ == '__main__':
unittest.main()