mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
add qwen 7b base and chat
添加QWen 7b base模型和chat模型及相关pipelines Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13482235 * add qwen 7b base and chat * fix logger * update examples, lint test * add unittest for qwen base and chat * rename qwen to qwen-7b * resolve imports and add a registry to text-generation * reset load model from pretrained * fix precheck * skip qwen test case now * remove strange file
This commit is contained in:
committed by
wenmeng.zwm
parent
54e65b034d
commit
fc0a0bcf60
@@ -1,69 +0,0 @@
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from torch import device as Device
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from modelscope import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def _format_device(device: Union[List[int], str]) -> Tuple[List[int], str]:
|
||||
if isinstance(device, list):
|
||||
device_ids = device
|
||||
device_str = ','.join([str(d) for d in device])
|
||||
else:
|
||||
device_ids = [int(d) for d in device.split(',') if d != '-1']
|
||||
device_str = device
|
||||
device_str = device_str.replace(' ', '')
|
||||
return device_ids, device_str
|
||||
|
||||
|
||||
def select_device(device: Union[List[int], str]) -> Device:
|
||||
"""Call this function before cuda is initialized.
|
||||
device: e.g. []: 'cpu', [0], [0, 1, 2]
|
||||
e.g. '-1': 'cpu', '0', '0,1,2'
|
||||
"""
|
||||
if torch.cuda.is_initialized():
|
||||
logger.warning('CUDA has been initialized! Device selection fails!')
|
||||
return torch.device('cuda:0')
|
||||
|
||||
device_ids, device_str = _format_device(device)
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = device_str
|
||||
log_s = 'Using device: '
|
||||
if len(device_ids) == 0:
|
||||
master_device: str = 'cpu'
|
||||
log_s += 'cpu'
|
||||
else:
|
||||
assert torch.cuda.is_available(
|
||||
) and torch.cuda.device_count() >= len(device_ids)
|
||||
master_device = 'cuda:0'
|
||||
log_s += f'cuda:{device_str}'
|
||||
logger.info(log_s)
|
||||
return torch.device(master_device)
|
||||
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
|
||||
def parse_args(class_type: Type[_T],
|
||||
argv: Optional[List[str]] = None) -> Tuple[_T, List[str]]:
|
||||
parser = HfArgumentParser([class_type])
|
||||
args, remaining_args = parser.parse_args_into_dataclasses(
|
||||
argv, return_remaining_strings=True)
|
||||
logger.info(f'args: {args}')
|
||||
return args, remaining_args
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceArguments:
|
||||
device: str = '0' # e.g. '-1'; '0'; '0,1'
|
||||
|
||||
|
||||
def parse_device(argv: Optional[List[str]] = None) -> List[str]:
|
||||
args, remaining_args = parse_args(DeviceArguments, argv)
|
||||
select_device(args.device)
|
||||
return remaining_args
|
||||
@@ -1,18 +1,26 @@
|
||||
# ### Setting up experimental environment.
|
||||
import os
|
||||
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Avoid cuda initialization caused by library import (e.g. peft, accelerate)
|
||||
from _parser import *
|
||||
# argv = parse_device(['--device', '1'])
|
||||
argv = parse_device()
|
||||
import torch
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
from utils import (DATASET_MAPPER, DEFAULT_PROMPT, MODEL_MAPPER, get_dataset,
|
||||
get_model_tokenizer, inference, parse_args, process_dataset,
|
||||
tokenize_function)
|
||||
|
||||
from utils import *
|
||||
from modelscope import get_logger
|
||||
from modelscope.swift import LoRAConfig, Swift
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferArguments:
|
||||
model_type: str = field(
|
||||
default='baichuan-7b', metadata={'choices': list(MODEL_MAPPER.keys())})
|
||||
default='qwen-7b', metadata={'choices': list(MODEL_MAPPER.keys())})
|
||||
sft_type: str = field(
|
||||
default='lora', metadata={'choices': ['lora', 'full']})
|
||||
ckpt_path: str = '/path/to/your/iter_xxx.pth'
|
||||
@@ -114,7 +122,7 @@ def llm_infer(args: InferArguments) -> None:
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args, remaining_argv = parse_args(InferArguments, argv)
|
||||
args, remaining_argv = parse_args(InferArguments)
|
||||
if len(remaining_argv) > 0:
|
||||
if args.ignore_args_error:
|
||||
logger.warning(f'remaining_argv: {remaining_argv}')
|
||||
|
||||
@@ -1,35 +1,52 @@
|
||||
# ### Setting up experimental environment.
|
||||
"""
|
||||
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y
|
||||
pip install sentencepiece charset_normalizer cpm_kernels tiktoken -U
|
||||
pip install matplotlib scikit-learn -U
|
||||
pip install transformers datasets -U
|
||||
pip install tqdm tensorboard torchmetrics -U
|
||||
pip install accelerate transformers_stream_generator -U
|
||||
|
||||
# Install the latest version of modelscope from source
|
||||
git clone https://github.com/modelscope/modelscope.git
|
||||
cd modelscope
|
||||
pip install -r requirements.txt
|
||||
pip install .
|
||||
|
||||
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
|
||||
pip install numpy pandas -U # Resolve torchmetrics dependencies and update numpy
|
||||
pip install matplotlib scikit-learn -U
|
||||
pip install transformers datasets -U
|
||||
pip install tqdm tensorboard torchmetrics sentencepiece charset_normalizer -U
|
||||
pip install accelerate transformers_stream_generator -U
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Avoid cuda initialization caused by library import (e.g. peft, accelerate)
|
||||
from _parser import *
|
||||
# argv = parse_device(['--device', '1'])
|
||||
argv = parse_device()
|
||||
import os
|
||||
# os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import List, Optional
|
||||
|
||||
from utils import *
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from utils import (DATASET_MAPPER, DEFAULT_PROMPT, MODEL_MAPPER,
|
||||
data_collate_fn, get_dataset, get_model_tokenizer,
|
||||
get_T_max, get_work_dir, parse_args, plot_images,
|
||||
print_example, print_model_info, process_dataset,
|
||||
seed_everything, show_freeze_layers, stat_dataset,
|
||||
tokenize_function)
|
||||
|
||||
from modelscope import get_logger
|
||||
from modelscope.swift import LoRAConfig, Swift
|
||||
from modelscope.trainers import EpochBasedTrainer
|
||||
from modelscope.utils.config import Config
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SftArguments:
|
||||
seed: int = 42
|
||||
model_type: str = field(
|
||||
default='baichuan-7b', metadata={'choices': list(MODEL_MAPPER.keys())})
|
||||
default='qwen-7b', metadata={'choices': list(MODEL_MAPPER.keys())})
|
||||
# baichuan-7b: 'lora': 16G; 'full': 80G
|
||||
sft_type: str = field(
|
||||
default='lora', metadata={'choices': ['lora', 'full']})
|
||||
output_dir: Optional[str] = None
|
||||
ignore_args_error: bool = False # True: notebook compatibility
|
||||
|
||||
dataset: str = field(
|
||||
@@ -82,6 +99,10 @@ class SftArguments:
|
||||
else:
|
||||
raise ValueError(f'sft_type: {self.sft_type}')
|
||||
|
||||
if self.output_dir is None:
|
||||
self.output_dir = 'runs'
|
||||
self.output_dir = os.path.join(self.output_dir, self.model_type)
|
||||
|
||||
if self.lora_target_modules is None:
|
||||
self.lora_target_modules = MODEL_MAPPER[self.model_type]['lora_TM']
|
||||
|
||||
@@ -145,7 +166,7 @@ def llm_sft(args: SftArguments) -> None:
|
||||
|
||||
T_max = get_T_max(
|
||||
len(train_dataset), args.batch_size, args.max_epochs, True)
|
||||
work_dir = get_work_dir(f'runs/{args.model_type}')
|
||||
work_dir = get_work_dir(args.output_dir)
|
||||
config = Config({
|
||||
'train': {
|
||||
'dataloader': {
|
||||
@@ -257,7 +278,7 @@ def llm_sft(args: SftArguments) -> None:
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args, remaining_argv = parse_args(SftArguments, argv)
|
||||
args, remaining_argv = parse_args(SftArguments)
|
||||
if len(remaining_argv) > 0:
|
||||
if args.ignore_args_error:
|
||||
logger.warning(f'remaining_argv: {remaining_argv}')
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
#!/bin/bash
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1 \
|
||||
python llm_infer.py \
|
||||
--device 0,1 \
|
||||
--model_type openbuddy-llama2-13b \
|
||||
--ckpt_path "runs/openbuddy-llama2-13b/vx_xxx/output_best/pytorch_model.bin" \
|
||||
--model_type qwen-7b \
|
||||
--ckpt_path "runs/qwen-7b/vx_xxx/output_best/pytorch_model.bin" \
|
||||
--eval_human true
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
DATE=$(date +"%Y%m%d-%H%M%S")
|
||||
nohup python llm_sft.py \
|
||||
--device 0,1 \
|
||||
--model_type openbuddy-llama2-13b \
|
||||
CUDA_VISIBLE_DEVICES=0,1 \
|
||||
python llm_sft.py \
|
||||
--model_type qwen-7b \
|
||||
--output_dir runs \
|
||||
--dataset alpaca-en,alpaca-zh \
|
||||
--dataset_sample 20000 \
|
||||
&> train_$DATE.out &
|
||||
--dataset_sample 20000
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
from .dataset import *
|
||||
from .models import *
|
||||
from .utils import *
|
||||
from .dataset import DATASET_MAPPER, get_dataset, process_dataset
|
||||
from .models import MODEL_MAPPER, get_model_tokenizer
|
||||
from .utils import (DEFAULT_PROMPT, MyMetric, data_collate_fn, get_T_max,
|
||||
get_work_dir, inference, parse_args, plot_images,
|
||||
print_example, print_model_info, read_tensorboard_file,
|
||||
seed_everything, show_freeze_layers, stat_dataset,
|
||||
tensorboard_smoothing, tokenize_function)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
@@ -8,6 +7,7 @@ from torch import dtype as Dtype
|
||||
from modelscope import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, Model,
|
||||
get_logger, read_config, snapshot_download)
|
||||
from modelscope.models.nlp.chatglm2 import ChatGLM2Config, ChatGLM2Tokenizer
|
||||
from modelscope.models.nlp.qwen import QWenConfig, QWenTokenizer
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -63,11 +63,32 @@ def get_model_tokenizer_chatglm2(model_dir: str,
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_model_tokenizer_qwen(model_dir: str,
|
||||
torch_dtype: Dtype,
|
||||
load_model: bool = True):
|
||||
config = read_config(model_dir)
|
||||
logger.info(config)
|
||||
model_config = QWenConfig.from_pretrained(model_dir)
|
||||
model_config.torch_dtype = torch_dtype
|
||||
logger.info(model_config)
|
||||
tokenizer = QWenTokenizer.from_pretrained(model_dir)
|
||||
model = None
|
||||
if load_model:
|
||||
model = Model.from_pretrained(
|
||||
model_dir,
|
||||
cfg_dict=config,
|
||||
config=model_config,
|
||||
device_map='auto',
|
||||
torch_dtype=torch_dtype)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
class LoRATM(NamedTuple):
|
||||
# default lora target modules
|
||||
baichuan = ['W_pack']
|
||||
chatglm2 = ['query_key_value']
|
||||
llama2 = ['q_proj', 'k_proj', 'v_proj']
|
||||
qwen = ['c_attn']
|
||||
|
||||
|
||||
# Reference: 'https://modelscope.cn/models/{model_id}/summary'
|
||||
@@ -105,7 +126,15 @@ MODEL_MAPPER = {
|
||||
},
|
||||
'openbuddy-llama2-13b': {
|
||||
'model_id': 'OpenBuddy/openbuddy-llama2-13b-v8.1-fp16',
|
||||
'revision': 'v1.0.0',
|
||||
'lora_TM': LoRATM.llama2
|
||||
},
|
||||
'qwen-7b': {
|
||||
'model_id': 'QWen/qwen-7b',
|
||||
'revision': 'v1.0.0',
|
||||
'get_function': get_model_tokenizer_qwen,
|
||||
'torch_dtype': torch.bfloat16,
|
||||
'lora_TM': LoRATM.qwen,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,34 +3,24 @@ import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Any, Callable, Counter, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Counter, Dict, List, Optional, Tuple, Type, TypeVar
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset as HfDataset
|
||||
from numpy import ndarray
|
||||
from tensorboard.backend.event_processing.event_accumulator import \
|
||||
EventAccumulator
|
||||
from torch import Tensor
|
||||
from torch import device as Device
|
||||
from torch import dtype as Dtype
|
||||
from torch.nn import Module
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torchmetrics import Accuracy, MeanMetric
|
||||
from tqdm import tqdm
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
from transformers import GenerationConfig, HfArgumentParser, TextStreamer
|
||||
|
||||
from modelscope import get_logger
|
||||
from modelscope.metrics.base import Metric
|
||||
from modelscope.metrics.builder import METRICS
|
||||
from modelscope.swift import LoRAConfig, Swift
|
||||
from modelscope.trainers import EpochBasedTrainer
|
||||
from modelscope.utils.config import Config, ConfigDict
|
||||
from modelscope.utils.registry import default_group
|
||||
|
||||
COLOR, COLOR_S = '#FFE2D9', '#FF7043'
|
||||
@@ -318,3 +308,15 @@ def inference(input_ids: List[int],
|
||||
generation_config=generation_config)
|
||||
output_text = tokenizer.decode(generate_ids[0])
|
||||
return output_text
|
||||
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
|
||||
def parse_args(class_type: Type[_T],
|
||||
argv: Optional[List[str]] = None) -> Tuple[_T, List[str]]:
|
||||
parser = HfArgumentParser([class_type])
|
||||
args, remaining_args = parser.parse_args_into_dataclasses(
|
||||
argv, return_remaining_strings=True)
|
||||
logger.info(f'args: {args}')
|
||||
return args, remaining_args
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -173,6 +173,7 @@ class Models(object):
|
||||
llama2 = 'llama2'
|
||||
chatglm_6b = 'chatglm6b'
|
||||
chatglm2_6b = 'chatglm2-6b'
|
||||
qwen_7b = 'qwen-7b'
|
||||
|
||||
# audio models
|
||||
sambert_hifigan = 'sambert-hifigan'
|
||||
|
||||
@@ -77,6 +77,7 @@ if TYPE_CHECKING:
|
||||
from .xlm_roberta import XLMRobertaConfig, XLMRobertaModel
|
||||
from .llama import LlamaForTextGeneration, LlamaConfig, LlamaModel, LlamaTokenizer, LlamaTokenizerFast
|
||||
from .llama2 import Llama2ForTextGeneration, Llama2Config, Llama2Model, Llama2Tokenizer, Llama2TokenizerFast
|
||||
from .qwen import QWenForTextGeneration, QWenConfig, QWenModel, QWenTokenizer
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -177,6 +178,8 @@ else:
|
||||
'Llama2ForTextGeneration', 'Llama2Config', 'Llama2Model',
|
||||
'Llama2Tokenizer', 'Llama2TokenizerFast'
|
||||
],
|
||||
'qwen':
|
||||
['QWenForTextGeneration', 'QWenConfig', 'QWenModel', 'QWenTokenizer'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
27
modelscope/models/nlp/qwen/__init__.py
Normal file
27
modelscope/models/nlp/qwen/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration import QWenConfig
|
||||
from .text_generation import QWenForTextGeneration
|
||||
from .backbone import QWenModel
|
||||
from .tokenization import QWenTokenizer
|
||||
else:
|
||||
_import_structure = {
|
||||
'configuration': ['QWenConfig'],
|
||||
'backbone': ['QWenModel'],
|
||||
'tokenization': ['QWenTokenizer'],
|
||||
'text_generation': ['QWenForTextGeneration'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
776
modelscope/models/nlp/qwen/backbone.py
Normal file
776
modelscope/models/nlp/qwen/backbone.py
Normal file
@@ -0,0 +1,776 @@
|
||||
# Copyright 2023 Alibaba Group. All rights reserved.
|
||||
"""Model classes for QWen."""
|
||||
|
||||
import importlib
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast)
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.trainer_utils import set_seed
|
||||
from transformers.utils import (ModelOutput, add_code_sample_docstrings,
|
||||
add_start_docstrings,
|
||||
add_start_docstrings_to_model_forward, logging)
|
||||
from transformers.utils.model_parallel_utils import (assert_device_map,
|
||||
get_device_map)
|
||||
|
||||
from modelscope import Model, TorchModel
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from ... import MODELS
|
||||
from .configuration import QWenConfig
|
||||
|
||||
try:
|
||||
from einops import rearrange
|
||||
except ImportError:
|
||||
rearrange = None
|
||||
try:
|
||||
from flash_attn.layers.rotary import apply_rotary_emb_func
|
||||
from einops import rearrange
|
||||
use_flash_rotary = True
|
||||
print('use flash_attn rotary')
|
||||
except ImportError:
|
||||
use_flash_rotary = False
|
||||
print('import flash_attn rotary fail')
|
||||
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import rms_norm
|
||||
print('use flash_attn rms_norm')
|
||||
except ImportError:
|
||||
rms_norm = None
|
||||
print('import flash_attn rms_norm fail')
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
_CHECKPOINT_FOR_DOC = 'qwen-7b'
|
||||
_CONFIG_FOR_DOC = 'QWenConfig'
|
||||
|
||||
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ['qwen-7b']
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||
except ImportError:
|
||||
flash_attn_unpadded_func = None
|
||||
|
||||
|
||||
class FlashSelfAttention(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
causal=False,
|
||||
softmax_scale=None,
|
||||
attention_dropout=0.0,
|
||||
device=None,
|
||||
dtype=None):
|
||||
super().__init__()
|
||||
assert flash_attn_unpadded_func is not None, (
|
||||
'Please install FlashAttention first, '
|
||||
'e.g., with pip install flash-attn')
|
||||
assert rearrange is not None, 'Please install einops first, e.g., with pip install einops'
|
||||
self.causal = causal
|
||||
self.softmax_scale = softmax_scale
|
||||
self.dropout_p = attention_dropout
|
||||
|
||||
def forward(self, q, k, v):
|
||||
assert all(
|
||||
(i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
|
||||
assert all((i.is_cuda for i in (q, k, v)))
|
||||
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
||||
seqlen_k = k.shape[1]
|
||||
q, k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [q, k, v]]
|
||||
cu_seqlens_q = torch.arange(
|
||||
0, (batch_size + 1) * seqlen_q,
|
||||
step=seqlen_q,
|
||||
dtype=torch.int32,
|
||||
device=q.device)
|
||||
|
||||
if self.training:
|
||||
assert seqlen_k == seqlen_q
|
||||
|
||||
is_causal = self.causal
|
||||
cu_seqlens_k = cu_seqlens_q
|
||||
else:
|
||||
is_causal = seqlen_q == seqlen_k
|
||||
cu_seqlens_k = torch.arange(
|
||||
0, (batch_size + 1) * seqlen_k,
|
||||
step=seqlen_k,
|
||||
dtype=torch.int32,
|
||||
device=q.device)
|
||||
self.dropout_p = 0
|
||||
output = flash_attn_unpadded_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
self.dropout_p,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=is_causal)
|
||||
|
||||
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
|
||||
return output
|
||||
|
||||
|
||||
class QWenAttention(nn.Module):
|
||||
|
||||
def __init__(self, config, layer_number=None):
|
||||
super().__init__()
|
||||
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
'bias',
|
||||
torch.tril(
|
||||
torch.ones((max_positions, max_positions),
|
||||
dtype=torch.bool)).view(1, 1, max_positions,
|
||||
max_positions),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
'masked_bias', torch.tensor(-1e4), persistent=False)
|
||||
self.layer_number = max(1, layer_number)
|
||||
self.params_dtype = config.params_dtype
|
||||
self.seq_length = config.seq_length
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.split_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
|
||||
self.use_flash_attn = config.use_flash_attn
|
||||
self.scale_attn_weights = True
|
||||
|
||||
self.layer_idx = None
|
||||
|
||||
self.projection_size = config.kv_channels * config.num_attention_heads
|
||||
|
||||
assert self.projection_size % config.num_attention_heads == 0
|
||||
self.hidden_size_per_attention_head = \
|
||||
self.projection_size // config.num_attention_heads
|
||||
|
||||
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
|
||||
|
||||
self.c_proj = nn.Linear(
|
||||
config.hidden_size, self.projection_size, bias=not config.no_bias)
|
||||
|
||||
if self.use_flash_attn:
|
||||
self.core_attention_flash = FlashSelfAttention(
|
||||
causal=True, attention_dropout=config.attn_pdrop)
|
||||
|
||||
self.bf16 = config.bf16
|
||||
|
||||
if config.rotary_pct == 1.0:
|
||||
self.rotary_ndims = None
|
||||
else:
|
||||
assert config.rotary_pct < 1
|
||||
self.rotary_ndims = int(self.hidden_size_per_attention_head
|
||||
* config.rotary_pct)
|
||||
dim = (
|
||||
self.rotary_ndims if self.rotary_ndims is not None else
|
||||
self.hidden_size_per_attention_head)
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
dim, base=config.rotary_emb_base, ntk_alpha=config.ntk_alpha)
|
||||
|
||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||
|
||||
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
if self.scale_attn_weights:
|
||||
attn_weights = attn_weights / torch.full(
|
||||
[],
|
||||
value.size(-1)**0.5,
|
||||
dtype=attn_weights.dtype,
|
||||
device=attn_weights.device)
|
||||
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length
|
||||
- query_length:key_length, :key_length]
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
|
||||
attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask,
|
||||
attn_weights.to(attn_weights.dtype),
|
||||
mask_value)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
attn_weights = attn_weights.type(value.dtype)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
def _upcast_and_reordered_attn(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask=None,
|
||||
head_mask=None):
|
||||
bsz, num_heads, q_seq_len, dk = query.size()
|
||||
_, _, k_seq_len, _ = key.size()
|
||||
|
||||
attn_weights = torch.empty(
|
||||
bsz * num_heads,
|
||||
q_seq_len,
|
||||
k_seq_len,
|
||||
dtype=torch.float32,
|
||||
device=query.device)
|
||||
|
||||
scale_factor = 1.0
|
||||
if self.scale_attn_weights:
|
||||
scale_factor /= float(value.size(-1))**0.5
|
||||
|
||||
with autocast(enabled=False):
|
||||
q, k = query.reshape(-1, q_seq_len,
|
||||
dk), key.transpose(-1, -2).reshape(
|
||||
-1, dk, k_seq_len)
|
||||
attn_weights = torch.baddbmm(
|
||||
attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
||||
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len,
|
||||
k_seq_len)
|
||||
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length
|
||||
- query_length:key_length, :key_length]
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
mask_value = torch.tensor(
|
||||
mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
|
||||
if attn_weights.dtype != torch.float32:
|
||||
raise RuntimeError(
|
||||
'Error with upcasting, attn_weights does not have dtype torch.float32'
|
||||
)
|
||||
attn_weights = attn_weights.type(value.dtype)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
def _split_heads(self, tensor, num_heads, attn_head_size):
|
||||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
||||
tensor = tensor.view(new_shape)
|
||||
return tensor
|
||||
|
||||
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
||||
tensor = tensor.contiguous()
|
||||
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size, )
|
||||
return tensor.view(new_shape)
|
||||
|
||||
def forward(self,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False):
|
||||
|
||||
mixed_x_layer = self.c_attn(hidden_states)
|
||||
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
|
||||
|
||||
query = self._split_heads(query, self.num_heads, self.head_dim)
|
||||
key = self._split_heads(key, self.num_heads, self.head_dim)
|
||||
value = self._split_heads(value, self.num_heads, self.head_dim)
|
||||
|
||||
kv_seq_len = hidden_states.size()[1]
|
||||
if layer_past:
|
||||
kv_seq_len += layer_past[0].shape[1]
|
||||
rotary_pos_emb = self.rotary_emb(kv_seq_len).to(hidden_states.device)
|
||||
|
||||
if rotary_pos_emb is not None:
|
||||
if isinstance(rotary_pos_emb, tuple):
|
||||
rotary_pos_emb = rotary_pos_emb
|
||||
else:
|
||||
rotary_pos_emb = ((rotary_pos_emb, ) * 2)
|
||||
|
||||
if rotary_pos_emb is not None:
|
||||
q_pos_emb, k_pos_emb = rotary_pos_emb
|
||||
cur_len = query.shape[1]
|
||||
q_pos_emb = q_pos_emb[:, -cur_len:, :, :]
|
||||
k_pos_emb = k_pos_emb[:, -cur_len:, :, :]
|
||||
query = apply_rotary_pos_emb(query, q_pos_emb)
|
||||
key = apply_rotary_pos_emb(key, k_pos_emb)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past[0], layer_past[1]
|
||||
key = torch.cat((past_key, key), dim=1)
|
||||
value = torch.cat((past_value, value), dim=1)
|
||||
|
||||
if use_cache:
|
||||
present = (key, value)
|
||||
else:
|
||||
present = None
|
||||
|
||||
if self.use_flash_attn:
|
||||
q, k, v = query, key, value
|
||||
context_layer = self.core_attention_flash(q, k, v)
|
||||
|
||||
context_layer = rearrange(context_layer,
|
||||
'b s h d -> b s (h d)').contiguous()
|
||||
else:
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
key = key.permute(0, 2, 1, 3)
|
||||
value = value.permute(0, 2, 1, 3)
|
||||
attn_output, attn_weight = self._attn(query, key, value,
|
||||
attention_mask, head_mask)
|
||||
context_layer = self._merge_heads(attn_output, self.num_heads,
|
||||
self.head_dim)
|
||||
|
||||
attn_output = self.c_proj(context_layer)
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
if self.use_flash_attn:
|
||||
raise ValueError(
|
||||
'Cannot output attentions while using flash-attn')
|
||||
else:
|
||||
outputs += (attn_weight, )
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class QWenMLP(nn.Module):
|
||||
|
||||
def __init__(self, intermediate_size, config):
|
||||
super().__init__()
|
||||
|
||||
self.w1 = nn.Linear(
|
||||
config.hidden_size,
|
||||
config.ffn_hidden_size // 2,
|
||||
bias=not config.no_bias)
|
||||
self.w2 = nn.Linear(
|
||||
config.hidden_size,
|
||||
config.ffn_hidden_size // 2,
|
||||
bias=not config.no_bias)
|
||||
|
||||
ff_dim_in = config.ffn_hidden_size // 2
|
||||
self.c_proj = nn.Linear(
|
||||
ff_dim_in, config.hidden_size, bias=not config.no_bias)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
|
||||
a1 = self.w1(hidden_states)
|
||||
a2 = self.w2(hidden_states)
|
||||
intermediate_parallel = a1 * F.silu(a2)
|
||||
|
||||
output = self.c_proj(intermediate_parallel)
|
||||
return output
|
||||
|
||||
|
||||
class QWenBlock(nn.Module):
|
||||
|
||||
def __init__(self, config, layer_idx=None, num_expert=1):
|
||||
super().__init__()
|
||||
self.num_expert = num_expert
|
||||
self.layer_number = layer_idx
|
||||
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
||||
hidden_size = config.hidden_size
|
||||
self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
|
||||
self.bf16 = config.bf16
|
||||
|
||||
if config.n_inner is not None:
|
||||
inner_dim = config.n_inner
|
||||
else:
|
||||
ff_mult = 4 * 2 / 3
|
||||
inner_dim = ff_mult * hidden_size
|
||||
|
||||
self.ln_1 = RMSNorm(
|
||||
hidden_size,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
self.attn = QWenAttention(config, layer_number=layer_idx)
|
||||
self.ln_2 = RMSNorm(
|
||||
hidden_size,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
|
||||
self.mlp = QWenMLP(inner_dim, config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
):
|
||||
layernorm_output = self.ln_1(hidden_states)
|
||||
|
||||
attn_outputs = self.attn(
|
||||
layernorm_output,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions)
|
||||
attn_output = attn_outputs[0]
|
||||
|
||||
outputs = attn_outputs[1:]
|
||||
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
layernorm_input = attn_output + residual
|
||||
|
||||
layernorm_output = self.ln_2(layernorm_input)
|
||||
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = layernorm_input
|
||||
|
||||
mlp_output = self.mlp(layernorm_output)
|
||||
hidden_states = residual + mlp_output
|
||||
|
||||
if use_cache:
|
||||
outputs = (hidden_states, ) + outputs
|
||||
else:
|
||||
outputs = (hidden_states, ) + outputs[1:]
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class QWenPreTrainedModel(TorchModel, PreTrainedModel):
|
||||
config_class = QWenConfig
|
||||
base_model_prefix = 'transformer'
|
||||
is_parallelizable = False
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ['QWenBlock']
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(config.name_or_path, **kwargs)
|
||||
super(Model, self).__init__(config)
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(
|
||||
mean=0.0, std=self.config.initializer_range)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(
|
||||
mean=0.0, std=self.config.initializer_range)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
elif isinstance(module, RMSNorm):
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
for name, p in module.named_parameters():
|
||||
if name == 'c_proj.weight':
|
||||
p.data.normal_(
|
||||
mean=0.0,
|
||||
std=(self.config.initializer_range
|
||||
/ math.sqrt(2 * self.config.n_layer)))
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, QWenModel):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@classmethod
|
||||
def _instantiate(cls, **kwargs):
|
||||
model_dir = kwargs.pop('model_dir', None)
|
||||
if model_dir is None:
|
||||
config = QWenConfig(**kwargs)
|
||||
model = cls(config)
|
||||
else:
|
||||
model = super(Model, cls).from_pretrained(
|
||||
pretrained_model_name_or_path=model_dir, **kwargs)
|
||||
model.model_dir = model_dir
|
||||
return model
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.backbone, module_name=Models.qwen_7b)
|
||||
class QWenModel(QWenPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ['attn.masked_bias']
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.vocab_size = config.padded_vocab_size
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
max_sequence_length = config.max_position_embeddings
|
||||
self.position_embedding_type = config.pos_emb
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
if self.position_embedding_type == 'learned':
|
||||
self.wpe = nn.Embedding(max_sequence_length, self.embed_dim)
|
||||
self.init_method(self.position_embeddings.weight)
|
||||
self._position_embeddings_key = 'position_embeddings'
|
||||
self.init_method(self.position_embeddings.weight)
|
||||
else:
|
||||
self.wpe = None
|
||||
self._position_embeddings_key = ''
|
||||
|
||||
self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
|
||||
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
self.h = nn.ModuleList([
|
||||
QWenBlock(
|
||||
config,
|
||||
layer_idx=i,
|
||||
) for i in range(config.num_hidden_layers)
|
||||
])
|
||||
self.ln_f = RMSNorm(
|
||||
self.embed_dim,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.wte
|
||||
|
||||
def set_input_embeddings(self, new_embeddings):
|
||||
self.wte = new_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else
|
||||
self.config.output_hidden_states)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
'You cannot specify both input_ids and inputs_embeds at the same time'
|
||||
)
|
||||
elif input_ids is not None:
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_shape[-1])
|
||||
batch_size = input_ids.shape[0]
|
||||
elif inputs_embeds is not None:
|
||||
input_shape = inputs_embeds.size()[:-1]
|
||||
batch_size = inputs_embeds.shape[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
'You have to specify either input_ids or inputs_embeds')
|
||||
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
|
||||
if past_key_values is None:
|
||||
past_length = 0
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
else:
|
||||
past_length = past_key_values[0][0].size(-2)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(
|
||||
past_length,
|
||||
input_shape[-1] + past_length,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
|
||||
if attention_mask is not None:
|
||||
if batch_size <= 0:
|
||||
raise ValueError('batch_size has to be defined and > 0')
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
attention_mask = attention_mask.to(dtype=self.dtype)
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(
|
||||
self.dtype).min
|
||||
|
||||
encoder_attention_mask = None
|
||||
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
hidden_states = inputs_embeds
|
||||
if self.wpe is not None:
|
||||
position_embeds = self.wpe(position_ids)
|
||||
hidden_states = hidden_states + position_embeds
|
||||
|
||||
hidden_states = self.drop(hidden_states)
|
||||
output_shape = input_shape + (hidden_states.size(-1), )
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states, )
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
None,
|
||||
attention_mask,
|
||||
head_mask[i],
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
outputs = block(
|
||||
hidden_states,
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
presents = presents + (
|
||||
outputs[2 if output_attentions else 1], )
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[1], )
|
||||
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.view(output_shape)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v
|
||||
for v in [hidden_states, presents, all_hidden_states]
|
||||
if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions)
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim, base=10000, ntk_alpha=1.0):
|
||||
super().__init__()
|
||||
base = base * ntk_alpha**(dim / (dim - 2))
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
if importlib.util.find_spec('einops') is None:
|
||||
raise RuntimeError('einops is required for Rotary Embedding')
|
||||
|
||||
self._rotary_pos_emb_cache = None
|
||||
self._seq_len_cached = 0
|
||||
|
||||
def update_rotary_pos_emb_cache(self, max_seq_len, offset=0):
|
||||
seqlen = max_seq_len + offset
|
||||
if seqlen > self._seq_len_cached:
|
||||
self._seq_len_cached = seqlen
|
||||
seq = torch.arange(seqlen, device=self.inv_freq.device)
|
||||
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
from einops import rearrange
|
||||
self._rotary_pos_emb_cache = rearrange(emb, 'n d -> 1 n 1 d')
|
||||
|
||||
def forward(self, max_seq_len, offset=0):
|
||||
self.update_rotary_pos_emb_cache(max_seq_len, offset)
|
||||
return self._rotary_pos_emb_cache[:, offset:offset + max_seq_len]
|
||||
|
||||
|
||||
def _rotate_half(x):
|
||||
from einops import rearrange
|
||||
x = rearrange(x, '... (j d) -> ... j d', j=2)
|
||||
x1, x2 = x.unbind(dim=-2)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(t, freqs, use_flash_rotary=False):
|
||||
if use_flash_rotary:
|
||||
t_ = t.float()
|
||||
freqs = freqs.squeeze(0).squeeze(1)
|
||||
cos = freqs[:, :freqs.shape[-1] // 2].cos()
|
||||
sin = freqs[:, :freqs.shape[-1] // 2].sin()
|
||||
output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
|
||||
return output
|
||||
else:
|
||||
rot_dim = freqs.shape[-1]
|
||||
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
|
||||
t_ = t_.float()
|
||||
t_pass_ = t_pass_.float()
|
||||
t_ = (t_ * freqs.cos()) + (_rotate_half(t_) * freqs.sin())
|
||||
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def _norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
if rms_norm is not None:
|
||||
return rms_norm(x, self.weight, self.eps)
|
||||
else:
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
return output * self.weight
|
||||
74
modelscope/models/nlp/qwen/configuration.py
Normal file
74
modelscope/models/nlp/qwen/configuration.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import GenerationConfig, PretrainedConfig
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
QWEN_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
||||
|
||||
|
||||
class QWenConfig(PretrainedConfig):
|
||||
model_type = 'qwen'
|
||||
keys_to_ignore_at_inference = ['past_key_values']
|
||||
attribute_map = {
|
||||
'hidden_size': 'n_embd',
|
||||
'num_attention_heads': 'n_head',
|
||||
'max_position_embeddings': 'n_positions',
|
||||
'num_hidden_layers': 'n_layer',
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151851,
|
||||
n_embd=4096,
|
||||
n_layer=32,
|
||||
n_head=32,
|
||||
n_inner=None,
|
||||
embd_pdrop=0.0,
|
||||
attn_pdrop=0.0,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
scale_attn_weights=True,
|
||||
use_cache=True,
|
||||
eos_token_id=151643,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
bf16=True,
|
||||
kv_channels=128,
|
||||
rotary_pct=1.0,
|
||||
rotary_emb_base=10000,
|
||||
ntk_alpha=1.0,
|
||||
use_flash_attn=True,
|
||||
ffn_hidden_size=22016,
|
||||
no_bias=True,
|
||||
tie_word_embeddings=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.eos_token_id = eos_token_id
|
||||
super().__init__(
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.n_inner = n_inner
|
||||
self.embd_pdrop = embd_pdrop
|
||||
self.attn_pdrop = attn_pdrop
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.scale_attn_weights = scale_attn_weights
|
||||
self.use_cache = use_cache
|
||||
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
||||
self.bf16 = bf16
|
||||
self.kv_channels = kv_channels
|
||||
self.rotary_pct = rotary_pct
|
||||
self.rotary_emb_base = rotary_emb_base
|
||||
self.ntk_alpha = ntk_alpha
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.no_bias = no_bias
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
401
modelscope/models/nlp/qwen/qwen_generation_utils.py
Normal file
401
modelscope/models/nlp/qwen/qwen_generation_utils.py
Normal file
@@ -0,0 +1,401 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generation support."""
|
||||
|
||||
from typing import Iterable, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.generation import LogitsProcessor
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
# Types.
|
||||
HistoryType = List[Tuple[str, str]]
|
||||
TokensType = List[int]
|
||||
BatchTokensType = List[List[int]]
|
||||
|
||||
|
||||
def pad_batch(batch: BatchTokensType, pad_id: int,
|
||||
seq_length: int) -> BatchTokensType:
|
||||
for tokens in batch:
|
||||
context_length = len(tokens)
|
||||
if context_length < seq_length:
|
||||
tokens.extend([pad_id] * (seq_length - context_length))
|
||||
return batch
|
||||
|
||||
|
||||
def get_ltor_masks_and_position_ids(
|
||||
data,
|
||||
eod_token,
|
||||
reset_position_ids,
|
||||
reset_attention_mask,
|
||||
eod_mask_loss,
|
||||
):
|
||||
"""Build masks and position id for left to right model."""
|
||||
|
||||
# Extract batch size and sequence length.
|
||||
micro_batch_size, seq_length = data.size()
|
||||
|
||||
# Attention mask (lower triangular).
|
||||
if reset_attention_mask:
|
||||
att_mask_batch = micro_batch_size
|
||||
else:
|
||||
att_mask_batch = 1
|
||||
attention_mask = torch.tril(
|
||||
torch.ones((att_mask_batch, seq_length, seq_length),
|
||||
device=data.device)).view(att_mask_batch, 1, seq_length,
|
||||
seq_length)
|
||||
|
||||
# Loss mask.
|
||||
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
|
||||
if eod_mask_loss:
|
||||
loss_mask[data == eod_token] = 0.0
|
||||
|
||||
# Position ids.
|
||||
position_ids = torch.arange(
|
||||
seq_length, dtype=torch.long, device=data.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(data)
|
||||
# We need to clone as the ids will be modifed based on batch index.
|
||||
if reset_position_ids:
|
||||
position_ids = position_ids.clone()
|
||||
|
||||
if reset_position_ids or reset_attention_mask:
|
||||
# Loop through the batches:
|
||||
for b in range(micro_batch_size):
|
||||
|
||||
# Find indecies where EOD token is.
|
||||
eod_index = position_ids[b, data[b] == eod_token]
|
||||
# Detach indecies from positions if going to modify positions.
|
||||
if reset_position_ids:
|
||||
eod_index = eod_index.clone()
|
||||
|
||||
# Loop through EOD indecies:
|
||||
prev_index = 0
|
||||
for j in range(eod_index.size()[0]):
|
||||
i = eod_index[j]
|
||||
# Mask attention loss.
|
||||
if reset_attention_mask:
|
||||
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
|
||||
# Reset positions.
|
||||
if reset_position_ids:
|
||||
position_ids[b, (i + 1):] -= (i + 1 - prev_index)
|
||||
prev_index = i + 1
|
||||
|
||||
# Convert attention mask to binary:
|
||||
attention_mask = (attention_mask < 0.5)
|
||||
|
||||
return attention_mask, loss_mask, position_ids
|
||||
|
||||
|
||||
def get_batch(context_tokens: torch.LongTensor, eod_id: int):
|
||||
"""Generate batch from context tokens."""
|
||||
# Move to GPU.
|
||||
tokens = context_tokens.contiguous().to(context_tokens.device)
|
||||
# Get the attention mask and postition ids.
|
||||
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
|
||||
tokens,
|
||||
eod_id,
|
||||
reset_position_ids=False,
|
||||
reset_attention_mask=False,
|
||||
eod_mask_loss=False)
|
||||
return tokens, attention_mask, position_ids
|
||||
|
||||
|
||||
def get_stop_words_ids(chat_format, tokenizer):
|
||||
if chat_format == 'raw':
|
||||
stop_words_ids = [tokenizer.encode('Human:'), [tokenizer.eod_id]]
|
||||
elif chat_format == 'chatml':
|
||||
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown chat format {chat_format!r}')
|
||||
return stop_words_ids
|
||||
|
||||
|
||||
def make_context(
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]] = [],
|
||||
system: str = '',
|
||||
max_window_size: int = 6144,
|
||||
chat_format: str = 'chatml',
|
||||
):
|
||||
|
||||
if chat_format == 'chatml':
|
||||
im_start, im_end = '<|im_start|>', '<|im_end|>'
|
||||
im_start_tokens = [tokenizer.im_start_id]
|
||||
im_end_tokens = [tokenizer.im_end_id]
|
||||
nl_tokens = tokenizer.encode('\n')
|
||||
|
||||
def _tokenize_str(role, content):
|
||||
return f'{role}\n{content}', tokenizer.encode(
|
||||
role) + nl_tokens + tokenizer.encode(content)
|
||||
|
||||
system_text, system_tokens_part = _tokenize_str('system', system)
|
||||
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
|
||||
|
||||
raw_text = ''
|
||||
context_tokens = []
|
||||
|
||||
for turn_query, turn_response in reversed(history):
|
||||
query_text, query_tokens_part = _tokenize_str('user', turn_query)
|
||||
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
|
||||
response_text, response_tokens_part = _tokenize_str(
|
||||
'assistant', turn_response)
|
||||
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
|
||||
|
||||
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
|
||||
prev_chat = f'\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}'
|
||||
|
||||
current_context_size = len(system_tokens) + len(
|
||||
next_context_tokens) + len(context_tokens)
|
||||
if current_context_size < max_window_size:
|
||||
context_tokens = next_context_tokens + context_tokens
|
||||
raw_text = prev_chat + raw_text
|
||||
else:
|
||||
break
|
||||
|
||||
context_tokens = system_tokens + context_tokens
|
||||
raw_text = f'{im_start}{system_text}{im_end}' + raw_text
|
||||
context_tokens += (
|
||||
nl_tokens + im_start_tokens + _tokenize_str('user', query)[1]
|
||||
+ im_end_tokens + nl_tokens + im_start_tokens
|
||||
+ tokenizer.encode('assistant') + nl_tokens)
|
||||
raw_text += f'\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n'
|
||||
|
||||
elif chat_format == 'raw':
|
||||
raw_text = query
|
||||
context_tokens = tokenizer.encode(raw_text)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown chat format {chat_format!r}')
|
||||
|
||||
return raw_text, context_tokens
|
||||
|
||||
|
||||
def _decode_default(
|
||||
tokens: List[int],
|
||||
*,
|
||||
stop_words: List[str],
|
||||
eod_words: List[str],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
raw_text_len: int,
|
||||
verbose: bool = False,
|
||||
return_end_reason: bool = False,
|
||||
):
|
||||
trim_decode_tokens = tokenizer.decode(tokens)[raw_text_len:]
|
||||
if verbose:
|
||||
print('\nRaw Generate: ', trim_decode_tokens)
|
||||
|
||||
end_reason = f'Gen length {len(tokens)}'
|
||||
for stop_word in stop_words:
|
||||
trim_decode_tokens = trim_decode_tokens.replace(stop_word, '').strip()
|
||||
for eod_word in eod_words:
|
||||
if eod_word in trim_decode_tokens:
|
||||
end_reason = f'Gen {eod_word!r}'
|
||||
trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
|
||||
trim_decode_tokens = trim_decode_tokens.strip()
|
||||
if verbose:
|
||||
print('\nEnd Reason:', end_reason)
|
||||
print('\nGenerate: ', trim_decode_tokens)
|
||||
|
||||
if return_end_reason:
|
||||
return trim_decode_tokens, end_reason
|
||||
else:
|
||||
return trim_decode_tokens
|
||||
|
||||
|
||||
def _decode_chatml(
|
||||
tokens: List[int],
|
||||
*,
|
||||
stop_words: List[str],
|
||||
eod_token_ids: List[int],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
raw_text_len: int,
|
||||
context_length: int,
|
||||
verbose: bool = False,
|
||||
return_end_reason: bool = False,
|
||||
):
|
||||
end_reason = f'Gen length {len(tokens)}'
|
||||
eod_token_idx = context_length
|
||||
for eod_token_idx in range(context_length, len(tokens)):
|
||||
if tokens[eod_token_idx] in eod_token_ids:
|
||||
end_reason = f'Gen {tokenizer.decode([tokens[eod_token_idx]])!r}'
|
||||
break
|
||||
|
||||
trim_decode_tokens = tokenizer.decode(
|
||||
tokens[:eod_token_idx])[raw_text_len:]
|
||||
if verbose:
|
||||
print('\nRaw Generate w/o EOD:',
|
||||
tokenizer.decode(tokens)[raw_text_len:])
|
||||
print('\nRaw Generate:', trim_decode_tokens)
|
||||
print('\nEnd Reason:', end_reason)
|
||||
for stop_word in stop_words:
|
||||
trim_decode_tokens = trim_decode_tokens.replace(stop_word, '').strip()
|
||||
trim_decode_tokens = trim_decode_tokens.strip()
|
||||
if verbose:
|
||||
print('\nGenerate:', trim_decode_tokens)
|
||||
|
||||
if return_end_reason:
|
||||
return trim_decode_tokens, end_reason
|
||||
else:
|
||||
return trim_decode_tokens
|
||||
|
||||
|
||||
def decode_tokens(
|
||||
tokens: Union[torch.LongTensor, TokensType],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
raw_text_len: int,
|
||||
context_length: int,
|
||||
chat_format: str,
|
||||
verbose: bool = False,
|
||||
return_end_reason: bool = False,
|
||||
) -> str:
|
||||
if torch.is_tensor(tokens):
|
||||
tokens = tokens.cpu().numpy().tolist()
|
||||
|
||||
if chat_format == 'chatml':
|
||||
return _decode_chatml(
|
||||
tokens,
|
||||
stop_words=[],
|
||||
eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
|
||||
tokenizer=tokenizer,
|
||||
raw_text_len=raw_text_len,
|
||||
context_length=context_length,
|
||||
verbose=verbose,
|
||||
return_end_reason=return_end_reason,
|
||||
)
|
||||
elif chat_format == 'raw':
|
||||
return _decode_default(
|
||||
tokens,
|
||||
stop_words=['<|endoftext|>'],
|
||||
eod_words=['<|endoftext|>'],
|
||||
tokenizer=tokenizer,
|
||||
raw_text_len=raw_text_len,
|
||||
verbose=verbose,
|
||||
return_end_reason=return_end_reason,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f'Unknown chat format {chat_format!r}')
|
||||
|
||||
|
||||
class StopWordsLogitsProcessor(LogitsProcessor):
|
||||
"""
|
||||
:class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
|
||||
|
||||
Args:
|
||||
stop_words_ids (:obj:`List[List[int]]`):
|
||||
List of list of token ids of stop ids. In order to get the tokens of the words
|
||||
that should not appear in the generated text, use :obj:`tokenizer(bad_word,
|
||||
add_prefix_space=True).input_ids`.
|
||||
eos_token_id (:obj:`int`):
|
||||
The id of the `end-of-sequence` token.
|
||||
"""
|
||||
|
||||
def __init__(self, stop_words_ids: Iterable[Iterable[int]],
|
||||
eos_token_id: int):
|
||||
|
||||
if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
|
||||
raise ValueError(
|
||||
f'`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}.'
|
||||
)
|
||||
if any(not isinstance(bad_word_ids, list)
|
||||
for bad_word_ids in stop_words_ids):
|
||||
raise ValueError(
|
||||
f'`stop_words_ids` has to be a list of lists, but is {stop_words_ids}.'
|
||||
)
|
||||
if any(
|
||||
any((not isinstance(token_id, (int,
|
||||
np.integer)) or token_id < 0)
|
||||
for token_id in stop_word_ids)
|
||||
for stop_word_ids in stop_words_ids):
|
||||
raise ValueError(
|
||||
f'Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}.'
|
||||
)
|
||||
|
||||
self.stop_words_ids = list(
|
||||
filter(lambda bad_token_seq: bad_token_seq != [eos_token_id],
|
||||
stop_words_ids))
|
||||
self.eos_token_id = eos_token_id
|
||||
for stop_token_seq in self.stop_words_ids:
|
||||
assert len(
|
||||
stop_token_seq
|
||||
) > 0, 'Stop words token sequences {} cannot have an empty list'.format(
|
||||
stop_words_ids)
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor,
|
||||
scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
stopped_samples = self._calc_stopped_samples(input_ids)
|
||||
for i, should_stop in enumerate(stopped_samples):
|
||||
if should_stop:
|
||||
scores[i, self.eos_token_id] = float(2**30)
|
||||
return scores
|
||||
|
||||
def _tokens_match(self, prev_tokens: torch.LongTensor,
|
||||
tokens: List[int]) -> bool:
|
||||
if len(tokens) == 0:
|
||||
# if bad word tokens is just one token always ban it
|
||||
return True
|
||||
elif len(tokens) > len(prev_tokens):
|
||||
# if bad word tokens are longer then prev input_ids they can't be equal
|
||||
return False
|
||||
elif prev_tokens[-len(tokens):].tolist() == tokens:
|
||||
# if tokens match
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _calc_stopped_samples(self,
|
||||
prev_input_ids: Iterable[int]) -> Iterable[int]:
|
||||
stopped_samples = []
|
||||
for prev_input_ids_slice in prev_input_ids:
|
||||
match = False
|
||||
for stop_token_seq in self.stop_words_ids:
|
||||
if self._tokens_match(prev_input_ids_slice, stop_token_seq):
|
||||
# if tokens do not match continue
|
||||
match = True
|
||||
break
|
||||
stopped_samples.append(match)
|
||||
|
||||
return stopped_samples
|
||||
|
||||
|
||||
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
||||
""" This function has been mostly taken from huggingface conversational
|
||||
ai code at
|
||||
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
|
||||
conversational-ai-with-transfer-learning-2d818ac26313 """
|
||||
|
||||
if top_k > 0:
|
||||
# Remove all tokens with a probability less than the
|
||||
# last token of the top-k
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1,
|
||||
None]
|
||||
logits[indices_to_remove] = filter_value
|
||||
|
||||
if top_p > 0.0:
|
||||
# Cconvert to 1D
|
||||
sorted_logits, sorted_indices = torch.sort(
|
||||
logits, descending=True, dim=-1)
|
||||
cumulative_probs = torch.cumsum(
|
||||
F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
# Remove tokens with cumulative probability above the threshold
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
# Shift the indices to the right to keep also the first token
|
||||
# above the threshold
|
||||
sorted_indices_to_remove[..., 1:] \
|
||||
= sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
for i in range(sorted_indices.size(0)):
|
||||
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
|
||||
logits[i][indices_to_remove] = filter_value
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def switch(val1, val2, boolean):
|
||||
boolean = boolean.type_as(val1)
|
||||
return (1 - boolean) * val1 + boolean * val2
|
||||
194
modelscope/models/nlp/qwen/text_generation.py
Normal file
194
modelscope/models/nlp/qwen/text_generation.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import warnings
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from ... import MODELS
|
||||
from .backbone import QWenModel, QWenPreTrainedModel
|
||||
from .qwen_generation_utils import (BatchTokensType, HistoryType,
|
||||
StopWordsLogitsProcessor, decode_tokens,
|
||||
get_batch, get_stop_words_ids,
|
||||
make_context, pad_batch, switch,
|
||||
top_k_logits)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.text_generation, module_name=Models.qwen_7b)
|
||||
@MODELS.register_module(Tasks.chat, module_name=Models.qwen_7b)
|
||||
class QWenForTextGeneration(QWenPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r'h\.\d+\.attn\.rotary_emb\.inv_freq']
|
||||
_keys_to_ignore_on_load_unexpected = [r'h\.\d+\.attn\.masked_bias']
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.transformer = QWenModel(config)
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
self.post_init()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def prepare_inputs_for_generation(self,
|
||||
input_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
**kwargs):
|
||||
token_type_ids = kwargs.get('token_type_ids', None)
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
attention_mask = kwargs.get('attention_mask', None)
|
||||
position_ids = kwargs.get('position_ids', None)
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past_key_values:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {'inputs_embeds': inputs_embeds}
|
||||
else:
|
||||
model_inputs = {'input_ids': input_ids}
|
||||
|
||||
model_inputs.update({
|
||||
'past_key_values': past_key_values,
|
||||
'use_cache': kwargs.get('use_cache'),
|
||||
'position_ids': position_ids,
|
||||
'attention_mask': attention_mask,
|
||||
'token_type_ids': token_type_ids,
|
||||
})
|
||||
return model_inputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
hidden_states = transformer_outputs[0]
|
||||
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
labels = labels.to(lm_logits.device)
|
||||
shift_logits = lm_logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits, ) + transformer_outputs[1:]
|
||||
return ((loss, ) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=lm_logits,
|
||||
past_key_values=transformer_outputs.past_key_values,
|
||||
hidden_states=transformer_outputs.hidden_states,
|
||||
attentions=transformer_outputs.attentions,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reorder_cache(past_key_values: Tuple[Tuple[torch.Tensor]],
|
||||
beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
|
||||
|
||||
return tuple(
|
||||
tuple(
|
||||
past_state.index_select(0, beam_idx.to(past_state.device))
|
||||
for past_state in layer_past)
|
||||
for layer_past in past_key_values)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
query: str,
|
||||
history: Optional[HistoryType],
|
||||
system: str = '',
|
||||
append_history: bool = True,
|
||||
) -> Tuple[str, HistoryType]:
|
||||
|
||||
if history is None:
|
||||
history = []
|
||||
|
||||
raw_text, context_tokens = make_context(
|
||||
tokenizer,
|
||||
query,
|
||||
history=history,
|
||||
system=system,
|
||||
max_window_size=6144,
|
||||
chat_format=self.generation_config.chat_format)
|
||||
|
||||
stop_words_ids = get_stop_words_ids(self.generation_config.chat_format,
|
||||
tokenizer)
|
||||
input_ids = torch.tensor([context_tokens]).to(self.device)
|
||||
|
||||
outputs = self.generate(
|
||||
input_ids,
|
||||
stop_words_ids=stop_words_ids,
|
||||
return_dict_in_generate=False,
|
||||
)
|
||||
|
||||
response = decode_tokens(
|
||||
outputs[0],
|
||||
tokenizer,
|
||||
raw_text_len=len(raw_text),
|
||||
context_length=len(context_tokens),
|
||||
chat_format=self.generation_config.chat_format,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
if append_history:
|
||||
history.append((query, response))
|
||||
|
||||
return {OutputKeys.RESPONSE: response, OutputKeys.HISTORY: history}
|
||||
226
modelscope/models/nlp/qwen/tokenization.py
Normal file
226
modelscope/models/nlp/qwen/tokenization.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# TODO
|
||||
# Copyright 2023 Alibaba Group. All rights reserved.
|
||||
"""Tokenization classes for QWen."""
|
||||
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import unicodedata
|
||||
from io import open
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import json
|
||||
import tiktoken
|
||||
from transformers import AddedToken, PreTrainedTokenizer
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
TIKTOKEN_NAME = 'qwen.tiktoken'
|
||||
|
||||
|
||||
class QWenTokenizer(PreTrainedTokenizer):
|
||||
"""QWen tokenizer."""
|
||||
"""NOTE: This tokenizer will not handle special tokens to avoid injection attacks"""
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls,
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=None,
|
||||
*inputs,
|
||||
**kwargs):
|
||||
merges_file = os.path.join(pretrained_model_name_or_path,
|
||||
TIKTOKEN_NAME)
|
||||
tokenizer = cls(merges_file, *inputs, **kwargs)
|
||||
return tokenizer
|
||||
|
||||
def __init__(self,
|
||||
merges_file,
|
||||
errors='replace',
|
||||
max_len=None,
|
||||
unk_token='<|endoftext|>',
|
||||
bos_token='<|endoftext|>',
|
||||
eos_token='<|endoftext|>',
|
||||
pad_token=None,
|
||||
add_prefix_space=False,
|
||||
add_bos_token=False,
|
||||
add_more_sp_tokens=True,
|
||||
**kwargs):
|
||||
bos_token = AddedToken(
|
||||
bos_token, lstrip=False, rstrip=False) if isinstance(
|
||||
bos_token, str) else bos_token
|
||||
eos_token = AddedToken(
|
||||
eos_token, lstrip=False, rstrip=False) if isinstance(
|
||||
eos_token, str) else eos_token
|
||||
unk_token = AddedToken(
|
||||
unk_token, lstrip=False, rstrip=False) if isinstance(
|
||||
unk_token, str) else unk_token
|
||||
pad_token = AddedToken(
|
||||
pad_token, lstrip=False, rstrip=False) if isinstance(
|
||||
pad_token, str) else pad_token
|
||||
super().__init__(
|
||||
errors=errors,
|
||||
unk_token=unk_token,
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
pad_token=pad_token,
|
||||
add_prefix_space=add_prefix_space,
|
||||
add_bos_token=add_bos_token,
|
||||
)
|
||||
self.add_bos_token = add_bos_token
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
|
||||
name = 'QWen'
|
||||
ENDOFTEXT = '<|endoftext|>'
|
||||
IMSTART = '<|im_start|>'
|
||||
IMEND = '<|im_end|>'
|
||||
if add_more_sp_tokens:
|
||||
special_tokens = (ENDOFTEXT, IMSTART, IMEND, '<R>', '<S>', '<X>',
|
||||
'<mask>', '<sep>') + tuple(
|
||||
[f'<extra_{i}>' for i in range(200)])
|
||||
else:
|
||||
special_tokens = (ENDOFTEXT, IMSTART, IMEND)
|
||||
|
||||
PAT_STR = (
|
||||
r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}|"""
|
||||
r""" ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""")
|
||||
|
||||
def load_tiktoken_bpe(tiktoken_bpe_file: str) -> 'dict[bytes, int]':
|
||||
contents = open(tiktoken_bpe_file, 'rb').read()
|
||||
return {
|
||||
base64.b64decode(token): int(rank)
|
||||
for token, rank in (line.split()
|
||||
for line in contents.splitlines() if line)
|
||||
}
|
||||
|
||||
mergeable_ranks = load_tiktoken_bpe(merges_file)
|
||||
special_tokens = {
|
||||
token: index
|
||||
for index, token in enumerate(
|
||||
special_tokens, start=len(mergeable_ranks))
|
||||
}
|
||||
self.special_tokens = special_tokens
|
||||
enc = tiktoken.Encoding(
|
||||
name,
|
||||
pat_str=PAT_STR,
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
assert len(mergeable_ranks) + len(
|
||||
special_tokens
|
||||
) == enc.n_vocab, f'{len(mergeable_ranks) + len(special_tokens)} != {enc.n_vocab} in encoding'
|
||||
|
||||
self.mergeable_ranks = mergeable_ranks
|
||||
self.encoder = self.mergeable_ranks
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.tokenizer = enc # type: tiktoken.Encoding
|
||||
self.eod_id = self.tokenizer.eot_token
|
||||
self.im_start_id = special_tokens[IMSTART]
|
||||
self.im_end_id = special_tokens[IMEND]
|
||||
|
||||
def __len__(self):
|
||||
return self.tokenizer.n_vocab
|
||||
|
||||
def get_vocab(self):
|
||||
return self.mergeable_ranks
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
ids = []
|
||||
# Remove support for py2
|
||||
if isinstance(tokens, str):
|
||||
if tokens in self.special_tokens:
|
||||
return self.special_tokens[tokens]
|
||||
else:
|
||||
return self.encoder.get(tokens)
|
||||
for token in tokens:
|
||||
if token in self.special_tokens:
|
||||
ids.append(self.special_tokens[token])
|
||||
else:
|
||||
ids.append(self.encoder.get(token))
|
||||
if len(ids) > self.max_len:
|
||||
logger.warning(
|
||||
'Token indices sequence length is longer than the specified maximum '
|
||||
' sequence length for this OpenAI GPT model ({} > {}). Running this'
|
||||
' sequence through the model will result in indexing errors'.
|
||||
format(len(ids), self.max_len))
|
||||
return ids
|
||||
|
||||
def save_vocabulary(self,
|
||||
save_directory: str,
|
||||
filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
"""
|
||||
Save only the vocabulary of the tokenizer (vocabulary + added tokens).
|
||||
|
||||
Returns:
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
file_path = os.path.join(save_directory, filename_prefix)
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(f, self.mergeable_ranks)
|
||||
return file_path
|
||||
|
||||
def tokenize(self, text: str, **kwargs) -> List[str]:
|
||||
"""
|
||||
Converts a string in a sequence of tokens, replacing unknown tokens with the `unk_token`.
|
||||
|
||||
Args:
|
||||
text (`str`):
|
||||
The sequence to be encoded.
|
||||
pair (`str`, *optional*):
|
||||
A second sequence to be encoded with the first.
|
||||
add_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to add the special tokens associated with the corresponding model.
|
||||
kwargs (additional keyword arguments, *optional*):
|
||||
Will be passed to the underlying model specific encode method. See details in
|
||||
[`~PreTrainedTokenizerBase.__call__`]
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of tokens.
|
||||
"""
|
||||
tokens = []
|
||||
text = unicodedata.normalize('NFC', text)
|
||||
for t in self.tokenizer.encode_ordinary(text):
|
||||
tokens.append(self.decoder[t])
|
||||
return tokens
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
"""
|
||||
Converts a sequence of tokens in a single string. The most simple way to do it is `" ".join(tokens)` but we
|
||||
often want to remove sub-word tokenization artifacts at the same time.
|
||||
"""
|
||||
text = ''.join(tokens)
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
||||
'utf-8', errors=self.errors)
|
||||
return text
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.tokenizer.n_vocab
|
||||
|
||||
def _convert_id_to_token(self, index: int) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _tokenize(self, text, **kwargs):
|
||||
"""
|
||||
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
|
||||
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
|
||||
|
||||
Do NOT take care of added tokens.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _decode(
|
||||
self,
|
||||
token_ids: Union[int, List[int]],
|
||||
skip_special_tokens: bool = False,
|
||||
clean_up_tokenization_spaces: bool = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if isinstance(token_ids, int):
|
||||
token_ids = [token_ids]
|
||||
return self.tokenizer.decode(token_ids)
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig
|
||||
|
||||
from modelscope import snapshot_download
|
||||
from modelscope.metainfo import Pipelines
|
||||
@@ -20,8 +21,12 @@ from modelscope.utils.hub import Config, read_config
|
||||
from modelscope.utils.streaming_output import PipelineStreamingOutputMixin
|
||||
|
||||
__all__ = [
|
||||
'TextGenerationPipeline', 'TextGenerationT5Pipeline',
|
||||
'ChatGLM6bTextGenerationPipeline', 'ChatGLM6bV2TextGenerationPipeline'
|
||||
'TextGenerationPipeline',
|
||||
'TextGenerationT5Pipeline',
|
||||
'ChatGLM6bTextGenerationPipeline',
|
||||
'ChatGLM6bV2TextGenerationPipeline',
|
||||
'QWenChatPipeline',
|
||||
'QWenTextGenerationPipeline',
|
||||
]
|
||||
|
||||
|
||||
@@ -268,3 +273,104 @@ class ChatGLM6bV2TextGenerationPipeline(Pipeline):
|
||||
# format the outputs from pipeline
|
||||
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
|
||||
return input
|
||||
|
||||
|
||||
@PIPELINES.register_module(group_key=Tasks.chat, module_name='qwen-chat')
|
||||
class QWenChatPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: Union[Model, str], **kwargs):
|
||||
from modelscope.models.nlp import QWenConfig, QWenTokenizer, QWenForTextGeneration
|
||||
torch_dtype = kwargs.get('torch_dtype', torch.bfloat16)
|
||||
device_map = kwargs.get('device_map', 'auto')
|
||||
|
||||
if isinstance(model, str):
|
||||
model_dir = snapshot_download(
|
||||
model) if not os.path.exists(model) else model
|
||||
|
||||
config = read_config(model_dir)
|
||||
model_config = QWenConfig.from_pretrained(model_dir)
|
||||
model_config.torch_dtype = torch_dtype
|
||||
|
||||
model = QWenForTextGeneration.from_pretrained(
|
||||
model_dir,
|
||||
cfg_dict=config,
|
||||
config=model_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch_dtype)
|
||||
model.generation_config = GenerationConfig.from_pretrained(
|
||||
model_dir)
|
||||
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
self.tokenizer = QWenTokenizer.from_pretrained(self.model.model_dir)
|
||||
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, **pipeline_parameters):
|
||||
return {}, pipeline_parameters, {}
|
||||
|
||||
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
# define the forward pass
|
||||
def forward(self, inputs: str, **forward_params) -> Dict[str, Any]:
|
||||
history = forward_params.get('history', None)
|
||||
system = forward_params.get('system', '')
|
||||
append_history = forward_params.get('append_history', True)
|
||||
return self.model.chat(self.tokenizer, inputs, history, system,
|
||||
append_history)
|
||||
|
||||
# format the outputs from pipeline
|
||||
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
|
||||
return input
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
group_key=Tasks.text_generation, module_name='qwen-text-generation')
|
||||
class QWenTextGenerationPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: Union[Model, str], **kwargs):
|
||||
from modelscope.models.nlp import QWenConfig, QWenTokenizer, QWenForTextGeneration
|
||||
torch_dtype = kwargs.get('torch_dtype', torch.bfloat16)
|
||||
device_map = kwargs.get('device_map', 'auto')
|
||||
|
||||
if isinstance(model, str):
|
||||
model_dir = snapshot_download(
|
||||
model) if not os.path.exists(model) else model
|
||||
|
||||
config = read_config(model_dir)
|
||||
model_config = QWenConfig.from_pretrained(model_dir)
|
||||
model_config.torch_dtype = torch_dtype
|
||||
|
||||
model = QWenForTextGeneration.from_pretrained(
|
||||
model_dir,
|
||||
cfg_dict=config,
|
||||
config=model_config,
|
||||
device_map=device_map,
|
||||
torch_dtype=torch_dtype)
|
||||
model.generation_config = GenerationConfig.from_pretrained(
|
||||
model_dir)
|
||||
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
self.tokenizer = QWenTokenizer.from_pretrained(self.model.model_dir)
|
||||
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, **pipeline_parameters):
|
||||
return {}, pipeline_parameters, {}
|
||||
|
||||
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
|
||||
return inputs
|
||||
|
||||
# define the forward pass
|
||||
def forward(self, inputs: str, **forward_params) -> Dict[str, Any]:
|
||||
return {
|
||||
OutputKeys.TEXT:
|
||||
self.model.chat(self.tokenizer, inputs,
|
||||
history=None)[OutputKeys.RESPONSE]
|
||||
}
|
||||
|
||||
# format the outputs from pipeline
|
||||
def postprocess(self, input, **kwargs) -> Dict[str, Any]:
|
||||
return input
|
||||
|
||||
81
tests/pipelines/test_qwen_text_generation_pipeline.py
Normal file
81
tests/pipelines/test_qwen_text_generation_pipeline.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class QWenTextGenerationPipelineTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.qwen_base = '../qwen_7b_ckpt_modelscope/' # local test only
|
||||
self.qwen_chat = '../qwen_7b_ckpt_chat_modelscope/' # local test only
|
||||
|
||||
self.qwen_base_input = '蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n埃塞俄比亚的首都是'
|
||||
self.qwen_chat_system = 'You are a helpful assistant.'
|
||||
self.qwen_chat_input = [
|
||||
'今天天气真好,我', 'How do you do? ', "What's your", '今夜阳光明媚', '宫廷玉液酒,',
|
||||
'7 * 8 + 32 =? ', '请问把大象关冰箱总共要几步?', '1+3=?',
|
||||
'请将下面这句话翻译为英文:在哪里跌倒就在哪里趴着'
|
||||
]
|
||||
|
||||
def run_pipeline_with_model_id(self,
|
||||
model_id,
|
||||
input,
|
||||
init_kwargs={},
|
||||
run_kwargs={}):
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.text_generation, model=model_id, **init_kwargs)
|
||||
pipeline_ins._model_prepare = True
|
||||
result = pipeline_ins(input, **run_kwargs)
|
||||
print(result['text'])
|
||||
|
||||
def run_chat_pipeline_with_model_id(self,
|
||||
model_id,
|
||||
inputs,
|
||||
system,
|
||||
init_kwargs={},
|
||||
run_kwargs={}):
|
||||
pipeline_ins = pipeline(task=Tasks.chat, model=model_id, **init_kwargs)
|
||||
pipeline_ins._model_prepare = True
|
||||
|
||||
history = None
|
||||
for turn_idx, query in enumerate(inputs, start=1):
|
||||
results = pipeline_ins(
|
||||
query,
|
||||
history=history,
|
||||
system=system,
|
||||
)
|
||||
response, history = results['response'], results['history']
|
||||
print(f'===== Turn {turn_idx} ====')
|
||||
print('Query:', query, end='\n')
|
||||
print('Response:', response, end='\n')
|
||||
|
||||
# 7B_ms_base
|
||||
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
|
||||
def test_qwen_base_with_text_generation(self):
|
||||
self.run_pipeline_with_model_id(
|
||||
self.qwen_base,
|
||||
self.qwen_base_input,
|
||||
init_kwargs={
|
||||
'device_map': 'auto',
|
||||
})
|
||||
|
||||
# 7B_ms_chat
|
||||
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
|
||||
def test_qwen_chat_with_chat(self):
|
||||
self.run_chat_pipeline_with_model_id(
|
||||
self.qwen_chat,
|
||||
self.qwen_chat_input,
|
||||
self.qwen_chat_system,
|
||||
init_kwargs={
|
||||
'device_map': 'auto',
|
||||
})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user