What's the purpose of tf.app.flags in TensorFlow?

我在 Tensorflow 读了一些示例代码,我发现了以下代码

flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
flags.DEFINE_integer('max_steps', 2000, 'Number of steps to run trainer.')
flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
flags.DEFINE_integer('batch_size', 100, 'Batch size.  '
'Must divide evenly into the dataset sizes.')
flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
'for unit testing.')

tensorflow/tensorflow/g3doc/tutorials/mnist/fully_connected_feed.py

但是我找不到任何关于 tf.app.flags使用的文档。

我发现这个标志的实现在 tensorflow/tensorflow/python/platform/default/_flags.py

显然,这个 tf.app.flags以某种方式用于配置网络,那么为什么它不在 API 文档中?有人能解释一下这是怎么回事吗?

78331 次浏览

tf.app.flags模块目前是围绕 python-gflags, so the documentation for that project is the best resource for how to use it argparse的一个薄包装器,它实现了 python-gflags中功能的一个子集。

请注意,这个模块目前被打包为便于编写演示应用程序,并且在技术上不是公共 API 的一部分,因此它在将来可能会发生变化。

我们建议您使用 argparse或任何您喜欢的库来实现您自己的标志解析。

编辑: tf.app.flags模块实际上并没有使用 python-gflags实现,但是它使用了一个类似的 API。

当您使用 tf.app.run()时,您可以使用 tf.app.flags在线程之间非常方便地传输变量。有关 tf.app.flags的进一步使用,请参见 这个

经过多次尝试,我发现这打印所有 FLAGS 键以及实际值-

for key in tf.app.flags.FLAGS.flag_values_dict():


print(key, FLAGS[key].value)

The tf.app.flags module is a functionality provided by Tensorflow to implement command line flags for your Tensorflow program. As an example, the code you came across would do the following:

flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')

第一个参数定义标志的名称,而第二个参数定义默认值,以防在执行文件时没有指定标志。

因此,如果运行以下命令:

$ python fully_connected_feed.py --learning_rate 1.00

then the learning rate is set to 1.00 and will remain 0.01 if the flag is not specified.

正如前面提到的 在这篇文章中,文档可能不存在,因为这可能是 Google 内部要求开发人员使用的东西。

此外,正如文章中提到的,使用 Tensorflow 标志比使用其他 Python 包(如 argparse)提供的标志功能有几个优势,特别是在处理 Tensorflow 模型时,最重要的是你可以向代码提供 Tensorflow 特定的信息,比如使用哪个 GPU 的信息。

简短的回答:

在谷歌,他们使用标志系统来设置参数的默认值。类似于 argparse。它们使用自己的标志系统,而不是 argparse 或 sys.argv。

来源: 我以前在那里工作过。

Long Answer:

对于那个例子中的参数,它们被称为 超参数。在神经网络中,有多个参数可以进行优化,以获得期望的结果。例如,对于 批次大小,它是可以在一次拍摄中传递给 优化器的数据向量(可以是图像、文本或原始数据点)的数量。

你可以谷歌一下这个参数的名字,看看它的目的是什么。如果你想学习深度学习,我建议你参加 Andrew Ng 的课程。