Tensorflow数据格式及多线程处理

本文介绍如何使用Tensorflow进行多线程预处理。首先介绍TFRecord格式,并介绍如何利用队列框架进行多线程数据预处理,最后介绍Tensorflow 1.3版之后推荐使用的数据集(Dataset)API。

TFRecord文件

介绍

TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。下面是tf.train.Example的定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
message Example {
Features feature = 1;
};

meassage Features {
map<string, Feature> feature =1;
};

message Feature {
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};

从上面的定义可以看出,tf.train.Example的数据结构比较简单,包含一个从属性名称到取值的字典,其中属性名称为字符串,取值可以为字符串、实数列表、整数列表。

样例

1
2
3
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def _int64_feature(value):
"""
生成整数型的属性
"""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
"""
生成字符串型的属性
"""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

mnist = input_data.read_data_sets('/media/seisinv/Data/04_data/MNIST_data',dtype=tf.uint8, one_hot=True)

images = mnist.train.images
labels = mnist.train.labels

pixels = images.shape[1]
num_examples = mnist.train.num_examples

# 输出TFRecord文件
filename = '/media/seisinv/Data/04_data/ai/test/output.tfrecords'

# 创建writer写TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)

for index in range(num_examples):
# 将图像矩阵转化成一个字符串
image_raw = images[index].tostring()
# 将一个样例转化成Example Protocol Buffer,并将所有信息写入这个数据结构中
example = tf.train.Example(features=tf.train.Features(feature={
'pixels': _int64_feature(pixels),
'label': _int64_feature(np.argmax(labels[index])),
'image_raw': _bytes_feature(image_raw)
}))

# 将一个Examples写入TFRecord文件
writer.write(example.SerializeToString())

# 关闭文件
writer.close()
Extracting /media/seisinv/Data/04_data/MNIST_data/train-images-idx3-ubyte.gz
Extracting /media/seisinv/Data/04_data/MNIST_data/train-labels-idx1-ubyte.gz
Extracting /media/seisinv/Data/04_data/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting /media/seisinv/Data/04_data/MNIST_data/t10k-labels-idx1-ubyte.gz
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# 创建reader读取TFRecord文件
reader = tf.TFRecordReader()

# 创建队列维护输入文件列表
filename_queue = tf.train.string_input_producer(['/media/seisinv/Data/04_data/ai/test/output.tfrecords'])

# 从文件中读取一个样例,也可以使用read_up_to函数一次性读取多个样例
_, serialized_example = reader.read(filename_queue)

# 解析读入的一个样例。如果需要解析多个样例,可以使用parse_example函数
features = tf.parse_single_example(serialized_example,
features = {
'image_raw': tf.FixedLenFeature([], tf.string),
'pixels': tf.FixedLenFeature([], tf.int64),
'label': tf.FixedLenFeature([], tf.int64)
})

# tf.decode_raw可以将字符串解析成图像对应的像素数组
image = tf.decode_raw(features['image_raw'], tf.uint8)
label = tf.cast(features['label'], tf.int32)
pixels = tf.cast(features['pixels'], tf.int32)

sess = tf.Session()

# 启动多线程处理输入数据
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 每次运行可读取TFRecord文件中的一个样例,当所有样例都读完之后,程序会从头读取
for i in range(2):
print(sess.run([image, label, pixels]))
[array([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  97,
        96,  77, 118,  61,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,  90, 138, 235, 235, 235, 235, 235,
       235, 251, 251, 248, 254, 245, 235, 190,  21,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0, 140, 251, 254, 254, 254, 254,
       254, 254, 254, 254, 254, 254, 254, 254, 254, 254, 189,  23,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0, 226, 254, 208, 199,
       199, 199, 199, 139,  61,  61,  61,  61,  61, 128, 222, 254, 254,
       189,  21,   0,   0,   0,   0,   0,   0,   0,   0,   0,  38,  82,
        13,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  34,
       213, 254, 254, 115,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,  84, 254, 254, 234,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,  84, 254, 254, 234,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0, 106, 157, 254, 254, 243,  51,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,  25, 117, 228, 228, 228, 253, 254, 254, 254, 254, 240,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,  68, 119, 220, 254, 254, 254, 254, 254, 254, 254, 254,
       254, 142,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,  37, 187, 253, 254, 254, 254, 223, 206, 206,  75,  68,
       215, 254, 254, 117,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0, 113, 219, 254, 242, 227, 115,  89,  31,   0,   0,
         0,   0, 200, 254, 241,  41,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0, 169, 254, 176,  62,   0,   0,   0,   0,
         0,   0,   0,  48, 231, 254, 234,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,  18, 124,   0,   0,   0,   0,
         0,   0,   0,   0,   0,  84, 254, 254, 166,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0, 139, 254, 238,  57,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0, 210, 250, 254, 168,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 242, 254, 239,
        57,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  89, 251,
       241,  86,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   5,
       206, 246, 157,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   4, 117,  69,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0], dtype=uint8), 7, 784]
