new Feat/0817 (#504)

This commit is contained in:
Jintao
2023-08-29 16:43:36 +08:00
committed by GitHub
parent 4bb190e419
commit 2ee5ebaf35
22 changed files with 435 additions and 3245 deletions

2
.gitignore vendored
View File

@@ -124,6 +124,8 @@ replace.sh
result.png
result.jpg
result.mp4
runs/
ckpt/
# Pytorch
*.pth

View File

@@ -1,4 +1,3 @@
<h1 align="center">LLM SFT Example</h1>
<p align="center">
@@ -14,46 +13,59 @@
<a href="README_CN.md">中文</a>&nbsp &nbspEnglish
</p>
## Note!!!
## Note
1. This README.md file is **copied from** [ms-swift](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/README.md)
2. This directory has been **migrated** to [ms-swift](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm), and the files in this directory are **no longer maintained**.
## Features
1. supported sft method: lora, qlora, full, ...
2. supported models: [**qwen-7b**](https://github.com/QwenLM/Qwen-7B), baichuan-7b, baichuan-13b, chatglm2-6b, llama2-7b, llama2-13b, llama2-70b, openbuddy-llama2-13b, ...
1. supported sft method: [lora](https://arxiv.org/abs/2106.09685), [qlora](https://arxiv.org/abs/2305.14314), full(full parameter fine tuning), ...
2. supported models: [**qwen-7b**](https://github.com/QwenLM/Qwen-7B), baichuan-7b, baichuan-13b, chatglm2-6b, chatglm2-6b-32k, llama2-7b, llama2-13b, llama2-70b, openbuddy-llama2-13b, openbuddy-llama-65b, polylm-13b, ...
3. supported feature: quantization, ddp, model parallelism(device map), gradient checkpoint, gradient accumulation steps, push to modelscope hub, custom datasets, ...
4. supported datasets: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, ...
## Prepare the Environment
Experimental environment: A10, 3090, A100, ... (V100 does not support bf16, quantization)
```bash
# Please note the cuda version
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y
# Installing miniconda
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
sh Miniconda3-latest-Linux-x86_64.sh
# Setting up a conda virtual environment
conda create --name ms-sft python=3.10
conda activate ms-sft
# Setting up a global pip mirror for faster downloads
pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/
pip install torch torchvision torchaudio -U
pip install sentencepiece charset_normalizer cpm_kernels tiktoken -U
pip install matplotlib tqdm tensorboard -U
pip install matplotlib scikit-learn tqdm tensorboard -U
pip install transformers datasets -U
pip install accelerate transformers_stream_generator -U
pip install ms-swift modelscope -U
# Recommended installation from source code for faster bug fixes
git clone https://github.com/modelscope/swift.git
cd swift
pip install -r requirements.txt
pip install .
# same as modelscope...(git clone ...)
# You can also install it from pypi
pip install ms-swift modelscope -U
```
## Run SFT and Inference
```bash
# Clone the repository and enter the code directory.
git clone https://github.com/modelscope/swift.git
cd swift/examples/pytorch/llm
# sft(qlora) and infer qwen-7b, Requires 10GB VRAM.
# sft(qlora) and infer qwen-7b, Requires 16GB VRAM.
# If you want to use quantification, you need to `pip install bitsandbytes`
bash scripts/qwen_7b/qlora/sft.sh
# If you want to push the model to modelscope hub during training
bash scripts/qwen_7b/qlora/sft_push_to_hub.sh
bash scripts/qwen_7b/qlora/infer.sh
# sft(qlora+ddp) and infer qwen-7b, Requires 4*10GB VRAM.
# sft(qlora+ddp) and infer qwen-7b, Requires 4*16GB VRAM.
bash scripts/qwen_7b/qlora_ddp/sft.sh
bash scripts/qwen_7b/qlora_ddp/infer.sh

View File

@@ -1,4 +1,3 @@
<h1 align="center">大模型微调的例子</h1>
<p align="center">
@@ -14,47 +13,61 @@
中文&nbsp &nbsp<a href="README.md">English</a>
</p>
## 请注意!!!
## 请注意
1. 该README_CN.md**拷贝**自[ms-swift](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/README_CN.md)
2. 该目录已经**迁移**至[ms-swift](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm), 此目录中的文件**不再维护**.
## 特性
1. 支持的sft方法: lora, qlora, 全参数微调, ...
2. 支持的模型: [**qwen-7b**](https://github.com/QwenLM/Qwen-7B), baichuan-7b, baichuan-13b, chatglm2-6b, llama2-7b, llama2-13b, llama2-70b, openbuddy-llama2-13b, ...
1. [lora](https://arxiv.org/abs/2106.09685), [qlora](https://arxiv.org/abs/2305.14314), 全参数微调, ...
2. 支持的模型: [**qwen-7b**](https://github.com/QwenLM/Qwen-7B), baichuan-7b, baichuan-13b, chatglm2-6b, chatglm2-6b-32k, llama2-7b, llama2-13b, llama2-70b, openbuddy-llama2-13b, openbuddy-llama-65b, polylm-13b, ...
3. 支持的特性: 模型量化, DDP, 模型并行(device_map), gradient checkpoint, 梯度累加, 支持推送modelscope hub, 支持自定义数据集, ...
4. 支持的数据集: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, ...
## 准备实验环境
实验环境: A10, 3090, A100均可. (V100不支持bf16, 量化)
```bash
# 请注意修改cuda的版本
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y
# 安装miniconda
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
# 一直[ENTER], 最后一个选项yes即可
sh Miniconda3-latest-Linux-x86_64.sh
# conda虚拟环境搭建
conda create --name ms-sft python=3.10
conda activate ms-sft
# pip设置全局镜像与相关python包安装
pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/
pip install torch torchvision torchaudio -U
pip install sentencepiece charset_normalizer cpm_kernels tiktoken -U
pip install matplotlib tqdm tensorboard -U
pip install matplotlib scikit-learn tqdm tensorboard -U
pip install transformers datasets -U
pip install accelerate transformers_stream_generator -U
pip install ms-swift modelscope -U
# 推荐从源码安装swift和modelscope, 这具有更多的特性和更快的bug修复
git clone https://github.com/modelscope/swift.git
cd swift
pip install -r requirements.txt
pip install .
# modelscope类似...(git clone ...)
# 当然, 你也可以从pypi上下载
pip install ms-swift modelscope -U
```
## 微调和推理
```bash
# clone仓库并进入代码目录
git clone https://github.com/modelscope/swift.git
cd swift/examples/pytorch/llm
# 微调(qlora)+推理 qwen-7b, 需要10G显存.
# 微调(qlora)+推理 qwen-7b, 需要16GB显存.
# 如果你想要使用量化, 你需要`pip install bitsandbytes`
bash scripts/qwen_7b/qlora/sft.sh
# 如果你想在训练时, 将权重push到modelscope hub中.
bash scripts/qwen_7b/qlora/sft_push_to_hub.sh
bash scripts/qwen_7b/qlora/infer.sh
# 微调(qlora+ddp)+推理 qwen-7b, 需要4卡*10G显存.
# 微调(qlora+ddp)+推理 qwen-7b, 需要4卡*16GB显存.
bash scripts/qwen_7b/qlora_ddp/sft.sh
bash scripts/qwen_7b/qlora_ddp/infer.sh

View File

@@ -26,8 +26,12 @@ if TYPE_CHECKING:
from .pipelines import Pipeline, pipeline
from .utils.hub import read_config, create_model_if_not_exist
from .utils.logger import get_logger
from .utils.constant import Tasks
from .utils.hf_util import AutoConfig, GenerationConfig
from .utils.hf_util import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM
from .utils.hf_util import (AutoModel, AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification)
from .utils.hf_util import AutoTokenizer
from .msdatasets import MsDataset
@@ -68,9 +72,12 @@ else:
'pipelines': ['Pipeline', 'pipeline'],
'utils.hub': ['read_config', 'create_model_if_not_exist'],
'utils.logger': ['get_logger'],
'utils.constant': ['Tasks'],
'utils.hf_util': [
'AutoConfig', 'GenerationConfig', 'AutoModel',
'AutoModelForCausalLM', 'AutoModelForSeq2SeqLM', 'AutoTokenizer'
'AutoModelForCausalLM', 'AutoModelForSeq2SeqLM', 'AutoTokenizer',
'AutoModelForSequenceClassification',
'AutoModelForTokenClassification'
],
'msdatasets': ['MsDataset']
}

View File

@@ -4,10 +4,11 @@ import os.path as osp
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from modelscope.hub.check_model import check_local_model_is_latest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Tasks
from modelscope.models.builder import build_backbone, build_model
from modelscope.utils.automodel_utils import (can_load_by_ms,
try_to_load_hf_model)
from modelscope.utils.config import Config
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
from modelscope.utils.device import verify_device
@@ -84,12 +85,22 @@ class Model(ABC):
device(str, `optional`): The device to load the model.
**kwargs:
task(str, `optional`): The `Tasks` enumeration value to replace the task value
read out of config in the `model_name_or_path`. This is useful when the model to be loaded is not
equal to the model saved.
For example, load a `backbone` into a `text-classification` model.
Other kwargs will be directly fed into the `model` key, to replace the default configs.
use_hf(bool): If set True, will use AutoModel in hf to initialize the model to keep compatibility
with huggingface transformers.
read out of config in the `model_name_or_path`. This is useful when the model to be loaded is not
equal to the model saved.
For example, load a `backbone` into a `text-classification` model.
Other kwargs will be directly fed into the `model` key, to replace the default configs.
use_hf(bool, `optional`):
If set to True, it will initialize the model using AutoModel or AutoModelFor* from hf.
If set to False, the model is loaded using the modelscope mode.
If set to None, the loading mode will be automatically selected.
ignore_file_pattern(List[str], `optional`):
This parameter is passed to snapshot_download
device_map(str | Dict[str, str], `optional`):
This parameter is passed to AutoModel or AutoModelFor*
torch_dtype(torch.dtype, `optional`):
This parameter is passed to AutoModel or AutoModelFor*
config(PretrainedConfig, `optional`):
This parameter is passed to AutoModel or AutoModelFor*
Returns:
A model instance.
@@ -115,14 +126,14 @@ class Model(ABC):
)
invoked_by = '%s/%s' % (Invoke.KEY, invoked_by)
ignore_file_pattern = kwargs.get('ignore_file_pattern', None)
local_model_dir = snapshot_download(
model_name_or_path, revision, user_agent=invoked_by)
model_name_or_path,
revision,
user_agent=invoked_by,
ignore_file_pattern=ignore_file_pattern)
logger.info(f'initialize model from {local_model_dir}')
if kwargs.pop('use_hf', False):
from modelscope import AutoModel
return AutoModel.from_pretrained(local_model_dir)
if cfg_dict is not None:
cfg = cfg_dict
else:
@@ -134,6 +145,23 @@ class Model(ABC):
model_cfg = cfg.model
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
model_cfg.type = model_cfg.model_type
model_type = model_cfg.type
if isinstance(device, str) and device.startswith('gpu'):
device = 'cuda' + device[3:]
use_hf = kwargs.pop('use_hf', None)
if use_hf is None and can_load_by_ms(local_model_dir, task_name,
model_type):
use_hf = False
model = None
if use_hf in {True, None}:
model = try_to_load_hf_model(local_model_dir, task_name, use_hf,
**kwargs)
if model is not None:
device_map = kwargs.get('device_map', None)
if device_map is None and device is not None:
model = model.to(device)
return model
# use ms
model_cfg.model_dir = local_model_dir
# install and import remote repos before build

View File

@@ -1,23 +1,24 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from transformers.models.llama import (LlamaConfig, LlamaTokenizer,
LlamaTokenizerFast)
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .configuration import LlamaConfig
from .text_generation import LlamaForTextGeneration
from .backbone import LlamaModel
from .tokenization import LlamaTokenizer
from .tokenization_fast import LlamaTokenizerFast
from .text_generation import LlamaForTextGeneration
else:
_import_structure = {
'configuration': ['LlamaConfig'],
'text_generation': ['LlamaForTextGeneration'],
'backbone': ['LlamaModel'],
'tokenization': ['LlamaTokenizer'],
'tokenization_fast': ['LlamaTokenizerFast'],
'text_generation': ['LlamaForTextGeneration'],
}
_extra_objects = {
'LlamaConfig': LlamaConfig,
'LlamaTokenizer': LlamaTokenizer,
'LlamaTokenizerFast': LlamaTokenizerFast,
}
import sys
sys.modules[__name__] = LazyImportModule(
@@ -25,5 +26,5 @@ else:
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
extra_objects=_extra_objects,
)

View File

@@ -18,389 +18,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch LLaMA model."""
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.models.llama import LlamaConfig
from transformers.models.llama import LlamaModel as LlamaModelHF
from transformers.models.llama import \
LlamaPreTrainedModel as LlamaPreTrainedModelHF
from modelscope.metainfo import Models
from modelscope.models import Model, TorchModel
from modelscope.models.builder import MODELS
from modelscope.outputs import AttentionBackboneModelOutput
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from .configuration import LlamaConfig
logger = get_logger()
_CONFIG_FOR_DOC = 'LlamaConfig'
# This file is mainly copied from the llama code of transformers
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len),
torch.tensor(torch.finfo(dtype).min, device=device),
device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
tgt_len,
past_key_values_length, # noqa
dtype=dtype,
device=device),
mask
],
dim=-1) # noqa
return mask[None, None, :, :].expand(bsz, 1, tgt_len,
tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor,
dtype: torch.dtype,
tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len,
src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool),
torch.finfo(dtype).min)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(
-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance
+ self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None):
super().__init__()
inv_freq = 1.0 / (
base**(torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer('inv_freq', inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(
self.max_seq_len_cached,
device=self.inv_freq.device,
dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer(
'cos_cached', emb.cos()[None, None, :, :], persistent=False)
self.register_buffer(
'sin_cached', emb.sin()[None, None, :, :], persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(
self.max_seq_len_cached,
device=x.device,
dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer(
'cos_cached', emb.cos()[None, None, :, :], persistent=False)
self.register_buffer(
'sin_cached', emb.sin()[None, None, :, :], persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
cos = torch.gather(
cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
sin = torch.gather(
sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
f' and `num_heads`: {self.num_heads}).')
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(
bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(
bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(
bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
attn_weights = torch.matmul(query_states, key_states.transpose(
2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
f' {attn_weights.size()}')
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights,
torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
f' {attn_output.size()}')
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, )
if output_attentions:
outputs += (self_attn_weights, )
if use_cache:
outputs += (present_key_value, )
return outputs
class LlamaPreTrainedModel(TorchModel, PreTrainedModel):
config_class = LlamaConfig
base_model_prefix = 'model'
supports_gradient_checkpointing = True
_no_split_modules = ['LlamaDecoderLayer']
_keys_to_ignore_on_load_unexpected = [r'decoder\.version']
def __init__(self, config, **kwargs):
super().__init__(config.name_or_path, **kwargs)
super(Model, self).__init__(config)
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LlamaModel):
module.gradient_checkpointing = value
class MsModelMixin:
@classmethod
def _instantiate(cls, **kwargs):
@@ -416,272 +48,22 @@ class LlamaPreTrainedModel(TorchModel, PreTrainedModel):
Returns:
The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
"""
model_dir = kwargs.pop('model_dir', None)
if model_dir is None:
config = LlamaConfig(**kwargs)
model = cls(config)
else:
model = super(Model, cls).from_pretrained(
model = super(MsModelMixin, cls).from_pretrained(
pretrained_model_name_or_path=model_dir, **kwargs)
model.model_dir = model_dir
return model
class LlamaPreTrainedModel(MsModelMixin, LlamaPreTrainedModelHF, TorchModel):
pass
@MODELS.register_module(Tasks.backbone, module_name=Models.llama2)
@MODELS.register_module(Tasks.backbone, module_name=Models.llama)
class LlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig, **kwargs):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
self.padding_idx)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype,
tgt_len=input_shape[-1]).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else
expanded_attn_mask + combined_attention_mask)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, AttentionBackboneModelOutput]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Padding will be ignored by default should you provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range `[0, config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed
or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`,
with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape `(batch_size, 1)`
instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
'You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time'
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
'You have to specify either decoder_input_ids or decoder_inputs_embeds'
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds,
past_key_values_length)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states, )
past_key_value = past_key_values[
idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (
layer_outputs[2 if output_attentions else 1], )
if output_attentions:
all_self_attns += (layer_outputs[1], )
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states, )
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v for v in
[hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None)
return AttentionBackboneModelOutput(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class LlamaModel(MsModelMixin, LlamaModelHF, TorchModel):
pass

View File

@@ -1,101 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LLaMA model configuration"""
from transformers.configuration_utils import PretrainedConfig
LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
# This file is mainly copied from the llama code of transformers
class LlamaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the LLaMA-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`LlamaModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
"""
model_type = 'llama'
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
hidden_act='silu',
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@@ -20,8 +20,8 @@ import shutil
import json
import torch
from transformers.models.llama import LlamaConfig
from .configuration import LlamaConfig
from .text_generation import LlamaForTextGeneration
# This file is mainly copied from the llama code of transformers

View File

@@ -19,165 +19,97 @@
# limitations under the License.
from typing import Dict, List, Optional, Tuple, Union
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
import torch
from transformers.models.llama import LlamaForCausalLM
from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.outputs import AttentionTextGenerationModelOutput
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import Tasks
from modelscope.utils.streaming_output import \
PretrainedModelStreamingOutputMixin
from .backbone import LlamaModel, LlamaPreTrainedModel
from .backbone import MsModelMixin
def get_chat_prompt(system: str, text: str, history: List[Tuple[str, str]],
max_length: int, tokenizer):
system_prompt = f'<s>[INST] <<SYS>>\n{system}\n<</SYS>>\n\n'
system_ids = tokenizer(system_prompt, return_tensors='pt').input_ids
text_prompt = f'{text.strip()} [/INST]'
text_ids = tokenizer(text_prompt, return_tensors='pt').input_ids
prompt_length = system_ids.shape[-1] + text_ids.shape[-1]
if prompt_length > max_length:
raise RuntimeError(
f'prepend prompt length {prompt_length} is bigger than max_length {max_length}'
)
history_prompt = ''
history_ids_list = []
# traverse history in reverse order
for user, bot in history[::-1]:
assert isinstance(user, str)
assert isinstance(bot, str)
round_prompt = f'{user.strip()} [/INST] {bot.strip()} </s><s>[INST] '
round_ids = tokenizer(round_prompt, return_tensors='pt').input_ids
if prompt_length + round_ids.shape[-1] > max_length:
# excess history should not be appended to the prompt
break
else:
history_prompt = round_prompt + history_prompt
history_ids_list = [round_ids] + history_ids_list
prompt_length += round_ids.shape[-1]
prompt_list = [system_prompt, history_prompt, text_prompt]
prompt_ids_list = [system_ids] + history_ids_list + [text_ids]
return ''.join(prompt_list), torch.cat(prompt_ids_list, dim=1)
# This file is mainly copied from the llama code of transformers
@MODELS.register_module(Tasks.text_generation, module_name=Models.llama2)
@MODELS.register_module(Tasks.text_generation, module_name=Models.llama)
class LlamaForTextGeneration(LlamaPreTrainedModel,
PretrainedModelStreamingOutputMixin):
_keys_to_ignore_on_load_missing = [r'lm_head.weight']
class LlamaForTextGeneration(MsModelMixin, LlamaForCausalLM, TorchModel):
def __init__(self, config, **kwargs):
super().__init__(config)
self.model = LlamaModel(config)
self.lm_head = nn.Linear(
config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, AttentionTextGenerationModelOutput]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits, ) + outputs[1:]
return (loss, ) + output if loss is not None else output
# There is a conflict between the `ModelOutputBase` in the modelscope
# and the `send_to_device` function in the accelerate library.
# Temporarily change AttentionTextGenerationModelOutput to dict
return dict(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get('position_ids', None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
def chat(self, input: Dict, tokenizer) -> Dict:
import copy
gen_kwargs = copy.copy(input)
if 'text' not in input:
text: str = ''
else:
model_inputs = {'input_ids': input_ids}
text: str = input['text']
gen_kwargs.pop('text')
model_inputs.update({
'position_ids': position_ids,
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
})
return model_inputs
if 'system' not in input:
system: str = ''
else:
system: str = input['system']
gen_kwargs.pop('system')
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(
past_state.index_select(0, beam_idx)
for past_state in layer_past), )
return reordered_past
if 'history' not in input:
history = []
else:
history: List[Tuple] = copy.copy(input['history'])
gen_kwargs.pop('history')
def generate(self, inputs: Dict[str, Tensor],
**kwargs) -> Dict[str, Tensor]:
return super().generate(**inputs, **kwargs)
if 'max_length' not in gen_kwargs:
gen_kwargs['max_length'] = 4096
prompt, prompt_ids = get_chat_prompt(
system=system,
text=text,
history=history,
max_length=gen_kwargs['max_length'],
tokenizer=tokenizer)
input_ids = prompt_ids.to(self.device)
generate_ids = self.generate(input_ids, **gen_kwargs)
# remove input tokens
generate_ids = generate_ids[:, input_ids.shape[1]:]
response = tokenizer.batch_decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)[0]
response = response.strip()
history.append((text, response))
return {OutputKeys.RESPONSE: response, OutputKeys.HISTORY: history}

