mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-22 02:59:24 +01:00
design doc: https://yuque.antfin.com/pai/rwqgvl/khy4uw5dgi39s6ke usage: ```python from modelscope.trainers.training_args import (ArgAttr, MSArgumentParser, training_args) training_args.topk = ArgAttr(cfg_node_name=['train.evaluation.metric_options.topk', 'evaluation.metric_options.topk'], default=(1,), help='evaluation using topk, tuple format, eg (1,), (1,5)') training_args.train_data = ArgAttr(type=str, default='tany0699/cats_and_dogs', help='train dataset') training_args.validation_data = ArgAttr(type=str, default='tany0699/cats_and_dogs', help='validation dataset') training_args.model_id = ArgAttr(type=str, default='damo/cv_vit-base_image-classification_ImageNet-labels', help='model name') parser = MSArgumentParser(training_args) cfg_dict = parser.get_cfg_dict() args = parser.args train_dataset = create_dataset(args.train_data, split='train') val_dataset = create_dataset(args.validation_data, split='validation') def cfg_modify_fn(cfg): cfg.merge_from_dict(cfg_dict) return cfg kwargs = dict( model=args.model_id, # model id train_dataset=train_dataset, # training dataset eval_dataset=val_dataset, # validation dataset cfg_modify_fn=cfg_modify_fn # callback to modify configuration ) trainer = build_trainer(name=Trainers.image_classification, default_args=kwargs) # start to train trainer.train() ``` Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11225071
6 lines
238 B
Bash
6 lines
238 B
Bash
PYTHONPATH=. python -m torch.distributed.launch --nproc_per_node=2 \
|
|
examples/pytorch/finetune_image_classification.py \
|
|
--num_classes 2 \
|
|
--train_data 'tany0699/cats_and_dogs' \
|
|
--validation_data 'tany0699/cats_and_dogs'
|