Simplifying two-phase coroutines for enhanced user-friendliness

I’ve been working on a pattern for writing coroutines that have a distinct setup phase and a subsequent asynchronous data production phase. This is a common need when dealing with things like initialization, resource acquisition, and then asynchronously yielding results. My goal was to provide a way for users to write these coroutines in a way that feels as natural as a regular synchronous function, hiding the complexity of the asynchronous transitions.

Example Usage

async def my_data_producer(context):
    # Setup phase (runs immediately)
    ... 
    async with context.session() as session:
        # Data production (only starts when the async iterator is consumed)
        for i in range(10):
            await session.send(i) 

# Converting the coroutine to an async iterator
with coro_to_aiter(int, my_data_producer, "producer") as it:
    async for val in it:
        print(val)

Summary of Current Solution

My solution centers around wrapping the coroutine into a managed task and using an asyncio.Queue for communication between the coroutine and the external process. Internally, the coroutine signals completion of its setup phase by yielding a sentinel value. This setup completion handling is encapsulated within a custom asynchronous context manager to provide a clean user interface.

The core elements of this solution are:

  • Managed Task: The coroutine is executed as an asyncio.Task.
  • Communication Queue: An asyncio.Queue coordinates the flow of data and signals (e.g., setup completion, end-of-stream, exceptions) between the coroutine’s phases and the consuming code.
  • Custom Context Manager: A context manager (with ctx.session() ) streamlines the interaction, hides the internal sentinel signaling, and provides a user-friendly way to demarcate the setup and data production phases.

I’m quite happy with how this looks like from the user point of view, but the implementation surely is complex. I see no way to do it without a task, but perhaps I’m missing something. The code below is a bit simplified in that it uses asyncio.create_task() instead of taking a task group.

Code

Implementation
"""
coro_to_aiter: A function to convert a coroutine into an async iterator.

This is useful if you want to write a coroutine that yields data, but you want to do some setup before you start
yielding data.
"""

from __future__ import annotations

import asyncio
from dataclasses import dataclass
from typing import Any, AsyncIterator, Callable, Coroutine, Generic, Self, TypeVar


_T = TypeVar("_T")
_T2 = TypeVar("_T2")


class EndOfStream:
    """A sentinel value used to signal the end of a stream of data."""


class SetupComplete:
    """A token used to signal that a coroutine has started and finished setup."""


@dataclass
class _ExceptionWrapper:
    exc: BaseException


class _QueueWithTask(Generic[_T]):
    """A wrapper around a queue and a task that monitors the task while waiting for data."""

    _queue: asyncio.Queue[_T]
    _task: asyncio.Task[Any]

    def __init__(self, queue: asyncio.Queue[_T], task: asyncio.Task[Any]) -> None:
        self._queue = queue
        self._task = task

    async def get(self) -> _T | EndOfStream:
        try:
            queue_get_task = asyncio.Task(self._queue.get(), name="queue_get_task")
            done, _ = await asyncio.wait([self._task, queue_get_task], return_when=asyncio.FIRST_COMPLETED)
            if queue_get_task in done:
                return queue_get_task.result()

            if self._task in done:
                # logger.warning("task result: {}", self._task.result())
                # logger.warning("Task finished before queue")
                return EndOfStream()
        finally:
            queue_get_task.cancel()

        raise RuntimeError("Unexpected state")


class ProducerContext(Generic[_T]):
    _queue: asyncio.Queue[_T | SetupComplete | EndOfStream | _ExceptionWrapper]
    _want_data_sema: asyncio.Semaphore  # used to signal that someone wants data

    class _Session(Generic[_T2]):
        _ctx: ProducerContext[_T2]

        def __init__(self, ctx: ProducerContext[_T2]) -> None:
            self._ctx = ctx

        async def __aenter__(self) -> Self:
            await self._ctx._queue.put(SetupComplete())
            await self._ctx._want_data_sema.acquire()
            return self

        async def __aexit__(self, exc_type: object, exc_val: object, exc_tb: object) -> bool:
            if exc_type is not None:
                # logger.warning("Exception in producer coroutine: {}", exc_val)
                assert isinstance(exc_val, BaseException)
                await self._ctx._queue.put(_ExceptionWrapper(exc_val))
                # suppress the exception here
                return True
            await self._ctx._queue.put(EndOfStream())
            return False

        async def send(self, data: _T2) -> None:
            await self._ctx._queue.put(data)
            await self._ctx._want_data_sema.acquire()

    def __init__(self) -> None:
        self._queue = asyncio.Queue()
        self._want_data_sema = asyncio.Semaphore(0)

    def session(self) -> ProducerContext._Session[_T]:
        return self._Session(self)


_ProducerContextT = TypeVar("_ProducerContextT", bound=ProducerContext[Any])


