如何将可迭代文件分割为常量大小的块

令我惊讶的是,我找不到一个将迭代器作为输入并返回迭代器的“批处理”函数。

例如:

for i in batch(range(0,10), 1): print i
[0]
[1]
...
[9]

或:

for i in batch(range(0,10), 3): print i
[0,1,2]
[3,4,5]
[6,7,8]
[9]

现在,我写了一个我认为非常简单的生成器:

def batch(iterable, n = 1):
current_batch = []
for item in iterable:
current_batch.append(item)
if len(current_batch) == n:
yield current_batch
current_batch = []
if current_batch:
yield current_batch

但上述情况并没有给我带来我所期望的结果:

for x in   batch(range(0,10),3): print x
[0]
[0, 1]
[0, 1, 2]
[3]
[3, 4]
[3, 4, 5]
[6]
[6, 7]
[6, 7, 8]
[9]

因此,我遗漏了一些东西,这可能表明我完全不了解 Python 生成器。有人愿意告诉我正确的方向吗?

[编辑: 我最终意识到上面的行为只发生在我在 ipython 而不是 python 本身中运行的时候]

120073 次浏览

奇怪的是,在 Python2.x 中似乎对我很有用

>>> def batch(iterable, n = 1):
...    current_batch = []
...    for item in iterable:
...        current_batch.append(item)
...        if len(current_batch) == n:
...            yield current_batch
...            current_batch = []
...    if current_batch:
...        yield current_batch
...
>>> for x in batch(range(0, 10), 3):
...     print x
...
[0, 1, 2]
[3, 4, 5]
[6, 7, 8]
[9]

Itertools 模块中的配方提供了两种方法来实现这一点,具体取决于您希望如何处理最终的奇数批量(保留它,用一个填充值填充它,忽略它,或引发异常) :

from itertools import islice, zip_longest


def batched(iterable, n):
"Batch data into lists of length n. The last batch may be shorter."
# batched('ABCDEFG', 3) --> ABC DEF G
it = iter(iterable)
while True:
batch = list(islice(it, n))
if not batch:
return
yield batch


def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
"Collect data into non-overlapping fixed-length chunks or blocks"
# grouper('ABCDEFG', 3, fillvalue='x') --> ABC DEF Gxx
# grouper('ABCDEFG', 3, incomplete='strict') --> ABC DEF ValueError
# grouper('ABCDEFG', 3, incomplete='ignore') --> ABC DEF
args = [iter(iterable)] * n
if incomplete == 'fill':
return zip_longest(*args, fillvalue=fillvalue)
if incomplete == 'strict':
return zip(*args, strict=True)
if incomplete == 'ignore':
return zip(*args)
else:
raise ValueError('Expected fill, strict, or ignore')

这可能更有效(更快)

def batch(iterable, n=1):
l = len(iterable)
for ndx in range(0, l, n):
yield iterable[ndx:min(ndx + n, l)]


for x in batch(range(0, 10), 3):
print x

示例使用列表

data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # list of data


for x in batch(data, 3):
print(x)


# Output


[0, 1, 2]
[3, 4, 5]
[6, 7, 8]
[9, 10]

它避免了建立新的列表。

正如其他人指出的那样,您提供的代码完全可以满足您的需要。对于使用 itertools.islice的另一种方法,您可以看到以下配方的 例子:

from itertools import islice, chain


def batch(iterable, size):
sourceiter = iter(iterable)
while True:
batchiter = islice(sourceiter, size)
yield chain([batchiter.next()], batchiter)

这是我在我的项目中使用的。它尽可能有效地处理可迭代文件或列表。

def chunker(iterable, size):
if not hasattr(iterable, "__len__"):
# generators don't have len, so fall back to slower
# method that works with generators
for chunk in chunker_gen(iterable, size):
yield chunk
return


it = iter(iterable)
for i in range(0, len(iterable), size):
yield [k for k in islice(it, size)]




def chunker_gen(generator, size):
iterator = iter(generator)
for first in iterator:


def chunk():
yield first
for more in islice(iterator, size - 1):
yield more


yield [k for k in chunk()]

这对任何迭代都适用。

from itertools import zip_longest, filterfalse


def batch_iterable(iterable, batch_size=2):
args = [iter(iterable)] * batch_size
return (tuple(filterfalse(lambda x: x is None, group)) for group in zip_longest(fillvalue=None, *args))

工作原理是这样的:

>>>list(batch_iterable(range(0,5)), 2)
[(0, 1), (2, 3), (4,)]

PS: 如果 iterable 具有 Nothing 值,那么它将无法工作。

下面是一个使用 reduce函数的方法。

线条:

from functools import reduce
reduce(lambda cumulator,item: cumulator[-1].append(item) or cumulator if len(cumulator[-1]) < batch_size else cumulator + [[item]], input_array, [[]])

或更易读的版本:

from functools import reduce
def batch(input_list, batch_size):
def reducer(cumulator, item):
if len(cumulator[-1]) < batch_size:
cumulator[-1].append(item)
return cumulator
else:
cumulator.append([item])
return cumulator
return reduce(reducer, input_list, [[]])

测试:

>>> batch([1,2,3,4,5,6,7], 3)
[[1, 2, 3], [4, 5, 6], [7]]
>>> batch(a, 8)
[[1, 2, 3, 4, 5, 6, 7]]
>>> batch([1,2,3,None,4], 3)
[[1, 2, 3], [None, 4]]

