Tensorflow模型的持久化

本文介绍Tensorflow中如何持久化(保存和加载)模型,以及持久化的工作原理和持久化所涉及的文件格式。并通过一个简单的加法实例,说明持久化的具体实现和工作机制。

模型持久化的实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1], name='v1'))
v2 = tf.Variable(tf.constant(2.0, shape=[1], name='v2'))

result = v1+v2

init_op = tf.global_variables_initializer()

# 声明tf.train.Saver类用于保存模型
saver = tf.train.Saver()

with tf.Session() as sess:
sess.run(init_op)
saver.save(sess, 'model/model_test.ckpt')

上述代码执行后,会生成3个文件,分别是

  1. 后缀为meta的文件。保存的是计算图的结构,也可以理解为神经网络的结构
  2. 后缀为ckpt的文件。保存每个变量的取值
  3. checkpoint文件。保存当前目录下所有的模型文件列表

持久化所有变量

1
2
3
4
5
6
7
8
import tensorflow as tf

# 加载持久化的图
saver = tf.train.import_meta_graph('model/model_test.ckpt.meta')

with tf.Session() as sess:
saver.restore(sess,'model/model_test.ckpt')
print(sess.run(result))
INFO:tensorflow:Restoring parameters from model/model_test.ckpt
[ 3.]

持久化部分变量

应用场景:已经训练好了一个5层网络,想尝试一个6层网络,可以讲前面训练好的5层网络直接加载到新的模型,然后仅仅训练第6层网络。

1
2
3
4
5
6
7
8
9
10
11
12
13
v1 = tf.Variable(tf.constant(1.0, shape=[1], name='v1'))
v2 = tf.Variable(tf.constant(2.0, shape=[1], name='v2'))

result = v1+v2

init_op = tf.global_variables_initializer()

# 只持久化变量v1
saver = tf.train.Saver([v1])

with tf.Session() as sess:
sess.run(init_op)
saver.save(sess, 'model/model_test.ckpt')
1
2
3
4
5
6
7
8
import tensorflow as tf

# 加载持久化的图
saver = tf.train.import_meta_graph('model/model_test.ckpt.meta')

with tf.Session() as sess:
saver.restore(sess,'model/model_test.ckpt')
print(sess.run(v1))
INFO:tensorflow:Restoring parameters from model/model_test.ckpt
[1.]

变量重命名

应用场景:在对模型进行滑动平均操作后,直接将结果赋值给原来的变量,方便该变量继续在后面的操作中使用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
tf.reset_default_graph()

v = tf.Variable(0, dtype=tf.float32, name='v')

# 当前只有一个变量v:0
for variables in tf.global_variables():
print(variables.name)

# 使用tf.train.ExponentialMovingAverage滑动平均类
ema = tf.train.ExponentialMovingAverage(0.99)

# 申明滑动平均模型之后,生成影子变量v/ExponentialMovingAverage:0
maintain_averages_op = ema.apply(tf.global_variables())

# 同时输出两个变量
for variables in tf.global_variables():
print(variables.name, variables.shape)

saver = tf.train.Saver()

with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)

sess.run(tf.assign(v, 10))

sess.run(maintain_averages_op)

# 保存这两个变量
saver.save(sess, 'model/model_test.ckpt')

print(sess.run([v,ema.average(v)]))
v:0
v:0 ()
v/ExponentialMovingAverage:0 ()
[10.0, 0.099999905]

重命名滑动平均变量

1
2
3
4
5
6
7
v = tf.Variable(0, dtype=tf.float32, name='v')

# 使用变量重命名将原来变量v的滑动平均值直接赋值给v
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
with tf.Session() as sess:
saver.restore(sess, "model/model_test.ckpt")
print(sess.run(v))
INFO:tensorflow:Restoring parameters from model/model_test.ckpt
0.099999905

