Merge branch 'master' of github.com:modelscope/modelscope into auto_support

This commit is contained in:
wenmeng.zwm
2023-07-24 09:54:52 +08:00
37 changed files with 2632 additions and 732 deletions

View File

@@ -142,6 +142,14 @@ class Chatglm6bArguments(TrainingArgs):
metadata={'help': 'The lora alpha'},
)
use_amp: int = field(
default=0,
metadata={
'help':
'Whether to use amp(automatic mixed precision) to train the model.'
},
)
args = Chatglm6bArguments(eval_metrics='chatglm').parse_cli()
print(args)
@@ -159,6 +167,13 @@ def cfg_modify_fn(cfg):
cfg.merge_from_dict(config)
else:
cfg = config
if args.use_amp:
if not getattr(cfg.train, 'hooks', None):
cfg.train.hooks = []
cfg.train.hooks.append({
'type': 'TorchAMPOptimizerHook',
# Optional loss_scale parameter here.
})
if cfg.train.lr_scheduler.type == 'LinearLR':
cfg.train.lr_scheduler['total_iters'] = \
int(len(train_dataset) / cfg.train.dataloader.batch_size_per_gpu) * cfg.train.max_epochs
@@ -192,7 +207,7 @@ tokenizer = ChatGLMTokenizer.from_pretrained(model_dir, trust_remote_code=True)
device_map_kwargs = {}
device_kwargs = {}
if args.use_lora != 0:
if args.use_lora != 0 and torch.cuda.device_count() > 1:
device_map_kwargs['device_map'] = 'auto'
# No placement for model, leave the model to `device_map`
device_kwargs['device'] = 'cpu'
@@ -228,7 +243,10 @@ if args.use_lora != 0:
rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout)
model = model.bfloat16()
if args.use_amp:
model = model.float()
else:
model = model.bfloat16()
Swift.prepare_model(model, lora_config)
prefix = args.source_prefix if args.source_prefix is not None else ''

View File