View File

@@ -1,272 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for LLaMA."""
import os
from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
from modelscope.utils.logger import get_logger
# This file is mainly copied from the llama code of transformers
logger = get_logger()
VOCAB_FILES_NAMES = {'vocab_file': 'tokenizer.model'}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': {
'hf-internal-testing/llama-tokenizer':
'https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model',
},
'tokenizer_file': {
'hf-internal-testing/llama-tokenizer':
'https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json',
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'hf-internal-testing/llama-tokenizer': 2048,
}
class LlamaTokenizer(PreTrainedTokenizer):
"""
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
Args:
vocab_file (`str`):
Path to the vocabulary file.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ['input_ids', 'attention_mask']
def __init__(
self,
vocab_file,
unk_token='<unk>',
bos_token='<s>',
eos_token='</s>',
pad_token=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
add_bos_token=True,
add_eos_token=False,
clean_up_tokenization_spaces=False,
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
bos_token = AddedToken(
bos_token, lstrip=False, rstrip=False) if isinstance(
bos_token, str) else bos_token
eos_token = AddedToken(
eos_token, lstrip=False, rstrip=False) if isinstance(
eos_token, str) else eos_token
unk_token = AddedToken(
unk_token, lstrip=False, rstrip=False) if isinstance(
unk_token, str) else unk_token
pad_token = AddedToken(
pad_token, lstrip=False, rstrip=False) if isinstance(
pad_token, str) else pad_token
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
sp_model_kwargs=self.sp_model_kwargs,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)
self.vocab_file = vocab_file
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
def __getstate__(self):
state = self.__dict__.copy()
state['sp_model'] = None
return state
def __setstate__(self, d):
self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(self.vocab_file)
@property
def vocab_size(self):
"""Returns vocab size"""
return self.sp_model.get_piece_size()
def get_vocab(self):
"""Returns vocab as a dict"""
vocab = {
self.convert_ids_to_tokens(i): i
for i in range(self.vocab_size)
}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text):
"""Returns a tokenized string."""
return self.sp_model.encode(text, out_type=str)
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
token = self.sp_model.IdToPiece(index)
return token
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ''
prev_is_special = False
for i, token in enumerate(tokens):
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special and i != 0:
out_string += ' '
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string
def save_vocabulary(self,
save_directory,
filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if not os.path.isdir(save_directory):
logger.error(
f'Vocabulary path ({save_directory}) should be a directory')
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + '-' if filename_prefix else '')
+ VOCAB_FILES_NAMES['vocab_file'])
if os.path.abspath(self.vocab_file) != os.path.abspath(
out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, 'wb') as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file, )
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = bos_token_id + token_ids_0 + eos_token_id
if token_ids_1 is not None:
output = output + bos_token_id + token_ids_1 + eos_token_id
return output
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True)
bos_token_id = [1] if self.add_bos_token else []
eos_token_id = [1] if self.add_eos_token else []
if token_ids_1 is None:
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
return (bos_token_id + # noqa
([0] * len(token_ids_0)) + eos_token_id + bos_token_id # noqa
+ # noqa
([0] * len(token_ids_1)) + eos_token_id) # noqa
def create_token_type_ids_from_sequences(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None) -> List[int]:
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
sequence pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
if token_ids_1 is None, only returns the first portion of the mask (0s).
Args:
token_ids_0 (`List[int]`):
List of ids.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
sep = [self.sep_token_id]
cls = [self.cls_token_id]
if token_ids_1 is None:
return len(cls + token_ids_0 + sep) * [0]
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1
+ sep) * [1]

