Tensorflow实现迁移学习

当训练样本比较少,可以使用预先训练好的模型,也可以将预训练好的模型用来初始化。本文介绍如何使用Tensorflow进行迁移学习。

预处理

1
2
3
4
5
6
import glob
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile
import tensorflow.contrib.slim as slim
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# 输入数据目录,共有5个子目录,分别对应5种类别的花
INPUT_DATA = '/home/seisinv/data/flower_photos/'
OUTPUT_DATA = '/home/seisinv/data/flower_photos_process.npy'

# 测试和验证数据的比例
VALIDATION_PERCENTAGE = 10
TEST_PERCENTAGE = 10

# 读取数据,并分割成训练数据、验证数据和测试数据
def create_image_list(sess, testing_percentage, validation_percentage):
sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
is_root_dir = True

# 初始化各个数据集
training_images = []
training_labels = []
testing_images = []
testing_labels = []
validation_images = []
validation_labels = []
current_label = 0

# 读取子目录
for sub_dir in sub_dirs:
if is_root_dir:
is_root_dir = False
continue
# 获取一个子目录中所有的图片文件
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
file_list = []
dir_name = os.path.basename(sub_dir)
for extension in extensions:
file_glob = os.path.join(INPUT_DATA, dir_name, '*.', extension)
file_list.extend(glob.glob(file_glob))
if not file_list: continue

# 处理图片数据
for file_name in file_list:
# 读取图片
image_raw_data = gfile.FastGFile(file_name,'rb').read()
# 解析图片
image = tf.image.decode_jpeg(image_raw_data)
# 图片转为float32型
if image.dtype != tf.float32:
image = tf.image_convert_image_dtype(
image, dtype=tf.float32)
# 图片大小改为299*299,以满足inception-v3模型的需求
image = tf.image_resize_images(image, [299, 299])
image_value = sess.run(image)

# 随机划分数据集
chance = np.random.randint(100)
if chance < validation_percentage:
validation_images.append(image_value)
validation_labels.append(current_label)
elif chance < (testing_percentage + validation_percentage):
testing_images.append(image_value)
validation_labels.append(current_label)
else:
training_images.append(image_value)
training_labels.append(current_label)
current_label += 1

# 将训练速度打乱
state = np.random.get_state()
np.random.shuffle(training_images)
np.random.set_state(state)
np.random.shuffle(training_labels)

return np.asarray([training_images, training_labels,
validation_images, validation_labels,
testing_images, testing_labels])
1
2
3
4
5
6
tf.reset_default_graph()

with tf.Session() as sess:
processed_data = create_image_list(
sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE)
np.save(OUTPUT_DATA, processed_data)
---------------------------------------------------------------------------

FileNotFoundError                         Traceback (most recent call last)

<ipython-input-3-89177c393fa3> in <module>()
      4     processed_data = create_image_list(
      5     sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE)
----> 6     np.save(OUTPUT_DATA, processed_data)


~/anaconda3/envs/fwi_ai/lib/python3.5/site-packages/numpy/lib/npyio.py in save(file, arr, allow_pickle, fix_imports)
    488         if not file.endswith('.npy'):
    489             file = file + '.npy'
--> 490         fid = open(file, "wb")
    491         own_fid = True
    492     elif is_pathlib_path(file):


FileNotFoundError: [Errno 2] No such file or directory: '/home/seisinv/data/flower_photos_process.npy'

迁移学习

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# 加载Tensorflow-Slim定义好的inception_v3模型
import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3

# 处理后的数据
INPUT_DATA = '/home/seisinv/data/flower_photos_process.npy'
# 保存训练后的模型
TRAIN_FILE = '/home/seisinv/data/model'
# 谷歌提供的训练好的模型
CKPT_FILE = '/home/seisinv/data/model/inception_v3.ckpt'

# 定义训练过程所使用的参数
LEARNING_RATE = 0.0001
STEPS = 300
BATCH = 23
N_CLASSES = 5

# 不需要从预训练模型中加载的参数名称,这里指最后的全连接层,给出参数的前缀
CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
# 需要训练的参数名称,同样指的是最后的全连接层,给出参数前缀
TRAINABLE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'

# 获取所有需要从预训练模型中加载的参数
def get_tuned_variables():
exclusions = [scope.strip() for scope in\
CHECKPOINT_EXCLUDE_SCOPES.split(',')]
variables_to_restore = []
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.appen(var)
return variables_to_restore

# 获取所有需要训练的变量列表
def get_trainable_variable():
scopes = [scope.strip() for scope in\
TRAINABLE_SCOPES.split(',')]
variables_to_train = []
for scope in scopes:
variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope)
variables_to_train.extend(variables)
return variables_to_train
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def main():
processed_data = np.load(INPUT_DATA)
training_images = processed_data[0]
n_training_examples = len(training_images)
training_labels = processed_data[1]
validation_images = processed_data[2]
validation_labels = processed_data[3]
testing_images = processed_data[4]
testing_labels = processed_data[5]
print("%d training examples, %d validation examples and %d "
"testing examples." %(n_training_examples, len(validation_images), len(testing_images)))

# 定义inception-v3的输入
images = tf.placeholder(
tf.float32, [None, 299, 299, 3],
name='input_image')
labels = tf.placeholder(
tf.int64, [None], name='labels')
1
print("a""b")
ab
1
tf.reset_default_graph()

更多谷歌训练好的模型可以参考这里

参考资料

  • 郑泽宇、梁博文和顾思宇,Tensorflow: 实战Google深度学习框架(第二版)