TensorFlow,为什么保存模型后有3个文件?

在阅读了 医生之后,我在 TensorFlow中保存了一个模型,下面是我的演示代码:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()


# Add ops to save and restore all the variables.
saver = tf.train.Saver()


# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model.
..
# Save the variables to disk.
save_path = saver.save(sess, "/tmp/model.ckpt")
print("Model saved in file: %s" % save_path)

但在那之后,我发现有3个文件

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

我不能通过恢复 model.ckpt文件来恢复模型,因为没有这样的文件。这是我的密码

with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")

那为什么有三份文件?

93868 次浏览

试试这个:

with tf.Session() as sess:
saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
saver.restore(sess, "/tmp/model.ckpt")

TensorFlow 保存方法保存三种文件,因为它将 图形结构图形结构变量值分开存储。.meta文件描述了保存的图形结构,因此在还原检查点之前需要导入它(否则它不知道保存的检查点值对应于哪些变量)。

或者,你可以这样做:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")


...


# Now load the checkpoint variable values
with tf.Session() as sess:
saver = tf.train.Saver()
saver.restore(sess, "/tmp/model.ckpt")

尽管没有名为 model.ckpt的文件,但是在还原时仍然使用该名称引用保存的检查点。来自 saver.py源代码:

用户只需要与用户指定的前缀... 进行交互 任何物理路径的名称。

  • 元文件 : 描述保存的图形结构,包括 GraphDef、 SaverDef 等; 然后应用 tf.train.import_meta_graph('/tmp/model.ckpt.meta'),将恢复 SaverGraph

  • Index file : 它是一个字符串不可变的表(tensorflow: : Table: : Table)。每个键是一个张量的名称,其值是一个序列化的 BundleEntryProto。每个 BundleEntryProto 都描述张量的元数据: 哪个“数据”文件包含张量的内容、该文件的偏移量、校验和、一些辅助数据等。

  • Data file : 它是 TensorBundle 集合,保存所有变量的值。

我恢复从 Word2Vec张量流教程训练的单词嵌入。

如果您创建了多个检查点:

创建的文件如下所示

Model-ckpt-55695. data-00000-of-00001

Ckpt-55695. index

Model ckpt-55695元

试试这个

def restore_session(self, session):
saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
saver.restore(session, './tmp/model.ckpt-55695')

调用 return _ session ()时:

def test_word2vec():
opts = Options()
with tf.Graph().as_default(), tf.Session() as session:
with tf.device("/cpu:0"):
model = Word2Vec(opts, session)
model.restore_session(session)
model.get_embedding("assistance")

例如,如果你训练一个退学的 CNN,你可以这样做:

def predict(image, model_name):
"""
image -> single image, (width, height, channels)
model_name -> model file that was saved without any extensions
"""
with tf.Session() as sess:
saver = tf.train.import_meta_graph('./' + model_name + '.meta')
saver.restore(sess, './' + model_name)
# Substitute 'logits' with your model
prediction = tf.argmax(logits, 1)
# 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
return prediction.eval(feed_dict={x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0})