Skip to content

replacing pmap docs with shard #805

@theorashid

Description

@theorashid

jax.pmap is being deprecated (discussion, docs) in favour of the sharding API.

I read this while flicking through the jax docs and thought about the main use of pmap I know: sampling multiple chains in blackjax/numpyro.

Tagging @dfm because I saw him rewrite something as shard_map elsewhere, and he knows both projects well.

Current behaviour

The current docs for sampling on multiple chains explain how to do this with pmap. I do like pmap, it's clean and good for this use case.

Desired behaviour

Sharding is more flexible than pmap, but it is less straightforward here – I haven't quite worked out how to do this yet (below). Once we have, I'll happily update the docs. I think this will be useful because:

  • pmap is being deprecated, and I think we should match the latest jax docs (or at least have it there as an alternative). There's a mention of sharding (although old sharding API) in the numpyro docs for multiple accelerators
  • Pretty much all the examples of sharding in practice are neural network stuff. It would be nice to have another practical example here.

I would also like to add jax.debug.visualize_array_sharding(initial_states_sharded.position["loc"]) to the docs, which I think is very helpful visualisation from the shard_map notebook.

Progress so far

Running the notebook locally and inspecting pmap_states.position["loc"].sharding, we get NamedSharding(mesh=Mesh('<axis 0x10db00e00>': 8, axis_types=(Auto,)), spec=PartitionSpec('<axis 0x10db00e00>',), memory_kind=device). So let's create a mesh of same shape, but with a useful name chain.

import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

mesh = jax.make_mesh((num_chains,), ('chain',))  # or mesh = Mesh(jax.devices(), ('chain',))
spec = P('chain',)
named_sharding = NamedSharding(mesh, spec)

initial_states_sharded = jax.device_put(initial_states, named_sharding)

jax.debug.visualize_array_sharding(initial_states_sharded.position["loc"]) # correctly showing the 8 cpus I set up

The inference loop from earlier in the notebook is defined as

def inference_loop(rng_key, kernel, initial_state, num_samples):

    @jax.jit
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

and pmap applied as

inference_loop_multiple_chains = jax.pmap(inference_loop, in_axes=(0, None, 0, None), static_broadcasted_argnums=(1, 3))

I've tried a few things to get jit + shard_map or just jit working, but no success so far. I couldn't get enough from the docs to work through the static arguments etc. Any ideas?

Cheers and thanks,
Theo

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions