在 Tensorflow,得到图中所有张量的名称

我正在用 Tensorflowskflow创建神经网络; 出于某种原因,我想得到一些给定输入的内部张量的值,所以我使用 myClassifier.get_layer_value(input, "tensorName")myClassifier作为 skflow.estimators.TensorFlowEstimator

然而,我发现很难找到正确的张量名称的语法,即使知道它的名称(我正在混淆操作和张量) ,所以我使用张量板绘制图形,并寻找名称。

有没有一种不用张量板就可以列举图中所有张量的方法?

160513 次浏览

你可以的

[n.name for n in tf.get_default_graph().as_graph_def().node]

此外,如果你在一个 IPython 笔记本电脑原型,你可以显示图表直接在笔记本,见 show_graph功能在亚历山大的深度梦想 笔记本

tf.all_variables()可以得到你想要的信息。

此外,今天在 TensorFlow 学习中制作的 这个承诺在估计器中提供了一个函数 get_variable_names,您可以使用它轻松地检索所有变量名。

有一种方法可以比雅罗斯拉夫的答案稍微快一点,那就是使用 获取操作。下面是一个简单的例子:

import tensorflow as tf


a = tf.constant(1.3, name='const_a')
b = tf.Variable(3.1, name='variable_b')
c = tf.add(a, b, name='addition')
d = tf.multiply(c, a, name='multiply')


for op in tf.get_default_graph().get_operations():
print(str(op.name))

我认为这样也可以:

print(tf.contrib.graph_editor.get_tensors(tf.get_default_graph()))

但与萨尔瓦多和雅罗斯拉夫的回答相比,我不知道哪一个更好。

接受的答案只给出一个包含名称的字符串列表。我倾向于采用一种不同的方法,这种方法可以(几乎)直接访问张量:

graph = tf.get_default_graph()
list_of_tuples = [op.values() for op in graph.get_operations()]

list_of_tuples现在包含每个张量,每个张量都在一个元组中。你也可以调整它直接得到张量:

graph = tf.get_default_graph()
list_of_tuples = [op.values()[0] for op in graph.get_operations()]

以前的答案都很好,我只是想分享一个我写的从图表中选择张量的实用函数:

def get_graph_op(graph, and_conds=None, op='and', or_conds=None):
"""Selects nodes' names in the graph if:
- The name contains all items in and_conds
- OR/AND depending on op
- The name contains any item in or_conds


Condition starting with a "!" are negated.
Returns all ops if no optional arguments is given.


Args:
graph (tf.Graph): The graph containing sought tensors
and_conds (list(str)), optional): Defaults to None.
"and" conditions
op (str, optional): Defaults to 'and'.
How to link the and_conds and or_conds:
with an 'and' or an 'or'
or_conds (list(str), optional): Defaults to None.
"or conditions"


Returns:
list(str): list of relevant tensor names
"""
assert op in {'and', 'or'}


if and_conds is None:
and_conds = ['']
if or_conds is None:
or_conds = ['']


node_names = [n.name for n in graph.as_graph_def().node]


ands = {
n for n in node_names
if all(
cond in n if '!' not in cond
else cond[1:] not in n
for cond in and_conds
)}


ors = {
n for n in node_names
if any(
cond in n if '!' not in cond
else cond[1:] not in n
for cond in or_conds
)}


if op == 'and':
return [
n for n in node_names
if n in ands.intersection(ors)
]
elif op == 'or':
return [
n for n in node_names
if n in ands.union(ors)
]

所以如果你有一个带操作的图表:

['model/classifier/dense/kernel',
'model/classifier/dense/kernel/Assign',
'model/classifier/dense/kernel/read',
'model/classifier/dense/bias',
'model/classifier/dense/bias/Assign',
'model/classifier/dense/bias/read',
'model/classifier/dense/MatMul',
'model/classifier/dense/BiasAdd',
'model/classifier/ArgMax/dimension',
'model/classifier/ArgMax']

然后逃跑

get_graph_op(tf.get_default_graph(), ['dense', '!kernel'], 'or', ['Assign'])

报税表:

['model/classifier/dense/kernel/Assign',
'model/classifier/dense/bias',
'model/classifier/dense/bias/Assign',
'model/classifier/dense/bias/read',
'model/classifier/dense/MatMul',
'model/classifier/dense/BiasAdd']

这对我很有效:

for n in tf.get_default_graph().as_graph_def().node:
print('\n',n)

因为 OP 要求的是张量列表而不是操作/节点列表,所以代码应该略有不同:

graph = tf.get_default_graph()
tensors_per_node = [node.values() for node in graph.get_operations()]
tensor_names = [tensor.name for tensors in tensors_per_node for tensor in tensors]

我试着总结一下答案:

要获得图中的所有 节点: (类型为 tensorflow.core.framework.node_def_pb2.NodeDef)

all_nodes = [n for n in tf.get_default_graph().as_graph_def().node]

要获得图中的所有 指挥中心: (类型为 tensorflow.python.framework.ops.Operation)

all_ops = tf.get_default_graph().get_operations()

要获得图中的所有 变量: (类型为 tensorflow.python.ops.resource_variable_ops.ResourceVariable)

all_vars = tf.global_variables()

要获得图中的所有 张量: (类型为 tensorflow.python.framework.ops.Tensor)

all_tensors = [tensor for op in tf.get_default_graph().get_operations() for tensor in op.values()]

要获得图中的所有 占位符: (类型为 tensorflow.python.framework.ops.Tensor)

all_placeholders = [placeholder for op in tf.get_default_graph().get_operations() if op.type=='Placeholder' for placeholder in op.values()]

张量流2

要获得 Tensorflow 2中的图形,您需要首先实例化一个 tf.function并访问 graph属性,而不是 tf.get_default_graph(),例如:

graph = func.get_concrete_function().graph

其中 functf.function

下面的解决方案适用于 TensorFlow 2.3-

def load_pb(path_to_pb):
with tf.io.gfile.GFile(path_to_pb, 'rb') as f:
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
return graph
tf_graph = load_pb(MODEL_FILE)
sess = tf.compat.v1.Session(graph=tf_graph)


# Show tensor names in graph
for op in tf_graph.get_operations():
print(op.values())

其中 MODEL_FILE是冻结图的路径。

取自 给你