[array([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  31, 132, 254,
       253, 254, 213,  82,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  21, 142, 233,
       252, 253, 252, 253, 252, 223,  20,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 123, 254,
       253, 254, 253, 224, 203, 203, 223, 255, 213,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
       203, 253, 252, 253, 212,  20,   0,   0,  61, 253, 252,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,  41, 243, 224, 203, 183,  41, 152,  30,   0,   0, 255, 253,
       102,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,  40,  20,   0,   0, 102, 253,  50,   0,  82,
       253, 252,  20,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  82, 214,  31,
       113, 233, 254, 233,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,  62, 102,  82,  41,
       253, 232, 253, 252, 233,  50,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 152, 253,
       254, 253, 254, 253, 254, 233, 123,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
       152, 252, 253, 252, 253, 252, 192,  50,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,  62, 183, 203, 243, 254, 253,  62,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,  40, 172, 252, 203,  20,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,  21,   0,   0,   0,   0,   0,   0,   0,   0,   0, 183, 254,
       112,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,  62, 203, 163,   0,   0,   0,   0,   0,   0,   0,   0,
        61, 253, 151,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,  21, 214, 192,   0,   0,   0,   0,   0,   0,   0,
         0,  11, 213, 254, 151,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0, 102, 253, 151,   0,   0,   0,   0,   0,
         0,   0,  41, 213, 252, 253, 111,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,  41, 255, 213,  92,  51,   0,
         0,  31,  92, 173, 253, 254, 253, 142,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 172, 252, 253,
       252, 203, 203, 233, 252, 253, 252, 253, 130,  20,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  21,
       203, 255, 253, 254, 253, 254, 253, 244, 203,  82,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,  20, 151, 151, 253, 171, 151, 151,  40,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
         0,   0,   0,   0], dtype=uint8), 3, 784]

多线程数据处理框架

队列和多线程

在Tensorflow中,队列和变量类似,都是计算图上有状态的节点,其他的计算节点可以修改它们的状态。

  • 对于变量,可以通过赋值操作修改变量的取值。
  • 对于队列,修改队列的操作主要有Enqueue, EnqueueMany和Dequeue
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 创建一个先进先出的队列,指定最多可以保存两个元素,类型为整型
q = tf.FIFOQueue(2, 'int32')

# 初始化队列中的元素,和变量初始化类似,使用队列之前都需要明确的调用初始化过程
init = q.enqueue_many(([0,10],))

# 取出队列第一个元素,并赋值到变量x中
x = q.dequeue()

y = x + 1

# 将加1后的值加入到队列中
q_inc = q.enqueue([y])

with tf.Session() as sess:
init.run()
for _ in range(5):
v, _ = sess.run([x, q_inc])
print(v)
0
10
1
11
2

Tensorflow中还提供了两种队列,FIFOQueue和RandomShufferQueue,第一种实现的是先进先出队列,后一种会将队列中的元素打乱,每次出队列操作得到的是从当前队列所有元素中随机选择的一个。

在Tensorflow中,队列不仅是一种数据结构,而且是异步计算张量取值的一个重要机制。比如多个线程同事向一个队列写元素或者读元素。

Tensorflow中提供了tf.Coordinator和tf.QueueRunner两个类实现多线程协同的功能。tf.Coordinator类主要用于协同多个线程一起停止,并提供了should_stop, request_stopjoin三个函数。用法如下:

  • 在启动线程之前,首先需要声明一个tf.Coordinator类
  • 将这个类传入每个创建的线程,启动的线程需要一直查询should_stophanshu,当该函数返回为True时,当前线程需要退出
  • 每个启动的线程都可以通过request_stop函数通知其他的线程退出
  • 当某个线程调用request_stop函数之后,should_stop函数的返回值将被设置为True,这样其他的线程就会同时终止。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import threading
