Dataclasses: Factorioes for field's repr, compare, hash, and other

TL;DR

Add the following parameters to the dataclasses.field (T is field type):

  • repr_factory: Callable[[T], Any]
  • hash_key_factory: Callable[[T], Hashable]
  • equals_key_factory: Callable[[T], R], R is safe-comparable
  • equals_fn: Callable[[T, T], bool] (Alternative to the above)
  • order_key_factory: Callable[[T], R], R is ordering

When both factory and the respective non-factory parameter are present, an error is thrown.

The Problem

The problem lies when you want to store inside a dataclass data types you have no control over.
Yes, you can either create your custom magic methods (ignoring the purpose of dataclasses)
or create new container classes implementing these methods (which is not always possible).

Adding these new factory parameters will make dataclasses work for you.

The Story

Let’s assume we construct a dataclass which has some unhashable fields.
Numpy arrays are good example for that.
We may guarantee that these fields would not be changed for any existing class, and declare our class as frozen.
Furthermore, we don’t know how to sort them (yet), so let’s set order=False.

from numbers import Number
from dataclasses import dataclass, field
from typing import *

from numpy import array, ndarray, dtype


type Array[N: Number] = ndarray[Tuple[int], dtype[N]]
def Vector[N: Number](*items: N) -> Array[N]:
    return array(items)

UNIT_VECTOR_FACTORY = lambda: Vector(1)

@dataclass(frozen=True, order=False)
class TwoVector:
    first: Array = field(default_factory=UNIT_VECTOR_FACTORY)
    second: Array = field(default_factory=UNIT_VECTOR_FACTORY)
A long, long story...

Anger

However, this would raise an error when you call hash(TwoVector()) as numpy.ndarray is not hashable:

Traceback (most recent call last):
  File "C:\Program Files\JetBrains\PyCharm Community Edition 2022.2.2\plugins\python-ce\helpers\pydev\pydevconsole.py", line 364, in runcode
    coro = func()
           ^^^^^^
  File "<input>", line 1, in <module>
  File "<string>", line 3, in __hash__
TypeError: unhashable type: 'numpy.ndarray'

Bargaining

Okay, let’s add __hash__() method.
And, while we’re here, let’s also add __eq__() and fancy __repr__():

Note: Here and below, the imports and type alias section will repeat, mostly without any changes.

from numbers import Number
from dataclasses import dataclass, field
from typing import *

from numpy import array, ndarray, dtype, array_equal


type Array[N: Number] = ndarray[Tuple[int], dtype[N]]
def Vector[N: Number](*items: N) -> Array[N]:
    return array(items)

UNIT_VECTOR_FACTORY = lambda: Vector(1)

def vector_repr(v: Array, /, *, sep: str = ', ') -> str:
    return f'{{ {sep.join(map(str, v))} }}'

@dataclass(frozen=True, order=False)
class TwoVector:
    first: Array = field(default_factory=UNIT_VECTOR_FACTORY)
    second: Array = field(default_factory=UNIT_VECTOR_FACTORY)
    
    def __hash__(self):
        return hash((tuple(self.first), tuple(self.second)))
    
    def __eq__(self, other) -> bool:
        if (isinstance(other, TwoVector) and other.__class__ is self.__class__):
            return array_equal(self.first, other.first) and array_equal(self.second, other.second)
        raise NotImplementedError()
    
    def __repr__(self):
        return f"{self.__class__.__name__}(first={vector_repr(self.first, sep='; ')}, second={vector_repr(self.first, sep=', ')})"

Okay, now that’s works. Furthermore, hash is now consistent even if we create different instances with the similar value:

>>> hash(TwoVector())
8946237269106090120
>>> hash(TwoVector())
8946237269106090120
>>> hash(TwoVector(Vector(1, 2), Vector(2, 3)))
8374439192799522889
>>> hash(TwoVector(Vector(1, 2), Vector(2, 3)))
8374439192799522889
>>> TwoVector(Vector(1, 2), Vector(2, 3)) == TwoVector(Vector(1, 2), Vector(2, 3))
True
>>> TwoVector(Vector(1, 2), Vector(2, 3)) == TwoVector()
False
>>> TwoVector()
TwoVector(first={ 1 }, second={ 1 })
>>> TwoVector(Vector(1, 2), Vector(2, 3))
TwoVector(first={ 1; 2 }, second={ 1, 2 })

Even the __repr__ works as intended!
Now let’s add a new field – name:

from numbers import Number
from dataclasses import dataclass, field
from typing import *

from numpy import array, ndarray, dtype, array_equal


type Array[N: Number] = ndarray[Tuple[int], dtype[N]]
def Vector[N: Number](*items: N) -> Array[N]:
    return array(items)

UNIT_VECTOR_FACTORY = lambda: Vector(1)

def vector_repr(v: Array, /, *, sep: str = ', ') -> str:
    return f'{{ {sep.join(map(str, v))} }}'

