mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 12:09:22 +01:00
[to #48855014] llama finetune + deepspeed
1. llama base finetune:使用trainer从llama finetune至alpaca,效果确认并提供example 2. deepspeed通用性完善:mpu解耦;deepspeed的训练信息支持从ms log透出(目前ms打印的log有误);支持从modelscope configuration.json 进行 deepspeed config 配置;deepspeed optimizer和lr_scheduler 初始化支持;解决deepspeed和ddp同时使用报错;解决保存ckpt时报错 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12651323
This commit is contained in:
263
examples/pytorch/llama/finetune_llama.py
Normal file
263
examples/pytorch/llama/finetune_llama.py
Normal file
@@ -0,0 +1,263 @@
|
||||
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import json
|
||||
import torch
|
||||
import utils
|
||||
|
||||
from modelscope import TrainingArgs
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.models.nlp.llama import LlamaForTextGeneration, LlamaTokenizer
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.torch_custom_dataset import \
|
||||
TorchCustomDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
DEFAULT_PAD_TOKEN = '[PAD]'
|
||||
DEFAULT_EOS_TOKEN = '</s>'
|
||||
DEFAULT_BOS_TOKEN = '<s>'
|
||||
DEFAULT_UNK_TOKEN = '<unk>'
|
||||
PROMPT_DICT = {
|
||||
'prompt_input':
|
||||
('Below is an instruction that describes a task, paired with an input that provides further context. '
|
||||
'Write a response that appropriately completes the request.\n\n'
|
||||
'### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:'
|
||||
),
|
||||
'prompt_no_input':
|
||||
('Below is an instruction that describes a task. '
|
||||
'Write a response that appropriately completes the request.\n\n'
|
||||
'### Instruction:\n{instruction}\n\n### Response:'),
|
||||
}
|
||||
|
||||
|
||||
@dataclass(init=False)
|
||||
class TextGenerationArguments(TrainingArgs):
|
||||
src_txt: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The source text key of preprocessor',
|
||||
'cfg_node': 'preprocessor.src_txt'
|
||||
})
|
||||
|
||||
deepspeed: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
'help': 'The location of DeepSpeed json config file.',
|
||||
})
|
||||
|
||||
work_dir: str = field(
|
||||
default=None, metadata={
|
||||
'help': 'The location of work dir',
|
||||
})
|
||||
|
||||
|
||||
def _tokenize_fn(strings, tokenizer):
|
||||
"""Tokenize a list of strings."""
|
||||
tokenized_list = [
|
||||
tokenizer(
|
||||
text,
|
||||
return_tensors='pt',
|
||||
padding='longest',
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
) for text in strings
|
||||
]
|
||||
input_ids = labels = [
|
||||
tokenized.input_ids[0] for tokenized in tokenized_list
|
||||
]
|
||||
input_ids_lens = labels_lens = [
|
||||
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
|
||||
for tokenized in tokenized_list
|
||||
]
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
input_ids_lens=input_ids_lens,
|
||||
labels_lens=labels_lens,
|
||||
)
|
||||
|
||||
|
||||
def preprocess(sources, targets, tokenizer):
|
||||
"""Preprocess the data by tokenizing."""
|
||||
examples = [s + t for s, t in zip(sources, targets)]
|
||||
examples_tokenized, sources_tokenized = [
|
||||
_tokenize_fn(strings, tokenizer) for strings in (examples, sources)
|
||||
]
|
||||
input_ids = examples_tokenized['input_ids']
|
||||
labels = copy.deepcopy(input_ids)
|
||||
for label, source_len in zip(labels, sources_tokenized['input_ids_lens']):
|
||||
label[:source_len] = IGNORE_INDEX
|
||||
return dict(input_ids=input_ids, labels=labels)
|
||||
|
||||
|
||||
def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer,
|
||||
model):
|
||||
"""Resize tokenizer and embedding.
|
||||
|
||||
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
||||
"""
|
||||
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if num_new_tokens > 0:
|
||||
input_embeddings = model.get_input_embeddings().weight.data
|
||||
output_embeddings = model.get_output_embeddings().weight.data
|
||||
|
||||
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
||||
dim=0, keepdim=True)
|
||||
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
||||
dim=0, keepdim=True)
|
||||
|
||||
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
||||
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
||||
|
||||
|
||||
class SupervisedDataset(TorchCustomDataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
def __init__(self, data_path: str, tokenizer):
|
||||
logging.warning('Loading data...')
|
||||
f = open(data_path, 'r')
|
||||
list_data_dict = json.load(f)
|
||||
f.close()
|
||||
|
||||
logging.warning('Formatting inputs...')
|
||||
prompt_input, prompt_no_input = PROMPT_DICT[
|
||||
'prompt_input'], PROMPT_DICT['prompt_no_input']
|
||||
sources = [
|
||||
prompt_input.format_map(example) if example.get('input', '') != ''
|
||||
else prompt_no_input.format_map(example)
|
||||
for example in list_data_dict
|
||||
]
|
||||
targets = [
|
||||
f"{example['output']}{tokenizer.eos_token}"
|
||||
for example in list_data_dict
|
||||
]
|
||||
|
||||
logging.warning('Tokenizing inputs... This may take some time...')
|
||||
data_dict = preprocess(sources, targets, tokenizer)
|
||||
|
||||
self.input_ids = data_dict['input_ids']
|
||||
self.labels = data_dict['labels']
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSupervisedDataset(object):
|
||||
"""Collate examples for supervised fine-tuning."""
|
||||
|
||||
tokenizer: LlamaTokenizer
|
||||
|
||||
def __call__(self, instances):
|
||||
input_ids, labels = tuple([instance[key] for instance in instances]
|
||||
for key in ('input_ids', 'labels'))
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id)
|
||||
labels = torch.nn.utils.rnn.pad_sequence(
|
||||
labels, batch_first=True, padding_value=IGNORE_INDEX)
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
||||
)
|
||||
|
||||
|
||||
config, args = TextGenerationArguments().parse_cli().to_config()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
if args.use_model_config:
|
||||
cfg.merge_from_dict(config)
|
||||
else:
|
||||
cfg = config
|
||||
cfg.train.lr_scheduler = {
|
||||
'type': 'CosineAnnealingLR',
|
||||
'T_max': 1,
|
||||
'options': {
|
||||
'by_epoch': False
|
||||
}
|
||||
}
|
||||
cfg.train.optimizer = {
|
||||
'type': 'AdamW',
|
||||
'lr': 2e-5,
|
||||
'weight_decay': 0.0,
|
||||
'options': {
|
||||
'cumulative_iters': 8,
|
||||
'warmup': {
|
||||
'type': 'LinearWarmup',
|
||||
'warmup_ratio': 0.03
|
||||
}
|
||||
}
|
||||
}
|
||||
cfg.train.logging = {'interval': 8, 'by_epoch': False}
|
||||
cfg.train['bf16'] = True
|
||||
cfg.train.dataloader = {'batch_size_per_gpu': 4, 'workers_per_gpu': 1}
|
||||
if 'hooks' not in cfg.train:
|
||||
cfg.train['hooks'] = []
|
||||
cfg.train.hooks.append({
|
||||
'type': 'DeepspeedHook',
|
||||
'config': args.deepspeed,
|
||||
'save_zero_checkpoint': True,
|
||||
'with_mpu': False,
|
||||
})
|
||||
|
||||
cfg.preprocessor.sequence_length = 512
|
||||
return cfg
|
||||
|
||||
model_path = args.model if os.path.exists(
|
||||
args.model) else snapshot_download(args.model)
|
||||
data_path = args.src_txt if args.src_txt else os.path.join(
|
||||
model_path, 'alpaca_data.json')
|
||||
model = LlamaForTextGeneration.from_pretrained(model_path)
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
model_path,
|
||||
model_max_length=512,
|
||||
padding_side='right',
|
||||
)
|
||||
|
||||
special_tokens_dict = dict()
|
||||
special_tokens_dict['pad_token'] = DEFAULT_PAD_TOKEN
|
||||
special_tokens_dict['eos_token'] = DEFAULT_EOS_TOKEN
|
||||
special_tokens_dict['bos_token'] = DEFAULT_BOS_TOKEN
|
||||
special_tokens_dict['unk_token'] = DEFAULT_UNK_TOKEN
|
||||
|
||||
smart_tokenizer_and_embedding_resize(
|
||||
special_tokens_dict=special_tokens_dict,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
)
|
||||
|
||||
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_path)
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
||||
|
||||
kwargs = dict(
|
||||
model=model,
|
||||
cfg_file=os.path.join(model_path, 'configuration.json'),
|
||||
train_dataset=train_dataset,
|
||||
data_collator=data_collator,
|
||||
max_epochs=3,
|
||||
work_dir=args.work_dir,
|
||||
cfg_modify_fn=cfg_modify_fn)
|
||||
|
||||
# Construct trainer and train
|
||||
trainer = build_trainer(
|
||||
name=Trainers.text_generation_trainer, default_args=kwargs)
|
||||
trainer.train()
|
||||
9
examples/pytorch/llama/run_train_llama.sh
Normal file
9
examples/pytorch/llama/run_train_llama.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
DATA_PARALLEL_SIZE=4
|
||||
|
||||
|
||||
export PYTHONPATH=$PYTHONPATH:./
|
||||
torchrun --nproc_per_node $DATA_PARALLEL_SIZE examples/pytorch/llama/finetune_llama.py \
|
||||
--work_dir './tmp' \
|
||||
--model 'skyline2006/llama-7b' \
|
||||
--deepspeed 'default_offload_opt_param.json' \
|
||||
--eval_interval 100
|
||||
Reference in New Issue
Block a user