mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
100 lines
2.8 KiB
Python
100 lines
2.8 KiB
Python
import os
|
|
import shutil
|
|
import unittest
|
|
|
|
import json
|
|
import requests
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
|
|
from modelscope.hub.snapshot_download import snapshot_download
|
|
from modelscope.metainfo import Trainers
|
|
from modelscope.msdatasets import MsDataset
|
|
from modelscope.trainers import build_trainer
|
|
from modelscope.trainers.multi_modal.team.team_trainer_utils import (
|
|
collate_fn, train_mapping, val_mapping)
|
|
from modelscope.utils.config import Config
|
|
from modelscope.utils.constant import DownloadMode, ModeKeys, ModelFile
|
|
from modelscope.utils.logger import get_logger
|
|
from modelscope.utils.test_utils import test_level
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
def train_worker(device_id):
|
|
model_id = 'damo/multi-modal_team-vit-large-patch14_multi-modal-similarity'
|
|
ckpt_dir = './ckpt'
|
|
os.makedirs(ckpt_dir, exist_ok=True)
|
|
# Use epoch=1 for faster training here
|
|
cfg = Config({
|
|
'framework': 'pytorch',
|
|
'task': 'multi-modal-similarity',
|
|
'pipeline': {
|
|
'type': 'multi-modal-similarity'
|
|
},
|
|
'model': {
|
|
'type': 'team-multi-modal-similarity'
|
|
},
|
|
'dataset': {
|
|
'name': 'Caltech101',
|
|
'class_num': 101
|
|
},
|
|
'preprocessor': {},
|
|
'train': {
|
|
'epoch': 1,
|
|
'batch_size': 32,
|
|
'ckpt_dir': ckpt_dir
|
|
},
|
|
'evaluation': {
|
|
'batch_size': 64
|
|
}
|
|
})
|
|
cfg_file = '{}/{}'.format(ckpt_dir, ModelFile.CONFIGURATION)
|
|
cfg.dump(cfg_file)
|
|
|
|
train_dataset = MsDataset.load(
|
|
cfg.dataset.name,
|
|
namespace='modelscope',
|
|
split='train',
|
|
download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset()
|
|
train_dataset = train_dataset.with_transform(train_mapping)
|
|
val_dataset = MsDataset.load(
|
|
cfg.dataset.name,
|
|
namespace='modelscope',
|
|
split='validation',
|
|
download_mode=DownloadMode.FORCE_REDOWNLOAD).to_hf_dataset()
|
|
val_dataset = val_dataset.with_transform(val_mapping)
|
|
|
|
default_args = dict(
|
|
cfg_file=cfg_file,
|
|
model=model_id,
|
|
device_id=device_id,
|
|
data_collator=collate_fn,
|
|
train_dataset=train_dataset,
|
|
val_dataset=val_dataset)
|
|
|
|
trainer = build_trainer(
|
|
name=Trainers.image_classification_team, default_args=default_args)
|
|
trainer.train()
|
|
trainer.evaluate()
|
|
|
|
|
|
class TEAMTransferTrainerTest(unittest.TestCase):
|
|
|
|
def tearDown(self) -> None:
|
|
super().tearDown()
|
|
shutil.rmtree('./ckpt')
|
|
|
|
@unittest.skipUnless(test_level() >= 1, 'dataset is not available')
|
|
def test_trainer(self):
|
|
if torch.cuda.device_count() > 0:
|
|
train_worker(device_id=0)
|
|
else:
|
|
train_worker(device_id=-1)
|
|
logger.info('Training done')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|