Files
modelscope/examples/pytorch/text_classification/finetune_text_classification.py
hemu.zp 82482b3e96 update training args
Based on feat/0131/nlp_args branch, the original code review: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11408570

Support for running finetuning from the command line with training args, Compatible with the configuration optimization.
2023-02-10 05:32:21 +00:00

91 lines
2.6 KiB
Python

import os
from dataclasses import dataclass, field
from modelscope.msdatasets import MsDataset
from modelscope.trainers import EpochBasedTrainer, build_trainer
from modelscope.trainers.training_args import TrainingArgs
def get_labels(cfg, metadata):
label2id = cfg.safe_get(metadata['cfg_node'])
if label2id is not None:
return ','.join(label2id.keys())
def set_labels(cfg, labels, metadata):
if isinstance(labels, str):
labels = labels.split(',')
cfg.merge_from_dict(
{metadata['cfg_node']: {label: id
for id, label in enumerate(labels)}})
@dataclass
class TextClassificationArguments(TrainingArgs):
first_sequence: str = field(
default=None,
metadata={
'help': 'The first sequence key of preprocessor',
'cfg_node': 'preprocessor.first_sequence'
})
second_sequence: str = field(
default=None,
metadata={
'help': 'The second sequence key of preprocessor',
'cfg_node': 'preprocessor.second_sequence'
})
label: str = field(
default=None,
metadata={
'help': 'The label key of preprocessor',
'cfg_node': 'preprocessor.label'
})
labels: str = field(
default=None,
metadata={
'help': 'The labels of the dataset',
'cfg_node': 'preprocessor.label2id',
'cfg_getter': get_labels,
'cfg_setter': set_labels,
})
preprocessor: str = field(
default=None,
metadata={
'help': 'The preprocessor type',
'cfg_node': 'preprocessor.type'
})
def __call__(self, config):
config = super().__call__(config)
config.model['num_labels'] = len(self.labels)
if config.train.lr_scheduler.type == 'LinearLR':
config.train.lr_scheduler['total_iters'] = \
int(len(train_dataset) / self.per_device_train_batch_size) * self.max_epochs
return config
args = TextClassificationArguments.from_cli(
task='text-classification', eval_metrics='seq-cls-metric')
print(args)
dataset = MsDataset.load(args.dataset_name, subset_name=args.subset_name)
train_dataset = dataset['train']
validation_dataset = dataset['validation']
kwargs = dict(
model=args.model,
train_dataset=train_dataset,
eval_dataset=validation_dataset,
seed=args.seed,
cfg_modify_fn=args)
os.environ['LOCAL_RANK'] = str(args.local_rank)
trainer: EpochBasedTrainer = build_trainer(name='trainer', default_args=kwargs)
trainer.train()