老师,我可以理解当batch_size=1时,dataset每一个元素是(x(shape=(8,),y(shape=(1,))形式,但是当batch_size=2时,为什么dataset每一个元素是(x(shape=(2,8),y(shape=(2,1))而不是(x(shape=(8,),y(shape=(1,),x(shape=(8,),y(shape=(1,))这样的呢,batch是在解析之后完成的呀,解析的结果不就是一行数据变成x,y两个结果,为什么batch之后两行数据也是x,y两个结果呢,而不是x,y,x,y这样的类型呢?
def csv_reader_dataset(filenames, n_readers=5,
batch_size=32, n_parse_threads=5,
shuffle_buffer_size=10000):
dataset = tf.data.Dataset.list_files(filenames)
dataset = dataset.repeat()
dataset = dataset.interleave(
lambda filename: tf.data.TextLineDataset(filename).skip(1),
cycle_length = n_readers
)
dataset.shuffle(shuffle_buffer_size)
dataset = dataset.map(parse_csv_line,
num_parallel_calls=n_parse_threads)
dataset = dataset.batch(batch_size)
return dataset