diff --git a/.gitignore b/.gitignore index 6cc2df63..7e956c3b 100644 --- a/.gitignore +++ b/.gitignore @@ -124,6 +124,8 @@ replace.sh result.png result.jpg result.mp4 +runs/ +ckpt/ # Pytorch *.pth diff --git a/examples/pytorch/llm/README.md b/examples/pytorch/llm/README.md index 259b1302..408e2986 100644 --- a/examples/pytorch/llm/README.md +++ b/examples/pytorch/llm/README.md @@ -1,4 +1,3 @@ -

LLM SFT Example

@@ -14,46 +13,59 @@ 中文  |  English

-## Note!!! +## Note 1. This README.md file is **copied from** [ms-swift](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/README.md) 2. This directory has been **migrated** to [ms-swift](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm), and the files in this directory are **no longer maintained**. ## Features -1. supported sft method: lora, qlora, full, ... -2. supported models: [**qwen-7b**](https://github.com/QwenLM/Qwen-7B), baichuan-7b, baichuan-13b, chatglm2-6b, llama2-7b, llama2-13b, llama2-70b, openbuddy-llama2-13b, ... +1. supported sft method: [lora](https://arxiv.org/abs/2106.09685), [qlora](https://arxiv.org/abs/2305.14314), full(full parameter fine tuning), ... +2. supported models: [**qwen-7b**](https://github.com/QwenLM/Qwen-7B), baichuan-7b, baichuan-13b, chatglm2-6b, chatglm2-6b-32k, llama2-7b, llama2-13b, llama2-70b, openbuddy-llama2-13b, openbuddy-llama-65b, polylm-13b, ... 3. supported feature: quantization, ddp, model parallelism(device map), gradient checkpoint, gradient accumulation steps, push to modelscope hub, custom datasets, ... 4. supported datasets: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, ... ## Prepare the Environment +Experimental environment: A10, 3090, A100, ... (V100 does not support bf16, quantization) ```bash -# Please note the cuda version -conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y +# Installing miniconda +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh +sh Miniconda3-latest-Linux-x86_64.sh +# Setting up a conda virtual environment +conda create --name ms-sft python=3.10 +conda activate ms-sft + +# Setting up a global pip mirror for faster downloads +pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/ + +pip install torch torchvision torchaudio -U pip install sentencepiece charset_normalizer cpm_kernels tiktoken -U -pip install matplotlib tqdm tensorboard -U +pip install matplotlib scikit-learn tqdm tensorboard -U pip install transformers datasets -U pip install accelerate transformers_stream_generator -U +pip install ms-swift modelscope -U # Recommended installation from source code for faster bug fixes git clone https://github.com/modelscope/swift.git cd swift pip install -r requirements.txt pip install . # same as modelscope...(git clone ...) -# You can also install it from pypi -pip install ms-swift modelscope -U ``` ## Run SFT and Inference ```bash +# Clone the repository and enter the code directory. git clone https://github.com/modelscope/swift.git cd swift/examples/pytorch/llm -# sft(qlora) and infer qwen-7b, Requires 10GB VRAM. +# sft(qlora) and infer qwen-7b, Requires 16GB VRAM. +# If you want to use quantification, you need to `pip install bitsandbytes` bash scripts/qwen_7b/qlora/sft.sh +# If you want to push the model to modelscope hub during training +bash scripts/qwen_7b/qlora/sft_push_to_hub.sh bash scripts/qwen_7b/qlora/infer.sh -# sft(qlora+ddp) and infer qwen-7b, Requires 4*10GB VRAM. +# sft(qlora+ddp) and infer qwen-7b, Requires 4*16GB VRAM. bash scripts/qwen_7b/qlora_ddp/sft.sh bash scripts/qwen_7b/qlora_ddp/infer.sh diff --git a/examples/pytorch/llm/README_CN.md b/examples/pytorch/llm/README_CN.md index 0c827493..acbcb3d7 100644 --- a/examples/pytorch/llm/README_CN.md +++ b/examples/pytorch/llm/README_CN.md @@ -1,4 +1,3 @@ -

大模型微调的例子

@@ -14,47 +13,61 @@ 中文  |  English

