Files
modelscope/examples/pytorch/human_detection/finetune_human_detection.py

65 lines
2.5 KiB
Python
Raw Permalink Normal View History

import os.path as osp
from argparse import ArgumentParser
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode
parser = ArgumentParser()
parser.add_argument('--dataset_name', type=str, help='The dataset name')
parser.add_argument('--namespace', type=str, help='The dataset namespace')
parser.add_argument('--model', type=str, help='The model id or model dir')
parser.add_argument(
'--num_classes', type=int, help='The num_classes in the dataset')
parser.add_argument('--batch_size', type=int, help='The training batch size')
parser.add_argument('--max_epochs', type=int, help='The training max epochs')
parser.add_argument(
'--base_lr_per_img',
type=float,
help='The base learning rate for per image')
args = parser.parse_args()
print(args)
# Step 1: 数据集准备可以使用modelscope上已有的数据集也可以自己在本地构建COCO数据集
train_dataset = MsDataset.load(
args.dataset_name,
namespace=args.namespace,
split='train',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
val_dataset = MsDataset.load(
args.dataset_name,
namespace=args.namespace,
split='validation',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
# Step 2: 相关参数设置
train_root_dir = train_dataset.config_kwargs['split_config']['train']
val_root_dir = val_dataset.config_kwargs['split_config']['validation']
train_img_dir = osp.join(train_root_dir, 'images')
val_img_dir = osp.join(val_root_dir, 'images')
train_anno_path = osp.join(train_root_dir, 'train.json')
val_anno_path = osp.join(val_root_dir, 'val.json')
kwargs = dict(
model=args.model, # 使用DAMO-YOLO-S模型
gpu_ids=[ # 指定训练使用的gpu
0,
],
batch_size=args.
batch_size, # batch_size, 每个gpu上的图片数等于batch_size // len(gpu_ids)
max_epochs=args.max_epochs, # 总的训练epochs
num_classes=args.num_classes, # 自定义数据中的类别数
load_pretrain=True, # 是否载入预训练模型若为False则为从头重新训练
base_lr_per_img=args.
base_lr_per_img, # 每张图片的学习率lr=base_lr_per_img*batch_size
train_image_dir=train_img_dir, # 训练图片路径
val_image_dir=val_img_dir, # 测试图片路径
train_ann=train_anno_path, # 训练标注文件路径
val_ann=val_anno_path, # 测试标注文件路径
)
# Step 3: 开启训练任务
trainer = build_trainer(name=Trainers.tinynas_damoyolo, default_args=kwargs)
trainer.train()