为了方便加载时重命名滑动平均变量,tf.train.ExponentialMovingAverage类提供了variable_to_restore函数生成tf.train.Saver类所需要的变量重命名字典。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
tf.reset_default_graph()

v = tf.Variable(0, dtype=tf.float32, name='v')

for variables in tf.global_variables():
print(variables.name)

# 使用tf.train.ExponentialMovingAverage滑动平均类
ema = tf.train.ExponentialMovingAverage(0.99)

# 使用滑动平均类自带的函数生成所需要的变量重命名字典
print(ema.variables_to_restore())

# 打印变量重命名字典
saver = tf.train.Saver(ema.variables_to_restore())

with tf.Session() as sess:
saver.restore(sess, 'model/model_test.ckpt')
print(sess.run(v))
v:0
{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
INFO:tensorflow:Restoring parameters from model/model_test.ckpt
0.099999905

变量保存为常量

应用场景:在测试过程中,只需要知道如何从神经网络的输入层经过向前传播计算输出层,不需要变量初始化、模型保存等辅助的信息,因此,可以将变量以常量的形式保存,这样整个计算图可以统一地放在一个文件中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')

result = v1 + v2

init_op = tf.global_variables_initializer()

with tf.Session() as sess:
sess.run(init_op)

# 导出当前计算图的GraphDef部分,只需要这一部分就可以完成输入到输出的计算过程
graph_def = tf.get_default_graph().as_graph_def()

# 将图中的变量及其取值转化为常量,同时将不必要的节点去掉。最后一个参数表示需要保存的节点名称,而不是变量的名称
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])

# 将到处的模型存入文件
with tf.gfile.GFile('model/model_test.pb', 'wb') as f:
f.write(output_graph_def.SerializeToString())
INFO:tensorflow:Froze 2 variables.
Converted 2 variables to const ops.
1
2
3
4
5
6
7
8
9
10
11
with tf.Session() as sess:
model_filename = 'model/model_test.pb'

# 读取保存的模型文件,并解析为对应的GraphDef Protocol Buffer
with tf.gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

# 将保存的图加载到当前图中,最后一个参数表示返回的张量名称,而不是节点名称
result = tf.import_graph_def(graph_def, return_elements=["add: 0"])
print(sess.run(result))
[array([3.], dtype=float32)]

持久化原理及数据格式

Tensorflow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。而元图是由MetaGraphDef Protocal Buffer定义的,保存MetaGraphDef信息的文件默认以.meta后缀名。下面用一个简单的求和例子说明元图的工作原理以及MetaGraphDef类型的定义。

1
2
3
4
5
6
7
8
9
tf.reset_default_graph()

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='v2')
result = v1 + v2

saver = tf.train.Saver()

