-
Notifications
You must be signed in to change notification settings - Fork 123
Description
jax.pmap is being deprecated (discussion, docs) in favour of the sharding API.
I read this while flicking through the jax docs and thought about the main use of pmap I know: sampling multiple chains in blackjax/numpyro.
Tagging @dfm because I saw him rewrite something as shard_map elsewhere, and he knows both projects well.
Current behaviour
The current docs for sampling on multiple chains explain how to do this with pmap. I do like pmap, it's clean and good for this use case.
Desired behaviour
Sharding is more flexible than pmap, but it is less straightforward here – I haven't quite worked out how to do this yet (below). Once we have, I'll happily update the docs. I think this will be useful because:
pmapis being deprecated, and I think we should match the latest jax docs (or at least have it there as an alternative). There's a mention of sharding (although old sharding API) in the numpyro docs for multiple accelerators- Pretty much all the examples of sharding in practice are neural network stuff. It would be nice to have another practical example here.
I would also like to add jax.debug.visualize_array_sharding(initial_states_sharded.position["loc"]) to the docs, which I think is very helpful visualisation from the shard_map notebook.
Progress so far
Running the notebook locally and inspecting pmap_states.position["loc"].sharding, we get NamedSharding(mesh=Mesh('<axis 0x10db00e00>': 8, axis_types=(Auto,)), spec=PartitionSpec('<axis 0x10db00e00>',), memory_kind=device). So let's create a mesh of same shape, but with a useful name chain.
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
mesh = jax.make_mesh((num_chains,), ('chain',)) # or mesh = Mesh(jax.devices(), ('chain',))
spec = P('chain',)
named_sharding = NamedSharding(mesh, spec)
initial_states_sharded = jax.device_put(initial_states, named_sharding)
jax.debug.visualize_array_sharding(initial_states_sharded.position["loc"]) # correctly showing the 8 cpus I set upThe inference loop from earlier in the notebook is defined as
def inference_loop(rng_key, kernel, initial_state, num_samples):
@jax.jit
def one_step(state, rng_key):
state, _ = kernel(rng_key, state)
return state, state
keys = jax.random.split(rng_key, num_samples)
_, states = jax.lax.scan(one_step, initial_state, keys)
return statesand pmap applied as
inference_loop_multiple_chains = jax.pmap(inference_loop, in_axes=(0, None, 0, None), static_broadcasted_argnums=(1, 3))I've tried a few things to get jit + shard_map or just jit working, but no success so far. I couldn't get enough from the docs to work through the static arguments etc. Any ideas?
Cheers and thanks,
Theo