Proposal for a new way to overload methods by arguments, and whether it is synchronous or asynchronous

Inspired by the `singledispatch` method in the `functools` module, I thought it would be very useful to have a method that can also overload methods or functions based on whether they are synchronous or asynchronous.

This can greatly simplify the use of libraries containing both synchronous and asynchronous code that perform the same task but use different dependencies depending on the context. It can also facilitate the migration of synchronous code to asynchronous code.

example of use

import asyncio


class Example:
    @coroutinedispatch
    def process(self, x: int) -> str:
        return f"Processing integer: {x}"

    @process.register
    async def _(self, x: int) -> str:
        await asyncio.sleep(0.1)
        return f"Processing integer asynchronously: {x}"

    @process.register
    def _(self, x: str) -> str:
        return f"Processing string: {x}"

    @process.register
    async def _(self, x: str) -> str:
        await asyncio.sleep(0.1)
        return f"Processing string asynchronously: {x}"


example = Example()
print(example.process(42))  # Processing integer: 42
print(example.process("hello"))  # Processing string: hello


async def main():
    print(await example.process(42))  # Processing integer asynchronously: 42
    print(await example.process("hello"))  # Processing string asynchronously: hello

I had thought that an example code might look something like this

import inspect
import asyncio
from typing import Any, Callable, get_type_hints, get_origin
from functools import lru_cache


class coroutinedispatch:
    def __init__(self, func: Callable):
        self._sync_methods = {}
        self._async_methods = {}
        self._name = func.__name__

        arg_types = self.get_arg_types(func)

        if inspect.iscoroutinefunction(func):
            self._async_methods[arg_types] = func
        else:
            self._sync_methods[arg_types] = func

    def get_arg_types(self, func: Callable) -> tuple:
        try:
            hints = get_type_hints(func)
            sig = inspect.signature(func)
            params = list(sig.parameters.values())

            if params and params[0].name in ("self", "cls"):
                params = params[1:]

            return tuple(hints.get(p.name, Any) for p in params)
        except Exception:
            return ()

    @lru_cache(maxsize=128)
    def match_types(self, provided_args: tuple, expected_types: tuple) -> bool:
        if len(provided_args) != len(expected_types):
            return False

        for arg, expected in zip(provided_args, expected_types):
            if expected is Any:
                continue

            origin = get_origin(expected)
            if origin is not None:
                expected = origin

            if not isinstance(arg, expected):
                return False

        return True

    @lru_cache(maxsize=128)
    def _find_matching_method(self, is_async: bool, *args: Any) -> Callable:
        """Find the method that matches the argument types."""
        methods = self._async_methods if is_async else self._sync_methods

        # Search for exact match
        for arg_types, method in methods.items():
            if self.match_types(args, arg_types):
                return method

        # If no match, raise descriptive error
        arg_type_names = tuple(type(arg).__name__ for arg in args)
        context = "async" if is_async else "sync"
        available = list(methods.keys())

        raise TypeError(
            f"No matching {context} method '{self._name}' found for "
            f"arguments: {arg_type_names}. Available: {available}"
        )

    def register(self, func: Callable) -> Callable:
        """Register a new method overload."""
        arg_types = self.get_arg_types(func)

        if inspect.iscoroutinefunction(func):
            self._async_methods[arg_types] = func
        else:
            self._sync_methods[arg_types] = func

        return self

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        try:
            asyncio.get_event_loop()
            is_async_context = True
        except RuntimeError:
            is_async_context = False
            pass

        # Exclude 'self' from arguments if present
        check_args = args[1:] if args and hasattr(args[0], self._name) else args

        try:
            method = self._find_matching_method(is_async_context, *check_args)
            result = method(*args, **kwargs)

            return result
        except TypeError:
            # If no async/sync method, try the other
            try:
                is_async_context = not is_async_context
                method = self._find_matching_method(is_async_context, *check_args)
                result = method(*args, **kwargs)

                return result
            except TypeError:
                raise

    def __get__(self, obj, objtype=None):
        """Support for bound methods"""
        if obj is None:
            return self

        import functools

        return functools.partial(self.__call__, obj)

I would like to know your opinions and be able to discuss this idea.

A problem with this approach is the way the “async context” detection works.

This is something that a few libraries do (Django being a famous example), and it causes all sorts of headaches. The issue is that simply checking for “async artifacts” (in this case a running event loop) to infer if something is being called within an “async context”, assumes that the caller explicitly set up this event loop or is even aware about it, which isn’t always the case. One example of this would be libraries that implicitly set up event loops behind the scenes, a notable example of this being fsspec.

In practice this means that you can run into a situation where someone is using library a and library b in their code. a exposes a coroutinedispatch function, b implicitly sets up an event loop at some point somewhere, and then the user of library a now suddenly has a coroutine at hand and doesn’t know why.

def do_something():
 value = coro_dispatch_fn() # returns an 'int' here
 files = fsspec.filesystem("http").ls()
 value = coro_dispatch_fn() # now returns a 'Coroutine[int]'

I think the return type of a function changing because of a side effect somewhere else is quite confusing and a potential major source of all sorts of bugs.

Another aspect to consider would be that this is impossible to type-check. A static type checker can infer if something is guaranteed to occur within an async context (e.g. it’s being called within an async function), but it cannot infer if something is guaranteed to not be called within an async context, at least when “async context” is being used as in your proposal.

3 Likes

Good morning, thank you so much for taking the time to respond to my proposal.

There are different ways to determine whether we are in an asynchronous context or not. For example, we can also do this to determine if we are in a synchronous or asynchronous context:

CO_COROUTINE = getattr(inspect, 'CO_COROUTINE', 0x0080)

