Type Inference with Generic Endofunctions

An endofunction or self-map is a function whose domain and codomain are identical. We may declare a generic class which has a method that always is an endofunction on the generic type, as follows:

class Transform[T]:
    def apply(self, x: T) -> T: ...

It is clear that, no matter what T is bound to, t.apply(x) should return the type of x, assuming of course that x is compatible with T. Currently, type-checkers infer Any/Unknown in this scenario (pyright playground, mypy-playground).

A real world use-case where this behavior would make a lot of sense are data transformations such as scikit-learn dataset transformations. Even with type annotation, such classes would often fail to bind at initialization time, because they usually need to see some example data before all parameters are determined, which happens only during the call to the .fit(data)-method, but at that time it is too late to bind T.

Yet, it would often be useful to retain the endomorphic character of these classes, e.g. if you feed in a pandas.DataFrame you get back a pandas.DataFrame, and if you feed in a numpy.ndarray you get back a numpy.ndarray, instead of coercing to Any.

Idea: If a class-bound type var is unresolved, then when calling a method, treat the type-var as if it were bound to the method instead.

An example
import numpy as np
from dataclasses import dataclass
from typing import assert_type

@dataclass
class Scaler[T]:
    factor: float | T

    def apply(self, x: T) -> T:
        return self.factor * x

# case ①: T is bound at initialization time
factors = np.array([1.0, 3.6])
array_scaler = Scaler(factors)
assert_type(array_scaler, Scaler[np.ndarray])

# transform some data
x = np.array([[1, 2], [3, 4], [5, 6]])  # 3×2
y = array_scaler.apply(x)
assert_type(y, np.ndarray)  # ✅

# case ②: T remains unbound
generic_scaler = Scaler(2.5)
y1 = generic_scaler.apply(x)
assert_type(y1, np.ndarray)  # ❌
y2 = generic_scaler.apply(3.8)
assert_type(y2, float)  # ❌

In case ②, the type checker could pretend as if the class was defined like so instead:

@dataclass
class Scaler:
    factor: float | Any  # or bound[T] if applicable

    def apply[T](self, x: T) -> T:
        return self.factor * x

Which would give the desired results.

No, that is not what the type describes, see this example:

from typing import TypeVar, Generic

T = TypeVar("T")

class Transform(Generic[T]):
    def __init__(self, value: T):
        self.value = value
    def apply(self, x: T) -> T: return self.value


class A: pass
class B(A): pass

t = Transform[A](A())
x: B = B()
reveal_type(t.apply(B())) # must be `A`, `B` would be wrong.

Any (or Unknown) is the only valid choice as a default for T.

1 Like

In your example, T is bound to A; I am talking about the case when T remains unresolved.

If a class-bound type var is unresolved…

The class-bound type var is not unresolved in your example. It is resolved to Any at construction time. Class type variables must be resolved at the time a class instance is constructed. If the constructor call doesn’t inform the value, then it takes on its default value. If no default value is specified for the type parameter, it is implicitly Any (for TypeVars), ... (for ParamSpecs), or *tuple[Any, ...] (for TypeVarTuples). That means the type of Transform() is Transform[Any].

It sounds like what you want here is a method-scoped type variable.

Code sample in pyright playground

class Transform:
    def apply[T](self, x: T) -> T: return x

t = Transform()
x: int = 2
reveal_type(t.apply(x)) # int
6 Likes

Ah, then sadly it is not possible. The type-checker would have to remember that the fallback was used.

What I want is not a method-scoped type-var per se, but rather to fall back to method scoping if the type-var wasn’t resolved. I thought my example was instructive, but let me elaborate further, then:

When we initialize this class as Scaler(np.array([1,2,3]) you would expect that it can be applied to numpy-arrays. If we instead initialize Scaler(torch.tensor([1,2,3])), we should be able to feed torch tensors, etc. All we need is some type tensor type T that supports self-multiplication (i.e. __mul__(self, other: Self) -> Self).

On the other hand, if we initialize Scaler(2.4), then the resulting object is universal in that it can be used with any tensor-type that supports float-multiplication (i.e. __mul__(self, other: float) -> Self). Most tensor libraries support both.

One solution would be to have two classes instead, but that is annoying and produces a lot of code duplication:

@dataclass
class UniversalScaler:
     factor: float
     def apply[T](self, x: T) -> T:
        return self.factor * x

@dataclass
class SpecializedScaler[T]:
    factor: T
    def apply(self, x: T) -> T:
        return self.factor * x

@overload
def make_scaler(factor: float) -> UniversalScaler: ...
@overload
def make_scaler(factor: T) -> SpecializedScaler[T]: ...
def make_scaler(factor):
     if isinstance(factor, float):
        return UniversalScaler(factor)
    return SpecializedScaler(factor)

With this proposal, the Scaler-class would capture both behaviors.

Code sample in pyright playground

from dataclasses import dataclass
from typing import Iterator, Protocol, Self, assert_type, reveal_type

class TensorProto(Protocol):
    def __mul__(self, other: float | Self) -> Self: ...
    def __rmul__(self, other: float | Self) -> Self:
        return self.__mul__(other)

@dataclass
class A(TensorProto):
    values: list[float]
    
    def __mul__(self, other: float | Self) -> Self:
        cls = type(self)
        if isinstance(other, int | float):
            return cls([x * other for x in self.values])
        return cls([x * y for x, y in zip(self.values, other.values)])

@dataclass
class B(TensorProto):
    values: list[float]

    def __mul__(self, other: float | Self) -> Self:
        cls = type(self)
        if isinstance(other, int | float):
            return cls([x * other for x in self.values])
        return cls([x * y for x, y in zip(self.values, other.values)])

@dataclass
class Scaler[T: TensorProto]:
    factor: float | T

    def apply(self, x: T) -> T:
        return self.factor * x

x = A([1.0, 2.0, 3.0])
y = B([1.0, 2.0, 3.0])

specialized_to_A = Scaler(x)
assert_type(specialized_to_A, Scaler[A])  # ✅
assert_type(specialized_to_A.apply(x), A)  # ✅

specialized_to_B = Scaler(y)
assert_type(specialized_to_B, Scaler[B])  # ✅
assert_type(specialized_to_B.apply(y), B)  # ✅

generic_scaler = Scaler(2.5)
reveal_type(generic_scaler)  # Scaler[Unknown]
reveal_type(generic_scaler.apply(x))  # Unknown (desired inference: A)
reveal_type(generic_scaler.apply(y))  # Unknown (desired inference: B)

Essentially, it’s about getting crisper type inference in the unresolved case, while maintaining the behavior in the specialized case.

I think that your problem is a more general problem that needs a more complicated solution. As the Array API takes off, what you really want are a variety of functions of the form:

def f(x: Array[T], y: Array[U]) -> Array[Promoted[T, U]]: ...

Where T and U capture Array API provider and dtype, and the promotion verifies that the array API providers are both the same or else complex, and Promote returns the promoted type according to the rules specified by the Array API provider. (Each provider can have its own types and its own promotion rules.)

I haven’t seen any proposals to implement anything like this.

However, it would be really nice to at least check that the arrays come from the same provider since that’s often a runtime error.

That’s an interesting aspect, but goes a bit into another direction, since the kind of data-transformations I have in mind are mostly unary operations that preserve the scalar type (Standardization, Min-Max Scaling, Clipping, etc.).

I don’t see how transform in your question is “unary” since it takes a value of one type T and combines it with values in apply of a possibly different type U, and then you do need to promote. If it were a truly unary operation, then a simple unbound type variable would solve your problem, right?

I actually think I found a solution via overloads:

@dataclass
class Scaler[T: TensorProto | float]:
    factor: T

    @overload
    def apply[S: TensorProto](self: "Scaler[float]", x: S) -> S: ...
    @overload
    def apply(self, x: T) -> T: ...
    def apply(self, x):
        return self.factor * x

Code sample in pyright playground

EDIT: Unfortunately, this only works easily when the class has only one parameter (Code sample in pyright playground), otherwise one will have to write a full suite of __new__ overloads as well.