mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-21 18:49:23 +01:00
91 lines
2.6 KiB
Python
91 lines
2.6 KiB
Python
|
|
import os
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
|
||
|
|
from modelscope.msdatasets import MsDataset
|
||
|
|
from modelscope.trainers import EpochBasedTrainer, build_trainer
|
||
|
|
from modelscope.trainers.training_args import TrainingArgs
|
||
|
|
|
||
|
|
|
||
|
|
def get_labels(cfg, metadata):
|
||
|
|
label2id = cfg.safe_get(metadata['cfg_node'])
|
||
|
|
if label2id is not None:
|
||
|
|
return ','.join(label2id.keys())
|
||
|
|
|
||
|
|
|
||
|
|
def set_labels(cfg, labels, metadata):
|
||
|
|
if isinstance(labels, str):
|
||
|
|
labels = labels.split(',')
|
||
|
|
cfg.merge_from_dict(
|
||
|
|
{metadata['cfg_node']: {label: id
|
||
|
|
for id, label in enumerate(labels)}})
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class TextClassificationArguments(TrainingArgs):
|
||
|
|
|
||
|
|
first_sequence: str = field(
|
||
|
|
default=None,
|
||
|
|
metadata={
|
||
|
|
'help': 'The first sequence key of preprocessor',
|
||
|
|
'cfg_node': 'preprocessor.first_sequence'
|
||
|
|
})
|
||
|
|
|
||
|
|
second_sequence: str = field(
|
||
|
|
default=None,
|
||
|
|
metadata={
|
||
|
|
'help': 'The second sequence key of preprocessor',
|
||
|
|
'cfg_node': 'preprocessor.second_sequence'
|
||
|
|
})
|
||
|
|
|
||
|
|
label: str = field(
|
||
|
|
default=None,
|
||
|
|
metadata={
|
||
|
|
'help': 'The label key of preprocessor',
|
||
|
|
'cfg_node': 'preprocessor.label'
|
||
|
|
})
|
||
|
|
|
||
|
|
labels: str = field(
|
||
|
|
default=None,
|
||
|
|
metadata={
|
||
|
|
'help': 'The labels of the dataset',
|
||
|
|
'cfg_node': 'preprocessor.label2id',
|
||
|
|
'cfg_getter': get_labels,
|
||
|
|
'cfg_setter': set_labels,
|
||
|
|
})
|
||
|
|
|
||
|
|
preprocessor: str = field(
|
||
|
|
default=None,
|
||
|
|
metadata={
|
||
|
|
'help': 'The preprocessor type',
|
||
|
|
'cfg_node': 'preprocessor.type'
|
||
|
|
})
|
||
|
|
|
||
|
|
def __call__(self, config):
|
||
|
|
config = super().__call__(config)
|
||
|
|
config.model['num_labels'] = len(self.labels)
|
||
|
|
if config.train.lr_scheduler.type == 'LinearLR':
|
||
|
|
config.train.lr_scheduler['total_iters'] = \
|
||
|
|
int(len(train_dataset) / self.per_device_train_batch_size) * self.max_epochs
|
||
|
|
return config
|
||
|
|
|
||
|
|
|
||
|
|
args = TextClassificationArguments.from_cli(
|
||
|
|
task='text-classification', eval_metrics='seq-cls-metric')
|
||
|
|
|
||
|
|
print(args)
|
||
|
|
|
||
|
|
dataset = MsDataset.load(args.dataset_name, subset_name=args.subset_name)
|
||
|
|
train_dataset = dataset['train']
|
||
|
|
validation_dataset = dataset['validation']
|
||
|
|
|
||
|
|
kwargs = dict(
|
||
|
|
model=args.model,
|
||
|
|
train_dataset=train_dataset,
|
||
|
|
eval_dataset=validation_dataset,
|
||
|
|
seed=args.seed,
|
||
|
|
cfg_modify_fn=args)
|
||
|
|
|
||
|
|
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||
|
|
trainer: EpochBasedTrainer = build_trainer(name='trainer', default_args=kwargs)
|
||
|
|
trainer.train()
|