import time

def MyLoop(coord, workder_id):
"""
在一个线程中运行的程序,每隔1秒判断是否需要停止并打印自己的ID
"""
while not coord.should_stop():
if np.random.rand() < 0.1:
print("Stoping from id: %d\n"%workder_id)
coord.request_stop()
else:
print("Working on id:%d\n"%workder_id)
time.sleep(1)

# 声明tf.train.Coordinator类,协同多个线程
coord = tf.train.Coordinator()

# 声明创建2个线程
threads = [
threading.Thread(target=MyLoop, args=(coord, i, )) for i in range(2)
]

# 启动所有线程
for t in threads: t.start()

# 等待所有线程退出
coord.join(threads)
Working on id:0

Working on id:1

Working on id:0

Working on id:1

Working on id:0
Stoping from id: 1

tf.QueueRunner主要用于启动多个线程来操作同一个队列,启动的这些线程可以通过上面介绍的tf.Coordinator类来统一管理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# 声明一个队列,共100个元素,类型为实数
queue = tf.FIFOQueue(100,'float')

# 定义队列入队操作
enqueue_op = queue.enqueue([tf.random_normal([1])])

# 创建多个线程运行队列的入队操作,参数1指被操作的队列,参数2表示启动2个线程,每个线程运行enqueue_op操作
qr = tf.train.QueueRunner(queue, [enqueue_op]*2)

# 将定义过的tf.train.QueueRunner加入默认的tf.GraphKeys.QUEUE_RUNNERS集合
tf.train.add_queue_runner(qr)

# 定义出队操作
out_tensor = queue.dequeue()

with tf.Session() as sess:
coord = tf.train.Coordinator()

# 启动所有线程
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

# 获取队列中的取值
for _ in range(2): print(sess.run(out_tensor)[0])

# 停止所有线程
coord.request_stop()
coord.join(threads)
1.43469
-0.598175

输入文件队列

虽然可以将多个训练样本放入一个TFRecord中,但是当训练数据量很大时,将数据分成多个TFRecord文件可以提高处理效率。Tensorflow提供了tf.train.match_filenames_once函数获取符合正则化表达式的所有文件列表,并通过tf.train.string_input_producer函数进行管理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# 模拟海量数据情况下,将数据写入不同的文件
num_shards = 2 #总共2个文件
instances_per_shard = 2 #每个文件2个数据样本

for i in range(num_shards):
filename = ('/media/seisinv/Data/04_data/ai/test/data.tfrecords_%.5d-of-%.5d' %(i, num_shards))
writer = tf.python_io.TFRecordWriter(filename)

for j in range(instances_per_shard):
# Example类中仅仅保留两个简单的信息
example = tf.train.Example(features = tf.train.Features(feature={
'i': _int64_feature(i),
'j': _int64_feature(j)
}))
writer.write(example.SerializeToString())
writer.close()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# 获取文件列表
files = tf.train.match_filenames_once('/media/seisinv/Data/04_data/ai/test/data.tfrecords_*')

# 创建输入队列,shuffer参数控制是否随机打乱读取文件的顺序,在实际过程中,一般设置为True
filename_queue = tf.train.string_input_producer(files, shuffle=False)
# filename_queue = tf.train.string_input_producer(files, shuffle=False, num_epochs=1)

reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)

# 解析一个样本
features = tf.parse_single_example(serialized_example,
features={
'i': tf.FixedLenFeature([], tf.int64),
'j': tf.FixedLenFeature([], tf.int64)
})

with tf.Session() as sess:
tf.local_variables_initializer().run()
print(sess.run(files))

# 协同不同的线程
coord = tf.train.Coordinator()

# 启动所有的线程
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

for i in range(6):
print(sess.run([features['i'], features['j']]))
coord.request_stop()
coord.join(threads)
[b'/media/seisinv/Data/04_data/ai/test/data.tfrecords_00000-of-00002'
 b'/media/seisinv/Data/04_data/ai/test/data.tfrecords_00001-of-00002']
