Multiprocessing with shared_memory (potential fix for bug 82300)

Hello -
I need to solve bug 82300 on windows and will probably upset a number of people if I tell them to wait for a release.

Below I describe how the implementation solves the problem for my use cases, but I’d like a second pair of eyes on it.

  • Does it allow issue 82300 to progress?
  • Is it helpful at all?
  • Is there perhaps a better way I’m not aware of?

Thank you again.


Using shared memory objects make no sense outside a multi processing context, so I’ll take that for granted in the rest of this post. I’ll use two names: MAIN and SUB where sub is any subprocess started by main.

The problem (as I experience it) is that shared_memory objects (SMOs) are garbage collected before the handover has happened.

Case 1: Example that works because ref count > 0.

  • MAIN creates SMO and shares the SMO-name with SUB.
  • SUB accesses SMO, does work and closes.

Case 2: Example that doesn’t work because ref count == 0 before MAIN connext.

  • MAIN creates task for SUB.
  • SUB creates SMO and shares SMO-name with MAIN through a queue.
  • MAIN is busy and doesn’t see SMO-name until SUB has left the name space where the last SMO reference existed.

I can solve the last example using the robust Notify-Acknowledge-Transfer (NAT) protocol, but I really don’t want to if python’s resource tracker has a better approach.

The NAT protocol solves case 2 as follows:

step description
NOTIFY SUB has ref count, and notifies MAIN of SMO creation name
ACK MAIN creates ref count and acknowledges SMO existence to SUB
TRANSFER COMPLETE SUB destroys ref count and declares the transfer as “complete”.

High level implementation details:

MAIN has global dict SHMmain and SUB has global dict SHMsub which are used as hard references to avoid GC count to reach zero. The process is roughly as follows:

  1. SUB creates SMO
  2. SUB sets hard ref in SHMsub[SMO.name] = SMO
  3. SUB notifies MAIN about SMO.name
  4. SUB does value-adding work and enters sleeps mode as long as SHMsub is not empty (e.g. won’t exit)
  5. (delay)
  6. MAIN reads notification from SUB, connects to SMO and creates entry in SHMmain[SMO.name] = SMO
  7. MAIN returns acceptance to SUB
  8. SUB either finishes work or awakens from sleep and reads the acceptance message from MAIN.
  9. SUB can now remove the hard reference as del SHMsub[SMO.name] and let gc do it’s work.
  10. after deletion of the reference SUB sends “transfer complete” message to MAIN which signals that MAIN is now allowed to remove the hard reference as well.

What if…:

  • If SUB crashes after NOTIFICATION, MAIN will raise OS error for SMO name not found. No big issue.
  • If MAIN is busy for a very long time, NOTIFICATION will not be read. This could eventually lead to OutOfMemory error. Not a problem either.
  • SUB is not blocked by MAIN, as SUB can still collect other tasks from its task queue and keep processing. The receipt of ACKNOWLEDGE merely becomes a task to clear out the SMO.name from SHMsub.

Implementation

There are 5 classes

class description
TaskManager Used like multiprocessing.Pool, except I need it to do more.
Worker wrapper for the subprocess
NATsignal envelope for sharing shm.name between MAIN and SUB
TrackedSharedMemory wrapper for shared_memory.SharedMemory so that I can ref count in both processes
Task An envelope for the workers job Task( f , *args, **kwargs)

The implementation is as follows:

import multiprocessing
from multiprocessing import shared_memory, cpu_count
from tqdm import tqdm   # OPTIONAL
import time
import queue
from abc import ABC
import copy
from itertools import count
import io
import numpy as np  # OPTIONAL
import traceback
from collections import defaultdict