View File

@@ -1,127 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from shutil import copyfile
from typing import Optional, Tuple
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from transformers.utils import is_sentencepiece_available
from transformers.utils.versions import require_version
from modelscope.utils.logger import get_logger
# This file is mainly copied from the llama code of transformers
require_version('tokenizers>=0.13.3')
if is_sentencepiece_available():
from .tokenization import LlamaTokenizer
else:
LlamaTokenizer = None
logger = get_logger()
VOCAB_FILES_NAMES = {
'vocab_file': 'tokenizer.model',
'tokenizer_file': 'tokenizer.json'
}
class LlamaTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
This uses notably ByteFallback and no normalization.
```
from transformers import LlamaTokenizerFast
tokenizer = LlaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.encode("Hello this is a test")
>>> [1, 15043, 445, 338, 263, 1243]
```
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
contains the vocabulary necessary to instantiate a tokenizer.
tokenizer_file (`str`):
[tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
contains everything needed to load the tokenizer.
clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
spaces.
bos_token (`str`, *optional*, defaults to `"<s>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
"""
vocab_files_names = VOCAB_FILES_NAMES
slow_tokenizer_class = LlamaTokenizer
padding_side = 'left'
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
clean_up_tokenization_spaces=False,
unk_token='<unk>',
bos_token='<s>',
eos_token='</s>',
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
**kwargs,
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
def save_vocabulary(self,
save_directory: str,
filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow '
'tokenizer.')
if not os.path.isdir(save_directory):
logger.error(
f'Vocabulary path ({save_directory}) should be a directory')
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + '-' if filename_prefix else '')
+ VOCAB_FILES_NAMES['vocab_file'])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file, )

View File

@@ -1,23 +1,25 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.models.nlp.llama import LlamaConfig as Llama2Config
from modelscope.models.nlp.llama import LlamaTokenizer as Llama2Tokenizer
from modelscope.models.nlp.llama import \
LlamaTokenizerFast as Llama2TokenizerFast
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .configuration import Llama2Config
from .text_generation import Llama2ForTextGeneration
from .backbone import Llama2Model
from .tokenization import Llama2Tokenizer
from .tokenization_fast import Llama2TokenizerFast
from .text_generation import Llama2ForTextGeneration
else:
_import_structure = {
'configuration': ['Llama2Config'],
'text_generation': ['Llama2ForTextGeneration'],
'backbone': ['Llama2Model'],
'tokenization': ['Llama2Tokenizer'],
'tokenization_fast': ['Llama2TokenizerFast'],
'text_generation': ['Llama2ForTextGeneration'],
}
_extra_objects = {
'Llama2Config': Llama2Config,
'Llama2Tokenizer': Llama2Tokenizer,
'Llama2TokenizerFast': Llama2TokenizerFast,
}
import sys
sys.modules[__name__] = LazyImportModule(
@@ -25,5 +27,5 @@ else:
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
extra_objects=_extra_objects,
)

795
modelscope/models/nlp/llama2/backbone.py Executable file → Normal file
View File

@@ -1,795 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch LLaMA model."""
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from modelscope import Model, TorchModel
from modelscope.metainfo import Models
from modelscope.models.builder import MODELS
from modelscope.models.nlp.llama import LlamaModel as LlamaModel2
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from ... import MODELS
from .configuration import Llama2Config
logger = get_logger()
_CONFIG_FOR_DOC = 'Llama2Config'
# This file is mainly copied from the llama code of transformers
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len),
torch.finfo(dtype).min,
device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
_tmp_value = torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device)
mask = torch.cat([_tmp_value, mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len,
tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor,
dtype: torch.dtype,
tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len,
src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool),
torch.finfo(dtype).min)
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance
+ self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base**(torch.arange(
0, self.dim, 2).float().to(device) / self.dim)) # noqa
self.register_buffer('inv_freq', inv_freq)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype())
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer(
'cos_cached',
emb.cos()[None, None, :, :].to(dtype),
persistent=False)
self.register_buffer(
'sin_cached',
emb.sin()[None, None, :, :].to(dtype),
persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(
seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer(
'cos_cached',
emb.cos()[None, None, :, :].to(dtype),
persistent=False)
self.register_buffer(
'sin_cached',
emb.sin()[None, None, :, :].to(dtype),
persistent=False)
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings)
- (self.scaling_factor - 1))**(
self.dim / (self.dim - 2))
inv_freq = 1.0 / (base**(torch.arange(
0, self.dim, 2).float().to(device) / self.dim)) # noqa
self.register_buffer('inv_freq', inv_freq)
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer(
'cos_cached',
emb.cos()[None, None, :, :].to(dtype),
persistent=False)
self.register_buffer(
'sin_cached',
emb.sin()[None, None, :, :].to(dtype),
persistent=False)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.pretraining_tp = config.pretraining_tp
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
if self.pretraining_tp > 1:
slice = self.intermediate_size // self.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat([
F.linear(x, gate_proj_slices[i])
for i in range(self.pretraining_tp)
],
dim=-1) # noqa
up_proj = torch.cat([
F.linear(x, up_proj_slices[i])
for i in range(self.pretraining_tp)
],
dim=-1) # noqa
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(
slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i])
for i in range(self.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(
self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :,
None, :, :].expand(batch,
num_key_value_heads,
n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
head_dim)
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Llama2Config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.pretraining_tp = config.pretraining_tp
self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
f' and `num_heads`: {self.num_heads}).')
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False)
self.v_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False)
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings)
else:
scaling_type = self.config.rope_scaling['type']
scaling_factor = self.config.rope_scaling['factor']
if scaling_type == 'linear':
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor)
elif scaling_type == 'dynamic':
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor)
else:
raise ValueError(f'Unknown RoPE scaling type {scaling_type}')
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads
* self.head_dim) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i])
for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i])
for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i])
for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(
2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
f' {attn_weights.size()}')
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
f' {attn_output.size()}')
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.pretraining_tp > 1:
attn_output = attn_output.split(
self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1)
attn_output = sum([
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: Llama2Config):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor,
torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states, )
if output_attentions:
outputs += (self_attn_weights, )
if use_cache:
outputs += (present_key_value, )
return outputs
class LlamaPreTrainedModel(TorchModel, PreTrainedModel):
config_class = Llama2Config
base_model_prefix = 'model'
supports_gradient_checkpointing = True
_no_split_modules = ['LlamaDecoderLayer']
_skip_keys_device_placement = 'past_key_values'
def __init__(self, config, **kwargs):
super().__init__(config.name_or_path, **kwargs)
super(Model, self).__init__(config)
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, Llama2Model):
module.gradient_checkpointing = value
@classmethod
def _instantiate(cls, **kwargs):
"""Instantiate the model.
Args:
kwargs: Input args.
model_dir: The model dir used to load the checkpoint and the label information.
num_labels: An optional arg to tell the model how many classes to initialize.
Method will call utils.parse_label_mapping if num_labels not supplied.
If num_labels is not found, the model will use the default setting (2 classes).
Returns:
The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
"""
model_dir = kwargs.pop('model_dir', None)
if model_dir is None:
config = Llama2Config(**kwargs)
model = cls(config)
else:
model = super(Model, cls).from_pretrained(
pretrained_model_name_or_path=model_dir, **kwargs)
model.model_dir = model_dir
return model
@MODELS.register_module(Tasks.backbone, module_name=Models.llama2)
class Llama2Model(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: Llama2Config
"""
def __init__(self, config: Llama2Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size,
self.padding_idx)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype,
tgt_len=input_shape[-1]).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else
expanded_attn_mask + combined_attention_mask)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
'You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time'
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError(
'You have to specify either decoder_input_ids or decoder_inputs_embeds'
)
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds,
past_key_values_length)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states, )
past_key_value = past_key_values[
idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (
layer_outputs[2 if output_attentions else 1], )
if output_attentions:
all_self_attns += (layer_outputs[1], )
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states, )
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v for v in
[hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)

View File

@@ -1,165 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" LLaMA model configuration"""
from transformers.configuration_utils import PretrainedConfig
from modelscope.utils.logger import get_logger
logger = get_logger()
LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class Llama2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the LLaMA-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`LlamaModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
pretraining_tp (`int`, *optional*, defaults to `1`):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
"""
model_type = 'llama'
keys_to_ignore_at_inference = ['past_key_values']
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act='silu',
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling,
dict) or len(self.rope_scaling) != 2:
raise ValueError(
'`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, '
f'got {self.rope_scaling}')
rope_scaling_type = self.rope_scaling.get('type', None)
rope_scaling_factor = self.rope_scaling.get('factor', None)
if rope_scaling_type is None or rope_scaling_type not in [
'linear', 'dynamic'
]:
raise ValueError(
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(
rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(
f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}"
)

View File

@@ -1,268 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithPast
from modelscope.metainfo import Models
from modelscope.outputs import OutputKeys
from modelscope.models.builder import MODELS
from modelscope.models.nlp.llama import \
LlamaForTextGeneration as Llama2ForTextGeneration
from modelscope.utils.constant import Tasks
from ... import MODELS
from .backbone import Llama2Model, LlamaPreTrainedModel
def get_chat_prompt(system: str, text: str, history: List[Tuple[str, str]],
max_length: int, tokenizer):
system_prompt = f'<s>[INST] <<SYS>>\n{system}\n<</SYS>>\n\n'
system_ids = tokenizer(system_prompt, return_tensors='pt').input_ids
text_prompt = f'{text.strip()} [/INST]'
text_ids = tokenizer(text_prompt, return_tensors='pt').input_ids
prompt_length = system_ids.shape[-1] + text_ids.shape[-1]
if prompt_length > max_length:
raise RuntimeError(
f'prepend prompt length {prompt_length} is bigger than max_length {max_length}'
)
history_prompt = ''
history_ids_list = []
# traverse history in reverse order
for user, bot in history[::-1]:
assert isinstance(user, str)
assert isinstance(bot, str)
round_prompt = f'{user.strip()} [/INST] {bot.strip()} </s><s>[INST] '
round_ids = tokenizer(round_prompt, return_tensors='pt').input_ids
if prompt_length + round_ids.shape[-1] > max_length:
# excess history should not be appended to the prompt
break
else:
history_prompt = round_prompt + history_prompt
history_ids_list = [round_ids] + history_ids_list
prompt_length += round_ids.shape[-1]
prompt_list = [system_prompt, history_prompt, text_prompt]
prompt_ids_list = [system_ids] + history_ids_list + [text_ids]
return ''.join(prompt_list), torch.cat(prompt_ids_list, dim=1)
# This file is mainly copied from the llama code of transformers
@MODELS.register_module(Tasks.text_generation, module_name=Models.llama2)
class Llama2ForTextGeneration(LlamaPreTrainedModel):
_tied_weights_keys = ['lm_head.weight']
def __init__(self, config):
super().__init__(config)
self.model = Llama2Model(config)
self.pretraining_tp = config.pretraining_tp
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(
config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else
self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(
self.vocab_size // self.pretraining_tp, dim=0)
logits = [
F.linear(hidden_states, lm_head_slices[i])
for i in range(self.pretraining_tp)
]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits, ) + outputs[1:]
return (loss, ) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
**kwargs):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get('position_ids', None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
model_inputs = {'input_ids': input_ids}
model_inputs.update({
'position_ids': position_ids,
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
})
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past), )
return reordered_past
def chat(self, input: Dict, tokenizer) -> Dict:
import copy
gen_kwargs = copy.copy(input)
if 'text' not in input:
text: str = ''
else:
text: str = input['text']
gen_kwargs.pop('text')
if 'system' not in input:
system: str = ''
else:
system: str = input['system']
gen_kwargs.pop('system')
if 'history' not in input:
history = []
else:
history: List[Tuple] = copy.copy(input['history'])
gen_kwargs.pop('history')
if 'max_length' not in gen_kwargs:
gen_kwargs['max_length'] = 4096
prompt, prompt_ids = get_chat_prompt(
system=system,
text=text,
history=history,
max_length=gen_kwargs['max_length'],
tokenizer=tokenizer)
input_ids = prompt_ids.to(self.device)
generate_ids = self.generate(input_ids, **gen_kwargs)
# remove input tokens
generate_ids = generate_ids[:, input_ids.shape[1]:]
response = tokenizer.batch_decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)[0]
response = response.strip()
history.append((text, response))
return {OutputKeys.RESPONSE: response, OutputKeys.HISTORY: history}

View File

@@ -1,410 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for LLaMA."""
import os
from shutil import copyfile
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import sentencepiece as spm
from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer
from modelscope.utils.logger import get_logger
if TYPE_CHECKING:
from transformers.pipelines.conversational import Conversation
logger = get_logger()
VOCAB_FILES_NAMES = {'vocab_file': 'tokenizer.model'}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file': {
'hf-internal-testing/llama-tokenizer':
'https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model',
},
'tokenizer_file': {
'hf-internal-testing/llama-tokenizer':
'https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json',
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'hf-internal-testing/llama-tokenizer': 2048,
}
SPIECE_UNDERLINE = ''
B_INST, E_INST = '[INST]', '[/INST]'
B_SYS, E_SYS = '<<SYS>>\n', '\n<</SYS>>\n\n'
# fmt: off
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your\
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not\
correct. If you don't know the answer to a question, please don't share false information."""
# fmt: on
class Llama2Tokenizer(PreTrainedTokenizer):
"""
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is
no padding token in the original model.
Args:
vocab_file (`str`):
Path to the vocabulary file.
legacy (`bool`, *optional*, defaults to `True`):
Whether or not the `legacy` behaviour of the tokenizer should be used. Legacy is before the merge of #24622
which includes fixes to properly handle tokens that appear after special tokens. A simple example:
- `legacy=True`:
```python
>>> from transformers import T5Tokenizer
>>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=True)
>>> tokenizer.encode("Hello <extra_id_0>.")
[8774, 32099, 3, 5, 1]
```
- `legacy=False`:
```python
>>> from transformers import T5Tokenizer
>>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=False)
>>> tokenizer.encode("Hello <extra_id_0>.") # the extra space `[3]` is no longer here
[8774, 32099, 5, 1]
```
Checkout the pull request and the issue [here](https://github.com/huggingface/transformers/pull/24565) for
more details.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ['input_ids', 'attention_mask']
def __init__(
self,
vocab_file,
unk_token='<unk>',
bos_token='<s>',
eos_token='</s>',
pad_token=None,
sp_model_kwargs: Optional[Dict[str, Any]] = None,
add_bos_token=True,
add_eos_token=False,
clean_up_tokenization_spaces=False,
legacy=True,
**kwargs,
):
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
bos_token = AddedToken(
bos_token, lstrip=False, rstrip=False) if isinstance(
bos_token, str) else bos_token
eos_token = AddedToken(
eos_token, lstrip=False, rstrip=False) if isinstance(
eos_token, str) else eos_token
unk_token = AddedToken(
unk_token, lstrip=False, rstrip=False) if isinstance(
unk_token, str) else unk_token
pad_token = AddedToken(
pad_token, lstrip=False, rstrip=False) if isinstance(
pad_token, str) else pad_token
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
add_bos_token=add_bos_token,
add_eos_token=add_eos_token,
sp_model_kwargs=self.sp_model_kwargs,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
legacy=legacy,
**kwargs,
)
if legacy:
logger.warning_once(
f'You are using the legacy behaviour of the {self.__class__}. '
f'This means that tokens that come after special '
f'tokens will not be properly handled. We recommend you to'
' read the related pull request available at https://github.com/huggingface/transformers/pull/24565'
)
self.legacy = legacy
self.vocab_file = vocab_file
self.add_bos_token = add_bos_token
self.add_eos_token = add_eos_token
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.Load(vocab_file)
def __getstate__(self):
state = self.__dict__.copy()
state['sp_model'] = None
state['sp_model_proto'] = self.sp_model.serialized_model_proto()
return state
def __setstate__(self, d):
self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
@property
def vocab_size(self):
"""Returns vocab size"""
return self.sp_model.get_piece_size()
def get_vocab(self):
"""Returns vocab as a dict"""
vocab = {
self.convert_ids_to_tokens(i): i
for i in range(self.vocab_size)
}
vocab.update(self.added_tokens_encoder)
return vocab
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize
def tokenize(self, text, **kwargs) -> List[str]:
# Replace the SPIECE_UNDERLINE with a space to make sure SPIECE_UNDERLINE is only used at
# the beginning of the text
if not self.legacy:
text = SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, ' ')
return super().tokenize(text, **kwargs)
# Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize
def _tokenize(self, text):
"""
Returns a tokenized string.
Since the sentencepiece internal model always adds a SPIECE_UNDERLINE, at the beginning of the provided text,
we need to remove it by hand when the current text is a subsequence. This happens whenever the `self.tokenize`
function is called with specials tokens: the input is split on the special tokens, and each subsequence is
passed to `_tokenize`. Thus if a subsequence did not start with a `" "` or SPIECE_UNDERLINE, we have to remove
the extra `SPIECE_UNDERLINE` prepended.
"""
if not self.legacy:
is_first = text.startswith(SPIECE_UNDERLINE)
if is_first:
text = text[1:]
tokens = self.sp_model.encode(text, out_type=str)
if not self.legacy and not is_first and not text.startswith(
' ') and tokens[0].startswith(SPIECE_UNDERLINE):
tokens = ([tokens[0][1:]]
if len(tokens[0]) > 1 else []) + tokens[1:]
return tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
token = self.sp_model.IdToPiece(index)
return token
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
current_sub_tokens = []
out_string = ''
prev_is_special = False
for i, token in enumerate(tokens):
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_special_tokens:
if not prev_is_special and i != 0:
out_string += ' '
out_string += self.sp_model.decode(current_sub_tokens) + token
prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string
def save_vocabulary(self,
save_directory,
filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if not os.path.isdir(save_directory):
logger.error(
f'Vocabulary path ({save_directory}) should be a directory')
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + '-' if filename_prefix else '')
+ VOCAB_FILES_NAMES['vocab_file'])
if os.path.abspath(self.vocab_file) != os.path.abspath(
out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, 'wb') as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file, )
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = bos_token_id + token_ids_0 + eos_token_id
if token_ids_1 is not None:
output = output + bos_token_id + token_ids_1 + eos_token_id
return output
def get_special_tokens_mask(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer `prepare_for_model` method.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0,
token_ids_1=token_ids_1,
already_has_special_tokens=True)
bos_token_id = [1] if self.add_bos_token else []
eos_token_id = [1] if self.add_eos_token else []
if token_ids_1 is None:
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
return bos_token_id + (
[0] * len(token_ids_0)) + eos_token_id + bos_token_id + (
[0] * len(token_ids_1)) + eos_token_id # noqa
def create_token_type_ids_from_sequences(
self,
token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None) -> List[int]:
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
sequence pair mask has the following format:
```
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
```
if token_ids_1 is None, only returns the first portion of the mask (0s).
Args:
token_ids_0 (`List[int]`):
List of ids.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
"""
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
if token_ids_1 is not None:
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
return output
def _build_conversation_input_ids(
self, conversation: 'Conversation') -> List[int]:
"""Builds the input ids for a conversation.
This is the format used in the provided examples. System prompts should be manually added at the beginning of
the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used.
```
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
<bos>[INST] Prompt [/INST] Answer <eos>
<bos>[INST] Prompt [/INST]
```
If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following:
```python
>>> from transformers import Conversation
>>> Conversation(
... "<<SYS>>\n Only answer with emojis, and charades\n<</SYS>>\n\nHow can I build a house in 10 septs?"
... )
```
Args:
conversation (`Conversation`):
Conversation to build input ids for.
Returns:
`List[int]`:
Input ids for the conversation.
"""
dialogue = list(conversation.iter_texts())
if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
[not is_user for is_user, msg in dialogue[1::2]]): # noqa
raise ValueError(
"The model only supports 'user' and 'assistant' roles, "
'starting with user and alternating (u/a/u/a/u...)')
dialog_tokens: List[int] = []
if len(conversation.past_user_inputs) > 0:
if not conversation.past_user_inputs[0].startswith(
B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
conversation.past_user_inputs[0] = (
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS
+ conversation.past_user_inputs[0])
elif not dialogue[0][1].startswith(
B_SYS) or E_SYS not in dialogue[0][1]:
dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT
+ E_SYS + dialogue[0][1])
dialog_tokens += sum(
[[self.bos_token_id] + self.encode(
f'{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ',
add_special_tokens=False) + [self.eos_token_id]
for prompt, answer in zip(dialogue[::2], dialogue[1::2])],
[],
)
if not (dialogue[-1][0]):
raise ValueError(
f"Last message must be from user, got {dialogue[-1]['role']}")
dialog_tokens += [self.bos_token_id] + self.encode(
f'{B_INST} {(dialogue[-1][1]).strip()} {E_INST}',
add_special_tokens=False)
return dialog_tokens

View File

@@ -1,251 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from shutil import copyfile
from typing import TYPE_CHECKING, Optional, Tuple
from tokenizers import processors
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from transformers.utils import is_sentencepiece_available
from transformers.utils.versions import require_version
from modelscope.utils import logger as logging
if TYPE_CHECKING:
from transformers.pipelines.conversational import Conversation
require_version('tokenizers>=0.13.3')
if is_sentencepiece_available():
from .tokenization import Llama2Tokenizer
else:
Llama2Tokenizer = None
logger = logging.get_logger()
VOCAB_FILES_NAMES = {
'vocab_file': 'tokenizer.model',
'tokenizer_file': 'tokenizer.json'
}
B_INST, E_INST = '[INST]', '[/INST]'
B_SYS, E_SYS = '<<SYS>>\n', '\n<</SYS>>\n\n'
# fmt: off
DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your\
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\
that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not\
correct. If you don't know the answer to a question, please don't share false information."""
# fmt: on
class Llama2TokenizerFast(PreTrainedTokenizerFast):
"""
Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding.
This uses notably ByteFallback and no normalization.
```
from transformers import LlamaTokenizerFast
tokenizer = LlaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.encode("Hello this is a test")
>>> [1, 15043, 445, 338, 263, 1243]
```
If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or
call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the
values of the first token and final token of an encoded sequence will not be correct). For more details, checkout
[post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation.
This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods.
Args:
vocab_file (`str`):
[SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that
contains the vocabulary necessary to instantiate a tokenizer.
tokenizer_file (`str`):
[tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that
contains everything needed to load the tokenizer.
clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`):
Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra
spaces.
bos_token (`str`, *optional*, defaults to `"<s>"`):
The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
eos_token (`str`, *optional*, defaults to `"</s>"`):
The end of sequence token.
unk_token (`str`, *optional*, defaults to `"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
"""
vocab_files_names = VOCAB_FILES_NAMES
slow_tokenizer_class = Llama2Tokenizer
padding_side = 'left'
model_input_names = ['input_ids', 'attention_mask']
def __init__(
self,
vocab_file=None,
tokenizer_file=None,
clean_up_tokenization_spaces=False,
unk_token='<unk>',
bos_token='<s>',
eos_token='</s>',
add_bos_token=True,
add_eos_token=False,
**kwargs,
):
super().__init__(
vocab_file=vocab_file,
tokenizer_file=tokenizer_file,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
**kwargs,
)
self._add_bos_token = add_bos_token
self._add_eos_token = add_eos_token
self.update_post_processor()
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
def update_post_processor(self):
"""
Updates the underlying post processor with the current `bos_token` and `eos_token`.
"""
bos = self.bos_token
bos_token_id = self.bos_token_id
eos = self.eos_token
eos_token_id = self.eos_token_id
single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') * self.add_eos_token}"
pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') * self.add_eos_token}"
special_tokens = []
if self.add_bos_token:
special_tokens.append((bos, bos_token_id))
if self.add_eos_token:
special_tokens.append((eos, eos_token_id))
self._tokenizer.post_processor = processors.TemplateProcessing(
single=single, pair=pair, special_tokens=special_tokens)
@property
def add_eos_token(self):
return self._add_eos_token
@property
def add_bos_token(self):
return self._add_bos_token
@add_eos_token.setter
def add_eos_token(self, value):
self._add_eos_token = value
self.update_post_processor()
@add_bos_token.setter
def add_bos_token(self, value):
self._add_bos_token = value
self.update_post_processor()
def save_vocabulary(self,
save_directory: str,
filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow '
'tokenizer.')
if not os.path.isdir(save_directory):
logger.error(
f'Vocabulary path ({save_directory}) should be a directory')
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + '-' if filename_prefix else '')
+ VOCAB_FILES_NAMES['vocab_file'])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file, )
def _build_conversation_input_ids(self, conversation: 'Conversation'):
"""Builds the input ids for a conversation.
This is the format used in the provided examples. System prompts should be manually added at the beginning of
the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used.
```
<bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
<bos>[INST] Prompt [/INST] Answer <eos>
<bos>[INST] Prompt [/INST]
```
If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following:
```python
>>> from transformers import Conversation
>>> Conversation(
... "<<SYS>>\n Only answer with emojis, and charades\n<</SYS>>\n\nHow can I build a house in 10 septs?"
... )
```
Args:
conversation (`Conversation`):
Conversation to build input ids for.
Returns:
`List[int]`:
Input ids for the conversation.
"""
dialogue = list(conversation.iter_texts())
if not all([is_user for is_user, msg in dialogue[::2]]) or not all(
[not is_user for is_user, msg in dialogue[1::2]]): # noqa
raise ValueError(
"The model only supports 'user' and 'assistant' roles, "
'starting with user and alternating (u/a/u/a/u...)')
dialog_tokens = []
if len(conversation.past_user_inputs) > 0:
if not conversation.past_user_inputs[0].startswith(
B_SYS) or E_SYS not in conversation.past_user_inputs[0]:
conversation.past_user_inputs[0] = (
B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS
+ conversation.past_user_inputs[0])
elif not dialogue[0][1].startswith(
B_SYS) or E_SYS not in dialogue[0][1]:
dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT
+ E_SYS + dialogue[0][1])
dialog_tokens += sum(
[[self.bos_token_id] + self.encode(
f'{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ',
add_special_tokens=False) + [self.eos_token_id]
for prompt, answer in zip(dialogue[::2], dialogue[1::2])],
[],
)
if not (dialogue[-1][0]):
raise ValueError(
f"Last message must be from user, got {dialogue[-1]['role']}")
dialog_tokens += [self.bos_token_id] + self.encode(
f'{B_INST} {(dialogue[-1][1]).strip()} {E_INST}',
add_special_tokens=False)
return dialog_tokens

