TensorFlow 保存到/从文件中加载图形

从我目前收集到的信息来看,有几种不同的方法可以将 TensorFlow 图表转储到一个文件中,然后将其加载到另一个程序中,但是我还没有找到关于它们如何工作的清晰示例/信息。我已经知道的是:

  1. 使用 tf.train.Saver()将模型变量保存到一个检查点文件(. ckpt)中,并在以后恢复它们(来源)
  2. 将模型保存到. pb 文件中,并使用 tf.train.write_graph()tf.import_graph_def()(来源)将其加载回来
  3. 中加载模型。Pb 文件,重新训练它,并将其转储到一个新的。Pb 文件使用 Bazel (来源)
  4. 冻结图形以保存图形和权重(来源)
  5. 使用 as_graph_def()保存模型,对于权重/变量,将它们映射到常量(来源)

然而,关于这些不同的方法,我还没有弄清楚几个问题:

  1. 关于检查点文件,它们是否只保存模型的经过训练的权重?检查点文件是否可以被加载到一个新程序中,并用于运行模型,或者它们仅仅作为在某个时间/阶段保存模型中权重的方法?
  2. 关于 tf.train.write_graph(),是否也保存了权重/变量?
  3. 对于 Bazel,它只能保存到/load from。再培训的 PB 档案?是否有一个简单的 Bazel 命令只是将图转储到。花生酱?
  4. 关于冻结,使用 tf.import_graph_def()可以加载冻结图吗?
  5. TensorFlow 的 Android 演示从一个。Pb 档案。如果我想用我自己的。Pb 文件,我该怎么做?我是否需要更改任何本机代码/方法?
  6. 一般来说,所有这些方法之间到底有什么区别? 或者更广泛地说,as_graph_def()/. ckpt/. pb 之间的区别是什么?

简而言之,我正在寻找的是一种方法,既可以保存一个图(如各种操作等)及其权重/变量到一个文件中,然后可以用来加载图和权重到另一个程序,以供使用(不一定继续/再培训)。

关于这个主题的文档不是很直接,所以如果有任何答案/信息,我们将非常感激。

43195 次浏览

There are many ways to approach the problem of saving a model in TensorFlow, which can make it a bit confusing. Taking each of your sub-questions in turn:

  1. The checkpoint files (produced e.g. by calling saver.save() on a tf.train.Saver object) contain only the weights, and any other variables defined in the same program. To use them in another program, you must re-create the associated graph structure (e.g. by running code to build it again, or calling tf.import_graph_def()), which tells TensorFlow what to do with those weights. Note that calling saver.save() also produces a file containing a MetaGraphDef, which contains a graph and details of how to associate the weights from a checkpoint with that graph. See the tutorial for more details.

  2. tf.train.write_graph() only writes the graph structure; not the weights.

  3. Bazel is unrelated to reading or writing TensorFlow graphs. (Perhaps I misunderstand your question: feel free to clarify it in a comment.)

  4. A frozen graph can be loaded using tf.import_graph_def(). In this case, the weights are (typically) embedded in the graph, so you don't need to load a separate checkpoint.

  5. The main change would be to update the names of the tensor(s) that are fed into the model, and the names of the tensor(s) that are fetched from the model. In the TensorFlow Android demo, this would correspond to the inputName and outputName strings that are passed to TensorFlowClassifier.initializeTensorFlow().

  6. The GraphDef is the program structure, which typically does not change through the training process. The checkpoint is a snapshot of the state of a training process, which typically changes at every step of the training process. As a result, TensorFlow uses different storage formats for these types of data, and the low-level API provides different ways to save and load them. Higher-level libraries, such as the MetaGraphDef libraries, Keras, and skflow build on these mechanisms to provide more convenient ways to save and restore an entire model.

You can try the following code:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)