同学你好,你的这个问题其实我们在后面第十章的大项目中有类似的代码,在tensorflow2.0中除了fit函数外,我们也可以自己定制这个学习过程。
这是第十章seq2seq+attn的训练的实现,这里就是遍历dataset,然后把batch塞给网络
EPOCHS = 10
for epoch in range(EPOCHS):
start = time.time()
encoding_hidden = encoder.initialize_hidden_state()
total_loss = 0
for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):
batch_loss = train_step(inp, targ, encoding_hidden)
total_loss += batch_loss
if batch % 100 == 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, batch, batch_loss.numpy())) # saving (checkpoint) the model every 2 epochs
if (epoch + 1) % 2 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print('Epoch {} Loss {:.4f}'.format(epoch + 1, total_loss / steps_per_epoch))
print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))train_step的实现则如下:
@tf.function
def train_step(inp, targ, encoding_hidden):
loss = 0
with tf.GradientTape() as tape:
encoding_output, encoding_hidden = encoder(inp, encoding_hidden)
decoding_hidden = encoding_hidden
decoding_input = tf.expand_dims([targ_lang.word_index['<start>']] * BATCH_SIZE, 1) # Teacher forcing - feeding the target as the next input
for t in range(1, targ.shape[1]):
# passing enc_output to the decoder
predictions, decoding_hidden, _ = decoder(decoding_input, decoding_hidden, encoding_output)
loss += loss_function(targ[:, t], predictions) # using teacher forcing
decoding_input = tf.expand_dims(targ[:, t], 1)
batch_loss = (loss / int(targ.shape[1]))
variables = encoder.trainable_variables + decoder.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return batch_loss
在这里,数据输入给网络后,就可以通过loss_function来实现损失,这个时候,你就可以按照你自己的需求自定义损失的实现。