@@ -5,6 +5,7 @@ import os
import random
import re
import sys
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -15,7 +16,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from datasets import Dataset as HFDataset
from datasets import Dataset as HfDataset
from datasets import concatenate_datasets
from matplotlib.axes import Axes
from matplotlib.figure import Figure
@@ -36,6 +37,8 @@ from torch.utils.data import Dataset
from torchmetrics import Accuracy, MeanMetric
#
from tqdm import tqdm
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
GenerationConfig, HfArgumentParser, TextStreamer)
#
from modelscope import (Model, MsDataset, get_logger, read_config,
@@ -51,25 +54,16 @@ from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.registry import default_group
#
TEST_SPLIT_P = 0.01
SPLIT_SEED = 42
MAX_LENGTH: Optional[int] = 2048
COLOR, COLOR_S = '#FFE2D9', '#FF7043'
PROMPT = """### 用户
{instruction}
### AI助手
"""
PROMPT = """Human: {instruction}
AI: """
logger = get_logger()
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
#
def get_model_dir(model_id: str, model_revision: Optional[str] = None) -> str:
model_dir = snapshot_download(model_id, model_revision)
return model_dir
def _get_version(work_dir: str) -> int:
if os.path.isdir(work_dir):
fnames = os.listdir(work_dir)
@@ -96,28 +90,40 @@ def get_work_dir(work_dir: str) -> str:
return work_dir
def select_device(device_ids: List[int]) -> Device:
def _format_device(device: Union[List[int], str]) -> Tuple[List[int], str]:
if isinstance(device, list):
device_ids = device
device_str = ','.join([str(d) for d in device])
else:
device_ids = [int(d) for d in device.split(',') if d != '-1']
device_str = device
device_str = device_str.replace(' ', '')
return device_ids, device_str
def select_device(device: Union[List[int], str]) -> Device:
"""Call this function before cuda is initialized.
Return: master device
device: e.g. []: 'cpu', [0], [0, 1, 2]
e.g. '-1': 'cpu', '0', '0,1,2'
"""
if torch.cuda.is_initialized():
logger.warning('CUDA has been initialized! Device selection fails!')
return torch.device('cuda:0')
#
device_ids, device_str = _format_device(device)
#
os.environ['CUDA_VISIBLE_DEVICES'] = device_str
log_s = 'Using device: '
if len(device_ids) == 0: # cpu
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
device: str = 'cpu'
log_s += device
if len(device_ids) == 0:
master_device: str = 'cpu'
log_s += 'cpu'
else:
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
[str(d) for d in device_ids])
assert torch.cuda.is_available(
) and torch.cuda.device_count() >= len(device_ids)
log_s += f"cuda:{','.join([str(d) for d in device_ids])}" # e.g. 'cuda:1,7,8'
device = 'cuda:0'
master_device = 'cuda:0'
log_s += f'cuda:{device_str}'
logger.info(log_s)
return torch.device(device)
return torch.device(master_device)
def seed_everything(seed: Optional[int] = None, gpu_dtm: bool = False) -> int:
@@ -148,10 +154,12 @@ def get_T_max(dataset_len: int, batch_size: int, max_epochs: int,
return T_max
def tokenize_function(example: Dict[str, str], tokenizer) -> Dict[str, Any]:
def tokenize_function(example: Dict[str, Optional[str]],
tokenizer,
max_length: Optional[int] = 2048) -> Dict[str, Any]:
"""Only applicable to baichuan and chatglm2. Other models need to be tested"""
instruction = example['instruction']
input_: str = example['input']
instruction: str = example['instruction']
input_ = example['input']
if input_ is not None and input_ != '':
# instruction = instruction + '\n'
if input_.startswith('输入:'):
@@ -159,12 +167,12 @@ def tokenize_function(example: Dict[str, str], tokenizer) -> Dict[str, Any]:
else:
instruction = instruction + input_
output = example['output']
src_text = PROMPT.format(instruction=instruction, add_special_tokens=False)
src_text = PROMPT.format(instruction=instruction)
src_input_ids: List[int] = tokenizer(
src_text, return_attention_mask=False,
add_special_tokens=True)['input_ids']
# tokenizer.bos_token_id: Avoid `tgt_input_ids` being empty
tgt_input_ids = [tokenizer.bos_token_id]
#
tgt_input_ids = []
if output is not None:
tgt_input_ids += tokenizer(
output, return_attention_mask=False,
@@ -175,15 +183,15 @@ def tokenize_function(example: Dict[str, str], tokenizer) -> Dict[str, Any]:
labels = None
input_ids = src_input_ids + tgt_input_ids
#
if MAX_LENGTH is not None:
input_ids = input_ids[-MAX_LENGTH:]
if max_length is not None:
input_ids = input_ids[-max_length:]
if labels is not None:
labels = labels[-MAX_LENGTH:]
labels = labels[-max_length:]
#
return {'input_ids': input_ids, 'labels': labels}
def stat_dataset(dataset: HFDataset) -> None:
def stat_dataset(dataset: HfDataset) -> None:
"""Statistical analysis was performed on the data set"""
_token_len = []
for d in dataset:
@@ -198,10 +206,12 @@ def stat_dataset(dataset: HFDataset) -> None:
)
def print_examples(examples: Dict[str, Any], tokenizer) -> None:
input_ids, labels = examples['input_ids'], examples['labels']
print(f'[INPUT_IDS] {tokenizer.decode(input_ids)}')
def print_example(example: Dict[str, Any], tokenizer) -> None:
input_ids, labels = example['input_ids'], example['labels']
print(f'[INPUT_IDS] {input_ids}')
print(f'[INPUT] {tokenizer.decode(input_ids)}')
print()
print(f'[LABLES_IDS] {labels}')
print(
f'[LABLES] {tokenizer.decode([lb if lb != -100 else 0 for lb in labels])}'
)
@@ -283,66 +293,49 @@ class MyMetric(Metric):
}
def merge(self, other: 'MyMetric') -> None:
"""This script does not support ddp"""
"""This script does not support ddp. TODO"""
raise NotImplementedError
def get_baichuan7B_model_tokenizer(model_dir: Optional[str] = None,
load_model: bool = True):
if model_dir is None:
model_id = 'baichuan-inc/baichuan-7B'
model_dir = get_model_dir(model_id, None)
#
def _add_special_token(tokenizer):
if tokenizer.eos_token_id is None:
tokenizer.eos_token_id = 2
if tokenizer.bos_token_id is None:
tokenizer.bos_token_id = 1
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = 0
logger.info(f'bos_token_id: {tokenizer.bos_token_id}, '
f'eos_token_id: {tokenizer.eos_token_id}, '
f'pad_token_id: {tokenizer.pad_token_id}')
def get_baichuan_model_tokenizer(model_dir: str,
load_model: bool = True,
add_special_token: bool = True):
sys.path.insert(0, model_dir)
from configuration_baichuan import BaiChuanConfig
from tokenization_baichuan import BaiChuanTokenizer
from modeling_baichuan import BaiChuanForCausalLM
model_config = BaiChuanConfig.from_pretrained(model_dir)
model_config = AutoConfig.from_pretrained(
model_dir, trust_remote_code=True)
model_config.torch_dtype = torch.float16
logger.info(f'model_config: {model_config}')
tokenizer = BaiChuanTokenizer.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(
model_dir, trust_remote_code=True)
model = None
if load_model:
model = BaiChuanForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
model_dir,
config=model_config,
device_map='auto',
torch_dtype=torch.float16)
torch_dtype=torch.float16,
trust_remote_code=True)
#
if add_special_token:
_add_special_token(tokenizer)
return model, tokenizer
def get_baichuan13B_model_tokenizer(model_dir: Optional[str] = None,
load_model: bool = True):
if model_dir is None:
model_id = 'baichuan-inc/Baichuan-13B-Base'
model_dir = get_model_dir(model_id, 'v1.0.1')
#
sys.path.insert(0, model_dir)
from configuration_baichuan import BaichuanConfig
from tokenization_baichuan import BaichuanTokenizer
from modeling_baichuan import BaichuanForCausalLM
model_config = BaichuanConfig.from_pretrained(model_dir)
model_config.torch_dtype = torch.float16
logger.info(f'model_config: {model_config}')
tokenizer = BaichuanTokenizer.from_pretrained(model_dir)
model = None
if load_model:
model = BaichuanForCausalLM.from_pretrained(
model_dir,
config=model_config,
device_map='auto',
torch_dtype=torch.float16)
#
return model, tokenizer
def get_chatglm2_model_tokenizer(model_dir: Optional[str] = None,
load_model: bool = True):
if model_dir is None:
model_id = 'ZhipuAI/chatglm2-6b'
model_dir = snapshot_download(model_id, None)
#
def get_chatglm2_model_tokenizer(model_dir: str,
load_model: bool = True,
add_special_token: bool = True):
config = read_config(model_dir)
config['model'] = ConfigDict({'type': 'chatglm2-6b'})
tokenizer = ChatGLM2Tokenizer.from_pretrained(model_dir)
@@ -353,25 +346,49 @@ def get_chatglm2_model_tokenizer(model_dir: Optional[str] = None,
cfg_dict=config,
device_map='auto',
torch_dtype=torch.float16)
if add_special_token:
_add_special_token(tokenizer)
return model, tokenizer
def get_llama2_model_tokenizer(model_dir: str,
load_model: bool = True,
add_special_token: bool = True):
config = AutoConfig.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = None
if load_model:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
config=config,
device_map='auto',
torch_dtype=torch.float16,
)
if add_special_token:
_add_special_token(tokenizer)
return model, tokenizer
def get_alpaca_en_zh_dataset(
tokenize_function,
only_val: bool = False) -> Tuple[HFDataset, HFDataset]:
only_val: bool = False,
test_split_p: float = 0.01,
split_seed: int = 42,
data_sample: Optional[int] = None) -> Tuple[HfDataset, HfDataset]:
"""
split: Literal['train', 'validation', None]
"""
dataset_en: HFDataset = MsDataset.load(
dataset_en: HfDataset = MsDataset.load(
'AI-ModelScope/alpaca-gpt4-data-en', split='train').to_hf_dataset()
dataset_zh: HFDataset = MsDataset.load(
dataset_zh: HfDataset = MsDataset.load(
'AI-ModelScope/alpaca-gpt4-data-zh', split='train').to_hf_dataset()
dataset_en = dataset_en.remove_columns(['text'])
dataset: HFDataset = concatenate_datasets([dataset_zh, dataset_en])
dataset: HfDataset = concatenate_datasets([dataset_zh, dataset_en])
#
# dataset = dataset.select(range(1000)) # for debug
dataset = dataset.train_test_split(TEST_SPLIT_P, seed=SPLIT_SEED)
if data_sample is not None:
dataset = dataset.select(range(data_sample))
dataset = dataset.train_test_split(test_split_p, seed=split_seed)
if only_val:
dataset = dataset['test']
if tokenize_function is not None:
@@ -428,7 +445,7 @@ def plot_image(tb_dir: str,
fname = os.listdir(tb_dir)[0]
tb_path = os.path.join(tb_dir, fname)
data = read_tensorboard_file(tb_path)
#
for k in data.keys():
_data = data[k]
steps = [d['step'] for d in _data]

View File

@@ -1,62 +0,0 @@
# ### Setting up experimental environment.
from _common import *
from transformers import TextStreamer
device_ids = [0, 1]
logger.info(device_ids)
select_device(device_ids)
# ### Loading Model and Tokenizer
# Note: You need to set the value of `CKPT_FPATH`
BAICHUAN_TYPE = '13B' # Literal['7B', '13B']
CKPT_FAPTH = '/path/to/your/xxx.pth'
LORA_TARGET_MODULES = ['W_pack']
if BAICHUAN_TYPE == '7B':
model, tokenizer = get_baichuan7B_model_tokenizer()
else:
model, tokenizer = get_baichuan13B_model_tokenizer()
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
model.bfloat16() # Consistent with training
# ### Preparing lora
LORA_RANK = 8
LORA_ALPHA = 32
LORA_DROPOUT_P = 0 # Arbitrary value
lora_config = LoRAConfig(
replace_modules=LORA_TARGET_MODULES,
rank=LORA_RANK,
lora_alpha=LORA_ALPHA,
lora_dropout=LORA_DROPOUT_P,
pretrained_weights=CKPT_FAPTH)
logger.info(f'lora_config: {lora_config}')
Swift.prepare_model(model, lora_config)
# ### Loading Dataset
_, test_dataset = get_alpaca_en_zh_dataset(None, True)
# ### Inference
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
for d in test_dataset[:5]:
output = d['output']
d['output'] = None
input_ids = tokenize_function(d, tokenizer)['input_ids']
print(f'[TEST]{tokenizer.decode(input_ids)}', end='')
input_ids = torch.tensor(input_ids)[None].cuda()
attention_mask = torch.ones_like(input_ids)
generate_ids = model.generate(
input_ids=input_ids,
max_new_tokens=512,
attention_mask=attention_mask,
streamer=streamer,
pad_token_id=tokenizer.pad_token_id,
temperature=0.7,
top_k=50,
do_sample=True)
print()
print(f'[LABELS]{output}')
print(
'-----------------------------------------------------------------------------------'
)
# input('next[ENTER]')

View File

@@ -1,199 +0,0 @@
# ### Setting up experimental environment.
"""
pip install modelscope
pip install numpy pandas matplotlib scikit-learn
pip install transformers datasets
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install tqdm
pip install tensorboard
pip install torchmetrics
pip install sentencepiece
pip install accelerate
pip install numpy -U # Resolve torchmetrics dependencies and update numpy
"""
from _common import *
device_ids = [0, 1, 2, 3]
logger.info(device_ids)
select_device(device_ids)
seed_everything(42)
# ### Loading Model and Tokenizer
BAICHUAN_TYPE = '13B' # Literal['7B', '13B']
WORK_DIR = f'runs/baichuan_{BAICHUAN_TYPE}'
LORA_TARGET_MODULES = ['W_pack']
#
if BAICHUAN_TYPE == '7B':
model_id = 'baichuan-inc/baichuan-7B'
model_dir = get_model_dir(model_id, None)
model, tokenizer = get_baichuan7B_model_tokenizer(model_dir)
else:
model_id = 'baichuan-inc/Baichuan-13B-Base'
model_dir = get_model_dir(model_id, 'v1.0.1')
model, tokenizer = get_baichuan13B_model_tokenizer(model_dir)
#
GRADIENT_CHECKPOINTING = True
if GRADIENT_CHECKPOINTING:
# baichuan13B does not implement the `get_input_embeddings` function
if BAICHUAN_TYPE == '13B':
def get_input_embeddings(self):
return self.model.embed_tokens
model.__class__.get_input_embeddings = get_input_embeddings.__get__(
model)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
#
logger.info(
f'bos_token_id: {tokenizer.bos_token_id}, eos_token_id: {tokenizer.eos_token_id}, '
f'pad_token_id: {tokenizer.pad_token_id}')
# ### Preparing lora
LORA_RANK = 8
LORA_ALPHA = 32
LORA_DROPOUT_P = 0.1
lora_config = LoRAConfig(
replace_modules=LORA_TARGET_MODULES,
rank=LORA_RANK,
lora_alpha=LORA_ALPHA,
lora_dropout=LORA_DROPOUT_P)
logger.info(f'lora_config: {lora_config}')
Swift.prepare_model(model, lora_config)
#
show_freeze_layers(model)
print_model_info(model)
_p = list(model.parameters())[100]
logger.info(f'device: {_p.device}, dtype: {_p.dtype}')
model.bfloat16()
# ### Loading Dataset
tokenize_function = partial(tokenize_function, tokenizer=tokenizer)
train_dataset, val_dataset = get_alpaca_en_zh_dataset(tokenize_function)
# Data analysis
stat_dataset(train_dataset)
stat_dataset(val_dataset)
data_collate_fn = partial(data_collate_fn, tokenizer=tokenizer)
print_examples(train_dataset[0], tokenizer)
# ### Setting Config
cfg_file = os.path.join(model_dir, 'configuration.json')
#
BATCH_SIZE = 1
MAX_EPOCHS = 1
T_max = get_T_max(len(train_dataset), BATCH_SIZE, MAX_EPOCHS, True)
WORK_DIR = get_work_dir(WORK_DIR)
EVAL_INTERVAL = 500
CONFIG = Config({
'train': {
'dataloader': {
'batch_size_per_gpu': BATCH_SIZE,
'workers_per_gpu': 1,
'shuffle': True,
'drop_last': True,
'pin_memory': True
},
'max_epochs':
MAX_EPOCHS,
'work_dir':
WORK_DIR,
'optimizer': {
'type': 'AdamW',
'lr': 1e-4,
'weight_decay': 0.01,
'options': {
'cumulative_iters': 16,
'grad_clip': {
'norm_type': 2,
'max_norm': 2.0
}
}
},
'lr_scheduler': {
'type': 'CosineAnnealingLR',
'T_max': T_max,
'eta_min': 1e-5,
'options': {
'by_epoch': False,
'warmup': {
'type': 'LinearWarmup',
'warmup_ratio': 0.1,
'warmup_iters': 200
}
}
},
'hooks': [
{
'type': 'CheckpointHook',
'by_epoch': False,
'interval': EVAL_INTERVAL,
'max_checkpoint_num': 1
},
{
'type': 'EvaluationHook',
'by_epoch': False,
'interval': EVAL_INTERVAL
},
{
'type': 'BestCkptSaverHook',
'metric_key': 'acc',
'save_best': True,
'rule': 'max',
'max_checkpoint_num': 1
},
{
'type': 'TextLoggerHook',
'by_epoch': True, # Whether EpochBasedTrainer is used
'interval': 5
},
{
'type': 'TensorboardHook',
'by_epoch': False,
'interval': 5
}
]
},
'evaluation': {
'dataloader': {
'batch_size_per_gpu': BATCH_SIZE,
'workers_per_gpu': 1,
'shuffle': False,
'drop_last': False,
'pin_memory': True
},
'metrics': [{
'type': 'my_metric',
'vocab_size': tokenizer.vocab_size
}]
}
})
# ### Finetuning
def cfg_modify_fn(cfg: Config) -> Config:
cfg.update(CONFIG)
return cfg
trainer = EpochBasedTrainer(
model=model,
cfg_file=cfg_file,
data_collator=data_collate_fn,
train_dataset=train_dataset,
eval_dataset=val_dataset,
remove_unused_data=True,
seed=42,
device='cpu', # No placement for model, leave the model to `device_map`
cfg_modify_fn=cfg_modify_fn,
)
trainer.train()
# ### Visualization
tb_dir = os.path.join(WORK_DIR, 'tensorboard_output')
plot_image(tb_dir, ['loss'], 0.9)

View File

@@ -1,60 +0,0 @@
# ### Setting up experimental environment.
from _common import *
from transformers import TextStreamer
device_ids = [0, 1]
logger.info(device_ids)
select_device(device_ids)
# ### Loading Model and Tokenizer
# Note: You need to set the value of `CKPT_FPATH`
CKPT_FAPTH = '/path/to/your/xxx.pth'
LORA_TARGET_MODULES = ['query_key_value']
model, tokenizer = get_chatglm2_model_tokenizer()
if tokenizer.eos_token_id is None:
tokenizer.eos_token_id = tokenizer.pad_token_id
if tokenizer.bos_token_id is None:
tokenizer.bos_token_id = 1
model.bfloat16() # Consistent with training
# ### Preparing lora
LORA_RANK = 8
LORA_ALPHA = 32
LORA_DROPOUT_P = 0 # Arbitrary value
lora_config = LoRAConfig(
replace_modules=LORA_TARGET_MODULES,
rank=LORA_RANK,
lora_alpha=LORA_ALPHA,
lora_dropout=LORA_DROPOUT_P,
pretrained_weights=CKPT_FAPTH)
logger.info(f'lora_config: {lora_config}')
Swift.prepare_model(model, lora_config)
# ### Loading Dataset
_, test_dataset = get_alpaca_en_zh_dataset(None, True)
# ### Inference
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
for d in test_dataset[:5]:
output = d['output']
d['output'] = None
input_ids = tokenize_function(d, tokenizer)['input_ids']
print(f'[TEST]{tokenizer.decode(input_ids)}', end='')
input_ids = torch.tensor(input_ids)[None].cuda()
attention_mask = torch.ones_like(input_ids)
generate_ids = model.generate(
input_ids=input_ids,
max_new_tokens=512,
attention_mask=attention_mask,
streamer=streamer,
pad_token_id=tokenizer.pad_token_id,
temperature=0.7,
top_k=50,
do_sample=True)
print()
print(f'[LABELS]{output}')
print(
'-----------------------------------------------------------------------------------'
)
# input('next[ENTER]')

View File

@@ -1,188 +0,0 @@
# ### Setting up experimental environment.
"""
pip install modelscope
pip install numpy pandas matplotlib scikit-learn
pip install transformers datasets
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install tqdm
pip install tensorboard
pip install torchmetrics
pip install sentencepiece
pip install accelerate
pip install numpy -U # Resolve torchmetrics dependencies and update numpy
"""
from _common import *
device_ids = [0, 1, 2, 3]
logger.info(device_ids)
select_device(device_ids)
seed_everything(42)
# ### Loading Model and Tokenizer
model_id = 'ZhipuAI/chatglm2-6b'
WORK_DIR = 'runs/chatglm2'
LORA_TARGET_MODULES = ['query_key_value']
#
model_dir = get_model_dir(model_id, None)
model, tokenizer = get_chatglm2_model_tokenizer(model_dir)
# chatglm2 does not support gradient_checkpointing
GRADIENT_CHECKPOINTING = False
if GRADIENT_CHECKPOINTING:
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
logger.info(tokenizer.special_tokens)
if tokenizer.eos_token_id is None:
tokenizer.eos_token_id = tokenizer.pad_token_id
if tokenizer.bos_token_id is None:
tokenizer.bos_token_id = 1
#
logger.info(
f'bos_token_id: {tokenizer.bos_token_id}, eos_token_id: {tokenizer.eos_token_id}, '
f'pad_token_id: {tokenizer.pad_token_id}')
# ### Preparing lora
LORA_RANK = 8
LORA_ALPHA = 32
LORA_DROPOUT_P = 0.1
lora_config = LoRAConfig(
replace_modules=LORA_TARGET_MODULES,
rank=LORA_RANK,
lora_alpha=LORA_ALPHA,
lora_dropout=LORA_DROPOUT_P)
logger.info(f'lora_config: {lora_config}')
Swift.prepare_model(model, lora_config)
#
show_freeze_layers(model)
print_model_info(model)
_p = list(model.parameters())[100]
logger.info(f'device: {_p.device}, dtype: {_p.dtype}')
model.bfloat16()
# ### Loading Dataset
tokenize_function = partial(tokenize_function, tokenizer=tokenizer)
train_dataset, val_dataset = get_alpaca_en_zh_dataset(tokenize_function)
# Data analysis
stat_dataset(train_dataset)
stat_dataset(val_dataset)
data_collate_fn = partial(data_collate_fn, tokenizer=tokenizer)
print_examples(train_dataset[0], tokenizer)
# ### Setting Config
cfg_file = os.path.join(model_dir, 'configuration.json')
#
BATCH_SIZE = 1
MAX_EPOCHS = 1
T_max = get_T_max(len(train_dataset), BATCH_SIZE, MAX_EPOCHS, True)
WORK_DIR = get_work_dir(WORK_DIR)
EVAL_INTERVAL = 500
CONFIG = Config({
'train': {
'dataloader': {
'batch_size_per_gpu': BATCH_SIZE,
'workers_per_gpu': 1,
'shuffle': True,
'drop_last': True,
'pin_memory': True
},
'max_epochs':
MAX_EPOCHS,
'work_dir':
WORK_DIR,
'optimizer': {
'type': 'AdamW',
'lr': 1e-4,
'weight_decay': 0.01,
'options': {
'cumulative_iters': 16,
'grad_clip': {
'norm_type': 2,
'max_norm': 2.0
}
}
},
'lr_scheduler': {
'type': 'CosineAnnealingLR',
'T_max': T_max,
'eta_min': 1e-5,
'options': {
'by_epoch': False,
'warmup': {
'type': 'LinearWarmup',
'warmup_ratio': 0.1,
'warmup_iters': 200
}
}
},
'hooks': [
{
'type': 'CheckpointHook',
'by_epoch': False,
'interval': EVAL_INTERVAL,
'max_checkpoint_num': 1
},
{
'type': 'EvaluationHook',
'by_epoch': False,
'interval': EVAL_INTERVAL
},
{
'type': 'BestCkptSaverHook',
'metric_key': 'acc',
'save_best': True,
'rule': 'max',
'max_checkpoint_num': 1
},
{
'type': 'TextLoggerHook',
'by_epoch': True, # Whether EpochBasedTrainer is used
'interval': 5
},
{
'type': 'TensorboardHook',
'by_epoch': False,
'interval': 5
}
]
},
'evaluation': {
'dataloader': {
'batch_size_per_gpu': BATCH_SIZE,
'workers_per_gpu': 1,
'shuffle': False,
'drop_last': False,
'pin_memory': True
},
'metrics': [{
'type': 'my_metric',
'vocab_size': tokenizer.vocab_size
}]
}
})
# ### Finetuning
def cfg_modify_fn(cfg: Config) -> Config:
cfg.update(CONFIG)
return cfg
trainer = EpochBasedTrainer(
model=model,
cfg_file=cfg_file,
data_collator=data_collate_fn,
train_dataset=train_dataset,
eval_dataset=val_dataset,
remove_unused_data=True,
seed=42,
device='cpu', # No placement for model, leave the model to `device_map`
cfg_modify_fn=cfg_modify_fn,
)
trainer.train()
# ### Visualization
tb_dir = os.path.join(WORK_DIR, 'tensorboard_output')
plot_image(tb_dir, ['loss'], 0.9)

View File

@@ -0,0 +1,122 @@
# ### Setting up experimental environment.
from _common import *
@dataclass
class Arguments:
device: str = '0' # e.g. '-1'; '0'; '0,1'
model_type: str = field(
default='baichuan-7b',
metadata={
'choices':
['baichuan-7b', 'baichuan-13b', 'chatglm2', 'llama2-7b']
})
ckpt_fpath: str = '' # e.g. '/path/to/your/iter_xxx.pth'
eval_human: bool = False # False: eval test_dataset
data_sample: Optional[int] = None
#
lora_target_modules: Optional[List[str]] = None
lora_rank: int = 8
lora_alpha: int = 32
lora_dropout_p: float = 0.1
#
max_new_tokens: int = 512
temperature: float = 0.9
top_k: int = 50
top_p: float = 0.9
def __post_init__(self):
if self.lora_target_modules is None:
if self.model_type in {'baichuan-7b', 'baichuan-13b'}:
self.lora_target_modules = ['W_pack']
elif self.model_type == 'chatglm2':
self.lora_target_modules = ['query_key_value']
elif self.model_type == 'llama2-7b':
self.lora_target_modules = ['q_proj', 'k_proj', 'v_proj']
else:
raise ValueError(f'model_type: {self.model_type}')
#
if not os.path.isfile(self.ckpt_fpath):
raise ValueError('Please enter a valid fpath')
def parse_args() -> Arguments:
args, = HfArgumentParser([Arguments]).parse_args_into_dataclasses()
return args
args = parse_args()
logger.info(args)
select_device(args.device)
# ### Loading Model and Tokenizer
if args.model_type == 'baichuan-7b':
model_dir = snapshot_download('baichuan-inc/baichuan-7B', 'v1.0.5')
model, tokenizer = get_baichuan_model_tokenizer(model_dir)
elif args.model_type == 'baichuan-13b':
model_dir = snapshot_download('baichuan-inc/Baichuan-13B-Base', 'v1.0.2')
model, tokenizer = get_baichuan_model_tokenizer(model_dir)
elif args.model_type == 'chatglm2':
model_dir = snapshot_download('ZhipuAI/chatglm2-6b', 'v1.0.6')
model, tokenizer = get_chatglm2_model_tokenizer(model_dir)
elif args.model_type == 'llama2-7b':
model_dir = snapshot_download('modelscope/Llama-2-7b-ms', 'v1.0.0')
model, tokenizer = get_llama2_model_tokenizer(model_dir)
else:
raise ValueError(f'model_type: {args.model_type}')
# ### Preparing lora
lora_config = LoRAConfig(
replace_modules=args.lora_target_modules,
rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout_p,
pretrained_weights=args.ckpt_fpath)
logger.info(f'lora_config: {lora_config}')
Swift.prepare_model(model, lora_config)
model.bfloat16() # Consistent with training
# ### Inference
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_config = GenerationConfig(
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id)
logger.info(generation_config)
def inference(data: Dict[str, Optional[str]]) -> str:
input_ids = tokenize_function(data, tokenizer)['input_ids']
print(f'[TEST]{tokenizer.decode(input_ids)}', end='')
input_ids = torch.tensor(input_ids)[None].cuda()
attention_mask = torch.ones_like(input_ids)
generate_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
streamer=streamer,
generation_config=generation_config)
output_text = tokenizer.decode(generate_ids[0])
return output_text
if args.eval_human:
while True:
instruction = input('<<< ')
data = {'instruction': instruction, 'input': None, 'output': None}
inference(data)
print('-' * 80)
else:
_, test_dataset = get_alpaca_en_zh_dataset(
None, True, split_seed=42, data_sample=None)
mini_test_dataset = test_dataset.select(range(10))
for data in mini_test_dataset:
output = data['output']
data['output'] = None
inference(data)
print()
print(f'[LABELS]{output}')
print('-' * 80)
# input('next[ENTER]')

View File

@@ -0,0 +1,237 @@
# ### Setting up experimental environment.
"""
pip install modelscope
pip install numpy pandas matplotlib scikit-learn
pip install transformers datasets
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
pip install tqdm tensorboard torchmetrics sentencepiece charset_normalizer
pip install accelerate transformers_stream_generator
pip install numpy -U # Resolve torchmetrics dependencies and update numpy
"""
from _common import *
@dataclass
class Arguments:
device: str = '0,1' # e.g. '-1'; '0'; '0,1'
seed: int = 42
model_type: str = field(
default='baichuan-7b',
metadata={
'choices':
['baichuan-7b', 'baichuan-13b', 'chatglm2', 'llama2-7b']
})
data_sample: Optional[int] = None
#
lora_target_modules: Optional[List[str]] = None
lora_rank: int = 8
lora_alpha: int = 32
lora_dropout_p: float = 0.1
#
gradient_checkpoint: bool = True
batch_size: int = 1
max_epochs: int = 1
eval_interval: int = 500
learning_rate: float = 1e-4
weight_decay: float = 0.01
n_accumulate_grad: int = 16
grad_clip_norm: float = 1.
warmup_iters: int = 200
last_max_checkpoint_num: int = 1
best_max_checkpoint_num: int = 1
#
logging_interval: int = 5
tb_interval: int = 5
def __post_init__(self):
if self.lora_target_modules is None:
if self.model_type in {'baichuan-7b', 'baichuan-13b'}:
self.lora_target_modules = ['W_pack']
elif self.model_type == 'chatglm2':
self.lora_target_modules = ['query_key_value']
elif self.model_type == 'llama2-7b':
self.lora_target_modules = ['q_proj', 'k_proj', 'v_proj']
else:
raise ValueError(f'model_type: {self.model_type}')
def parse_args() -> Arguments:
args, = HfArgumentParser([Arguments]).parse_args_into_dataclasses()
return args
args = parse_args()
logger.info(args)
select_device(args.device)
seed_everything(args.seed)
# ### Loading Model and Tokenizer
if args.model_type == 'baichuan-7b':
model_dir = snapshot_download('baichuan-inc/baichuan-7B', 'v1.0.5')
model, tokenizer = get_baichuan_model_tokenizer(model_dir)
elif args.model_type == 'baichuan-13b':
model_dir = snapshot_download('baichuan-inc/Baichuan-13B-Base', 'v1.0.2')
model, tokenizer = get_baichuan_model_tokenizer(model_dir)
elif args.model_type == 'chatglm2':
model_dir = snapshot_download('ZhipuAI/chatglm2-6b', 'v1.0.6')
model, tokenizer = get_chatglm2_model_tokenizer(model_dir)
elif args.model_type == 'llama2-7b':
model_dir = snapshot_download('modelscope/Llama-2-7b-ms', 'v1.0.0')
model, tokenizer = get_llama2_model_tokenizer(model_dir)
else:
raise ValueError(f'model_type: {args.model_type}')
#
if args.gradient_checkpoint:
# baichuan13B does not implement the `get_input_embeddings` function
if args.model_type == 'baichuan-13b':
def get_input_embeddings(self):
return self.model.embed_tokens
model.__class__.get_input_embeddings = get_input_embeddings.__get__(
model)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()
# ### Preparing lora
lora_config = LoRAConfig(
replace_modules=args.lora_target_modules,
rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout_p)
logger.info(f'lora_config: {lora_config}')
Swift.prepare_model(model, lora_config)
#
show_freeze_layers(model)
print_model_info(model)
_p: Parameter = list(model.parameters())[100]
logger.info(f'device: {_p.device}, dtype: {_p.dtype}')
model.bfloat16()
# ### Loading Dataset
tokenize_function = partial(tokenize_function, tokenizer=tokenizer)
train_dataset, val_dataset = get_alpaca_en_zh_dataset(
tokenize_function, split_seed=42, data_sample=args.data_sample)
# Data analysis
stat_dataset(train_dataset)
stat_dataset(val_dataset)
data_collate_fn = partial(data_collate_fn, tokenizer=tokenizer)
print_example(train_dataset[0], tokenizer)
# ### Setting Config
cfg_file = os.path.join(model_dir, 'configuration.json')
#
T_max = get_T_max(len(train_dataset), args.batch_size, args.max_epochs, True)
work_dir = get_work_dir(f'runs/{args.model_type}')
config = Config({
'train': {
'dataloader': {
'batch_size_per_gpu': args.batch_size,
'workers_per_gpu': 1,
'shuffle': True,
'drop_last': True,
'pin_memory': True
},
'max_epochs':
args.max_epochs,
'work_dir':
work_dir,
'optimizer': {
'type': 'AdamW',
'lr': args.learning_rate,
'weight_decay': args.weight_decay,
'options': {
'cumulative_iters': args.n_accumulate_grad,
'grad_clip': {
'norm_type': 2,
'max_norm': args.grad_clip_norm
}
}
},
'lr_scheduler': {
'type': 'CosineAnnealingLR',
'T_max': T_max,
'eta_min': 0,
'options': {
'by_epoch': False,
'warmup': {
'type': 'LinearWarmup',
'warmup_ratio': 0.1,
'warmup_iters': args.warmup_iters
}
}
},
'hooks': [
{
'type': 'CheckpointHook',
'by_epoch': False,
'interval': args.eval_interval,
'max_checkpoint_num': args.last_max_checkpoint_num
},
{
'type': 'EvaluationHook',
'by_epoch': False,
'interval': args.eval_interval
},
{
'type': 'BestCkptSaverHook',
'metric_key': 'loss',
'save_best': True,
'rule': 'min',
'max_checkpoint_num': args.best_max_checkpoint_num
},
{
'type': 'TextLoggerHook',
'by_epoch': True, # Whether EpochBasedTrainer is used
'interval': args.logging_interval
},
{
'type': 'TensorboardHook',
'by_epoch': False,
'interval': args.tb_interval
}
]
},
'evaluation': {
'dataloader': {
'batch_size_per_gpu': args.batch_size,
'workers_per_gpu': 1,
'shuffle': False,
'drop_last': False,
'pin_memory': True
},
'metrics': [{
'type': 'my_metric',
'vocab_size': tokenizer.vocab_size
}]
}
})
# ### Finetuning
def cfg_modify_fn(cfg: Config) -> Config:
cfg.update(config)
return cfg
trainer = EpochBasedTrainer(
model=model,
cfg_file=cfg_file,
data_collator=data_collate_fn,
train_dataset=train_dataset,
eval_dataset=val_dataset,
remove_unused_data=True,
seed=42,
device='cpu', # No placement for model, leave the model to `device_map`
cfg_modify_fn=cfg_modify_fn,
)
trainer.train()
# ### Visualization
tb_dir = os.path.join(work_dir, 'tensorboard_output')
plot_image(tb_dir, ['loss'], 0.9)

View File

@@ -0,0 +1,5 @@
python llm_infer.py \
--device 0 \
--model_type llama2-7b \
--ckpt_fpath "runs/llama2-7b/vx_xxx/output_best/pytorch_model.bin" \
--eval_human true

View File

@@ -0,0 +1,8 @@
#!/bin/bash
DATE=$(date +"%Y%m%d-%H%M%S")
nohup python llm_sft.py \
--device 0 \
--model_type llama2-7b \
--data_sample 25000 \
&> train_$DATE.out &

View File

@@ -49,11 +49,9 @@ from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.registry import default_group
#
SYSTEM_TEXT = """{system}"""
USER_TEXT = """\n\n### 用户
{user}"""
ASSISTANT_PROMPT = """\n\n### 助手
"""
PROMPT = """System: {system}
Human: {user}
AI: """
MAX_LENGTH = 2048
TEST_MAX_LENGTH = MAX_LENGTH
@@ -62,11 +60,6 @@ logger = get_logger()
#
def get_model_dir(model_id: str, model_revision: Optional[str] = None) -> str:
model_dir = snapshot_download(model_id, model_revision)
return model_dir
def _get_version(work_dir: str) -> int:
if os.path.isdir(work_dir):
fnames = os.listdir(work_dir)
@@ -93,28 +86,40 @@ def get_work_dir(work_dir: str) -> str:
return work_dir
def select_device(device_ids: List[int]) -> Device:
def _format_device(device: Union[List[int], str]) -> Tuple[List[int], str]:
if isinstance(device, list):
device_ids = device
device_str = ','.join([str(d) for d in device])
else:
device_ids = [int(d) for d in device.split(',') if d != '-1']
device_str = device
device_str = device_str.replace(' ', '')
return device_ids, device_str
def select_device(device: Union[List[int], str]) -> Device:
"""Call this function before cuda is initialized.
Return: master device
device: e.g. []: 'cpu', [0], [0, 1, 2]
e.g. '-1': 'cpu', '0', '0,1,2'
"""
if torch.cuda.is_initialized():
logger.warning('CUDA has been initialized! Device selection fails!')
return torch.device('cuda:0')
#
device_ids, device_str = _format_device(device)
#
os.environ['CUDA_VISIBLE_DEVICES'] = device_str
log_s = 'Using device: '
if len(device_ids) == 0: # cpu
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
device: str = 'cpu'
log_s += device
if len(device_ids) == 0:
master_device: str = 'cpu'
log_s += 'cpu'
else:
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
[str(d) for d in device_ids])
assert torch.cuda.is_available(
) and torch.cuda.device_count() >= len(device_ids)
log_s += f"cuda:{','.join([str(d) for d in device_ids])}" # e.g. 'cuda:1,7,8'
device = 'cuda:0'
master_device = 'cuda:0'
log_s += f'cuda:{device_str}'
logger.info(log_s)
return torch.device(device)
return torch.device(master_device)
def seed_everything(seed: Optional[int] = None, gpu_dtm: bool = False) -> int:
@@ -148,37 +153,27 @@ def get_T_max(dataset_len: int, batch_size: int, max_epochs: int,
def tokenize_function(system: str, user: str, assistant: Optional[str],
tokenizer) -> Dict[str, Any]:
"""Only applicable to baichuan and chatglm2. Other models need to be tested"""
system_text = SYSTEM_TEXT.format(system=system)
user_text = USER_TEXT.format(user=user)
system_text_ids: List[int] = tokenizer(
system_text, return_attention_mask=False,
src_text = PROMPT.format(system=system, user=user)
src_input_ids: List[int] = tokenizer(
src_text, return_attention_mask=False,
add_special_tokens=True)['input_ids']
user_text_ids: List[int] = tokenizer(
user_text, return_attention_mask=False,
add_special_tokens=False)['input_ids']
assistant_p_input_ids: List[int] = tokenizer(
ASSISTANT_PROMPT,
return_attention_mask=False,
add_special_tokens=False)['input_ids']
# tokenizer.bos_token_id: Avoid `assistant` being empty
assistant_input_ids: List[int] = [tokenizer.bos_token_id]
#
tgt_input_ids: List[int] = []
if assistant is not None:
assistant_input_ids += tokenizer(
tgt_input_ids += tokenizer(
assistant, return_attention_mask=False,
add_special_tokens=False)['input_ids']
assistant_input_ids += [tokenizer.eos_token_id]
tgt_input_ids += [tokenizer.eos_token_id]
labels = [-100] * len(src_input_ids) + tgt_input_ids
else:
labels = None
input_ids = src_input_ids + tgt_input_ids
#
input_ids = system_text_ids + user_text_ids + assistant_p_input_ids + assistant_input_ids
if assistant is not None: # train, val
if assistant is not None:
if len(input_ids) > MAX_LENGTH:
return {}
len_mask = len(input_ids) - len(assistant_input_ids)
labels = [-100] * len_mask + assistant_input_ids
else: # test
else:
input_ids = input_ids[-TEST_MAX_LENGTH:]
labels = None
#
return {'input_ids': input_ids, 'labels': labels}
@@ -305,12 +300,21 @@ class MyMetric(Metric):
raise NotImplementedError
def get_baichuan_model_tokenizer(model_dir: Optional[str] = None,
load_model: bool = True):
if model_dir is None:
model_id = 'baichuan-inc/baichuan-7B'
model_dir = get_model_dir(model_id, None)
#
def _add_special_token(tokenizer):
if tokenizer.eos_token_id is None:
tokenizer.eos_token_id = 2
if tokenizer.bos_token_id is None:
tokenizer.bos_token_id = 1
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = 0
logger.info(f'bos_token_id: {tokenizer.bos_token_id}, '
f'eos_token_id: {tokenizer.eos_token_id}, '
f'pad_token_id: {tokenizer.pad_token_id}')
def get_baichuan7B_model_tokenizer(model_dir: str,
load_model: bool = True,
add_special_token: bool = True):
sys.path.insert(0, model_dir)
from configuration_baichuan import BaiChuanConfig
from tokenization_baichuan import BaiChuanTokenizer
@@ -327,15 +331,14 @@ def get_baichuan_model_tokenizer(model_dir: Optional[str] = None,
device_map='auto',
torch_dtype=torch.float16)
#
if add_special_token:
_add_special_token(tokenizer)
return model, tokenizer
def get_chatglm2_model_tokenizer(model_dir: Optional[str] = None,
load_model: bool = True):
if model_dir is None:
model_id = 'ZhipuAI/chatglm2-6b'
model_dir = snapshot_download(model_id, None)
#
def get_chatglm2_model_tokenizer(model_dir: str,
load_model: bool = True,
add_special_token: bool = True):
config = read_config(model_dir)
config['model'] = ConfigDict({'type': 'chatglm2-6b'})
tokenizer = ChatGLM2Tokenizer.from_pretrained(model_dir)
@@ -346,6 +349,8 @@ def get_chatglm2_model_tokenizer(model_dir: Optional[str] = None,
cfg_dict=config,
device_map='auto',
torch_dtype=torch.float16)
if add_special_token:
_add_special_token(tokenizer)
return model, tokenizer

View File

@@ -54,7 +54,6 @@
"from _common import *\n",
"from transformers import TextStreamer\n",
"device_ids = [0, 1]\n",
"logger.info(device_ids)\n",
"select_device(device_ids)"
]
},
@@ -146,9 +145,8 @@
"CKPT_FAPTH = '/home/hackathon/my_git/agent/runs/baichuan/v10-20230702-172449/output_best/pytorch_model.bin'\n",
"LORA_TARGET_MODULES = ['W_pack']\n",
"\n",
"model, tokenizer = get_baichuan_model_tokenizer()\n",
"if tokenizer.pad_token_id is None:\n",
" tokenizer.pad_token_id = tokenizer.eos_token_id\n",
"model_dir = snapshot_download('baichuan-inc/baichuan-7B', 'v1.0.5')\n",
"model, tokenizer = get_baichuan7B_model_tokenizer(model_dir)\n",
"model.bfloat16() # Consistent with training"
]
},
@@ -451,8 +449,8 @@
" attention_mask = torch.ones_like(input_ids)\n",
" generate_ids = model.generate(input_ids=input_ids, max_new_tokens=512,\n",
" attention_mask=attention_mask,\n",
" streamer=streamer, pad_token_id=tokenizer.pad_token_id, \n",
" temperature=0.7, top_k=50, do_sample=True)\n",
" streamer=streamer, pad_token_id=tokenizer.eos_token_id, \n",
" temperature=0.7, top_k=50, top_p=0.7, do_sample=True)\n",
" print()\n",
" print(f'[LABELS]{assistant}')\n",
" print('-----------------------------------------------------------------------------------')\n",

