fix numpy tensor error for csanmt

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11621108
This commit is contained in:
xiangpeng.wxp
2023-02-10 06:06:35 +00:00
committed by wenmeng.zwm
parent 27304e38b0
commit e252113294

View File

@@ -85,7 +85,7 @@ class CsanmtForTranslation(Model):
src_bias = tf.compat.v1.get_variable('encoder_input_bias',
[hidden_size])
eos_padding = tf.zeros([tf.shape(input=features)[0], 1], tf.int64)
eos_padding = tf.zeros_like(features, dtype=tf.int64)[:, :1]
src_seq = tf.concat([features, eos_padding], 1)
src_mask = tf.cast(tf.not_equal(src_seq, 0), dtype=tf.float32)
shift_src_mask = src_mask[:, :-1]
@@ -135,7 +135,7 @@ class CsanmtForTranslation(Model):
embedding_mat = tf.compat.v1.get_variable(
'Weights', [vocab_size, hidden_size], initializer=initializer)
eos_padding = tf.zeros([tf.shape(input=features)[0], 1], tf.int64)
eos_padding = tf.zeros_like(features, dtype=tf.int64)[:, :1]
input_seq = tf.concat([features, eos_padding], 1)
input_mask = tf.cast(tf.not_equal(input_seq, 0), dtype=tf.float32)
shift_input_mask = input_mask[:, :-1]
@@ -233,7 +233,7 @@ class CsanmtForTranslation(Model):
'Weights', [trg_vocab_size, hidden_size],
initializer=initializer)
eos_padding = tf.zeros([tf.shape(input=labels)[0], 1], tf.int64)
eos_padding = tf.zeros_like(labels, dtype=tf.int64)[:, :1]
trg_seq = tf.concat([labels, eos_padding], 1)
trg_mask = tf.cast(tf.not_equal(trg_seq, 0), dtype=tf.float32)
shift_trg_mask = trg_mask[:, :-1]
@@ -520,16 +520,16 @@ class CsanmtForTranslation(Model):
tf.constant([[0.] + [tf.float32.min] * (beam_size - 1)])
init_log_probs = tf.tile(init_log_probs, [batch_size, 1])
init_scores = tf.zeros_like(init_log_probs)
fin_seqs = tf.zeros([batch_size, beam_size, 1], tf.int32)
fin_seqs = tf.cast(tf.fill([batch_size, beam_size, 1], 0), tf.int32)
fin_scores = tf.fill([batch_size, beam_size], tf.float32.min)
fin_flags = tf.zeros([batch_size, beam_size], tf.bool)
fin_flags = tf.cast(tf.fill([batch_size, beam_size], 0), tf.bool)
states_key = [
tf.zeros([batch_size, 0, hidden_size])
tf.fill([batch_size, 0, hidden_size], 0.0)
for layer in range(num_decoder_layers)
]
states_val = [
tf.zeros([batch_size, 0, hidden_size])
tf.fill([batch_size, 0, hidden_size], 0.0)
for layer in range(num_decoder_layers)
]
for layer in range(num_decoder_layers):
@@ -1117,7 +1117,8 @@ def attention_bias(inputs, mode, inf=-1e9, dtype=None):
elif mode == 'causal':
length = inputs
lower_triangle = tf.linalg.band_part(tf.ones([length, length]), -1, 0)
lower_triangle = tf.linalg.band_part(
tf.fill([length, length], 1.0), -1, 0)
ret = inf * (1.0 - lower_triangle)
ret = tf.reshape(ret, [1, 1, length, length])
else: