Files
modelscope/tests/trainers/test_tinynas_damoyolo_trainer.py
2026-03-07 22:40:43 +08:00

150 lines
5.3 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import subprocess
import unittest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.trainers import build_trainer
from modelscope.utils.import_utils import exists
from modelscope.utils.test_utils import test_level
def _setup():
model_id = 'damo/cv_tinynas_object-detection_damoyolo'
cache_path = snapshot_download(model_id)
return cache_path
class TestTinynasDamoyoloTrainerSingleGPU(unittest.TestCase):
def setUp(self):
# pycocotools==2.0.8
subprocess.getstatusoutput('pip install pycocotools==2.0.8')
self.model_id = 'damo/cv_tinynas_object-detection_damoyolo'
self.cache_path = _setup()
def tearDown(self) -> None:
super().tearDown()
shutil.rmtree('./workdirs', ignore_errors=True)
@unittest.skipUnless(
exists('transformers<5.0'),
'Skip test because transformers version is too high.')
def test_trainer_from_scratch_singleGPU(self):
kwargs = dict(
cfg_file=os.path.join(self.cache_path, 'configuration.json'),
gpu_ids=[
0,
],
batch_size=2,
max_epochs=3,
num_classes=80,
base_lr_per_img=0.001,
cache_path=self.cache_path,
train_image_dir='./data/test/images/image_detection/images',
val_image_dir='./data/test/images/image_detection/images',
train_ann=
'./data/test/images/image_detection/annotations/coco_sample.json',
val_ann=
'./data/test/images/image_detection/annotations/coco_sample.json',
work_dir='./workdirs',
exp_name='damoyolo_s',
)
trainer = build_trainer(
name=Trainers.tinynas_damoyolo, default_args=kwargs)
trainer.train()
trainer.evaluate(
checkpoint_path=os.path.join('./workdirs/damoyolo_s',
'epoch_3_ckpt.pth'))
@unittest.skipUnless(
exists('transformers<5.0'),
'Skip test because transformers version is too high.')
def test_trainer_from_scratch_singleGPU_model_id(self):
kwargs = dict(
model=self.model_id,
gpu_ids=[
0,
],
batch_size=2,
max_epochs=3,
num_classes=80,
load_pretrain=True,
base_lr_per_img=0.001,
train_image_dir='./data/test/images/image_detection/images',
val_image_dir='./data/test/images/image_detection/images',
train_ann=
'./data/test/images/image_detection/annotations/coco_sample.json',
val_ann=
'./data/test/images/image_detection/annotations/coco_sample.json',
work_dir='./workdirs',
exp_name='damoyolo_s',
)
trainer = build_trainer(
name=Trainers.tinynas_damoyolo, default_args=kwargs)
trainer.train()
trainer.evaluate(
checkpoint_path=os.path.join(self.cache_path,
'damoyolo_tinynasL25_S.pt'))
@unittest.skip('multiGPU test is verified offline')
def test_trainer_from_scratch_multiGPU(self):
kwargs = dict(
cfg_file=os.path.join(self.cache_path, 'configuration.json'),
gpu_ids=[
0,
1,
],
batch_size=32,
max_epochs=3,
num_classes=1,
cache_path=self.cache_path,
train_image_dir='./data/test/images/image_detection/images',
val_image_dir='./data/test/images/image_detection/images',
train_ann=
'./data/test/images/image_detection/annotations/coco_sample.json',
val_ann=
'./data/test/images/image_detection/annotations/coco_sample.json',
work_dir='./workdirs',
exp_name='damoyolo_s',
)
trainer = build_trainer(
name=Trainers.tinynas_damoyolo, default_args=kwargs)
trainer.train()
@unittest.skipUnless(
exists('transformers<5.0'),
'Skip test because transformers version is too high.')
def test_trainer_finetune_singleGPU(self):
kwargs = dict(
cfg_file=os.path.join(self.cache_path, 'configuration.json'),
gpu_ids=[
0,
],
batch_size=16,
max_epochs=3,
num_classes=1,
load_pretrain=True,
pretrain_model=os.path.join(self.cache_path,
'damoyolo_tinynasL25_S.pt'),
cache_path=self.cache_path,
train_image_dir='./data/test/images/image_detection/images',
val_image_dir='./data/test/images/image_detection/images',
train_ann=
'./data/test/images/image_detection/annotations/coco_sample.json',
val_ann=
'./data/test/images/image_detection/annotations/coco_sample.json',
work_dir='./workdirs',
exp_name='damoyolo_s',
)
trainer = build_trainer(
name=Trainers.tinynas_damoyolo, default_args=kwargs)
trainer.train()
if __name__ == '__main__':
unittest.main()