View File

@@ -33,16 +33,12 @@
"metadata": {},
"outputs": [],
"source": [
"# !pip install modelscope -U\n",
"# !pip install modelscope\n",
"# !pip install numpy pandas matplotlib scikit-learn\n",
"# !pip install transformers datasets\n",
"# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n",
"# !pip install tqdm\n",
"# !pip install tensorboard\n",
"# !pip install torchmetrics\n",
"# !pip install sentencepiece\n",
"# !pip install accelerate\n",
"#\n",
"# !conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia\n",
"# !pip install tqdm tensorboard torchmetrics sentencepiece charset_normalizer accelerate\n",
"\n",
"# !pip install numpy -U # Resolve torchmetrics dependencies and update numpy"
]
},
@@ -75,8 +71,7 @@
],
"source": [
"from _common import *\n",
"device_ids = [0, 1, 2, 3]\n",
"logger.info(device_ids)\n",
"device_ids = [0, 1]\n",
"select_device(device_ids)\n",
"_ = seed_everything(42)"
]
@@ -132,22 +127,16 @@
}
],
"source": [
"model_id = 'baichuan-inc/baichuan-7B'\n",
"WORK_DIR = 'runs/baichuan'\n",
"LORA_TARGET_MODULES = ['W_pack']\n",
"#\n",
"model_dir = get_model_dir(model_id, None)\n",
"model, tokenizer = get_baichuan_model_tokenizer(model_dir)\n",
"model_dir = snapshot_download('baichuan-inc/baichuan-7B', 'v1.0.5')\n",
"model, tokenizer = get_baichuan7B_model_tokenizer(model_dir)\n",
"#\n",
"GRADIENT_CHECKPOINTING = True\n",
"if GRADIENT_CHECKPOINTING:\n",
" model.gradient_checkpointing_enable()\n",
" model.enable_input_require_grads()\n",
"if tokenizer.pad_token_id is None:\n",
" tokenizer.pad_token_id = tokenizer.eos_token_id\n",
"#\n",
"logger.info(f'bos_token_id: {tokenizer.bos_token_id}, eos_token_id: {tokenizer.eos_token_id}, '\n",
" f'pad_token_id: {tokenizer.pad_token_id}')"
" model.enable_input_require_grads()"
]
},
{

View File

@@ -55,7 +55,6 @@
"from _common import *\n",
"from transformers import TextStreamer\n",
"device_ids = [0, 1]\n",
"logger.info(device_ids)\n",
"select_device(device_ids)"
]
},
@@ -143,11 +142,8 @@
"CKPT_FAPTH = '/home/hackathon/my_git/agent/runs/chatglm2/v1-20230702-203505/output_best/pytorch_model.bin'\n",
"LORA_TARGET_MODULES = ['query_key_value']\n",
"\n",
"model, tokenizer = get_chatglm2_model_tokenizer()\n",
"if tokenizer.eos_token_id is None:\n",
" tokenizer.eos_token_id = tokenizer.pad_token_id\n",
"if tokenizer.bos_token_id is None:\n",
" tokenizer.bos_token_id = 1\n",
"model_dir = snapshot_download('ZhipuAI/chatglm2-6b', 'v1.0.6')\n",
"model, tokenizer = get_chatglm2_model_tokenizer(model_dir)\n",
"model.bfloat16() # Consistent with training"
]
},
@@ -484,8 +480,8 @@
" attention_mask = torch.ones_like(input_ids)\n",
" generate_ids = model.generate(input_ids=input_ids, max_new_tokens=512,\n",
" attention_mask=attention_mask,\n",
" streamer=streamer, pad_token_id=tokenizer.pad_token_id, \n",
" temperature=0.7, top_k=50, do_sample=True)\n",
" streamer=streamer, pad_token_id=tokenizer.eos_token_id, \n",
" temperature=0.7, top_k=50, top_p=0.7, do_sample=True)\n",
" print()\n",
" print(f'[LABELS]{assistant}')\n",
" print('-----------------------------------------------------------------------------------')\n",