View File

@@ -19,6 +19,7 @@ from modelscope.utils.chinese_utils import remove_space_between_chinese_chars
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.hub import Config, read_config
from modelscope.utils.streaming_output import PipelineStreamingOutputMixin
from modelscope.utils.torch_utils import is_on_same_device
__all__ = [
'TextGenerationPipeline',
@@ -242,23 +243,39 @@ class ChatGLM6bV2TextGenerationPipeline(Pipeline):
quantization_bit=None,
use_bf16=False,
**kwargs):
from modelscope.models.nlp import (ChatGLM2Config,
ChatGLM2ForConditionalGeneration,
ChatGLM2Tokenizer)
from modelscope import AutoTokenizer
device: str = kwargs.get('device', 'gpu')
if isinstance(model, str):
revision = kwargs.get('revision', None)
model_dir = snapshot_download(
model) if not os.path.exists(model) else model
model = ChatGLM2ForConditionalGeneration.from_pretrained(model_dir)
if torch.cuda.is_available():
model = model.cuda()
model,
revision=revision) if not os.path.exists(model) else model
default_device_map = None
if device.startswith('gpu') or device.startswith('cuda'):
default_device_map = {'': 0}
device_map = kwargs.get('device_map', default_device_map)
default_torch_dtype = None
if use_bf16:
default_torch_dtype = torch.bfloat16
torch_dtype = kwargs.get('torch_dtype', default_torch_dtype)
model = Model.from_pretrained(
model_dir,
trust_remote_code=True,
device_map=device_map,
torch_dtype=torch_dtype)
else:
if ((device.startswith('gpu') or device.startswith('cuda'))
and is_on_same_device(model)):
model.cuda()
if use_bf16:
model.bfloat16()
if quantization_bit is not None:
model = model.quantize(quantization_bit)
if use_bf16:
model = model.bfloat16()
self.model = model
self.model.eval()
self.tokenizer = ChatGLM2Tokenizer.from_pretrained(
self.model.model_dir)
self.tokenizer = AutoTokenizer.from_pretrained(
self.model.model_dir, trust_remote_code=True)
super().__init__(model=model, **kwargs)

