Two alternative solutions:
-
Instead of copying every i
to the outside index
and fake-testing its truth, create outside access to i
for potential use after an error occurred.
-
Itertools, to benefit from C speed.
Code with demo:
from itertools import compress, count
# Setup
iterable = [True] * 100
iterable[42] = False
def f(v):
assert v
# Solution 1
try:
[f(v)
for _ in [None]
if (last_i := lambda: i)
for i, v in enumerate(iterable)]
except AssertionError:
print('error at', last_i())
# Solution 2
try:
ctr = count(1)
[f(v) for v in compress(iterable, ctr)]
except AssertionError:
print('error at', next(ctr) - 2)
Output (Attempt This Online!):
error at 42
error at 42
Benchmark results:
482.7 ± 3.7 μs [f(v) for v in iterable]
485.3 ± 4.6 μs it = iter(iterable); [f(v) for v in it]
607.8 ± 6.4 μs ctr = count(1); [f(v) for v in compress(iterable, ctr)]
629.2 ± 4.3 μs [f(v) for _ in [None] if (last_i := lambda: i) for i, v in enumerate(iterable)]
699.4 ± 7.1 μs [f(v) for i,v in enumerate(iterable, start=1) if (index := i)]
938.4 ± 7.5 μs [f((index:=i,v)[1]) for i,v in enumerate(iterable, start=1)]
Python: 3.10.8 (main, Oct 11 2022, 11:35:05) [GCC 11.3.0]
(The it = ...
solution is from my later “One more” message. Had to omit your wrapper
one, somehow it doesn’t work with this timeit usage.)
Benchmark code:
from timeit import timeit
from statistics import mean, stdev
import sys
setup = '''
from itertools import compress, count
def f(v):
return v
iterable = range(5000)
index = 0
def wrapper(v):
# nonlocal index
r = f(v)
index += 1
return r
'''
solutions = [
'''[f(v) for v in iterable]''',
'''[f(v) for i,v in enumerate(iterable, start=1) if (index := i)]''',
'''[f((index:=i,v)[1]) for i,v in enumerate(iterable, start=1)]''',
#'''[wrapper(i) for i in iterable]''',
'''[f(v) for _ in [None] if (last_i := lambda: i) for i, v in enumerate(iterable)]''',
'''ctr = count(1); [f(v) for v in compress(iterable, ctr)]''',
'''it = iter(iterable); [f(v) for v in it]''',
]
times = {s: [] for s in solutions}
def stats(s):
ts = [t * 1e6 for t in sorted(times[s])[:10]]
return f'{mean(ts):6.1f} ± {stdev(ts):3.1f} μs '
for _ in range(100):
for s in solutions:
t = timeit(s, setup, number=10) / 10
times[s].append(t)
for s in sorted(solutions, key=stats):
print(stats(s), s)
print('Python:', sys.version)
Attempt This Online!