-## 请注意!!! +## 请注意 1. 该README_CN.md**拷贝**自[ms-swift](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/README_CN.md) 2. 该目录已经**迁移**至[ms-swift](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm), 此目录中的文件**不再维护**. ## 特性 -1. 支持的sft方法: lora, qlora, 全参数微调, ... -2. 支持的模型: [**qwen-7b**](https://github.com/QwenLM/Qwen-7B), baichuan-7b, baichuan-13b, chatglm2-6b, llama2-7b, llama2-13b, llama2-70b, openbuddy-llama2-13b, ... +1. [lora](https://arxiv.org/abs/2106.09685), [qlora](https://arxiv.org/abs/2305.14314), 全参数微调, ... +2. 支持的模型: [**qwen-7b**](https://github.com/QwenLM/Qwen-7B), baichuan-7b, baichuan-13b, chatglm2-6b, chatglm2-6b-32k, llama2-7b, llama2-13b, llama2-70b, openbuddy-llama2-13b, openbuddy-llama-65b, polylm-13b, ... 3. 支持的特性: 模型量化, DDP, 模型并行(device_map), gradient checkpoint, 梯度累加, 支持推送modelscope hub, 支持自定义数据集, ... 4. 支持的数据集: alpaca-en(gpt4), alpaca-zh(gpt4), finance-en, multi-alpaca-all, code-en, instinwild-en, instinwild-zh, ... ## 准备实验环境 +实验环境: A10, 3090, A100均可. (V100不支持bf16, 量化) ```bash -# 请注意修改cuda的版本 -conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y +# 安装miniconda +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh +# 一直[ENTER], 最后一个选项yes即可 +sh Miniconda3-latest-Linux-x86_64.sh +# conda虚拟环境搭建 +conda create --name ms-sft python=3.10 +conda activate ms-sft + +# pip设置全局镜像与相关python包安装 +pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/ + +pip install torch torchvision torchaudio -U pip install sentencepiece charset_normalizer cpm_kernels tiktoken -U -pip install matplotlib tqdm tensorboard -U +pip install matplotlib scikit-learn tqdm tensorboard -U pip install transformers datasets -U pip install accelerate transformers_stream_generator -U +pip install ms-swift modelscope -U # 推荐从源码安装swift和modelscope, 这具有更多的特性和更快的bug修复 git clone https://github.com/modelscope/swift.git cd swift pip install -r requirements.txt pip install . # modelscope类似...(git clone ...) -# 当然, 你也可以从pypi上下载 -pip install ms-swift modelscope -U ``` ## 微调和推理 ```bash +# clone仓库并进入代码目录 git clone https://github.com/modelscope/swift.git cd swift/examples/pytorch/llm -# 微调(qlora)+推理 qwen-7b, 需要10G显存. +# 微调(qlora)+推理 qwen-7b, 需要16GB显存. +# 如果你想要使用量化, 你需要`pip install bitsandbytes` bash scripts/qwen_7b/qlora/sft.sh +# 如果你想在训练时, 将权重push到modelscope hub中. +bash scripts/qwen_7b/qlora/sft_push_to_hub.sh bash scripts/qwen_7b/qlora/infer.sh -# 微调(qlora+ddp)+推理 qwen-7b, 需要4卡*10G显存. +# 微调(qlora+ddp)+推理 qwen-7b, 需要4卡*16GB显存. bash scripts/qwen_7b/qlora_ddp/sft.sh bash scripts/qwen_7b/qlora_ddp/infer.sh diff --git a/modelscope/__init__.py b/modelscope/__init__.py index bf95cb81..ac362be1 100644 --- a/modelscope/__init__.py +++ b/modelscope/__init__.py @@ -26,8 +26,12 @@ if TYPE_CHECKING: from .pipelines import Pipeline, pipeline from .utils.hub import read_config, create_model_if_not_exist from .utils.logger import get_logger + from .utils.constant import Tasks from .utils.hf_util import AutoConfig, GenerationConfig - from .utils.hf_util import AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM + from .utils.hf_util import (AutoModel, AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoModelForTokenClassification) from .utils.hf_util import AutoTokenizer from .msdatasets import MsDataset @@ -68,9 +72,12 @@ else: 'pipelines': ['Pipeline', 'pipeline'], 'utils.hub': ['read_config', 'create_model_if_not_exist'], 'utils.logger': ['get_logger'], + 'utils.constant': ['Tasks'], 'utils.hf_util': [ 'AutoConfig', 'GenerationConfig', 'AutoModel', - 'AutoModelForCausalLM', 'AutoModelForSeq2SeqLM', 'AutoTokenizer' + 'AutoModelForCausalLM', 'AutoModelForSeq2SeqLM', 'AutoTokenizer', + 'AutoModelForSequenceClassification', + 'AutoModelForTokenClassification' ], 'msdatasets': ['MsDataset'] } diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index 02f50483..788d5c43 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -4,10 +4,11 @@ import os.path as osp from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Union -from modelscope.hub.check_model import check_local_model_is_latest from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Tasks from modelscope.models.builder import build_backbone, build_model +from modelscope.utils.automodel_utils import (can_load_by_ms, + try_to_load_hf_model) from modelscope.utils.config import Config from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile from modelscope.utils.device import verify_device @@ -84,12 +85,22 @@ class Model(ABC): device(str, `optional`): The device to load the model. **kwargs: task(str, `optional`): The `Tasks` enumeration value to replace the task value - read out of config in the `model_name_or_path`. This is useful when the model to be loaded is not - equal to the model saved. - For example, load a `backbone` into a `text-classification` model. - Other kwargs will be directly fed into the `model` key, to replace the default configs. - use_hf(bool): If set True, will use AutoModel in hf to initialize the model to keep compatibility - with huggingface transformers. + read out of config in the `model_name_or_path`. This is useful when the model to be loaded is not + equal to the model saved. + For example, load a `backbone` into a `text-classification` model. + Other kwargs will be directly fed into the `model` key, to replace the default configs. + use_hf(bool, `optional`): + If set to True, it will initialize the model using AutoModel or AutoModelFor* from hf. + If set to False, the model is loaded using the modelscope mode. + If set to None, the loading mode will be automatically selected. + ignore_file_pattern(List[str], `optional`): + This parameter is passed to snapshot_download + device_map(str | Dict[str, str], `optional`): + This parameter is passed to AutoModel or AutoModelFor* + torch_dtype(torch.dtype, `optional`): + This parameter is passed to AutoModel or AutoModelFor* + config(PretrainedConfig, `optional`): + This parameter is passed to AutoModel or AutoModelFor* Returns: A model instance. @@ -115,14 +126,14 @@ class Model(ABC): ) invoked_by = '%s/%s' % (Invoke.KEY, invoked_by) + ignore_file_pattern = kwargs.get('ignore_file_pattern', None) local_model_dir = snapshot_download( - model_name_or_path, revision, user_agent=invoked_by) + model_name_or_path, + revision, + user_agent=invoked_by, + ignore_file_pattern=ignore_file_pattern) logger.info(f'initialize model from {local_model_dir}') - if kwargs.pop('use_hf', False): - from modelscope import AutoModel - return AutoModel.from_pretrained(local_model_dir) - if cfg_dict is not None: cfg = cfg_dict else: @@ -134,6 +145,23 @@ class Model(ABC): model_cfg = cfg.model if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): model_cfg.type = model_cfg.model_type + model_type = model_cfg.type + if isinstance(device, str) and device.startswith('gpu'): + device = 'cuda' + device[3:] + use_hf = kwargs.pop('use_hf', None) + if use_hf is None and can_load_by_ms(local_model_dir, task_name, + model_type): + use_hf = False + model = None + if use_hf in {True, None}: + model = try_to_load_hf_model(local_model_dir, task_name, use_hf, + **kwargs) + if model is not None: + device_map = kwargs.get('device_map', None) + if device_map is None and device is not None: + model = model.to(device) + return model + # use ms model_cfg.model_dir = local_model_dir # install and import remote repos before build diff --git a/modelscope/models/nlp/llama/__init__.py b/modelscope/models/nlp/llama/__init__.py index 9cc10253..d5b6fd19 100644 --- a/modelscope/models/nlp/llama/__init__.py +++ b/modelscope/models/nlp/llama/__init__.py @@ -1,23 +1,24 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import TYPE_CHECKING +from transformers.models.llama import (LlamaConfig, LlamaTokenizer, + LlamaTokenizerFast) + from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .configuration import LlamaConfig - from .text_generation import LlamaForTextGeneration from .backbone import LlamaModel - from .tokenization import LlamaTokenizer - from .tokenization_fast import LlamaTokenizerFast + from .text_generation import LlamaForTextGeneration else: _import_structure = { - 'configuration': ['LlamaConfig'], - 'text_generation': ['LlamaForTextGeneration'], 'backbone': ['LlamaModel'], - 'tokenization': ['LlamaTokenizer'], - 'tokenization_fast': ['LlamaTokenizerFast'], + 'text_generation': ['LlamaForTextGeneration'], + } + _extra_objects = { + 'LlamaConfig': LlamaConfig, + 'LlamaTokenizer': LlamaTokenizer, + 'LlamaTokenizerFast': LlamaTokenizerFast, } - import sys sys.modules[__name__] = LazyImportModule( @@ -25,5 +26,5 @@ else: globals()['__file__'], _import_structure, module_spec=__spec__, - extra_objects={}, + extra_objects=_extra_objects, ) diff --git a/modelscope/models/nlp/llama/backbone.py b/modelscope/models/nlp/llama/backbone.py index 16be099f..0ac5bf5c 100755 --- a/modelscope/models/nlp/llama/backbone.py +++ b/modelscope/models/nlp/llama/backbone.py @@ -18,389 +18,21 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch LLaMA model.""" -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.utils.checkpoint -from torch import nn -from transformers.activations import ACT2FN -from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama import LlamaConfig +from transformers.models.llama import LlamaModel as LlamaModelHF +from transformers.models.llama import \ + LlamaPreTrainedModel as LlamaPreTrainedModelHF from modelscope.metainfo import Models from modelscope.models import Model, TorchModel from modelscope.models.builder import MODELS -from modelscope.outputs import AttentionBackboneModelOutput from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger -from .configuration import LlamaConfig logger = get_logger() -_CONFIG_FOR_DOC = 'LlamaConfig' - -# This file is mainly copied from the llama code of transformers -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), - torch.tensor(torch.finfo(dtype).min, device=device), - device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - mask = torch.cat( - [ - torch.zeros( - tgt_len, - past_key_values_length, # noqa - dtype=dtype, - device=device), - mask - ], - dim=-1) # noqa - return mask[None, None, :, :].expand(bsz, 1, tgt_len, - tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, - dtype: torch.dtype, - tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, - src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill( - inverted_mask.to(torch.bool), - torch.finfo(dtype).min) - - -class LlamaRMSNorm(nn.Module): - - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - variance = hidden_states.to(torch.float32).pow(2).mean( - -1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance - + self.variance_epsilon) - - # convert into half-precision if necessary - if self.weight.dtype in [torch.float16, torch.bfloat16]: - hidden_states = hidden_states.to(self.weight.dtype) - - return self.weight * hidden_states - - -class LlamaRotaryEmbedding(torch.nn.Module): - - def __init__(self, - dim, - max_position_embeddings=2048, - base=10000, - device=None): - super().__init__() - inv_freq = 1.0 / ( - base**(torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer('inv_freq', inv_freq) - - # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange( - self.max_seq_len_cached, - device=self.inv_freq.device, - dtype=self.inv_freq.dtype) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - 'cos_cached', emb.cos()[None, None, :, :], persistent=False) - self.register_buffer( - 'sin_cached', emb.sin()[None, None, :, :], persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. - if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, - device=x.device, - dtype=self.inv_freq.dtype) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer( - 'cos_cached', emb.cos()[None, None, :, :], persistent=False) - self.register_buffer( - 'sin_cached', emb.sin()[None, None, :, :], persistent=False) - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] - gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) - cos = torch.gather( - cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) - sin = torch.gather( - sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class LlamaMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - ): - super().__init__() - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.act_fn = ACT2FN[hidden_act] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - -class LlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: LlamaConfig): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}' - f' and `num_heads`: {self.num_heads}).') - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.rotary_emb = LlamaRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings) - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, - self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states).view( - bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view( - bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view( - bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids) - # [bsz, nh, t, hd] - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - attn_weights = torch.matmul(query_states, key_states.transpose( - 2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is' - f' {attn_weights.size()}') - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' - ) - attn_weights = attn_weights + attention_mask - attn_weights = torch.max( - attn_weights, - torch.tensor(torch.finfo(attn_weights.dtype).min)) - - # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is' - f' {attn_output.size()}') - - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaDecoderLayer(nn.Module): - - def __init__(self, config: LlamaConfig): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP( - hidden_size=self.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - ) - self.input_layernorm = LlamaRMSNorm( - config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm( - config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, - torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states, ) - - if output_attentions: - outputs += (self_attn_weights, ) - - if use_cache: - outputs += (present_key_value, ) - - return outputs - - -class LlamaPreTrainedModel(TorchModel, PreTrainedModel): - config_class = LlamaConfig - base_model_prefix = 'model' - supports_gradient_checkpointing = True - _no_split_modules = ['LlamaDecoderLayer'] - _keys_to_ignore_on_load_unexpected = [r'decoder\.version'] - - def __init__(self, config, **kwargs): - super().__init__(config.name_or_path, **kwargs) - super(Model, self).__init__(config) - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, LlamaModel): - module.gradient_checkpointing = value +class MsModelMixin: @classmethod def _instantiate(cls, **kwargs): @@ -416,272 +48,22 @@ class LlamaPreTrainedModel(TorchModel, PreTrainedModel): Returns: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained """ - model_dir = kwargs.pop('model_dir', None) if model_dir is None: config = LlamaConfig(**kwargs) model = cls(config) else: - model = super(Model, cls).from_pretrained( + model = super(MsModelMixin, cls).from_pretrained( pretrained_model_name_or_path=model_dir, **kwargs) model.model_dir = model_dir return model +class LlamaPreTrainedModel(MsModelMixin, LlamaPreTrainedModelHF, TorchModel): + pass + + +@MODELS.register_module(Tasks.backbone, module_name=Models.llama2) @MODELS.register_module(Tasks.backbone, module_name=Models.llama) -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig, **kwargs): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, - self.padding_idx) - self.layers = nn.ModuleList([ - LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers) - ]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, - inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, - tgt_len=input_shape[-1]).to(inputs_embeds.device) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else - expanded_attn_mask + combined_attention_mask) - - return combined_attention_mask - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, AttentionBackboneModelOutput]: - r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. - Padding will be ignored by default should you provide it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. - Selected in the range `[0, config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed - or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, - with each tuple having 2 tensors of shape - `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape - `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the - cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` - (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` - instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else - self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - 'You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time' - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - 'You have to specify either decoder_input_ids or decoder_inputs_embeds' - ) - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, - past_key_values_length) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states, ) - - past_key_value = past_key_values[ - idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += ( - layer_outputs[2 if output_attentions else 1], ) - - if output_attentions: - all_self_attns += (layer_outputs[1], ) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states, ) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v for v in - [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None) - return AttentionBackboneModelOutput( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) +class LlamaModel(MsModelMixin, LlamaModelHF, TorchModel): + pass diff --git a/modelscope/models/nlp/llama/configuration.py b/modelscope/models/nlp/llama/configuration.py deleted file mode 100644 index cab02410..00000000 --- a/modelscope/models/nlp/llama/configuration.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" LLaMA model configuration""" - -from transformers.configuration_utils import PretrainedConfig - -LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} - - -# This file is mainly copied from the llama code of transformers -class LlamaConfig(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`LlamaModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. Typically set this to something large - just in case (e.g., 512 or 1024 or 2048). - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-12): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings(`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - """ - model_type = 'llama' - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - hidden_act='silu', - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - tie_word_embeddings=False, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/modelscope/models/nlp/llama/convert_llama_weights_to_hf.py b/modelscope/models/nlp/llama/convert_llama_weights_to_hf.py index d1ad316c..8c0c8568 100644 --- a/modelscope/models/nlp/llama/convert_llama_weights_to_hf.py +++ b/modelscope/models/nlp/llama/convert_llama_weights_to_hf.py @@ -20,8 +20,8 @@ import shutil import json import torch +from transformers.models.llama import LlamaConfig -from .configuration import LlamaConfig from .text_generation import LlamaForTextGeneration # This file is mainly copied from the llama code of transformers diff --git a/modelscope/models/nlp/llama/text_generation.py b/modelscope/models/nlp/llama/text_generation.py index 119561c3..0a325df2 100644 --- a/modelscope/models/nlp/llama/text_generation.py +++ b/modelscope/models/nlp/llama/text_generation.py @@ -19,165 +19,97 @@ # limitations under the License. from typing import Dict, List, Optional, Tuple, Union -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss +import torch +from transformers.models.llama import LlamaForCausalLM from modelscope.metainfo import Models -from modelscope.models.base import Tensor, TorchModel +from modelscope.models.base import TorchModel from modelscope.models.builder import MODELS -from modelscope.outputs import AttentionTextGenerationModelOutput +from modelscope.outputs import OutputKeys from modelscope.utils.constant import Tasks -from modelscope.utils.streaming_output import \ - PretrainedModelStreamingOutputMixin -from .backbone import LlamaModel, LlamaPreTrainedModel +from .backbone import MsModelMixin + + +def get_chat_prompt(system: str, text: str, history: List[Tuple[str, str]], + max_length: int, tokenizer): + system_prompt = f'[INST] <>\n{system}\n<>\n\n' + system_ids = tokenizer(system_prompt, return_tensors='pt').input_ids + + text_prompt = f'{text.strip()} [/INST]' + text_ids = tokenizer(text_prompt, return_tensors='pt').input_ids + + prompt_length = system_ids.shape[-1] + text_ids.shape[-1] + if prompt_length > max_length: + raise RuntimeError( + f'prepend prompt length {prompt_length} is bigger than max_length {max_length}' + ) + + history_prompt = '' + history_ids_list = [] + # traverse history in reverse order + for user, bot in history[::-1]: + assert isinstance(user, str) + assert isinstance(bot, str) + round_prompt = f'{user.strip()} [/INST] {bot.strip()} [INST] ' + round_ids = tokenizer(round_prompt, return_tensors='pt').input_ids + if prompt_length + round_ids.shape[-1] > max_length: + # excess history should not be appended to the prompt + break + else: + history_prompt = round_prompt + history_prompt + history_ids_list = [round_ids] + history_ids_list + prompt_length += round_ids.shape[-1] + + prompt_list = [system_prompt, history_prompt, text_prompt] + prompt_ids_list = [system_ids] + history_ids_list + [text_ids] + + return ''.join(prompt_list), torch.cat(prompt_ids_list, dim=1) # This file is mainly copied from the llama code of transformers +@MODELS.register_module(Tasks.text_generation, module_name=Models.llama2) @MODELS.register_module(Tasks.text_generation, module_name=Models.llama) -class LlamaForTextGeneration(LlamaPreTrainedModel, - PretrainedModelStreamingOutputMixin): - _keys_to_ignore_on_load_missing = [r'lm_head.weight'] +class LlamaForTextGeneration(MsModelMixin, LlamaForCausalLM, TorchModel): - def __init__(self, config, **kwargs): - super().__init__(config) - self.model = LlamaModel(config) - - self.lm_head = nn.Linear( - config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, AttentionTextGenerationModelOutput]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - """ - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else - self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - logits = self.lm_head(hidden_states) - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits, ) + outputs[1:] - return (loss, ) + output if loss is not None else output - - # There is a conflict between the `ModelOutputBase` in the modelscope - # and the `send_to_device` function in the accelerate library. - # Temporarily change AttentionTextGenerationModelOutput to dict - return dict( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation(self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get('position_ids', None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {'inputs_embeds': inputs_embeds} + def chat(self, input: Dict, tokenizer) -> Dict: + import copy + gen_kwargs = copy.copy(input) + if 'text' not in input: + text: str = '' else: - model_inputs = {'input_ids': input_ids} + text: str = input['text'] + gen_kwargs.pop('text') - model_inputs.update({ - 'position_ids': position_ids, - 'past_key_values': past_key_values, - 'use_cache': kwargs.get('use_cache'), - 'attention_mask': attention_mask, - }) - return model_inputs + if 'system' not in input: + system: str = '' + else: + system: str = input['system'] + gen_kwargs.pop('system') - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += (tuple( - past_state.index_select(0, beam_idx) - for past_state in layer_past), ) - return reordered_past + if 'history' not in input: + history = [] + else: + history: List[Tuple] = copy.copy(input['history']) + gen_kwargs.pop('history') - def generate(self, inputs: Dict[str, Tensor], - **kwargs) -> Dict[str, Tensor]: - return super().generate(**inputs, **kwargs) + if 'max_length' not in gen_kwargs: + gen_kwargs['max_length'] = 4096 + + prompt, prompt_ids = get_chat_prompt( + system=system, + text=text, + history=history, + max_length=gen_kwargs['max_length'], + tokenizer=tokenizer) + input_ids = prompt_ids.to(self.device) + generate_ids = self.generate(input_ids, **gen_kwargs) + # remove input tokens + generate_ids = generate_ids[:, input_ids.shape[1]:] + response = tokenizer.batch_decode( + generate_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=False)[0] + response = response.strip() + history.append((text, response)) + + return {OutputKeys.RESPONSE: response, OutputKeys.HISTORY: history} diff --git a/modelscope/models/nlp/llama/tokenization.py b/modelscope/models/nlp/llama/tokenization.py deleted file mode 100644 index cd423683..00000000 --- a/modelscope/models/nlp/llama/tokenization.py +++ /dev/null @@ -1,272 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -# coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tokenization classes for LLaMA.""" -import os -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple - -import sentencepiece as spm -from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer - -from modelscope.utils.logger import get_logger - -# This file is mainly copied from the llama code of transformers -logger = get_logger() - -VOCAB_FILES_NAMES = {'vocab_file': 'tokenizer.model'} - -PRETRAINED_VOCAB_FILES_MAP = { - 'vocab_file': { - 'hf-internal-testing/llama-tokenizer': - 'https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model', - }, - 'tokenizer_file': { - 'hf-internal-testing/llama-tokenizer': - 'https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json', - }, -} -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - 'hf-internal-testing/llama-tokenizer': 2048, -} - - -class LlamaTokenizer(PreTrainedTokenizer): - """ - Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. - - Args: - vocab_file (`str`): - Path to the vocabulary file. - """ - - vocab_files_names = VOCAB_FILES_NAMES - pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP - max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - model_input_names = ['input_ids', 'attention_mask'] - - def __init__( - self, - vocab_file, - unk_token='', - bos_token='', - eos_token='', - pad_token=None, - sp_model_kwargs: Optional[Dict[str, Any]] = None, - add_bos_token=True, - add_eos_token=False, - clean_up_tokenization_spaces=False, - **kwargs, - ): - self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - bos_token = AddedToken( - bos_token, lstrip=False, rstrip=False) if isinstance( - bos_token, str) else bos_token - eos_token = AddedToken( - eos_token, lstrip=False, rstrip=False) if isinstance( - eos_token, str) else eos_token - unk_token = AddedToken( - unk_token, lstrip=False, rstrip=False) if isinstance( - unk_token, str) else unk_token - pad_token = AddedToken( - pad_token, lstrip=False, rstrip=False) if isinstance( - pad_token, str) else pad_token - super().__init__( - bos_token=bos_token, - eos_token=eos_token, - unk_token=unk_token, - pad_token=pad_token, - add_bos_token=add_bos_token, - add_eos_token=add_eos_token, - sp_model_kwargs=self.sp_model_kwargs, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - **kwargs, - ) - self.vocab_file = vocab_file - self.add_bos_token = add_bos_token - self.add_eos_token = add_eos_token - self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - self.sp_model.Load(vocab_file) - - def __getstate__(self): - state = self.__dict__.copy() - state['sp_model'] = None - return state - - def __setstate__(self, d): - self.__dict__ = d - self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - self.sp_model.Load(self.vocab_file) - - @property - def vocab_size(self): - """Returns vocab size""" - return self.sp_model.get_piece_size() - - def get_vocab(self): - """Returns vocab as a dict""" - vocab = { - self.convert_ids_to_tokens(i): i - for i in range(self.vocab_size) - } - vocab.update(self.added_tokens_encoder) - return vocab - - def _tokenize(self, text): - """Returns a tokenized string.""" - return self.sp_model.encode(text, out_type=str) - - def _convert_token_to_id(self, token): - """Converts a token (str) in an id using the vocab.""" - return self.sp_model.piece_to_id(token) - - def _convert_id_to_token(self, index): - """Converts an index (integer) in a token (str) using the vocab.""" - token = self.sp_model.IdToPiece(index) - return token - - def convert_tokens_to_string(self, tokens): - """Converts a sequence of tokens (string) in a single string.""" - current_sub_tokens = [] - out_string = '' - prev_is_special = False - for i, token in enumerate(tokens): - # make sure that special tokens are not decoded using sentencepiece model - if token in self.all_special_tokens: - if not prev_is_special and i != 0: - out_string += ' ' - out_string += self.sp_model.decode(current_sub_tokens) + token - prev_is_special = True - current_sub_tokens = [] - else: - current_sub_tokens.append(token) - prev_is_special = False - out_string += self.sp_model.decode(current_sub_tokens) - return out_string - - def save_vocabulary(self, - save_directory, - filename_prefix: Optional[str] = None) -> Tuple[str]: - """ - Save the vocabulary and special tokens file to a directory. - - Args: - save_directory (`str`): - The directory in which to save the vocabulary. - - Returns: - `Tuple(str)`: Paths to the files saved. - """ - if not os.path.isdir(save_directory): - logger.error( - f'Vocabulary path ({save_directory}) should be a directory') - return - out_vocab_file = os.path.join( - save_directory, (filename_prefix + '-' if filename_prefix else '') - + VOCAB_FILES_NAMES['vocab_file']) - - if os.path.abspath(self.vocab_file) != os.path.abspath( - out_vocab_file) and os.path.isfile(self.vocab_file): - copyfile(self.vocab_file, out_vocab_file) - elif not os.path.isfile(self.vocab_file): - with open(out_vocab_file, 'wb') as fi: - content_spiece_model = self.sp_model.serialized_model_proto() - fi.write(content_spiece_model) - - return (out_vocab_file, ) - - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - bos_token_id = [self.bos_token_id] if self.add_bos_token else [] - eos_token_id = [self.eos_token_id] if self.add_eos_token else [] - - output = bos_token_id + token_ids_0 + eos_token_id - - if token_ids_1 is not None: - output = output + bos_token_id + token_ids_1 + eos_token_id - - return output - - def get_special_tokens_mask( - self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None, - already_has_special_tokens: bool = False) -> List[int]: - """ - Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding - special tokens using the tokenizer `prepare_for_model` method. - - Args: - token_ids_0 (`List[int]`): - List of IDs. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - already_has_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not the token list is already formatted with special tokens for the model. - - Returns: - `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. - """ - if already_has_special_tokens: - return super().get_special_tokens_mask( - token_ids_0=token_ids_0, - token_ids_1=token_ids_1, - already_has_special_tokens=True) - - bos_token_id = [1] if self.add_bos_token else [] - eos_token_id = [1] if self.add_eos_token else [] - - if token_ids_1 is None: - return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id - return (bos_token_id + # noqa - ([0] * len(token_ids_0)) + eos_token_id + bos_token_id # noqa - + # noqa - ([0] * len(token_ids_1)) + eos_token_id) # noqa - - def create_token_type_ids_from_sequences( - self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None) -> List[int]: - """ - Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT - sequence pair mask has the following format: - - ``` - 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 - | first sequence | second sequence | - ``` - - if token_ids_1 is None, only returns the first portion of the mask (0s). - - Args: - token_ids_0 (`List[int]`): - List of ids. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - - Returns: - `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). - """ - sep = [self.sep_token_id] - cls = [self.cls_token_id] - - if token_ids_1 is None: - return len(cls + token_ids_0 + sep) * [0] - return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 - + sep) * [1] diff --git a/modelscope/models/nlp/llama/tokenization_fast.py b/modelscope/models/nlp/llama/tokenization_fast.py deleted file mode 100644 index 13696b59..00000000 --- a/modelscope/models/nlp/llama/tokenization_fast.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -# coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from shutil import copyfile -from typing import Optional, Tuple - -from transformers.tokenization_utils_fast import PreTrainedTokenizerFast -from transformers.utils import is_sentencepiece_available -from transformers.utils.versions import require_version - -from modelscope.utils.logger import get_logger - -# This file is mainly copied from the llama code of transformers -require_version('tokenizers>=0.13.3') - -if is_sentencepiece_available(): - from .tokenization import LlamaTokenizer -else: - LlamaTokenizer = None - -logger = get_logger() -VOCAB_FILES_NAMES = { - 'vocab_file': 'tokenizer.model', - 'tokenizer_file': 'tokenizer.json' -} - - -class LlamaTokenizerFast(PreTrainedTokenizerFast): - """ - Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. - - This uses notably ByteFallback and no normalization. - - ``` - from transformers import LlamaTokenizerFast - - tokenizer = LlaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") - tokenizer.encode("Hello this is a test") - >>> [1, 15043, 445, 338, 263, 1243] - ``` - - This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should - refer to this superclass for more information regarding those methods. - - Args: - vocab_file (`str`): - [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that - contains the vocabulary necessary to instantiate a tokenizer. - tokenizer_file (`str`): - [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that - contains everything needed to load the tokenizer. - - clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`): - Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra - spaces. - - bos_token (`str`, *optional*, defaults to `""`): - The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. - - eos_token (`str`, *optional*, defaults to `""`): - The end of sequence token. - - unk_token (`str`, *optional*, defaults to `""`): - The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this - token instead. - """ - - vocab_files_names = VOCAB_FILES_NAMES - slow_tokenizer_class = LlamaTokenizer - padding_side = 'left' - - def __init__( - self, - vocab_file=None, - tokenizer_file=None, - clean_up_tokenization_spaces=False, - unk_token='', - bos_token='', - eos_token='', - **kwargs, - ): - super().__init__( - vocab_file=vocab_file, - tokenizer_file=tokenizer_file, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - unk_token=unk_token, - bos_token=bos_token, - eos_token=eos_token, - **kwargs, - ) - - self.vocab_file = vocab_file - self.can_save_slow_tokenizer = False if not self.vocab_file else True - - def save_vocabulary(self, - save_directory: str, - filename_prefix: Optional[str] = None) -> Tuple[str]: - if not self.can_save_slow_tokenizer: - raise ValueError( - 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow ' - 'tokenizer.') - - if not os.path.isdir(save_directory): - logger.error( - f'Vocabulary path ({save_directory}) should be a directory') - return - out_vocab_file = os.path.join( - save_directory, (filename_prefix + '-' if filename_prefix else '') - + VOCAB_FILES_NAMES['vocab_file']) - - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): - copyfile(self.vocab_file, out_vocab_file) - - return (out_vocab_file, ) diff --git a/modelscope/models/nlp/llama2/__init__.py b/modelscope/models/nlp/llama2/__init__.py index 12a295b6..82bed042 100644 --- a/modelscope/models/nlp/llama2/__init__.py +++ b/modelscope/models/nlp/llama2/__init__.py @@ -1,23 +1,25 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from typing import TYPE_CHECKING +from modelscope.models.nlp.llama import LlamaConfig as Llama2Config +from modelscope.models.nlp.llama import LlamaTokenizer as Llama2Tokenizer +from modelscope.models.nlp.llama import \ + LlamaTokenizerFast as Llama2TokenizerFast from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .configuration import Llama2Config - from .text_generation import Llama2ForTextGeneration from .backbone import Llama2Model - from .tokenization import Llama2Tokenizer - from .tokenization_fast import Llama2TokenizerFast + from .text_generation import Llama2ForTextGeneration else: _import_structure = { - 'configuration': ['Llama2Config'], - 'text_generation': ['Llama2ForTextGeneration'], 'backbone': ['Llama2Model'], - 'tokenization': ['Llama2Tokenizer'], - 'tokenization_fast': ['Llama2TokenizerFast'], + 'text_generation': ['Llama2ForTextGeneration'], + } + _extra_objects = { + 'Llama2Config': Llama2Config, + 'Llama2Tokenizer': Llama2Tokenizer, + 'Llama2TokenizerFast': Llama2TokenizerFast, } - import sys sys.modules[__name__] = LazyImportModule( @@ -25,5 +27,5 @@ else: globals()['__file__'], _import_structure, module_spec=__spec__, - extra_objects={}, + extra_objects=_extra_objects, ) diff --git a/modelscope/models/nlp/llama2/backbone.py b/modelscope/models/nlp/llama2/backbone.py old mode 100755 new mode 100644 index ee0d742b..3627efef --- a/modelscope/models/nlp/llama2/backbone.py +++ b/modelscope/models/nlp/llama2/backbone.py @@ -1,795 +1,4 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch LLaMA model.""" -import math -from typing import List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast -from transformers.modeling_utils import PreTrainedModel - -from modelscope import Model, TorchModel from modelscope.metainfo import Models +from modelscope.models.builder import MODELS +from modelscope.models.nlp.llama import LlamaModel as LlamaModel2 from modelscope.utils.constant import Tasks -from modelscope.utils.logger import get_logger -from ... import MODELS -from .configuration import Llama2Config - -logger = get_logger() - -_CONFIG_FOR_DOC = 'Llama2Config' - - -# This file is mainly copied from the llama code of transformers -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask(input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0): - """ - Make causal mask used for bi-directional self-attention. - """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), - torch.finfo(dtype).min, - device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - - if past_key_values_length > 0: - _tmp_value = torch.zeros( - tgt_len, past_key_values_length, dtype=dtype, device=device) - mask = torch.cat([_tmp_value, mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, - tgt_len + past_key_values_length) - - -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, - dtype: torch.dtype, - tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, - src_len).to(dtype) - - inverted_mask = 1.0 - expanded_mask - - return inverted_mask.masked_fill( - inverted_mask.to(torch.bool), - torch.finfo(dtype).min) - - -class LlamaRMSNorm(nn.Module): - - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance - + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -class LlamaRotaryEmbedding(torch.nn.Module): - - def __init__(self, - dim, - max_position_embeddings=2048, - base=10000, - device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (self.base**(torch.arange( - 0, self.dim, 2).float().to(device) / self.dim)) # noqa - self.register_buffer('inv_freq', inv_freq) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype()) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - 'cos_cached', - emb.cos()[None, None, :, :].to(dtype), - persistent=False) - self.register_buffer( - 'sin_cached', - emb.sin()[None, None, :, :].to(dtype), - persistent=False) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache( - seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - ) - - -class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - - def __init__(self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - t = t / self.scaling_factor - - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - 'cos_cached', - emb.cos()[None, None, :, :].to(dtype), - persistent=False) - self.register_buffer( - 'sin_cached', - emb.sin()[None, None, :, :].to(dtype), - persistent=False) - - -class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - - def __init__(self, - dim, - max_position_embeddings=2048, - base=10000, - device=None, - scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - - if seq_len > self.max_position_embeddings: - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - - (self.scaling_factor - 1))**( - self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base**(torch.arange( - 0, self.dim, 2).float().to(device) / self.dim)) # noqa - self.register_buffer('inv_freq', inv_freq) - - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - 'cos_cached', - emb.cos()[None, None, :, :].to(dtype), - persistent=False) - self.register_buffer( - 'sin_cached', - emb.sin()[None, None, :, :].to(dtype), - persistent=False) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class LlamaMLP(nn.Module): - - def __init__(self, config): - super().__init__() - self.pretraining_tp = config.pretraining_tp - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear( - self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, x): - if self.pretraining_tp > 1: - slice = self.intermediate_size // self.pretraining_tp - gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) - up_proj_slices = self.up_proj.weight.split(slice, dim=0) - down_proj_slices = self.down_proj.weight.split(slice, dim=1) - - gate_proj = torch.cat([ - F.linear(x, gate_proj_slices[i]) - for i in range(self.pretraining_tp) - ], - dim=-1) # noqa - up_proj = torch.cat([ - F.linear(x, up_proj_slices[i]) - for i in range(self.pretraining_tp) - ], - dim=-1) # noqa - - intermediate_states = (self.act_fn(gate_proj) * up_proj).split( - slice, dim=2) - down_proj = [ - F.linear(intermediate_states[i], down_proj_slices[i]) - for i in range(self.pretraining_tp) - ] - down_proj = sum(down_proj) - else: - down_proj = self.down_proj( - self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - - return down_proj - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, - None, :, :].expand(batch, - num_key_value_heads, - n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, - head_dim) - - -class LlamaAttention(nn.Module): - """Multi-headed attention from 'Attention Is All You Need' paper""" - - def __init__(self, config: Llama2Config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.hidden_size // self.num_heads - self.num_key_value_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.pretraining_tp = config.pretraining_tp - self.max_position_embeddings = config.max_position_embeddings - - if (self.head_dim * self.num_heads) != self.hidden_size: - raise ValueError( - f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}' - f' and `num_heads`: {self.num_heads}).') - self.q_proj = nn.Linear( - self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=False) - self.v_proj = nn.Linear( - self.hidden_size, - self.num_key_value_heads * self.head_dim, - bias=False) - self.o_proj = nn.Linear( - self.num_heads * self.head_dim, self.hidden_size, bias=False) - self._init_rope() - - def _init_rope(self): - if self.config.rope_scaling is None: - self.rotary_emb = LlamaRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings) - else: - scaling_type = self.config.rope_scaling['type'] - scaling_factor = self.config.rope_scaling['factor'] - if scaling_type == 'linear': - self.rotary_emb = LlamaLinearScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor) - elif scaling_type == 'dynamic': - self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - scaling_factor=scaling_factor) - else: - raise ValueError(f'Unknown RoPE scaling type {scaling_type}') - - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, - self.head_dim).transpose(1, 2).contiguous() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() - - if self.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads - * self.head_dim) // self.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) - for i in range(self.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) - for i in range(self.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) - for i in range(self.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - past_key_value = (key_states, value_states) if use_cache else None - - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - attn_weights = torch.matmul(query_states, key_states.transpose( - 2, 3)) / math.sqrt(self.head_dim) - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is' - f' {attn_weights.size()}') - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax( - attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_output = torch.matmul(attn_weights, value_states) - - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is' - f' {attn_output.size()}') - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - - if self.pretraining_tp > 1: - attn_output = attn_output.split( - self.hidden_size // self.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.pretraining_tp, dim=1) - attn_output = sum([ - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.pretraining_tp) - ]) - else: - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class LlamaDecoderLayer(nn.Module): - - def __init__(self, config: Llama2Config): - super().__init__() - self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config=config) - self.mlp = LlamaMLP(config) - self.input_layernorm = LlamaRMSNorm( - config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = LlamaRMSNorm( - config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, - torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - """ - - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - outputs = (hidden_states, ) - - if output_attentions: - outputs += (self_attn_weights, ) - - if use_cache: - outputs += (present_key_value, ) - - return outputs - - -class LlamaPreTrainedModel(TorchModel, PreTrainedModel): - config_class = Llama2Config - base_model_prefix = 'model' - supports_gradient_checkpointing = True - _no_split_modules = ['LlamaDecoderLayer'] - _skip_keys_device_placement = 'past_key_values' - - def __init__(self, config, **kwargs): - super().__init__(config.name_or_path, **kwargs) - super(Model, self).__init__(config) - - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, Llama2Model): - module.gradient_checkpointing = value - - @classmethod - def _instantiate(cls, **kwargs): - """Instantiate the model. - - Args: - kwargs: Input args. - model_dir: The model dir used to load the checkpoint and the label information. - num_labels: An optional arg to tell the model how many classes to initialize. - Method will call utils.parse_label_mapping if num_labels not supplied. - If num_labels is not found, the model will use the default setting (2 classes). - - Returns: - The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained - """ - - model_dir = kwargs.pop('model_dir', None) - if model_dir is None: - config = Llama2Config(**kwargs) - model = cls(config) - else: - model = super(Model, cls).from_pretrained( - pretrained_model_name_or_path=model_dir, **kwargs) - model.model_dir = model_dir - return model - - -@MODELS.register_module(Tasks.backbone, module_name=Models.llama2) -class Llama2Model(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: Llama2Config - """ - - def __init__(self, config: Llama2Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, - self.padding_idx) - self.layers = nn.ModuleList([ - LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers) - ]) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, - inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, - tgt_len=input_shape[-1]).to(inputs_embeds.device) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else - expanded_attn_mask + combined_attention_mask) - - return combined_attention_mask - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else - self.config.output_hidden_states) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - 'You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time' - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - 'You have to specify either decoder_input_ids or decoder_inputs_embeds' - ) - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, - past_key_values_length) - - hidden_states = inputs_embeds - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...' - ) - use_cache = False - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states, ) - - past_key_value = past_key_values[ - idx] if past_key_values is not None else None - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += ( - layer_outputs[2 if output_attentions else 1], ) - - if output_attentions: - all_self_attns += (layer_outputs[1], ) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states, ) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v for v in - [hidden_states, next_cache, all_hidden_states, all_self_attns] - if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) diff --git a/modelscope/models/nlp/llama2/configuration.py b/modelscope/models/nlp/llama2/configuration.py deleted file mode 100644 index c9f38fe4..00000000 --- a/modelscope/models/nlp/llama2/configuration.py +++ /dev/null @@ -1,165 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" LLaMA model configuration""" - -from transformers.configuration_utils import PretrainedConfig - -from modelscope.utils.logger import get_logger - -logger = get_logger() - -LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} - - -class Llama2Config(PretrainedConfig): - r""" - This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA - model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the LLaMA-7B. - - Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the - documentation from [`PretrainedConfig`] for more information. - - - Args: - vocab_size (`int`, *optional*, defaults to 32000): - Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`LlamaModel`] - hidden_size (`int`, *optional*, defaults to 4096): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 11008): - Dimension of the MLP representations. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer encoder. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer encoder. - num_key_value_heads (`int`, *optional*): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details checkout [this - paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to - `num_attention_heads`. - pretraining_tp (`int`, *optional*, defaults to `1`): - Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this - document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is - necessary to ensure exact reproducibility of the pretraining results. Please refer to [this - issue](https://github.com/pytorch/pytorch/issues/76232). - hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): - The non-linear activation function (function or string) in the decoder. - max_position_embeddings (`int`, *optional*, defaults to 2048): - The maximum sequence length that this model might ever be used with. Typically set this to something large - just in case (e.g., 512 or 1024 or 2048). - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - rms_norm_eps (`float`, *optional*, defaults to 1e-12): - The epsilon used by the rms normalization layers. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - tie_word_embeddings(`bool`, *optional*, defaults to `False`): - Whether to tie weight embeddings - rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling - strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format - is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update - `max_position_embeddings` to the expected new maximum. See the following thread for more information on how - these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an - experimental feature, subject to breaking API changes in future versions. - """ - model_type = 'llama' - keys_to_ignore_at_inference = ['past_key_values'] - - def __init__( - self, - vocab_size=32000, - hidden_size=4096, - intermediate_size=11008, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=None, - hidden_act='silu', - max_position_embeddings=2048, - initializer_range=0.02, - rms_norm_eps=1e-6, - use_cache=True, - pad_token_id=0, - bos_token_id=1, - eos_token_id=2, - pretraining_tp=1, - tie_word_embeddings=False, - rope_scaling=None, - **kwargs, - ): - self.vocab_size = vocab_size - self.max_position_embeddings = max_position_embeddings - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - - # for backward compatibility - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.pretraining_tp = pretraining_tp - self.use_cache = use_cache - self.rope_scaling = rope_scaling - self._rope_scaling_validation() - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) - - def _rope_scaling_validation(self): - """ - Validate the `rope_scaling` configuration. - """ - if self.rope_scaling is None: - return - - if not isinstance(self.rope_scaling, - dict) or len(self.rope_scaling) != 2: - raise ValueError( - '`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, ' - f'got {self.rope_scaling}') - rope_scaling_type = self.rope_scaling.get('type', None) - rope_scaling_factor = self.rope_scaling.get('factor', None) - if rope_scaling_type is None or rope_scaling_type not in [ - 'linear', 'dynamic' - ]: - raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance( - rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError( - f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}" - ) diff --git a/modelscope/models/nlp/llama2/text_generation.py b/modelscope/models/nlp/llama2/text_generation.py index 71ccaffe..4d77d3f1 100644 --- a/modelscope/models/nlp/llama2/text_generation.py +++ b/modelscope/models/nlp/llama2/text_generation.py @@ -1,268 +1,5 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Dict, List, Optional, Tuple, Union - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss -from transformers.modeling_outputs import CausalLMOutputWithPast - from modelscope.metainfo import Models -from modelscope.outputs import OutputKeys +from modelscope.models.builder import MODELS +from modelscope.models.nlp.llama import \ + LlamaForTextGeneration as Llama2ForTextGeneration from modelscope.utils.constant import Tasks -from ... import MODELS -from .backbone import Llama2Model, LlamaPreTrainedModel - - -def get_chat_prompt(system: str, text: str, history: List[Tuple[str, str]], - max_length: int, tokenizer): - system_prompt = f'[INST] <>\n{system}\n<>\n\n' - system_ids = tokenizer(system_prompt, return_tensors='pt').input_ids - - text_prompt = f'{text.strip()} [/INST]' - text_ids = tokenizer(text_prompt, return_tensors='pt').input_ids - - prompt_length = system_ids.shape[-1] + text_ids.shape[-1] - if prompt_length > max_length: - raise RuntimeError( - f'prepend prompt length {prompt_length} is bigger than max_length {max_length}' - ) - - history_prompt = '' - history_ids_list = [] - # traverse history in reverse order - for user, bot in history[::-1]: - assert isinstance(user, str) - assert isinstance(bot, str) - round_prompt = f'{user.strip()} [/INST] {bot.strip()} [INST] ' - round_ids = tokenizer(round_prompt, return_tensors='pt').input_ids - if prompt_length + round_ids.shape[-1] > max_length: - # excess history should not be appended to the prompt - break - else: - history_prompt = round_prompt + history_prompt - history_ids_list = [round_ids] + history_ids_list - prompt_length += round_ids.shape[-1] - - prompt_list = [system_prompt, history_prompt, text_prompt] - prompt_ids_list = [system_ids] + history_ids_list + [text_ids] - - return ''.join(prompt_list), torch.cat(prompt_ids_list, dim=1) - - -# This file is mainly copied from the llama code of transformers -@MODELS.register_module(Tasks.text_generation, module_name=Models.llama2) -class Llama2ForTextGeneration(LlamaPreTrainedModel): - _tied_weights_keys = ['lm_head.weight'] - - def __init__(self, config): - super().__init__(config) - self.model = Llama2Model(config) - self.pretraining_tp = config.pretraining_tp - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear( - config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens - - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - """ - - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else - self.config.output_hidden_states) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - - hidden_states = outputs[0] - if self.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split( - self.vocab_size // self.pretraining_tp, dim=0) - logits = [ - F.linear(hidden_states, lm_head_slices[i]) - for i in range(self.pretraining_tp) - ] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() - - loss = None - if labels is not None: - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits, ) + outputs[1:] - return (loss, ) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - def prepare_inputs_for_generation(self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - **kwargs): - if past_key_values: - input_ids = input_ids[:, -1:] - - position_ids = kwargs.get('position_ids', None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {'inputs_embeds': inputs_embeds} - else: - model_inputs = {'input_ids': input_ids} - - model_inputs.update({ - 'position_ids': position_ids, - 'past_key_values': past_key_values, - 'use_cache': kwargs.get('use_cache'), - 'attention_mask': attention_mask, - }) - return model_inputs - - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += (tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past), ) - return reordered_past - - def chat(self, input: Dict, tokenizer) -> Dict: - import copy - gen_kwargs = copy.copy(input) - if 'text' not in input: - text: str = '' - else: - text: str = input['text'] - gen_kwargs.pop('text') - - if 'system' not in input: - system: str = '' - else: - system: str = input['system'] - gen_kwargs.pop('system') - - if 'history' not in input: - history = [] - else: - history: List[Tuple] = copy.copy(input['history']) - gen_kwargs.pop('history') - - if 'max_length' not in gen_kwargs: - gen_kwargs['max_length'] = 4096 - - prompt, prompt_ids = get_chat_prompt( - system=system, - text=text, - history=history, - max_length=gen_kwargs['max_length'], - tokenizer=tokenizer) - input_ids = prompt_ids.to(self.device) - generate_ids = self.generate(input_ids, **gen_kwargs) - # remove input tokens - generate_ids = generate_ids[:, input_ids.shape[1]:] - response = tokenizer.batch_decode( - generate_ids, - skip_special_tokens=True, - clean_up_tokenization_spaces=False)[0] - response = response.strip() - history.append((text, response)) - - return {OutputKeys.RESPONSE: response, OutputKeys.HISTORY: history} diff --git a/modelscope/models/nlp/llama2/tokenization.py b/modelscope/models/nlp/llama2/tokenization.py deleted file mode 100644 index bb276621..00000000 --- a/modelscope/models/nlp/llama2/tokenization.py +++ /dev/null @@ -1,410 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tokenization classes for LLaMA.""" -import os -from shutil import copyfile -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple - -import sentencepiece as spm -from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer - -from modelscope.utils.logger import get_logger - -if TYPE_CHECKING: - from transformers.pipelines.conversational import Conversation - -logger = get_logger() - -VOCAB_FILES_NAMES = {'vocab_file': 'tokenizer.model'} - -PRETRAINED_VOCAB_FILES_MAP = { - 'vocab_file': { - 'hf-internal-testing/llama-tokenizer': - 'https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model', - }, - 'tokenizer_file': { - 'hf-internal-testing/llama-tokenizer': - 'https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json', - }, -} -PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - 'hf-internal-testing/llama-tokenizer': 2048, -} -SPIECE_UNDERLINE = '▁' - -B_INST, E_INST = '[INST]', '[/INST]' -B_SYS, E_SYS = '<>\n', '\n<>\n\n' - -# fmt: off -DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your\ -answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ -that your responses are socially unbiased and positive in nature. - -If a question does not make any sense, or is not factually coherent, explain why instead of answering something not\ -correct. If you don't know the answer to a question, please don't share false information.""" -# fmt: on - - -class Llama2Tokenizer(PreTrainedTokenizer): - """ - Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is - no padding token in the original model. - - Args: - vocab_file (`str`): - Path to the vocabulary file. - legacy (`bool`, *optional*, defaults to `True`): - Whether or not the `legacy` behaviour of the tokenizer should be used. Legacy is before the merge of #24622 - which includes fixes to properly handle tokens that appear after special tokens. A simple example: - - - `legacy=True`: - ```python - >>> from transformers import T5Tokenizer - - >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=True) - >>> tokenizer.encode("Hello .") - [8774, 32099, 3, 5, 1] - ``` - - `legacy=False`: - ```python - >>> from transformers import T5Tokenizer - - >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=False) - >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here - [8774, 32099, 5, 1] - ``` - Checkout the pull request and the issue [here](https://github.com/huggingface/transformers/pull/24565) for - more details. - - """ - - vocab_files_names = VOCAB_FILES_NAMES - pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP - max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - model_input_names = ['input_ids', 'attention_mask'] - - def __init__( - self, - vocab_file, - unk_token='', - bos_token='', - eos_token='', - pad_token=None, - sp_model_kwargs: Optional[Dict[str, Any]] = None, - add_bos_token=True, - add_eos_token=False, - clean_up_tokenization_spaces=False, - legacy=True, - **kwargs, - ): - self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs - bos_token = AddedToken( - bos_token, lstrip=False, rstrip=False) if isinstance( - bos_token, str) else bos_token - eos_token = AddedToken( - eos_token, lstrip=False, rstrip=False) if isinstance( - eos_token, str) else eos_token - unk_token = AddedToken( - unk_token, lstrip=False, rstrip=False) if isinstance( - unk_token, str) else unk_token - pad_token = AddedToken( - pad_token, lstrip=False, rstrip=False) if isinstance( - pad_token, str) else pad_token - super().__init__( - bos_token=bos_token, - eos_token=eos_token, - unk_token=unk_token, - pad_token=pad_token, - add_bos_token=add_bos_token, - add_eos_token=add_eos_token, - sp_model_kwargs=self.sp_model_kwargs, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - legacy=legacy, - **kwargs, - ) - if legacy: - logger.warning_once( - f'You are using the legacy behaviour of the {self.__class__}. ' - f'This means that tokens that come after special ' - f'tokens will not be properly handled. We recommend you to' - ' read the related pull request available at https://github.com/huggingface/transformers/pull/24565' - ) - self.legacy = legacy - self.vocab_file = vocab_file - self.add_bos_token = add_bos_token - self.add_eos_token = add_eos_token - self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - self.sp_model.Load(vocab_file) - - def __getstate__(self): - state = self.__dict__.copy() - state['sp_model'] = None - state['sp_model_proto'] = self.sp_model.serialized_model_proto() - return state - - def __setstate__(self, d): - self.__dict__ = d - self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) - self.sp_model.LoadFromSerializedProto(self.sp_model_proto) - - @property - def vocab_size(self): - """Returns vocab size""" - return self.sp_model.get_piece_size() - - def get_vocab(self): - """Returns vocab as a dict""" - vocab = { - self.convert_ids_to_tokens(i): i - for i in range(self.vocab_size) - } - vocab.update(self.added_tokens_encoder) - return vocab - - # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize - def tokenize(self, text, **kwargs) -> List[str]: - # Replace the SPIECE_UNDERLINE with a space to make sure SPIECE_UNDERLINE is only used at - # the beginning of the text - if not self.legacy: - text = SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, ' ') - return super().tokenize(text, **kwargs) - - # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize - def _tokenize(self, text): - """ - Returns a tokenized string. - - Since the sentencepiece internal model always adds a SPIECE_UNDERLINE, at the beginning of the provided text, - we need to remove it by hand when the current text is a subsequence. This happens whenever the `self.tokenize` - function is called with specials tokens: the input is split on the special tokens, and each subsequence is - passed to `_tokenize`. Thus if a subsequence did not start with a `" "` or SPIECE_UNDERLINE, we have to remove - the extra `SPIECE_UNDERLINE` prepended. - """ - if not self.legacy: - is_first = text.startswith(SPIECE_UNDERLINE) - if is_first: - text = text[1:] - - tokens = self.sp_model.encode(text, out_type=str) - - if not self.legacy and not is_first and not text.startswith( - ' ') and tokens[0].startswith(SPIECE_UNDERLINE): - tokens = ([tokens[0][1:]] - if len(tokens[0]) > 1 else []) + tokens[1:] - return tokens - - def _convert_token_to_id(self, token): - """Converts a token (str) in an id using the vocab.""" - return self.sp_model.piece_to_id(token) - - def _convert_id_to_token(self, index): - """Converts an index (integer) in a token (str) using the vocab.""" - token = self.sp_model.IdToPiece(index) - return token - - def convert_tokens_to_string(self, tokens): - """Converts a sequence of tokens (string) in a single string.""" - current_sub_tokens = [] - out_string = '' - prev_is_special = False - for i, token in enumerate(tokens): - # make sure that special tokens are not decoded using sentencepiece model - if token in self.all_special_tokens: - if not prev_is_special and i != 0: - out_string += ' ' - out_string += self.sp_model.decode(current_sub_tokens) + token - prev_is_special = True - current_sub_tokens = [] - else: - current_sub_tokens.append(token) - prev_is_special = False - out_string += self.sp_model.decode(current_sub_tokens) - return out_string - - def save_vocabulary(self, - save_directory, - filename_prefix: Optional[str] = None) -> Tuple[str]: - """ - Save the vocabulary and special tokens file to a directory. - - Args: - save_directory (`str`): - The directory in which to save the vocabulary. - - Returns: - `Tuple(str)`: Paths to the files saved. - """ - if not os.path.isdir(save_directory): - logger.error( - f'Vocabulary path ({save_directory}) should be a directory') - return - out_vocab_file = os.path.join( - save_directory, (filename_prefix + '-' if filename_prefix else '') - + VOCAB_FILES_NAMES['vocab_file']) - - if os.path.abspath(self.vocab_file) != os.path.abspath( - out_vocab_file) and os.path.isfile(self.vocab_file): - copyfile(self.vocab_file, out_vocab_file) - elif not os.path.isfile(self.vocab_file): - with open(out_vocab_file, 'wb') as fi: - content_spiece_model = self.sp_model.serialized_model_proto() - fi.write(content_spiece_model) - - return (out_vocab_file, ) - - def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - bos_token_id = [self.bos_token_id] if self.add_bos_token else [] - eos_token_id = [self.eos_token_id] if self.add_eos_token else [] - - output = bos_token_id + token_ids_0 + eos_token_id - - if token_ids_1 is not None: - output = output + bos_token_id + token_ids_1 + eos_token_id - - return output - - def get_special_tokens_mask( - self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None, - already_has_special_tokens: bool = False) -> List[int]: - """ - Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding - special tokens using the tokenizer `prepare_for_model` method. - - Args: - token_ids_0 (`List[int]`): - List of IDs. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - already_has_special_tokens (`bool`, *optional*, defaults to `False`): - Whether or not the token list is already formatted with special tokens for the model. - - Returns: - `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. - """ - if already_has_special_tokens: - return super().get_special_tokens_mask( - token_ids_0=token_ids_0, - token_ids_1=token_ids_1, - already_has_special_tokens=True) - - bos_token_id = [1] if self.add_bos_token else [] - eos_token_id = [1] if self.add_eos_token else [] - - if token_ids_1 is None: - return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id - return bos_token_id + ( - [0] * len(token_ids_0)) + eos_token_id + bos_token_id + ( - [0] * len(token_ids_1)) + eos_token_id # noqa - - def create_token_type_ids_from_sequences( - self, - token_ids_0: List[int], - token_ids_1: Optional[List[int]] = None) -> List[int]: - """ - Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT - sequence pair mask has the following format: - - ``` - 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 - | first sequence | second sequence | - ``` - - if token_ids_1 is None, only returns the first portion of the mask (0s). - - Args: - token_ids_0 (`List[int]`): - List of ids. - token_ids_1 (`List[int]`, *optional*): - Optional second list of IDs for sequence pairs. - - Returns: - `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). - """ - bos_token_id = [self.bos_token_id] if self.add_bos_token else [] - eos_token_id = [self.eos_token_id] if self.add_eos_token else [] - - output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) - - if token_ids_1 is not None: - output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) - - return output - - def _build_conversation_input_ids( - self, conversation: 'Conversation') -> List[int]: - """Builds the input ids for a conversation. - This is the format used in the provided examples. System prompts should be manually added at the beginning of - the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used. - ``` - [INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer - [INST] Prompt [/INST] Answer - [INST] Prompt [/INST] - ``` - - If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following: - ```python - >>> from transformers import Conversation - - >>> Conversation( - ... "<>\n Only answer with emojis, and charades\n<>\n\nHow can I build a house in 10 septs?" - ... ) - ``` - Args: - conversation (`Conversation`): - Conversation to build input ids for. - Returns: - `List[int]`: - Input ids for the conversation. - """ - dialogue = list(conversation.iter_texts()) - if not all([is_user for is_user, msg in dialogue[::2]]) or not all( - [not is_user for is_user, msg in dialogue[1::2]]): # noqa - raise ValueError( - "The model only supports 'user' and 'assistant' roles, " - 'starting with user and alternating (u/a/u/a/u...)') - - dialog_tokens: List[int] = [] - if len(conversation.past_user_inputs) > 0: - if not conversation.past_user_inputs[0].startswith( - B_SYS) or E_SYS not in conversation.past_user_inputs[0]: - conversation.past_user_inputs[0] = ( - B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS - + conversation.past_user_inputs[0]) - elif not dialogue[0][1].startswith( - B_SYS) or E_SYS not in dialogue[0][1]: - dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT - + E_SYS + dialogue[0][1]) - - dialog_tokens += sum( - [[self.bos_token_id] + self.encode( - f'{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ', - add_special_tokens=False) + [self.eos_token_id] - for prompt, answer in zip(dialogue[::2], dialogue[1::2])], - [], - ) - if not (dialogue[-1][0]): - raise ValueError( - f"Last message must be from user, got {dialogue[-1]['role']}") - dialog_tokens += [self.bos_token_id] + self.encode( - f'{B_INST} {(dialogue[-1][1]).strip()} {E_INST}', - add_special_tokens=False) - return dialog_tokens diff --git a/modelscope/models/nlp/llama2/tokenization_fast.py b/modelscope/models/nlp/llama2/tokenization_fast.py deleted file mode 100644 index 13862955..00000000 --- a/modelscope/models/nlp/llama2/tokenization_fast.py +++ /dev/null @@ -1,251 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -# Copyright 2020 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from shutil import copyfile -from typing import TYPE_CHECKING, Optional, Tuple - -from tokenizers import processors -from transformers.tokenization_utils_fast import PreTrainedTokenizerFast -from transformers.utils import is_sentencepiece_available -from transformers.utils.versions import require_version - -from modelscope.utils import logger as logging - -if TYPE_CHECKING: - from transformers.pipelines.conversational import Conversation - -require_version('tokenizers>=0.13.3') - -if is_sentencepiece_available(): - from .tokenization import Llama2Tokenizer -else: - Llama2Tokenizer = None - -logger = logging.get_logger() -VOCAB_FILES_NAMES = { - 'vocab_file': 'tokenizer.model', - 'tokenizer_file': 'tokenizer.json' -} - -B_INST, E_INST = '[INST]', '[/INST]' -B_SYS, E_SYS = '<>\n', '\n<>\n\n' - -# fmt: off -DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your\ -answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ -that your responses are socially unbiased and positive in nature. - -If a question does not make any sense, or is not factually coherent, explain why instead of answering something not\ -correct. If you don't know the answer to a question, please don't share false information.""" -# fmt: on - - -class Llama2TokenizerFast(PreTrainedTokenizerFast): - """ - Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. - - This uses notably ByteFallback and no normalization. - - ``` - from transformers import LlamaTokenizerFast - - tokenizer = LlaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") - tokenizer.encode("Hello this is a test") - >>> [1, 15043, 445, 338, 263, 1243] - ``` - - If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or - call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the - values of the first token and final token of an encoded sequence will not be correct). For more details, checkout - [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation. - - - This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should - refer to this superclass for more information regarding those methods. - - Args: - vocab_file (`str`): - [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that - contains the vocabulary necessary to instantiate a tokenizer. - tokenizer_file (`str`): - [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that - contains everything needed to load the tokenizer. - - clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`): - Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra - spaces. - - bos_token (`str`, *optional*, defaults to `""`): - The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. - - eos_token (`str`, *optional*, defaults to `""`): - The end of sequence token. - - unk_token (`str`, *optional*, defaults to `""`): - The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this - token instead. - """ - - vocab_files_names = VOCAB_FILES_NAMES - slow_tokenizer_class = Llama2Tokenizer - padding_side = 'left' - model_input_names = ['input_ids', 'attention_mask'] - - def __init__( - self, - vocab_file=None, - tokenizer_file=None, - clean_up_tokenization_spaces=False, - unk_token='', - bos_token='', - eos_token='', - add_bos_token=True, - add_eos_token=False, - **kwargs, - ): - super().__init__( - vocab_file=vocab_file, - tokenizer_file=tokenizer_file, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - unk_token=unk_token, - bos_token=bos_token, - eos_token=eos_token, - **kwargs, - ) - self._add_bos_token = add_bos_token - self._add_eos_token = add_eos_token - self.update_post_processor() - - self.vocab_file = vocab_file - self.can_save_slow_tokenizer = False if not self.vocab_file else True - - def update_post_processor(self): - """ - Updates the underlying post processor with the current `bos_token` and `eos_token`. - """ - bos = self.bos_token - bos_token_id = self.bos_token_id - - eos = self.eos_token - eos_token_id = self.eos_token_id - - single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') * self.add_eos_token}" - pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') * self.add_eos_token}" - - special_tokens = [] - if self.add_bos_token: - special_tokens.append((bos, bos_token_id)) - if self.add_eos_token: - special_tokens.append((eos, eos_token_id)) - self._tokenizer.post_processor = processors.TemplateProcessing( - single=single, pair=pair, special_tokens=special_tokens) - - @property - def add_eos_token(self): - return self._add_eos_token - - @property - def add_bos_token(self): - return self._add_bos_token - - @add_eos_token.setter - def add_eos_token(self, value): - self._add_eos_token = value - self.update_post_processor() - - @add_bos_token.setter - def add_bos_token(self, value): - self._add_bos_token = value - self.update_post_processor() - - def save_vocabulary(self, - save_directory: str, - filename_prefix: Optional[str] = None) -> Tuple[str]: - if not self.can_save_slow_tokenizer: - raise ValueError( - 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow ' - 'tokenizer.') - - if not os.path.isdir(save_directory): - logger.error( - f'Vocabulary path ({save_directory}) should be a directory') - return - out_vocab_file = os.path.join( - save_directory, (filename_prefix + '-' if filename_prefix else '') - + VOCAB_FILES_NAMES['vocab_file']) - - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): - copyfile(self.vocab_file, out_vocab_file) - - return (out_vocab_file, ) - - def _build_conversation_input_ids(self, conversation: 'Conversation'): - """Builds the input ids for a conversation. - This is the format used in the provided examples. System prompts should be manually added at the beginning of - the conversation. If no system prompt is given, the `DEFAULT_SYSTEM_PROMPT` will be used. - ``` - [INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer - [INST] Prompt [/INST] Answer - [INST] Prompt [/INST] - ``` - - If you want to use your own system prompt, make sure to use both `B_SYS` and `E_SYS` use the following: - ```python - >>> from transformers import Conversation - - >>> Conversation( - ... "<>\n Only answer with emojis, and charades\n<>\n\nHow can I build a house in 10 septs?" - ... ) - ``` - Args: - conversation (`Conversation`): - Conversation to build input ids for. - Returns: - `List[int]`: - Input ids for the conversation. - """ - dialogue = list(conversation.iter_texts()) - if not all([is_user for is_user, msg in dialogue[::2]]) or not all( - [not is_user for is_user, msg in dialogue[1::2]]): # noqa - raise ValueError( - "The model only supports 'user' and 'assistant' roles, " - 'starting with user and alternating (u/a/u/a/u...)') - - dialog_tokens = [] - if len(conversation.past_user_inputs) > 0: - if not conversation.past_user_inputs[0].startswith( - B_SYS) or E_SYS not in conversation.past_user_inputs[0]: - conversation.past_user_inputs[0] = ( - B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS - + conversation.past_user_inputs[0]) - elif not dialogue[0][1].startswith( - B_SYS) or E_SYS not in dialogue[0][1]: - dialogue[0] = (dialogue[0][0], B_SYS + DEFAULT_SYSTEM_PROMPT - + E_SYS + dialogue[0][1]) - - dialog_tokens += sum( - [[self.bos_token_id] + self.encode( - f'{B_INST} {(prompt[1]).strip()} {E_INST} {(answer[1]).strip()} ', - add_special_tokens=False) + [self.eos_token_id] - for prompt, answer in zip(dialogue[::2], dialogue[1::2])], - [], - ) - if not (dialogue[-1][0]): - raise ValueError( - f"Last message must be from user, got {dialogue[-1]['role']}") - dialog_tokens += [self.bos_token_id] + self.encode( - f'{B_INST} {(dialogue[-1][1]).strip()} {E_INST}', - add_special_tokens=False) - return dialog_tokens diff --git a/modelscope/pipelines/nlp/text_generation_pipeline.py b/modelscope/pipelines/nlp/text_generation_pipeline.py index edfd01cf..dc6bfdb8 100644 --- a/modelscope/pipelines/nlp/text_generation_pipeline.py +++ b/modelscope/pipelines/nlp/text_generation_pipeline.py @@ -19,6 +19,7 @@ from modelscope.utils.chinese_utils import remove_space_between_chinese_chars from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.hub import Config, read_config from modelscope.utils.streaming_output import PipelineStreamingOutputMixin +from modelscope.utils.torch_utils import is_on_same_device __all__ = [ 'TextGenerationPipeline', @@ -242,23 +243,39 @@ class ChatGLM6bV2TextGenerationPipeline(Pipeline): quantization_bit=None, use_bf16=False, **kwargs): - from modelscope.models.nlp import (ChatGLM2Config, - ChatGLM2ForConditionalGeneration, - ChatGLM2Tokenizer) + from modelscope import AutoTokenizer + device: str = kwargs.get('device', 'gpu') if isinstance(model, str): + revision = kwargs.get('revision', None) model_dir = snapshot_download( - model) if not os.path.exists(model) else model - model = ChatGLM2ForConditionalGeneration.from_pretrained(model_dir) - if torch.cuda.is_available(): - model = model.cuda() + model, + revision=revision) if not os.path.exists(model) else model + default_device_map = None + if device.startswith('gpu') or device.startswith('cuda'): + default_device_map = {'': 0} + device_map = kwargs.get('device_map', default_device_map) + default_torch_dtype = None + if use_bf16: + default_torch_dtype = torch.bfloat16 + torch_dtype = kwargs.get('torch_dtype', default_torch_dtype) + model = Model.from_pretrained( + model_dir, + trust_remote_code=True, + device_map=device_map, + torch_dtype=torch_dtype) + else: + if ((device.startswith('gpu') or device.startswith('cuda')) + and is_on_same_device(model)): + model.cuda() + if use_bf16: + model.bfloat16() if quantization_bit is not None: model = model.quantize(quantization_bit) - if use_bf16: - model = model.bfloat16() + self.model = model self.model.eval() - self.tokenizer = ChatGLM2Tokenizer.from_pretrained( - self.model.model_dir) + self.tokenizer = AutoTokenizer.from_pretrained( + self.model.model_dir, trust_remote_code=True) super().__init__(model=model, **kwargs) diff --git a/modelscope/utils/automodel_utils.py b/modelscope/utils/automodel_utils.py new file mode 100644 index 00000000..56075618 --- /dev/null +++ b/modelscope/utils/automodel_utils.py @@ -0,0 +1,84 @@ +import os +from typing import Optional + +from modelscope.metainfo import Tasks +from modelscope.utils.ast_utils import INDEX_KEY +from modelscope.utils.import_utils import LazyImportModule + + +def can_load_by_ms(model_dir: str, tast_name: str, model_type: str) -> bool: + if ('MODELS', tast_name, + model_type) in LazyImportModule.AST_INDEX[INDEX_KEY]: + return True + ms_wrapper_path = os.path.join(model_dir, 'ms_wrapper.py') + if os.path.exists(ms_wrapper_path): + return True + return False + + +def _can_load_by_hf_automodel(automodel_class: type, config) -> bool: + automodel_class_name = automodel_class.__name__ + if type(config) in automodel_class._model_mapping.keys(): + return True + if hasattr(config, 'auto_map') and automodel_class_name in config.auto_map: + return True + return False + + +def get_hf_automodel_class(model_dir: str, task_name: str) -> Optional[type]: + from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + AutoModelForTokenClassification, + AutoModelForSequenceClassification) + automodel_mapping = { + Tasks.backbone: AutoModel, + Tasks.chat: AutoModelForCausalLM, + Tasks.text_generation: AutoModelForCausalLM, + Tasks.text_classification: AutoModelForSequenceClassification, + Tasks.token_classification: AutoModelForTokenClassification, + } + automodel_class = automodel_mapping.get(task_name, None) + if automodel_class is None: + return None + config_path = os.path.join(model_dir, 'config.json') + if not os.path.exists(config_path): + return None + try: + try: + config = AutoConfig.from_pretrained( + model_dir, trust_remote_code=True) + except (FileNotFoundError, ValueError): + return None + + if _can_load_by_hf_automodel(automodel_class, config): + return automodel_class + if (automodel_class is AutoModelForCausalLM + and _can_load_by_hf_automodel(AutoModelForSeq2SeqLM, config)): + return AutoModelForSeq2SeqLM + return None + except Exception: + return None + + +def try_to_load_hf_model(model_dir: str, task_name: str, + use_hf: Optional[bool], **kwargs): + automodel_class = get_hf_automodel_class(model_dir, task_name) + + if use_hf and automodel_class is None: + raise ValueError(f'Model import failed. You used `use_hf={use_hf}`, ' + 'but the model is not a model of hf') + + model = None + if automodel_class is not None: + # use hf + device_map = kwargs.get('device_map', None) + torch_dtype = kwargs.get('torch_dtype', None) + config = kwargs.get('config', None) + + model = automodel_class.from_pretrained( + model_dir, + device_map=device_map, + torch_dtype=torch_dtype, + config=config, + trust_remote_code=True) + return model diff --git a/modelscope/utils/hf_util.py b/modelscope/utils/hf_util.py index dc3afc40..2a534eb6 100644 --- a/modelscope/utils/hf_util.py +++ b/modelscope/utils/hf_util.py @@ -3,13 +3,21 @@ import os import sys +from transformers import CONFIG_MAPPING from transformers import AutoConfig as AutoConfigHF from transformers import AutoModel as AutoModelHF from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF +from transformers import \ + AutoModelForSequenceClassification as AutoModelForSequenceClassificationHF +from transformers import \ + AutoModelForTokenClassification as AutoModelForTokenClassificationHF from transformers import AutoTokenizer as AutoTokenizerHF from transformers import GenerationConfig as GenerationConfigHF -from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers import (PretrainedConfig, PreTrainedModel, + PreTrainedTokenizerBase) +from transformers.models.auto.tokenization_auto import ( + TOKENIZER_MAPPING_NAMES, get_tokenizer_config) from modelscope import snapshot_download from modelscope.utils.constant import Invoke @@ -70,6 +78,58 @@ patch_tokenizer_base() patch_model_base() +def check_hf_code(model_dir: str, auto_class: type, + trust_remote_code: bool) -> None: + config_path = os.path.join(model_dir, 'config.json') + if not os.path.exists(config_path): + raise FileNotFoundError(f'{config_path} is not found') + config_dict = PretrainedConfig.get_config_dict(config_path)[0] + auto_class_name = auto_class.__name__ + # load from repo + if trust_remote_code: + has_remote_code = False + if auto_class is AutoTokenizerHF: + tokenizer_config_dict = get_tokenizer_config(model_dir) + auto_map = tokenizer_config_dict.get('auto_map', None) + if auto_map is not None: + module_name = auto_map.get(auto_class_name, None) + if module_name is not None: + module_name = module_name[0] + has_remote_code = True + else: + auto_map = config_dict.get('auto_map', None) + if auto_map is not None: + module_name = auto_map.get(auto_class_name, None) + has_remote_code = module_name is not None + + if has_remote_code: + module_path = os.path.join(model_dir, + module_name.split('.')[0] + '.py') + if not os.path.exists(module_path): + raise FileNotFoundError(f'{module_path} is not found') + return + + # trust_remote_code is False or has_remote_code is False + model_type = config_dict.get('model_type', None) + if model_type is None: + raise ValueError(f'`model_type` key is not found in {config_path}') + + if auto_class is AutoConfigHF: + if model_type not in CONFIG_MAPPING: + raise ValueError(f'{model_type} not found in HF CONFIG_MAPPING') + elif auto_class is AutoTokenizerHF: + if model_type not in TOKENIZER_MAPPING_NAMES: + raise ValueError( + f'{model_type} not found in HF TOKENIZER_MAPPING_NAMES') + else: + mapping_names = [ + m.model_type for m in auto_class._model_mapping.keys() + ] + if model_type not in mapping_names: + raise ValueError( + f'{model_type} not found in HF auto_class._model_mapping') + + def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs): """Get a custom wrapper class for auto classes to download the models from the ModelScope hub Args: @@ -79,12 +139,15 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs): Returns: The wrapper """ + default_ignore_file_pattern = ignore_file_pattern class ClassWrapper(module_class): @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + ignore_file_pattern = kwargs.pop('ignore_file_pattern', + default_ignore_file_pattern) if not os.path.exists(pretrained_model_name_or_path): revision = kwargs.pop('revision', None) model_dir = snapshot_download( @@ -95,11 +158,18 @@ def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs): else: model_dir = pretrained_model_name_or_path - model = module_class.from_pretrained(model_dir, *model_args, - **kwargs) - model.model_dir = model_dir - return model + if module_class is not GenerationConfigHF: + trust_remote_code = kwargs.get('trust_remote_code', False) + check_hf_code(model_dir, module_class, trust_remote_code) + module_obj = module_class.from_pretrained(model_dir, *model_args, + **kwargs) + if module_class.__name__.startswith('AutoModel'): + module_obj.model_dir = model_dir + return module_obj + + ClassWrapper.__name__ = module_class.__name__ + ClassWrapper.__qualname__ = module_class.__qualname__ return ClassWrapper @@ -109,6 +179,12 @@ AutoModelForCausalLM = get_wrapped_class( AutoModelForCausalLMHF, ignore_file_pattern=[r'\w+\.safetensors']) AutoModelForSeq2SeqLM = get_wrapped_class( AutoModelForSeq2SeqLMHF, ignore_file_pattern=[r'\w+\.safetensors']) +AutoModelForSequenceClassification = get_wrapped_class( + AutoModelForSequenceClassificationHF, + ignore_file_pattern=[r'\w+\.safetensors']) +AutoModelForTokenClassification = get_wrapped_class( + AutoModelForTokenClassificationHF, + ignore_file_pattern=[r'\w+\.safetensors']) AutoTokenizer = get_wrapped_class( AutoTokenizerHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors']) diff --git a/tests/models/test_model_base.py b/tests/models/test_model_base.py index 9d353ec5..9569755a 100644 --- a/tests/models/test_model_base.py +++ b/tests/models/test_model_base.py @@ -5,12 +5,10 @@ import shutil import tempfile import unittest -import numpy as np import torch -import torch.nn as nn -import torch.nn.functional as F from modelscope.models.base import Model +from modelscope.utils.test_utils import test_level class BaseTest(unittest.TestCase): @@ -25,15 +23,31 @@ class BaseTest(unittest.TestCase): shutil.rmtree(self.tmp_dir) super().tearDown() - def test_from_pretrained(self): + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_from_pretrained_baichuan(self): model = Model.from_pretrained( - 'baichuan-inc/baichuan-7B', revision='v1.0.5') + 'baichuan-inc/Baichuan-13B-Chat', + revision='v1.0.8', + torch_dtype=torch.float16, + device='gpu') + print(model.__class__.__name__) self.assertIsNotNone(model) - def test_from_pretrained_hf(self): + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_from_pretrained_chatglm2(self): + model = Model.from_pretrained( + 'ZhipuAI/chatglm2-6b', + revision='v1.0.7', + torch_dtype=torch.float16, + device='gpu') + print(model.__class__.__name__) + self.assertIsNotNone(model) + + def test_from_pretrained_ms(self): model = Model.from_pretrained( 'damo/nlp_structbert_sentence-similarity_chinese-tiny', - use_hf=True) + device='gpu') + print(model.__class__.__name__) self.assertIsNotNone(model)