Files
modelscope/examples/pytorch/finetune_image_classification.py

87 lines
2.7 KiB
Python
Raw Normal View History

import os
2023-01-11 10:43:56 +08:00
from modelscope.metainfo import Trainers
from modelscope.msdatasets.ms_dataset import MsDataset
from modelscope.trainers.builder import build_trainer
2023-01-11 10:43:56 +08:00
from modelscope.trainers.training_args import (ArgAttr, CliArgumentParser,
training_args)
def define_parser():
2023-01-11 10:43:56 +08:00
training_args.num_classes = ArgAttr(
cfg_node_name=[
'model.mm_model.head.num_classes',
'model.mm_model.train_cfg.augments.0.num_classes',
'model.mm_model.train_cfg.augments.1.num_classes'
],
type=int,
help='number of classes')
training_args.train_batch_size.default = 16
training_args.train_data_worker.default = 1
training_args.max_epochs.default = 1
training_args.optimizer.default = 'AdamW'
training_args.lr.default = 1e-4
2023-01-11 10:43:56 +08:00
training_args.warmup_iters = ArgAttr(
'train.lr_config.warmup_iters',
type=int,
default=1,
help='number of warmup epochs')
training_args.topk = ArgAttr(
cfg_node_name=[
'train.evaluation.metric_options.topk',
'evaluation.metric_options.topk'
],
default=(1, ),
help='evaluation using topk, tuple format, eg (1,), (1,5)')
2023-01-11 10:43:56 +08:00
training_args.train_data = ArgAttr(
type=str, default='tany0699/cats_and_dogs', help='train dataset')
training_args.validation_data = ArgAttr(
type=str, default='tany0699/cats_and_dogs', help='validation dataset')
training_args.model_id = ArgAttr(
type=str,
default='damo/cv_vit-base_image-classification_ImageNet-labels',
help='model name')
parser = CliArgumentParser(training_args)
return parser
def create_dataset(name, split):
namespace, dataset_name = name.split('/')
2023-01-11 10:43:56 +08:00
return MsDataset.load(
dataset_name, namespace=namespace, subset_name='default', split=split)
def train(parser):
cfg_dict = parser.get_cfg_dict()
args = parser.args
train_dataset = create_dataset(args.train_data, split='train')
val_dataset = create_dataset(args.validation_data, split='validation')
def cfg_modify_fn(cfg):
cfg.merge_from_dict(cfg_dict)
return cfg
kwargs = dict(
2023-01-11 10:43:56 +08:00
model=args.model_id, # model id
train_dataset=train_dataset, # training dataset
2023-01-11 10:43:56 +08:00
eval_dataset=val_dataset, # validation dataset
cfg_modify_fn=cfg_modify_fn # callback to modify configuration
)
# in distributed training, specify pytorch launcher
if 'MASTER_ADDR' in os.environ:
kwargs['launcher'] = 'pytorch'
2023-01-11 10:43:56 +08:00
trainer = build_trainer(
name=Trainers.image_classification, default_args=kwargs)
# start to train
trainer.train()
if __name__ == '__main__':
parser = define_parser()
train(parser)