diff --git a/modelscope/models/nlp/chatglm2/text_generation.py b/modelscope/models/nlp/chatglm2/text_generation.py index 1052b875..21323e64 100644 --- a/modelscope/models/nlp/chatglm2/text_generation.py +++ b/modelscope/models/nlp/chatglm2/text_generation.py @@ -122,13 +122,19 @@ def split_tensor_along_last_dim( class RotaryEmbedding(nn.Module): - def __init__(self, dim, original_impl=False, device=None, dtype=None): + def __init__(self, + dim, + rope_ratio=1, + original_impl=False, + device=None, + dtype=None): super().__init__() inv_freq = 1.0 / (10000**( torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) self.register_buffer('inv_freq', inv_freq) self.dim = dim self.original_impl = original_impl + self.rope_ratio = rope_ratio def forward_impl(self, seq_len: int, @@ -148,7 +154,8 @@ class RotaryEmbedding(nn.Module): / n_elem)) # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = torch.arange(seq_len, dtype=dtype, device=device) + seq_idx = torch.arange( + seq_len, dtype=dtype, device=device) / self.rope_ratio # Calculate the product of position index and $\theta_i$ idx_theta = torch.outer(seq_idx, theta).float() @@ -864,6 +871,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): self.rotary_pos_emb = RotaryEmbedding( rotary_dim // 2, + rope_ratio=config.rope_ratio, original_impl=config.original_rope, device=device, dtype=config.torch_dtype) @@ -1169,7 +1177,7 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel): tokenizer, query: str, history: List[Tuple[str, str]] = None, - max_length: int = 8192, + max_length: int = None, num_beams=1, do_sample=True, top_p=0.8, @@ -1181,6 +1189,8 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel): if logits_processor is None: logits_processor = LogitsProcessorList() logits_processor.append(InvalidScoreLogitsProcessor()) + if max_length is None: + max_length = self.seq_length gen_kwargs = { 'max_length': max_length, 'num_beams': num_beams, @@ -1204,7 +1214,7 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel): query: str, history: List[Tuple[str, str]] = None, past_key_values=None, - max_length: int = 8192, + max_length: int = None, do_sample=True, top_p=0.8, temperature=0.8, @@ -1216,6 +1226,8 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel): if logits_processor is None: logits_processor = LogitsProcessorList() logits_processor.append(InvalidScoreLogitsProcessor()) + if max_length is None: + max_length = self.seq_length gen_kwargs = { 'max_length': max_length, 'do_sample': do_sample,