def _called_from_async(depth: int = 2) -> bool:
    try:
        frame = sys._getframe(depth)
        return bool(frame.f_code.co_flags & CO_COROUTINE)

    except (ValueError, AttributeError):
        return False

Although I’ve tested the code you suggested and it works correctly, I’m seeing that the problem might occur when we’re in an asynchronous function calling another synchronous one. In that case, it can produce a false positive result indicating that it’s being executed within an asynchronous function.

Through testing, sys._getframe works correctly and is slightly faster (from 2µs with asyncio.get_event_loop to 1.27µs with sys._getframe).

Example of coroutinedispatch with sys._getframe:

import inspect
import asyncio
import sys
from typing import Any, Callable, get_type_hints, get_origin
from functools import lru_cache


class coroutinedispatch:
    CO_COROUTINE = getattr(inspect, "CO_COROUTINE", 0x0080)

    def __init__(self, func: Callable):
        self._sync_methods = {}
        self._async_methods = {}
        self._name = func.__name__

        arg_types = self.get_arg_types(func)

        if inspect.iscoroutinefunction(func):
            self._async_methods[arg_types] = func
        else:
            self._sync_methods[arg_types] = func

    def _called_from_async(self, depth: int = 2) -> bool:
        """Detect if caller is inside an async def."""
        try:
            frame = sys._getframe(depth)
            return bool(frame.f_code.co_flags & self.CO_COROUTINE)
        except (ValueError, AttributeError):
            return False

    def get_arg_types(self, func: Callable) -> tuple:
        try:
            hints = get_type_hints(func)
            sig = inspect.signature(func)
            params = list(sig.parameters.values())

            if params and params[0].name in ("self", "cls"):
                params = params[1:]

            return tuple(hints.get(p.name, Any) for p in params)
        except Exception:
            return ()

    @lru_cache(maxsize=128)
    def match_types(self, provided_args: tuple, expected_types: tuple) -> bool:
        if len(provided_args) != len(expected_types):
            return False

        for arg, expected in zip(provided_args, expected_types):
            if expected is Any:
                continue

            origin = get_origin(expected)
            if origin is not None:
                expected = origin

            if not isinstance(arg, expected):
                return False

        return True

    @lru_cache(maxsize=128)
    def _find_matching_method(self, is_async: bool, *args: Any) -> Callable:
        """Find the method that matches the argument types."""
        methods = self._async_methods if is_async else self._sync_methods

        # Search for exact match
        for arg_types, method in methods.items():
            if self.match_types(args, arg_types):
                return method

        # If no match, raise descriptive error
        arg_type_names = tuple(type(arg).__name__ for arg in args)
        context = "async" if is_async else "sync"
        available = list(methods.keys())

        raise TypeError(
            f"No matching {context} method '{self._name}' found for "
            f"arguments: {arg_type_names}. Available: {available}"
        )

    def register(self, func: Callable) -> Callable:
        """Register a new method overload."""
        arg_types = self.get_arg_types(func)

        if inspect.iscoroutinefunction(func):
            self._async_methods[arg_types] = func
        else:
            self._sync_methods[arg_types] = func

        return self

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        is_async_context = self._called_from_async()

        # Exclude 'self' from arguments if present
        check_args = args[1:] if args and hasattr(args[0], self._name) else args

        try:
            method = self._find_matching_method(is_async_context, *check_args)
            result = method(*args, **kwargs)

            return result
        except TypeError:
            # If no async/sync method, try the other
            try:
                is_async_context = not is_async_context
                method = self._find_matching_method(is_async_context, *check_args)
                result = method(*args, **kwargs)

                return result
            except TypeError:
                raise

    def __get__(self, obj, objtype=None):
        """Support for bound methods"""
        if obj is None:
            return self

        import functools

        return functools.partial(self.__call__, obj)

I look forward to any further suggestions or problems you might encounter.

But why detect the state we are in? User already needs to state explicitly, what they want (just calling or await-ing).

I get noise reduction, but user can explicitly chose what they want. So, how is runtime detection better, then explicit naming?

My proposal aims to simplify the use of any library/framework for the user. For example, currently in Django, if you want to execute a sync/async query, you run the following code.

User.objects.get(username="example") # Sync
await User.objects.aget(username="example") # Async

And with these changes, the method that needs to be executed would be managed internally.

def sync_function():
    User.objects.get(username="example") # Sync

async def async_function():
    await User.objects.get(username="example") # Async

This proposal could halve the API documentation for libraries that support both synchronous and asynchronous operations, and also prevent errors such as synchronous operations running in an asynchronous context (blocking) and asynchronous operations running in a synchronous context (leaving the coroutine unexecuted).

x = example.process(42)
await x

What is the type of x? What code is executed for example.process?

If you say an asynchronous, because it is executed when there is a running event loop, then what about the following code?

asynd def f(x):
    print(await x)
asyncio.run(f(example.process(42)))

If you say a synchronous, because there is no await immediately preceeding the call (we can consider await + call a single syntax construct), then your hybrid method will be incompatible with a lot of existing code.

For the first example, the IDE will tell you it’s Coroutine[str] | str.

Although I would still need to improve the class to ensure it works.

For the second example:

async def f(x):
    print(await x)


asyncio.run(f(example.process(42)))

What you would be passing to the function f is a str because you are invoking the function in a synchronous context. If you needed to get the coroutine, you would have to invoke it within an asynchronous context.

async def bar(f):
    print(await f)


async def foo(x):
    bar(example.process(x))


asyncio.run(foo(42))

It’s clear that if you need to obtain the coroutine in a synchronous context, this proposal won’t work for you, but you could always choose not to use it; ultimately, this is a proposal to add to functools so that people who find it useful can use it.