mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 05:05:00 +02:00
add model revision in training_args and modify dataset loading in finetune text classification
1.add parameter model_revision in training_args.py. 2.add parameter model_revision in kwargs for finetune_text_classification.py and finetune_text_generation.py. 3.modify dataset loading in finetune_text_classification.py for flex training. Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12869552 * add model revision in training_args and modify dataset loading in finetune text classification
This commit is contained in:
@@ -70,16 +70,23 @@ def cfg_modify_fn(cfg):
|
||||
|
||||
|
||||
if args.dataset_json_file is None:
|
||||
dataset = MsDataset.load(
|
||||
args.train_dataset_name, subset_name=args.train_subset_name)
|
||||
train_dataset = dataset['train']
|
||||
validation_dataset = dataset['validation']
|
||||
train_dataset = MsDataset.load(
|
||||
args.train_dataset_name,
|
||||
subset_name=args.train_subset_name,
|
||||
split=args.train_split,
|
||||
namespace=args.train_dataset_namespace)
|
||||
validation_dataset = MsDataset.load(
|
||||
args.val_dataset_name,
|
||||
subset_name=args.val_subset_name,
|
||||
split=args.val_split,
|
||||
namespace=args.val_dataset_namespace)
|
||||
else:
|
||||
train_dataset, validation_dataset = build_dataset_from_file(
|
||||
args.dataset_json_file)
|
||||
|
||||
kwargs = dict(
|
||||
model=args.model,
|
||||
model_revision=args.model_revision,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=validation_dataset,
|
||||
seed=args.seed,
|
||||
|
||||
@@ -2,15 +2,24 @@ PYTHONPATH=. python examples/pytorch/text_classification/finetune_text_classific
|
||||
--task 'text-classification' \
|
||||
--model 'damo/nlp_structbert_backbone_base_std' \
|
||||
--train_dataset_name 'clue' \
|
||||
--val_dataset_name 'clue' \
|
||||
--train_subset_name 'tnews' \
|
||||
--val_subset_name 'tnews' \
|
||||
--train_split 'train' \
|
||||
--val_split 'validation' \
|
||||
--first_sequence 'sentence' \
|
||||
--preprocessor.label label \
|
||||
--model.num_labels 15 \
|
||||
--label label \
|
||||
--num_labels 15 \
|
||||
--labels '0,1,2,3,4,5,6,7,8,9,10,11,12,13,14' \
|
||||
--preprocessor 'sen-cls-tokenizer' \
|
||||
--use_model_config True \
|
||||
--max_epochs 1 \
|
||||
--train.dataloader.workers_per_gpu 0 \
|
||||
--evaluation.dataloader.workers_per_gpu 0 \
|
||||
--train.optimizer.lr 1e-5 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--per_device_eval_batch_size 16 \
|
||||
--eval_interval 100 \
|
||||
--eval_strategy by_step \
|
||||
--work_dir './tmp' \
|
||||
--train_data_worker 0 \
|
||||
--eval_data_worker 0 \
|
||||
--lr 1e-5 \
|
||||
--eval_metrics 'seq-cls-metric' \
|
||||
|
||||
@@ -112,6 +112,7 @@ else:
|
||||
|
||||
kwargs = dict(
|
||||
model=args.model,
|
||||
model_revision=args.model_revision,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=validation_dataset,
|
||||
seed=args.seed,
|
||||
|
||||
@@ -117,6 +117,11 @@ class ModelArgs:
|
||||
'help': 'A model id or model dir',
|
||||
})
|
||||
|
||||
model_revision: str = field(
|
||||
default=None, metadata={
|
||||
'help': 'the revision of model',
|
||||
})
|
||||
|
||||
model_type: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
|
||||
Reference in New Issue
Block a user