TensorFlow使用记录 (四): 模型保存

模型文件

tensorflow 训练保存的模型注意包含两个部分:网络结构和参数值。

.meta

 .meta 文件以 “protocol buffer”格式保存了整个模型的结构图,模型上定义的操作等信息。

 .data & .index

.data 和 .index 文件合在一起组成了 ckpt 文件,保存了网络结构中所有 权重和偏置 的数值。

.data文件保存的是变量值,.index文件保存的是.data文件中数据和 .meta文件中结构图之间的对应关系

查看 ckpt 模型文件中保存的 Tensor 信息:

import tensorflow as tf

checkpoint_path = 'cnn_mnist.ckpt'
reader = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()

# Print tensor name and values
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key))
知识兔

模型保存 tf.train.Saver

 tf.train.Saver()
saver.save(sess,"model_test.ckpt")
知识兔

Saver类的构造函数定义:

def __init__(self,
             var_list=None, # 指定要保存的变量的序列或字典,默认为None,保存所有变量
             reshape=False,
             sharded=False,
             max_to_keep=5, # 定义最多保存最近的多少个模型文件
             keep_checkpoint_every_n_hours=10000.0,
             name=None,
             restore_sequentially=False,
             saver_def=None,
             builder=None,
             defer_build=False,
             allow_empty=False,
             write_version=saver_pb2.SaverDef.V2,
             pad_step_number=False,
             save_relative_paths=False,
             filename=None):
知识兔

saver.save函数定义:

def save(self,
         sess,                     # 当前的会话环境
         save_path,                # 模型保存路径
         global_step=None,         # 训练轮次,如果添加,会在模型文件名称后加上这个轮次的后缀
         latest_filename=None,     # checkpoint 文本文件的名称,默认为 'checkpoint'
         meta_graph_suffix="meta", # 保存的网络图结构文件的后缀
         write_meta_graph=True,    # 定义是否保存网络结构
         write_state=True,
         strip_default_attrs=False,
         save_debug_info=False):
知识兔

模型加载 

加载模型时候可以先加载图结构,再加载图中的参数(在Session中操作):

'./model_saved/model_test.meta')
saver.restore(sess, tf.train.latest_checkpoint('./model_saved'))
知识兔

或者一次性加载:

 tf.train.Saver()
saver.restore(sess, './model_saved/model_test.ckpt')
# or
saver.restore(sess, tf.train.latest_checkpoint('./model_saved'))
知识兔
计算机