mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
support llama2 (#393)
* Unify sft and infer code into a single file * update llama2 sft infer
This commit is contained in:
@@ -5,6 +5,7 @@ import os
|
|||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -15,7 +16,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from datasets import Dataset as HFDataset
|
from datasets import Dataset as HfDataset
|
||||||
from datasets import concatenate_datasets
|
from datasets import concatenate_datasets
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
@@ -36,6 +37,8 @@ from torch.utils.data import Dataset
|
|||||||
from torchmetrics import Accuracy, MeanMetric
|
from torchmetrics import Accuracy, MeanMetric
|
||||||
#
|
#
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
||||||
|
GenerationConfig, HfArgumentParser, TextStreamer)
|
||||||
|
|
||||||
#
|
#
|
||||||
from modelscope import (Model, MsDataset, get_logger, read_config,
|
from modelscope import (Model, MsDataset, get_logger, read_config,
|
||||||
@@ -57,6 +60,7 @@ PROMPT = """Human: {instruction}
|
|||||||
AI: """
|
AI: """
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
@@ -150,12 +154,12 @@ def get_T_max(dataset_len: int, batch_size: int, max_epochs: int,
|
|||||||
return T_max
|
return T_max
|
||||||
|
|
||||||
|
|
||||||
def tokenize_function(example: Dict[str, str],
|
def tokenize_function(example: Dict[str, Optional[str]],
|
||||||
tokenizer,
|
tokenizer,
|
||||||
max_length: Optional[int] = 2048) -> Dict[str, Any]:
|
max_length: Optional[int] = 2048) -> Dict[str, Any]:
|
||||||
"""Only applicable to baichuan and chatglm2. Other models need to be tested"""
|
"""Only applicable to baichuan and chatglm2. Other models need to be tested"""
|
||||||
instruction = example['instruction']
|
instruction: str = example['instruction']
|
||||||
input_: str = example['input']
|
input_ = example['input']
|
||||||
if input_ is not None and input_ != '':
|
if input_ is not None and input_ != '':
|
||||||
# instruction = instruction + '\n'
|
# instruction = instruction + '\n'
|
||||||
if input_.startswith('输入:'):
|
if input_.startswith('输入:'):
|
||||||
@@ -187,7 +191,7 @@ def tokenize_function(example: Dict[str, str],
|
|||||||
return {'input_ids': input_ids, 'labels': labels}
|
return {'input_ids': input_ids, 'labels': labels}
|
||||||
|
|
||||||
|
|
||||||
def stat_dataset(dataset: HFDataset) -> None:
|
def stat_dataset(dataset: HfDataset) -> None:
|
||||||
"""Statistical analysis was performed on the data set"""
|
"""Statistical analysis was performed on the data set"""
|
||||||
_token_len = []
|
_token_len = []
|
||||||
for d in dataset:
|
for d in dataset:
|
||||||
@@ -202,8 +206,8 @@ def stat_dataset(dataset: HFDataset) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def print_examples(examples: Dict[str, Any], tokenizer) -> None:
|
def print_example(example: Dict[str, Any], tokenizer) -> None:
|
||||||
input_ids, labels = examples['input_ids'], examples['labels']
|
input_ids, labels = example['input_ids'], example['labels']
|
||||||
print(f'[INPUT_IDS] {input_ids}')
|
print(f'[INPUT_IDS] {input_ids}')
|
||||||
print(f'[INPUT] {tokenizer.decode(input_ids)}')
|
print(f'[INPUT] {tokenizer.decode(input_ids)}')
|
||||||
print()
|
print()
|
||||||
@@ -305,48 +309,24 @@ def _add_special_token(tokenizer):
|
|||||||
f'pad_token_id: {tokenizer.pad_token_id}')
|
f'pad_token_id: {tokenizer.pad_token_id}')
|
||||||
|
|
||||||
|
|
||||||
def get_baichuan7B_model_tokenizer(model_dir: str,
|
def get_baichuan_model_tokenizer(model_dir: str,
|
||||||
load_model: bool = True,
|
load_model: bool = True,
|
||||||
add_special_token: bool = True):
|
add_special_token: bool = True):
|
||||||
sys.path.insert(0, model_dir)
|
sys.path.insert(0, model_dir)
|
||||||
from configuration_baichuan import BaiChuanConfig
|
model_config = AutoConfig.from_pretrained(
|
||||||
from tokenization_baichuan import BaiChuanTokenizer
|
model_dir, trust_remote_code=True)
|
||||||
from modeling_baichuan import BaiChuanForCausalLM
|
|
||||||
model_config = BaiChuanConfig.from_pretrained(model_dir)
|
|
||||||
model_config.torch_dtype = torch.float16
|
model_config.torch_dtype = torch.float16
|
||||||
logger.info(f'model_config: {model_config}')
|
logger.info(f'model_config: {model_config}')
|
||||||
tokenizer = BaiChuanTokenizer.from_pretrained(model_dir)
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_dir, trust_remote_code=True)
|
||||||
model = None
|
model = None
|
||||||
if load_model:
|
if load_model:
|
||||||
model = BaiChuanForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_dir,
|
model_dir,
|
||||||
config=model_config,
|
config=model_config,
|
||||||
device_map='auto',
|
device_map='auto',
|
||||||
torch_dtype=torch.float16)
|
torch_dtype=torch.float16,
|
||||||
#
|
trust_remote_code=True)
|
||||||
if add_special_token:
|
|
||||||
_add_special_token(tokenizer)
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def get_baichuan13B_model_tokenizer(model_dir: str,
|
|
||||||
load_model: bool = True,
|
|
||||||
add_special_token: bool = True):
|
|
||||||
sys.path.insert(0, model_dir)
|
|
||||||
from configuration_baichuan import BaichuanConfig
|
|
||||||
from tokenization_baichuan import BaichuanTokenizer
|
|
||||||
from modeling_baichuan import BaichuanForCausalLM
|
|
||||||
model_config = BaichuanConfig.from_pretrained(model_dir)
|
|
||||||
model_config.torch_dtype = torch.float16
|
|
||||||
logger.info(f'model_config: {model_config}')
|
|
||||||
tokenizer = BaichuanTokenizer.from_pretrained(model_dir)
|
|
||||||
model = None
|
|
||||||
if load_model:
|
|
||||||
model = BaichuanForCausalLM.from_pretrained(
|
|
||||||
model_dir,
|
|
||||||
config=model_config,
|
|
||||||
device_map='auto',
|
|
||||||
torch_dtype=torch.float16)
|
|
||||||
#
|
#
|
||||||
if add_special_token:
|
if add_special_token:
|
||||||
_add_special_token(tokenizer)
|
_add_special_token(tokenizer)
|
||||||
@@ -371,23 +351,43 @@ def get_chatglm2_model_tokenizer(model_dir: str,
|
|||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def get_llama2_model_tokenizer(model_dir: str,
|
||||||
|
load_model: bool = True,
|
||||||
|
add_special_token: bool = True):
|
||||||
|
config = AutoConfig.from_pretrained(model_dir)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||||
|
model = None
|
||||||
|
if load_model:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_dir,
|
||||||
|
config=config,
|
||||||
|
device_map='auto',
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
if add_special_token:
|
||||||
|
_add_special_token(tokenizer)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def get_alpaca_en_zh_dataset(
|
def get_alpaca_en_zh_dataset(
|
||||||
tokenize_function,
|
tokenize_function,
|
||||||
only_val: bool = False,
|
only_val: bool = False,
|
||||||
test_split_p: float = 0.01,
|
test_split_p: float = 0.01,
|
||||||
split_seed: int = 42) -> Tuple[HFDataset, HFDataset]:
|
split_seed: int = 42,
|
||||||
|
data_sample: Optional[int] = None) -> Tuple[HfDataset, HfDataset]:
|
||||||
"""
|
"""
|
||||||
split: Literal['train', 'validation', None]
|
split: Literal['train', 'validation', None]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dataset_en: HFDataset = MsDataset.load(
|
dataset_en: HfDataset = MsDataset.load(
|
||||||
'AI-ModelScope/alpaca-gpt4-data-en', split='train').to_hf_dataset()
|
'AI-ModelScope/alpaca-gpt4-data-en', split='train').to_hf_dataset()
|
||||||
dataset_zh: HFDataset = MsDataset.load(
|
dataset_zh: HfDataset = MsDataset.load(
|
||||||
'AI-ModelScope/alpaca-gpt4-data-zh', split='train').to_hf_dataset()
|
'AI-ModelScope/alpaca-gpt4-data-zh', split='train').to_hf_dataset()
|
||||||
dataset_en = dataset_en.remove_columns(['text'])
|
dataset_en = dataset_en.remove_columns(['text'])
|
||||||
dataset: HFDataset = concatenate_datasets([dataset_zh, dataset_en])
|
dataset: HfDataset = concatenate_datasets([dataset_zh, dataset_en])
|
||||||
#
|
#
|
||||||
# dataset = dataset.select(range(1000)) # for debug
|
if data_sample is not None:
|
||||||
|
dataset = dataset.select(range(data_sample))
|
||||||
dataset = dataset.train_test_split(test_split_p, seed=split_seed)
|
dataset = dataset.train_test_split(test_split_p, seed=split_seed)
|
||||||
if only_val:
|
if only_val:
|
||||||
dataset = dataset['test']
|
dataset = dataset['test']
|
||||||
|
|||||||
@@ -1,62 +0,0 @@
|
|||||||
# ### Setting up experimental environment.
|
|
||||||
from _common import *
|
|
||||||
from transformers import TextStreamer
|
|
||||||
|
|
||||||
device_ids = [0, 1]
|
|
||||||
select_device(device_ids)
|
|
||||||
# Note: You need to set the value of `CKPT_FPATH`
|
|
||||||
CKPT_FAPTH = '/path/to/your/iter_xxx.pth'
|
|
||||||
|
|
||||||
# ### Loading Model and Tokenizer
|
|
||||||
BAICHUAN_TYPE = '13B' # Literal['7B', '13B']
|
|
||||||
if BAICHUAN_TYPE == '7B':
|
|
||||||
model_dir = snapshot_download('baichuan-inc/baichuan-7B', 'v1.0.5')
|
|
||||||
model, tokenizer = get_baichuan7B_model_tokenizer(model_dir)
|
|
||||||
else:
|
|
||||||
model_dir = snapshot_download('baichuan-inc/Baichuan-13B-Base', 'v1.0.2')
|
|
||||||
model, tokenizer = get_baichuan13B_model_tokenizer(model_dir)
|
|
||||||
model.bfloat16() # Consistent with training
|
|
||||||
|
|
||||||
# ### Preparing lora
|
|
||||||
LORA_TARGET_MODULES = ['W_pack']
|
|
||||||
LORA_RANK = 8
|
|
||||||
LORA_ALPHA = 32
|
|
||||||
LORA_DROPOUT_P = 0 # Arbitrary value
|
|
||||||
lora_config = LoRAConfig(
|
|
||||||
replace_modules=LORA_TARGET_MODULES,
|
|
||||||
rank=LORA_RANK,
|
|
||||||
lora_alpha=LORA_ALPHA,
|
|
||||||
lora_dropout=LORA_DROPOUT_P,
|
|
||||||
pretrained_weights=CKPT_FAPTH)
|
|
||||||
logger.info(f'lora_config: {lora_config}')
|
|
||||||
Swift.prepare_model(model, lora_config)
|
|
||||||
|
|
||||||
# ### Loading Dataset
|
|
||||||
_, test_dataset = get_alpaca_en_zh_dataset(None, True)
|
|
||||||
|
|
||||||
# ### Inference
|
|
||||||
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
||||||
mini_test_dataset = test_dataset.select(range(5))
|
|
||||||
for d in mini_test_dataset:
|
|
||||||
output = d['output']
|
|
||||||
d['output'] = None
|
|
||||||
input_ids = tokenize_function(d, tokenizer)['input_ids']
|
|
||||||
print(f'[TEST]{tokenizer.decode(input_ids)}', end='')
|
|
||||||
input_ids = torch.tensor(input_ids)[None].cuda()
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
|
||||||
generate_ids = model.generate(
|
|
||||||
input_ids=input_ids,
|
|
||||||
max_new_tokens=512,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
streamer=streamer,
|
|
||||||
pad_token_id=tokenizer.eos_token_id,
|
|
||||||
temperature=0.7,
|
|
||||||
top_k=50,
|
|
||||||
top_p=0.7,
|
|
||||||
do_sample=True)
|
|
||||||
print()
|
|
||||||
print(f'[LABELS]{output}')
|
|
||||||
print(
|
|
||||||
'-----------------------------------------------------------------------------------'
|
|
||||||
)
|
|
||||||
# input('next[ENTER]')
|
|
||||||
@@ -1,186 +0,0 @@
|
|||||||
# ### Setting up experimental environment.
|
|
||||||
"""
|
|
||||||
pip install modelscope
|
|
||||||
pip install numpy pandas matplotlib scikit-learn
|
|
||||||
pip install transformers datasets
|
|
||||||
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
|
|
||||||
pip install tqdm tensorboard torchmetrics sentencepiece charset_normalizer accelerate
|
|
||||||
|
|
||||||
pip install numpy -U # Resolve torchmetrics dependencies and update numpy
|
|
||||||
"""
|
|
||||||
|
|
||||||
from _common import *
|
|
||||||
|
|
||||||
device_ids = [0, 1]
|
|
||||||
select_device(device_ids)
|
|
||||||
seed_everything(42)
|
|
||||||
|
|
||||||
# ### Loading Model and Tokenizer
|
|
||||||
BAICHUAN_TYPE = '13B' # Literal['7B', '13B']
|
|
||||||
WORK_DIR = f'runs/baichuan_{BAICHUAN_TYPE}'
|
|
||||||
#
|
|
||||||
if BAICHUAN_TYPE == '7B':
|
|
||||||
model_dir = snapshot_download('baichuan-inc/baichuan-7B', 'v1.0.5')
|
|
||||||
model, tokenizer = get_baichuan7B_model_tokenizer(model_dir)
|
|
||||||
else:
|
|
||||||
model_dir = snapshot_download('baichuan-inc/Baichuan-13B-Base', 'v1.0.2')
|
|
||||||
model, tokenizer = get_baichuan13B_model_tokenizer(model_dir)
|
|
||||||
#
|
|
||||||
GRADIENT_CHECKPOINTING = True
|
|
||||||
if GRADIENT_CHECKPOINTING:
|
|
||||||
# baichuan13B does not implement the `get_input_embeddings` function
|
|
||||||
if BAICHUAN_TYPE == '13B':
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
|
||||||
return self.model.embed_tokens
|
|
||||||
|
|
||||||
model.__class__.get_input_embeddings = get_input_embeddings.__get__(
|
|
||||||
model)
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
model.enable_input_require_grads()
|
|
||||||
|
|
||||||
# ### Preparing lora
|
|
||||||
LORA_TARGET_MODULES = ['W_pack']
|
|
||||||
LORA_RANK = 8
|
|
||||||
LORA_ALPHA = 32
|
|
||||||
LORA_DROPOUT_P = 0.1
|
|
||||||
lora_config = LoRAConfig(
|
|
||||||
replace_modules=LORA_TARGET_MODULES,
|
|
||||||
rank=LORA_RANK,
|
|
||||||
lora_alpha=LORA_ALPHA,
|
|
||||||
lora_dropout=LORA_DROPOUT_P)
|
|
||||||
logger.info(f'lora_config: {lora_config}')
|
|
||||||
Swift.prepare_model(model, lora_config)
|
|
||||||
#
|
|
||||||
show_freeze_layers(model)
|
|
||||||
print_model_info(model)
|
|
||||||
_p = list(model.parameters())[100]
|
|
||||||
logger.info(f'device: {_p.device}, dtype: {_p.dtype}')
|
|
||||||
model.bfloat16()
|
|
||||||
|
|
||||||
# ### Loading Dataset
|
|
||||||
tokenize_function = partial(tokenize_function, tokenizer=tokenizer)
|
|
||||||
train_dataset, val_dataset = get_alpaca_en_zh_dataset(tokenize_function)
|
|
||||||
# Data analysis
|
|
||||||
stat_dataset(train_dataset)
|
|
||||||
stat_dataset(val_dataset)
|
|
||||||
data_collate_fn = partial(data_collate_fn, tokenizer=tokenizer)
|
|
||||||
print_examples(train_dataset[0], tokenizer)
|
|
||||||
|
|
||||||
# ### Setting Config
|
|
||||||
cfg_file = os.path.join(model_dir, 'configuration.json')
|
|
||||||
#
|
|
||||||
BATCH_SIZE = 1
|
|
||||||
MAX_EPOCHS = 1
|
|
||||||
T_max = get_T_max(len(train_dataset), BATCH_SIZE, MAX_EPOCHS, True)
|
|
||||||
WORK_DIR = get_work_dir(WORK_DIR)
|
|
||||||
EVAL_INTERVAL = 500
|
|
||||||
CONFIG = Config({
|
|
||||||
'train': {
|
|
||||||
'dataloader': {
|
|
||||||
'batch_size_per_gpu': BATCH_SIZE,
|
|
||||||
'workers_per_gpu': 1,
|
|
||||||
'shuffle': True,
|
|
||||||
'drop_last': True,
|
|
||||||
'pin_memory': True
|
|
||||||
},
|
|
||||||
'max_epochs':
|
|
||||||
MAX_EPOCHS,
|
|
||||||
'work_dir':
|
|
||||||
WORK_DIR,
|
|
||||||
'optimizer': {
|
|
||||||
'type': 'AdamW',
|
|
||||||
'lr': 1e-4,
|
|
||||||
'weight_decay': 0.01,
|
|
||||||
'options': {
|
|
||||||
'cumulative_iters': 16,
|
|
||||||
'grad_clip': {
|
|
||||||
'norm_type': 2,
|
|
||||||
'max_norm': 2.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'lr_scheduler': {
|
|
||||||
'type': 'CosineAnnealingLR',
|
|
||||||
'T_max': T_max,
|
|
||||||
'eta_min': 1e-5,
|
|
||||||
'options': {
|
|
||||||
'by_epoch': False,
|
|
||||||
'warmup': {
|
|
||||||
'type': 'LinearWarmup',
|
|
||||||
'warmup_ratio': 0.1,
|
|
||||||
'warmup_iters': 200
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'hooks': [
|
|
||||||
{
|
|
||||||
'type': 'CheckpointHook',
|
|
||||||
'by_epoch': False,
|
|
||||||
'interval': EVAL_INTERVAL,
|
|
||||||
'max_checkpoint_num': 1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'type': 'EvaluationHook',
|
|
||||||
'by_epoch': False,
|
|
||||||
'interval': EVAL_INTERVAL
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'type': 'BestCkptSaverHook',
|
|
||||||
'metric_key': 'acc',
|
|
||||||
'save_best': True,
|
|
||||||
'rule': 'max',
|
|
||||||
'max_checkpoint_num': 1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'type': 'TextLoggerHook',
|
|
||||||
'by_epoch': True, # Whether EpochBasedTrainer is used
|
|
||||||
'interval': 5
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'type': 'TensorboardHook',
|
|
||||||
'by_epoch': False,
|
|
||||||
'interval': 5
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
'evaluation': {
|
|
||||||
'dataloader': {
|
|
||||||
'batch_size_per_gpu': BATCH_SIZE,
|
|
||||||
'workers_per_gpu': 1,
|
|
||||||
'shuffle': False,
|
|
||||||
'drop_last': False,
|
|
||||||
'pin_memory': True
|
|
||||||
},
|
|
||||||
'metrics': [{
|
|
||||||
'type': 'my_metric',
|
|
||||||
'vocab_size': tokenizer.vocab_size
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
# ### Finetuning
|
|
||||||
|
|
||||||
|
|
||||||
def cfg_modify_fn(cfg: Config) -> Config:
|
|
||||||
cfg.update(CONFIG)
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
trainer = EpochBasedTrainer(
|
|
||||||
model=model,
|
|
||||||
cfg_file=cfg_file,
|
|
||||||
data_collator=data_collate_fn,
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
eval_dataset=val_dataset,
|
|
||||||
remove_unused_data=True,
|
|
||||||
seed=42,
|
|
||||||
device='cpu', # No placement for model, leave the model to `device_map`
|
|
||||||
cfg_modify_fn=cfg_modify_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
# ### Visualization
|
|
||||||
tb_dir = os.path.join(WORK_DIR, 'tensorboard_output')
|
|
||||||
plot_image(tb_dir, ['loss'], 0.9)
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
# ### Setting up experimental environment.
|
|
||||||
from _common import *
|
|
||||||
from transformers import TextStreamer
|
|
||||||
|
|
||||||
device_ids = [0, 1]
|
|
||||||
select_device(device_ids)
|
|
||||||
# Note: You need to set the value of `CKPT_FPATH`
|
|
||||||
CKPT_FAPTH = '/path/to/your/iter_xxx.pth'
|
|
||||||
|
|
||||||
# ### Loading Model and Tokenizer
|
|
||||||
model_dir = snapshot_download('ZhipuAI/chatglm2-6b', 'v1.0.6')
|
|
||||||
model, tokenizer = get_chatglm2_model_tokenizer(model_dir)
|
|
||||||
model.bfloat16() # Consistent with training
|
|
||||||
|
|
||||||
# ### Preparing lora
|
|
||||||
LORA_TARGET_MODULES = ['query_key_value']
|
|
||||||
LORA_RANK = 8
|
|
||||||
LORA_ALPHA = 32
|
|
||||||
LORA_DROPOUT_P = 0 # Arbitrary value
|
|
||||||
lora_config = LoRAConfig(
|
|
||||||
replace_modules=LORA_TARGET_MODULES,
|
|
||||||
rank=LORA_RANK,
|
|
||||||
lora_alpha=LORA_ALPHA,
|
|
||||||
lora_dropout=LORA_DROPOUT_P,
|
|
||||||
pretrained_weights=CKPT_FAPTH)
|
|
||||||
logger.info(f'lora_config: {lora_config}')
|
|
||||||
Swift.prepare_model(model, lora_config)
|
|
||||||
|
|
||||||
# ### Loading Dataset
|
|
||||||
_, test_dataset = get_alpaca_en_zh_dataset(None, True)
|
|
||||||
|
|
||||||
# ### Inference
|
|
||||||
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
||||||
mini_test_dataset = test_dataset.select(range(5))
|
|
||||||
for d in mini_test_dataset:
|
|
||||||
output = d['output']
|
|
||||||
d['output'] = None
|
|
||||||
input_ids = tokenize_function(d, tokenizer)['input_ids']
|
|
||||||
print(f'[TEST]{tokenizer.decode(input_ids)}', end='')
|
|
||||||
input_ids = torch.tensor(input_ids)[None].cuda()
|
|
||||||
attention_mask = torch.ones_like(input_ids)
|
|
||||||
generate_ids = model.generate(
|
|
||||||
input_ids=input_ids,
|
|
||||||
max_new_tokens=512,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
streamer=streamer,
|
|
||||||
pad_token_id=tokenizer.eos_token_id,
|
|
||||||
temperature=0.7,
|
|
||||||
top_k=50,
|
|
||||||
top_p=0.7,
|
|
||||||
do_sample=True)
|
|
||||||
print()
|
|
||||||
print(f'[LABELS]{output}')
|
|
||||||
print(
|
|
||||||
'-----------------------------------------------------------------------------------'
|
|
||||||
)
|
|
||||||
# input('next[ENTER]')
|
|
||||||
@@ -1,173 +0,0 @@
|
|||||||
# ### Setting up experimental environment.
|
|
||||||
"""
|
|
||||||
pip install modelscope
|
|
||||||
pip install numpy pandas matplotlib scikit-learn
|
|
||||||
pip install transformers datasets
|
|
||||||
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
|
|
||||||
pip install tqdm tensorboard torchmetrics sentencepiece charset_normalizer accelerate
|
|
||||||
|
|
||||||
pip install numpy -U # Resolve torchmetrics dependencies and update numpy
|
|
||||||
"""
|
|
||||||
|
|
||||||
from _common import *
|
|
||||||
|
|
||||||
device_ids = [0, 1]
|
|
||||||
select_device(device_ids)
|
|
||||||
seed_everything(42)
|
|
||||||
|
|
||||||
# ### Loading Model and Tokenizer
|
|
||||||
WORK_DIR = 'runs/chatglm2'
|
|
||||||
#
|
|
||||||
model_dir = snapshot_download('ZhipuAI/chatglm2-6b', 'v1.0.6')
|
|
||||||
model, tokenizer = get_chatglm2_model_tokenizer(model_dir)
|
|
||||||
#
|
|
||||||
GRADIENT_CHECKPOINTING = True
|
|
||||||
if GRADIENT_CHECKPOINTING:
|
|
||||||
model.gradient_checkpointing_enable()
|
|
||||||
model.enable_input_require_grads()
|
|
||||||
|
|
||||||
# ### Preparing lora
|
|
||||||
LORA_TARGET_MODULES = ['query_key_value']
|
|
||||||
LORA_RANK = 8
|
|
||||||
LORA_ALPHA = 32
|
|
||||||
LORA_DROPOUT_P = 0.1
|
|
||||||
lora_config = LoRAConfig(
|
|
||||||
replace_modules=LORA_TARGET_MODULES,
|
|
||||||
rank=LORA_RANK,
|
|
||||||
lora_alpha=LORA_ALPHA,
|
|
||||||
lora_dropout=LORA_DROPOUT_P)
|
|
||||||
logger.info(f'lora_config: {lora_config}')
|
|
||||||
Swift.prepare_model(model, lora_config)
|
|
||||||
#
|
|
||||||
show_freeze_layers(model)
|
|
||||||
print_model_info(model)
|
|
||||||
_p = list(model.parameters())[100]
|
|
||||||
logger.info(f'device: {_p.device}, dtype: {_p.dtype}')
|
|
||||||
model.bfloat16()
|
|
||||||
|
|
||||||
# ### Loading Dataset
|
|
||||||
tokenize_function = partial(tokenize_function, tokenizer=tokenizer)
|
|
||||||
train_dataset, val_dataset = get_alpaca_en_zh_dataset(tokenize_function)
|
|
||||||
# Data analysis
|
|
||||||
stat_dataset(train_dataset)
|
|
||||||
stat_dataset(val_dataset)
|
|
||||||
data_collate_fn = partial(data_collate_fn, tokenizer=tokenizer)
|
|
||||||
print_examples(train_dataset[0], tokenizer)
|
|
||||||
|
|
||||||
# ### Setting Config
|
|
||||||
cfg_file = os.path.join(model_dir, 'configuration.json')
|
|
||||||
#
|
|
||||||
BATCH_SIZE = 1
|
|
||||||
MAX_EPOCHS = 1
|
|
||||||
T_max = get_T_max(len(train_dataset), BATCH_SIZE, MAX_EPOCHS, True)
|
|
||||||
WORK_DIR = get_work_dir(WORK_DIR)
|
|
||||||
EVAL_INTERVAL = 500
|
|
||||||
CONFIG = Config({
|
|
||||||
'train': {
|
|
||||||
'dataloader': {
|
|
||||||
'batch_size_per_gpu': BATCH_SIZE,
|
|
||||||
'workers_per_gpu': 1,
|
|
||||||
'shuffle': True,
|
|
||||||
'drop_last': True,
|
|
||||||
'pin_memory': True
|
|
||||||
},
|
|
||||||
'max_epochs':
|
|
||||||
MAX_EPOCHS,
|
|
||||||
'work_dir':
|
|
||||||
WORK_DIR,
|
|
||||||
'optimizer': {
|
|
||||||
'type': 'AdamW',
|
|
||||||
'lr': 1e-4,
|
|
||||||
'weight_decay': 0.01,
|
|
||||||
'options': {
|
|
||||||
'cumulative_iters': 16,
|
|
||||||
'grad_clip': {
|
|
||||||
'norm_type': 2,
|
|
||||||
'max_norm': 2.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'lr_scheduler': {
|
|
||||||
'type': 'CosineAnnealingLR',
|
|
||||||
'T_max': T_max,
|
|
||||||
'eta_min': 1e-5,
|
|
||||||
'options': {
|
|
||||||
'by_epoch': False,
|
|
||||||
'warmup': {
|
|
||||||
'type': 'LinearWarmup',
|
|
||||||
'warmup_ratio': 0.1,
|
|
||||||
'warmup_iters': 200
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'hooks': [
|
|
||||||
{
|
|
||||||
'type': 'CheckpointHook',
|
|
||||||
'by_epoch': False,
|
|
||||||
'interval': EVAL_INTERVAL,
|
|
||||||
'max_checkpoint_num': 1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'type': 'EvaluationHook',
|
|
||||||
'by_epoch': False,
|
|
||||||
'interval': EVAL_INTERVAL
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'type': 'BestCkptSaverHook',
|
|
||||||
'metric_key': 'acc',
|
|
||||||
'save_best': True,
|
|
||||||
'rule': 'max',
|
|
||||||
'max_checkpoint_num': 1
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'type': 'TextLoggerHook',
|
|
||||||
'by_epoch': True, # Whether EpochBasedTrainer is used
|
|
||||||
'interval': 5
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'type': 'TensorboardHook',
|
|
||||||
'by_epoch': False,
|
|
||||||
'interval': 5
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
'evaluation': {
|
|
||||||
'dataloader': {
|
|
||||||
'batch_size_per_gpu': BATCH_SIZE,
|
|
||||||
'workers_per_gpu': 1,
|
|
||||||
'shuffle': False,
|
|
||||||
'drop_last': False,
|
|
||||||
'pin_memory': True
|
|
||||||
},
|
|
||||||
'metrics': [{
|
|
||||||
'type': 'my_metric',
|
|
||||||
'vocab_size': tokenizer.vocab_size
|
|
||||||
}]
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
# ### Finetuning
|
|
||||||
|
|
||||||
|
|
||||||
def cfg_modify_fn(cfg: Config) -> Config:
|
|
||||||
cfg.update(CONFIG)
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
|
|
||||||
trainer = EpochBasedTrainer(
|
|
||||||
model=model,
|
|
||||||
cfg_file=cfg_file,
|
|
||||||
data_collator=data_collate_fn,
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
eval_dataset=val_dataset,
|
|
||||||
remove_unused_data=True,
|
|
||||||
seed=42,
|
|
||||||
device='cpu', # No placement for model, leave the model to `device_map`
|
|
||||||
cfg_modify_fn=cfg_modify_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer.train()
|
|
||||||
|
|
||||||
# ### Visualization
|
|
||||||
tb_dir = os.path.join(WORK_DIR, 'tensorboard_output')
|
|
||||||
plot_image(tb_dir, ['loss'], 0.9)
|
|
||||||
122
examples/pytorch/llm/llm_infer.py
Normal file
122
examples/pytorch/llm/llm_infer.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
# ### Setting up experimental environment.
|
||||||
|
from _common import *
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Arguments:
|
||||||
|
device: str = '0' # e.g. '-1'; '0'; '0,1'
|
||||||
|
model_type: str = field(
|
||||||
|
default='baichuan-7b',
|
||||||
|
metadata={
|
||||||
|
'choices':
|
||||||
|
['baichuan-7b', 'baichuan-13b', 'chatglm2', 'llama2-7b']
|
||||||
|
})
|
||||||
|
ckpt_fpath: str = '' # e.g. '/path/to/your/iter_xxx.pth'
|
||||||
|
eval_human: bool = False # False: eval test_dataset
|
||||||
|
data_sample: Optional[int] = None
|
||||||
|
#
|
||||||
|
lora_target_modules: Optional[List[str]] = None
|
||||||
|
lora_rank: int = 8
|
||||||
|
lora_alpha: int = 32
|
||||||
|
lora_dropout_p: float = 0.1
|
||||||
|
#
|
||||||
|
max_new_tokens: int = 512
|
||||||
|
temperature: float = 0.9
|
||||||
|
top_k: int = 50
|
||||||
|
top_p: float = 0.9
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.lora_target_modules is None:
|
||||||
|
if self.model_type in {'baichuan-7b', 'baichuan-13b'}:
|
||||||
|
self.lora_target_modules = ['W_pack']
|
||||||
|
elif self.model_type == 'chatglm2':
|
||||||
|
self.lora_target_modules = ['query_key_value']
|
||||||
|
elif self.model_type == 'llama2-7b':
|
||||||
|
self.lora_target_modules = ['q_proj', 'k_proj', 'v_proj']
|
||||||
|
else:
|
||||||
|
raise ValueError(f'model_type: {self.model_type}')
|
||||||
|
#
|
||||||
|
if not os.path.isfile(self.ckpt_fpath):
|
||||||
|
raise ValueError('Please enter a valid fpath')
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> Arguments:
|
||||||
|
args, = HfArgumentParser([Arguments]).parse_args_into_dataclasses()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
logger.info(args)
|
||||||
|
select_device(args.device)
|
||||||
|
|
||||||
|
# ### Loading Model and Tokenizer
|
||||||
|
if args.model_type == 'baichuan-7b':
|
||||||
|
model_dir = snapshot_download('baichuan-inc/baichuan-7B', 'v1.0.5')
|
||||||
|
model, tokenizer = get_baichuan_model_tokenizer(model_dir)
|
||||||
|
elif args.model_type == 'baichuan-13b':
|
||||||
|
model_dir = snapshot_download('baichuan-inc/Baichuan-13B-Base', 'v1.0.2')
|
||||||
|
model, tokenizer = get_baichuan_model_tokenizer(model_dir)
|
||||||
|
elif args.model_type == 'chatglm2':
|
||||||
|
model_dir = snapshot_download('ZhipuAI/chatglm2-6b', 'v1.0.6')
|
||||||
|
model, tokenizer = get_chatglm2_model_tokenizer(model_dir)
|
||||||
|
elif args.model_type == 'llama2-7b':
|
||||||
|
model_dir = snapshot_download('modelscope/Llama-2-7b-ms', 'v1.0.0')
|
||||||
|
model, tokenizer = get_llama2_model_tokenizer(model_dir)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'model_type: {args.model_type}')
|
||||||
|
|
||||||
|
# ### Preparing lora
|
||||||
|
lora_config = LoRAConfig(
|
||||||
|
replace_modules=args.lora_target_modules,
|
||||||
|
rank=args.lora_rank,
|
||||||
|
lora_alpha=args.lora_alpha,
|
||||||
|
lora_dropout=args.lora_dropout_p,
|
||||||
|
pretrained_weights=args.ckpt_fpath)
|
||||||
|
logger.info(f'lora_config: {lora_config}')
|
||||||
|
Swift.prepare_model(model, lora_config)
|
||||||
|
model.bfloat16() # Consistent with training
|
||||||
|
|
||||||
|
# ### Inference
|
||||||
|
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
generation_config = GenerationConfig(
|
||||||
|
max_new_tokens=args.max_new_tokens,
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_k=args.top_k,
|
||||||
|
top_p=args.top_p,
|
||||||
|
do_sample=True,
|
||||||
|
pad_token_id=tokenizer.eos_token_id)
|
||||||
|
logger.info(generation_config)
|
||||||
|
|
||||||
|
|
||||||
|
def inference(data: Dict[str, Optional[str]]) -> str:
|
||||||
|
input_ids = tokenize_function(data, tokenizer)['input_ids']
|
||||||
|
print(f'[TEST]{tokenizer.decode(input_ids)}', end='')
|
||||||
|
input_ids = torch.tensor(input_ids)[None].cuda()
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
generate_ids = model.generate(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
streamer=streamer,
|
||||||
|
generation_config=generation_config)
|
||||||
|
output_text = tokenizer.decode(generate_ids[0])
|
||||||
|
return output_text
|
||||||
|
|
||||||
|
|
||||||
|
if args.eval_human:
|
||||||
|
while True:
|
||||||
|
instruction = input('<<< ')
|
||||||
|
data = {'instruction': instruction, 'input': None, 'output': None}
|
||||||
|
inference(data)
|
||||||
|
print('-' * 80)
|
||||||
|
else:
|
||||||
|
_, test_dataset = get_alpaca_en_zh_dataset(
|
||||||
|
None, True, split_seed=42, data_sample=None)
|
||||||
|
mini_test_dataset = test_dataset.select(range(10))
|
||||||
|
for data in mini_test_dataset:
|
||||||
|
output = data['output']
|
||||||
|
data['output'] = None
|
||||||
|
inference(data)
|
||||||
|
print()
|
||||||
|
print(f'[LABELS]{output}')
|
||||||
|
print('-' * 80)
|
||||||
|
# input('next[ENTER]')
|
||||||
237
examples/pytorch/llm/llm_sft.py
Normal file
237
examples/pytorch/llm/llm_sft.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
# ### Setting up experimental environment.
|
||||||
|
"""
|
||||||
|
pip install modelscope
|
||||||
|
pip install numpy pandas matplotlib scikit-learn
|
||||||
|
pip install transformers datasets
|
||||||
|
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
|
||||||
|
pip install tqdm tensorboard torchmetrics sentencepiece charset_normalizer
|
||||||
|
pip install accelerate transformers_stream_generator
|
||||||
|
|
||||||
|
pip install numpy -U # Resolve torchmetrics dependencies and update numpy
|
||||||
|
"""
|
||||||
|
|
||||||
|
from _common import *
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Arguments:
|
||||||
|
device: str = '0,1' # e.g. '-1'; '0'; '0,1'
|
||||||
|
seed: int = 42
|
||||||
|
model_type: str = field(
|
||||||
|
default='baichuan-7b',
|
||||||
|
metadata={
|
||||||
|
'choices':
|
||||||
|
['baichuan-7b', 'baichuan-13b', 'chatglm2', 'llama2-7b']
|
||||||
|
})
|
||||||
|
data_sample: Optional[int] = None
|
||||||
|
#
|
||||||
|
lora_target_modules: Optional[List[str]] = None
|
||||||
|
lora_rank: int = 8
|
||||||
|
lora_alpha: int = 32
|
||||||
|
lora_dropout_p: float = 0.1
|
||||||
|
#
|
||||||
|
gradient_checkpoint: bool = True
|
||||||
|
batch_size: int = 1
|
||||||
|
max_epochs: int = 1
|
||||||
|
eval_interval: int = 500
|
||||||
|
learning_rate: float = 1e-4
|
||||||
|
weight_decay: float = 0.01
|
||||||
|
n_accumulate_grad: int = 16
|
||||||
|
grad_clip_norm: float = 1.
|
||||||
|
warmup_iters: int = 200
|
||||||
|
last_max_checkpoint_num: int = 1
|
||||||
|
best_max_checkpoint_num: int = 1
|
||||||
|
#
|
||||||
|
logging_interval: int = 5
|
||||||
|
tb_interval: int = 5
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.lora_target_modules is None:
|
||||||
|
if self.model_type in {'baichuan-7b', 'baichuan-13b'}:
|
||||||
|
self.lora_target_modules = ['W_pack']
|
||||||
|
elif self.model_type == 'chatglm2':
|
||||||
|
self.lora_target_modules = ['query_key_value']
|
||||||
|
elif self.model_type == 'llama2-7b':
|
||||||
|
self.lora_target_modules = ['q_proj', 'k_proj', 'v_proj']
|
||||||
|
else:
|
||||||
|
raise ValueError(f'model_type: {self.model_type}')
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> Arguments:
|
||||||
|
args, = HfArgumentParser([Arguments]).parse_args_into_dataclasses()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
logger.info(args)
|
||||||
|
select_device(args.device)
|
||||||
|
seed_everything(args.seed)
|
||||||
|
|
||||||
|
# ### Loading Model and Tokenizer
|
||||||
|
if args.model_type == 'baichuan-7b':
|
||||||
|
model_dir = snapshot_download('baichuan-inc/baichuan-7B', 'v1.0.5')
|
||||||
|
model, tokenizer = get_baichuan_model_tokenizer(model_dir)
|
||||||
|
elif args.model_type == 'baichuan-13b':
|
||||||
|
model_dir = snapshot_download('baichuan-inc/Baichuan-13B-Base', 'v1.0.2')
|
||||||
|
model, tokenizer = get_baichuan_model_tokenizer(model_dir)
|
||||||
|
elif args.model_type == 'chatglm2':
|
||||||
|
model_dir = snapshot_download('ZhipuAI/chatglm2-6b', 'v1.0.6')
|
||||||
|
model, tokenizer = get_chatglm2_model_tokenizer(model_dir)
|
||||||
|
elif args.model_type == 'llama2-7b':
|
||||||
|
model_dir = snapshot_download('modelscope/Llama-2-7b-ms', 'v1.0.0')
|
||||||
|
model, tokenizer = get_llama2_model_tokenizer(model_dir)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'model_type: {args.model_type}')
|
||||||
|
|
||||||
|
#
|
||||||
|
if args.gradient_checkpoint:
|
||||||
|
# baichuan13B does not implement the `get_input_embeddings` function
|
||||||
|
if args.model_type == 'baichuan-13b':
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
model.__class__.get_input_embeddings = get_input_embeddings.__get__(
|
||||||
|
model)
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
|
||||||
|
# ### Preparing lora
|
||||||
|
lora_config = LoRAConfig(
|
||||||
|
replace_modules=args.lora_target_modules,
|
||||||
|
rank=args.lora_rank,
|
||||||
|
lora_alpha=args.lora_alpha,
|
||||||
|
lora_dropout=args.lora_dropout_p)
|
||||||
|
logger.info(f'lora_config: {lora_config}')
|
||||||
|
Swift.prepare_model(model, lora_config)
|
||||||
|
#
|
||||||
|
show_freeze_layers(model)
|
||||||
|
print_model_info(model)
|
||||||
|
_p: Parameter = list(model.parameters())[100]
|
||||||
|
logger.info(f'device: {_p.device}, dtype: {_p.dtype}')
|
||||||
|
model.bfloat16()
|
||||||
|
|
||||||
|
# ### Loading Dataset
|
||||||
|
tokenize_function = partial(tokenize_function, tokenizer=tokenizer)
|
||||||
|
train_dataset, val_dataset = get_alpaca_en_zh_dataset(
|
||||||
|
tokenize_function, split_seed=42, data_sample=args.data_sample)
|
||||||
|
# Data analysis
|
||||||
|
stat_dataset(train_dataset)
|
||||||
|
stat_dataset(val_dataset)
|
||||||
|
data_collate_fn = partial(data_collate_fn, tokenizer=tokenizer)
|
||||||
|
print_example(train_dataset[0], tokenizer)
|
||||||
|
|
||||||
|
# ### Setting Config
|
||||||
|
cfg_file = os.path.join(model_dir, 'configuration.json')
|
||||||
|
#
|
||||||
|
T_max = get_T_max(len(train_dataset), args.batch_size, args.max_epochs, True)
|
||||||
|
work_dir = get_work_dir(f'runs/{args.model_type}')
|
||||||
|
config = Config({
|
||||||
|
'train': {
|
||||||
|
'dataloader': {
|
||||||
|
'batch_size_per_gpu': args.batch_size,
|
||||||
|
'workers_per_gpu': 1,
|
||||||
|
'shuffle': True,
|
||||||
|
'drop_last': True,
|
||||||
|
'pin_memory': True
|
||||||
|
},
|
||||||
|
'max_epochs':
|
||||||
|
args.max_epochs,
|
||||||
|
'work_dir':
|
||||||
|
work_dir,
|
||||||
|
'optimizer': {
|
||||||
|
'type': 'AdamW',
|
||||||
|
'lr': args.learning_rate,
|
||||||
|
'weight_decay': args.weight_decay,
|
||||||
|
'options': {
|
||||||
|
'cumulative_iters': args.n_accumulate_grad,
|
||||||
|
'grad_clip': {
|
||||||
|
'norm_type': 2,
|
||||||
|
'max_norm': args.grad_clip_norm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'lr_scheduler': {
|
||||||
|
'type': 'CosineAnnealingLR',
|
||||||
|
'T_max': T_max,
|
||||||
|
'eta_min': 0,
|
||||||
|
'options': {
|
||||||
|
'by_epoch': False,
|
||||||
|
'warmup': {
|
||||||
|
'type': 'LinearWarmup',
|
||||||
|
'warmup_ratio': 0.1,
|
||||||
|
'warmup_iters': args.warmup_iters
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'hooks': [
|
||||||
|
{
|
||||||
|
'type': 'CheckpointHook',
|
||||||
|
'by_epoch': False,
|
||||||
|
'interval': args.eval_interval,
|
||||||
|
'max_checkpoint_num': args.last_max_checkpoint_num
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'type': 'EvaluationHook',
|
||||||
|
'by_epoch': False,
|
||||||
|
'interval': args.eval_interval
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'type': 'BestCkptSaverHook',
|
||||||
|
'metric_key': 'loss',
|
||||||
|
'save_best': True,
|
||||||
|
'rule': 'min',
|
||||||
|
'max_checkpoint_num': args.best_max_checkpoint_num
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'type': 'TextLoggerHook',
|
||||||
|
'by_epoch': True, # Whether EpochBasedTrainer is used
|
||||||
|
'interval': args.logging_interval
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'type': 'TensorboardHook',
|
||||||
|
'by_epoch': False,
|
||||||
|
'interval': args.tb_interval
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
'evaluation': {
|
||||||
|
'dataloader': {
|
||||||
|
'batch_size_per_gpu': args.batch_size,
|
||||||
|
'workers_per_gpu': 1,
|
||||||
|
'shuffle': False,
|
||||||
|
'drop_last': False,
|
||||||
|
'pin_memory': True
|
||||||
|
},
|
||||||
|
'metrics': [{
|
||||||
|
'type': 'my_metric',
|
||||||
|
'vocab_size': tokenizer.vocab_size
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
# ### Finetuning
|
||||||
|
|
||||||
|
|
||||||
|
def cfg_modify_fn(cfg: Config) -> Config:
|
||||||
|
cfg.update(config)
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
trainer = EpochBasedTrainer(
|
||||||
|
model=model,
|
||||||
|
cfg_file=cfg_file,
|
||||||
|
data_collator=data_collate_fn,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=val_dataset,
|
||||||
|
remove_unused_data=True,
|
||||||
|
seed=42,
|
||||||
|
device='cpu', # No placement for model, leave the model to `device_map`
|
||||||
|
cfg_modify_fn=cfg_modify_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# ### Visualization
|
||||||
|
tb_dir = os.path.join(work_dir, 'tensorboard_output')
|
||||||
|
plot_image(tb_dir, ['loss'], 0.9)
|
||||||
5
examples/pytorch/llm/run_infer.sh
Normal file
5
examples/pytorch/llm/run_infer.sh
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
python llm_infer.py \
|
||||||
|
--device 0 \
|
||||||
|
--model_type llama2-7b \
|
||||||
|
--ckpt_fpath "runs/llama2-7b/vx_xxx/output_best/pytorch_model.bin" \
|
||||||
|
--eval_human true
|
||||||
8
examples/pytorch/llm/run_sft.sh
Normal file
8
examples/pytorch/llm/run_sft.sh
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
DATE=$(date +"%Y%m%d-%H%M%S")
|
||||||
|
nohup python llm_sft.py \
|
||||||
|
--device 0 \
|
||||||
|
--model_type llama2-7b \
|
||||||
|
--data_sample 25000 \
|
||||||
|
&> train_$DATE.out &
|
||||||
Reference in New Issue
Block a user