I’ve been working on a pattern for writing coroutines that have a distinct setup phase and a subsequent asynchronous data production phase. This is a common need when dealing with things like initialization, resource acquisition, and then asynchronously yielding results. My goal was to provide a way for users to write these coroutines in a way that feels as natural as a regular synchronous function, hiding the complexity of the asynchronous transitions.
Example Usage
async def my_data_producer(context):
# Setup phase (runs immediately)
...
async with context.session() as session:
# Data production (only starts when the async iterator is consumed)
for i in range(10):
await session.send(i)
# Converting the coroutine to an async iterator
with coro_to_aiter(int, my_data_producer, "producer") as it:
async for val in it:
print(val)
Summary of Current Solution
My solution centers around wrapping the coroutine into a managed task and using an asyncio.Queue
for communication between the coroutine and the external process. Internally, the coroutine signals completion of its setup phase by yielding a sentinel value. This setup completion handling is encapsulated within a custom asynchronous context manager to provide a clean user interface.
The core elements of this solution are:
- Managed Task: The coroutine is executed as an
asyncio.Task
. - Communication Queue: An
asyncio.Queue
coordinates the flow of data and signals (e.g., setup completion, end-of-stream, exceptions) between the coroutine’s phases and the consuming code. - Custom Context Manager: A context manager (
with ctx.session()
) streamlines the interaction, hides the internal sentinel signaling, and provides a user-friendly way to demarcate the setup and data production phases.
I’m quite happy with how this looks like from the user point of view, but the implementation surely is complex. I see no way to do it without a task, but perhaps I’m missing something. The code below is a bit simplified in that it uses asyncio.create_task()
instead of taking a task group.
Code
Implementation
"""
coro_to_aiter: A function to convert a coroutine into an async iterator.
This is useful if you want to write a coroutine that yields data, but you want to do some setup before you start
yielding data.
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from typing import Any, AsyncIterator, Callable, Coroutine, Generic, Self, TypeVar
_T = TypeVar("_T")
_T2 = TypeVar("_T2")
class EndOfStream:
"""A sentinel value used to signal the end of a stream of data."""
class SetupComplete:
"""A token used to signal that a coroutine has started and finished setup."""
@dataclass
class _ExceptionWrapper:
exc: BaseException
class _QueueWithTask(Generic[_T]):
"""A wrapper around a queue and a task that monitors the task while waiting for data."""
_queue: asyncio.Queue[_T]
_task: asyncio.Task[Any]
def __init__(self, queue: asyncio.Queue[_T], task: asyncio.Task[Any]) -> None:
self._queue = queue
self._task = task
async def get(self) -> _T | EndOfStream:
try:
queue_get_task = asyncio.Task(self._queue.get(), name="queue_get_task")
done, _ = await asyncio.wait([self._task, queue_get_task], return_when=asyncio.FIRST_COMPLETED)
if queue_get_task in done:
return queue_get_task.result()
if self._task in done:
# logger.warning("task result: {}", self._task.result())
# logger.warning("Task finished before queue")
return EndOfStream()
finally:
queue_get_task.cancel()
raise RuntimeError("Unexpected state")
class ProducerContext(Generic[_T]):
_queue: asyncio.Queue[_T | SetupComplete | EndOfStream | _ExceptionWrapper]
_want_data_sema: asyncio.Semaphore # used to signal that someone wants data
class _Session(Generic[_T2]):
_ctx: ProducerContext[_T2]
def __init__(self, ctx: ProducerContext[_T2]) -> None:
self._ctx = ctx
async def __aenter__(self) -> Self:
await self._ctx._queue.put(SetupComplete())
await self._ctx._want_data_sema.acquire()
return self
async def __aexit__(self, exc_type: object, exc_val: object, exc_tb: object) -> bool:
if exc_type is not None:
# logger.warning("Exception in producer coroutine: {}", exc_val)
assert isinstance(exc_val, BaseException)
await self._ctx._queue.put(_ExceptionWrapper(exc_val))
# suppress the exception here
return True
await self._ctx._queue.put(EndOfStream())
return False
async def send(self, data: _T2) -> None:
await self._ctx._queue.put(data)
await self._ctx._want_data_sema.acquire()
def __init__(self) -> None:
self._queue = asyncio.Queue()
self._want_data_sema = asyncio.Semaphore(0)
def session(self) -> ProducerContext._Session[_T]:
return self._Session(self)
_ProducerContextT = TypeVar("_ProducerContextT", bound=ProducerContext[Any])
class _CoroToAiter(Generic[_T, _ProducerContextT]):
"""A wrapper around a coroutine to convert it into an async iterator."""
_coroutine: Callable[[_ProducerContextT], Coroutine[Any, Any, Any]]
_name: str
_context_cls: Callable[[], _ProducerContextT]
_task: asyncio.Task[Any] | None = None
_ctx: _ProducerContextT
def __init__(
self,
*,
elem_type: type[_T],
coroutine: Callable[[_ProducerContextT], Coroutine[Any, Any, None]],
name: str,
context_cls: Callable[[], _ProducerContextT],
) -> None:
self._elem_type = elem_type
self._coroutine = coroutine
self._name = name
self._context_cls = context_cls
self._ctx = context_cls()
def __enter__(self) -> AsyncIterator[_T | SetupComplete]:
# logger.debug("Entering coro_to_aiter: {}", self._name)
self._task = asyncio.create_task(self._coroutine(self._ctx), name=self._name)
return _coro_to_aiter_with_task(self._task, self._elem_type, self._ctx)
def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
if exc_type is not None:
# logger.warning("exit: exc_type={}", exc_type)
pass
del exc_type, exc_val, exc_tb
if self._task is not None and not self._task.done():
# logger.warning("Cancelling task")
self._task.cancel()
def coro_to_aiter(
elem_type: type[_T],
coroutine: Callable[[_ProducerContextT], Coroutine[Any, Any, None]],
name: str,
context_cls: Callable[[], _ProducerContextT] = ProducerContext, # type: ignore[assignment]
) -> _CoroToAiter[_T, _ProducerContextT]:
"""
Convert a coroutine that produces data using ProducerContext into an async iterator.
This is particularly useful for coroutines that involve setup and teardown phases, or when
the data production is logically separated from data consumption.
Usage:
async def data_producer(ctx: ProducerContext[int]) -> None:
# setup phase
...
async with ctx.session() as session:
# produce data
await session.send(1)
# Wrap the coroutine as an async iterator
with coro_to_aiter(int, data_producer, "int_producer") as async_iter:
# Use it like any other async iterator
async for item in async_iter:
...
`async_iter` in the above example will first yield a SetupComplete() after it has finished setup, and then yield
the data.
@param elem_type: The type of elements that the coroutine produces.
@param coroutine: The coroutine function yielding data.
@param name: Name of the coroutine for easier identification.
@param context_cls: The class of the context given to the coroutine. Defaults to ProducerContext.
@return An instance of _CoroToAiter managing the lifecycle of the coroutine and its output.
"""
return _CoroToAiter(
elem_type=elem_type,
coroutine=coroutine,
name=name,
context_cls=context_cls,
)
async def _coro_to_aiter_with_task(
task: asyncio.Task[Any], elem_type: type[_T], ctx: _ProducerContextT
) -> AsyncIterator[_T | SetupComplete]:
del elem_type # unused, but needed for type checking
qt = _QueueWithTask(ctx._queue, task)
try:
# The first item is the SetupComplete
# logger.debug("Waiting for setup complete token")
start_token = await qt.get()
assert isinstance(start_token, SetupComplete)
# logger.debug("Got setup complete token")
yield start_token
ctx._want_data_sema.release()
while True:
# logger.debug("Waiting for data")
result = await qt.get()
if isinstance(result, EndOfStream):
await task
break
elif isinstance(result, _ExceptionWrapper):
# logger.warning("Got exception in producer coroutine: {}", result.exc)
raise result.exc
assert not isinstance(result, SetupComplete)
assert result is not None
yield result
ctx._want_data_sema.release()
# logger.debug("Task done")
finally:
if not task.done():
task.cancel()
Tests
import asyncstdlib as astd
import pytest
import coro_to_aiter as c2a
@pytest.mark.timeout(2)
async def test_coro_to_aiter_simple() -> None:
async def coro(ctx: c2a.ProducerContext[int]) -> None:
# logger.info("Starting coro")
async with ctx.session() as session:
# logger.info("Sending SetupComplete")
for i in range(10):
await session.send(i)
with c2a.coro_to_aiter(int, coro, name="test producer") as c:
it = aiter(c)
# First item is the SetupComplete
start_token = await anext(it)
assert isinstance(start_token, c2a.SetupComplete)
async for i, data in astd.enumerate(it):
assert data == i
assert i == 9
@pytest.mark.timeout(2)
async def test_coro_to_aiter_is_lazy() -> None:
"""Test that coro_to_aiter is lazy."""
last_sent = -2
async def coro(ctx: c2a.ProducerContext[int]) -> None:
nonlocal last_sent
last_sent = -1
async with ctx.session() as session:
for i in range(10):
last_sent = i
await session.send(i)
with c2a.coro_to_aiter(int, coro, name="test producer") as it:
assert last_sent == -2
# First item is the SetupComplete
start_token = await anext(it)
assert isinstance(start_token, c2a.SetupComplete)
assert last_sent == -1
async for i, data in astd.enumerate(it):
assert data == i
assert last_sent == data
assert i == 9
@pytest.mark.timeout(2)
async def test_exception_during_setup() -> None:
class TestException(Exception):
pass
async def coro(ctx: c2a.ProducerContext[int]) -> None:
raise TestException("test")
with c2a.coro_to_aiter(int, coro, name="test producer") as it, pytest.raises(TestException):
await anext(it)
@pytest.mark.timeout(2)
async def test_exception_during_stream() -> None:
class TestException(Exception):
pass
async def coro(ctx: c2a.ProducerContext[int]) -> None:
async with ctx.session() as session:
for i in range(10):
await session.send(i)
raise TestException("test")
with c2a.coro_to_aiter(int, coro, name="test producer") as it:
# First item is the SetupComplete
start_token = await anext(it)
assert isinstance(start_token, c2a.SetupComplete)
for i in range(10):
assert await anext(it) == i
with pytest.raises(TestException):
await anext(it)
The Questions
-
Are there significantly simpler approaches to this problem that I’m overlooking? I’ve achieved the desired functionality, but my current implementation feels overly complex for what it does.
-
If fundamental simplification isn’t possible, could this use case indicate a potential area for evolution in the Python standard library or language itself?
I’m eager to hear insights from the Python community, especially those with expertise in coroutines and asynchronous programming. Thank you for any guidance and suggestions!