Multiprocessing.Pool - Pass Data to Workers w/o Globals: A Proposal

Intro

Link to Code and Tests

This post introduces a proposal for a new keyword argument in the __init__() method of Pool named expect_initret. This keyword defaults to False, and when it is set to True, the return value of the initializer function is passed to the function we are mapping over as a kwarg. I’ve provided two patterns in the reading ahead which illustrate this feature.

Note: There was a similar issue opened years ago, that got some attention, but was ultimately closed due to backwards compatibility issues. I’ve designed this implementation based off the feedback from this issue.

Pattern 1: Initialize Object in Child Process without Global Scope

This pattern is used to initialize an object after each worker (i.e. subprocess) has been created. Oftentimes the need for this arises when the func we are applying satisfies one of two cases:

  1. func is an instance method, and the instance bound to it contains an item that is is not pickle-able. more reading here
  2. There are global variables that hold onto sockets, like database connections, that should usually not be serialized/passed to children processes.

We will use a SQLAlchemy.Engine object as our example. Our goal: give each worker process its own engine object.

The current implementation of Pool allows for this behavior, however it forces the user to define a global variable in the initializer() function as follows:

def initializer(db_url):
    global sqla_engine
    sqla_engine = create_engine(db_url)

def insert_record(record):
    sqla_engine.execute(table.insert(record))

records = [...]  # Dictionaries of DB records

with Pool(initializer, ("mysql://foo:bar@localhost",)) as pool:
    pool.map(insert_record, records)

Note: There are plenty of arguments for/against global variables. There are also arguments for/against variables being made available outside their lexical scope. I intend not to get into these arguments - the goal here is to provide an alternative to the current globals-only solution to initializing Pool workers.

Using expect_initret, the parallelized insertion of records looks as follows:

def initializer(db_url):
    return create_engine(db_url)

def insert_record(record, initret: sqlalchemy.engine.Engine=None):
    sqla_engine = initret  # For readability's sake
    sqla_engine.execute(table.insert(record))

records = [...]  # Dictionaries of DB records

with Pool(initializer,
          ("mysql://foo:bar@localhost",),
          expect_initret=True) as pool:
    pool.map(insert_record, records)

So, we preserve lexical scoping of the sqlalchemy.Engine object, at the expense of a somewhat ambiguous kwarg named initret to our mapped function insert_record(). These becomes a bit more readable with type-hinting.

Pattern 2: Pass Object from Parent to Child & Avoid Global Scope

The idea here is to create a large object ONCE, like a big map or dictionary, in the parent process, and pass that object to each Pool worker. Specifically, the object will be made available in each workers’ local scope as a parameter to our mapped function.

Let’s consider the dummy problem of counting every “on” bit in all integers smaller than 2**16 (i.e. “10101” => 3 “on” bits).

from multiprocessing.pool import Pool

def initializer(int_to_binary_cache: Dict[int, int]) -> None:
    global int_to_binary_cache
    
def count_bits(i: int) -> int:
    return int_to_binary_cache[i].count("1")

def parallel_bit_counter(int_ls: List[int]) -> int:
    big_int_to_binary_cache = {
        i: bin(i) for i in range(2**16 - 1)
    }
    with Pool(initializer,
              initargs=(big_int_to_binary_cache,)) as p:
        return sum(
            p.imap_unordered(parallel_bit_counter, int_ls))

Note: You can also send data to Pool workers with class attributes, which buys a bit more encapsulation.

With expect_initret, the implementation looks as follows:

from multiprocessing.pool import Pool

def initializer(int_to_binary_cache: Dict[int, int]
                ) -> Dict[int, int]:
    # The identity function
    return int_to_binary_cache

def count_bits(i: int, initret: Dict[int, int]) -> int:
    return initret[i].count("1")

def parallel_bit_counter(int_ls: List[int]) -> int:
    big_int_to_binary_cache = {
        i: bin(i) for i in range(2**16 - 1)
    }
    with Pool(initializer,
              initargs=(big_int_to_binary_cache,),
              expect_initret=True) as p:
        return sum(
            p.imap_unordered(count_bits, int_ls))

Yet again, I am the first to admit that the initret kwarg is somewhat ambiguous. However, the goal is to let Python users choose between the following:

  1. An explicit flow of data, with lexically scoped variables, at the expense of a somewhat ambiguous kwarg, initret.
  2. Preservation of proper variable names, at the expense of an implicit flow of data, with globally scoped variables defined within a function.

Final Thoughts

For those interested, the path to getting stuck deep, deep in the cavernous rabbit hole of Python’s multiprocessing.Pool is as follows:

  1. Get stuck in a pickle while prematurely optimizing an application that predicts the bioactivity of food compounds.
  2. Give a Python Boston User Group talk on how you can very easily do the same!
  3. Have a crazy idea that you can prevent others from your past mistakes by extending a CPython lib!

If you were to take every library written in Python, and…

  1. Count every function that accesses a global variable, or defines a global from within a function via the global keyword
  2. Count every function that does NOT access globals, and adheres to lexical scoping

…the count of (2) would be overwhelmingly higher than (1).

Given that Python users (like me) are more familiar with functions that do not “create global variables as their side-effect”, it is my hope that this API extension, and the examples above, will enable more Python users to use the Pool interface, while preserving every bit of the beautifully abstracted multiprocessing.Pool module.