View File

@@ -40,22 +40,18 @@
"metadata": {},
"outputs": [],
"source": [
"# !pip install modelscope -U\n",
"# !pip install modelscope\n",
"# !pip install numpy pandas matplotlib scikit-learn\n",
"# !pip install transformers datasets\n",
"# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n",
"# !pip install tqdm\n",
"# !pip install tensorboard\n",
"# !pip install torchmetrics\n",
"# !pip install sentencepiece\n",
"# !pip install accelerate\n",
"#\n",
"# !conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia\n",
"# !pip install tqdm tensorboard torchmetrics sentencepiece charset_normalizer accelerate\n",
"\n",
"# !pip install numpy -U # Resolve torchmetrics dependencies and update numpy"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [
{
@@ -80,8 +76,7 @@
],
"source": [
"from _common import *\n",
"device_ids = [0, 1, 2, 3]\n",
"logger.info(device_ids)\n",
"device_ids = [0, 1]\n",
"select_device(device_ids)\n",
"_ = seed_everything(42)"
]
@@ -136,25 +131,16 @@
}
],
"source": [
"model_id = 'ZhipuAI/chatglm2-6b'\n",
"WORK_DIR = 'runs/chatglm2'\n",
"LORA_TARGET_MODULES = ['query_key_value']\n",
"#\n",
"model_dir = get_model_dir(model_id, None)\n",
"model_dir = snapshot_download('ZhipuAI/chatglm2-6b', 'v1.0.6')\n",
"model, tokenizer = get_chatglm2_model_tokenizer(model_dir)\n",
"# chatglm2 does not support gradient_checkpointing\n",
"GRADIENT_CHECKPOINTING = False\n",
"#\n",
"GRADIENT_CHECKPOINTING = True\n",
"if GRADIENT_CHECKPOINTING:\n",
" model.gradient_checkpointing_enable()\n",
" model.enable_input_require_grads()\n",
"logger.info(tokenizer.special_tokens)\n",
"if tokenizer.eos_token_id is None:\n",
" tokenizer.eos_token_id = tokenizer.pad_token_id\n",
"if tokenizer.bos_token_id is None:\n",
" tokenizer.bos_token_id = 1\n",
"#\n",
"logger.info(f'bos_token_id: {tokenizer.bos_token_id}, eos_token_id: {tokenizer.eos_token_id}, '\n",
" f'pad_token_id: {tokenizer.pad_token_id}')"
" model.enable_input_require_grads()"
]
},
{

View File

@@ -165,6 +165,7 @@ class Models(object):
doc2bot = 'doc2bot'
peer = 'peer'
llama = 'llama'
llama2 = 'llama2'
chatglm_6b = 'chatglm6b'
chatglm2_6b = 'chatglm2-6b'
@@ -522,6 +523,7 @@ class Pipelines(object):
soonet_video_temporal_grounding = 'soonet-video-temporal-grounding'
efficient_diffusion_tuning = 'efficient-diffusion-tuning'
multimodal_dialogue = 'multimodal-dialogue'
llama2_text_generation_pipeline = 'llama2-text-generation-pipeline'
# science tasks
protein_structure = 'unifold-protein-structure'

View File

@@ -75,6 +75,7 @@ if TYPE_CHECKING:
DocumentGroundedDialogRerankModel)
from .xlm_roberta import XLMRobertaConfig, XLMRobertaModel
from .llama import LlamaForTextGeneration, LlamaConfig, LlamaModel, LlamaTokenizer, LlamaTokenizerFast
from .llama2 import Llama2ForTextGeneration, Llama2Config, Llama2Model, Llama2Tokenizer, Llama2TokenizerFast
else:
_import_structure = {
@@ -170,6 +171,10 @@ else:
'LlamaForTextGeneration', 'LlamaConfig', 'LlamaModel',
'LlamaTokenizer', 'LlamaTokenizerFast'
],
'llama2': [
'Llama2ForTextGeneration', 'Llama2Config', 'Llama2Model',
'Llama2Tokenizer', 'Llama2TokenizerFast'
],
}
import sys

