Trying to type complex function decorator

Hello,

I’m struggling to type a function decorator which:

  • takes arguments itself, and
  • injects different args into the decorated function calls, and
  • decorates functions with different args/kwargs.

An minimal example would be the following:

import collections.abc
import functools
import typing

# P = typing.ParamSpec("P")
# R = typing.TypeVar("R")

# https://docs.python.org/3/reference/compound_stmts.html#generic-type-aliases
# type F[**P, R] = (
#     collections.abc.Callable[typing.Concatenate[int, float, P], R]
#     | collections.abc.Callable[typing.Concatenate[float, P], R]
# )
type F[**P, R] = collections.abc.Callable[P, R]


def decorator[**P, R](with_x: bool = False) -> collections.abc.Callable[[F[..., R]], F[..., R]]:
    def outer(wrapped: F[..., R]) -> F[..., R]:
        @functools.wraps(wrapped)
        def inner(*args: P.args, **kwargs: P.kwargs) -> R:
            # This inner() function can actually be quite complex!

            f = 3.14  # `f` is passed to *all* decorated functions.
            if with_x:
                x = 1  # `x` is passed only if `with_x` is truthy.
                return wrapped(x, f, *args, *kwargs)
            return wrapped(f, *args, kwargs)

        return inner
    return outer


@decorator(with_x=True)
def fun1(x: int, f: float, s: str) -> str:
    return s.lower()


@decorator(with_x=False)
def fun2(f: float, s: str, d: dict[str, typing.Any]) -> str:
    return s.lower() + "".join(d.keys())

This code fails because error: ParamSpec "P" is unbound and it’s missing actual types for the call parameters (the ... ellipsis).

As show, depending on the decorator argument with_x the decorator function inserts either x and f args, or just f into the function call to the decorated function. The rest of the args/kwargs of the decorated function can be anything.

Thank you heaps!

PS: Somewhat related seems Type Hinting with Decorator Removing First Argument

You can use typing.overload to create functions with types that depend on the arguments, then typing.Literal to depend on the literal values of True and False and typing.Concatenate to add/remove the extra parameters to a function (where you’re currently using an elipsis). For the inner function it’s simpler if the two implementations are separated. So the below example:

from typing import Callable, overload, Literal, Concatenate, reveal_type, Any
from functools import wraps

type FillOne[T, **P, R] = Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]
type FillTwo[T, U, **P, R] = Callable[[Callable[Concatenate[T, U, P], R]], Callable[P, R]]

@overload
def decorator[**P, R](with_x: Literal[False] = False) -> FillOne[float, P, R]:
    ...

@overload
def decorator[**P, R](with_x: Literal[True]) -> FillTwo[int, float, P, R]:
    ...

def decorator[**P, R](with_x: bool = False) -> FillOne[float, P, R] | FillTwo[int, float, P, R]:
    f = 3.14
    x = 1
    if with_x:
        def pad_two(wrapped: Callable[Concatenate[int, float, P], R]) -> Callable[P, R]:
            @wraps(wrapped)
            def inner(*args: P.args, **kwargs: P.kwargs) -> R:
                return wrapped(x, f, *args, **kwargs)
            return inner
        return pad_two
    
    def pad_one(wrapped: Callable[Concatenate[float, P], R]) -> Callable[P, R]:
        @wraps(wrapped)
        def inner(*args: P.args, **kwargs: P.kwargs) -> R:
            return wrapped(f, *args, **kwargs)
        return inner

    return pad_one


@decorator(with_x=True)
def fun1(x: int, f: float, s: str) -> str:
    return s.lower()


@decorator(with_x=False)
def fun2(f: float, s: str, d: dict[str, Any]) -> str:
    return s.lower() + "".join(d.keys())

reveal_type(fun1)
reveal_type(fun2)

in pyright gives

Type of "fun1" is "(s: str) -> str"
Type of "fun2" is "(s: str, d: dict[str, Any]) -> str"

Try it in pyright online

If you want to require specific argument names (x and f, where the above currently only requires arguments with a specific type but any name), you can use a typing.Protocol in place of typing.Callable but it’s a bit more complicated.

Thank you @alwaysmpe, two different outer() functions is an interesting suggestion!

The only caveat with the suggestion is that in my case the inner() function contains quite a lot of code which I would then have to either duplicate or pop into a shared helper function/generator.

