mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
update training args
Based on feat/0131/nlp_args branch, the original code review: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11408570 Support for running finetuning from the command line with training args, Compatible with the configuration optimization.
This commit is contained in:
@@ -1,51 +1,43 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets.ms_dataset import MsDataset
|
||||
from modelscope.trainers.builder import build_trainer
|
||||
from modelscope.trainers.training_args import (ArgAttr, CliArgumentParser,
|
||||
training_args)
|
||||
from modelscope.trainers.training_args import TrainingArgs
|
||||
|
||||
|
||||
def define_parser():
|
||||
training_args.num_classes = ArgAttr(
|
||||
cfg_node_name=[
|
||||
'model.mm_model.head.num_classes',
|
||||
'model.mm_model.train_cfg.augments.0.num_classes',
|
||||
'model.mm_model.train_cfg.augments.1.num_classes'
|
||||
],
|
||||
type=int,
|
||||
help='number of classes')
|
||||
@dataclass
|
||||
class ImageClassificationTrainingArgs(TrainingArgs):
|
||||
num_classes: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': [
|
||||
'model.mm_model.head.num_classes',
|
||||
'model.mm_model.train_cfg.augments.0.num_classes',
|
||||
'model.mm_model.train_cfg.augments.1.num_classes'
|
||||
],
|
||||
'help':
|
||||
'number of classes',
|
||||
})
|
||||
|
||||
training_args.train_batch_size.default = 16
|
||||
training_args.train_data_worker.default = 1
|
||||
training_args.max_epochs.default = 1
|
||||
training_args.optimizer.default = 'AdamW'
|
||||
training_args.lr.default = 1e-4
|
||||
training_args.warmup_iters = ArgAttr(
|
||||
'train.lr_config.warmup_iters',
|
||||
type=int,
|
||||
default=1,
|
||||
help='number of warmup epochs')
|
||||
training_args.topk = ArgAttr(
|
||||
cfg_node_name=[
|
||||
'train.evaluation.metric_options.topk',
|
||||
'evaluation.metric_options.topk'
|
||||
],
|
||||
default=(1, ),
|
||||
help='evaluation using topk, tuple format, eg (1,), (1,5)')
|
||||
topk: tuple = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': [
|
||||
'train.evaluation.metric_options.topk',
|
||||
'evaluation.metric_options.topk'
|
||||
],
|
||||
'help':
|
||||
'evaluation using topk, tuple format, eg (1,), (1,5)',
|
||||
})
|
||||
|
||||
training_args.train_data = ArgAttr(
|
||||
type=str, default='tany0699/cats_and_dogs', help='train dataset')
|
||||
training_args.validation_data = ArgAttr(
|
||||
type=str, default='tany0699/cats_and_dogs', help='validation dataset')
|
||||
training_args.model_id = ArgAttr(
|
||||
type=str,
|
||||
default='damo/cv_vit-base_image-classification_ImageNet-labels',
|
||||
help='model name')
|
||||
|
||||
parser = CliArgumentParser(training_args)
|
||||
return parser
|
||||
warmup_iters: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': 'train.lr_config.warmup_iters',
|
||||
'help': 'The warmup iters',
|
||||
})
|
||||
|
||||
|
||||
def create_dataset(name, split):
|
||||
@@ -54,21 +46,26 @@ def create_dataset(name, split):
|
||||
dataset_name, namespace=namespace, subset_name='default', split=split)
|
||||
|
||||
|
||||
def train(parser):
|
||||
cfg_dict = parser.get_cfg_dict()
|
||||
args = parser.args
|
||||
train_dataset = create_dataset(args.train_data, split='train')
|
||||
val_dataset = create_dataset(args.validation_data, split='validation')
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
cfg.merge_from_dict(cfg_dict)
|
||||
return cfg
|
||||
def train():
|
||||
args = ImageClassificationTrainingArgs.from_cli(
|
||||
model='damo/cv_vit-base_image-classification_ImageNet-labels',
|
||||
max_epochs=1,
|
||||
lr=1e-4,
|
||||
optimizer='AdamW',
|
||||
warmup_iters=1,
|
||||
topk=(1, ))
|
||||
if args.dataset_name is not None:
|
||||
train_dataset = create_dataset(args.dataset_name, split='train')
|
||||
val_dataset = create_dataset(args.dataset_name, split='validation')
|
||||
else:
|
||||
train_dataset = create_dataset(args.train_dataset_name, split='train')
|
||||
val_dataset = create_dataset(args.val_dataset_name, split='validation')
|
||||
|
||||
kwargs = dict(
|
||||
model=args.model_id, # model id
|
||||
model=args.model, # model id
|
||||
train_dataset=train_dataset, # training dataset
|
||||
eval_dataset=val_dataset, # validation dataset
|
||||
cfg_modify_fn=cfg_modify_fn # callback to modify configuration
|
||||
cfg_modify_fn=args # callback to modify configuration
|
||||
)
|
||||
|
||||
# in distributed training, specify pytorch launcher
|
||||
@@ -82,5 +79,4 @@ def train(parser):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = define_parser()
|
||||
train(parser)
|
||||
train()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
PYTHONPATH=. python -m torch.distributed.launch --nproc_per_node=2 \
|
||||
examples/pytorch/finetune_image_classification.py \
|
||||
--num_classes 2 \
|
||||
--train_data 'tany0699/cats_and_dogs' \
|
||||
--validation_data 'tany0699/cats_and_dogs'
|
||||
--train_dataset_name 'tany0699/cats_and_dogs' \
|
||||
--val_dataset_name 'tany0699/cats_and_dogs'
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import EpochBasedTrainer, build_trainer
|
||||
from modelscope.trainers.training_args import TrainingArgs
|
||||
|
||||
|
||||
def get_labels(cfg, metadata):
|
||||
label2id = cfg.safe_get(metadata['cfg_node'])
|
||||
if label2id is not None:
|
||||
return ','.join(label2id.keys())
|
||||
|
||||
|
||||
def set_labels(cfg, labels, metadata):
|
||||
if isinstance(labels, str):
|
||||
labels = labels.split(',')
|
||||
cfg.merge_from_dict(
|
||||
{metadata['cfg_node']: {label: id
|
||||
for id, label in enumerate(labels)}})
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextClassificationArguments(TrainingArgs):
|
||||
|
||||
first_sequence: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The first sequence key of preprocessor',
|
||||
'cfg_node': 'preprocessor.first_sequence'
|
||||
})
|
||||
|
||||
second_sequence: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The second sequence key of preprocessor',
|
||||
'cfg_node': 'preprocessor.second_sequence'
|
||||
})
|
||||
|
||||
label: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The label key of preprocessor',
|
||||
'cfg_node': 'preprocessor.label'
|
||||
})
|
||||
|
||||
labels: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The labels of the dataset',
|
||||
'cfg_node': 'preprocessor.label2id',
|
||||
'cfg_getter': get_labels,
|
||||
'cfg_setter': set_labels,
|
||||
})
|
||||
|
||||
preprocessor: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The preprocessor type',
|
||||
'cfg_node': 'preprocessor.type'
|
||||
})
|
||||
|
||||
def __call__(self, config):
|
||||
config = super().__call__(config)
|
||||
config.model['num_labels'] = len(self.labels)
|
||||
if config.train.lr_scheduler.type == 'LinearLR':
|
||||
config.train.lr_scheduler['total_iters'] = \
|
||||
int(len(train_dataset) / self.per_device_train_batch_size) * self.max_epochs
|
||||
return config
|
||||
|
||||
|
||||
args = TextClassificationArguments.from_cli(
|
||||
task='text-classification', eval_metrics='seq-cls-metric')
|
||||
|
||||
print(args)
|
||||
|
||||
dataset = MsDataset.load(args.dataset_name, subset_name=args.subset_name)
|
||||
train_dataset = dataset['train']
|
||||
validation_dataset = dataset['validation']
|
||||
|
||||
kwargs = dict(
|
||||
model=args.model,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=validation_dataset,
|
||||
seed=args.seed,
|
||||
cfg_modify_fn=args)
|
||||
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
trainer: EpochBasedTrainer = build_trainer(name='trainer', default_args=kwargs)
|
||||
trainer.train()
|
||||
12
examples/pytorch/text_classification/run_train.sh
Normal file
12
examples/pytorch/text_classification/run_train.sh
Normal file
@@ -0,0 +1,12 @@
|
||||
PYTHONPATH=. python examples/pytorch/text_classification/finetune_text_classification.py \
|
||||
--model 'damo/nlp_structbert_backbone_base_std' \
|
||||
--dataset_name 'clue' \
|
||||
--subset_name 'tnews' \
|
||||
--first_sequence 'sentence' \
|
||||
--preprocessor.label label \
|
||||
--model.num_labels 15 \
|
||||
--labels '0,1,2,3,4,5,6,7,8,9,10,11,12,13,14' \
|
||||
--preprocessor 'sen-cls-tokenizer' \
|
||||
--train.dataloader.workers_per_gpu 0 \
|
||||
--evaluation.dataloader.workers_per_gpu 0 \
|
||||
--train.optimizer.lr 1e-5 \
|
||||
1
examples/pytorch/transformers/configuration.json
Normal file
1
examples/pytorch/transformers/configuration.json
Normal file
@@ -0,0 +1 @@
|
||||
{"framework":"pytorch","train":{"work_dir":"/tmp","max_epochs":10,"dataloader":{"batch_size_per_gpu":16,"workers_per_gpu":0},"optimizer":{"type":"SGD","lr":0.001},"lr_scheduler":{"type":"StepLR","step_size":2},"hooks":[{"type":"CheckpointHook","interval":1}]},"evaluation":{"dataloader":{"batch_size_per_gpu":16,"workers_per_gpu":0,"shuffle":false}}}
|
||||
57
examples/pytorch/transformers/finetune_transformers_model.py
Normal file
57
examples/pytorch/transformers/finetune_transformers_model.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import (BertForSequenceClassification, BertTokenizerFast,
|
||||
default_data_collator)
|
||||
|
||||
from modelscope.trainers import EpochBasedTrainer, build_trainer
|
||||
from modelscope.trainers.default_config import DEFAULT_CONFIG, TrainingArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformersArguments(TrainingArgs):
|
||||
|
||||
num_labels: int = field(
|
||||
default=None, metadata={
|
||||
'help': 'The number of labels',
|
||||
})
|
||||
|
||||
|
||||
args = TransformersArguments.from_cli(
|
||||
task='text-classification', eval_metrics='seq-cls-metric')
|
||||
|
||||
print(args)
|
||||
|
||||
dataset = load_dataset(args.dataset_name, args.subset_name)
|
||||
|
||||
model = BertForSequenceClassification.from_pretrained(
|
||||
args.model, num_labels=args.num_labels)
|
||||
tokenizer = BertTokenizerFast.from_pretrained(args.model)
|
||||
|
||||
|
||||
def tokenize_sentence(row):
|
||||
return tokenizer(row['sentence'], padding='max_length', max_length=128)
|
||||
|
||||
|
||||
# Extra columns, Rename columns
|
||||
dataset = dataset.map(tokenize_sentence).remove_columns(['sentence',
|
||||
'idx']).rename_column(
|
||||
'label', 'labels')
|
||||
|
||||
cfg_file = os.path.join(args.work_dir or './', 'configuration.json')
|
||||
DEFAULT_CONFIG.dump(cfg_file)
|
||||
|
||||
kwargs = dict(
|
||||
model=model,
|
||||
cfg_file=cfg_file,
|
||||
# data_collator
|
||||
data_collator=default_data_collator,
|
||||
train_dataset=dataset['train'],
|
||||
eval_dataset=dataset['validation'],
|
||||
seed=args.seed,
|
||||
cfg_modify_fn=args)
|
||||
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
trainer: EpochBasedTrainer = build_trainer(name='trainer', default_args=kwargs)
|
||||
trainer.train()
|
||||
5
examples/pytorch/transformers/run_train.sh
Normal file
5
examples/pytorch/transformers/run_train.sh
Normal file
@@ -0,0 +1,5 @@
|
||||
PYTHONPATH=. python examples/pytorch/transformers/finetune_transformers_model.py \
|
||||
--model bert-base-uncased \
|
||||
--num_labels 15 \
|
||||
--dataset_name clue \
|
||||
--subset_name tnews
|
||||
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
||||
ReferringVideoObjectSegmentationTrainer)
|
||||
from .multi_modal import CLIPTrainer
|
||||
from .nlp import SequenceClassificationTrainer, TextRankingTrainer
|
||||
from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer, NlpTrainerArguments
|
||||
from .nlp_trainer import NlpEpochBasedTrainer, VecoTrainer
|
||||
from .trainer import EpochBasedTrainer
|
||||
|
||||
else:
|
||||
@@ -28,8 +28,7 @@ else:
|
||||
],
|
||||
'multi_modal': ['CLIPTrainer'],
|
||||
'nlp': ['SequenceClassificationTrainer', 'TextRankingTrainer'],
|
||||
'nlp_trainer':
|
||||
['NlpEpochBasedTrainer', 'VecoTrainer', 'NlpTrainerArguments'],
|
||||
'nlp_trainer': ['NlpEpochBasedTrainer', 'VecoTrainer'],
|
||||
'trainer': ['EpochBasedTrainer']
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,41 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from modelscope.utils.config import Config
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
DEFAULT_CONFIG = Config({
|
||||
'framework': 'pytorch',
|
||||
'train': {
|
||||
'work_dir': '/tmp',
|
||||
'max_epochs': 10,
|
||||
'dataloader': {
|
||||
'batch_size_per_gpu': 16,
|
||||
'workers_per_gpu': 0
|
||||
},
|
||||
'optimizer': {
|
||||
'type': 'SGD',
|
||||
'lr': 1e-3
|
||||
},
|
||||
'lr_scheduler': {
|
||||
'type': 'StepLR',
|
||||
'step_size': 2
|
||||
},
|
||||
'hooks': [{
|
||||
'type': 'CheckpointHook',
|
||||
'interval': 1
|
||||
}]
|
||||
},
|
||||
'evaluation': {
|
||||
'dataloader': {
|
||||
'batch_size_per_gpu': 16,
|
||||
'workers_per_gpu': 0,
|
||||
'shuffle': False
|
||||
},
|
||||
}
|
||||
})
|
||||
|
||||
DEFAULT_HOOKS_CONFIG = {
|
||||
'train': {
|
||||
'hooks': [{
|
||||
'type': 'CheckpointHook',
|
||||
@@ -25,10 +58,46 @@ def merge_cfg(cfg: Config):
|
||||
Aegs:
|
||||
cfg: The input cfg to be merged into.
|
||||
"""
|
||||
cfg.merge_from_dict(DEFAULT_CONFIG, force=False)
|
||||
cfg.merge_from_dict(DEFAULT_HOOKS_CONFIG, force=False)
|
||||
# pop duplicate hook
|
||||
|
||||
if any(['BestCkptSaverHook' == hook['type'] for hook in cfg.train.hooks]):
|
||||
cfg.train.hooks = list(
|
||||
filter(lambda hook: hook['type'] != 'CheckpointHook',
|
||||
cfg.train.hooks))
|
||||
|
||||
|
||||
def merge_hooks(cfg: Config) -> List[Dict]:
|
||||
key_chain_hook_map = {
|
||||
'train.logging': 'TextLoggerHook',
|
||||
'train.checkpoint.period': 'CheckpointHook',
|
||||
'train.checkpoint.best': 'BestCkptSaverHook',
|
||||
'evaluation.period': 'EvaluationHook'
|
||||
}
|
||||
hooks = cfg.train.hooks.copy()
|
||||
for key_chain, hook_type in key_chain_hook_map.items():
|
||||
hook = _key_chain_to_hook(cfg, key_chain, hook_type)
|
||||
if hook is not None:
|
||||
hooks.append(hook)
|
||||
return hooks
|
||||
|
||||
|
||||
def _key_chain_to_hook(cfg: Config, key_chain: str,
|
||||
hook_type: str) -> Optional[Dict]:
|
||||
if not _check_basic_hook(cfg, key_chain, hook_type):
|
||||
return None
|
||||
hook_params: Dict = cfg.safe_get(key_chain)
|
||||
hook = {'type': hook_type}
|
||||
hook.update(hook_params)
|
||||
return hook
|
||||
|
||||
|
||||
def _check_basic_hook(cfg: Config, key_chain: str, hook_type: str) -> bool:
|
||||
if cfg.safe_get(key_chain) is None:
|
||||
return False
|
||||
hooks = list(
|
||||
filter(lambda hook: hook['type'] == hook_type, cfg.train.hooks))
|
||||
assert len(hooks) == 0, f'The key_chain {key_chain} and the traditional hook ' \
|
||||
f'cannot exist at the same time, ' \
|
||||
f'please delete {hook_type} in the configuration file.'
|
||||
return True
|
||||
|
||||
@@ -1,429 +1,21 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.metrics.builder import build_metric
|
||||
from modelscope.models.base import Model, TorchModel
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.utils.config import Config, ConfigDict
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ModeKeys,
|
||||
ModelFile)
|
||||
from modelscope.utils.hub import parse_label_mapping
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
from .base import TRAINERS
|
||||
from .trainer import EpochBasedTrainer
|
||||
|
||||
|
||||
@dataclass
|
||||
class NlpTrainerArguments:
|
||||
"""The arguments for the nlp trainer.
|
||||
|
||||
All the arguments listed here have None default values, which means follow the default value in the input
|
||||
cfg dict.
|
||||
"""
|
||||
|
||||
work_dir: Optional[str] = field(
|
||||
default=None, metadata={'help': 'The work dir(key: train.work_dir)'})
|
||||
|
||||
task: Optional[str] = field(
|
||||
default=None, metadata={'help': 'The task type(key: task)'})
|
||||
|
||||
preprocessor_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={'help': 'The preprocessor type(key: preprocessor.type)'})
|
||||
|
||||
train_first_sequence: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The key of first sentence for the training dataset(key:preprocessor.train.'
|
||||
'first_sequence/dataset.train.first_sequence)'
|
||||
})
|
||||
|
||||
train_second_sequence: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The key of second sentence for the training dataset(key:preprocessor.train.'
|
||||
'second_sequence/dataset.train.second_sequence)'
|
||||
})
|
||||
|
||||
train_label: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The key of label for the training dataset(key:preprocessor.train.'
|
||||
'second_sequence/dataset.train.second_sequence)'
|
||||
})
|
||||
|
||||
eval_first_sequence: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The key of first sentence for the eval dataset(key:preprocessor.val.'
|
||||
'first_sequence/dataset.val.first_sequence), '
|
||||
'if not provided, the trainer will use the train_first_sequence for evaluation'
|
||||
})
|
||||
|
||||
eval_second_sequence: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The key of second sentence for the eval dataset(key:preprocessor.val.'
|
||||
'second_sequence/dataset.val.second_sequence),'
|
||||
'if not provided, the trainer will use the train_second_sequence for evaluation'
|
||||
})
|
||||
|
||||
eval_label: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The key of label for the eval dataset(key:preprocessor.val.'
|
||||
'second_sequence/dataset.val.second_sequence),'
|
||||
'if not provided, the trainer will use the train_label for evaluation'
|
||||
})
|
||||
|
||||
labels: Optional[List] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The labels list of the dataset(key:dataset.train.labels),'
|
||||
'This parameter has the same effect with "label2id"'
|
||||
})
|
||||
|
||||
max_epochs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The max_epochs of the training loop(key: train.max_epochs)'
|
||||
})
|
||||
|
||||
train_batch_size_per_gpu: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The train batch size per gpu(key: train.dataloader.batch_size_per_gpu)'
|
||||
})
|
||||
|
||||
train_workers_per_gpu: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The number of workers per gpu(key: train.dataloader.workers_per_gpu)'
|
||||
})
|
||||
|
||||
train_shuffle: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'Shuffle the train dataset or not(key: train.dataloader.shuffle)'
|
||||
})
|
||||
|
||||
eval_batch_size_per_gpu: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The eval batch size per gpu(key: evaluation.dataloader.batch_size_per_gpu)'
|
||||
})
|
||||
|
||||
eval_workers_per_gpu: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The number of workers per gpu(key: evaluation.dataloader.workers_per_gpu)'
|
||||
})
|
||||
|
||||
eval_shuffle: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'Shuffle the eval dataset or not(key: evaluation.dataloader.shuffle)'
|
||||
})
|
||||
|
||||
optimizer_args: Optional[Dict] = field(
|
||||
default=None,
|
||||
metadata={'help': 'The optimizer config dict(key: train.optimizer)'})
|
||||
|
||||
lr_scheduler_args: Optional[Dict] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The lr_scheduler config dict(key: train.lr_scheduler)'
|
||||
})
|
||||
|
||||
checkpoint_saving_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The checkpoint saving type(key: The ckpt hook dict in train.hooks), '
|
||||
'valid options: "BestCkptSaverHook", "CheckpointHook"'
|
||||
})
|
||||
|
||||
checkpoint_by_epoch: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'Saving checkpoint by epoch or not(key: The by_epoch key in '
|
||||
'ckpt hook dict in train.hooks)'
|
||||
})
|
||||
|
||||
checkpoint_interval: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The checkpoint saving interval(key: The interval key in '
|
||||
'ckpt hook dict in train.hooks)'
|
||||
})
|
||||
|
||||
metric_key: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The metric key for the BestCkptSaverHook(key: The metric_key key in '
|
||||
'ckpt hook dict in train.hooks), if the checkpoint_saving_type is "CheckpointHook" or '
|
||||
'"None", the metric_key key has no effects'
|
||||
})
|
||||
|
||||
evaluation_type: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The evaluation type(key: The evaluation hook dict in train.hooks), '
|
||||
'valid options: "EvaluationHook", "None"'
|
||||
})
|
||||
|
||||
evaluation_by_epoch: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'Evaluating by epoch or not(key: The by_epoch key in '
|
||||
'evaluation hook dict in train.hooks)'
|
||||
})
|
||||
|
||||
evaluation_interval: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The evaluating interval(key: The interval key in '
|
||||
'evaluation hook dict in train.hooks)'
|
||||
})
|
||||
|
||||
metrics: Optional[List[str]] = field(
|
||||
default=None,
|
||||
metadata={'help': 'The metrics class keys(key: evaluation.metrics)'})
|
||||
|
||||
default_train_config = ConfigDict({
|
||||
'work_dir':
|
||||
'/tmp',
|
||||
'max_epochs':
|
||||
5,
|
||||
'dataloader': {
|
||||
'batch_size_per_gpu': 32,
|
||||
'workers_per_gpu': 0
|
||||
},
|
||||
'optimizer': {
|
||||
'type': 'AdamW',
|
||||
'lr': 2e-5,
|
||||
'options': {}
|
||||
},
|
||||
'lr_scheduler': {
|
||||
'type': 'LinearLR',
|
||||
'start_factor': 1.0,
|
||||
'end_factor': 0.0,
|
||||
'total_iters': 10000,
|
||||
'options': {
|
||||
'by_epoch': False
|
||||
}
|
||||
},
|
||||
'hooks': [{
|
||||
'type': 'CheckpointHook',
|
||||
'by_epoch': False,
|
||||
'interval': 100
|
||||
}, {
|
||||
'type': 'TextLoggerHook',
|
||||
'interval': 1
|
||||
}, {
|
||||
'type': 'IterTimerHook'
|
||||
}, {
|
||||
'type': 'EvaluationHook',
|
||||
'by_epoch': False,
|
||||
'interval': 100
|
||||
}]
|
||||
})
|
||||
|
||||
def __call__(self, cfg):
|
||||
"""
|
||||
|
||||
Args:
|
||||
cfg(`Config`): The cfg to be modified.
|
||||
|
||||
Returns:
|
||||
The cfg after modification.
|
||||
"""
|
||||
|
||||
if self.task is not None:
|
||||
cfg.task = self.task
|
||||
|
||||
if self.preprocessor_type is not None:
|
||||
if not hasattr(cfg, 'preprocessor'):
|
||||
cfg.preprocessor = ConfigDict()
|
||||
cfg.preprocessor.type = self.preprocessor_type
|
||||
|
||||
if self.train_first_sequence is not None or self.train_second_sequence \
|
||||
is not None or self.train_label is not None or self.labels is not None:
|
||||
if not hasattr(cfg, 'dataset'):
|
||||
cfg.dataset = ConfigDict()
|
||||
if not hasattr(cfg.dataset, 'train'):
|
||||
cfg.dataset.train = ConfigDict()
|
||||
if self.train_first_sequence is not None:
|
||||
cfg.dataset.train.first_sequence = self.train_first_sequence
|
||||
if self.train_second_sequence is not None:
|
||||
cfg.dataset.train.second_sequence = self.train_second_sequence
|
||||
if self.train_label is not None:
|
||||
cfg.dataset.train.label = self.train_label
|
||||
if self.labels is not None:
|
||||
cfg.dataset.train.labels = self.labels
|
||||
|
||||
if self.eval_first_sequence is not None or self.eval_second_sequence \
|
||||
is not None or self.eval_label is not None:
|
||||
if not hasattr(cfg, 'dataset'):
|
||||
cfg.dataset = ConfigDict()
|
||||
if not hasattr(cfg.dataset, 'val'):
|
||||
cfg.dataset.val = ConfigDict()
|
||||
if self.eval_first_sequence is not None:
|
||||
cfg.dataset.val.first_sequence = self.eval_first_sequence
|
||||
if self.eval_second_sequence is not None:
|
||||
cfg.dataset.val.second_sequence = self.eval_second_sequence
|
||||
if self.eval_label is not None:
|
||||
cfg.dataset.val.label = self.eval_label
|
||||
|
||||
if self.max_epochs is not None or self.train_batch_size_per_gpu is not None \
|
||||
or self.train_shuffle is not None or self.optimizer_args is not None \
|
||||
or self.work_dir is not None or self.lr_scheduler_args is not None\
|
||||
or self.train_workers_per_gpu is not None:
|
||||
if not hasattr(cfg, 'train'):
|
||||
cfg.train = deepcopy(self.default_train_config)
|
||||
if not hasattr(cfg.train, 'dataloader'):
|
||||
cfg.train.dataloader = deepcopy(
|
||||
self.default_train_config.dataloader)
|
||||
if not hasattr(cfg.train, 'optimizer'):
|
||||
cfg.train.optimizer = deepcopy(
|
||||
self.default_train_config.optimizer)
|
||||
if not hasattr(cfg.train, 'lr_scheduler'):
|
||||
cfg.train.lr_scheduler = deepcopy(
|
||||
self.default_train_config.lr_scheduler)
|
||||
if self.work_dir is not None:
|
||||
cfg.train.work_dir = self.work_dir
|
||||
if self.max_epochs is not None:
|
||||
cfg.train.max_epochs = self.max_epochs
|
||||
if self.train_batch_size_per_gpu is not None:
|
||||
cfg.train.dataloader.batch_size_per_gpu = self.train_batch_size_per_gpu
|
||||
if self.train_workers_per_gpu is not None:
|
||||
cfg.train.dataloader.workers_per_gpu = self.train_workers_per_gpu
|
||||
if self.train_shuffle is not None:
|
||||
cfg.train.dataloader.shuffle = self.train_shuffle
|
||||
if self.optimizer_args is not None:
|
||||
if cfg.train.optimizer.type != self.optimizer_args.get(
|
||||
'type', cfg.train.optimizer.type):
|
||||
cfg.train.optimizer = ConfigDict(
|
||||
deepcopy(self.optimizer_args))
|
||||
else:
|
||||
cfg.train.optimizer = Config._merge_a_into_b(
|
||||
self.optimizer_args, cfg.train.optimizer, force=True)
|
||||
if self.lr_scheduler_args is not None:
|
||||
if cfg.train.lr_scheduler.type != self.lr_scheduler_args.get(
|
||||
'type', cfg.train.lr_scheduler.type):
|
||||
cfg.train.lr_scheduler = ConfigDict(
|
||||
deepcopy(self.lr_scheduler_args))
|
||||
else:
|
||||
cfg.train.lr_scheduler = Config._merge_a_into_b(
|
||||
self.lr_scheduler_args,
|
||||
cfg.train.lr_scheduler,
|
||||
force=True)
|
||||
|
||||
if self.checkpoint_saving_type is not None or self.checkpoint_by_epoch is not None \
|
||||
or self.checkpoint_interval is not None or self.metric_key is not None:
|
||||
if not any([
|
||||
self.checkpoint_saving_type == hook['type']
|
||||
for hook in cfg.train.hooks
|
||||
]):
|
||||
cfg.train.hooks = list(
|
||||
filter(
|
||||
lambda hook: hook['type'] not in
|
||||
['CheckpointHook', 'BestCkptSaverHook'],
|
||||
cfg.train.hooks))
|
||||
cfg.train.hooks.append(
|
||||
deepcopy(self.default_train_config.hooks[0]))
|
||||
cfg.train.hooks[-1].type = self.checkpoint_saving_type
|
||||
checkpoint_hook = list(
|
||||
filter(
|
||||
lambda hook: hook[
|
||||
'type'] in ['CheckpointHook', 'BestCkptSaverHook'],
|
||||
cfg.train.hooks))[0]
|
||||
if self.checkpoint_by_epoch is not None:
|
||||
checkpoint_hook['by_epoch'] = self.checkpoint_by_epoch
|
||||
if self.checkpoint_interval is not None:
|
||||
checkpoint_hook['interval'] = self.checkpoint_interval
|
||||
if checkpoint_hook['type'] == 'BestCkptSaverHook':
|
||||
assert self.metric_key is not None, 'The metric_key must be provided ' \
|
||||
'if the ckpt saving hook is "BestCkptSaverHook"'
|
||||
checkpoint_hook['metric_key'] = self.metric_key
|
||||
|
||||
if self.evaluation_type is not None or self.evaluation_by_epoch is not None \
|
||||
or self.evaluation_interval is not None or self.eval_batch_size_per_gpu is not None or \
|
||||
self.eval_shuffle is not None or self.metrics is not None:
|
||||
if self.evaluation_type is not None and not any([
|
||||
self.evaluation_type == hook['type']
|
||||
for hook in cfg.train.hooks
|
||||
]):
|
||||
cfg.train.hooks = list(
|
||||
filter(lambda hook: hook['type'] not in ['EvaluationHook'],
|
||||
cfg.train.hooks))
|
||||
if self.evaluation_type != 'None':
|
||||
cfg.train.hooks.append(
|
||||
deepcopy(self.default_train_config.hooks[3]))
|
||||
cfg.train.hooks[-1].type = self.evaluation_type
|
||||
|
||||
evaluation_hook = list(
|
||||
filter(lambda hook: hook['type'] in ['EvaluationHook'],
|
||||
cfg.train.hooks))
|
||||
evaluation_hook = evaluation_hook[0] if len(
|
||||
evaluation_hook) > 0 else None
|
||||
|
||||
if evaluation_hook is not None and self.evaluation_by_epoch is not None:
|
||||
evaluation_hook['by_epoch'] = self.evaluation_by_epoch
|
||||
if evaluation_hook is not None and self.evaluation_interval is not None:
|
||||
evaluation_hook['interval'] = self.evaluation_interval
|
||||
|
||||
if not hasattr(cfg, 'evaluation'):
|
||||
cfg.evaluation = ConfigDict({
|
||||
'dataloader': {
|
||||
'batch_size_per_gpu': 32,
|
||||
'workers_per_gpu': 0,
|
||||
'shuffle': False
|
||||
}
|
||||
})
|
||||
|
||||
if self.metrics is not None:
|
||||
cfg.evaluation.metrics = self.metrics
|
||||
if self.eval_batch_size_per_gpu is not None:
|
||||
cfg.evaluation.dataloader.batch_size_per_gpu = self.eval_batch_size_per_gpu
|
||||
if self.eval_workers_per_gpu is not None:
|
||||
cfg.evaluation.dataloader.workers_per_gpu = self.eval_workers_per_gpu
|
||||
if self.eval_shuffle is not None:
|
||||
cfg.evaluation.dataloader.shuffle = self.eval_shuffle
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
@TRAINERS.register_module(module_name=Trainers.nlp_base_trainer)
|
||||
class NlpEpochBasedTrainer(EpochBasedTrainer):
|
||||
"""Add code to adapt with nlp models.
|
||||
|
||||
@@ -27,7 +27,7 @@ from modelscope.trainers.hooks.builder import HOOKS
|
||||
from modelscope.trainers.hooks.priority import Priority, get_priority
|
||||
from modelscope.trainers.lrscheduler.builder import build_lr_scheduler
|
||||
from modelscope.trainers.optimizer.builder import build_optimizer
|
||||
from modelscope.utils.config import Config, ConfigDict
|
||||
from modelscope.utils.config import Config, ConfigDict, JSONIteratorEncoder
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, ConfigFields,
|
||||
ConfigKeys, ModeKeys, ModelFile,
|
||||
TrainerStages)
|
||||
@@ -41,7 +41,7 @@ from modelscope.utils.torch_utils import (broadcast, get_dist_info,
|
||||
is_master, set_random_seed)
|
||||
from .base import BaseTrainer
|
||||
from .builder import TRAINERS
|
||||
from .default_config import merge_cfg
|
||||
from .default_config import merge_cfg, merge_hooks
|
||||
from .hooks.hook import Hook
|
||||
from .parallel.builder import build_parallel
|
||||
from .parallel.utils import is_parallel
|
||||
@@ -129,6 +129,15 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
# add default config
|
||||
merge_cfg(self.cfg)
|
||||
self.cfg = self.rebuild_config(self.cfg)
|
||||
self.logger = get_logger(log_level=self.cfg.get('log_level', 'INFO'))
|
||||
self.logger.info(
|
||||
'==========================Training Config Start=========================='
|
||||
)
|
||||
self.logger.info(
|
||||
json.dumps(self.cfg._cfg_dict, indent=4, cls=JSONIteratorEncoder))
|
||||
self.logger.info(
|
||||
'===========================Training Config End==========================='
|
||||
)
|
||||
if 'cfg_options' in kwargs:
|
||||
self.cfg.merge_from_dict(kwargs['cfg_options'])
|
||||
|
||||
@@ -488,7 +497,8 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
self.train_dataloader = self.get_train_dataloader()
|
||||
self.data_loader = self.train_dataloader
|
||||
self.register_optimizers_hook()
|
||||
self.register_hook_from_cfg(self.cfg.train.hooks)
|
||||
hooks = merge_hooks(self.cfg)
|
||||
self.register_hook_from_cfg(hooks)
|
||||
self.set_checkpoint_file_to_hook(checkpoint_path)
|
||||
self.model.train()
|
||||
|
||||
@@ -1006,7 +1016,7 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
if not inserted:
|
||||
self._hooks.insert(0, hook)
|
||||
|
||||
def register_hook_from_cfg(self, hook_cfg: Dict) -> None:
|
||||
def register_hook_from_cfg(self, hook_cfg: List) -> None:
|
||||
"""Register a hook from its cfg.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,169 +1,625 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import dataclasses
|
||||
import re
|
||||
from argparse import Action, ArgumentDefaultsHelpFormatter, ArgumentParser
|
||||
from typing import Any, Dict, List, Union
|
||||
from dataclasses import dataclass, field, fields
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from addict import Dict as Adict
|
||||
from modelscope.trainers.default_config import DEFAULT_CONFIG
|
||||
from modelscope.utils.config import Config, ConfigDict
|
||||
from modelscope.utils.hub import read_config
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ArgAttr():
|
||||
""" Attributes for each arg
|
||||
def get_flatten_value(config: Config, metadata: Dict, exclusions=None):
|
||||
cfg_node = metadata['cfg_node']
|
||||
if exclusions is None:
|
||||
exclusions = []
|
||||
|
||||
Args:
|
||||
cfg_node_name (str or list[str]): if set empty, it means a normal arg for argparse, otherwise it means
|
||||
this arg value correspond to those nodes in configuration file, and will replace them for training.
|
||||
default: default value for current argument.
|
||||
type: type for current argument.
|
||||
choices (list of str): choices of value for this argument.
|
||||
help (str): help str for this argument.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# define argument train_batch_size which corresponds to train.dataloader.batch_size_per_gpu
|
||||
training_args = Adict(
|
||||
train_batch_size=ArgAttr(
|
||||
'train.dataloader.batch_size_per_gpu',
|
||||
default=16,
|
||||
type=int,
|
||||
help='training batch size')
|
||||
)
|
||||
|
||||
# num_classes which will modify three places in configuration
|
||||
training_args = Adict(
|
||||
num_classes = ArgAttr(
|
||||
['model.mm_model.head.num_classes',
|
||||
'model.mm_model.train_cfg.augments.0.num_classes',
|
||||
'model.mm_model.train_cfg.augments.1.num_classes'],
|
||||
type=int,
|
||||
help='number of classes')
|
||||
)
|
||||
```
|
||||
# a normal argument which has no relation with configuration
|
||||
training_args = Adict(
|
||||
local_rank = ArgAttr(
|
||||
'',
|
||||
default=1,
|
||||
type=int,
|
||||
help='local rank for current training process')
|
||||
)
|
||||
|
||||
"""
|
||||
cfg_node_name: Union[str, List[str]] = ''
|
||||
default: Any = None
|
||||
type: type = None
|
||||
choices: List[str] = None
|
||||
help: str = ''
|
||||
values = config.safe_get(cfg_node)
|
||||
if isinstance(values, dict):
|
||||
param_map = []
|
||||
for key, value in values.items():
|
||||
if key in exclusions or not isinstance(value,
|
||||
(str, int, float, bool)):
|
||||
continue
|
||||
value = add_quotes_for_str(value)
|
||||
param_map.append(f'{key}={value}')
|
||||
return ','.join(param_map)
|
||||
else:
|
||||
return values
|
||||
|
||||
|
||||
training_args = Adict(
|
||||
train_batch_size=ArgAttr(
|
||||
'train.dataloader.batch_size_per_gpu',
|
||||
default=16,
|
||||
type=int,
|
||||
help='training batch size'),
|
||||
train_data_worker=ArgAttr(
|
||||
'train.dataloader.workers_per_gpu',
|
||||
default=8,
|
||||
type=int,
|
||||
help='number of data worker used for training'),
|
||||
eval_batch_size=ArgAttr(
|
||||
'evaluation.dataloader.batch_size_per_gpu',
|
||||
default=16,
|
||||
type=int,
|
||||
help='training batch size'),
|
||||
max_epochs=ArgAttr(
|
||||
'train.max_epochs',
|
||||
default=10,
|
||||
type=int,
|
||||
help='max number of training epoch'),
|
||||
work_dir=ArgAttr(
|
||||
'train.work_dir',
|
||||
default='./work_dir',
|
||||
type=str,
|
||||
help='training directory to save models and training logs'),
|
||||
lr=ArgAttr(
|
||||
'train.optimizer.lr',
|
||||
default=0.001,
|
||||
type=float,
|
||||
help='initial learning rate'),
|
||||
optimizer=ArgAttr(
|
||||
'train.optimizer.type',
|
||||
default='SGD',
|
||||
type=str,
|
||||
choices=[
|
||||
'Adadelta', 'Adagrad', 'Adam', 'AdamW', 'Adamax', 'ASGD',
|
||||
'RMSprop', 'Rprop'
|
||||
'SGD'
|
||||
],
|
||||
help='optimizer type'),
|
||||
local_rank=ArgAttr(
|
||||
'', default=0, type=int, help='local rank for this process'))
|
||||
def set_flatten_value(config: Config, values: Union[str, List[str]],
|
||||
metadata: Dict):
|
||||
cfg_node = metadata['cfg_node']
|
||||
if values is None:
|
||||
return config
|
||||
|
||||
pairs = values.split(',') if isinstance(values, str) else values
|
||||
for kv in pairs:
|
||||
if len(kv.strip()) == 0:
|
||||
continue
|
||||
key, value = kv.split('=')
|
||||
value = parse_value(value)
|
||||
config.merge_from_dict({cfg_node + '.' + key: value})
|
||||
return config
|
||||
|
||||
|
||||
def get_base_hook_args(config: Config, metadata: Dict):
|
||||
cfg_node = metadata['cfg_node']
|
||||
hook_type = metadata['hook_type']
|
||||
key = metadata['key']
|
||||
value = config.safe_get(cfg_node)
|
||||
if value is None:
|
||||
return get_hook_param(config, hook_type, key)
|
||||
else:
|
||||
return True if key == 'type' else value
|
||||
|
||||
|
||||
def set_base_hook_args(config: Config, value: Any, metadata: Dict):
|
||||
cfg_node = metadata['cfg_node']
|
||||
hook_type = metadata['hook_type']
|
||||
key = metadata['key']
|
||||
if 'hooks' in config.train:
|
||||
config.train.hooks = [
|
||||
hook for hook in config.train.hooks if hook['type'] != hook_type
|
||||
]
|
||||
if key == 'type':
|
||||
if value and config.safe_get(cfg_node) is None:
|
||||
config.merge_from_dict({cfg_node: {}})
|
||||
else:
|
||||
config.merge_from_dict({cfg_node: value})
|
||||
|
||||
|
||||
def get_strategy(config: Config,
|
||||
metadata: Dict,
|
||||
value_pair: Tuple[str] = ('by_epoch', 'by_step')):
|
||||
flag = get_base_hook_args(config, metadata)
|
||||
if flag is None:
|
||||
return None
|
||||
return value_pair[0] if flag else value_pair[1]
|
||||
|
||||
|
||||
def set_strategy(config: Config,
|
||||
value: Any,
|
||||
metadata: Dict,
|
||||
value_pair: Tuple[str] = ('by_epoch', 'by_step')):
|
||||
set_base_hook_args(config, value == value_pair[0], metadata)
|
||||
|
||||
|
||||
def get_hook_param(config, hook_type: str, key='type'):
|
||||
hooks = config.safe_get('train.hooks', [])
|
||||
_hooks = list(filter(lambda hook: hook['type'] == hook_type, hooks))
|
||||
if key == 'type':
|
||||
return len(_hooks) > 0
|
||||
elif len(_hooks) > 0:
|
||||
return getattr(_hooks[0], key, None)
|
||||
return None
|
||||
|
||||
|
||||
def add_quotes_for_str(value: Union[str, float, bool, None]) -> str:
|
||||
if isinstance(value, str):
|
||||
return f'"{value}"'
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
|
||||
def parse_value(value: str) -> Union[str, float, bool, None]:
|
||||
const_map = {
|
||||
'True': True,
|
||||
'true': True,
|
||||
'False': False,
|
||||
'false': False,
|
||||
'None': None,
|
||||
'none': None,
|
||||
'null': None
|
||||
}
|
||||
if value in const_map:
|
||||
return const_map[value]
|
||||
elif '"' in value or "'" in value:
|
||||
return value.replace('"', '').replace("'", '')
|
||||
elif re.match(r'^\d+$', value):
|
||||
return int(value)
|
||||
elif re.match(r'[+-]?(?=\d*[.eE])(?=\.?\d)\d*\.?\d*(?:[eE][+-]?\d+)?',
|
||||
value):
|
||||
return float(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArgs:
|
||||
model: str = field(
|
||||
default=None, metadata={
|
||||
'help': 'A model id or model dir',
|
||||
})
|
||||
|
||||
seed: int = field(
|
||||
default=42, metadata={
|
||||
'help': 'The random seed',
|
||||
})
|
||||
|
||||
task: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The task code to be used',
|
||||
'cfg_node': 'task'
|
||||
})
|
||||
|
||||
dataset_name: str = field(
|
||||
default=None, metadata={
|
||||
'help': 'The dataset name',
|
||||
})
|
||||
|
||||
subset_name: str = field(
|
||||
default=None, metadata={
|
||||
'help': 'The subset name of the dataset',
|
||||
})
|
||||
|
||||
train_dataset_name: str = field(
|
||||
default=None, metadata={
|
||||
'help': 'The train dataset name',
|
||||
})
|
||||
|
||||
val_dataset_name: str = field(
|
||||
default=None, metadata={
|
||||
'help': 'The validation dataset name',
|
||||
})
|
||||
|
||||
per_device_train_batch_size: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': 'train.dataloader.batch_size_per_gpu',
|
||||
'help': 'The training batch size per GPU',
|
||||
})
|
||||
|
||||
train_data_worker: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': 'train.dataloader.workers_per_gpu',
|
||||
'help': 'The number of data workers for train dataloader',
|
||||
})
|
||||
|
||||
train_shuffle: bool = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': 'train.dataloader.shuffle',
|
||||
'help': 'Shuffle the train dataset or not',
|
||||
})
|
||||
|
||||
per_device_eval_batch_size: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': 'evaluation.dataloader.batch_size_per_gpu',
|
||||
'help': 'The eval batch size per GPU',
|
||||
})
|
||||
|
||||
eval_data_worker: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': 'evaluation.dataloader.workers_per_gpu',
|
||||
'help': 'The number of data workers for eval dataloader',
|
||||
})
|
||||
|
||||
eval_shuffle: bool = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': 'evaluation.dataloader.shuffle',
|
||||
'help': 'Shuffle the eval dataset or not',
|
||||
})
|
||||
|
||||
max_epochs: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': 'train.max_epochs',
|
||||
'help': 'The training epochs',
|
||||
})
|
||||
|
||||
work_dir: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': 'train.work_dir',
|
||||
'help': 'The training dir to save models and logs',
|
||||
})
|
||||
|
||||
lr: float = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': 'train.optimizer.lr',
|
||||
'help': 'The learning rate of the optimizer',
|
||||
})
|
||||
|
||||
optimizer: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node': 'train.optimizer.type',
|
||||
'help': 'The optimizer type',
|
||||
})
|
||||
|
||||
optimizer_params: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node':
|
||||
'train.optimizer',
|
||||
'cfg_getter':
|
||||
partial(get_flatten_value, exclusions=['type', 'lr', 'options']),
|
||||
'cfg_setter':
|
||||
set_flatten_value,
|
||||
'help':
|
||||
'The optimizer init params except `lr`',
|
||||
})
|
||||
|
||||
lr_scheduler_params: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'cfg_node':
|
||||
'train.lr_scheduler',
|
||||
'cfg_getter':
|
||||
partial(get_flatten_value, exclusions=['type', 'lr', 'options']),
|
||||
'cfg_setter':
|
||||
set_flatten_value,
|
||||
'help':
|
||||
'The lr_scheduler init params',
|
||||
})
|
||||
|
||||
local_rank: int = field(
|
||||
default=0, metadata={
|
||||
'help': 'The training local rank',
|
||||
})
|
||||
|
||||
save_ckpt: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
'help':
|
||||
'Periodically save checkpoint when True, corresponding to CheckpointHook',
|
||||
'cfg_node': 'train.checkpoint.period',
|
||||
'hook_type': 'CheckpointHook',
|
||||
'key': 'type',
|
||||
'cfg_getter': get_base_hook_args,
|
||||
'cfg_setter': set_base_hook_args,
|
||||
})
|
||||
|
||||
save_ckpt_best: bool = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'Save best checkpoint when True, corresponding to BestCkptSaverHook',
|
||||
'cfg_node': 'train.checkpoint.best',
|
||||
'hook_type': 'BestCkptSaverHook',
|
||||
'key': 'type',
|
||||
'cfg_getter': get_base_hook_args,
|
||||
'cfg_setter': set_base_hook_args,
|
||||
})
|
||||
|
||||
evaluate: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
'help': 'Evaluate when True, corresponding to EvaluationHook',
|
||||
'cfg_node': 'evaluation.period',
|
||||
'hook_type': 'EvaluationHook',
|
||||
'key': 'type',
|
||||
'cfg_getter': get_base_hook_args,
|
||||
'cfg_setter': set_base_hook_args,
|
||||
})
|
||||
|
||||
save_ckpt_strategy: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'Periodically save checkpoint by epoch or by step'
|
||||
'use with `CheckpointHook`, can be `by_epoch` or `by_step`',
|
||||
'cfg_node': 'train.checkpoint.period.by_epoch',
|
||||
'hook_type': 'CheckpointHook',
|
||||
'key': 'by_epoch',
|
||||
'choices': ['by_epoch', 'by_step'],
|
||||
'cfg_getter': get_strategy,
|
||||
'cfg_setter': set_strategy,
|
||||
})
|
||||
|
||||
save_ckpt_best_strategy: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'Save best checkpoint by epoch or by step'
|
||||
'use with `BestCkptSaverHook`, can be `by_epoch` or `by_step`',
|
||||
'cfg_node': 'train.checkpoint.best.by_epoch',
|
||||
'hook_type': 'BestCkptSaverHook',
|
||||
'key': 'by_epoch',
|
||||
'choices': ['by_epoch', 'by_step'],
|
||||
'cfg_getter': get_strategy,
|
||||
'cfg_setter': set_strategy,
|
||||
})
|
||||
|
||||
ckpt_period_interval: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
'help':
|
||||
'The interval of epoch or iter of saving checkpoint period',
|
||||
'cfg_node': 'train.checkpoint.period.interval',
|
||||
'hook_type': 'CheckpointHook',
|
||||
'key': 'interval',
|
||||
'cfg_getter': get_base_hook_args,
|
||||
'cfg_setter': set_base_hook_args,
|
||||
})
|
||||
|
||||
ckpt_best_interval: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The interval of epoch or iter of saving checkpoint best',
|
||||
'cfg_node': 'train.checkpoint.best.interval',
|
||||
'hook_type': 'BestCkptSaverHook',
|
||||
'key': 'interval',
|
||||
'cfg_getter': get_base_hook_args,
|
||||
'cfg_setter': set_base_hook_args,
|
||||
})
|
||||
|
||||
metric_for_best_model: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'Which metric key to judge the checkpoint is better or not, use with `BestCkptSaverHook`, '
|
||||
'please make sure this key is returned by the `evaluation_metrics` classes',
|
||||
'cfg_node':
|
||||
'train.checkpoint.best.metric_key',
|
||||
'hook_type':
|
||||
'BestCkptSaverHook',
|
||||
'key':
|
||||
'metric_key',
|
||||
'cfg_getter':
|
||||
get_base_hook_args,
|
||||
'cfg_setter':
|
||||
set_base_hook_args,
|
||||
})
|
||||
|
||||
metric_rule_for_best_model: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'Which rule to compare the value of `checkpoint_saving_metric`, '
|
||||
'use with `BestCkptSaverHook`, can be `max` or `min`',
|
||||
'cfg_node':
|
||||
'train.checkpoint.best.rule',
|
||||
'hook_type':
|
||||
'BestCkptSaverHook',
|
||||
'key':
|
||||
'rule',
|
||||
'cfg_getter':
|
||||
get_base_hook_args,
|
||||
'cfg_setter':
|
||||
set_base_hook_args,
|
||||
})
|
||||
|
||||
save_ckpt_peroid_limit: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The max saving number of checkpoint, older checkpoints will be deleted.',
|
||||
'cfg_node': 'train.checkpoint.period.max_checkpoint_num',
|
||||
'hook_type': 'CheckpointHook',
|
||||
'key': 'max_checkpoint_num',
|
||||
'cfg_getter': get_base_hook_args,
|
||||
'cfg_setter': set_base_hook_args,
|
||||
})
|
||||
|
||||
save_ckpt_best_limit: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help':
|
||||
'The max saving number of checkpoint, worse checkpoints will be deleted.',
|
||||
'cfg_node': 'train.checkpoint.best.max_checkpoint_num',
|
||||
'hook_type': 'BestCkptSaverHook',
|
||||
'key': 'max_checkpoint_num',
|
||||
'cfg_getter': get_base_hook_args,
|
||||
'cfg_setter': set_base_hook_args,
|
||||
})
|
||||
|
||||
logging_interval: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The interval of iter of logging information',
|
||||
'cfg_node': 'train.logging.interval',
|
||||
'hook_type': 'TextLoggerHook',
|
||||
'key': 'interval',
|
||||
'cfg_getter': get_base_hook_args,
|
||||
'cfg_setter': set_base_hook_args,
|
||||
})
|
||||
|
||||
eval_strategy: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'Evaluate model by epoch or by step'
|
||||
'use with `EvaluationHook`, can be `by_epoch` or `by_step`',
|
||||
'cfg_node': 'evaluation.period.by_epoch',
|
||||
'hook_type': 'EvaluationHook',
|
||||
'key': 'by_epoch',
|
||||
'choices': ['by_epoch', 'by_step'],
|
||||
'cfg_getter': get_strategy,
|
||||
'cfg_setter': set_strategy,
|
||||
})
|
||||
|
||||
eval_interval: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
'help': 'Evaluation interval by epoch or iter',
|
||||
'cfg_node': 'evaluation.period.interval',
|
||||
'hook_type': 'EvaluationHook',
|
||||
'key': 'interval',
|
||||
'cfg_getter': get_base_hook_args,
|
||||
'cfg_setter': set_base_hook_args,
|
||||
})
|
||||
|
||||
eval_metrics: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The metric module name used in evaluation',
|
||||
'cfg_node': 'evaluation.metrics'
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, parser_args=None, **extra_kwargs):
|
||||
"""Construct a TrainingArg class by the parameters of CLI.
|
||||
|
||||
Args:
|
||||
**extra_kwargs: Extra args which can be defined in code.
|
||||
|
||||
Returns:
|
||||
The output TrainingArg class with the parameters from CLI.
|
||||
"""
|
||||
self = cls(**extra_kwargs)
|
||||
parser = CliArgumentParser(self)
|
||||
args, unknown = parser.parse_known_args(parser_args)
|
||||
unknown = [item for item in unknown if item not in ('\\', '\n')]
|
||||
_unknown = {}
|
||||
for i in range(0, len(unknown), 2):
|
||||
_unknown[unknown[i].replace('-', '')] = parse_value(unknown[i + 1])
|
||||
cfg_dict = vars(args)
|
||||
|
||||
if args.model is not None:
|
||||
try:
|
||||
cfg = read_config(args.model)
|
||||
except Exception as e:
|
||||
print('Read config failed with error:', e)
|
||||
else:
|
||||
cfg.merge_from_dict(_unknown)
|
||||
self = cls.from_config(cfg, **extra_kwargs)
|
||||
for key, value in cfg_dict.items():
|
||||
if key is not None and hasattr(self,
|
||||
key) and key in parser.manual_args:
|
||||
setattr(self, key, value)
|
||||
return self
|
||||
|
||||
def to_args(self):
|
||||
"""Convert the TrainingArg class to key-value pairs.
|
||||
|
||||
Returns: The key-value pair.
|
||||
|
||||
"""
|
||||
_args = {}
|
||||
for f in fields(self):
|
||||
_args[f.name] = getattr(self, f.name)
|
||||
return _args
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config=DEFAULT_CONFIG, **kwargs):
|
||||
"""Construct the TrainingArg class by a `Config` class.
|
||||
|
||||
Args:
|
||||
config: The Config class. By default, `DEFAULT_CONFIG` is used.
|
||||
**kwargs: Extra args which can be defined in code.
|
||||
|
||||
Returns: The output TrainingArg class with the parameters from the config.
|
||||
|
||||
"""
|
||||
|
||||
self = cls(**kwargs)
|
||||
for f in fields(self):
|
||||
if 'cfg_node' in f.metadata and getattr(self, f.name) is None:
|
||||
self._to_field(f, config)
|
||||
return self
|
||||
|
||||
def _to_field(self, f, config):
|
||||
assert 'cfg_node' in f.metadata
|
||||
if 'cfg_getter' in f.metadata:
|
||||
cfg_getter = f.metadata['cfg_getter']
|
||||
setattr(self, f.name, cfg_getter(config, f.metadata))
|
||||
else:
|
||||
cfg_node = f.metadata['cfg_node']
|
||||
setattr(self, f.name, config.safe_get(cfg_node))
|
||||
|
||||
def _to_config(self, f, config: Config):
|
||||
assert 'cfg_node' in f.metadata
|
||||
value = getattr(self, f.name)
|
||||
if 'cfg_setter' in f.metadata:
|
||||
cfg_setter = f.metadata['cfg_setter']
|
||||
config = cfg_setter(config, value, f.metadata)
|
||||
else:
|
||||
cfg_node = f.metadata['cfg_node']
|
||||
if isinstance(cfg_node, str):
|
||||
cfg_node = [cfg_node]
|
||||
for _node in cfg_node:
|
||||
config.merge_from_dict({_node: value})
|
||||
return config
|
||||
|
||||
def __call__(self, cfg: Config):
|
||||
for f in fields(self):
|
||||
if 'cfg_node' not in f.metadata:
|
||||
continue
|
||||
|
||||
value = getattr(self, f.name)
|
||||
if value is not None:
|
||||
self._to_config(f, cfg)
|
||||
else:
|
||||
self._to_field(f, cfg)
|
||||
return cfg
|
||||
|
||||
|
||||
class CliArgumentParser(ArgumentParser):
|
||||
""" Argument Parser to define and parse command-line args for training.
|
||||
|
||||
Args:
|
||||
arg_dict (dict of `ArgAttr` or list of them): dict or list of dict which defines different
|
||||
parameters for training.
|
||||
training_args (TrainingArgs): dict or list of dict which defines different
|
||||
paramters for training.
|
||||
"""
|
||||
|
||||
def __init__(self, arg_dict: Union[Dict[str, ArgAttr],
|
||||
List[Dict[str, ArgAttr]]], **kwargs):
|
||||
def __init__(self, training_args: TrainingArgs = None, **kwargs):
|
||||
if 'formatter_class' not in kwargs:
|
||||
kwargs['formatter_class'] = ArgumentDefaultsHelpFormatter
|
||||
super().__init__(**kwargs)
|
||||
self.arg_dict = arg_dict if isinstance(
|
||||
arg_dict, Dict) else self._join_args(arg_dict)
|
||||
self.training_args = training_args
|
||||
self.define_args()
|
||||
|
||||
def _join_args(self, arg_dict_list: List[Dict[str, ArgAttr]]):
|
||||
total_args = arg_dict_list[0].copy()
|
||||
for args in arg_dict_list[1:]:
|
||||
total_args.update(args)
|
||||
return total_args
|
||||
def get_manual_args(self, args):
|
||||
return [arg[2:] for arg in args if arg.startswith('--')]
|
||||
|
||||
def _parse_known_args(self, args: List = None, namespace=None):
|
||||
self.model_id = namespace.model if namespace is not None else None
|
||||
if '--model' in args:
|
||||
self.model_id = args[args.index('--model') + 1]
|
||||
self.manual_args = self.get_manual_args(args)
|
||||
return super()._parse_known_args(args, namespace)
|
||||
|
||||
def print_help(self, file=None):
|
||||
config = DEFAULT_CONFIG
|
||||
if self.model_id is not None:
|
||||
try:
|
||||
config = read_config(self.model_id)
|
||||
except Exception as e:
|
||||
print('Read config failed with error:', e)
|
||||
|
||||
if config is not None:
|
||||
for action_group in self._optionals._group_actions:
|
||||
if hasattr(self.training_args, action_group.dest):
|
||||
value = getattr(self.training_args, action_group.dest)
|
||||
f = {f.name: f
|
||||
for f in fields(self.training_args)
|
||||
}.get(action_group.dest)
|
||||
if value is not None:
|
||||
action_group.default = value
|
||||
elif 'cfg_node' in f.metadata:
|
||||
cfg_node = f.metadata['cfg_node']
|
||||
if isinstance(cfg_node, str):
|
||||
cfg_node = [cfg_node]
|
||||
|
||||
assert isinstance(cfg_node, (list, tuple))
|
||||
if isinstance(cfg_node[0], str):
|
||||
action_group.default = config.safe_get(cfg_node[0])
|
||||
else:
|
||||
action_group.default = cfg_node[0](config)
|
||||
return super().print_help(file)
|
||||
|
||||
def define_args(self):
|
||||
for arg_name, arg_attr in self.arg_dict.items():
|
||||
name = f'--{arg_name}'
|
||||
kwargs = dict(type=arg_attr.type, help=arg_attr.help)
|
||||
if arg_attr.default is not None:
|
||||
kwargs['default'] = arg_attr.default
|
||||
else:
|
||||
kwargs['required'] = True
|
||||
if self.training_args is not None:
|
||||
for f in fields(self.training_args):
|
||||
arg_name = f.name
|
||||
arg_attr = getattr(self.training_args, f.name)
|
||||
name = f'--{arg_name}'
|
||||
kwargs = dict(type=f.type, help=f.metadata['help'])
|
||||
kwargs['default'] = arg_attr
|
||||
|
||||
if arg_attr.choices is not None:
|
||||
kwargs['choices'] = arg_attr.choices
|
||||
if 'choices' in f.metadata:
|
||||
kwargs['choices'] = f.metadata['choices']
|
||||
|
||||
kwargs['action'] = SingleAction
|
||||
self.add_argument(name, **kwargs)
|
||||
|
||||
def get_cfg_dict(self, args=None):
|
||||
"""
|
||||
Args:
|
||||
args (default None):
|
||||
List of strings to parse. The default is taken from sys.argv. (same as argparse.ArgumentParser)
|
||||
|
||||
Returns:
|
||||
cfg_dict (dict of config): each key is a config node name such as 'train.max_epochs', this cfg_dict
|
||||
should be used with function `cfg.merge_from_dict` to update config object.
|
||||
"""
|
||||
self.args, remainning = self.parse_known_args(args)
|
||||
args_dict = vars(self.args)
|
||||
cfg_dict = {}
|
||||
for k, v in args_dict.items():
|
||||
if k not in self.arg_dict or self.arg_dict[k].cfg_node_name == '':
|
||||
continue
|
||||
cfg_node = self.arg_dict[k].cfg_node_name
|
||||
if isinstance(cfg_node, list):
|
||||
for node in cfg_node:
|
||||
cfg_dict[node] = v
|
||||
else:
|
||||
cfg_dict[cfg_node] = v
|
||||
|
||||
return cfg_dict
|
||||
kwargs['action'] = SingleAction
|
||||
self.add_argument(name, **kwargs)
|
||||
|
||||
|
||||
class DictAction(Action):
|
||||
@@ -215,8 +671,8 @@ class DictAction(Action):
|
||||
inside these brackets are ignored.
|
||||
"""
|
||||
assert (string.count('(') == string.count(')')) and (
|
||||
string.count('[') == string.count(']')), \
|
||||
f'Imbalanced brackets exist in {string}'
|
||||
string.count('[')
|
||||
== string.count(']')), f'Imbalanced brackets exist in {string}'
|
||||
end = len(string)
|
||||
for idx, char in enumerate(string):
|
||||
pre = string[:idx]
|
||||
|
||||
@@ -337,7 +337,7 @@ class Config:
|
||||
super(Config, self).__setattr__('_filename', _filename)
|
||||
super(Config, self).__setattr__('_text', _text)
|
||||
|
||||
def safe_get(self, key_chain: str, default=None):
|
||||
def safe_get(self, key_chain: str, default=None, type_field='type'):
|
||||
"""Get a value with a key-chain in str format, if key does not exist, the default value will be returned.
|
||||
|
||||
This method is safe to call, and will not edit any value.
|
||||
@@ -345,7 +345,9 @@ class Config:
|
||||
Args:
|
||||
key_chain: The input key chain, for example: 'train.hooks[0].type'
|
||||
default: The default value returned when any key does not exist, default None.
|
||||
|
||||
type_field: Get an object from a list or tuple for example by 'train.hooks.CheckPointHook', in which
|
||||
'hooks' is a list, and 'CheckPointHook' is a value of the content of key `type_field`.
|
||||
If there are multiple matched objects, the first element will be returned.
|
||||
Returns:
|
||||
The value, or the default value.
|
||||
"""
|
||||
@@ -357,7 +359,16 @@ class Config:
|
||||
if '[' in key:
|
||||
key, val = key.split('[')
|
||||
val, _ = val.split(']')
|
||||
_cfg_dict = getattr(_cfg_dict, key)
|
||||
|
||||
if isinstance(_cfg_dict, (list, tuple)):
|
||||
assert type_field is not None, 'Getting object without an index from a list or tuple ' \
|
||||
'needs an valid `type_field` param.'
|
||||
_sub_cfg_dict = list(
|
||||
filter(lambda sub: getattr(sub, type_field) == key,
|
||||
_cfg_dict))
|
||||
_cfg_dict = _sub_cfg_dict[0]
|
||||
else:
|
||||
_cfg_dict = getattr(_cfg_dict, key)
|
||||
if val is not None:
|
||||
_cfg_dict = _cfg_dict[int(val)]
|
||||
return _cfg_dict
|
||||
|
||||
@@ -8,12 +8,13 @@ from modelscope.metainfo import Preprocessors, Trainers
|
||||
from modelscope.models import Model
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.trainers import NlpTrainerArguments, build_trainer
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.trainers.hooks import Hook
|
||||
from modelscope.trainers.nlp_trainer import (EpochBasedTrainer,
|
||||
NlpEpochBasedTrainer)
|
||||
from modelscope.trainers.optimizer.child_tuning_adamw_optimizer import \
|
||||
calculate_fisher
|
||||
from modelscope.trainers.training_args import TrainingArgs
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.data_utils import to_device
|
||||
from modelscope.utils.regress_test_utils import (MsRegressTool,
|
||||
@@ -43,7 +44,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
|
||||
dataset = MsDataset.load('clue', subset_name='tnews')
|
||||
train_dataset = dataset['train']
|
||||
validation_dataset = dataset['validation']
|
||||
cfg_modify_fn = NlpTrainerArguments(
|
||||
cfg_modify_fn = TrainingArgs(
|
||||
task=Tasks.text_classification,
|
||||
preprocessor_type=Preprocessors.sen_cls_tokenizer,
|
||||
train_first_sequence='sentence',
|
||||
|
||||
@@ -1,17 +1,8 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.trainers.training_args import (ArgAttr, CliArgumentParser,
|
||||
training_args)
|
||||
from modelscope.trainers.default_config import DEFAULT_CONFIG
|
||||
from modelscope.trainers.training_args import CliArgumentParser, TrainingArgs
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
@@ -25,54 +16,32 @@ class TrainingArgsTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_define_args(self):
|
||||
myparser = CliArgumentParser(training_args)
|
||||
myparser = CliArgumentParser(TrainingArgs())
|
||||
input_args = [
|
||||
'--max_epochs', '100', '--work_dir', 'ddddd', '--train_batch_size',
|
||||
'8', '--unkown', 'unkown'
|
||||
'--max_epochs', '100', '--work_dir', 'ddddd',
|
||||
'--per_device_train_batch_size', '8', '--unkown', 'unkown'
|
||||
]
|
||||
args, remainning = myparser.parse_known_args(input_args)
|
||||
myparser.print_help()
|
||||
self.assertTrue(args.max_epochs == 100)
|
||||
self.assertTrue(args.work_dir == 'ddddd')
|
||||
self.assertTrue(args.train_batch_size == 8)
|
||||
self.assertTrue(args.per_device_train_batch_size == 8)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_new_args(self):
|
||||
training_args.num_classes = ArgAttr(
|
||||
'model.mm_model.head.num_classes',
|
||||
type=int,
|
||||
help='number of classes')
|
||||
training_args.mean = ArgAttr(
|
||||
'train.data.mean', help='3-dim mean vector')
|
||||
training_args.flip = ArgAttr('train.data.flip', help='flip or not')
|
||||
training_args.img_size = ArgAttr(
|
||||
'train.data.img_size', help='image size')
|
||||
myparser = CliArgumentParser(training_args)
|
||||
def test_flatten_args(self):
|
||||
cfg = DEFAULT_CONFIG
|
||||
input_args = [
|
||||
'--max_epochs', '100', '--work_dir', 'ddddd', '--train_batch_size',
|
||||
'8', '--num_classes', '10', '--mean', '[125.0,125.0,125.0]',
|
||||
'--flip', 'false', '--img_size', '(640,640)'
|
||||
'--optimizer_params',
|
||||
'weight_decay=0.8,eps=1e-6,correct_bias=False',
|
||||
'--lr_scheduler_params', 'initial_lr=3e-5,niter_decay=1'
|
||||
]
|
||||
args, remainning = myparser.parse_known_args(input_args)
|
||||
myparser.print_help()
|
||||
self.assertTrue(args.max_epochs == 100)
|
||||
self.assertTrue(args.work_dir == 'ddddd')
|
||||
self.assertTrue(args.train_batch_size == 8)
|
||||
self.assertTrue(args.num_classes == 10)
|
||||
self.assertTrue(len(args.mean) == 3)
|
||||
self.assertTrue(not args.flip)
|
||||
self.assertAlmostEqual(args.mean[0], 125.0)
|
||||
self.assertAlmostEqual(args.img_size, (640, 640))
|
||||
|
||||
cfg_dict = myparser.get_cfg_dict(args=input_args)
|
||||
self.assertTrue(cfg_dict['model.mm_model.head.num_classes'] == 10)
|
||||
self.assertAlmostEqual(cfg_dict['train.data.mean'],
|
||||
[125.0, 125.0, 125.0])
|
||||
self.assertTrue(not cfg_dict['train.data.flip'])
|
||||
self.assertEqual(cfg_dict['train.dataloader.batch_size_per_gpu'], 8)
|
||||
self.assertEqual(cfg_dict['train.work_dir'], 'ddddd')
|
||||
self.assertEqual(cfg_dict['train.max_epochs'], 100)
|
||||
self.assertEqual(cfg_dict['train.data.img_size'], (640, 640))
|
||||
training_args = TrainingArgs.from_cli(input_args)
|
||||
cfg = training_args(cfg)
|
||||
self.assertAlmostEqual(cfg.train.optimizer.weight_decay, 0.8)
|
||||
self.assertAlmostEqual(cfg.train.optimizer.eps, 1e-6)
|
||||
self.assertFalse(cfg.train.optimizer.correct_bias)
|
||||
self.assertAlmostEqual(cfg.train.lr_scheduler.initial_lr, 3e-5)
|
||||
self.assertEqual(cfg.train.lr_scheduler.niter_decay, 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user