A common task that comes up when writing code is to group items by some property. For example: group people by their last name, or group integers by their last digit. To be specific, grouping takes an iterable and a key function, and returns a dictionary mapping results-of-the-keying-function to the list of items that produced that result.
Grouping is similar to filtering or mapping, in that it’s actually not very much code to write, but it’s useful in enough places to make a standardized method worthwhile. I find myself adding this method to projects all the time.
One major obstacle here is that itertools.groupby
exists, and has the wrong semantics. In particular, itertools.groupby
requires same-key items to be contiguous in the sequence in order for them to end up in the same group. This is a bit like having a sort
method that’s a single pass of a bubble sort: it kind of has the right idea, and it’s more efficient than full sorting, but in practice it tends not to match up with what you actually need. Perhaps the itertools.groupby
method could be modified to have an option to merge non-contiguous groups? For example, that might look like groups = dict(itertools.groupby(items, key=key, ensure_unique_keys=True))
?
Anyways, to be concrete, this is what I imagined the method being:
from typing import Callable, TypeVar, List, Dict, Iterator
TVal = TypeVar('TVal')
TKey = TypeVar('TKey')
def group_by(items: Iterator[TVal],
*,
key: Callable[[TVal], TKey],
) -> Dict[TKey, List[TVal]]:
"""Groups items based on whether they produce the same key from a function.
Args:
items: The items to group.
key: Items that produce the same value from this function get grouped together.
Returns:
A dictionary mapping outputs that were produced by the grouping function to
the list of items that produced that output.
Examples:
>>> group_by([1, 2, 3], key=lambda i: i == 2)
{False: [1, 3], True: [2]}
>>> group_by(range(10), key=lambda i: i % 3)
{0: [0, 3, 6, 9], 1: [1, 4, 7], 2: [2, 5, 8]}
"""
result: Dict[TKey, List[TVal]] = {}
for item in items:
item_key = key(item)
result.setdefault(item_key, []).append(item)
return result