在循环(或理解)中创建函数(或 lambdas)

我试图在一个循环中创建函数:

functions = []


for i in range(3):
def f():
return i


# alternatively: f = lambda: i


functions.append(f)

问题是所有的函数最终都是相同的。这三个函数不返回0、1和2,而是返回2:

print([f() for f in functions])
# expected output: [0, 1, 2]
# actual output:   [2, 2, 2]

为什么会发生这种情况? 为了得到分别输出0、1和2的3个不同函数,我应该做些什么?

89933 次浏览

您在使用 后期装订时遇到了一个问题——每个函数都尽可能晚地查找 i(因此,当循环结束后调用时,i将被设置为 2)。

通过强制早期绑定很容易修复: 像这样将 def f():改为 def f(i=i)::

def f(i=i):
return i

默认值(i=i中右边的 i是参数名 i的默认值,也就是 i=i中左边的 i)在 def时查找,而不是在 call时查找,所以本质上它们是一种专门查找早期绑定的方法。

如果你担心 f得到一个额外的参数(因此可能会被错误地调用) ,有一种更复杂的方法涉及到使用一个闭包作为“函数工厂”:

def make_f(i):
def f():
return i
return f

在循环中使用 f = make_f(i)而不是 def语句。

解释

这里的问题是在创建函数 f时没有保存 i的值。相反,f电话时查找 i的值。

仔细想想,这种行为完全说得通。事实上,这是函数工作的唯一合理方式。假设你有一个函数可以访问一个全局变量,像这样:

global_var = 'foo'


def my_function():
print(global_var)


global_var = 'bar'
my_function()

当您读取这段代码时,您当然希望它打印“ bar”,而不是“ foo”,因为在声明这个函数之后,global_var的值已经发生了变化。同样的事情也发生在您自己的代码中: 当您调用 f时,i的值已经改变并被设置为 2

解决方案

实际上有很多方法可以解决这个问题,下面是一些选择:

  • 通过使用 i作为默认参数来强制早期绑定 i

    与闭包变量(如 i)不同,在定义函数时会立即计算默认参数:

    for i in range(3):
    def f(i=i):  # <- right here is the important bit
    return i
    
    
    functions.append(f)
    

    为了深入了解这是如何工作的: 函数的默认参数被存储为函数的一个属性; 因此 i目前值被快照并保存。

    >>> i = 0
    >>> def f(i=i):
    ...     pass
    >>> f.__defaults__  # this is where the current value of i is stored
    (0,)
    >>> # assigning a new value to i has no effect on the function's default arguments
    >>> i = 5
    >>> f.__defaults__
    (0,)
    
  • Use a function factory to capture the current value of i in a closure

    The root of your problem is that i is a variable that can change. We can work around this problem by creating another variable that is guaranteed to never change - and the easiest way to do this is a closure:

    def f_factory(i):
    def f():
    return i  # i is now a *local* variable of f_factory and can't ever change
    return f
    
    
    for i in range(3):
    f = f_factory(i)
    functions.append(f)
    
  • Use functools.partial to bind the current value of i to f

    functools.partial lets you attach arguments to an existing function. In a way, it too is a kind of function factory.

    import functools
    
    
    def f(i):
    return i
    
    
    for i in range(3):
    f_with_i = functools.partial(f, i)  # important: use a different variable than "f"
    functions.append(f_with_i)
    

Caveat: These solutions only work if you assign a new value to the variable. If you modify the object stored in the variable, you'll experience the same problem again:

>>> i = []  # instead of an int, i is now a *mutable* object
>>> def f(i=i):
...     print('i =', i)
...
>>> i.append(5)  # instead of *assigning* a new value to i, we're *mutating* it
>>> f()
i = [5]

注意,即使我们将 i转换为默认参数,它仍然发生了变化!如果代码 变异 i,那么必须将 收到i绑定到函数,如下所示:

  • def f(i=i.copy()):
  • f = f_factory(i.copy())
  • f_with_i = functools.partial(f, i.copy())

要添加到@Aran-fee 的出色答案中,在第二个解决方案中,你可能还希望修改函数内部的变量,这可以通过关键字 nonlocal来实现:

def f_factory(i):
def f(offset):
nonlocal i
i += offset
return i  # i is now a *local* variable of f_factory and can't ever change
return f


for i in range(3):
f = f_factory(i)
print(f(10))

你可以这样试试:

l=[]
for t in range(10):
def up(y):
print(y)
l.append(up)
l[5]('printing in 5th function')

只要把最后一行修改为

functions.append(f())

编辑: 这是因为 f是一个函数—— python 将函数视为一等公民,您可以将它们传递给以后调用的变量。因此,原始代码所做的是将函数本身附加到列表中,而您想要做的是将函数的 结果附加到列表中,这就是上一行通过调用函数所实现的。

你必须将每个 i值保存在内存中的一个单独的空间中,例如:

class StaticValue:
val = None


def __init__(self, value: int):
StaticValue.val = value


@staticmethod
def get_lambda():
return lambda x: x*StaticValue.val




class NotStaticValue:
def __init__(self, value: int):
self.val = value


def get_lambda(self):
return lambda x: x*self.val




if __name__ == '__main__':
def foo():
return [lambda x: x*i for i in range(4)]


def bar():
return [StaticValue(i).get_lambda() for i in range(4)]


def foo_repaired():
return [NotStaticValue(i).get_lambda() for i in range(4)]


print([x(2) for x in foo()])
print([x(2) for x in bar()])
print([x(2) for x in foo_repaired()])


Result:
[6, 6, 6, 6]
[6, 6, 6, 6]
[0, 2, 4, 6]