本文介绍Tensorflow中如何持久化(保存和加载)模型,以及持久化的工作原理和持久化所涉及的文件格式。并通过一个简单的加法实例,说明持久化的具体实现和工作机制。
模型持久化的实现
1 | import tensorflow as tf |
上述代码执行后,会生成3个文件,分别是
- 后缀为meta的文件。保存的是计算图的结构,也可以理解为神经网络的结构
- 后缀为ckpt的文件。保存每个变量的取值
- checkpoint文件。保存当前目录下所有的模型文件列表
持久化所有变量
1 | import tensorflow as tf |
INFO:tensorflow:Restoring parameters from model/model_test.ckpt
[ 3.]
持久化部分变量
应用场景:已经训练好了一个5层网络,想尝试一个6层网络,可以讲前面训练好的5层网络直接加载到新的模型,然后仅仅训练第6层网络。
1 | v1 = tf.Variable(tf.constant(1.0, shape=[1], name='v1')) |
1 | import tensorflow as tf |
INFO:tensorflow:Restoring parameters from model/model_test.ckpt
[1.]
变量重命名
应用场景:在对模型进行滑动平均操作后,直接将结果赋值给原来的变量,方便该变量继续在后面的操作中使用。
1 | tf.reset_default_graph() |
v:0
v:0 ()
v/ExponentialMovingAverage:0 ()
[10.0, 0.099999905]
重命名滑动平均变量
1 | v = tf.Variable(0, dtype=tf.float32, name='v') |
INFO:tensorflow:Restoring parameters from model/model_test.ckpt
0.099999905
为了方便加载时重命名滑动平均变量,tf.train.ExponentialMovingAverage类提供了variable_to_restore函数生成tf.train.Saver类所需要的变量重命名字典。
1 | tf.reset_default_graph() |
v:0
{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
INFO:tensorflow:Restoring parameters from model/model_test.ckpt
0.099999905
变量保存为常量
应用场景:在测试过程中,只需要知道如何从神经网络的输入层经过向前传播计算输出层,不需要变量初始化、模型保存等辅助的信息,因此,可以将变量以常量的形式保存,这样整个计算图可以统一地放在一个文件中。
1 | from tensorflow.python.framework import graph_util |
INFO:tensorflow:Froze 2 variables.
Converted 2 variables to const ops.
1 | with tf.Session() as sess: |
[array([3.], dtype=float32)]
持久化原理及数据格式
Tensorflow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。而元图是由MetaGraphDef Protocal Buffer定义的,保存MetaGraphDef信息的文件默认以.meta后缀名。下面用一个简单的求和例子说明元图的工作原理以及MetaGraphDef类型的定义。
1 | tf.reset_default_graph() |
meta_info_def {
stripped_op_list {
op {
name: "Add"
input_arg {
name: "x"
type_attr: "T"
}
input_arg {
name: "y"
type_attr: "T"
}
output_arg {
name: "z"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
type: DT_UINT8
type: DT_INT8
type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_COMPLEX64
type: DT_COMPLEX128
type: DT_STRING
}
}
}
}
op {
name: "Assign"
input_arg {
name: "ref"
type_attr: "T"
is_ref: true
}
input_arg {
name: "value"
type_attr: "T"
}
output_arg {
name: "output_ref"
type_attr: "T"
is_ref: true
}
attr {
name: "T"
type: "type"
}
attr {
name: "validate_shape"
type: "bool"
default_value {
b: true
}
}
attr {
name: "use_locking"
type: "bool"
default_value {
b: true
}
}
allows_uninitialized_input: true
}
op {
name: "Const"
output_arg {
name: "output"
type_attr: "dtype"
}
attr {
name: "value"
type: "tensor"
}
attr {
name: "dtype"
type: "type"
}
}
op {
name: "Identity"
input_arg {
name: "input"
type_attr: "T"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "T"
type: "type"
}
}
op {
name: "NoOp"
}
op {
name: "RestoreV2"
input_arg {
name: "prefix"
type: DT_STRING
}
input_arg {
name: "tensor_names"
type: DT_STRING
}
input_arg {
name: "shape_and_slices"
type: DT_STRING
}
output_arg {
name: "tensors"
type_list_attr: "dtypes"
}
attr {
name: "dtypes"
type: "list(type)"
has_minimum: true
minimum: 1
}
is_stateful: true
}
op {
name: "SaveV2"
input_arg {
name: "prefix"
type: DT_STRING
}
input_arg {
name: "tensor_names"
type: DT_STRING
}
input_arg {
name: "shape_and_slices"
type: DT_STRING
}
input_arg {
name: "tensors"
type_list_attr: "dtypes"
}
attr {
name: "dtypes"
type: "list(type)"
has_minimum: true
minimum: 1
}
is_stateful: true
}
op {
name: "VariableV2"
output_arg {
name: "ref"
type_attr: "dtype"
is_ref: true
}
attr {
name: "shape"
type: "shape"
}
attr {
name: "dtype"
type: "type"
}
attr {
name: "container"
type: "string"
default_value {
s: ""
}
}
attr {
name: "shared_name"
type: "string"
default_value {
s: ""
}
}
is_stateful: true
}
}
tensorflow_version: "1.4.0-rc1"
tensorflow_git_version: "v1.4.0-rc0-21-g1e25994"
}
graph_def {
node {
name: "Const"
op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
}
}
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 1
}
}
float_val: 1.0
}
}
}
}
node {
name: "v1"
op: "VariableV2"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
}
}
}
}
attr {
key: "container"
value {
s: ""
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
}
}
}
attr {
key: "shared_name"
value {
s: ""
}
}
}
node {
name: "v1/Assign"
op: "Assign"
input: "v1"
input: "Const"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@v1"
}
}
}
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
}
}
}
}
attr {
key: "use_locking"
value {
b: true
}
}
attr {
key: "validate_shape"
value {
b: true
}
}
}
node {
name: "v1/read"
op: "Identity"
input: "v1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@v1"
}
}
}
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
}
}
}
}
}
node {
name: "Const_1"
op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
}
}
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 1
}
}
float_val: 2.0
}
}
}
}
node {
name: "v2"
op: "VariableV2"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
}
}
}
}
attr {
key: "container"
value {
s: ""
}
}
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: 1
}
}
}
}
attr {
key: "shared_name"
value {
s: ""
}
}
}
node {
name: "v2/Assign"
op: "Assign"
input: "v2"
input: "Const_1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@v2"
}
}
}
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
}
}
}
}
attr {
key: "use_locking"
value {
b: true
}
}
attr {
key: "validate_shape"
value {
b: true
}
}
}
node {
name: "v2/read"
op: "Identity"
input: "v2"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@v2"
}
}
}
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
}
}
}
}
}
node {
name: "add"
op: "Add"
input: "v1/read"
input: "v2/read"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
}
}
}
}
}
node {
name: "save/Const"
op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
}
string_val: "model"
}
}
}
}
node {
name: "save/SaveV2/tensor_names"
op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 2
}
}
}
}
}
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
size: 2
}
}
string_val: "v1"
string_val: "v2"
}
}
}
}
node {
name: "save/SaveV2/shape_and_slices"
op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 2
}
}
}
}
}
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
size: 2
}
}
string_val: ""
string_val: ""
}
}
}
}
node {
name: "save/SaveV2"
op: "SaveV2"
input: "save/Const"
input: "save/SaveV2/tensor_names"
input: "save/SaveV2/shape_and_slices"
input: "v1"
input: "v2"
attr {
key: "dtypes"
value {
list {
type: DT_FLOAT
type: DT_FLOAT
}
}
}
}
node {
name: "save/control_dependency"
op: "Identity"
input: "save/Const"
input: "^save/SaveV2"
attr {
key: "T"
value {
type: DT_STRING
}
}
attr {
key: "_class"
value {
list {
s: "loc:@save/Const"
}
}
}
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
}
node {
name: "save/RestoreV2/tensor_names"
op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
}
}
}
}
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
size: 1
}
}
string_val: "v1"
}
}
}
}
node {
name: "save/RestoreV2/shape_and_slices"
op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
}
}
}
}
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
size: 1
}
}
string_val: ""
}
}
}
}
node {
name: "save/RestoreV2"
op: "RestoreV2"
input: "save/Const"
input: "save/RestoreV2/tensor_names"
input: "save/RestoreV2/shape_and_slices"
attr {
key: "_output_shapes"
value {
list {
shape {
unknown_rank: true
}
}
}
}
attr {
key: "dtypes"
value {
list {
type: DT_FLOAT
}
}
}
}
node {
name: "save/Assign"
op: "Assign"
input: "v1"
input: "save/RestoreV2"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@v1"
}
}
}
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
}
}
}
}
attr {
key: "use_locking"
value {
b: true
}
}
attr {
key: "validate_shape"
value {
b: true
}
}
}
node {
name: "save/RestoreV2_1/tensor_names"
op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
}
}
}
}
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
size: 1
}
}
string_val: "v2"
}
}
}
}
node {
name: "save/RestoreV2_1/shape_and_slices"
op: "Const"
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
}
}
}
}
attr {
key: "dtype"
value {
type: DT_STRING
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_STRING
tensor_shape {
dim {
size: 1
}
}
string_val: ""
}
}
}
}
node {
name: "save/RestoreV2_1"
op: "RestoreV2"
input: "save/Const"
input: "save/RestoreV2_1/tensor_names"
input: "save/RestoreV2_1/shape_and_slices"
attr {
key: "_output_shapes"
value {
list {
shape {
unknown_rank: true
}
}
}
}
attr {
key: "dtypes"
value {
list {
type: DT_FLOAT
}
}
}
}
node {
name: "save/Assign_1"
op: "Assign"
input: "v2"
input: "save/RestoreV2_1"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "_class"
value {
list {
s: "loc:@v2"
}
}
}
attr {
key: "_output_shapes"
value {
list {
shape {
dim {
size: 1
}
}
}
}
}
attr {
key: "use_locking"
value {
b: true
}
}
attr {
key: "validate_shape"
value {
b: true
}
}
}
node {
name: "save/restore_all"
op: "NoOp"
input: "^save/Assign"
input: "^save/Assign_1"
}
versions {
producer: 24
}
}
saver_def {
filename_tensor_name: "save/Const:0"
save_tensor_name: "save/control_dependency:0"
restore_op_name: "save/restore_all"
max_to_keep: 5
keep_checkpoint_every_n_hours: 10000.0
version: V2
}
collection_def {
key: "trainable_variables"
value {
bytes_list {
value: "\n\004v1:0\022\tv1/Assign\032\tv1/read:02\007Const:0"
value: "\n\004v2:0\022\tv2/Assign\032\tv2/read:02\tConst_1:0"
}
}
}
collection_def {
key: "variables"
value {
bytes_list {
value: "\n\004v1:0\022\tv1/Assign\032\tv1/read:02\007Const:0"
value: "\n\004v2:0\022\tv2/Assign\032\tv2/read:02\tConst_1:0"
}
}
}
除了持久化Tensorflow计算图的结构,Tensorflow中的取值也很重要。名为model.ckpt.data和\(model.ckpt.data-*****-of*****\)的文件保存的就是所有变量的取值。其中model.ckpt.data文件是通过SSTable格式存储的,可以大致理解为是一个(key, value)字典列表。以下例子展示如何使用tf.train.NewCheckpointReader类查看保存的变量信息
1 | tf.reset_default_graph() |
[3.]
1 | saver = tf.train.import_meta_graph('model/model_test.ckpt.meta') |
INFO:tensorflow:Restoring parameters from model/model_test.ckpt
[3.]
1 | tf.reset_default_graph() |
Variable_1 [1]
Variable [1]
value for variable v1 is [2.]
最后一个文件的名字是固定的——checkpoint。该文件是由tf.train.Saver类自动生成的,维护了tf.train.Saver类持久化的所有模型文件名。当某个模型文件删除时,这个模型所对应的文件名也会从checkpoint文件中删除。
参考资料
- 郑泽宇、梁博文和顾思宇,Tensorflow: 实战Google深度学习框架(第二版)