@dataclass(frozen=True, order=False)
class TwoVector:
    first: Array = field(default_factory=UNIT_VECTOR_FACTORY)
    second: Array = field(default_factory=UNIT_VECTOR_FACTORY)
    name: str = field(kw_only=True, default="New TwoVector")
    
    def __hash__(self):
        return hash((tuple(self.first), tuple(self.second)))
    
    def __eq__(self, other) -> bool:
        if (isinstance(other, TwoVector) and other.__class__ is self.__class__):
            return array_equal(self.first, other.first) and array_equal(self.second, other.second)
        raise NotImplementedError()
    
    def __repr__(self):
        return f"{self.__class__.__name__}(first={vector_repr(self.first, sep='; ')}, second={vector_repr(self.first, sep=', ')})"

Depression

And now all our methods are no longer valid since they all ignore the field name:

>>> TwoVector()
TwoVector(first={ 1 }, second={ 1 })
>>> TwoVector(name='123') == TwoVector(name='myVector')
True
from numbers import Number
from dataclasses import dataclass, field
from typing import *

from numpy import array, ndarray, dtype, array_equal


type Array[N: Number] = ndarray[Tuple[int], dtype[N]]
def Vector[N: Number](*items: N) -> Array[N]:
    return array(items)

UNIT_VECTOR_FACTORY = lambda: Vector(1)

def vector_repr(v: Array, /, *, sep: str = ', ') -> str:
    return f'{{ {sep.join(map(str, v))} }}'

@dataclass(frozen=True, order=False)
class TwoVector:
    first: Array = field(default_factory=UNIT_VECTOR_FACTORY)
    second: Array = field(default_factory=UNIT_VECTOR_FACTORY)
    name: str = field(kw_only=True, default="New TwoVector")
    
    def __hash__(self):
        return hash((tuple(self.first), tuple(self.second), self.name))
    
    def __eq__(self, other) -> bool:
        if (isinstance(other, TwoVector) and other.__class__ is self.__class__):
            return array_equal(self.first, other.first) and array_equal(self.second, other.second) and self.name == other.name
        raise NotImplementedError()
    
    def __repr__(self):
        return f"{self.__class__.__name__}(first={vector_repr(self.first, sep='; ')}, second={vector_repr(self.first, sep=', ')}, name={self.name!r})"

Acceptance

As you can see above, this heavily diverges from the original idea of dataclasses – automation of data-model magic methods’ generation,
as each time we add a new field, we must update all existing methods in all places.

However, there IS a workaround – instead of storing the numpy.ndarray themselves, store them inside a proxy-class which covers all these magical methods:

from numbers import Number
from dataclasses import dataclass, field
from typing import *

from numpy import array, ndarray, dtype, array_equal


type Array[N: Number] = ndarray[Tuple[int], dtype[N]]

@dataclass(frozen=True, order=False, slots=True)
class Vector[N: Number]:
    _data: Array[N]
    
    def __init__(self, *items: N):
        object.__setattr__(self, '_data', array(items))
    
    def __hash__(self) -> int:
        return hash(tuple(self._data))
    
    def __format__(self, format_spec: str) -> str:
        prefix = ''
        sep = ', '
        pars = ('(', ')')
        if (not format_spec):
            format_spec = 'c,('
        
        for s in [ ',', ';', ':', ' ' ]:
            if (s in format_spec):
                sep = s.strip() + ' '
                break
        
        if ('c' in format_spec):
            prefix = self.__class__.__name__
            pars = ('(', ')')
        elif ('{' in format_spec):
            pars = ('{ ', ' }')
        
        return f'{prefix}{pars[0]}{sep.join(map(str, self._data))}{pars[1]}'
    
    def __repr__(self):
        return self.__format__('{;')
    
    def __eq__(self, other):
        if (isinstance(other, Vector) and other.__class__ is self.__class__):
            return array_equal(self._data, other._data)
        raise NotImplementedError()

UNIT_VECTOR_FACTORY = lambda: Vector(1)

@dataclass(frozen=True, order=False, slots=True)
class TwoVector:
    first: Vector = field(default_factory=UNIT_VECTOR_FACTORY)
    second: Vector = field(default_factory=UNIT_VECTOR_FACTORY)
    name: str = field(kw_only=True, default="New TwoVector")
>>> TwoVector()
TwoVector(first={ 1 }, second={ 1 }, name='New TwoVector')
>>> TwoVector(name='123') == TwoVector(name='myVector')
False
>>> TwoVector(Vector(1, 2), Vector(2, 3)) == TwoVector(Vector(1, 2), Vector(2, 3))
True

How this may look

from numbers import Number
from dataclasses import dataclass, field
from typing import *

from numpy import array, ndarray, dtype, array_equal


type Array[N: Number] = ndarray[Tuple[int], dtype[N]]
def Vector[N: Number](*items: N) -> Array[N]:
    return array(items)

UNIT_VECTOR_FACTORY = lambda: Vector(1)

def vector_repr(v: Array, /, *, sep: str = ', ') -> str:
    return f'{{ {sep.join(map(str, v))} }}'

@dataclass(frozen=True, order=False, slots=True)
class TwoVector:
    first: Vector = field(default_factory=UNIT_VECTOR_FACTORY, repr_factory=vector_repr, equals_fn=array_equal, hash_key_factory=lambda arr: hash(tuple(arr)))
    second: Vector = field(default_factory=UNIT_VECTOR_FACTORY, repr_factory=vector_repr, equals_fn=array_equal, hash_key_factory=lambda arr: hash(tuple(arr)))
    name: str = field(kw_only=True, default="New TwoVector")