请教老师,个人训练skip-gram模型时,为什么用@tf.function加速train_step速度仅是1.5s/it?(不加速是3s/it)
模型向前计算速度正常极速,但是整个train_step极慢一个iter要1.5s。
@tf.function
def train_step(inp_w_id, inp_v_id, inp_neg_v_ids):
with tf.GradientTape() as tape:
loss = skip_gram_model(inp_w_id, inp_v_id, inp_neg_v_ids) ***# 这步模型向前计算运行时间极快,正常。***
loss_ = tf.reduce_mean(loss) #
variables = skip_gram_model.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss_
调用整个train_step,本行就需要1.5s。
train_step(inp_w_id, inp_v_id, inp_neg_v_ids)
是啥原因呢?apply_gradients的问题吗?s/it