View File

@@ -0,0 +1,84 @@
import os
from typing import Optional
from modelscope.metainfo import Tasks
from modelscope.utils.ast_utils import INDEX_KEY
from modelscope.utils.import_utils import LazyImportModule
def can_load_by_ms(model_dir: str, tast_name: str, model_type: str) -> bool:
if ('MODELS', tast_name,
model_type) in LazyImportModule.AST_INDEX[INDEX_KEY]:
return True
ms_wrapper_path = os.path.join(model_dir, 'ms_wrapper.py')
if os.path.exists(ms_wrapper_path):
return True
return False
def _can_load_by_hf_automodel(automodel_class: type, config) -> bool:
automodel_class_name = automodel_class.__name__
if type(config) in automodel_class._model_mapping.keys():
return True
if hasattr(config, 'auto_map') and automodel_class_name in config.auto_map:
return True
return False
def get_hf_automodel_class(model_dir: str, task_name: str) -> Optional[type]:
from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForTokenClassification,
AutoModelForSequenceClassification)
automodel_mapping = {
Tasks.backbone: AutoModel,
Tasks.chat: AutoModelForCausalLM,
Tasks.text_generation: AutoModelForCausalLM,
Tasks.text_classification: AutoModelForSequenceClassification,
Tasks.token_classification: AutoModelForTokenClassification,
}
automodel_class = automodel_mapping.get(task_name, None)
if automodel_class is None:
return None
config_path = os.path.join(model_dir, 'config.json')
if not os.path.exists(config_path):
return None
try:
try:
config = AutoConfig.from_pretrained(
model_dir, trust_remote_code=True)
except (FileNotFoundError, ValueError):
return None
if _can_load_by_hf_automodel(automodel_class, config):
return automodel_class
if (automodel_class is AutoModelForCausalLM
and _can_load_by_hf_automodel(AutoModelForSeq2SeqLM, config)):
return AutoModelForSeq2SeqLM
return None
except Exception:
return None
def try_to_load_hf_model(model_dir: str, task_name: str,
use_hf: Optional[bool], **kwargs):
automodel_class = get_hf_automodel_class(model_dir, task_name)
if use_hf and automodel_class is None:
raise ValueError(f'Model import failed. You used `use_hf={use_hf}`, '
'but the model is not a model of hf')
model = None
if automodel_class is not None:
# use hf
device_map = kwargs.get('device_map', None)
torch_dtype = kwargs.get('torch_dtype', None)
config = kwargs.get('config', None)
model = automodel_class.from_pretrained(
model_dir,
device_map=device_map,
torch_dtype=torch_dtype,
config=config,
trust_remote_code=True)
return model

