2023-04-13 10:21:00 +08:00
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
|
2023-05-22 10:53:18 +08:00
|
|
|
from modelscope import (EpochBasedTrainer, MsDataset, TrainingArgs,
|
|
|
|
|
build_dataset_from_file)
|
2023-04-13 10:21:00 +08:00
|
|
|
|
|
|
|
|
|
2023-05-22 10:53:18 +08:00
|
|
|
@dataclass(init=False)
|
2023-04-13 10:21:00 +08:00
|
|
|
class TokenClassificationArguments(TrainingArgs):
|
|
|
|
|
trainer: str = field(
|
2023-05-22 10:53:18 +08:00
|
|
|
default=None, metadata={
|
2023-04-13 10:21:00 +08:00
|
|
|
'help': 'The trainer used',
|
|
|
|
|
})
|
|
|
|
|
|
2023-05-22 10:53:18 +08:00
|
|
|
work_dir: str = field(
|
|
|
|
|
default='./tmp',
|
|
|
|
|
metadata={
|
|
|
|
|
'help': 'The working path for saving checkpoint',
|
|
|
|
|
})
|
|
|
|
|
|
2023-04-13 10:21:00 +08:00
|
|
|
preprocessor: str = field(
|
|
|
|
|
default=None,
|
|
|
|
|
metadata={
|
|
|
|
|
'help': 'The preprocessor type',
|
|
|
|
|
'cfg_node': 'preprocessor.type'
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
preprocessor_padding: str = field(
|
|
|
|
|
default=None,
|
|
|
|
|
metadata={
|
|
|
|
|
'help': 'The preprocessor padding',
|
|
|
|
|
'cfg_node': 'preprocessor.padding'
|
|
|
|
|
})
|
|
|
|
|
|
2023-05-22 10:53:18 +08:00
|
|
|
mode: str = field(
|
|
|
|
|
default='inference',
|
|
|
|
|
metadata={
|
|
|
|
|
'help': 'The preprocessor padding',
|
|
|
|
|
'cfg_node': 'preprocessor.mode'
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
first_sequence: str = field(
|
2023-04-13 10:21:00 +08:00
|
|
|
default=None,
|
|
|
|
|
metadata={
|
2023-05-22 10:53:18 +08:00
|
|
|
'cfg_node': 'preprocessor.first_sequence',
|
2023-04-13 10:21:00 +08:00
|
|
|
'help': 'The parameters for train dataset',
|
|
|
|
|
})
|
|
|
|
|
|
2023-05-22 10:53:18 +08:00
|
|
|
label: str = field(
|
|
|
|
|
default=None,
|
|
|
|
|
metadata={
|
|
|
|
|
'cfg_node': 'preprocessor.label',
|
|
|
|
|
'help': 'The parameters for train dataset',
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
sequence_length: int = field(
|
|
|
|
|
default=128,
|
|
|
|
|
metadata={
|
|
|
|
|
'cfg_node': 'preprocessor.sequence_length',
|
|
|
|
|
'help': 'The parameters for train dataset',
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_args = TokenClassificationArguments().parse_cli()
|
|
|
|
|
config, args = training_args.to_config()
|
2023-04-13 10:21:00 +08:00
|
|
|
print(args)
|
|
|
|
|
|
2023-05-22 10:53:18 +08:00
|
|
|
|
|
|
|
|
def get_label_list(labels):
|
|
|
|
|
unique_labels = set()
|
|
|
|
|
for label in labels:
|
|
|
|
|
unique_labels = unique_labels | set(label)
|
|
|
|
|
label_list = list(unique_labels)
|
|
|
|
|
label_list.sort()
|
|
|
|
|
return label_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cfg_modify_fn(cfg):
|
|
|
|
|
if args.use_model_config:
|
|
|
|
|
cfg.merge_from_dict(config)
|
|
|
|
|
else:
|
|
|
|
|
cfg = config
|
|
|
|
|
labels = train_dataset[training_args.label] + validation_dataset[
|
|
|
|
|
training_args.label]
|
|
|
|
|
label_enumerate_values = get_label_list(labels)
|
|
|
|
|
cfg.merge_from_dict({
|
|
|
|
|
'preprocessor.label2id':
|
|
|
|
|
{label: id
|
|
|
|
|
for id, label in enumerate(label_enumerate_values)}
|
|
|
|
|
})
|
|
|
|
|
cfg.merge_from_dict({'model.num_labels': len(label_enumerate_values)})
|
|
|
|
|
cfg.merge_from_dict({'preprocessor.use_fast': True})
|
|
|
|
|
cfg.merge_from_dict({
|
|
|
|
|
'evaluation.metrics': {
|
|
|
|
|
'type': 'token-cls-metric',
|
|
|
|
|
'label2id':
|
|
|
|
|
{label: id
|
|
|
|
|
for id, label in enumerate(label_enumerate_values)}
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
if cfg.train.lr_scheduler.type == 'LinearLR':
|
|
|
|
|
cfg.train.lr_scheduler['total_iters'] = \
|
|
|
|
|
int(len(train_dataset) / cfg.train.dataloader.batch_size_per_gpu) * cfg.train.max_epochs
|
|
|
|
|
return cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.dataset_json_file is None:
|
|
|
|
|
train_dataset = MsDataset.load(
|
|
|
|
|
args.train_dataset_name,
|
|
|
|
|
subset_name=args.train_subset_name,
|
|
|
|
|
split='train',
|
|
|
|
|
namespace=args.train_dataset_namespace)['train']
|
|
|
|
|
validation_dataset = MsDataset.load(
|
|
|
|
|
args.train_dataset_name,
|
|
|
|
|
subset_name=args.train_subset_name,
|
|
|
|
|
split='validation',
|
|
|
|
|
namespace=args.train_dataset_namespace)['validation']
|
|
|
|
|
else:
|
|
|
|
|
train_dataset, validation_dataset = build_dataset_from_file(
|
|
|
|
|
args.dataset_json_file)
|
2023-04-13 10:21:00 +08:00
|
|
|
|
|
|
|
|
kwargs = dict(
|
|
|
|
|
model=args.model,
|
|
|
|
|
train_dataset=train_dataset,
|
2023-05-22 10:53:18 +08:00
|
|
|
eval_dataset=validation_dataset,
|
2023-04-13 10:21:00 +08:00
|
|
|
work_dir=args.work_dir,
|
2023-05-22 10:53:18 +08:00
|
|
|
cfg_modify_fn=cfg_modify_fn)
|
2023-04-13 10:21:00 +08:00
|
|
|
|
2023-05-22 10:53:18 +08:00
|
|
|
trainer = EpochBasedTrainer(**kwargs)
|
2023-04-13 10:21:00 +08:00
|
|
|
trainer.train()
|