View File

@@ -1095,6 +1095,7 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel):
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-100)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))

View File

@@ -0,0 +1,29 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .configuration import Llama2Config
from .text_generation import Llama2ForTextGeneration
from .backbone import Llama2Model
from .tokenization import Llama2Tokenizer
from .tokenization_fast import Llama2TokenizerFast
else:
_import_structure = {
'configuration': ['Llama2Config'],
'text_generation': ['Llama2ForTextGeneration'],
'backbone': ['Llama2Model'],
'tokenization': ['Llama2Tokenizer'],
'tokenization_fast': ['Llama2TokenizerFast'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,795 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch LLaMA model."""
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from modelscope import Model, TorchModel
from modelscope.metainfo import Models
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from ... import MODELS
from .configuration import Llama2Config
logger = get_logger(__name__)
_CONFIG_FOR_DOC = 'Llama2Config'
# This file is mainly copied from the llama code of transformers
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len),
torch.finfo(dtype).min,
device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
_tmp_value = torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device)
mask = torch.cat([_tmp_value, mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len,
tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor,
dtype: torch.dtype,
tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len,
src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool),
torch.finfo(dtype).min)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance
+ self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base**(torch.arange(
0, self.dim, 2).float().to(device) / self.dim)) # noqa
self.register_buffer('inv_freq', inv_freq)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype())
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer(
'cos_cached',
emb.cos()[None, None, :, :].to(dtype),
persistent=False)
self.register_buffer(
'sin_cached',
emb.sin()[None, None, :, :].to(dtype),
persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(
seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer(
'cos_cached',
emb.cos()[None, None, :, :].to(dtype),
persistent=False)
self.register_buffer(
'sin_cached',
emb.sin()[None, None, :, :].to(dtype),
persistent=False)
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings)
- (self.scaling_factor - 1))**(
self.dim / (self.dim - 2))
inv_freq = 1.0 / (base**(torch.arange(
0, self.dim, 2).float().to(device) / self.dim)) # noqa
self.register_buffer('inv_freq', inv_freq)
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer(
'cos_cached',
emb.cos()[None, None, :, :].to(dtype),
persistent=False)
self.register_buffer(
'sin_cached',
emb.sin()[None, None, :, :].to(dtype),
persistent=False)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.pretraining_tp = config.pretraining_tp
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
if self.pretraining_tp > 1:
slice = self.intermediate_size // self.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat([
F.linear(x, gate_proj_slices[i])
for i in range(self.pretraining_tp)
],
dim=-1) # noqa
up_proj = torch.cat([
F.linear(x, up_proj_slices[i])
for i in range(self.pretraining_tp)
],
dim=-1) # noqa
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(
slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i])
for i in range(self.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(
self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :,
None, :, :].expand(batch,
num_key_value_heads,
n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
head_dim)
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Llama2Config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.pretraining_tp = config.pretraining_tp
self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
f' and `num_heads`: {self.num_heads}).')
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False)
self.v_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False)
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings)
else:
scaling_type = self.config.rope_scaling['type']
scaling_factor = self.config.rope_scaling['factor']
if scaling_type == 'linear':
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor)
elif scaling_type == 'dynamic':
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor)
else:
raise ValueError(f'Unknown RoPE scaling type {scaling_type}')
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads
* self.head_dim) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i])
for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i])
for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i])
for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(
2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
f' {attn_weights.size()}')
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
f' {attn_output.size()}')
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.pretraining_tp > 1:
attn_output = attn_output.split(
self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1)
attn_output = sum([
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: Llama2Config):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, )
if output_attentions:
outputs += (self_attn_weights, )
if use_cache:
outputs += (present_key_value, )
return outputs
class LlamaPreTrainedModel(TorchModel, PreTrainedModel):
config_class = Llama2Config
base_model_prefix = 'model'
supports_gradient_checkpointing = True
_no_split_modules = ['LlamaDecoderLayer']
_skip_keys_device_placement = 'past_key_values'
def __init__(self, config, **kwargs):
super().__init__(config.name_or_path, **kwargs)
super(Model, self).__init__(config)
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LlamaModel):
module.gradient_checkpointing = value
@classmethod
def _instantiate(cls, **kwargs):
"""Instantiate the model.
Args:
kwargs: Input args.
model_dir: The model dir used to load the checkpoint and the label information.
num_labels: An optional arg to tell the model how many classes to initialize.
Method will call utils.parse_label_mapping if num_labels not supplied.
If num_labels is not found, the model will use the default setting (2 classes).
Returns:
The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
"""
model_dir = kwargs.pop('model_dir', None)
if model_dir is None:
config = Llama2Config(**kwargs)
model = cls(config)
else:
model = super(Model, cls).from_pretrained(
pretrained_model_name_or_path=model_dir, **kwargs)
model.model_dir = model_dir
return model
@MODELS.register_module(Tasks.backbone, module_name=Models.llama2)
class Llama2Model(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: Llama2Config
"""
def __init__(self, config: Llama2Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
self.padding_idx)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype,
tgt_len=input_shape[-1]).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else
expanded_attn_mask + combined_attention_mask)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
'You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time'
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
'You have to specify either decoder_input_ids or decoder_inputs_embeds'
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds,
past_key_values_length)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states, )
past_key_value = past_key_values[
idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (
layer_outputs[2 if output_attentions else 1], )
if output_attentions:
all_self_attns += (layer_outputs[1], )
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states, )
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v for v in
[hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

View File

@@ -0,0 +1,165 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LLaMA model configuration"""
from transformers.configuration_utils import PretrainedConfig
from modelscope.utils.logger import get_logger
logger = get_logger(__name__)
LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class Llama2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the LLaMA-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`LlamaModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
pretraining_tp (`int`, *optional*, defaults to `1`):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
"""
model_type = 'llama'
keys_to_ignore_at_inference = ['past_key_values']
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act='silu',
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling,
dict) or len(self.rope_scaling) != 2:
raise ValueError(
'`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, '
f'got {self.rope_scaling}')
rope_scaling_type = self.rope_scaling.get('type', None)
rope_scaling_factor = self.rope_scaling.get('factor', None)
if rope_scaling_type is None or rope_scaling_type not in [
'linear', 'dynamic'
]:
raise ValueError(
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(
rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(
f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}"
)

View File

@@ -0,0 +1,188 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithPast
from modelscope.metainfo import Models
from modelscope.utils.constant import Tasks
from ... import MODELS
from .backbone import Llama2Model, LlamaPreTrainedModel
# This file is mainly copied from the llama code of transformers
@MODELS.register_module(Tasks.text_generation, module_name=Models.llama2)
class Llama2ForTextGeneration(LlamaPreTrainedModel):
_tied_weights_keys = ['lm_head.weight']
def __init__(self, config):
super().__init__(config)
self.model = Llama2Model(config)
self.pretraining_tp = config.pretraining_tp
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(
config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(
self.vocab_size // self.pretraining_tp, dim=0)
logits = [
F.linear(hidden_states, lm_head_slices[i])
for i in range(self.pretraining_tp)
]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits, ) + outputs[1:]
return (loss, ) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get('position_ids', None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
model_inputs = {'input_ids': input_ids}
model_inputs.update({
'position_ids': position_ids,
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
})
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past), )
return reordered_past

View File

@@ -0,0 +1,410 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for LLaMA."""
import os
from shutil import copyfile
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
from modelscope.utils.logger import get_logger
if TYPE_CHECKING:
from transformers.pipelines.conversational import Conversation
logger = get_logger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'tokenizer.model'}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': {
'hf-internal-testing/llama-tokenizer':
'https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model',
},
'tokenizer_file': {
'hf-internal-testing/llama-tokenizer':
'https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json',
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'hf-internal-testing/llama-tokenizer': 2048,
}
SPIECE_UNDERLINE = ''
B_INST, E_INST = '[INST]', '[/INST]'
B_SYS, E_SYS = '<<SYS>>\n', '\n<</SYS>>\n\n'
# fmt: off
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your\
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not\
correct. If you don't know the answer to a question, please don't share false information."""
# fmt: on
class Llama2Tokenizer(PreTrainedTokenizer):
"""
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
no padding token in the original model.
Args:
vocab_file (`str`):
Path to the vocabulary file.
legacy (`bool`, *optional*, defaults to `True`):
Whether or not the `legacy` behaviour of the tokenizer should be used. Legacy is before the merge of #24622
which includes fixes to properly handle tokens that appear after special tokens. A simple example:
- `legacy=True`:
```python
>>> from transformers import T5Tokenizer
>>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=True)
>>> tokenizer.encode("Hello <extra_id_0>.")
[8774, 32099, 3, 5, 1]
```
- `legacy=False`:
```python
>>> from transformers import T5Tokenizer
>>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=False)
>>> tokenizer.encode("Hello <extra_id_0>.") # the extra space `[3]` is no longer here
[8774, 32099, 5, 1]
```
Checkout the pull request and the issue [here](https://github.com/huggingface/transformers/pull/24565) for
more details.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ['input_ids', 'attention_mask']
def __init__(
self,
vocab_file,
unk_token='<unk>',
bos_token='<s>',
eos_token='</s>',
pad_token=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
add_bos_token=True,
add_eos_token=False,
clean_up_tokenization_spaces=False,
legacy=True,
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
bos_token = AddedToken(
bos_token, lstrip=False, rstrip=False) if isinstance(
bos_token, str) else bos_token
eos_token = AddedToken(
eos_token, lstrip=False, rstrip=False) if isinstance(
eos_token, str) else eos_token
unk_token = AddedToken(
unk_token, lstrip=False, rstrip=False) if isinstance(
unk_token, str) else unk_token
pad_token = AddedToken(
pad_token, lstrip=False, rstrip=False) if isinstance(
pad_token, str) else pad_token
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
sp_model_kwargs=self.sp_model_kwargs,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
legacy=legacy,
**kwargs,
)
if legacy:
logger.warning_once(
f'You are using the legacy behaviour of the {self.__class__}. '
f'This means that tokens that come after special '
f'tokens will not be properly handled. We recommend you to'
' read the related pull request available at https://github.com/huggingface/transformers/pull/24565'
)
self.legacy = legacy
self.vocab_file = vocab_file
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
def __getstate__(self):
state = self.__dict__.copy()
state['sp_model'] = None
state['sp_model_proto'] = self.sp_model.serialized_model_proto()
return state
def __setstate__(self, d):
self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
@property
def vocab_size(self):
"""Returns vocab size"""
return self.sp_model.get_piece_size()
def get_vocab(self):
"""Returns vocab as a dict"""
vocab = {
self.convert_ids_to_tokens(i): i
for i in range(self.vocab_size)
}
vocab.update(self.added_tokens_encoder)
return vocab
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
def tokenize(self, text, **kwargs) -> List[str]:
# Replace the SPIECE_UNDERLINE with a space to make sure SPIECE_UNDERLINE is only used at
# the beginning of the text
if not self.legacy:
text = SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, ' ')
return super().tokenize(text, **kwargs)
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
def _tokenize(self, text):
"""
Returns a tokenized string.
Since the sentencepiece internal model always adds a SPIECE_UNDERLINE, at the beginning of the provided text,
we need to remove it by hand when the current text is a subsequence. This happens whenever the `self.tokenize`
function is called with specials tokens: the input is split on the special tokens, and each subsequence is
passed to `_tokenize`. Thus if a subsequence did not start with a `" "` or SPIECE_UNDERLINE, we have to remove
the extra `SPIECE_UNDERLINE` prepended.
"""
if not self.legacy:
is_first = text.startswith(SPIECE_UNDERLINE)
if is_first:
text = text[1:]
tokens = self.sp_model.encode(text, out_type=str)
if not self.legacy and not is_first and not text.startswith(
' ') and tokens[0].startswith(SPIECE_UNDERLINE):
tokens = ([tokens[0][1:]]
if len(tokens[0]) > 1 else []) + tokens[1:]
return tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
token = self.sp_model.IdToPiece(index)
return token
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ''
prev_is_special = False
for i, token in enumerate(tokens):
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special and i != 0:
out_string += ' '
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string
def save_vocabulary(self,
save_directory,
filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if not os.path.isdir(save_directory):
logger.error(
f'Vocabulary path ({save_directory}) should be a directory')
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + '-' if filename_prefix else '')
+ VOCAB_FILES_NAMES['vocab_file'])
if os.path.abspath(self.vocab_file) != os.path.abspath(
out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, 'wb') as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file, )
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = bos_token_id + token_ids_0 + eos_token_id
if token_ids_1 is not None:
output = output + bos_token_id + token_ids_1 + eos_token_id
return output
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True)
bos_token_id = [1] if self.add_bos_token else []
eos_token_id = [1] if self.add_eos_token else []
if token_ids_1 is None:
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
return bos_token_id + (
[0] * len(token_ids_0)) + eos_token_id + bos_token_id + (
[0] * len(token_ids_1)) + eos_token_id # noqa
def create_token_type_ids_from_sequences(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None) -> List[int]:
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
sequence pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
if token_ids_1 is None, only returns the first portion of the mask (0s).
Args:
token_ids_0 (`List[int]`):
List of ids.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
if token_ids_1 is not None:
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
return output
def _build_conversation_input_ids(
self, conversation: 'Conversation') -> List[int]:
"""Builds the input ids for a conversation.
This is the format used in the provided examples. System prompts should be manually added at the beginning of
the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used.
```
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
<bos>[INST] Prompt [/INST] Answer <eos>
<bos>[INST] Prompt [/INST]
```
If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following:
```python
>>> from transformers import Conversation
>>> Conversation(
... "<<SYS>>\n Only answer with emojis, and charades\n<</SYS>>\n\nHow can I build a house in 10 septs?"
... )
```
Args:
conversation (`Conversation`):
Conversation to build input ids for.
Returns:
`List[int]`:
Input ids for the conversation.
"""
dialogue = list(conversation.iter_texts())
if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
[not is_user for is_user, msg in dialogue[1::2]]): # noqa
raise ValueError(
"The model only supports 'user' and 'assistant' roles, "
'starting with user and alternating (u/a/u/a/u...)')
dialog_tokens: List[int] = []
if len(conversation.past_user_inputs) > 0:
if not conversation.past_user_inputs[0].startswith(
B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
conversation.past_user_inputs[0] = (
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS
+ conversation.past_user_inputs[0])
elif not dialogue[0][1].startswith(
B_SYS) or E_SYS not in dialogue[0][1]:
dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT
+ E_SYS + dialogue[0][1])
dialog_tokens += sum(
[[self.bos_token_id] + self.encode(
f'{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ',
add_special_tokens=False) + [self.eos_token_id]
for prompt, answer in zip(dialogue[::2], dialogue[1::2])],
[],
)
if not (dialogue[-1][0]):
raise ValueError(
f"Last message must be from user, got {dialogue[-1]['role']}")
dialog_tokens += [self.bos_token_id] + self.encode(
f'{B_INST} {(dialogue[-1][1]).strip()} {E_INST}',
add_special_tokens=False)
return dialog_tokens

View File

@@ -0,0 +1,249 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from shutil import copyfile
from typing import TYPE_CHECKING, Optional, Tuple
from tokenizers import processors
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from transformers.utils import is_sentencepiece_available, logging
from transformers.utils.versions import require_version
if TYPE_CHECKING:
from transformers.pipelines.conversational import Conversation
require_version('tokenizers>=0.13.3')
if is_sentencepiece_available():
from .tokenization import Llama2Tokenizer
else:
Llama2Tokenizer = None
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
'vocab_file': 'tokenizer.model',
'tokenizer_file': 'tokenizer.json'
}
B_INST, E_INST = '[INST]', '[/INST]'
B_SYS, E_SYS = '<<SYS>>\n', '\n<</SYS>>\n\n'
# fmt: off
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your\
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not\
correct. If you don't know the answer to a question, please don't share false information."""
# fmt: on
class Llama2TokenizerFast(PreTrainedTokenizerFast):
"""
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
This uses notably ByteFallback and no normalization.
```
from transformers import LlamaTokenizerFast
tokenizer = LlaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.encode("Hello this is a test")
>>> [1, 15043, 445, 338, 263, 1243]
```
If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
[post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
contains the vocabulary necessary to instantiate a tokenizer.
tokenizer_file (`str`):
[tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
contains everything needed to load the tokenizer.
clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
spaces.
bos_token (`str`, *optional*, defaults to `"<s>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
"""
vocab_files_names = VOCAB_FILES_NAMES
slow_tokenizer_class = Llama2Tokenizer
padding_side = 'left'
model_input_names = ['input_ids', 'attention_mask']
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
clean_up_tokenization_spaces=False,
unk_token='<unk>',
bos_token='<s>',
eos_token='</s>',
add_bos_token=True,
add_eos_token=False,
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
**kwargs,
)
self._add_bos_token = add_bos_token
self._add_eos_token = add_eos_token
self.update_post_processor()
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
def update_post_processor(self):
"""
Updates the underlying post processor with the current `bos_token` and `eos_token`.
"""
bos = self.bos_token
bos_token_id = self.bos_token_id
eos = self.eos_token
eos_token_id = self.eos_token_id
single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') * self.add_eos_token}"
pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') * self.add_eos_token}"
special_tokens = []
if self.add_bos_token:
special_tokens.append((bos, bos_token_id))
if self.add_eos_token:
special_tokens.append((eos, eos_token_id))
self._tokenizer.post_processor = processors.TemplateProcessing(
single=single, pair=pair, special_tokens=special_tokens)
@property
def add_eos_token(self):
return self._add_eos_token
@property
def add_bos_token(self):
return self._add_bos_token
@add_eos_token.setter
def add_eos_token(self, value):
self._add_eos_token = value
self.update_post_processor()
@add_bos_token.setter
def add_bos_token(self, value):
self._add_bos_token = value
self.update_post_processor()
def save_vocabulary(self,
save_directory: str,
filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow '
'tokenizer.')
if not os.path.isdir(save_directory):
logger.error(
f'Vocabulary path ({save_directory}) should be a directory')
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + '-' if filename_prefix else '')
+ VOCAB_FILES_NAMES['vocab_file'])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file, )
def _build_conversation_input_ids(self, conversation: 'Conversation'):
"""Builds the input ids for a conversation.
This is the format used in the provided examples. System prompts should be manually added at the beginning of
the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used.
```
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
<bos>[INST] Prompt [/INST] Answer <eos>
<bos>[INST] Prompt [/INST]
```
If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following:
```python
>>> from transformers import Conversation
>>> Conversation(
... "<<SYS>>\n Only answer with emojis, and charades\n<</SYS>>\n\nHow can I build a house in 10 septs?"
... )
```
Args:
conversation (`Conversation`):
Conversation to build input ids for.
Returns:
`List[int]`:
Input ids for the conversation.
"""
dialogue = list(conversation.iter_texts())
if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
[not is_user for is_user, msg in dialogue[1::2]]): # noqa
raise ValueError(
"The model only supports 'user' and 'assistant' roles, "
'starting with user and alternating (u/a/u/a/u...)')
dialog_tokens = []
if len(conversation.past_user_inputs) > 0:
if not conversation.past_user_inputs[0].startswith(
B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
conversation.past_user_inputs[0] = (
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS
+ conversation.past_user_inputs[0])
elif not dialogue[0][1].startswith(
B_SYS) or E_SYS not in dialogue[0][1]:
dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT
+ E_SYS + dialogue[0][1])
dialog_tokens += sum(
[[self.bos_token_id] + self.encode(
f'{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ',
add_special_tokens=False) + [self.eos_token_id]
for prompt, answer in zip(dialogue[::2], dialogue[1::2])],
[],
)
if not (dialogue[-1][0]):
raise ValueError(
f"Last message must be from user, got {dialogue[-1]['role']}")
dialog_tokens += [self.bos_token_id] + self.encode(
f'{B_INST} {(dialogue[-1][1]).strip()} {E_INST}',
add_special_tokens=False)
return dialog_tokens