class _CoroToAiter(Generic[_T, _ProducerContextT]):
    """A wrapper around a coroutine to convert it into an async iterator."""

    _coroutine: Callable[[_ProducerContextT], Coroutine[Any, Any, Any]]
    _name: str
    _context_cls: Callable[[], _ProducerContextT]
    _task: asyncio.Task[Any] | None = None
    _ctx: _ProducerContextT

    def __init__(
        self,
        *,
        elem_type: type[_T],
        coroutine: Callable[[_ProducerContextT], Coroutine[Any, Any, None]],
        name: str,
        context_cls: Callable[[], _ProducerContextT],
    ) -> None:
        self._elem_type = elem_type
        self._coroutine = coroutine
        self._name = name
        self._context_cls = context_cls
        self._ctx = context_cls()

    def __enter__(self) -> AsyncIterator[_T | SetupComplete]:
        # logger.debug("Entering coro_to_aiter: {}", self._name)
        self._task = asyncio.create_task(self._coroutine(self._ctx), name=self._name)
        return _coro_to_aiter_with_task(self._task, self._elem_type, self._ctx)

    def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
        if exc_type is not None:
            # logger.warning("exit: exc_type={}", exc_type)
            pass
        del exc_type, exc_val, exc_tb
        if self._task is not None and not self._task.done():
            # logger.warning("Cancelling task")
            self._task.cancel()


def coro_to_aiter(
    elem_type: type[_T],
    coroutine: Callable[[_ProducerContextT], Coroutine[Any, Any, None]],
    name: str,
    context_cls: Callable[[], _ProducerContextT] = ProducerContext,  # type: ignore[assignment]
) -> _CoroToAiter[_T, _ProducerContextT]:
    """
    Convert a coroutine that produces data using ProducerContext into an async iterator.

    This is particularly useful for coroutines that involve setup and teardown phases, or when
    the data production is logically separated from data consumption.

    Usage:
        async def data_producer(ctx: ProducerContext[int]) -> None:
            # setup phase
            ...
            async with ctx.session() as session:
                # produce data
                await session.send(1)

        # Wrap the coroutine as an async iterator
        with coro_to_aiter(int, data_producer, "int_producer") as async_iter:

            # Use it like any other async iterator
            async for item in async_iter:
                    ...

    `async_iter` in the above example will first yield a SetupComplete() after it has finished setup, and then yield
    the data.

    @param elem_type: The type of elements that the coroutine produces.
    @param coroutine: The coroutine function yielding data.
    @param name: Name of the coroutine for easier identification.
    @param context_cls: The class of the context given to the coroutine. Defaults to ProducerContext.

    @return An instance of _CoroToAiter managing the lifecycle of the coroutine and its output.
    """
    return _CoroToAiter(
        elem_type=elem_type,
        coroutine=coroutine,
        name=name,
        context_cls=context_cls,
    )


async def _coro_to_aiter_with_task(
    task: asyncio.Task[Any], elem_type: type[_T], ctx: _ProducerContextT
) -> AsyncIterator[_T | SetupComplete]:
    del elem_type  # unused, but needed for type checking

    qt = _QueueWithTask(ctx._queue, task)

    try:
        # The first item is the SetupComplete
        # logger.debug("Waiting for setup complete token")
        start_token = await qt.get()
        assert isinstance(start_token, SetupComplete)
        # logger.debug("Got setup complete token")
        yield start_token

        ctx._want_data_sema.release()

        while True:
            # logger.debug("Waiting for data")
            result = await qt.get()
            if isinstance(result, EndOfStream):
                await task
                break
            elif isinstance(result, _ExceptionWrapper):
                # logger.warning("Got exception in producer coroutine: {}", result.exc)
                raise result.exc
            assert not isinstance(result, SetupComplete)
            assert result is not None
            yield result
            ctx._want_data_sema.release()

        # logger.debug("Task done")
    finally:
        if not task.done():
            task.cancel()
Tests
import asyncstdlib as astd
import pytest

import coro_to_aiter as c2a


@pytest.mark.timeout(2)
async def test_coro_to_aiter_simple() -> None:
    async def coro(ctx: c2a.ProducerContext[int]) -> None:
        # logger.info("Starting coro")
        async with ctx.session() as session:
            # logger.info("Sending SetupComplete")
            for i in range(10):
                await session.send(i)

    with c2a.coro_to_aiter(int, coro, name="test producer") as c:
        it = aiter(c)

        # First item is the SetupComplete
        start_token = await anext(it)
        assert isinstance(start_token, c2a.SetupComplete)

        async for i, data in astd.enumerate(it):
            assert data == i
        assert i == 9


@pytest.mark.timeout(2)
async def test_coro_to_aiter_is_lazy() -> None:
    """Test that coro_to_aiter is lazy."""

    last_sent = -2

    async def coro(ctx: c2a.ProducerContext[int]) -> None:
        nonlocal last_sent
        last_sent = -1
        async with ctx.session() as session:
            for i in range(10):
                last_sent = i
                await session.send(i)

    with c2a.coro_to_aiter(int, coro, name="test producer") as it:
        assert last_sent == -2

        # First item is the SetupComplete
        start_token = await anext(it)
        assert isinstance(start_token, c2a.SetupComplete)

        assert last_sent == -1

        async for i, data in astd.enumerate(it):
            assert data == i
            assert last_sent == data
        assert i == 9


