I have a class acts sort of like a callable and am trying to figure out how to add type annotations. The API looks something like (inspired by Pytorch’s nn.Module):
P = ParamSpec("P")
T = TypeVar("T")
class FuncWrapper(Generic[P, T]):
@abc.abstracmethod
def forward(self, *args: P.args, **kwargs: P.kwargs) -> T:
# User-defined function
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
# some class-defined wrapper with the same type-hints
return self.forward(*args, **kwargs)
Importantly, forward is implemented by subclassingFuncWrapper – it cannot just be passed in (because it needs attributes of self). The above type-hints does seem to work (you can subclass FuncWrapper[[int, bool], str]), but it comes with a couple limitations:
Redundancy: I have to specify the type hints twice, once in the implementation forward and once in the subclassing. Ideally, I’d be able to just say something like “use the ParamSpec + TypeVar from the forward implementation”.
AFAICT, there is no way to specify **kwargs in the Generic type hints, and that I can only specify positional-only types.
Is there a better way to type-hint such an API? The first limitation is “just” annoying, but the 2nd limitation is a big problem for us since we do need to support keyword arguments in the API.
Probably not, but without specific details that would support this, I’m not convinced that requiring users subclass for this is the best approach.
Consider:
P = ParamSpec("P")
T = TypeVar("T")
class FuncWrapper(Generic[P, T]):
def __init__(self, forward: Callable[P, T]):
self.forward: Callable[P, T] = forward
def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> T:
return self.forward(*args, **kwargs)
(This also ends up usable as a decorator as-is)
Would need more details as to your intended use and why you are wrapping user behavior with a level of indirection though to recommend anything more specific.
Of course, when the function doesn’t need access to self, then we can special-case those with a decorator (and we do internally in our code), but we have a lot of code that can’t be expressed like this.
Maybe there is another way to design this API, but this is deeply embedded into a lot of the libraries we already use (PyTorch nn.Module work like this, but I’d note that AFAICT all the other deep learning libraries like TensorFlow or Keras have basically isomorphic APIs), so changing the API is pretty much impossible form my standpoint.
FWIW, the status quo that PyTorch seems to use is something like:
class Module:
forward: Callable[..., Any]
which basically gives up any type hinting on it, although it has the benefit of not seeing a bunch of errors about subclasses changing the signature of forward.
That doesn’t seem like a problem at all, just pass self into the function and make that part of the API contract, i.e. use Concatenate to mandate that the first parameter of the decorated/passed-in function is a self reference.
It is slightly less ergonomic, since you will need to provide an annotation for self, but that seems okay to me. Although I wouldn’t really design an API in this way to begin with. But if I had to this would seem like a reasonable way to keep it reasonably strongly typed.
I see… you’re right that it does work, although it does seem a little ugly (and feels like more boilerplate than manually specifying P if we don’t have kwargs involved). I think a closer example would look like:
(There is sometimes state associated with each instance of FunCwrapper, so we cannot just “hide” the constructor behind the decorator in general).
I was hoping I could make forward a ClassVar or something to avoid having to manually “link” the forward function and the class instance… but apparently, generic classvars are disallowed so that doesn’t work. I could also just have a alternative constructor that does the binds things automatically I guess?
This also doesn’t really solve the problem for torch.nn.Module, where I can’t reasonably change the API like this without rewriting years worth of code, although from what I’m gathering the original API-style I asked about (which nn.Module uses) just isn’t possible today? I assume that’s the case, since both answers are basically a variant of “change the API”.
Out of curiosity, how else would you design the nn.Module use-case? I feel like subclassing is very natural for this as user (and presumably why all the deep learning libraries have converged to a similar design). I could imagine alternatives (e.g. avoid state by functionalizing the weights), but they all seem significantly less ergonomic to me as a user.