第一次出现的值大于现有值的Numpy

我在numpy中有一个1D数组,我想找到索引的位置,其中一个值超过numpy数组中的值。

如。

aa = range(-10,10)

aa中找到超出5值的位置。

228338 次浏览

这样更快一些(看起来也更好一些)

np.argmax(aa>5)

因为argmax将在第一个True处停止(“如果最大值多次出现,则返回第一次出现对应的索引。”)并且不会保存另一个列表。

In [2]: N = 10000


In [3]: aa = np.arange(-N,N)


In [4]: timeit np.argmax(aa>N/2)
100000 loops, best of 3: 52.3 us per loop


In [5]: timeit np.where(aa>N/2)[0][0]
10000 loops, best of 3: 141 us per loop


In [6]: timeit np.nonzero(aa>N/2)[0][0]
10000 loops, best of 3: 142 us per loop
In [34]: a=np.arange(-10,10)


In [35]: a
Out[35]:
array([-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,
3,   4,   5,   6,   7,   8,   9])


In [36]: np.where(a>5)
Out[36]: (array([16, 17, 18, 19]),)


In [37]: np.where(a>5)[0][0]
Out[37]: 16

给定数组的排序内容,还有一个更快的方法:searchsorted

import time
N = 10000
aa = np.arange(-N,N)
%timeit np.searchsorted(aa, N/2)+1
%timeit np.argmax(aa>N/2)
%timeit np.where(aa>N/2)[0][0]
%timeit np.nonzero(aa>N/2)[0][0]


# Output
100000 loops, best of 3: 5.97 µs per loop
10000 loops, best of 3: 46.3 µs per loop
10000 loops, best of 3: 154 µs per loop
10000 loops, best of 3: 154 µs per loop

我也对此感兴趣,我已经将所有建议的答案与perfplot进行了比较。(声明:我是perfplot的作者。)

如果你知道你正在查找的数组是已经排序,那么

numpy.searchsorted(a, alpha)

是给你的。它是O(log(n))操作,也就是说,速度几乎不取决于数组的大小。没有比这更快的了。

如果你对数组一无所知,你就不会出错

numpy.argmax(a > alpha)

已经排序:

enter image description here

未分类的:

enter image description here

代码重现情节:

import numpy
import perfplot




alpha = 0.5
numpy.random.seed(0)




def argmax(data):
return numpy.argmax(data > alpha)




def where(data):
return numpy.where(data > alpha)[0][0]




def nonzero(data):
return numpy.nonzero(data > alpha)[0][0]




def searchsorted(data):
return numpy.searchsorted(data, alpha)




perfplot.save(
"out.png",
# setup=numpy.random.rand,
setup=lambda n: numpy.sort(numpy.random.rand(n)),
kernels=[argmax, where, nonzero, searchsorted],
n_range=[2 ** k for k in range(2, 23)],
xlabel="len(array)",
)

我会说

i = np.min(np.where(V >= x))

其中V是vector (1d数组),x是值,i是结果索引。

元素间步长为常数的数组

对于range或任何其他线性递增数组,你可以简单地通过编程计算索引,根本不需要实际遍历数组:

def first_index_calculate_range_like(val, arr):
if len(arr) == 0:
raise ValueError('no value greater than {}'.format(val))
elif len(arr) == 1:
if arr[0] > val:
return 0
else:
raise ValueError('no value greater than {}'.format(val))


first_value = arr[0]
step = arr[1] - first_value
# For linearly decreasing arrays or constant arrays we only need to check
# the first element, because if that does not satisfy the condition
# no other element will.
if step <= 0:
if first_value > val:
return 0
else:
raise ValueError('no value greater than {}'.format(val))


calculated_position = (val - first_value) / step


if calculated_position < 0:
return 0
elif calculated_position > len(arr) - 1:
raise ValueError('no value greater than {}'.format(val))


return int(calculated_position) + 1

有人可能会改进一下。我已经确保它可以正确地为一些示例数组和值,但这并不意味着不可能有错误在那里,特别是考虑到它使用浮动…

>>> import numpy as np
>>> first_index_calculate_range_like(5, np.arange(-10, 10))
16
>>> np.arange(-10, 10)[16]  # double check
6


>>> first_index_calculate_range_like(4.8, np.arange(-10, 10))
15

假设它可以计算位置而不需要任何迭代,它将是常数时间(O(1)),并且可能击败所有其他提到的方法。但是,它要求数组中的步长为常数,否则将产生错误的结果。

使用numba的一般解决方案

更通用的方法是使用numba函数:

@nb.njit
def first_index_numba(val, arr):
for idx in range(len(arr)):
if arr[idx] > val:
return idx
return -1

这对任何数组都适用,但它必须遍历数组,所以在平均情况下,它将是O(n):

>>> first_index_numba(4.8, np.arange(-10, 10))
15
>>> first_index_numba(5, np.arange(-10, 10))
16

基准

尽管Nico Schlömer已经提供了一些基准,但我认为将我的新解决方案包括在内并测试不同的“值”可能是有用的。