View File

@@ -210,7 +210,12 @@ class AutomaticSpeechRecognitionPipeline(Pipeline):
if isinstance(audio_in, str):
# for funasr code, generate wav.scp from url or local path
self.audio_in, self.raw_inputs = generate_scp_from_url(audio_in)
if audio_in.startswith('http') or os.path.isfile(audio_in):
self.audio_in, self.raw_inputs = generate_scp_from_url(
audio_in)
else:
raise FileNotFoundError(
f'file {audio_in} NOT FOUND, please CHECK!')
elif isinstance(audio_in, bytes):
self.audio_in = audio_in
self.raw_inputs = None

View File

@@ -232,7 +232,13 @@ class SpeakerDiarizationPipeline(Pipeline):
def forward(self, audio_in: Union[tuple, str, Any] = None) -> list:
"""Decoding
"""
logger.info('Speaker Diarization Processing: {0} ...'.format(audio_in))
# log file_path/url or tuple (str, str)
if isinstance(audio_in, str) or \
(isinstance(audio_in, tuple) and all(isinstance(item, str) for item in audio_in)):
logger.info(f'Speaker Verification Processing: {audio_in} ...')
else:
logger.info(
f'Speaker Verification Processing: {str(audio_in)[:100]} ...')
data_cmd, raw_inputs = None, None
if isinstance(audio_in, tuple) or isinstance(audio_in, list):

