mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
@@ -122,13 +122,19 @@ def split_tensor_along_last_dim(
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, dim, original_impl=False, device=None, dtype=None):
|
||||
def __init__(self,
|
||||
dim,
|
||||
rope_ratio=1,
|
||||
original_impl=False,
|
||||
device=None,
|
||||
dtype=None):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (10000**(
|
||||
torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
self.dim = dim
|
||||
self.original_impl = original_impl
|
||||
self.rope_ratio = rope_ratio
|
||||
|
||||
def forward_impl(self,
|
||||
seq_len: int,
|
||||
@@ -148,7 +154,8 @@ class RotaryEmbedding(nn.Module):
|
||||
/ n_elem))
|
||||
|
||||
# Create position indexes `[0, 1, ..., seq_len - 1]`
|
||||
seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
|
||||
seq_idx = torch.arange(
|
||||
seq_len, dtype=dtype, device=device) / self.rope_ratio
|
||||
|
||||
# Calculate the product of position index and $\theta_i$
|
||||
idx_theta = torch.outer(seq_idx, theta).float()
|
||||
@@ -864,6 +871,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
|
||||
self.rotary_pos_emb = RotaryEmbedding(
|
||||
rotary_dim // 2,
|
||||
rope_ratio=config.rope_ratio,
|
||||
original_impl=config.original_rope,
|
||||
device=device,
|
||||
dtype=config.torch_dtype)
|
||||
@@ -1169,7 +1177,7 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
tokenizer,
|
||||
query: str,
|
||||
history: List[Tuple[str, str]] = None,
|
||||
max_length: int = 8192,
|
||||
max_length: int = None,
|
||||
num_beams=1,
|
||||
do_sample=True,
|
||||
top_p=0.8,
|
||||
@@ -1181,6 +1189,8 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
if logits_processor is None:
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InvalidScoreLogitsProcessor())
|
||||
if max_length is None:
|
||||
max_length = self.seq_length
|
||||
gen_kwargs = {
|
||||
'max_length': max_length,
|
||||
'num_beams': num_beams,
|
||||
@@ -1204,7 +1214,7 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
query: str,
|
||||
history: List[Tuple[str, str]] = None,
|
||||
past_key_values=None,
|
||||
max_length: int = 8192,
|
||||
max_length: int = None,
|
||||
do_sample=True,
|
||||
top_p=0.8,
|
||||
temperature=0.8,
|
||||
@@ -1216,6 +1226,8 @@ class ChatGLM2ForConditionalGeneration(ChatGLMPreTrainedModel):
|
||||
if logits_processor is None:
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InvalidScoreLogitsProcessor())
|
||||
if max_length is None:
|
||||
max_length = self.seq_length
|
||||
gen_kwargs = {
|
||||
'max_length': max_length,
|
||||
'do_sample': do_sample,
|
||||
|
||||
Reference in New Issue
Block a user