I recently submitted a fix to typeshed to indicate that contextlib.ContextDecorator.__call__ and contextlib.AsyncContextDecorator.__call__ do not in fact return callables with the same type as their input, they only return callables with the same call signature as their input: Improve [Async]ContextDecorator type hinting by ncoghlan · Pull Request #13416 · python/typeshed · GitHub
Those returned values will be callable and have a __wrapped__ attribute pointing to the wrapped callable, but they won’t have any other attributes and methods exposed by the original type, so the current fully generic signature for these methods in typeshed is not correct (it is both missing the addition of __wrapped__, and incorrectly indicating that other attributes and methods will be available).
The mypy_primer check picked up that jax was affected by this signature clarification, so I also submitted a PR to jax to ensure it wouldn’t be affected as type checkers started picking up the stdlib signature fix: Clarify type hints for wrapped functions by ncoghlan · Pull Request #25994 · jax-ml/jax · GitHub
Working on the jax PR highlighted a problem severe enough that I have suggested reverting the typeshed fix for the time being: fixing the type signature of those methods to correctly exclude access to unrelated attributes and methods also strips all of the detailed signature overload information, which is devastating for return type inferencing on the wrapped callable. The inferred type is always the full union of potential return types, even when a more specific overload is defined.
The heart of the problem is this approach to specifying wrapped callables:
_P = ParamSpec("_P")
_R = TypeVar("_R")
class _WrappedCallable(Generic[_P, _R]):
__wrapped__: Callable[_P, _R]
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
Splitting up the parameter spec and return type information like that is what loses the overload information.
To allow typecheckers to preserve the overload details, we would need the ability to instead declare something like the following:
_C = CallableSpec("_C") # Akin to a Callable-only TypeVar
class WrappedCallable(Generic[_C]):
__wrapped__: _C
def __call__(self, *args: _C.args, **kwargs: _C.kwargs) -> _C.retval: ...
The CallableSpec invocation shown would be equivalent to _C = CallableSpec("_C", bound=Callable[..., Any]), so generic definitions could place restrictions on the required callable signatures.
Somewhat related recent topics: