这是我们resnet block的代码:
def residual_block(x, output_channel):
"""residual connection implementation"""
input_channel = x.get_shape().as_list()[-1]
if input_channel * 2 == output_channel:
increase_dim = True
strides = (2, 2)
elif input_channel == output_channel:
increase_dim = False
strides = (1, 1)
else:
raise Exception("input channel can't match output channel")
conv1 = tf.layers.conv2d(x,
output_channel,
(3,3),
strides = strides,
padding = 'same',
activation = tf.nn.relu,
name = 'conv1')
conv2 = tf.layers.conv2d(conv1,
output_channel,
(3, 3),
strides = (1, 1),
padding = 'same',
activation = tf.nn.relu,
name = 'conv2')
if increase_dim:
# [None, image_width, image_height, channel] -> [,,,channel*2]
pooled_x = tf.layers.average_pooling2d(x,
(2, 2),
(2, 2),
padding = 'valid')
padded_x = tf.pad(pooled_x,
[[0,0],
[0,0],
[0,0],
[input_channel // 2, input_channel // 2]])
else:
padded_x = x
output_x = conv2 + padded_x
return output_x
在做pooling的时候,我们用了valid。
其实不管是same还是valid,因为做pooling总是要变小的。所以变小之后你得保证做了tf.pad后能跟conv2的shape是一样的。保证了这一点后,用valid还是same都可以的。