An endofunction or self-map is a function whose domain and codomain are identical. We may declare a generic class which has a method that always is an endofunction on the generic type, as follows:
class Transform[T]:
def apply(self, x: T) -> T: ...
It is clear that, no matter what T
is bound to, t.apply(x)
should return the type of x
, assuming of course that x
is compatible with T
. Currently, type-checkers infer Any
/Unknown
in this scenario (pyright playground, mypy-playground).
A real world use-case where this behavior would make a lot of sense are data transformations such as scikit-learn dataset transformations. Even with type annotation, such classes would often fail to bind at initialization time, because they usually need to see some example data before all parameters are determined, which happens only during the call to the .fit(data)
-method, but at that time it is too late to bind T
.
Yet, it would often be useful to retain the endomorphic character of these classes, e.g. if you feed in a pandas.DataFrame
you get back a pandas.DataFrame
, and if you feed in a numpy.ndarray
you get back a numpy.ndarray
, instead of coercing to Any
.
Idea: If a class-bound type var is unresolved, then when calling a method, treat the type-var as if it were bound to the method instead.
An example
import numpy as np
from dataclasses import dataclass
from typing import assert_type
@dataclass
class Scaler[T]:
factor: float | T
def apply(self, x: T) -> T:
return self.factor * x
# case â‘ : T is bound at initialization time
factors = np.array([1.0, 3.6])
array_scaler = Scaler(factors)
assert_type(array_scaler, Scaler[np.ndarray])
# transform some data
x = np.array([[1, 2], [3, 4], [5, 6]]) # 3Ă—2
y = array_scaler.apply(x)
assert_type(y, np.ndarray) # âś…
# case ②: T remains unbound
generic_scaler = Scaler(2.5)
y1 = generic_scaler.apply(x)
assert_type(y1, np.ndarray) # ❌
y2 = generic_scaler.apply(3.8)
assert_type(y2, float) # ❌
In case ②, the type checker could pretend as if the class was defined like so instead:
@dataclass
class Scaler:
factor: float | Any # or bound[T] if applicable
def apply[T](self, x: T) -> T:
return self.factor * x
Which would give the desired results.