如何运行 tf.app.run() ?

tf.app.run()在 Tensorflow 翻译演示中是如何工作的?

tensorflow/models/rnn/translate/translate.py中,有一个对 tf.app.run()的调用。它是如何处理的?

if __name__ == "__main__":
tf.app.run()
68219 次浏览

它只是一个非常快速的包装器,处理标志解析,然后分派到您自己的主机上。

if __name__ == "__main__":

意味着当前文件是在 shell 下执行的,而不是作为模块导入的。

tf.app.run()

正如您可以通过文件 app.py看到的

def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
f = flags.FLAGS


# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None


# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access


main = main or sys.modules['__main__'].main


# Call the main function, passing through any arguments
# to the final program.
sys.exit(main(sys.argv[:1] + flags_passthrough))

让我们一行一行地来分析:

flags_passthrough = f._parse_flags(args=args)

这可以确保您通过命令行传递的参数是有效的,例如。 python my_model.py --data_dir='...' --max_iteration=10000实际上,这个特性是基于 python 标准 argparse模块实现的。

main = main or sys.modules['__main__'].main

=右侧的第一个 main是当前函数 run(main=None, argv=None)的第一个参数 。而 sys.modules['__main__']表示当前正在运行的文件(例如 my_model.py)。

因此有两种情况:

  1. 你在 my_model.py中没有 main函数,那么你必须 调用 tf.app.run(my_main_running_function)

  2. my_model.py中有一个 main函数(大多数情况是这样的)

最后一句:

sys.exit(main(sys.argv[:1] + flags_passthrough))

确保您的 main(argv)my_main_running_function(argv)函数被正确地用解析参数调用。

没有什么特别的在 tf.app。这只是一个 一般入口点脚本,它

使用可选的“ main”函数和“ argv”列表运行程序。

它与神经网络没有任何关系,它只是调用 main 函数,传递任何参数给它。

简单来说,tf.app.run()的工作就是为以后的使用设置全局标志,比如:

from tensorflow.python.platform import flags
f = flags.FLAGS

然后用一组参数运行 海关总署函数。

例如,在 张量流代码库中,用于训练/推理的程序执行的第一个入口点从这一点开始(见下面的代码)

if __name__ == "__main__":
nmt_parser = argparse.ArgumentParser()
add_arguments(nmt_parser)
FLAGS, unparsed = nmt_parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

在使用 argparse解析参数之后,使用 tf.app.run()运行函数“ main”,其定义如下:

def main(unused_argv):
default_hparams = create_hparams(FLAGS)
train_fn = train.train
inference_fn = inference.inference
run_main(FLAGS, default_hparams, train_fn, inference_fn)

因此,在设置了全局使用的标志之后,tf.app.run()只是运行您传递给它的 main函数,并将 argv作为其参数。

附注: 正如 萨尔瓦多 · 达利的回答所说,我想这只是一个很好的软件工程实践,尽管我不确定 TensorFlow 是否比普通的 CPython 执行任何 main函数的优化运行。

Google 代码在很大程度上依赖于访问库/二进制文件/Python 脚本中的全局标志,因此 tf.app.run ()解析出这些标志,在 FLAGs (或类似的)变量中创建一个全局状态,然后按照应该的方式调用 python main ()。

如果他们没有调用 tf.app.run () ,那么用户可能会忘记进行 FLAG 解析,导致这些库/二进制文件/脚本无法访问他们需要的 FLAG。

2.0兼容答案 : 如果你想在 Tensorflow 2.0中使用 tf.app.run(),我们应该使用命令,

或者可以使用 tf_upgrade_v21.x代码转换为 2.0