saver.export_meta_graph('model/model_test.ckpt.meta.jason', as_text=True)
meta_info_def {
  stripped_op_list {
    op {
      name: "Add"
      input_arg {
        name: "x"
        type_attr: "T"
      }
      input_arg {
        name: "y"
        type_attr: "T"
      }
      output_arg {
        name: "z"
        type_attr: "T"
      }
      attr {
        name: "T"
        type: "type"
        allowed_values {
          list {
            type: DT_HALF
            type: DT_FLOAT
            type: DT_DOUBLE
            type: DT_UINT8
            type: DT_INT8
            type: DT_INT16
            type: DT_INT32
            type: DT_INT64
            type: DT_COMPLEX64
            type: DT_COMPLEX128
            type: DT_STRING
          }
        }
      }
    }
    op {
      name: "Assign"
      input_arg {
        name: "ref"
        type_attr: "T"
        is_ref: true
      }
      input_arg {
        name: "value"
        type_attr: "T"
      }
      output_arg {
        name: "output_ref"
        type_attr: "T"
        is_ref: true
      }
      attr {
        name: "T"
        type: "type"
      }
      attr {
        name: "validate_shape"
        type: "bool"
        default_value {
          b: true
        }
      }
      attr {
        name: "use_locking"
        type: "bool"
        default_value {
          b: true
        }
      }
      allows_uninitialized_input: true
    }
    op {
      name: "Const"
      output_arg {
        name: "output"
        type_attr: "dtype"
      }
      attr {
        name: "value"
        type: "tensor"
      }
      attr {
        name: "dtype"
        type: "type"
      }
    }
    op {
      name: "Identity"
      input_arg {
        name: "input"
        type_attr: "T"
      }
      output_arg {
        name: "output"
        type_attr: "T"
      }
      attr {
        name: "T"
        type: "type"
      }
    }
    op {
      name: "NoOp"
    }
    op {
      name: "RestoreV2"
      input_arg {
        name: "prefix"
        type: DT_STRING
      }
      input_arg {
        name: "tensor_names"
        type: DT_STRING
      }
      input_arg {
        name: "shape_and_slices"
        type: DT_STRING
      }
      output_arg {
        name: "tensors"
        type_list_attr: "dtypes"
      }
      attr {
        name: "dtypes"
        type: "list(type)"
        has_minimum: true
        minimum: 1
      }
      is_stateful: true
    }
    op {
      name: "SaveV2"
      input_arg {
        name: "prefix"
        type: DT_STRING
      }
      input_arg {
        name: "tensor_names"
        type: DT_STRING
      }
      input_arg {
        name: "shape_and_slices"
        type: DT_STRING
      }
      input_arg {
        name: "tensors"
        type_list_attr: "dtypes"
      }
      attr {
        name: "dtypes"
        type: "list(type)"
        has_minimum: true
        minimum: 1
      }
      is_stateful: true
    }
    op {
      name: "VariableV2"
      output_arg {
        name: "ref"
        type_attr: "dtype"
        is_ref: true
      }
      attr {
        name: "shape"
        type: "shape"
      }
      attr {
        name: "dtype"
        type: "type"
      }
      attr {
        name: "container"
        type: "string"
        default_value {
          s: ""
        }
      }
      attr {
        name: "shared_name"
        type: "string"
        default_value {
          s: ""
        }
      }
      is_stateful: true
    }
  }
  tensorflow_version: "1.4.0-rc1"
  tensorflow_git_version: "v1.4.0-rc0-21-g1e25994"
}
graph_def {
  node {
    name: "Const"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_FLOAT
          tensor_shape {
            dim {
              size: 1
            }
          }
          float_val: 1.0
        }
      }
    }
  }
  node {
    name: "v1"
    op: "VariableV2"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "container"
      value {
        s: ""
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "shape"
      value {
        shape {
          dim {
            size: 1
          }
        }
      }
    }
    attr {
      key: "shared_name"
      value {
        s: ""
      }
    }
  }
  node {
    name: "v1/Assign"
    op: "Assign"
    input: "v1"
    input: "Const"
    attr {
      key: "T"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "_class"
      value {
        list {
          s: "loc:@v1"
        }
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "use_locking"
      value {
        b: true
      }
    }
    attr {
      key: "validate_shape"
      value {
        b: true
      }
    }
  }
  node {
    name: "v1/read"
    op: "Identity"
    input: "v1"
    attr {
      key: "T"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "_class"
      value {
        list {
          s: "loc:@v1"
        }
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
  }
  node {
    name: "Const_1"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_FLOAT
          tensor_shape {
            dim {
              size: 1
            }
          }
          float_val: 2.0
        }
      }
    }
  }
  node {
    name: "v2"
    op: "VariableV2"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "container"
      value {
        s: ""
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "shape"
      value {
        shape {
          dim {
            size: 1
          }
        }
      }
    }
    attr {
      key: "shared_name"
      value {
        s: ""
      }
    }
  }
  node {
    name: "v2/Assign"
    op: "Assign"
    input: "v2"
    input: "Const_1"
    attr {
      key: "T"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "_class"
      value {
        list {
          s: "loc:@v2"
        }
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "use_locking"
      value {
        b: true
      }
    }
    attr {
      key: "validate_shape"
      value {
        b: true
      }
    }
  }
  node {
    name: "v2/read"
    op: "Identity"
    input: "v2"
    attr {
      key: "T"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "_class"
      value {
        list {
          s: "loc:@v2"
        }
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
  }
  node {
    name: "add"
    op: "Add"
    input: "v1/read"
    input: "v2/read"
    attr {
      key: "T"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
  }
  node {
    name: "save/Const"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_STRING
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_STRING
          tensor_shape {
          }
          string_val: "model"
        }
      }
    }
  }
  node {
    name: "save/SaveV2/tensor_names"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 2
            }
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_STRING
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_STRING
          tensor_shape {
            dim {
              size: 2
            }
          }
          string_val: "v1"
          string_val: "v2"
        }
      }
    }
  }
  node {
    name: "save/SaveV2/shape_and_slices"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 2
            }
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_STRING
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_STRING
          tensor_shape {
            dim {
              size: 2
            }
          }
          string_val: ""
          string_val: ""
        }
      }
    }
  }
  node {
    name: "save/SaveV2"
    op: "SaveV2"
    input: "save/Const"
    input: "save/SaveV2/tensor_names"
    input: "save/SaveV2/shape_and_slices"
    input: "v1"
    input: "v2"
    attr {
      key: "dtypes"
      value {
        list {
          type: DT_FLOAT
          type: DT_FLOAT
        }
      }
    }
  }
  node {
    name: "save/control_dependency"
    op: "Identity"
    input: "save/Const"
    input: "^save/SaveV2"
    attr {
      key: "T"
      value {
        type: DT_STRING
      }
    }
    attr {
      key: "_class"
      value {
        list {
          s: "loc:@save/Const"
        }
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
          }
        }
      }
    }
  }
  node {
    name: "save/RestoreV2/tensor_names"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_STRING
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_STRING
          tensor_shape {
            dim {
              size: 1
            }
          }
          string_val: "v1"
        }
      }
    }
  }
  node {
    name: "save/RestoreV2/shape_and_slices"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_STRING
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_STRING
          tensor_shape {
            dim {
              size: 1
            }
          }
          string_val: ""
        }
      }
    }
  }
  node {
    name: "save/RestoreV2"
    op: "RestoreV2"
    input: "save/Const"
    input: "save/RestoreV2/tensor_names"
    input: "save/RestoreV2/shape_and_slices"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            unknown_rank: true
          }
        }
      }
    }
    attr {
      key: "dtypes"
      value {
        list {
          type: DT_FLOAT
        }
      }
    }
  }
  node {
    name: "save/Assign"
    op: "Assign"
    input: "v1"
    input: "save/RestoreV2"
    attr {
      key: "T"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "_class"
      value {
        list {
          s: "loc:@v1"
        }
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "use_locking"
      value {
        b: true
      }
    }
    attr {
      key: "validate_shape"
      value {
        b: true
      }
    }
  }
  node {
    name: "save/RestoreV2_1/tensor_names"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_STRING
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_STRING
          tensor_shape {
            dim {
              size: 1
            }
          }
          string_val: "v2"
        }
      }
    }
  }
  node {
    name: "save/RestoreV2_1/shape_and_slices"
    op: "Const"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "dtype"
      value {
        type: DT_STRING
      }
    }
    attr {
      key: "value"
      value {
        tensor {
          dtype: DT_STRING
          tensor_shape {
            dim {
              size: 1
            }
          }
          string_val: ""
        }
      }
    }
  }
  node {
    name: "save/RestoreV2_1"
    op: "RestoreV2"
    input: "save/Const"
    input: "save/RestoreV2_1/tensor_names"
    input: "save/RestoreV2_1/shape_and_slices"
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            unknown_rank: true
          }
        }
      }
    }
    attr {
      key: "dtypes"
      value {
        list {
          type: DT_FLOAT
        }
      }
    }
  }
  node {
    name: "save/Assign_1"
    op: "Assign"
    input: "v2"
    input: "save/RestoreV2_1"
    attr {
      key: "T"
      value {
        type: DT_FLOAT
      }
    }
    attr {
      key: "_class"
      value {
        list {
          s: "loc:@v2"
        }
      }
    }
    attr {
      key: "_output_shapes"
      value {
        list {
          shape {
            dim {
              size: 1
            }
          }
        }
      }
    }
    attr {
      key: "use_locking"
      value {
        b: true
      }
    }
    attr {
      key: "validate_shape"
      value {
        b: true
      }
    }
  }
  node {
    name: "save/restore_all"
    op: "NoOp"
    input: "^save/Assign"
    input: "^save/Assign_1"
  }
  versions {
    producer: 24
  }
}
saver_def {
  filename_tensor_name: "save/Const:0"
  save_tensor_name: "save/control_dependency:0"
  restore_op_name: "save/restore_all"
  max_to_keep: 5
  keep_checkpoint_every_n_hours: 10000.0
  version: V2
}
collection_def {
  key: "trainable_variables"
  value {
    bytes_list {
      value: "\n\004v1:0\022\tv1/Assign\032\tv1/read:02\007Const:0"
      value: "\n\004v2:0\022\tv2/Assign\032\tv2/read:02\tConst_1:0"
    }
  }
}
collection_def {
  key: "variables"
  value {
    bytes_list {
      value: "\n\004v1:0\022\tv1/Assign\032\tv1/read:02\007Const:0"
      value: "\n\004v2:0\022\tv2/Assign\032\tv2/read:02\tConst_1:0"
    }
  }
}