[0, 0]
[0, 1]
[1, 0]
[1, 1]
[0, 0]
[0, 1]

在上面的例子中,由于没有打乱文件列表的顺序,因为会依次读取样本数据中每个样本的信息,而且当所有样本都被读取之后,程序会自动从头开始。如果限制num_epochs为1,那么程序将会报错。

组合训练数据

当得到单个样本之后,可以通过tf.train.batch和tf.train.shuffle_batch函数以队列的形式生成一个batch。入队操作是生成一个样本,出队操作得到一个batch的样本。两者唯一的区别在于是否会将数据顺序打乱。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
tf.reset_default_graph()

# 获取文件列表
files = tf.train.match_filenames_once('/media/seisinv/Data/04_data/ai/test/data.tfrecords_*')

# 创建输入队列,shuffer参数控制是否随机打乱读取文件的顺序,在实际过程中,一般设置为True
filename_queue = tf.train.string_input_producer(files, shuffle=False)
# filename_queue = tf.train.string_input_producer(files, shuffle=False, num_epochs=1)

reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)

# 解析一个样本
features = tf.parse_single_example(serialized_example,
features={
'i': tf.FixedLenFeature([], tf.int64),
'j': tf.FixedLenFeature([], tf.int64)
})

example, label = features['i'], features['j']

batch_size = 3 # batch大小

# 设置队列大小
capacity = 1000 + 3 * batch_size

example_batch, label_batch = tf.train.batch([example, label],
batch_size=batch_size, capacity=capacity)

with tf.Session() as sess:
tf.local_variables_initializer().run()

# 协同并启动所的线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

# 实际问题中,这里一般是神经网络的输入
for i in range(2):
cur_example_batch, cur_label_batch = sess.run([example_batch, label_batch])
print(cur_example_batch, cur_label_batch)

coord.request_stop()
coord.join()
[0 0 1] [0 1 0]
[1 0 0] [1 0 1]

tf.train.shuffle_batch的用法和tf.train.batch类似,区别有二:

  • 输出的样本顺序会被打乱
  • 增加一个参数min_after_dequeue,限制出队时样本的最少个数。这是因为当队列中元素太少时,随机打乱顺序的作用不大,只有当样本多于一定数量时,才开始随机出队

tf.train.batch、tf.train.shuffle_batch、tf.train.shuffle_batch_join都可以通过num_threads指定多个线程执行入队操作(包括数据读取和预处理)。区别在于:

  • tf.train.shuffle_batch函数多个线程会同时读取一个文件的不同样本并进行预处理。如果一个文件中的样例比较相似(比如属于同一类),那么神经网络的训练效果可能会受到影响,因此尽量将同一个TFRecord文件中的样本随机打乱
  • tf.train.shuffle_batch_join函数多个线程处理不同文件中的不同样本,不同线程会读取不同文件,具体来说,是将tf.train.string_input_producer函数生成的文件队列平均分配到不同的线程上。但是,如果读取数据的线程数比总文件数还大,那么多个线程可能会读取同一个文件中相近部分的数据。而且多个线程读取多个文件可能导致过多的硬盘寻址,从而使得读取效率降低。

总结具体的多线程文件读入和预处理步骤包括:

  • 输入文件列表,通过tf.train.string_input_producer函数生成输入文件队列(可以随机打乱)
  • 通过tf.train.batch或其他2个函数进行多线程入队操作(包括数据读入和预处理)
  • 生成样本batch,执行训练或者测试过程

数据集(Dataset)

前面介绍了,通过队列进行多线程输入和预处理。从Tensorflow 1.3开始,数据集被正式推荐为输入数据的首选框架。下面介绍数据集的基本用法。

基本用法

在数据集框架中,每个数据集代表一个数据来源,可能是一个张量、TFRecord文件、或者文本文件。由于训练数据通常都很大,无法全部写入内存中,因此从数据集中读取数据时需要使用一个迭代器按顺序读取,这点和队列相似,并且数据集也是计算图上的一个节点。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
input_data = [1, 2, 3, 4, 5]

# 从一个数组创建数据集
dataset = tf.data.Dataset.from_tensor_slices(input_data)

# 定义迭代器遍历数据集
iterator = dataset.make_one_shot_iterator()