View File

@@ -3,13 +3,21 @@
import os
import sys
from transformers import CONFIG_MAPPING
from transformers import AutoConfig as AutoConfigHF
from transformers import AutoModel as AutoModelHF
from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF
from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF
from transformers import \
AutoModelForSequenceClassification as AutoModelForSequenceClassificationHF
from transformers import \
AutoModelForTokenClassification as AutoModelForTokenClassificationHF
from transformers import AutoTokenizer as AutoTokenizerHF
from transformers import GenerationConfig as GenerationConfigHF
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers import (PretrainedConfig, PreTrainedModel,
PreTrainedTokenizerBase)
from transformers.models.auto.tokenization_auto import (
TOKENIZER_MAPPING_NAMES, get_tokenizer_config)
from modelscope import snapshot_download
from modelscope.utils.constant import Invoke
@@ -70,6 +78,58 @@ patch_tokenizer_base()
patch_model_base()
def check_hf_code(model_dir: str, auto_class: type,
trust_remote_code: bool) -> None:
config_path = os.path.join(model_dir, 'config.json')
if not os.path.exists(config_path):
raise FileNotFoundError(f'{config_path} is not found')
config_dict = PretrainedConfig.get_config_dict(config_path)[0]
auto_class_name = auto_class.__name__
# load from repo
if trust_remote_code:
has_remote_code = False
if auto_class is AutoTokenizerHF:
tokenizer_config_dict = get_tokenizer_config(model_dir)
auto_map = tokenizer_config_dict.get('auto_map', None)
if auto_map is not None:
module_name = auto_map.get(auto_class_name, None)
if module_name is not None:
module_name = module_name[0]
has_remote_code = True
else:
auto_map = config_dict.get('auto_map', None)
if auto_map is not None:
module_name = auto_map.get(auto_class_name, None)
has_remote_code = module_name is not None
if has_remote_code:
module_path = os.path.join(model_dir,
module_name.split('.')[0] + '.py')
if not os.path.exists(module_path):
raise FileNotFoundError(f'{module_path} is not found')
return
# trust_remote_code is False or has_remote_code is False
model_type = config_dict.get('model_type', None)
if model_type is None:
raise ValueError(f'`model_type` key is not found in {config_path}')
if auto_class is AutoConfigHF:
if model_type not in CONFIG_MAPPING:
raise ValueError(f'{model_type} not found in HF CONFIG_MAPPING')
elif auto_class is AutoTokenizerHF:
if model_type not in TOKENIZER_MAPPING_NAMES:
raise ValueError(
f'{model_type} not found in HF TOKENIZER_MAPPING_NAMES')
else:
mapping_names = [
m.model_type for m in auto_class._model_mapping.keys()
]
if model_type not in mapping_names:
raise ValueError(
f'{model_type} not found in HF auto_class._model_mapping')
def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
"""Get a custom wrapper class for auto classes to download the models from the ModelScope hub
Args:
@@ -79,12 +139,15 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
Returns:
The wrapper
"""
default_ignore_file_pattern = ignore_file_pattern
class ClassWrapper(module_class):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = kwargs.pop('ignore_file_pattern',
default_ignore_file_pattern)
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
@@ -95,11 +158,18 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
else:
model_dir = pretrained_model_name_or_path
model = module_class.from_pretrained(model_dir, *model_args,
**kwargs)
model.model_dir = model_dir
return model
if module_class is not GenerationConfigHF:
trust_remote_code = kwargs.get('trust_remote_code', False)
check_hf_code(model_dir, module_class, trust_remote_code)
module_obj = module_class.from_pretrained(model_dir, *model_args,
**kwargs)
if module_class.__name__.startswith('AutoModel'):
module_obj.model_dir = model_dir
return module_obj
ClassWrapper.__name__ = module_class.__name__
ClassWrapper.__qualname__ = module_class.__qualname__
return ClassWrapper
@@ -109,6 +179,12 @@ AutoModelForCausalLM = get_wrapped_class(
AutoModelForCausalLMHF, ignore_file_pattern=[r'\w+\.safetensors'])
AutoModelForSeq2SeqLM = get_wrapped_class(
AutoModelForSeq2SeqLMHF, ignore_file_pattern=[r'\w+\.safetensors'])
AutoModelForSequenceClassification = get_wrapped_class(
AutoModelForSequenceClassificationHF,
ignore_file_pattern=[r'\w+\.safetensors'])
AutoModelForTokenClassification = get_wrapped_class(
AutoModelForTokenClassificationHF,
ignore_file_pattern=[r'\w+\.safetensors'])
AutoTokenizer = get_wrapped_class(
AutoTokenizerHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors'])

View File

@@ -5,12 +5,10 @@ import shutil
import tempfile
import unittest
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.models.base import Model
from modelscope.utils.test_utils import test_level
class BaseTest(unittest.TestCase):
@@ -25,15 +23,31 @@ class BaseTest(unittest.TestCase):
shutil.rmtree(self.tmp_dir)
super().tearDown()
def test_from_pretrained(self):
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_from_pretrained_baichuan(self):
model = Model.from_pretrained(
'baichuan-inc/baichuan-7B', revision='v1.0.5')
'baichuan-inc/Baichuan-13B-Chat',
revision='v1.0.8',
torch_dtype=torch.float16,
device='gpu')
print(model.__class__.__name__)
self.assertIsNotNone(model)
def test_from_pretrained_hf(self):
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_from_pretrained_chatglm2(self):
model = Model.from_pretrained(
'ZhipuAI/chatglm2-6b',
revision='v1.0.7',
torch_dtype=torch.float16,
device='gpu')
print(model.__class__.__name__)
self.assertIsNotNone(model)
def test_from_pretrained_ms(self):
model = Model.from_pretrained(
'damo/nlp_structbert_sentence-similarity_chinese-tiny',
use_hf=True)
device='gpu')
print(model.__class__.__name__)
self.assertIsNotNone(model)