除了持久化Tensorflow计算图的结构,Tensorflow中的取值也很重要。名为model.ckpt.data和\(model.ckpt.data-*****-of*****\)的文件保存的就是所有变量的取值。其中model.ckpt.data文件是通过SSTable格式存储的,可以大致理解为是一个(key, value)字典列表。以下例子展示如何使用tf.train.NewCheckpointReader类查看保存的变量信息

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
tf.reset_default_graph()

v1 = tf.Variable(tf.constant(1.0, shape=[1], name="v1"))
v2 = tf.Variable(tf.constant(2.0, shape=[1], name="v2"))

result = v1+v2

init_op = tf.global_variables_initializer()

saver = tf.train.Saver()

with tf.Session() as sess:
sess.run(init_op)
print(sess.run(result))
saver.save(sess, 'model/model_test.ckpt')
[3.]
1
2
3
4
5
saver = tf.train.import_meta_graph('model/model_test.ckpt.meta')

with tf.Session() as sess:
saver.restore(sess, 'model/model_test.ckpt')
print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))
INFO:tensorflow:Restoring parameters from model/model_test.ckpt
[3.]
1
2
3
4
5
6
7
8
9
10
11
tf.reset_default_graph()

# 可以省去后面的.data和.index
reader = tf.train.NewCheckpointReader('model/model_test.ckpt')

# 获取所有变量列表。
global_variables = reader.get_variable_to_shape_map()
for variable_name in global_variables:
print(variable_name, global_variable[variable_name])

print("value for variable v1 is", reader.get_tensor("Variable_1"))
Variable_1 [1]
Variable [1]
value for variable v1 is [2.]

最后一个文件的名字是固定的——checkpoint。该文件是由tf.train.Saver类自动生成的,维护了tf.train.Saver类持久化的所有模型文件名。当某个模型文件删除时,这个模型所对应的文件名也会从checkpoint文件中删除。

参考资料

  • 郑泽宇、梁博文和顾思宇,Tensorflow: 实战Google深度学习框架(第二版)