class TaskManager(object):
    shared_memory_references = {}  
    shared_memory_reference_counter = defaultdict(int)  # tracker for the NAT protocol.

    def __init__(self) -> None:    
        self._cpus = cpu_count()
        self.tq = multiprocessing.Queue()  # task queue for workers.
        self.rq = multiprocessing.Queue()  # result queue for workers.
        self.pool = []                     # list of sub processes
        self.pool_sigq = {}                # signal queue for each worker.
        self.tasks = 0                     # counter for task tracking
        
    def __enter__(self):
        self.start()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb): # signature requires these, though I don't use them.
        self.stop()  # stop the workers.
        
        # Clean up on exit.
        for k,v in self.shared_memory_reference_counter.items():
            if k in self.shared_memory_references and v == 0:
                del self.shared_memory_references[k]  # this unlinks the shared memory object,
                # which now can be GC'ed if no other variable points to it.
        
    def start(self):
        for i in range(self._cpus):  # create workers
            name = str(i)
            sigq = multiprocessing.Queue()  # we create one signal queue for each proc.
            self.pool_sigq[name] = sigq
            worker = Worker(name=name, tq=self.tq, rq=self.rq, sigq=sigq)
            self.pool.append(worker)

        with tqdm(total=self._cpus, unit="n", desc="workers ready") as pbar:
            for p in self.pool:
                p.start()

            while True:
                alive = sum(1 if p.is_alive() else 0 for p in self.pool)
                pbar.n = alive
                pbar.refresh()
                if alive < self._cpus:
                    time.sleep(0.01)
                else:
                    break  # all sub processes are alive. exit the setup loop.

    def execute(self, tasks):
        if isinstance(tasks, Task):
            task = (tasks,)
        if not isinstance(tasks, (list,tuple)) or not all([isinstance(i, Task) for i in tasks]):
            raise TypeError

        for t in tasks:
            self.tq.put(t)
            self.tasks += 1  # increment task counter.
        
        results = []  
        with tqdm(total=self.tasks, unit='task') as pbar:
            while self.tasks != 0:
                try:
                    task = self.rq.get_nowait()
                
                    if isinstance(task, NATsignal): 
                        if task.shm_name not in self.shared_memory_references:  # its a NOTIFY from a WORKER.
                            # first create a hard ref to the memory object.
                            self.shared_memory_references[task.shm_name] = TrackedSharedMemory(name=task.shm_name, create=False)
                            self.shared_memory_reference_counter[task.shm_name] += 1
                            # then send the ACKNOWLEDGEMENT directly to the WORKER.
                            self.pool_sigq[task.worker_name].put(task)
                        else:  # It's the second time we see the name so it's a TRANSFER COMPLETE
                            self.shared_memory_reference_counter[task.shm_name] -= 1 
                        # at this point we can be certain that the SHMs are in the main process.
                        continue  # keep looping as there may be more.

                    elif isinstance(task, Task):
                        if task.exception:
                            raise Exception(task.exception)

                        self.tasks -= 1  # decrement task counter.
                        pbar.set_description(task.f.__name__)
                        results.append(task)
                        pbar.update(1)
                    
                except queue.Empty:
                    time.sleep(0.01)
        return results 

    def stop(self):
        for _ in range(self._cpus):  # put enough stop messages for all workers.
            self.tq.put("stop")

        with tqdm(total=len(self.pool), unit="n", desc="workers stopping") as pbar:
            while True:
                not_alive = sum(1 if not p.is_alive() else 0 for p in self.pool)
                pbar.n = not_alive
                pbar.refresh()
                if not_alive < self._cpus:
                    time.sleep(0.01)
                else:
                    break
        self.pool.clear()

        # clear the message queues.
        while not self.tq.empty:  
            _ = self.tq.get_nowait()  
        while not self.rq.empty:
            _ = self.rq.get_nowait()
        
  
class Worker(multiprocessing.Process):
    def __init__(self, name, tq, rq, sigq):
        super().__init__(group=None, target=self.update, name=name, daemon=False)
        self.exit = multiprocessing.Event()
        self.tq = tq  # workers task queue
        self.rq = rq  # workers result queue
        self.sigq = sigq  # worker signal reciept queue.
        
               
    def update(self):
        # this is global for the sub process only.
        TaskManager.shared_memory_references  

        while True:
            # first process any/all direct signals first.
            while True:
                try:
                    ack = self.sigq.get_nowait()   # receive acknowledgement of hard ref to SharedMemoryObject from SIGQ            
                    shm = TaskManager.shared_memory_references.pop(ack.shm_name)  # pop the shm
                    shm.close()  # assure closure of the shm.
                    del TaskManager.shared_memory_reference_counter[ack.shm_name]
                    self.rq.put(ack)  # respond to MAINs RQ that transfer is complete.
                except queue.Empty:
                    break

            # then deal with any tasks...
            try:  
                task = self.tq.get_nowait()
                if task == "stop":
                    self.tq.put_nowait(task)  # this assures that everyone gets the stop signal.
                    self.exit.set()
                    break
                elif isinstance(task, Task):
                    task.execute()
                    
                    for k,v in TaskManager.shared_memory_references.items():
                        if k not in TaskManager.shared_memory_reference_counter:
                            TaskManager.shared_memory_reference_counter[k] = 1
                            self.rq.put(NATsignal(k, self.name))  # send Notify from subprocess to main
                        
                    self.rq.put(task)

                else:
                    raise Exception(f"What is {task}?")
            except queue.Empty:
                time.sleep(0.01)
                continue


