Defining overload-preserving signatures for wrapped callables

I recently submitted a fix to typeshed to indicate that contextlib.ContextDecorator.__call__ and contextlib.AsyncContextDecorator.__call__ do not in fact return callables with the same type as their input, they only return callables with the same call signature as their input: Improve [Async]ContextDecorator type hinting by ncoghlan · Pull Request #13416 · python/typeshed · GitHub

Those returned values will be callable and have a __wrapped__ attribute pointing to the wrapped callable, but they won’t have any other attributes and methods exposed by the original type, so the current fully generic signature for these methods in typeshed is not correct (it is both missing the addition of __wrapped__, and incorrectly indicating that other attributes and methods will be available).

The mypy_primer check picked up that jax was affected by this signature clarification, so I also submitted a PR to jax to ensure it wouldn’t be affected as type checkers started picking up the stdlib signature fix: Clarify type hints for wrapped functions by ncoghlan · Pull Request #25994 · jax-ml/jax · GitHub

Working on the jax PR highlighted a problem severe enough that I have suggested reverting the typeshed fix for the time being: fixing the type signature of those methods to correctly exclude access to unrelated attributes and methods also strips all of the detailed signature overload information, which is devastating for return type inferencing on the wrapped callable. The inferred type is always the full union of potential return types, even when a more specific overload is defined.

The heart of the problem is this approach to specifying wrapped callables:

_P = ParamSpec("_P")
_R = TypeVar("_R")

class _WrappedCallable(Generic[_P, _R]):
    __wrapped__: Callable[_P, _R]
    def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ...

Splitting up the parameter spec and return type information like that is what loses the overload information.

To allow typecheckers to preserve the overload details, we would need the ability to instead declare something like the following:

_C = CallableSpec("_C") # Akin to a Callable-only TypeVar

class WrappedCallable(Generic[_C]):
    __wrapped__: _C
    def __call__(self, *args: _C.args, **kwargs: _C.kwargs) -> _C.retval: ...

The CallableSpec invocation shown would be equivalent to _C = CallableSpec("_C", bound=Callable[..., Any]), so generic definitions could place restrictions on the required callable signatures.

Somewhat related recent topics:

1 Like

This is a common enough pattern that rather than making people rewrite their existing paramspec uses, I’d rather see a specification update that says that type checkers should associate the solving of a return R here to the paramspec P whenever Callable[P, R] is expressed as part of a type. This should be sound and address the overload case.

2 Likes

If it can be done that way, then that would be excellent (I had, potentially incorrectly, assumed it wouldn’t be feasible to do it that way).

Why not just:

from __future__ import annotations
from typing import Any
from collections.abc import Callable

class _WrappedCallable[C: Callable[..., Any]]:
    __wrapped__: C
    def __call__[**P, R](self: _WrappedCallable[Callable[P, R]], *args: P.args, **kwargs: P.kwargs) -> R: ...

Edit: NVM don’t think it preserves overloads

The syntax without type parameters in my initial post is just an artifact of wanting to support versions older than Python 3.12.

But hopefully @mikeshardmind is right that it’s feasible to adjust the expected typechecker behaviour in this case without having to change the way the wrapping idiom is expressed in the type hints.

FYI I think Jax drops 3.11 in about 1 year.