mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-22 02:59:24 +01:00
fix transformer example and fix some bugs
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12626375
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from modelscope import MsDataset, TrainingArgs
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets.ms_dataset import MsDataset
|
||||
from modelscope.trainers.args import TrainingArgs
|
||||
from modelscope.trainers.builder import build_trainer
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(init=False)
|
||||
class ImageClassificationTrainingArgs(TrainingArgs):
|
||||
num_classes: int = field(
|
||||
default=None,
|
||||
@@ -46,26 +45,35 @@ def create_dataset(name, split):
|
||||
dataset_name, namespace=namespace, subset_name='default', split=split)
|
||||
|
||||
|
||||
def train():
|
||||
args = ImageClassificationTrainingArgs.from_cli(
|
||||
model='damo/cv_vit-base_image-classification_ImageNet-labels',
|
||||
max_epochs=1,
|
||||
lr=1e-4,
|
||||
optimizer='AdamW',
|
||||
warmup_iters=1,
|
||||
topk=(1, ))
|
||||
if args.dataset_name is not None:
|
||||
train_dataset = create_dataset(args.dataset_name, split='train')
|
||||
val_dataset = create_dataset(args.dataset_name, split='validation')
|
||||
training_args = ImageClassificationTrainingArgs(
|
||||
model='damo/cv_vit-base_image-classification_ImageNet-labels',
|
||||
max_epochs=1,
|
||||
lr=1e-4,
|
||||
optimizer='AdamW',
|
||||
warmup_iters=1,
|
||||
topk=(1, )).parse_cli()
|
||||
config, args = training_args.to_config()
|
||||
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
if args.use_model_config:
|
||||
cfg.merge_from_dict(config)
|
||||
else:
|
||||
train_dataset = create_dataset(args.train_dataset_name, split='train')
|
||||
val_dataset = create_dataset(args.val_dataset_name, split='validation')
|
||||
cfg = config
|
||||
return cfg
|
||||
|
||||
|
||||
def train():
|
||||
train_dataset = create_dataset(
|
||||
training_args.train_dataset_name, split=training_args.train_split)
|
||||
val_dataset = create_dataset(
|
||||
training_args.val_dataset_name, split=training_args.val_split)
|
||||
|
||||
kwargs = dict(
|
||||
model=args.model, # model id
|
||||
train_dataset=train_dataset, # training dataset
|
||||
eval_dataset=val_dataset, # validation dataset
|
||||
cfg_modify_fn=args # callback to modify configuration
|
||||
cfg_modify_fn=cfg_modify_fn # callback to modify configuration
|
||||
)
|
||||
|
||||
# in distributed training, specify pytorch launcher
|
||||
|
||||
Reference in New Issue
Block a user