mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-18 01:07:44 +01:00
65 lines
2.5 KiB
Python
65 lines
2.5 KiB
Python
import os.path as osp
|
||
from argparse import ArgumentParser
|
||
|
||
from modelscope.metainfo import Trainers
|
||
from modelscope.msdatasets import MsDataset
|
||
from modelscope.trainers import build_trainer
|
||
from modelscope.utils.constant import DownloadMode
|
||
|
||
parser = ArgumentParser()
|
||
parser.add_argument('--dataset_name', type=str, help='The dataset name')
|
||
parser.add_argument('--namespace', type=str, help='The dataset namespace')
|
||
parser.add_argument('--model', type=str, help='The model id or model dir')
|
||
parser.add_argument(
|
||
'--num_classes', type=int, help='The num_classes in the dataset')
|
||
parser.add_argument('--batch_size', type=int, help='The training batch size')
|
||
parser.add_argument('--max_epochs', type=int, help='The training max epochs')
|
||
parser.add_argument(
|
||
'--base_lr_per_img',
|
||
type=float,
|
||
help='The base learning rate for per image')
|
||
|
||
args = parser.parse_args()
|
||
print(args)
|
||
|
||
# Step 1: 数据集准备,可以使用modelscope上已有的数据集,也可以自己在本地构建COCO数据集
|
||
train_dataset = MsDataset.load(
|
||
args.dataset_name,
|
||
namespace=args.namespace,
|
||
split='train',
|
||
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||
val_dataset = MsDataset.load(
|
||
args.dataset_name,
|
||
namespace=args.namespace,
|
||
split='validation',
|
||
download_mode=DownloadMode.FORCE_REDOWNLOAD)
|
||
|
||
# Step 2: 相关参数设置
|
||
train_root_dir = train_dataset.config_kwargs['split_config']['train']
|
||
val_root_dir = val_dataset.config_kwargs['split_config']['validation']
|
||
train_img_dir = osp.join(train_root_dir, 'images')
|
||
val_img_dir = osp.join(val_root_dir, 'images')
|
||
train_anno_path = osp.join(train_root_dir, 'train.json')
|
||
val_anno_path = osp.join(val_root_dir, 'val.json')
|
||
kwargs = dict(
|
||
model=args.model, # 使用DAMO-YOLO-S模型
|
||
gpu_ids=[ # 指定训练使用的gpu
|
||
0,
|
||
],
|
||
batch_size=args.
|
||
batch_size, # batch_size, 每个gpu上的图片数等于batch_size // len(gpu_ids)
|
||
max_epochs=args.max_epochs, # 总的训练epochs
|
||
num_classes=args.num_classes, # 自定义数据中的类别数
|
||
load_pretrain=True, # 是否载入预训练模型,若为False,则为从头重新训练
|
||
base_lr_per_img=args.
|
||
base_lr_per_img, # 每张图片的学习率,lr=base_lr_per_img*batch_size
|
||
train_image_dir=train_img_dir, # 训练图片路径
|
||
val_image_dir=val_img_dir, # 测试图片路径
|
||
train_ann=train_anno_path, # 训练标注文件路径
|
||
val_ann=val_anno_path, # 测试标注文件路径
|
||
)
|
||
|
||
# Step 3: 开启训练任务
|
||
trainer = build_trainer(name=Trainers.tinynas_damoyolo, default_args=kwargs)
|
||
trainer.train()
|