class NATsignal(object):
    def __init__(self, shm_name, worker_name):
        """
        shm_name: str: name from shared_memory.
        worker_name: str: required by TaskManager for sending ACK message to worker.
        """
        self.shm_name = shm_name
        self.worker_name = worker_name


class TrackedSharedMemory(shared_memory.SharedMemory):
    def __init__(self, name=None, create=False, size=0) -> None:
        if name in TaskManager.shared_memory_references:
            return TaskManager.shared_memory_references[name]  # return from registry.
        else:
            super().__init__(name, create, size)
            TaskManager.shared_memory_references[self.name] = self  # add to registry. This blocks __del__ !  


class Task(ABC):
    """
    Generic Task class for tasks.
    """
    ids = count(start=1)
    def __init__(self, f, *args, **kwargs) -> None:
        """
        f: callable 
        *args: arguments for f
        **kwargs: keyword arguments for f.
        """
        if not callable(f):
            raise TypeError
        self.task_id = next(self.ids)
        self.f = f
        self.args = copy.deepcopy(args)  # deep copy is slow unless the data is shallow.
        self.kwargs = copy.deepcopy(kwargs)
        self.result = None
        self.exception = None

    def __str__(self) -> str:
        if self.exception:
            return f"Call to {self.f.__name__}(*{self.args}, **{self.kwargs}) --> Error: {self.exception}"
        else:
            return f"Call to {self.f.__name__}(*{self.args}, **{self.kwargs}) --> Result: {self.result}"

    def execute(self):
        """ The worker calls this function. """
        try:
            self.result = self.f(*self.args, **self.kwargs)
        except Exception as e:
            f = io.StringIO()
            traceback.print_exc(limit=3, file=f)
            f.seek(0)
            error = f.read()
            f.close()
            self.exception = error


def cpu_intense_task_with_shared_memory(n):
    # create shared memory object
    arr = np.array(list(range(n)))
    shm = TrackedSharedMemory(create=True, size=arr.nbytes)
    datablock = np.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf)
    datablock[:] = arr[:]  # copy the data.
    # disconnect from the task.
    return shm.name, datablock.shape


if __name__ == "__main__":
    """ test... """
    n = 8

    tasks =[ Task(f=cpu_intense_task_with_shared_memory, n=10**i) for i in range(n) ]
    
    with TaskManager() as tm:  # start sub procs by using the context manager.
        results = tm.execute(tasks)
        results.sort(key=lambda x: x.task_id)

        # collect evidence that it worked.
        assert len(results) == len(tasks)

        result_names, arrays = set(), []
        total = 0 
        for r in results:
            result_name, shape = r.result
            result_names.add(result_name)
            shm = tm.shared_memory_references[result_name]
            data = np.ndarray(shape, dtype=int, buffer=shm.buf)
            total += data.shape[0]  # get the data from the workers.
            
            arrays.append(data)

        tm_names = set(tm.shared_memory_references.keys())
        assert result_names == tm_names, (result_names, tm_names)
        assert total == sum(10**i for i in range(n)), total
    # stop all subprocs by exiting the context mgr.

    # check the data is still around.
    assert sum(len(arr) for arr in arrays) == total


output

(py39) C:\Data>python.exe c:/Data/a_multi_proc_shm_test.py
workers ready: 100%|███████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 69.57n/s]
cpu_intense_task_with_shared_memory: 100%|██████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  5.77task/s]
workers stopping: 100%|████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 51.01n/s] 

(py39tablite) C:\Data>

Thanks again for the second pair of eyes.
Kind regards
Bjorn