diff --git a/modelscope/models/nlp/chatglm2/text_generation.py b/modelscope/models/nlp/chatglm2/text_generation.py index 1052b875..57a7270f 100644 --- a/modelscope/models/nlp/chatglm2/text_generation.py +++ b/modelscope/models/nlp/chatglm2/text_generation.py @@ -122,7 +122,12 @@ def split_tensor_along_last_dim( class RotaryEmbedding(nn.Module): - def __init__(self, dim, original_impl=False, device=None, dtype=None): + def __init__(self, + dim, + rope_ratio=1, + original_impl=False, + device=None, + dtype=None): super().__init__() inv_freq = 1.0 / (10000**( torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) @@ -148,7 +153,8 @@ class RotaryEmbedding(nn.Module): / n_elem)) # Create position indexes `[0, 1, ..., seq_len - 1]` - seq_idx = torch.arange(seq_len, dtype=dtype, device=device) + seq_idx = torch.arange( + seq_len, dtype=dtype, device=device) / self.rope_ratio # Calculate the product of position index and $\theta_i$ idx_theta = torch.outer(seq_idx, theta).float() @@ -864,6 +870,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel): self.rotary_pos_emb = RotaryEmbedding( rotary_dim // 2, + rope_ratio=config.rope_ratio, original_impl=config.original_rope, device=device, dtype=config.torch_dtype) @@ -1169,7 +1176,7 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel): tokenizer, query: str, history: List[Tuple[str, str]] = None, - max_length: int = 8192, + max_length: int = None, num_beams=1, do_sample=True, top_p=0.8, @@ -1181,6 +1188,8 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel): if logits_processor is None: logits_processor = LogitsProcessorList() logits_processor.append(InvalidScoreLogitsProcessor()) + if max_length is None: + max_length = self.seq_length gen_kwargs = { 'max_length': max_length, 'num_beams': num_beams, @@ -1204,7 +1213,7 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel): query: str, history: List[Tuple[str, str]] = None, past_key_values=None, - max_length: int = 8192, + max_length: int = None, do_sample=True, top_p=0.8, temperature=0.8, @@ -1216,6 +1225,8 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel): if logits_processor is None: logits_processor = LogitsProcessorList() logits_processor.append(InvalidScoreLogitsProcessor()) + if max_length is None: + max_length = self.seq_length gen_kwargs = { 'max_length': max_length, 'do_sample': do_sample,