Extended contextvars without context copying (implying parent->child change propagation), and with generator support

Currently, contextvars library suffers few issues:

  • when creating a child context via the copy() method, the parent context is fully copied, which has two potential drawbacks:

    • it could be a performance issue if many contextvars are present
    • if the parent context is updated, the children are not
  • it does not work on generators. E.g. If we define a generator via def gen(), only its __call__ is wrapped in the context when running it e.g. as for x in ctx.run(gen). But we mostly need to wrap its __next__, __del__ and __dealloc__ functions.

It can be solved via an external library, or maybe it could be added somewhere to the standard library, as it is tightly coupled with it. I have a solution that I’d like to share for inspiration, which is quick and still a bit dirty but working for basic use cases.

from contextvars import Context, ContextVar
from typing import Any, Collection, Iterable, Iterator, Tuple

cdef class KeepContext:
    """Wrap a generator or a function. Then, it always runs in the given context."""

    cdef object gen
    cdef object ctx

    def __init__(self, gen, ctx):
        self.gen = gen
        self.ctx = ctx

    def __call__(self, *largs, **kargs):
        return self.ctx.run(self.gen, *largs, **kargs)

    def __iter__(self):
        return self

    def __next__(self):
        return self.ctx.run(self.gen.__next__)

    def send(self, x):
        return self.ctx.run(self.gen.send, x)

    def __dealloc__(self):
        try:
            del_ = self.gen.__del__
        except AttributeError:
            return
        self.ctx.run(del_)

    def __del__(self):
        try:
            del_ = self.gen.__del__
        except AttributeError:
            return
        self.ctx.run(del_)

    def get_ctx(self):
        return self.ctx


cdef class ForkerData:
    def __init__(self, ctx, parent=None):
        self.ctx = ctx
        self.parent = self if parent is None else parent
        self.peak = {}
        self.del_hooks = []

    def __del__(self):
        for del_hook in self.del_hooks:
            del_hook()


cdef class Forker:
    """Overlayed contextvars.

    A function has a context (with logging environment, configs, ...), then, it can fork the
    context, creating an empty context overlay (in constant time - no copying of the base context).
    Modifications of the overlayed context are not visible to the base context. Modifications of
    the base context are visible in the overlay unless overwritten by the overlayed context.

    TL;DR Forker has a dict interface which respects contexts in the manner described above. It
    also stores the context layer structure and includes context manipulation methods.

    The base forker is created via `f = Forker()`. You can add some config `f["foo"] = 5`. Then you
    can fork the context via `ctx = f.copy()` where `ctx` is a normal `contextvars.Context`. You
    can run a function in that context: `ctx.run(fn, arg1, arg2)` or you can do both these steps at
    once using `f.fork(fn, arg1, arg2)`. Inside that call, `f["bar"] = 6` will set "bar" only
    inside the forked context. If you change `f["foo"] = 10` in the base context (not in the fork),
    the change is visible inside the forked context (in contrast with the `contextvars` logic).
    """

    def __init__(Forker self, ctx=None, initial_data=()):
        self.data = ContextVar("data")

        if ctx is None:
            ctx = Context()

        self.root_data = ForkerData(ctx)
        self.root_data.peak.update(initial_data)

        @ctx.run
        def set_current():
            self.data.set(self.root_data)

    cdef inline ForkerData cdata(Forker self):
        return self.data.get()

    def get_cdata(Forker self) -> ForkerData:
        return self.cdata()

    def get_ctx(Forker self):
        return self.cdata().ctx

    def get_root_ctx(Forker self):
        return self.root_data.ctx

    def run_in_child_ctx(Forker self, object ctx, fn, *largs, **kargs):
        if ctx is self.cdata().ctx:
            return fn(*largs, **kargs)
        else:
            return ctx.run(fn, *largs, **kargs)

    def wrap(Forker self, object fn):
        """Apply KeepContext with the current context."""
        return KeepContext(fn, self.cdata().ctx)

    def wrap_gen(Forker self, object genfn):
        """Wrap a function that returns a generator and then the generator."""
        ctx = self.cdata().ctx

        @wraps(genfn)
        def wrapper(*largs, **kargs):
            gen = ctx.run(genfn, *largs, **kargs)
            return KeepContext(gen, ctx)

        return wrapper

    def copy(Forker self):
        """Create a forked context."""
        parent = self.cdata()
        ctx = parent.ctx.copy()

        @ctx.run
        def set_current():
            self.data.set(ForkerData(ctx, parent))

        return ctx

    def fork(Forker self, fn):
        """Apply KeepContext with a newly forked context."""
        return KeepContext(fn, self.copy())

    def fork_gen(Forker self, fn, *largs, **kargs):
        """Apply KeepContext with a newly forked context to the result of fn.

        The function `fn` is also run in that forked context.
        """
        ctx = self.copy()
        return KeepContext(ctx.run(fn, *largs, **kargs), ctx)

    cdef inline dict todict(Forker self):
        """Convert the internal overlayed dict structure to a flat dict."""
        cdef ForkerData cdata = self.cdata()
        cdef list parents = [cdata]
        cdef dict result = {}

        while cdata.parent is not cdata:
            cdata = cdata.parent
            parents.append(cdata)

        for i in range(len(parents) - 1, -1, -1):
            cdata = parents[i]
            result.update(cdata.peak)

        return result

    cdef inline object items_c(Forker self):
        return self.todict().items()

    def items(Forker self):
        return self.items_c()

    def keys(Forker self) -> Iterator[str]:
        return self.todict().keys()

    def values(Forker self) -> Iterator[Any]:
        return self.todict().values()

    cdef inline void update_c(Forker self, object items: Iterable[Tuple[str, Any]]):
        cdef dict peak = self.cdata().peak
        cdef str key
        cdef object value

        for key, value in items:
            self.setitem_c0(peak, key, value)

    def update(Forker self, object items: Iterable[Tuple[str, Any]]):
        self.update_c(items)

    cdef inline object get_c(Forker self, str key, object default = None):
        try:
            return self.getitem_c(key)
        except LookupError:
            return default

    def get(Forker self, str key, object default = None) -> Any:
        return self.get_c(key, default)

    def __contains__(Forker self, str key) -> bool:
        try:
            self.getitem_c(key)  # pylint: disable=pointless-statement
            return True
        except LookupError:
            return False

    cdef inline object getitem_c(Forker self, str key):
        cdef ForkerData cdata = self.cdata()
        while True:
            try:
                return cdata.peak[key]
            except KeyError:
                if cdata.parent is cdata:
                    raise LookupError(key)
                cdata = cdata.parent

    def __getitem__(Forker self, str key) -> Any:
        return self.getitem_c(key)

    cdef void setitem_c0(Forker self, dict peak, str key, object value: Any):
        peak[key] = value

    cdef inline void setitem_c(Forker self, str key, object value: Any):
        self.setitem_c0(self.cdata().peak, key, value)

    def __setitem__(Forker self, key: str, value: Any) -> None:
        self.setitem_c(key, value)