View File

@@ -180,8 +180,13 @@ class SpeakerVerificationPipeline(Pipeline):
def forward(self, audio_in: Union[tuple, str, Any] = None) -> list:
"""Decoding
"""
logger.info(
'Speaker Verification Processing: {0} ...'.format(audio_in))
# log file_path/url or tuple (str, str)
if isinstance(audio_in, str) or \
(isinstance(audio_in, tuple) and all(isinstance(item, str) for item in audio_in)):
logger.info(f'Speaker Verification Processing: {audio_in} ...')
else:
logger.info(
f'Speaker Verification Processing: {str(audio_in)[:100]} ...')
data_cmd, raw_inputs = None, None
if isinstance(audio_in, tuple) or isinstance(audio_in, list):

View File

@@ -146,7 +146,8 @@ class _DiffuersChineseStableDiffusionPipeline(StableDiffusionPipeline):
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None):
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None):
r"""
Encodes the prompt into text encoder hidden states.
@@ -169,7 +170,14 @@ class _DiffuersChineseStableDiffusionPipeline(StableDiffusionPipeline):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):

View File

@@ -0,0 +1,99 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright (c) 2022 Zhipu.AI
from typing import Any, Dict, Union
import torch
from modelscope import Model, snapshot_download
from modelscope.metainfo import Pipelines, Preprocessors
from modelscope.models.nlp.llama2 import Llama2Tokenizer
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.pipelines.nlp.text_generation_pipeline import \
TextGenerationPipeline
from modelscope.preprocessors import Preprocessor
from modelscope.utils.constant import Fields, Tasks
@PIPELINES.register_module(
Tasks.text_generation,
module_name=Pipelines.llama2_text_generation_pipeline)
class Llama2TaskPipeline(TextGenerationPipeline):
def __init__(self,
model: Union[Model, str],
preprocessor: Preprocessor = None,
config_file: str = None,
device: str = 'gpu',
auto_collate=True,
**kwargs):
"""Use `model` and `preprocessor` to create a generation pipeline for prediction.
Args:
model (str or Model): Supply either a local model dir which supported the text generation task,
or a model id from the model hub, or a torch model instance.
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
the model if supplied.
kwargs (dict, `optional`):
Extra kwargs passed into the preprocessor's constructor.
Examples:
>>> from modelscope.utils.constant import Tasks
>>> import torch
>>> from modelscope.pipelines import pipeline
>>> from modelscope import snapshot_download, Model
>>> model_dir = snapshot_download("modelscope/Llama-2-13b-chat-ms",
>>> ignore_file_pattern = [r'\\w+\\.safetensors'])
>>> pipe = pipeline(task=Tasks.text_generation, model=model_dir, device_map='auto',
>>> torch_dtype=torch.float16)
>>> inputs="咖啡的作用是什么?"
>>> result = pipe(inputs,max_length=200, do_sample=True, top_p=0.85,
>>> temperature=1.0, repetition_penalty=1., eos_token_id=2, bos_token_id=1, pad_token_id=0)
>>> print(result['text'])
To view other examples plese check tests/pipelines/test_llama2_text_generation_pipeline.py.
"""
self.model = Model.from_pretrained(
model, device_map='auto', torch_dtype=torch.float16)
self.tokenizer = Llama2Tokenizer.from_pretrained(model)
super().__init__(model=self.model, **kwargs)
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
return inputs
def _sanitize_parameters(self, **pipeline_parameters):
return {}, pipeline_parameters, {}
def forward(self,
inputs,
max_length=50,
do_sample=True,
top_p=0.85,
temperature=1.0,
repetition_penalty=1.,
eos_token_id=2,
bos_token_id=1,
pad_token_id=0,
**forward_params) -> Dict[str, Any]:
output = {}
inputs = self.tokenizer(inputs, return_tensors='pt')
generate_ids = self.model.generate(
inputs.input_ids.to('cuda'),
max_length=max_length,
do_sample=do_sample,
top_p=top_p,
temperature=temperature,
repetition_penalty=repetition_penalty,
eos_token_id=eos_token_id,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
**forward_params)
out = self.tokenizer.batch_decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)[0]
output['text'] = out
return output
# format the outputs from pipeline
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
return input

View File

@@ -50,6 +50,7 @@ class CheckpointHook(Hook):
hub_revision (str): Which branch to push the model to, default is `master`.
upload_strategy (str): The action adopted when the previous uploading is not done
and the next one is coming, can be `cancel` or `wait`.
save_trainer_state (bool): Save the trainer state for continue training, default True.
kwargs:
by_epoch (bool): Same with `save_strategy`, but has a higher priority, legacy argument.
output_sub_dir (str): The folder under the `save_dir` to save the output checkpoint for inference.
@@ -75,6 +76,7 @@ class CheckpointHook(Hook):
private_hub: Optional[bool] = True,
hub_revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
upload_strategy: Optional[str] = UploadStrategy.cancel,
save_trainer_state: Optional[bool] = True,
**kwargs):
self.interval = interval
self.save_dir = save_dir
@@ -97,6 +99,7 @@ class CheckpointHook(Hook):
self.private_hub = private_hub
self.hub_revision = hub_revision
self.upload_strategy = upload_strategy
self.save_trainer_state = save_trainer_state
self.tag = -1
self.is_model_id = None
self.max_checkpoint_num = None
@@ -219,7 +222,8 @@ class CheckpointHook(Hook):
checkpoint_path_prefix = os.path.join(self.save_dir, prefix)
meta = self._create_training_state(trainer)
self.processor.save_checkpoints(trainer, checkpoint_path_prefix,
self.output_dir, meta)
self.output_dir, meta,
self.save_trainer_state)
self.save_evaluate_results(trainer)
self.history_checkpoints.append(checkpoint_path_prefix)
self._remove_obsolete_checkpoints(trainer)
@@ -399,7 +403,8 @@ class BestCkptSaverHook(CheckpointHook):
self._best_ckpt_file = checkpoint_path_prefix
meta = self._create_training_state(trainer)
self.processor.save_checkpoints(trainer, checkpoint_path_prefix,
self.output_dir, meta)
self.output_dir, meta,
self.save_trainer_state)
self.save_evaluate_results(trainer)
self.history_checkpoints.add(checkpoint_path_prefix)
self._remove_obsolete_checkpoints(trainer)

View File

@@ -104,7 +104,8 @@ class CheckpointProcessor:
trainer,
checkpoint_path_prefix,
output_dir,
meta=None):
meta=None,
save_optimizers=True):
"""Save the state dict for trainer and model.
This is a strategic function which can be registered by other hook's function.
@@ -115,13 +116,15 @@ class CheckpointProcessor:
like: /tmp/test/epoch_0
output_dir(`str`): The output dir for inference.
meta: (`dict`): The meta info needed to be saved into files.
save_optimizers: (`bool`): Do save the optimizers state
"""
model = trainer.unwrap_module(trainer.model)
_model_file, _train_state_file = self._get_state_file_name(
checkpoint_path_prefix)
# Save pth file without model state_dict
self.save_trainer_state(trainer, model, _train_state_file, meta)
self.save_trainer_state(trainer, model, _train_state_file, meta,
save_optimizers)
self.save_model_state(model, _model_file)
self.link(model, _model_file, output_dir)
@@ -175,7 +178,8 @@ class CheckpointProcessor:
'changing to copy the bin file, this may use more disk space.')
shutil.copyfile(src_file, dest_file)
def save_trainer_state(self, trainer, model, train_state_file, meta):
def save_trainer_state(self, trainer, model, train_state_file, meta,
save_optimizers):
"""Save the trainer state, including optimizer/lr_scheduler's state dict, random states etc.
Args:
@@ -183,12 +187,13 @@ class CheckpointProcessor:
model: The model instance.
train_state_file: The target file name for saving trainer states.
meta: Some extra meta info.
save_optimizers: Save optimizers state or not.
"""
save_checkpoint(
model,
train_state_file,
trainer.optimizer,
trainer.lr_scheduler,
trainer.optimizer if save_optimizers else None,
trainer.lr_scheduler if save_optimizers else None,
meta=meta,
with_model=False)

View File

@@ -156,7 +156,8 @@ class DeepspeedProcessor(CheckpointProcessor, LrSchedulerProcessor,
trainer,
checkpoint_path_prefix,
output_dir,
meta=None):
meta=None,
save_optimizers=True):
model = trainer.unwrap_module(trainer.model)
_train_state_file = checkpoint_path_prefix + self.rank_name(
) + CheckpointProcessor.TRAINER_STATE_SUFFIX

View File

@@ -57,7 +57,8 @@ class MpuProcessor(CheckpointProcessor):
trainer,
checkpoint_path_prefix,
output_dir,
meta=None):
meta=None,
save_optimizers=True):
model = trainer.unwrap_module(trainer.model)
_train_state_file = checkpoint_path_prefix + self.rank_name(
) + CheckpointProcessor.TRAINER_STATE_SUFFIX
@@ -65,8 +66,8 @@ class MpuProcessor(CheckpointProcessor):
save_checkpoint(
model,
_train_state_file,
trainer.optimizer,
trainer.lr_scheduler,
trainer.optimizer if save_optimizers else None,
trainer.lr_scheduler if save_optimizers else None,
meta=meta,
with_model=False)

View File

@@ -41,7 +41,8 @@ class DreamboothCheckpointProcessor(CheckpointProcessor):
trainer,
checkpoint_path_prefix,
output_dir,
meta=None):
meta=None,
save_optimizers=True):
"""Save the state dict for dreambooth model.
"""
pipeline_args = {}

View File

@@ -21,7 +21,8 @@ class LoraDiffusionCheckpointProcessor(CheckpointProcessor):
trainer,
checkpoint_path_prefix,
output_dir,
meta=None):
meta=None,
save_optimizers=True):
"""Save the state dict for lora tune model.
"""
trainer.model.unet = trainer.model.unet.to(torch.float32)

View File

@@ -25,7 +25,7 @@ def get_logger(log_file: Optional[str] = None,
logger_name = __name__.split('.')[0]
logger = logging.getLogger(logger_name)
logger.propagate = False
if logger_name in init_loggers:
add_file_handler_if_needed(logger, log_file, file_mode, log_level)
return logger

View File

@@ -0,0 +1,47 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import torch
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class Llama2TextGenerationPipelineTest(unittest.TestCase):
def setUp(self) -> None:
self.llama2_model_id_7B_chat_ms = 'modelscope/Llama-2-7b-chat-ms'
self.llama2_input_chat_ch = '天空为什么是蓝色的?'
def run_pipeline_with_model_id(self,
model_id,
input,
init_kwargs={},
run_kwargs={}):
pipeline_ins = pipeline(
task=Tasks.text_generation, model=model_id, **init_kwargs)
pipeline_ins._model_prepare = True
result = pipeline_ins(input, **run_kwargs)
print(result['text'])
# 7B_ms_chat
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_llama2_7B_chat_ms_with_model_name_with_chat_ch_with_args(self):
self.run_pipeline_with_model_id(
self.llama2_model_id_7B_chat_ms,
self.llama2_input_chat_ch,
init_kwargs={
'device_map': 'auto',
'torch_dtype': torch.float16
},
run_kwargs={
'max_length': 200,
'do_sample': True,
'top_p': 0.85
})
if __name__ == '__main__':
unittest.main()