Sum Types in the stdlib

Hi all,

This idea is motivated by my earlier question here.

At a high level, the idea is to make something like this “work”: the case statements should match, and type checkers should be able to infer that the match is exhaustive.

Currently, the example below is broken because the case statements will never match (it works at runtime but static type checkers cannot understand it): there is no place in the program to suggest that the two @dataclasses defined inside Shape are supposed to be Shapes’s variants.

from dataclasses import dataclass

class Shape:
    @dataclass
    class Square:
        center: tuple[float, float]
        length: float

    @dataclass
    class Circle:
        center: tuple[float, float]
        radius: float


def area(s: Shape):
    match s:
        case Shape.Square(center, length):   # type checkers claim that
            return length * length           # these patterns will never match!
        case Shape.Circle(center, radius):
            return 3.14 * radius * radius

We can get pretty close to desired behavior by using this: the match statement works, and (most?) typecheckers are able to verify that the match is exhaustive.

from dataclasses import dataclass

@dataclass
class Square:
    center: tuple[float, float]
    length: float

@dataclass
class Circle:
    center: tuple[float, float]
    radius: float

type Shape = Square | Circle

def area(s: Shape):
    match s:
        case Square(center, length):
            return length * length
        case Circle(center, radius):
            return 3.14 * radius * radius

The issue with this approach — which is what I’m proposing to fix — is that Square and Circle are not well encapsulated.

In the following example, Square is redefined for PaymentEndpoint, rendering the case statement incorrect in the area(s) function incorrect.

from dataclasses import dataclass

@dataclass
class Square:
    center: tuple[float, float]
    length: float

@dataclass
class Circle:
    center: tuple[float, float]
    radius: float

type Shape = Square | Circle

@dataclass
class Square:
    token: ...

@dataclass
class Stripe:
    token: ...

type PaymentEndpoint = Square | Stripe

def area(s: Shape):
    match s:
        case Square(center, length): ## <-- uh oh!
            return length * length
        case Circle(center, radius):
            return 3.14 * radius * radius

This example is a little contrived: the problem goes away if Shape’s and PaymentEndpoint’s were placed in different files. And they probably should be! But it’s not hard to imagine(*) cases where it’s convenient to have multiple union types defined in the same file with the same variant names; these use-cases is not covered right now.

Another more abstract downside: the construct type U = X | Y only informs U that it is related with X and Y; it doesn’t tell X (or Y) that they are related to U.

Proposal. The concrete proposal is to introduce a new marker (kind of like dataclass): with the decorator, member variables become variants. Thus, a one/two-liner change to the first example would make it work:

from dataclasses import dataclass
from dataclasses import marker    # <- new!

@marker                           # <- new!
class Shape:
    @dataclass
    class Square:
        center: tuple[float, float]
        length: float

    @dataclass
    class Circle:
        center: tuple[float, float]
        radius: float

# And now this works:
def area(s: Shape):
    match s:
        case Shape.Square(center, length):   # these patterns will never match!
            return length * length
        case Shape.Circle(center, radius):
            return 3.14 * radius * radius

(marker is deliberately a poor name — just want to focus on the idea/interface instead of the proper naming.)

Curious to hear what others think. Thanks all!

1 Like

For comparison and prior art, here is how Rust does it – I think it does a good job.

enum Shape {
    Square { center: (f64, f64), length: f64 },
    Circle { center: (f64, f64), radius: f64 },
}

enum PaymentEndpoint {
    Square { token: String },
    Stripe { token: String },
}

fn area(s: &Shape) -> f64 {
    match s {
        Shape::Circle{ center, radius } => 3.14*radius*radius,
        Shape::Square{ center, length } => length * length
    }
}

Some related previous discussions:

2 Likes

As you noticed, it’s perfectly possible to implement stuff like this with metaprogramming (see my POC for even more extreme variants). I would much prefer if there is an effort for a generic way to teach type checkers about such dynamic constructs so that not every possible variant needs to be implemented in the stdlib.

Thanks all - those are useful references.

I would much prefer if there is an effort for a generic way to teach type checkers about such dynamic constructs so that not every possible variant needs to be implemented in the stdlib.

Yes and no. I agree that that being able to “teach” type checkers about dynamic constructs would be very useful.

On the other hand, I think (just my opinion!) sum types are sufficiently generic and ubiquitous to warrant treatment in the stdlib. It’s the counter part to @dataclass (product types) – a great addition to the language that was provided by 3rdparty libs (e.g. attrs) in the past.

1 Like

Yes; Are the semi-wide spread public projects that implement Sum types the same way attrs implemented Product types?

Ideally there would be multiple such projects so that different approaches can be compared - putting something into the stdlib is setting it into stone, ideally we can draw on already existing designs.

1 Like

So this kind of thing - where you want a sum type namespaced inside their abstract parent - is actually already expressible in a static-type-checking-compatible way today, like so. FunctionInfo is the parent/namespace, and then ResidualJac etc are the variants of the sum type.

It’s not exactly the cleanest syntax in the world, but it does work.

Thanks Patrick - that’s a neat snippet (TIL about ClassVar).

However, I tried the following but it doesn’t seem to work completely.

class Shape:
    Square: ClassVar[type["Square"]]

@dataclass
class Square(Shape):
    center: tuple[float, float]
    length: float

Shape.Square = Square
Square.__qualname__ = "Shape.Square"


def area(s: Shape):
    match s:
        case Shape.Square(_, length):
            return length * length

In this example, when I hover over the s in match s, my IDE reports a warning that it cannot determine its an exhaustive match.

Cases within match statement do not exhaustively handle all values
  Unhandled type: "Shape"
  If exhaustive handling is not intended, add "case _: pass" basedpyrightreport(MatchNotExhaustive)
(parameter) s: Shape

I think there’s no way to get exhaustive matching using inheritance with the tools we currently have. Anyone could define an additional subclass at any time, right?

In the spirit of “consenting adults”, you are right: there are no language features to prevent subclassing and thus guarantee that static analysis could verify exhaustive checks. (And I like it this way!)

But insofar as this @marker is a hint/directive to the type checker, it could imply @typing.final, so that

  • type checkers may assume there no subclasses and all possible variants are in the @marker’d class
  • type checkers may warn if it sees a marker class being subclassed.
  1. The phrase “consenting adults” here is misapplied.
  2. The overall sentiment is incorrect. You can ensure at runtime, not just via incomplete static analysis two or more types are disjoint in multiple ways, from raising in __init_subclass__ to ensuring they have incompatible slots.

There’s nothing preventing this from being possible and from being typecheckable with what currently exists, you just need typecheckers to understand what already exists better.

1 Like

In the interest of showing a productive answer:

from dataclasses import dataclass
from typing import final


@final
@dataclass
class Square:
    center: tuple[float, float]
    length: float


@final
@dataclass
class Circle:
    center: tuple[float, float]
    radius: float

type Shape = Square | Circle

def area(s: Shape) -> float:
    match s:
        case Square(center, length):
            return length * length
        case Circle(center, radius):
            return 3.14 * radius * radius

It’s written differently from rust’s enum types, namely, this is just a union that has enough information to match on, but you can do what you seem to want with this.

6 Likes