Tensorflow模型的基本用法

本文介绍Tensorflow的基本知识,包括三要素:计算图、张量和会话。这部分是Tensorflow的基础。

计算模型——计算图

Tensorflow通过计算图的形式描述所有计算过程,其中每一个计算都是计算图中的一个节点,而节点之间的边描述了计算之间的依赖关系。

1
2
3
4
5
6
7
import tensorflow as tf
a = tf.constant([1,0, 2.0], name='a')
b = tf.constant([2,0, 3.0], name='b')
result = a + b

# 查看张量所属的计算图。如果没有指定,使用默认计算图
print(a.graph, a.graph is tf.get_default_graph())
<tensorflow.python.framework.ops.Graph object at 0x7f4a1c41d6d8> True

默认情况下使用系统指定的计算图,当然,tensorflow也支持通过tf.Graph函数生成新的计算图。不同的计算图上的张量和运算都不会共享。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
g1 = tf.Graph()
with g1.as_default():
v = tf.get_variable('v', initializer=tf.zeros_initializer(),shape=[1])

g2 = tf.Graph()
with g2.as_default():
v = tf.get_variable('v', initializer=tf.ones_initializer(),shape=[1])

# 在计算图g1中取变量v的值
with tf.Session(graph=g1) as sess:
tf.global_variables_initializer().run()
with tf.variable_scope("", reuse=True):
print(sess.run(tf.get_variable('v')))

# 在计算图g2中取变量v的值
with tf.Session(graph=g2) as sess:
tf.global_variables_initializer().run()
with tf.variable_scope("", reuse=True):
print(sess.run(tf.get_variable('v')))
[ 0.]
[ 1.]

从上面可以看出,通过计算图,可以隔离张量和计算。另外,计算图还提供了管理张量和计算的机制。比如,计算图可以通过tf.Graph.device函数指定运行计算的设备(CPU或者GPU)。

1
2
3
4
5
g = tf.Graph()

# 指定计算运行的设备
with g.device('/cpu:0'):
result = a + b

另外,计算图可以有效地整合Tensorflow程序中的资源,包括张量、变量或者运行Tensorflow程序所需要的队列资源等。为了方便起见,Tensorflow也自动管理了一些常用的集合,比如集合tf.GraphKeys.VARIABLES表示所有的变量,tf.GraphKeys.TRAINABLE_VARIABLES表示所有可学习的变量,tf.GraphKeys.SUMMARIES表示日志相关的张量,tf.GraphKeys.QUEUE_RUNNERS表示处理输入的QueueRunner,tf.GraphKeys.MOVING_AVERAGE_VARIABLES表示所有计算了滑动平均值的变量

数据模型——张量

张量是tensorflow管理数据的形式。需要注意的是,张量中并没有真正保存数字,保存的是如何得到这些数字的计算过程

1
2
3
4
5
a = tf.constant([1,0, 2.0], name='a')
b = tf.constant([2,0, 3.0], name='b')
result1 = tf.add(a, b, name='add')
result2 = tf.add(result1, a, name="add")
print(result1, result2)
Tensor("add_9:0", shape=(3,), dtype=float32) Tensor("add_10:0", shape=(3,), dtype=float32)

张量中主要保存三个属性,名字、维度和类型。其中名字中包括计算节点和输出的第几个结果,比如add_9:0表示计算节点add_9的第一个结果(编号从0开始)。类型的作用是,在运行计算时出现类型不匹配时会报错。

张量的第一个作用是对中间计算结果的引用,提高代码的可读性;第二个作用是,当计算图构造完成之后,利用张量可以获得中间计算结果

运行模型——会话

会话拥有并管理tensorflow程序运行时的所有资源。并且所有计算完成之后,需要关闭会话来帮助系统回收资源,否则可能会出现资源泄露的问题。

Tensorflow中有两种使用会话的方式:一种方式是明确调用会话生成函数和关闭函数,当程序因为异常退出时,关闭会话的函数可能不会被执行从而导致资源泄露;另一种方式是利用上下文管理器来使用会话,当上下文管理器退出时会自动释放所有资源,可以解决因为异常退出时资源泄露的问题。

1
2
3
4
5
6
7
# 运行会话方式一
a = tf.constant([1.0, 2.0])
b = tf.constant([2.0, 3.0])
c = a + b
sess = tf.Session()
print(sess.run(c))
sess.close()
[ 3.  5.]
1
2
3
4
5
6
# 运行会话方式二
a = tf.constant([1.0, 2.0])
b = tf.constant([2.0, 3.0])
c = a + b
with tf.Session() as sess:
print(sess.run(c))
[ 3.  5.]

如果不指定,tensorflow会自动生成一个默认的计算图。而会话需要手动指定,指定之后,可以使用eval函数计算张量的取值。

1
2
3
4
5
6
7
8
9
10
11
sess = tf.Session()

# 方式一
with sess.as_default():
print(c.eval())

# 方式二
print(sess.run(c))

# 方式三
print(c.eval(session=sess))
[ 3.  5.]
[ 3.  5.]
[ 3.  5.]

在交互环境下,可以直接构建默认会话的函数,使用这个函数会自动将生成的会话注册为默认会话。

1
2
3
sess = tf.InteractiveSession()
print(c.eval())
sess.close()
[ 3.  5.]

配置会话:

通过ConfigProto可以配置并行的线程数、GPU分配策略、运算超时时间等参数

1
2
3
4
5
config = tf.ConfigProto(allow_soft_placement=True, # True表示当GPU无法运行时,可以放到CPU上运行
log_device_placement=True) # True表示记录每个计算步骤安排在哪个设备上

sess1 = tf.InteractiveSession(config=config)
sess2 = tf.Session(config=config)

参考资料

  • 郑泽宇、梁博文和顾思宇,Tensorflow: 实战Google深度学习框架(第二版)