wenmeng.zwm
|
b8ec677739
|
add training args support and image classification fintune example
design doc: https://yuque.antfin.com/pai/rwqgvl/khy4uw5dgi39s6ke
usage:
```python
from modelscope.trainers.training_args import (ArgAttr, MSArgumentParser,
training_args)
training_args.topk = ArgAttr(cfg_node_name=['train.evaluation.metric_options.topk',
'evaluation.metric_options.topk'],
default=(1,), help='evaluation using topk, tuple format, eg (1,), (1,5)')
training_args.train_data = ArgAttr(type=str, default='tany0699/cats_and_dogs', help='train dataset')
training_args.validation_data = ArgAttr(type=str, default='tany0699/cats_and_dogs', help='validation dataset')
training_args.model_id = ArgAttr(type=str, default='damo/cv_vit-base_image-classification_ImageNet-labels', help='model name')
parser = MSArgumentParser(training_args)
cfg_dict = parser.get_cfg_dict()
args = parser.args
train_dataset = create_dataset(args.train_data, split='train')
val_dataset = create_dataset(args.validation_data, split='validation')
def cfg_modify_fn(cfg):
cfg.merge_from_dict(cfg_dict)
return cfg
kwargs = dict(
model=args.model_id, # model id
train_dataset=train_dataset, # training dataset
eval_dataset=val_dataset, # validation dataset
cfg_modify_fn=cfg_modify_fn # callback to modify configuration
)
trainer = build_trainer(name=Trainers.image_classification, default_args=kwargs)
# start to train
trainer.train()
```
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11225071
|
2022-12-30 07:35:15 +08:00 |
|