mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
1. Fix the amp + device_map bug in chatglm2 finetune code 2. Optional to save optimizer state 3. Fix the logger double print problem
403 lines
12 KiB
Python
403 lines
12 KiB
Python
import os
|
|
from dataclasses import dataclass, field
|
|
|
|
import numpy as np
|
|
import torch
|
|
from chatglm_trainer import Seq2SeqTrainer
|
|
from text_generation_metric import TextGenerationMetric
|
|
from transformers import DataCollatorForSeq2Seq
|
|
|
|
from modelscope import snapshot_download
|
|
from modelscope.models import Model
|
|
from modelscope.msdatasets import MsDataset
|
|
from modelscope.swift import Swift
|
|
from modelscope.swift.lora import LoRAConfig
|
|
from modelscope.trainers.training_args import TrainingArgs
|
|
from modelscope.utils.config import ConfigDict
|
|
from modelscope.utils.hub import read_config
|
|
|
|
|
|
@dataclass(init=False)
|
|
class Chatglm6bArguments(TrainingArgs):
|
|
ptuning_checkpoint: str = field(
|
|
default=None,
|
|
metadata={
|
|
'help': 'The p-tuning checkpoint previously trained.',
|
|
})
|
|
|
|
pre_seq_len: int = field(
|
|
default=None, metadata={
|
|
'help': 'The p-tuning sequence length',
|
|
})
|
|
|
|
prefix_projection: bool = field(
|
|
default=False, metadata={
|
|
'help': '',
|
|
})
|
|
|
|
quantization_bit: int = field(
|
|
default=None, metadata={
|
|
'help': 'Quantized bit',
|
|
})
|
|
|
|
prompt_column: str = field(
|
|
default=None,
|
|
metadata={
|
|
'help':
|
|
'The name of the column in the datasets containing the full texts (for summarization).'
|
|
},
|
|
)
|
|
|
|
response_column: str = field(
|
|
default=None,
|
|
metadata={
|
|
'help':
|
|
'The name of the column in the datasets containing the summaries (for summarization).'
|
|
},
|
|
)
|
|
|
|
history_column: str = field(
|
|
default=None,
|
|
metadata={
|
|
'help':
|
|
'The name of the column in the datasets containing the history of chat.'
|
|
},
|
|
)
|
|
|
|
source_prefix: str = field(
|
|
default='',
|
|
metadata={
|
|
'help':
|
|
'A prefix to add before every source text (useful for T5 models).'
|
|
})
|
|
|
|
ignore_pad_token_for_loss: bool = field(
|
|
default=True,
|
|
metadata={
|
|
'help':
|
|
'Whether to ignore the tokens corresponding to padded labels in the loss computation or not.'
|
|
},
|
|
)
|
|
|
|
max_source_length: int = field(
|
|
default=1024,
|
|
metadata={
|
|
'help':
|
|
('The maximum total input sequence length after tokenization. Sequences longer '
|
|
'than this will be truncated, sequences shorter will be padded.')
|
|
},
|
|
)
|
|
|
|
max_target_length: int = field(
|
|
default=128,
|
|
metadata={
|
|
'help':
|
|
('The maximum total sequence length for target text after tokenization. Sequences longer '
|
|
'than this will be truncated, sequences shorter will be padded.')
|
|
},
|
|
)
|
|
|
|
max_train_samples: int = field(
|
|
default=None,
|
|
metadata={
|
|
'help':
|
|
('For debugging purposes or quicker training, truncate the number of training examples to this '
|
|
'value if set.')
|
|
},
|
|
)
|
|
|
|
max_eval_samples: int = field(
|
|
default=None,
|
|
metadata={
|
|
'help':
|
|
('For debugging purposes or quicker training, truncate the number of evaluation examples to this '
|
|
'value if set.')
|
|
},
|
|
)
|
|
|
|
preprocessing_num_workers: int = field(
|
|
default=None,
|
|
metadata={
|
|
'help': 'The number of processes to use for the preprocessing.'
|
|
},
|
|
)
|
|
|
|
use_lora: int = field(
|
|
default=0,
|
|
metadata={'help': 'Whether to use lora to train the model.'},
|
|
)
|
|
|
|
lora_rank: int = field(
|
|
default=32,
|
|
metadata={'help': 'The lora rank'},
|
|
)
|
|
|
|
lora_alpha: int = field(
|
|
default=32,
|
|
metadata={'help': 'The lora alpha'},
|
|
)
|
|
|
|
lora_dropout: float = field(
|
|
default=0.05,
|
|
metadata={'help': 'The lora alpha'},
|
|
)
|
|
|
|
use_amp: int = field(
|
|
default=0,
|
|
metadata={
|
|
'help':
|
|
'Whether to use amp(automatic mixed precision) to train the model.'
|
|
},
|
|
)
|
|
|
|
|
|
args = Chatglm6bArguments(eval_metrics='chatglm').parse_cli()
|
|
print(args)
|
|
config, _ = args.to_config(ignore_default_config=args.use_model_config)
|
|
config.dump('./configuration.json')
|
|
|
|
if config['model']['type'] == 'chatglm6b':
|
|
from modelscope.models.nlp import ChatGLMTokenizer
|
|
else:
|
|
from modelscope.models.nlp import ChatGLM2Tokenizer as ChatGLMTokenizer
|
|
|
|
|
|
def cfg_modify_fn(cfg):
|
|
if args.use_model_config:
|
|
cfg.merge_from_dict(config)
|
|
else:
|
|
cfg = config
|
|
if args.use_amp:
|
|
if not getattr(cfg.train, 'hooks', None):
|
|
cfg.train.hooks = []
|
|
cfg.train.hooks.append({
|
|
'type': 'TorchAMPOptimizerHook',
|
|
# Optional loss_scale parameter here.
|
|
})
|
|
if cfg.train.lr_scheduler.type == 'LinearLR':
|
|
cfg.train.lr_scheduler['total_iters'] = \
|
|
int(len(train_dataset) / cfg.train.dataloader.batch_size_per_gpu) * cfg.train.max_epochs
|
|
cfg['gen_kwargs'] = {
|
|
'do_sample': True,
|
|
'top_p': 0.7,
|
|
'max_length': 512,
|
|
'temperature': 0.95
|
|
}
|
|
return cfg
|
|
|
|
|
|
train_dataset = MsDataset.load(
|
|
args.train_dataset_name,
|
|
subset_name=args.train_subset_name,
|
|
split=args.train_split)
|
|
validation_dataset = MsDataset.load(
|
|
args.val_dataset_name,
|
|
subset_name=args.val_subset_name,
|
|
split=args.val_split)
|
|
|
|
model_dir = snapshot_download(args.model)
|
|
model_config = read_config(model_dir)
|
|
model_config['model'] = ConfigDict({
|
|
'type': config['model']['type'],
|
|
})
|
|
|
|
model_config['model']['pre_seq_len'] = args.pre_seq_len
|
|
model_config['model']['prefix_projection'] = args.prefix_projection
|
|
tokenizer = ChatGLMTokenizer.from_pretrained(model_dir, trust_remote_code=True)
|
|
|
|
device_map_kwargs = {}
|
|
device_kwargs = {}
|
|
if args.use_lora != 0 and torch.cuda.device_count() > 1:
|
|
device_map_kwargs['device_map'] = 'auto'
|
|
# No placement for model, leave the model to `device_map`
|
|
device_kwargs['device'] = 'cpu'
|
|
model = Model.from_pretrained(
|
|
model_dir, cfg_dict=model_config, **device_map_kwargs)
|
|
|
|
if args.ptuning_checkpoint is not None:
|
|
# Evaluation
|
|
# Loading extra state dict of prefix encoder
|
|
|
|
prefix_state_dict = torch.load(
|
|
os.path.join(args.ptuning_checkpoint, 'pytorch_model.bin'))
|
|
new_prefix_state_dict = {}
|
|
for k, v in prefix_state_dict.items():
|
|
if k.startswith('transformer.prefix_encoder.'):
|
|
new_prefix_state_dict[k[len('transformer.prefix_encoder.'):]] = v
|
|
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
|
|
|
|
if args.quantization_bit is not None:
|
|
print(f'Quantized to {args.quantization_bit} bit')
|
|
model = model.quantize(args.quantization_bit)
|
|
if args.pre_seq_len is not None:
|
|
# P-tuning v2
|
|
model = model.half()
|
|
model.transformer.prefix_encoder.float()
|
|
elif not args.use_lora:
|
|
# Finetune
|
|
model = model.float()
|
|
|
|
if args.use_lora != 0:
|
|
lora_config = LoRAConfig(
|
|
replace_modules=['attention.query_key_value'],
|
|
rank=args.lora_rank,
|
|
lora_alpha=args.lora_alpha,
|
|
lora_dropout=args.lora_dropout)
|
|
if args.use_amp:
|
|
model = model.float()
|
|
else:
|
|
model = model.bfloat16()
|
|
Swift.prepare_model(model, lora_config)
|
|
|
|
prefix = args.source_prefix if args.source_prefix is not None else ''
|
|
|
|
# Get the column names for input/target.
|
|
prompt_column = args.prompt_column
|
|
response_column = args.response_column
|
|
history_column = args.history_column
|
|
|
|
# Temporarily set max_target_length for training.
|
|
max_target_length = args.max_target_length
|
|
|
|
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
|
|
trainable_params = sum([np.prod(p.size()) for p in model_parameters])
|
|
|
|
model_parameters = filter(lambda p: not p.requires_grad, model.parameters())
|
|
non_trainable_params = sum([np.prod(p.size()) for p in model_parameters])
|
|
|
|
print('trainable_params:{} ({:.2f}%), non_trainable_params:{}'.format(
|
|
trainable_params, trainable_params / non_trainable_params * 100,
|
|
non_trainable_params))
|
|
|
|
|
|
def preprocess_function_eval(examples):
|
|
inputs, targets = [], []
|
|
for i in range(len(examples[prompt_column])):
|
|
if examples[prompt_column][i] and examples[response_column][i]:
|
|
query = examples[prompt_column][i]
|
|
if history_column is None or len(examples[history_column][i]) == 0:
|
|
prompt = query
|
|
else:
|
|
prompt = ''
|
|
history = examples[history_column][i]
|
|
for turn_idx, (old_query, response) in enumerate(history):
|
|
prompt += '[Round {}]\n问:{}\n答:{}\n'.format(
|
|
turn_idx, old_query, response)
|
|
prompt += '[Round {}]\n问:{}\n答:'.format(len(history), query)
|
|
inputs.append(prompt)
|
|
targets.append(examples[response_column][i])
|
|
|
|
inputs = [prefix + inp for inp in inputs]
|
|
model_inputs = tokenizer(
|
|
inputs,
|
|
max_length=args.max_source_length,
|
|
truncation=True,
|
|
padding=True)
|
|
labels = tokenizer(
|
|
text_target=targets, max_length=max_target_length, truncation=True)
|
|
|
|
if args.ignore_pad_token_for_loss:
|
|
labels['input_ids'] = [[(lb if lb != tokenizer.pad_token_id else -100)
|
|
for lb in label]
|
|
for label in labels['input_ids']]
|
|
model_inputs['labels'] = labels['input_ids']
|
|
|
|
return model_inputs
|
|
|
|
|
|
def preprocess_function_train(examples):
|
|
max_seq_length = args.max_source_length + args.max_target_length
|
|
|
|
model_inputs = {
|
|
'input_ids': [],
|
|
'labels': [],
|
|
}
|
|
for i in range(len(examples[prompt_column])):
|
|
if examples[prompt_column][i] and examples[response_column][i]:
|
|
query, answer = examples[prompt_column][i], examples[
|
|
response_column][i]
|
|
|
|
if history_column is None:
|
|
prompt = query
|
|
else:
|
|
prompt = ''
|
|
history = examples[history_column][i]
|
|
for turn_idx, (old_query, response) in enumerate(history):
|
|
prompt += '[Round {}]\n问:{}\n答:{}\n'.format(
|
|
turn_idx, old_query, response)
|
|
prompt += '[Round {}]\n问:{}\n答:'.format(len(history), query)
|
|
|
|
prompt = prefix + prompt
|
|
a_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
|
|
b_ids = tokenizer.encode(text=answer, add_special_tokens=False)
|
|
|
|
if len(a_ids) > args.max_source_length - 1:
|
|
a_ids = a_ids[:args.max_source_length - 1]
|
|
|
|
if len(b_ids) > args.max_target_length - 2:
|
|
b_ids = b_ids[:args.max_target_length - 2]
|
|
|
|
input_ids = tokenizer.build_inputs_with_special_tokens(
|
|
a_ids, b_ids)
|
|
|
|
if config['model']['type'] == 'chatglm6b':
|
|
context_length = input_ids.index(tokenizer.bos_token_id)
|
|
else:
|
|
context_length = len(a_ids) + 2
|
|
mask_position = context_length - 1
|
|
labels = [-100] * context_length + input_ids[mask_position + 1:]
|
|
|
|
pad_len = max_seq_length - len(input_ids)
|
|
input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
|
|
labels = labels + [tokenizer.pad_token_id] * pad_len
|
|
if args.ignore_pad_token_for_loss:
|
|
labels = [(lb if lb != tokenizer.pad_token_id else -100)
|
|
for lb in labels]
|
|
|
|
model_inputs['input_ids'].append(input_ids)
|
|
model_inputs['labels'].append(labels)
|
|
|
|
return model_inputs
|
|
|
|
|
|
train_dataset = train_dataset.to_hf_dataset().map(
|
|
preprocess_function_train,
|
|
batched=True,
|
|
num_proc=args.preprocessing_num_workers,
|
|
desc='Running tokenizer on train dataset',
|
|
)
|
|
|
|
validation_dataset = validation_dataset.to_hf_dataset().map(
|
|
preprocess_function_eval,
|
|
batched=True,
|
|
num_proc=args.preprocessing_num_workers,
|
|
desc='Running tokenizer on eval dataset',
|
|
)
|
|
|
|
# Data collator
|
|
label_pad_token_id = -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
|
data_collator = DataCollatorForSeq2Seq(
|
|
tokenizer,
|
|
model=model,
|
|
label_pad_token_id=label_pad_token_id,
|
|
pad_to_multiple_of=None,
|
|
padding=False)
|
|
|
|
model.gradient_checkpointing_enable()
|
|
model.enable_input_require_grads()
|
|
|
|
# import torch
|
|
# model = torch.nn.DataParallel(model).cuda()
|
|
trainer = Seq2SeqTrainer(
|
|
model=model,
|
|
cfg_file='./configuration.json',
|
|
train_dataset=train_dataset,
|
|
eval_dataset=validation_dataset,
|
|
seed=args.seed,
|
|
data_collator=data_collator,
|
|
remove_unused_data=True,
|
|
cfg_modify_fn=cfg_modify_fn,
|
|
**device_kwargs)
|
|
trainer.tokenizer = tokenizer
|
|
trainer.train()
|