测试设置:

import numpy as np
import math
import numba as nb


def first_index_using_argmax(val, arr):
return np.argmax(arr > val)


def first_index_using_where(val, arr):
return np.where(arr > val)[0][0]


def first_index_using_nonzero(val, arr):
return np.nonzero(arr > val)[0][0]


def first_index_using_searchsorted(val, arr):
return np.searchsorted(arr, val) + 1


def first_index_using_min(val, arr):
return np.min(np.where(arr > val))


def first_index_calculate_range_like(val, arr):
if len(arr) == 0:
raise ValueError('empty array')
elif len(arr) == 1:
if arr[0] > val:
return 0
else:
raise ValueError('no value greater than {}'.format(val))


first_value = arr[0]
step = arr[1] - first_value
if step <= 0:
if first_value > val:
return 0
else:
raise ValueError('no value greater than {}'.format(val))


calculated_position = (val - first_value) / step


if calculated_position < 0:
return 0
elif calculated_position > len(arr) - 1:
raise ValueError('no value greater than {}'.format(val))


return int(calculated_position) + 1


@nb.njit
def first_index_numba(val, arr):
for idx in range(len(arr)):
if arr[idx] > val:
return idx
return -1


funcs = [
first_index_using_argmax,
first_index_using_min,
first_index_using_nonzero,
first_index_calculate_range_like,
first_index_numba,
first_index_using_searchsorted,
first_index_using_where
]


from simple_benchmark import benchmark, MultiArgument

这些图是用以下方法生成的:

%matplotlib notebook
b.plot()

Item在开头

b = benchmark(
funcs,
{2**i: MultiArgument([0, np.arange(2**i)]) for i in range(2, 20)},
argument_name="array size")

enter image description here

numba函数的性能最好,其次是compute -函数和searchsorted函数。其他解决方案的性能要差得多。

Item在最后

b = benchmark(
funcs,
{2**i: MultiArgument([2**i-2, np.arange(2**i)]) for i in range(2, 20)},
argument_name="array size")

enter image description here

对于小型数组,numba函数执行得非常快,但是对于较大的数组,它被compute -function和searchsorted函数所超越。

项目位于根号(len)

b = benchmark(
funcs,
{2**i: MultiArgument([np.sqrt(2**i), np.arange(2**i)]) for i in range(2, 20)},
argument_name="array size")

enter image description here

这个更有趣。numba和calculate函数执行得很好,但这实际上触发了最坏的情况searchsorted,在这种情况下工作得不好。

没有值满足条件时函数的比较

另一个有趣的地方是,如果没有应该返回索引的值,这些函数将如何表现:

arr = np.ones(100)
value = 2


for func in funcs:
print(func.__name__)
try:
print('-->', func(value, arr))
except Exception as e:
print('-->', e)

结果是:

first_index_using_argmax
--> 0
first_index_using_min
--> zero-size array to reduction operation minimum which has no identity
first_index_using_nonzero
--> index 0 is out of bounds for axis 0 with size 0
first_index_calculate_range_like
--> no value greater than 2
first_index_numba
--> -1
first_index_using_searchsorted
--> 101
first_index_using_where
--> index 0 is out of bounds for axis 0 with size 0

Searchsorted、argmax和numba返回错误的值。然而,searchsortednumba返回的索引不是数组的有效索引。

函数whereminnonzerocalculate抛出异常。然而,只有calculate的异常实际上说了任何有用的东西。

这意味着实际上必须将这些调用包装在一个适当的包装器函数中,该函数捕获异常或无效的返回值并进行适当处理,至少在您不确定值是否可以在数组中时是这样。


注意:calculate和searchsorted选项仅在特殊情况下有效。“计算”函数需要一个常量步长,而searchsorted则需要对数组进行排序。所以这些在适当的情况下可能有用,但不是这个问题的一般解决方案。如果你在处理排序 Python列表,你可能想看看平分模块,而不是使用Numpys searchsorted。

我想求婚

np.min(np.append(np.where(aa>5)[0],np.inf))

这将返回满足条件的最小下标,而如果不满足条件则返回无穷大(并且where返回空数组)。

你应该使用__ABC0来代替np.argmax。后者即使没有找到值也会返回位置0,这不是您期望的索引。

>>> aa = np.array(range(-10,10))
>>> print(aa)
array([-10,  -9,  -8,  -7,  -6,  -5,  -4,  -3,  -2,  -1,   0,   1,   2,
3,   4,   5,   6,   7,   8,   9])

如果满足条件,则返回一个索引数组。

>>> idx = np.where(aa > 5)[0]
>>> print(idx)
array([16, 17, 18, 19], dtype=int64)

否则,如果不满足,则返回一个空数组。

>>> not_found = len(np.where(aa > 20)[0])
>>> print(not_found)
array([], dtype=int64)

在这种情况下,针对argmax的点是:越简单越好,如果解决方案不含糊。因此,要检查是否有东西符合条件,只需执行if len(np.where(aa > value_to_search)[0]) > 0