请稍等 ...
×

采纳答案成功!

向帮助你的同学说点啥吧!感谢那些助人为乐的人

pass

for x_batch, y_batch in dataset.take(steps_per_shard):
for x_example, y_example in zip(x_batch, y_batch):#解batch
老师,此时的数据集是之前的8个属性与类标签结合起来,相当于一行9个属性的形式,现在从数据集中取出前steps_per_shard个batch,我想请教的是:获取到steps_per_shard个batch后,是如何赋值给 x_batch, y_batch 的,又是怎样实现对batch解绑的,赋值给 x_example, y_example的,解绑是指将9个属性拆分成8+1吗?

正在回答 回答被采纳积分+3

1回答

正十七 2020-02-24 23:38:47

如下面的代码所示:

def parse_csv_line(line, n_fields = 9):
    defs = [tf.constant(np.nan)] * n_fields
    parsed_fields = tf.io.decode_csv(line, record_defaults=defs)
    x = tf.stack(parsed_fields[0:-1])
    y = tf.stack(parsed_fields[-1:])
    return x, y
 
def csv_reader_dataset(filenames, n_readers=5,
                       batch_size=32, n_parse_threads=5,
                       shuffle_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()
    dataset = dataset.interleave(
        lambda filename: tf.data.TextLineDataset(filename).skip(1),
        cycle_length = n_readers
    )
    dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(parse_csv_line,
                          num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset

在csv_reader_dataset中,我们在map函数中对每一行进行解析。

在parse_csv_line中,我们把每一行给拆成了前八个和后一个,即x和y。这样csv_reader_dataset返回的dataset里,每一个batch都是两个元素,即x_batch和y_batch。

而batch解绑定直接用了for循环,如下,并不是9拆8+1。

with tf.io.TFRecordWriter(filename_fullpath, options) as writer:
    for x_batch, y_batch in dataset.skip(shard_id * steps_per_shard).take(steps_per_shard):                for x_example, y_example in zip(x_batch, y_batch):
        writer.write(serialize_example(x_example, y_example))
        all_filenames.append(filename_fullpath)


0 回复 有任何疑惑可以回复我~
  • 提问者 战战的坚果 #1
    老师,以下是我的理解,您看对吗?
    在csv_reader_dataset中,在map函数中对每一行进行解析。
    
    在parse_csv_line中,把每一行给拆成了前八个和后一个,即x和y。这样csv_reader_dataset返回的dataset里,每一个batch都是两个元素,即x_batch:即batchsize条数据的前8个元素和y_batch:即batchsize条数据的后1个元素,则
     for x_batch, y_batch in dataset.skip(shard_id * steps_per_shard).take(steps_per_shard):
    执行后,取出了steps_per_shard个batch,然后通过for循环取出一个batch,并将这一个batch(由两部分组成)赋值给x_batch, y_batch,此时x_batch:是batchsize条数据的前8个元素,y_batch是batchsize条数据的后1个元素,此时通过第二个for循环,将x_batch, y_batch连接起来,依次取出batchsize条数据的第一条、第二条。。。第batchsize条数据,将第一条数据的前8个元素赋值给x_example, 后一个元素赋值y_example,所以serialize_example(x_example, y_example)是将一行数据序列化。
    回复 有任何疑惑可以回复我~ 2020-02-28 12:04:47
问题已解决,确定采纳
还有疑问,暂不采纳
意见反馈 帮助中心 APP下载
官方微信