Files
modelscope/examples/pytorch/human_detection/finetune_human_detection.py

65 lines
2.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os.path as osp
from argparse import ArgumentParser
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode
parser = ArgumentParser()
parser.add_argument('--dataset_name', type=str, help='The dataset name')
parser.add_argument('--namespace', type=str, help='The dataset namespace')
parser.add_argument('--model', type=str, help='The model id or model dir')
parser.add_argument(
'--num_classes', type=int, help='The num_classes in the dataset')
parser.add_argument('--batch_size', type=int, help='The training batch size')
parser.add_argument('--max_epochs', type=int, help='The training max epochs')
parser.add_argument(
'--base_lr_per_img',
type=float,
help='The base learning rate for per image')
args = parser.parse_args()
print(args)
# Step 1: 数据集准备可以使用modelscope上已有的数据集也可以自己在本地构建COCO数据集
train_dataset = MsDataset.load(
args.dataset_name,
namespace=args.namespace,
split='train',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
val_dataset = MsDataset.load(
args.dataset_name,
namespace=args.namespace,
split='validation',
download_mode=DownloadMode.FORCE_REDOWNLOAD)
# Step 2: 相关参数设置
train_root_dir = train_dataset.config_kwargs['split_config']['train']
val_root_dir = val_dataset.config_kwargs['split_config']['validation']
train_img_dir = osp.join(train_root_dir, 'images')
val_img_dir = osp.join(val_root_dir, 'images')
train_anno_path = osp.join(train_root_dir, 'train.json')
val_anno_path = osp.join(val_root_dir, 'val.json')
kwargs = dict(
model=args.model, # 使用DAMO-YOLO-S模型
gpu_ids=[ # 指定训练使用的gpu
0,
],
batch_size=args.
batch_size, # batch_size, 每个gpu上的图片数等于batch_size // len(gpu_ids)
max_epochs=args.max_epochs, # 总的训练epochs
num_classes=args.num_classes, # 自定义数据中的类别数
load_pretrain=True, # 是否载入预训练模型若为False则为从头重新训练
base_lr_per_img=args.
base_lr_per_img, # 每张图片的学习率lr=base_lr_per_img*batch_size
train_image_dir=train_img_dir, # 训练图片路径
val_image_dir=val_img_dir, # 测试图片路径
train_ann=train_anno_path, # 训练标注文件路径
val_ann=val_anno_path, # 测试标注文件路径
)
# Step 3: 开启训练任务
trainer = build_trainer(name=Trainers.tinynas_damoyolo, default_args=kwargs)
trainer.train()