def
parse_csv_line(line, n_fields
=
9
):
defs
=
[tf.constant(np.nan)]
*
n_fields
parsed_fields
=
tf.io.decode_csv(line, record_defaults
=
defs)
x
=
tf.stack(parsed_fields[
0
:
-
1
])
y
=
tf.stack(parsed_fields[
-
1
:])
return
x, y
def
csv_reader_dataset(filenames, n_readers
=
5
,
batch_size
=
32
, n_parse_threads
=
5
,
shuffle_buffer_size
=
10000
):
dataset
=
tf.data.Dataset.list_files(filenames)
dataset
=
dataset.repeat()
dataset
=
dataset.interleave(
lambda
filename: tf.data.TextLineDataset(filename).skip(
1
),
cycle_length
=
n_readers
)
dataset.shuffle(shuffle_buffer_size)
dataset
=
dataset.
map
(parse_csv_line,
num_parallel_calls
=
n_parse_threads)
dataset
=
dataset.batch(batch_size)
return
dataset