add support for chatglm2 32k

This commit is contained in:
wenmeng.zwm
2023-08-01 09:54:22 +08:00
parent ef1b429bef
commit 34b75afac3

View File

@@ -122,7 +122,12 @@ def split_tensor_along_last_dim(
class RotaryEmbedding(nn.Module):
def __init__(self, dim, original_impl=False, device=None, dtype=None):
def __init__(self,
dim,
rope_ratio=1,
original_impl=False,
device=None,
dtype=None):
super().__init__()
inv_freq = 1.0 / (10000**(
torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
@@ -148,7 +153,8 @@ class RotaryEmbedding(nn.Module):
/ n_elem))
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
seq_idx = torch.arange(
seq_len, dtype=dtype, device=device) / self.rope_ratio
# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta).float()
@@ -864,6 +870,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
self.rotary_pos_emb = RotaryEmbedding(
rotary_dim // 2,
rope_ratio=config.rope_ratio,
original_impl=config.original_rope,
device=device,
dtype=config.torch_dtype)
@@ -1169,7 +1176,7 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel):
tokenizer,
query: str,
history: List[Tuple[str, str]] = None,
max_length: int = 8192,
max_length: int = None,
num_beams=1,
do_sample=True,
top_p=0.8,
@@ -1181,6 +1188,8 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel):
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
if max_length is None:
max_length = self.seq_length
gen_kwargs = {
'max_length': max_length,
'num_beams': num_beams,
@@ -1204,7 +1213,7 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel):
query: str,
history: List[Tuple[str, str]] = None,
past_key_values=None,
max_length: int = 8192,
max_length: int = None,
do_sample=True,
top_p=0.8,
temperature=0.8,
@@ -1216,6 +1225,8 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel):
if logits_processor is None:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
if max_length is None:
max_length = self.seq_length
gen_kwargs = {
'max_length': max_length,
'do_sample': do_sample,