mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-21 18:49:23 +01:00
89 lines
2.7 KiB
Python
89 lines
2.7 KiB
Python
|
|
from dataclasses import dataclass, field
|
||
|
|
|
||
|
|
from modelscope.metainfo import Trainers
|
||
|
|
from modelscope.msdatasets import MsDataset
|
||
|
|
from modelscope.trainers import build_trainer
|
||
|
|
from modelscope.trainers.training_args import (TrainingArgs, get_flatten_value,
|
||
|
|
set_flatten_value)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class TokenClassificationArguments(TrainingArgs):
|
||
|
|
|
||
|
|
trainer: str = field(
|
||
|
|
default=Trainers.default, metadata={
|
||
|
|
'help': 'The trainer used',
|
||
|
|
})
|
||
|
|
|
||
|
|
preprocessor: str = field(
|
||
|
|
default=None,
|
||
|
|
metadata={
|
||
|
|
'help': 'The preprocessor type',
|
||
|
|
'cfg_node': 'preprocessor.type'
|
||
|
|
})
|
||
|
|
|
||
|
|
preprocessor_padding: str = field(
|
||
|
|
default=None,
|
||
|
|
metadata={
|
||
|
|
'help': 'The preprocessor padding',
|
||
|
|
'cfg_node': 'preprocessor.padding'
|
||
|
|
})
|
||
|
|
|
||
|
|
train_dataset_params: str = field(
|
||
|
|
default=None,
|
||
|
|
metadata={
|
||
|
|
'cfg_node': 'dataset.train',
|
||
|
|
'cfg_getter': get_flatten_value,
|
||
|
|
'cfg_setter': set_flatten_value,
|
||
|
|
'help': 'The parameters for train dataset',
|
||
|
|
})
|
||
|
|
|
||
|
|
def __call__(self, config):
|
||
|
|
config = super().__call__(config)
|
||
|
|
if config.safe_get('dataset.train.label') == 'ner_tags':
|
||
|
|
ner_tags_labels = train_dataset['ner_tags'] + eval_dataset[
|
||
|
|
'ner_tags']
|
||
|
|
label_enumerate_values = self._get_label_list(ner_tags_labels)
|
||
|
|
config.merge_from_dict(
|
||
|
|
{'dataset.train.labels': label_enumerate_values})
|
||
|
|
if config.train.lr_scheduler.type == 'LinearLR':
|
||
|
|
config.train.lr_scheduler['total_iters'] = \
|
||
|
|
int(len(train_dataset) / self.per_device_train_batch_size) * self.max_epochs
|
||
|
|
return config
|
||
|
|
|
||
|
|
# TODO: Future performance optimization in MsDataset
|
||
|
|
@staticmethod
|
||
|
|
def _get_label_list(labels):
|
||
|
|
unique_labels = set()
|
||
|
|
for label in labels:
|
||
|
|
unique_labels = unique_labels | set(label)
|
||
|
|
label_list = list(unique_labels)
|
||
|
|
label_list.sort()
|
||
|
|
return label_list
|
||
|
|
|
||
|
|
|
||
|
|
args = TokenClassificationArguments.from_cli(task='token-classification')
|
||
|
|
print(args)
|
||
|
|
|
||
|
|
# load dataset
|
||
|
|
train_dataset = MsDataset.load(
|
||
|
|
args.dataset_name,
|
||
|
|
subset_name=args.subset_name,
|
||
|
|
split='train',
|
||
|
|
namespace='damo')['train']
|
||
|
|
eval_dataset = MsDataset.load(
|
||
|
|
args.dataset_name,
|
||
|
|
subset_name=args.subset_name,
|
||
|
|
split='validation',
|
||
|
|
namespace='damo')['validation']
|
||
|
|
|
||
|
|
kwargs = dict(
|
||
|
|
model=args.model,
|
||
|
|
train_dataset=train_dataset,
|
||
|
|
eval_dataset=eval_dataset,
|
||
|
|
work_dir=args.work_dir,
|
||
|
|
cfg_modify_fn=args)
|
||
|
|
|
||
|
|
trainer = build_trainer(name=args.trainer, default_args=kwargs)
|
||
|
|
trainer.train()
|