This commit is contained in:
xingjun.wang
2023-05-22 10:53:18 +08:00
parent 52aea36c12
commit 48c0d2a9af
468 changed files with 12942 additions and 7176 deletions

View File

@@ -1,13 +1,12 @@
import os
from dataclasses import dataclass, field
from modelscope import MsDataset, TrainingArgs
from modelscope.metainfo import Trainers
from modelscope.msdatasets.ms_dataset import MsDataset
from modelscope.trainers.builder import build_trainer
from modelscope.trainers.training_args import TrainingArgs
@dataclass
@dataclass(init=False)
class ImageClassificationTrainingArgs(TrainingArgs):
num_classes: int = field(
default=None,
@@ -46,26 +45,35 @@ def create_dataset(name, split):
dataset_name, namespace=namespace, subset_name='default', split=split)
def train():
args = ImageClassificationTrainingArgs.from_cli(
model='damo/cv_vit-base_image-classification_ImageNet-labels',
max_epochs=1,
lr=1e-4,
optimizer='AdamW',
warmup_iters=1,
topk=(1, ))
if args.dataset_name is not None:
train_dataset = create_dataset(args.dataset_name, split='train')
val_dataset = create_dataset(args.dataset_name, split='validation')
training_args = ImageClassificationTrainingArgs(
model='damo/cv_vit-base_image-classification_ImageNet-labels',
max_epochs=1,
lr=1e-4,
optimizer='AdamW',
warmup_iters=1,
topk=(1, )).parse_cli()
config, args = training_args.to_config()
def cfg_modify_fn(cfg):
if args.use_model_config:
cfg.merge_from_dict(config)
else:
train_dataset = create_dataset(args.train_dataset_name, split='train')
val_dataset = create_dataset(args.val_dataset_name, split='validation')
cfg = config
return cfg
def train():
train_dataset = create_dataset(
training_args.train_dataset_name, split=training_args.train_split)
val_dataset = create_dataset(
training_args.val_dataset_name, split=training_args.val_split)
kwargs = dict(
model=args.model, # model id
train_dataset=train_dataset, # training dataset
eval_dataset=val_dataset, # validation dataset
cfg_modify_fn=args # callback to modify configuration
cfg_modify_fn=cfg_modify_fn # callback to modify configuration
)
# in distributed training, specify pytorch launcher