# 返回一个输入数据的张量,类似于队列中的dequeue()
x = iterator.get_next()

y = x*x

with tf.Session() as sess:
for i in range(len(input_data)):
print(sess.run(y))
1
4
9
16
25

利用数据集读取数据的步骤包括:

  • 创建数据集
  • 定义遍历器
  • 使用get_next读取数据张量
1
2
3
4
5
6
7
8
9
10
11
12
input_files = ["/media/seisinv/Data/04_data/ai/test/test1.txt","/media/seisinv/Data/04_data/ai/test/test2.txt"]

# 从文本文件创建数据集,假定每行表示一个训练样本
dataset = tf.data.TextLineDataset(input_files)

iterator = dataset.make_one_shot_iterator()

x = iterator.get_next()

with tf.Session() as sess:
for i in range(3):
print(sess.run(x))
b'1\t2'
b'3\t4'
b'5\t6'
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def parser(record):
"""
解析一个TFRecord的方法
"""
features = tf.parse_single_example(record,
features = {
'i': tf.FixedLenFeature([], tf.int64),
'j': tf.FixedLenFeature([], tf.int64)
})
return features['i'], features['j']

input_files = ['/media/seisinv/Data/04_data/ai/test/data.tfrecords_00000-of-00002','/media/seisinv/Data/04_data/ai/test/data.tfrecords_00001-of-00002']

# 从TFRecord文件创建数据集
dataset = tf.data.TFRecordDataset(input_files)

# 调用map函数对数据集中的每条数据应用parser操作,当然可以将parser换成其他的预处理操作。
dataset = dataset.map(parser)

iterator = dataset.make_one_shot_iterator()

feat1, feat2 = iterator.get_next()

with tf.Session() as sess:
for i in range(4):
print(sess.run([feat1, feat2]))
[0, 0]
[0, 1]
[1, 0]
[1, 1]

除了简单地使用one_shot_iterator(需要事先确定所欲的参数)来遍历数据外,还可以使用placeholder来初始化数据集,initializable_iterator初始化迭代器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
input_files = tf.placeholder(tf.string)

# 从TFRecord文件创建数据集
dataset = tf.data.TFRecordDataset(input_files)

# 调用map函数对数据集中的每条数据应用parser操作,当然可以将parser换成其他的预处理操作。
dataset = dataset.map(parser, num_parallel_calls=2)

# 可以不需要事先知道所有参数,对迭代器初始化
iterator = dataset.make_initializable_iterator()

feat1, feat2 = iterator.get_next()

with tf.Session() as sess:
sess.run(iterator.initializer,
feed_dict={input_files: ['/media/seisinv/Data/04_data/ai/test/data.tfrecords_00000-of-00002','/media/seisinv/Data/04_data/ai/test/data.tfrecords_00001-of-00002']})
while True:
try:
print(sess.run([feat1, feat2]))
except tf.errors.OutOfRangeError:
break
[0, 0]
[0, 1]
[1, 0]
[1, 1]

数据集的高层封装

在上一节介绍的队列框架中,预处理,shuffle,batch等操作有的在队列上进行,有的在图片张量上进行,整个处理流程在处理队列和张量的代码中来回切换。而在数据集中,所有的操作都是在数据集上进行,代码结构简洁、干净。

通过map方法可以封装更加复杂的预处理流程,比如:

1
dataset = dataset.map(lambda x: preprocess(x, image_size, image_size, None))

在数据集框架中,shuffle和batch操作由两个方法独立实现:

1
2
dataset = dataset.shuffle(buffer_size) # 和前面的min_after_dequeue相似
dataset = dataset.batch(batch_size)

数据集框架还提供了很多其他的函数,具体的参考Tensorflow相关文档。

结论

本文主要介绍Tensorflow所支持的多线程输入及预处理框架,包括:
- TFRecord格式可以将不同类型的数据统一管理,其数据结构可以简单理解为一种复杂的字典结构
- 队列框架支持从生成文件列表队列到batch组合队列,中间支持多线程并行预处理样本集
- 数据集框架是目前Tensorflow推荐的高层输入和预处理框架,该框架提供了随机打乱、样本batch、数据复制等高层操作。

参考资料

  • 郑泽宇、梁博文和顾思宇,Tensorflow: 实战Google深度学习框架(第二版)