检测 NumPy 数组是否至少包含一个非数值?

我需要写一个函数来检测输入是否至少包含一个非数值的值。如果找到一个非数值,我将引发一个错误(因为计算应该只返回一个数值)。输入数组的维数事先并不知道-无论 ndim 是什么,函数都应该给出正确的值。作为一个额外的复杂输入可能是一个单一的浮点数或 numpy.float64,甚至像零维数组一样的奇怪的东西。

解决这个问题的显而易见的方法是编写一个递归函数,它迭代数组中的每个可迭代对象,直到找到一个非迭代对象。它将在每个不可迭代的对象上应用 numpy.isnan()函数。如果至少找到一个非数值,那么函数将立即返回 False。否则,如果迭代器中的所有值都是数值,那么它最终将返回 True。

这种方法工作得很好,但是速度非常慢,我希望 笨蛋有更好的方法来实现这一点。有什么替代方案更快更麻木?

这是我的模型:

def contains_nan( myarray ):
"""
@param myarray : An n-dimensional array or a single float
@type myarray : numpy.ndarray, numpy.array, float
@returns: bool
Returns true if myarray is numeric or only contains numeric values.
Returns false if at least one non-numeric value exists
Not-A-Number is given by the numpy.isnan() function.
"""
return True
163636 次浏览

这应该比迭代更快,并且不管形状如何都能正常工作。

numpy.isnan(myarray).any()

编辑: 快30倍:

import timeit
s = 'import numpy;a = numpy.arange(10000.).reshape((100,100));a[10,10]=numpy.nan'
ms = [
'numpy.isnan(a).any()',
'any(numpy.isnan(x) for x in a.flatten())']
for m in ms:
print "  %.2f s" % timeit.Timer(m, s).timeit(1000), m

结果:

  0.11 s numpy.isnan(a).any()
3.75 s any(numpy.isnan(x) for x in a.flatten())

额外的好处: 对于非数组 NumPy 类型,它工作得很好:

>>> a = numpy.float64(42.)
>>> numpy.isnan(a).any()
False
>>> a = numpy.float64(numpy.nan)
>>> numpy.isnan(a).any()
True

使用 numpy 1.3或 svn 可以做到这一点

In [1]: a = arange(10000.).reshape(100,100)


In [3]: isnan(a.max())
Out[3]: False


In [4]: a[50,50] = nan


In [5]: isnan(a.max())
Out[5]: True


In [6]: timeit isnan(a.max())
10000 loops, best of 3: 66.3 µs per loop

在比较中对奶奶的处理在早期版本中并不一致。

如果无穷大是一个可能的值,我将使用 Numpy 是有限的

numpy.isfinite(myarray).all()

如果上面的值为 True,那么 myarray不包含 numpy.nannumpy.inf-numpy.inf

numpy.isnan对于 numpy.inf值没有问题,例如:

In [11]: import numpy as np


In [12]: b = np.array([[4, np.inf],[np.nan, -np.inf]])


In [13]: np.isnan(b)
Out[13]:
array([[False, False],
[ True, False]], dtype=bool)


In [14]: np.isfinite(b)
Out[14]:
array([[ True, False],
[False, False]], dtype=bool)

如果 A至少含有 nan的一个元素,(np.where(np.isnan(A)))[0].shape[0]将大于 0A可能是 n x m基质。

例如:

import numpy as np


A = np.array([1,2,4,np.nan])


if (np.where(np.isnan(A)))[0].shape[0]:
print "A contains nan"
else:
print "A does not contain nan"

噗! 几微秒! 永远不要在微秒内解决一个可以在纳秒内解决的问题。

请注意,公认的答案是:

  • 遍历整个数据,而不管是否找到 nan
  • 创建一个大小为 N 的临时数组,这是多余的。

一个更好的解决方案是在找到 NAN 时立即返回 True:

import numba
import numpy as np


NAN = float("nan")


@numba.njit(nogil=True)
def _any_nans(a):
for x in a:
if np.isnan(x): return True
return False


@numba.jit
def any_nans(a):
if not a.dtype.kind=='f': return False
return _any_nans(a.flat)


array1M = np.random.rand(1000000)
assert any_nans(array1M)==False
%timeit any_nans(array1M)  # 573us


array1M[0] = NAN
assert any_nans(array1M)==True
%timeit any_nans(array1M)  # 774ns  (!nanoseconds)

在 n 维空间工作:

array1M_nd = array1M.reshape((len(array1M)/2, 2))
assert any_nans(array1M_nd)==True
%timeit any_nans(array1M_nd)  # 774ns

将此与麻木的本机解决方案进行比较:

def any_nans(a):
if not a.dtype.kind=='f': return False
return np.isnan(a).any()


array1M = np.random.rand(1000000)
assert any_nans(array1M)==False
%timeit any_nans(array1M)  # 456us


array1M[0] = NAN
assert any_nans(array1M)==True
%timeit any_nans(array1M)  # 470us


%timeit np.isnan(array1M).any()  # 532us

早期退出方法是3个数量级的加速比(在某些情况下)。 对于一个简单的注释来说不算太寒酸。