Which raises the question: how much code refactoring is sensible in order to please the static type checker :thinking:

My background is C++, you might be asking the wrong person :sweat_smile: but you can do it with a single function, I think just need pad_one and pad_two renamed as separate overloads with the same name.

The online checker worked, but mypy gives me a whole bunch of errors :disappointed_face: I shall tinker some more!

Thank you, I’ll try that, see what happens!

Looks like mypy isn’t happy resolving the type alias. Making it a bit more verbose works:

from typing import Callable, overload, Literal, Concatenate, reveal_type, Any
from functools import wraps

type FillOne[T, **P, R] = Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]
type FillTwo[T, U, **P, R] = Callable[[Callable[Concatenate[T, U, P], R]], Callable[P, R]]

@overload
def decorator[**P, R](with_x: Literal[False] = False) -> Callable[[Callable[Concatenate[float, P], R]], Callable[P, R]]:
    ...

@overload
def decorator[**P, R](with_x: Literal[True]) -> Callable[[Callable[Concatenate[int, float, P], R]], Callable[P, R]]:
    ...

def decorator[**P, R](with_x: bool = False) -> FillTwo[int, float, P, R] | FillOne[float, P, R]:
    f = 3.14
    x = 1
    if with_x:
        def pad_two(wrapped: Callable[Concatenate[int, float, P], R]) -> Callable[P, R]:
            @wraps(wrapped)
            def inner(*args: P.args, **kwargs: P.kwargs) -> R:
                return wrapped(x, f, *args, **kwargs)
            return inner
        return pad_two
    
    def pad_one(wrapped: Callable[Concatenate[float, P], R]) -> Callable[P, R]:
        @wraps(wrapped)
        def inner(*args: P.args, **kwargs: P.kwargs) -> R:
            return wrapped(f, *args, **kwargs)
        return inner

    return pad_one


@decorator(with_x=True)
def fun1(x: int, f: float, s: str) -> str:
    return s.lower()


@decorator(with_x=False)
def fun2(f: float, s: str, d: dict[str, Any]) -> str:
    return s.lower() + "".join(d.keys())

reveal_type(fun1)
reveal_type(fun2)

And works on mypy online. Should probably raise an issue on mypy github…

That, and I commented out the @wraps decorator, and then cast the wrapped function like so:

from typing import Callable, overload, Literal, Concatenate, reveal_type, Any, cast
from functools import wraps

type FillOne[T, **P, R] = Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]
type FillTwo[T, U, **P, R] = Callable[[Callable[Concatenate[T, U, P], R]], Callable[P, R]]

type FunOne[**P, R] = Callable[Concatenate[float, P], R]
type FunTwo[**P, R] = Callable[Concatenate[int, float, P], R]

@overload
def decorator[**P, R](with_x: Literal[False] = False) -> Callable[[FunOne[P, R]], Callable[P, R]]:
    ...

@overload
def decorator[**P, R](with_x: Literal[True]) -> Callable[[FunTwo[P, R]], Callable[P, R]]:
    ...

def decorator[**P, R](with_x: bool = False) -> FillTwo[int, float, P, R] | FillOne[float, P, R]:
    f = 3.14
    x = 1

    @overload
    def outer(wrapped: FunOne[P, R]) -> Callable[P, R]:
        ...

    @overload
    def outer(wrapped: FunTwo[P, R]) -> Callable[P, R]:
        ...

    def outer(wrapped):  # type: ignore[no-untyped-def]
        # @wraps(wrapped)
        def inner(*args: P.args, **kwargs: P.kwargs) -> R:
            if with_x:
                fillTwoFunc = cast(FunTwo[P, R], wrapped)
                return fillTwoFunc(x, f, *args, **kwargs)
            fillOneFunc = cast(FunOne[P, R], wrapped)
            return fillOneFunc(f, *args, **kwargs)
        return inner
    return outer

@decorator(with_x=True)
def fun1(x: int, f: float, s: str) -> str:
    return s.lower()

@decorator(with_x=False)
def fun2(f: float, s: str, d: dict[str, Any]) -> str:
    return s.lower() + "".join(d.keys())

I am somewhat puzzled that I must type the function signature of the implementation for the @overlad function, considering the docs say otherwise.