Feat/chatglm2 32k support (#433)

* add support for chatglm2 32k

* fix
This commit is contained in:
wenmeng zhou
2023-08-01 10:58:29 +08:00
committed by GitHub
parent f2396ddb6a
commit 2809bdbff6

View File

@@ -122,13 +122,19 @@ def split_tensor_along_last_dim(
class RotaryEmbedding(nn.Module):
def __init__(self, dim, original_impl=False, device=None, dtype=None):
def __init__(self,
dim,
rope_ratio=1,
original_impl=False,
device=None,
dtype=None):
super().__init__()
inv_freq = 1.0 / (10000**(
torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
self.register_buffer('inv_freq', inv_freq)
self.dim = dim
self.original_impl = original_impl
self.rope_ratio = rope_ratio
def forward_impl(self,
seq_len: int,
@@ -148,7 +154,8 @@ class RotaryEmbedding(nn.Module):
/ n_elem))
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
seq_idx = torch.arange(
seq_len, dtype=dtype, device=device) / self.rope_ratio
# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta).float()
@@ -864,6 +871,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
self.rotary_pos_emb = RotaryEmbedding(
rotary_dim // 2,
rope_ratio=config.rope_ratio,
original_impl=config.original_rope,
device=device,
dtype=config.torch_dtype)
@@ -1169,7 +1177,7 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel):
tokenizer,
query: str,
history: List[Tuple[str, str]] = None,
max_length: int = 8192,
max_length: int = None,
num_beams=1,
do_sample=True,
top_p=0.8,
@@ -1181,6 +1189,8 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel):
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
if max_length is None:
max_length = self.seq_length
gen_kwargs = {
'max_length': max_length,
'num_beams': num_beams,
@@ -1204,7 +1214,7 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel):
query: str,
history: List[Tuple[str, str]] = None,
past_key_values=None,
max_length: int = 8192,
max_length: int = None,
do_sample=True,
top_p=0.8,
temperature=0.8,
@@ -1216,6 +1226,8 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel):
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
if max_length is None:
max_length = self.seq_length
gen_kwargs = {
'max_length': max_length,
'do_sample': do_sample,