同学你好,感谢你的细心发现,应该是程序出了问题,做一下修改把y_pred的第二个维度给去掉就可以了
# 1. batch 遍历训练集 metric
# 1.1 自动求导
# 2. epoch结束 验证集 metric
epochs = 100
batch_size = 32
steps_per_epoch = len(x_train_scaled) // batch_size
optimizer = keras.optimizers.SGD()
metric = keras.metrics.MeanSquaredError()
def random_batch(x, y, batch_size=32):
idx = np.random.randint(0, len(x), size=batch_size)
return x[idx], y[idx]
model = keras.models.Sequential([
keras.layers.Dense(30, activation='relu',
input_shape=x_train.shape[1:]),
keras.layers.Dense(1),
])
for epoch in range(epochs):
metric.reset_states()
for step in range(steps_per_epoch):
x_batch, y_batch = random_batch(x_train_scaled, y_train,
batch_size)
with tf.GradientTape() as tape:
y_pred = model(x_batch)
y_pred = tf.squeeze(y_pred, 1)
loss = keras.losses.mean_squared_error(y_batch, y_pred)
metric(y_batch, y_pred)
grads = tape.gradient(loss, model.variables)
grads_and_vars = zip(grads, model.variables)
optimizer.apply_gradients(grads_and_vars)
print("\rEpoch", epoch, " train mse:",
metric.result().numpy(), end="")
y_valid_pred = model(x_valid_scaled)
y_valid_pred = tf.squeeze(y_valid_pred, 1)
valid_loss = keras.losses.mean_squared_error(y_valid_pred, y_valid)
print("\t", "valid mse: ", valid_loss.numpy())