@pytest.mark.timeout(2)
async def test_exception_during_setup() -> None:
    class TestException(Exception):
        pass

    async def coro(ctx: c2a.ProducerContext[int]) -> None:
        raise TestException("test")

    with c2a.coro_to_aiter(int, coro, name="test producer") as it, pytest.raises(TestException):
        await anext(it)


@pytest.mark.timeout(2)
async def test_exception_during_stream() -> None:
    class TestException(Exception):
        pass

    async def coro(ctx: c2a.ProducerContext[int]) -> None:
        async with ctx.session() as session:
            for i in range(10):
                await session.send(i)
            raise TestException("test")

    with c2a.coro_to_aiter(int, coro, name="test producer") as it:
        # First item is the SetupComplete
        start_token = await anext(it)
        assert isinstance(start_token, c2a.SetupComplete)

        for i in range(10):
            assert await anext(it) == i

        with pytest.raises(TestException):
            await anext(it)

The Questions

  1. Are there significantly simpler approaches to this problem that I’m overlooking? I’ve achieved the desired functionality, but my current implementation feels overly complex for what it does.

  2. If fundamental simplification isn’t possible, could this use case indicate a potential area for evolution in the Python standard library or language itself?

I’m eager to hear insights from the Python community, especially those with expertise in coroutines and asynchronous programming. Thank you for any guidance and suggestions!

If you want that API, why not implement it directly?

@asynccontextmanager
async def my_data_producer(context):
    # Setup phase (runs immediately)
    ... 

    async def produce_data():
        # Data production (only starts when the async iterator is consumed)
        for i in range(10):
            yield i

    try:
        yield produce_data()
    finally:
        # Clean-up phase
1 Like

trio’s task spawning functionality has this builtin, via the use of a TaskStatus object and nursery.start() (nursery == TaskGroup):

import trio

async def task(count, *, task_status = trio.TASK_STATUS_IGNORED):
    # Do some setup...

    # open_memory_channel is Trio's equivalent to asyncio.Queue, but it's split into one object for each end.
    send, rec = trio.open_memory_channel(0)
    with send:
        # Signal completion and optionally pass back a value
        task_status.started(rec)
        for value in range(count):
            await send.send(value)

async def main():
    async with trio.open_nursery() as nursery:
        # Start the task, passing along 10 as positional arguments
        rec = await nursery.start(task, 10)
        with rec:
            async for value in rec:
                print(value)

trio.run(main)

What happens is that start() requires the task to have a task_status keyword argument. When executed it runs the task, then waits for it to call started() to signal that it’s setup and initialised. The parameter to started (if any) gets returned from start(). After started() is called the task gets moved to the scope of the nursery, and continues running.

TASK_STATUS_IGNORED is an implementation of TaskStatus that does nothing, and is used as the default argument so the task can still be called normally. For example, trio.open_process() supports being called normally to just wait for the process to end, or can be called via start() to give you access to the live process so you can interactively communicate with it.

Trio itself isn’t directly compatible with asyncio, but anyio provides an implementation of these semantics that works on both event loops.

1 Like

I think that it can be implemented as a decorator that creates a synchronous wrapper around the asynchronous function that calls the asynchronous function, immediately executes its first step, and returns an awaitable object that executes the rest of the code. The concrete implementation is left as a home work.

Maybe eager tasks satisfy this requirement?

Otherwise I’d structure it like this:

def coroutine(args):
    ... synchronous initialization ...
    async def real_coroutine():
        ... stuff using await ...
    return real_coroutine()

eager tasks and return real_coroutine() do require that your setup code only perform synchronous operations, which might be fine depending on what you’re doing but limits how far you can take it.

Rereading the OP, I think I misunderstood the problem anyway. :frowning:

This is what I suggested earlier (but async). I don’t think the metastructure of the OP’s solution removes any significant complexity, just renames it (and thus adds a hurdle to anyone reading/maintaining it). The small overhead needed to maintain the task and queue, plus additional complexity when debugging, are also weighing against it for me.

Yeah, I admit it’s generally a very fair tradeoff to just ask the users to make an inner coroutine. In this case I really wanted to provide the most intuitive interface, ideally usable by people who do not understand coroutines or things like nonlocal (the people who will have to work with what uses this in the project I’m in are more like audio engineers). The point about whether it’s worth the complexity is well taken, and it’s still a good question even in my context if it’s just best avoided.

Mostly I’m just surprised that this seemed so difficult!

Ooh, this sounds very promising. Thank you, I’ll take a look into anyio!

I hadn’t realised how problematic it was to use an async context manager (or await in a finally block) in an async generator! I have suggested that this might be an antipattern the flake8-trio linter could pick up on (Suggestion: error when using an async context manager in an async generator · Issue #211 · python-trio/flake8-trio · GitHub) as I couldn’t find a preexisting lint check.