这是一个非常简短的代码片段,我知道它不使用 len,并且可以在 Python 2和3(不是我创建的)下工作:

def chunks(iterable, size):
from itertools import chain, islice
iterator = iter(iterable)
for first in iterator:
yield list(chain([first], islice(iterator, size - 1)))

More-itertools 包含两个函数,它们可以满足您的需要:

def batch(iterable, n):
iterable=iter(iterable)
while True:
chunk=[]
for i in range(n):
try:
chunk.append(next(iterable))
except StopIteration:
yield chunk
return
yield chunk


list(batch(range(10), 3))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]

您可以根据可迭代项的批索引对它们进行分组。

def batch(items: Iterable, batch_size: int) -> Iterable[Iterable]:
# enumerate items and group them by batch index
enumerated_item_groups = itertools.groupby(enumerate(items), lambda t: t[0] // batch_size)
# extract items from enumeration tuples
item_batches = ((t[1] for t in enumerated_items) for key, enumerated_items in enumerated_item_groups)
return item_batches

当您希望收集内部可迭代文件时,通常会遇到这种情况,因此这里提供了更高级的版本。

def batch_advanced(items: Iterable, batch_size: int, batches_mapper: Callable[[Iterable], Any] = None) -> Iterable[Iterable]:
enumerated_item_groups = itertools.groupby(enumerate(items), lambda t: t[0] // batch_size)
if batches_mapper:
item_batches = (batches_mapper(t[1] for t in enumerated_items) for key, enumerated_items in enumerated_item_groups)
else:
item_batches = ((t[1] for t in enumerated_items) for key, enumerated_items in enumerated_item_groups)
return item_batches

例子:

print(list(batch_advanced([1, 9, 3, 5, 2, 4, 2], 4, tuple)))
# [(1, 9, 3, 5), (2, 4, 2)]
print(list(batch_advanced([1, 9, 3, 5, 2, 4, 2], 4, list)))
# [[1, 9, 3, 5], [2, 4, 2]]

您可能需要的相关功能:

def batch(size, i):
""" Get the i'th batch of the given size """
return slice(size* i, size* i + size)

用法:

>>> [1,2,3,4,5,6,7,8,9,10][batch(3, 1)]
>>> [4, 5, 6]

它从序列中获得第1个批处理,它还可以与其他数据结构一起工作,如熊猫数据帧(df.iloc[batch(100,0)])或数字数组(array[batch(100,0)])。

from itertools import *


class SENTINEL: pass


def batch(iterable, n):
return (tuple(filterfalse(lambda x: x is SENTINEL, group)) for group in zip_longest(fillvalue=SENTINEL, *[iter(iterable)] * n))


print(list(range(10), 3)))
# outputs: [(0, 1, 2), (3, 4, 5), (6, 7, 8), (9,)]
print(list(batch([None]*10, 3)))
# outputs: [(None, None, None), (None, None, None), (None, None, None), (None,)]

Python 3.8的解决方案如果你使用的迭代器没有定义 len函数,并且用尽了:

from itertools import islice


def batcher(iterable, batch_size):
iterator = iter(iterable)
while batch := list(islice(iterator, batch_size)):
yield batch

示例用法:

def my_gen():
yield from range(10)
 

for batch in batcher(my_gen(), 3):
print(batch)


>>> [0, 1, 2]
>>> [3, 4, 5]
>>> [6, 7, 8]
>>> [9]


当然也可以在没有海象操作员的情况下实现。

我吸毒

def batchify(arr, batch_size):
num_batches = math.ceil(len(arr) / batch_size)
return [arr[i*batch_size:(i+1)*batch_size] for i in range(num_batches)]
  

保持(最多) n 个元素,直到它用完。

def chop(n, iterable):
iterator = iter(iterable)
while chunk := list(take(n, iterator)):
yield chunk




def take(n, iterable):
iterator = iter(iterable)
for i in range(n):
try:
yield next(iterator)
except StopIteration:
return

Python 3.8中没有新特性的可行版本,改编自@Atra Azami 的回答。

import itertools


def batch_generator(iterable, batch_size=1):
iterable = iter(iterable)


while True:
batch = list(itertools.islice(iterable, batch_size))
if len(batch) > 0:
yield batch
else:
break


for x in batch_generator(range(0, 10), 3):
print(x)

产出:

[0, 1, 2]
[3, 4, 5]
[6, 7, 8]
[9]

通过利用 islice 和 iter (可调用)行为,尽可能多地进入 CPython:

from itertools import islice


def chunked(generator, size):
"""Read parts of the generator, pause each time after a chunk"""
# islice returns results until 'size',
# make_chunk gets repeatedly called by iter(callable).
gen = iter(generator)
make_chunk = lambda: list(islice(gen, size))
return iter(make_chunk, [])

受 more-itertools 的启发,简化为代码的本质。

该守则具有以下特点:

  • 可以将列表或生成器(没有 len ())作为输入
  • 不需要导入其他包
  • 没有填充添加到最后一批
def batch_generator(items, batch_size):
itemid=0 # Keeps track of current position in items generator/list
batch = [] # Empty batch
for item in items:
batch.append(item) # Append items to batch
if len(batch)==batch_size:
yield batch
itemid += batch_size # Increment the position in items
batch = []
yield batch # yield last bit

我喜欢这个,

def batch(x, bs):
return [x[i:i+bs] for i in range(0, len(x), bs)]

这将返回一个大小为 bs的批处理列表,当然,您可以使用生成器表达式 (i for i in iterable)使其成为生成器。