Skip to content

OOM in convex_narrowphase and memory pre-allocation in Data #926

@acrlw

Description

@acrlw

Hi team,

I've been using mujoco_warp for RL training involved with MyoSim musculoskeletal models.

Module mujoco_warp._src.solver 3a2dfc2 load on device 'cuda:0' took 3.03 ms  (cached)
  Step 0: JAX mem used: 0.20 GiB, peak: 0.57 GiB, limit: 18.95 GiB
  Step 0: Warp mempool used: 0.24 GiB, high-water: 5.00 GiB
  Step 4751360: JAX mem used: 0.21 GiB, peak: 2.70 GiB, limit: 18.95 GiB
  Step 4751360: Warp mempool used: 0.24 GiB, high-water: 5.00 GiB
  ...
  Step 114032640: JAX mem used: 0.23 GiB, peak: 2.70 GiB, limit: 18.95 GiB
  Step 114032640: Warp mempool used: 0.24 GiB, high-water: 5.00 GiB
Warp CUDA error 2: out of memory (in function wp_alloc_device_async, .../warp/native/warp.cu:769)
Traceback (most recent call last):
  File ".../site-packages/warp/_src/jax_experimental/ffi.py", line 708, in ffi_callback
    self.func(*arg_list)
  File ".../lib/rl/envs/mjx/mjwarp_env.py", line 141, in mjwarp_step_kernel
    mjwarp.step(m, d)
  File ".../mujoco_warp/_src/warp_util.py", line 99, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".../mujoco_warp/_src/forward.py", line 939, in step
    forward(m, d)
  File ".../mujoco_warp/_src/warp_util.py", line 99, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".../mujoco_warp/_src/forward.py", line 911, in forward
    fwd_position(m, d, factorize=False)
  File ".../mujoco_warp/_src/warp_util.py", line 99, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".../mujoco_warp/_src/forward.py", line 519, in fwd_position
    collision_driver.collision(m, d)
  File ".../mujoco_warp/_src/warp_util.py", line 99, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".../mujoco_warp/_src/collision_driver.py", line 729, in collision
    _narrowphase(m, d)
  File ".../mujoco_warp/_src/collision_driver.py", line 693, in _narrowphase
    convex_narrowphase(m, d)
  File ".../mujoco_warp/_src/warp_util.py", line 99, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File ".../mujoco_warp/_src/collision_convex.py", line 998, in convex_narrowphase
    epa_face = wp.empty(shape=(d.naconmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=wp.vec3i)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/warp/_src/context.py", line 5995, in empty
    return warp.array(shape=shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/warp/_src/types.py", line 2537, in __init__
    self._init_new(dtype, shape, strides, device, pinned)
  File ".../site-packages/warp/_src/types.py", line 2932, in _init_new
    ptr = allocator.alloc(capacity)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/warp/_src/context.py", line 2904, in alloc
    raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device '{self.device}'")
RuntimeError: Failed to allocate 176160768 bytes on device 'cuda:0'

E1216 11:55:02.916320 2884160 pjrt_stream_executor_client.cc:2085] Execution of replica 0 failed: UNKNOWN: FFI callback error: RuntimeError: Failed to allocate 176160768 bytes on device 'cuda:0'
Traceback (most recent call last):
  File ".../lib/rl/train/mjx/./mjwarp_train.py", line 364, in <module>
    make_inference_fn, params, _ = train_fn(
                                   ^^^^^^^^^
  File ".../site-packages/brax/training/agents/ppo/train.py", line 707, in train
    training_epoch_with_timing(training_state, env_state, epoch_keys)
  File ".../site-packages/brax/training/agents/ppo/train.py", line 580, in training_epoch_with_timing
    result = training_epoch(training_state, env_state, key)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../site-packages/jax/_src/pmap.py", line 72, in wrapped
    outs = jitted_f(*flat_global_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: UNKNOWN: FFI callback error: RuntimeError: Failed to allocate 176160768 bytes on device 'cuda:0'

Checking nvtop showed that there was still ~20GB of VRAM available. I managed to solve this issue by moving the temporary workspace allocation out of the loop and persisting it within the Data class .

Here are the changes I made ( The comment "# + NEW ADD" in the code. ):

collision_convex.py

@event_scope
def convex_narrowphase(m: Model, d: Data):
  """Runs narrowphase collision detection for convex geom pairs.

  This function handles collision detection for pairs of convex geometries that were
  identified during the broadphase. It uses the Gilbert-Johnson-Keerthi (GJK) algorithm to
  determine the distance between shapes and the Expanding Polytope Algorithm (EPA) to find
  the penetration depth and contact normal for colliding pairs.

  The convex geom types handled by this function are `SPHERE`, `CAPSULE`, `ELLIPSOID`, `CYLINDER`,
  `BOX`, `MESH`, `HFIELD`.

  To optimize performance, this function dynamically builds and launches a specialized
  kernel for each type of convex collision pair present in the model, avoiding unnecessary
  computations for non-existent pair types.
  """

  def _pair_count(p1: int, p2: int) -> int:
    return m.geom_pair_type_count[upper_trid_index(len(GeomType), p1, p2)]

  # no convex collisions, early return
  if not any(_pair_count(g[0].value, g[1].value) for g in _CONVEX_COLLISION_PAIRS):
    return

  epa_iterations = m.opt.ccd_iterations

  # set to true to enable multiccd
  use_multiccd = False

  # + NEW ADD
  epa_ws = d.epa_workspace
  epa_vert = epa_ws.epa_vert
  epa_vert1 = epa_ws.epa_vert1
  epa_vert2 = epa_ws.epa_vert2
  epa_vert_index1 = epa_ws.epa_vert_index1
  epa_vert_index2 = epa_ws.epa_vert_index2
  epa_face = epa_ws.epa_face
  epa_pr = epa_ws.epa_pr
  epa_norm2 = epa_ws.epa_norm2
  epa_index = epa_ws.epa_index
  epa_map = epa_ws.epa_map
  epa_horizon = epa_ws.epa_horizon
  multiccd_polygon = epa_ws.multiccd_polygon
  multiccd_clipped = epa_ws.multiccd_clipped
  multiccd_pnormal = epa_ws.multiccd_pnormal
  multiccd_pdist = epa_ws.multiccd_pdist
  multiccd_idx1 = epa_ws.multiccd_idx1
  multiccd_idx2 = epa_ws.multiccd_idx2
  multiccd_n1 = epa_ws.multiccd_n1
  multiccd_n2 = epa_ws.multiccd_n2
  multiccd_endvert = epa_ws.multiccd_endvert
  multiccd_face1 = epa_ws.multiccd_face1
  multiccd_face2 = epa_ws.multiccd_face2

  for geom_pair in _CONVEX_COLLISION_PAIRS:
    g1 = geom_pair[0].value
    g2 = geom_pair[1].value
    if _pair_count(g1, g2):
      wp.launch(
        ccd_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, use_multiccd),
        dim=d.naconmax,
        inputs=[
          m.opt.ccd_tolerance,
          m.geom_type,
          m.geom_condim,
          m.geom_dataid,
          m.geom_priority,
          m.geom_solmix,
          m.geom_solref,
          m.geom_solimp,
          m.geom_size,
          m.geom_aabb,
          m.geom_rbound,
          m.geom_friction,
          m.geom_margin,
          m.geom_gap,
          m.hfield_adr,
          m.hfield_nrow,
          m.hfield_ncol,
          m.hfield_size,
          m.hfield_data,
          m.mesh_vertadr,
          m.mesh_vertnum,
          m.mesh_vert,
          m.mesh_graphadr,
          m.mesh_graph,
          m.mesh_polynum,
          m.mesh_polyadr,
          m.mesh_polynormal,
          m.mesh_polyvertadr,
          m.mesh_polyvertnum,
          m.mesh_polyvert,
          m.mesh_polymapadr,
          m.mesh_polymapnum,
          m.mesh_polymap,
          m.pair_dim,
          m.pair_solref,
          m.pair_solreffriction,
          m.pair_solimp,
          m.pair_margin,
          m.pair_gap,
          m.pair_friction,
          d.naconmax,
          d.geom_xpos,
          d.geom_xmat,
          d.collision_pair,
          d.collision_pairid,
          d.collision_worldid,
          d.ncollision,
          epa_vert,
          epa_vert1,
          epa_vert2,
          epa_vert_index1,
          epa_vert_index2,
          epa_face,
          epa_pr,
          epa_norm2,
          epa_index,
          epa_map,
          epa_horizon,
          multiccd_polygon,
          multiccd_clipped,
          multiccd_pnormal,
          multiccd_pdist,
          multiccd_idx1,
          multiccd_idx2,
          multiccd_n1,
          multiccd_n2,
          multiccd_endvert,
          multiccd_face1,
          multiccd_face2,
        ],
        outputs=[
          d.nacon,
          d.contact.dist,
          d.contact.pos,
          d.contact.frame,
          d.contact.includemargin,
          d.contact.friction,
          d.contact.solref,
          d.contact.solreffriction,
          d.contact.solimp,
          d.contact.dim,
          d.contact.geom,
          d.contact.worldid,
          d.contact.type,
          d.contact.geomcollisionid,
        ],
      )

io.py

def make_data(
  mjm: mujoco.MjModel,
  nworld: int = 1,
  nconmax: Optional[int] = None,
  njmax: Optional[int] = None,
  naconmax: Optional[int] = None,
) -> types.Data:
  """Creates a data object on device.

  Args:
    mjm: The model containing kinematic and dynamic information (host).
    nworld: Number of worlds.
    nconmax: Number of contacts to allocate per world. Contacts exist in large
             heterogeneous arrays: one world may have more than nconmax contacts.
    njmax: Number of constraints to allocate per world. Constraint arrays are
           batched by world: no world may have more than njmax constraints.
    naconmax: Number of contacts to allocate for all worlds. Overrides nconmax.

  Returns:
    The data object containing the current state and output arrays (device).
  """
  # TODO(team): move nconmax, njmax to Model?
  # TODO(team): improve heuristic for nconmax and njmax
  nconmax = nconmax or 20
  njmax = njmax or nconmax * 6

  if nworld < 1:
    raise ValueError(f"nworld must be >= 1")

  if naconmax is None:
    if nconmax < 0:
      raise ValueError("nconmax must be >= 0")
    naconmax = max(512, nworld * nconmax)
  elif naconmax < 0:
    raise ValueError("naconmax must be >= 0")

  if njmax < 0:
    raise ValueError("njmax must be >= 0")

  sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int})
  sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max()
  sizes["nmaxpyramid"] = np.maximum(1, 2 * (sizes["nmaxcondim"] - 1))
  tile_size = types.TILE_SIZE_JTDAJ_SPARSE if is_sparse(mjm) else types.TILE_SIZE_JTDAJ_DENSE
  sizes["njmax_pad"], sizes["nv_pad"] = _get_padded_sizes(mjm.nv, njmax, is_sparse(mjm), tile_size)
  sizes["nworld"] = nworld
  sizes["naconmax"] = naconmax
  sizes["njmax"] = njmax

  contact = types.Contact(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Contact)})
  efc = types.Constraint(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Constraint)})

  # world body and static geom (attached to the world) poses are precomputed
  # this speeds up scenes with many static geoms (e.g. terrains)
  # TODO(team): remove this when we introduce dof islands + sleeping
  mjd = mujoco.MjData(mjm)
  mujoco.mj_kinematics(mjm, mjd)

  # mocap
  mocap_body = np.nonzero(mjm.body_mocapid >= 0)[0]
  mocap_id = mjm.body_mocapid[mocap_body]

  d_kwargs = {
    "qpos": wp.array(np.tile(mjm.qpos0, nworld), shape=(nworld, mjm.nq), dtype=float),
    "contact": contact,
    "efc": efc,
    "nworld": nworld,
    "naconmax": naconmax,
    "njmax": njmax,
    "qM": None,
    "qLD": None,
    "epa_workspace": None,  # +NEW ADD
    # world body
    "xquat": wp.array(np.tile(mjd.xquat, (nworld, 1)), shape=(nworld, mjm.nbody), dtype=wp.quat),
    "xmat": wp.array(np.tile(mjd.xmat, (nworld, 1)), shape=(nworld, mjm.nbody), dtype=wp.mat33),
    "ximat": wp.array(np.tile(mjd.ximat, (nworld, 1)), shape=(nworld, mjm.nbody), dtype=wp.mat33),
    # static geoms
    "geom_xpos": wp.array(np.tile(mjd.geom_xpos, (nworld, 1)), shape=(nworld, mjm.ngeom), dtype=wp.vec3),
    "geom_xmat": wp.array(np.tile(mjd.geom_xmat, (nworld, 1)), shape=(nworld, mjm.ngeom), dtype=wp.mat33),
    # mocap
    "mocap_pos": wp.array(np.tile(mjm.body_pos[mocap_body[mocap_id]], (nworld, 1)), shape=(nworld, mjm.nmocap), dtype=wp.vec3),
    "mocap_quat": wp.array(
      np.tile(mjm.body_quat[mocap_body[mocap_id]], (nworld, 1)), shape=(nworld, mjm.nmocap), dtype=wp.quat
    ),
    # equality constraints
    "eq_active": wp.array(np.tile(mjm.eq_active0.astype(bool), (nworld, 1)), shape=(nworld, mjm.neq), dtype=bool),
  }
  for f in dataclasses.fields(types.Data):
    if f.name in d_kwargs:
      continue
    d_kwargs[f.name] = _create_array(None, f.type, sizes)

  d = types.Data(**d_kwargs)

  if is_sparse(mjm):
    d.qM = wp.zeros((nworld, 1, mjm.nM), dtype=float)
    d.qLD = wp.zeros((nworld, 1, mjm.nC), dtype=float)
  else:
    d.qM = wp.zeros((nworld, sizes["nv_pad"], sizes["nv_pad"]), dtype=float)
    d.qLD = wp.zeros((nworld, mjm.nv, mjm.nv), dtype=float)

  # +NEW ADD
  ccd_iterations = mjm.opt.ccd_iterations
  epa_vert_size = 5 + ccd_iterations
  epa_face_size = 6 + types.MJ_MAX_EPAFACES * ccd_iterations
  nmaxpolygon = 0
  nmaxmeshdeg = 0

  d.epa_workspace = types.EpaWorkspace(
    epa_vert=wp.empty(shape=(naconmax, epa_vert_size), dtype=wp.vec3),
    epa_vert1=wp.empty(shape=(naconmax, epa_vert_size), dtype=wp.vec3),
    epa_vert2=wp.empty(shape=(naconmax, epa_vert_size), dtype=wp.vec3),
    epa_vert_index1=wp.empty(shape=(naconmax, epa_vert_size), dtype=int),
    epa_vert_index2=wp.empty(shape=(naconmax, epa_vert_size), dtype=int),
    epa_face=wp.empty(shape=(naconmax, epa_face_size), dtype=wp.vec3i),
    epa_pr=wp.empty(shape=(naconmax, epa_face_size), dtype=wp.vec3),
    epa_norm2=wp.empty(shape=(naconmax, epa_face_size), dtype=float),
    epa_index=wp.empty(shape=(naconmax, epa_face_size), dtype=int),
    epa_map=wp.empty(shape=(naconmax, epa_face_size), dtype=int),
    epa_horizon=wp.empty(shape=(naconmax, 2 * types.MJ_MAX_EPAHORIZON), dtype=int),
    multiccd_polygon=wp.empty(shape=(naconmax, max(1, 2 * nmaxpolygon)), dtype=wp.vec3),
    multiccd_clipped=wp.empty(shape=(naconmax, max(1, 2 * nmaxpolygon)), dtype=wp.vec3),
    multiccd_pnormal=wp.empty(shape=(naconmax, max(1, nmaxpolygon)), dtype=wp.vec3),
    multiccd_pdist=wp.empty(shape=(naconmax, max(1, nmaxpolygon)), dtype=float),
    multiccd_idx1=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=int),
    multiccd_idx2=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=int),
    multiccd_n1=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=wp.vec3),
    multiccd_n2=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=wp.vec3),
    multiccd_endvert=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=wp.vec3),
    multiccd_face1=wp.empty(shape=(naconmax, max(1, nmaxpolygon)), dtype=wp.vec3),
    multiccd_face2=wp.empty(shape=(naconmax, max(1, nmaxpolygon)), dtype=wp.vec3),
  )

  return d


def put_data(
  mjm: mujoco.MjModel,
  mjd: mujoco.MjData,
  nworld: int = 1,
  nconmax: Optional[int] = None,
  njmax: Optional[int] = None,
  naconmax: Optional[int] = None,
) -> types.Data:
  """Moves data from host to a device.

  Args:
    mjm: The model containing kinematic and dynamic information (host).
    mjd: The data object containing current state and output arrays (host).
    nworld: The number of worlds.
    nconmax: Number of contacts to allocate per world.  Contacts exist in large
             heterogenous arrays: one world may have more than nconmax contacts.
    njmax: Number of constraints to allocate per world.  Constraint arrays are
           batched by world: no world may have more than njmax constraints.
    naconmax: Number of contacts to allocate for all worlds. Overrides nconmax.

  Returns:
    The data object containing the current state and output arrays (device).
  """
  # TODO(team): move nconmax and njmax to Model?
  # TODO(team): decide what to do about uninitialized warp-only fields created by put_data
  #             we need to ensure these are only workspace fields and don't carry state

  # TODO(team): better heuristic for nconmax and njmax
  nconmax = nconmax or max(5, 4 * mjd.ncon)
  njmax = njmax or max(5, 4 * mjd.nefc)

  if nworld < 1:
    raise ValueError(f"nworld must be >= 1")

  if naconmax is None:
    if nconmax < 0:
      raise ValueError("nconmax must be >= 0")

    if mjd.ncon > nconmax:
      raise ValueError(f"nconmax overflow (nconmax must be >= {mjd.ncon})")

    naconmax = max(512, nworld * nconmax)
  elif naconmax < mjd.ncon * nworld:
    raise ValueError(f"naconmax overflow (naconmax must be >= {mjd.ncon * nworld})")

  if njmax < 0:
    raise ValueError("njmax must be >= 0")

  if mjd.nefc > njmax:
    raise ValueError(f"njmax overflow (njmax must be >= {mjd.nefc})")

  sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int})
  sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max()
  sizes["nmaxpyramid"] = np.maximum(1, 2 * (sizes["nmaxcondim"] - 1))
  tile_size = types.TILE_SIZE_JTDAJ_SPARSE if is_sparse(mjm) else types.TILE_SIZE_JTDAJ_DENSE
  sizes["njmax_pad"], sizes["nv_pad"] = _get_padded_sizes(mjm.nv, njmax, is_sparse(mjm), tile_size)
  sizes["nworld"] = nworld
  sizes["naconmax"] = naconmax
  sizes["njmax"] = njmax

  # ensure static geom positions are computed
  # TODO: remove once MjData creation semantics are fixed
  mujoco.mj_kinematics(mjm, mjd)

  # create contact
  contact_kwargs = {"efc_address": None, "worldid": None, "type": None, "geomcollisionid": None}
  for f in dataclasses.fields(types.Contact):
    if f.name in contact_kwargs:
      continue
    val = getattr(mjd.contact, f.name)
    val = np.repeat(val, nworld, axis=0)
    width = ((0, naconmax - val.shape[0]),) + ((0, 0),) * (val.ndim - 1)
    val = np.pad(val, width)
    contact_kwargs[f.name] = _create_array(val, f.type, sizes)

  contact = types.Contact(**contact_kwargs)

  contact.efc_address = np.zeros((naconmax, sizes["nmaxpyramid"]), dtype=int)
  for i in range(mjd.ncon):
    efc_address = mjd.contact.efc_address[i]
    if efc_address == -1:
      continue
    condim = mjd.contact.dim[i]
    ndim = max(1, 2 * (condim - 1)) if mjm.opt.cone == mujoco.mjtCone.mjCONE_PYRAMIDAL else condim
    for j in range(nworld):
      contact.efc_address[j * mjd.ncon + i, :ndim] = efc_address + np.arange(ndim)

  contact.efc_address = wp.array(contact.efc_address, dtype=int)
  contact.worldid = np.pad(np.repeat(np.arange(nworld), mjd.ncon), (0, naconmax - nworld * mjd.ncon))
  contact.worldid = wp.array(contact.worldid, dtype=int)
  contact.type = wp.ones((naconmax,), dtype=int)  # TODO(team): set values
  contact.geomcollisionid = wp.empty((naconmax,), dtype=int)  # TODO(team): set values

  # create efc
  efc_kwargs = {"J": None}

  for f in dataclasses.fields(types.Constraint):
    if f.name in efc_kwargs:
      continue
    shape = tuple(sizes[dim] if isinstance(dim, str) else dim for dim in f.type.shape)
    val = np.zeros(shape, dtype=f.type.dtype)
    if f.name in ("type", "id", "pos", "margin", "D", "vel", "aref", "frictionloss", "force"):
      val[:, : mjd.nefc] = np.tile(getattr(mjd, "efc_" + f.name), (nworld, 1))
    efc_kwargs[f.name] = wp.array(val, dtype=f.type.dtype)

  efc = types.Constraint(**efc_kwargs)

  if mujoco.mj_isSparse(mjm):
    efc_j = np.zeros((mjd.nefc, mjm.nv))
    mujoco.mju_sparse2dense(efc_j, mjd.efc_J, mjd.efc_J_rownnz, mjd.efc_J_rowadr, mjd.efc_J_colind)
  else:
    efc_j = mjd.efc_J.reshape((mjd.nefc, mjm.nv))
  efc.J = np.zeros((nworld, sizes["njmax_pad"], sizes["nv_pad"]), dtype=f.type.dtype)
  efc.J[:, : mjd.nefc, : mjm.nv] = np.tile(efc_j, (nworld, 1, 1))
  efc.J = wp.array(efc.J, dtype=float)

  # create data
  d_kwargs = {
    "contact": contact,
    "efc": efc,
    "nworld": nworld,
    "naconmax": naconmax,
    "njmax": njmax,
    # fields set after initialization:
    "solver_niter": None,
    "qM": None,
    "qLD": None,
    "ten_J": None,
    "actuator_moment": None,
    "flexedge_J": None,
    "nacon": None,
    "ne_connect": None,
    "ne_weld": None,
    "ne_jnt": None,
    "ne_ten": None,
    "ne_flex": None,
    "nsolving": None,
    "epa_workspace": None,  # + NEW ADD
  }
  for f in dataclasses.fields(types.Data):
    if f.name in d_kwargs:
      continue
    val = getattr(mjd, f.name, None)
    if val is not None:
      shape = val.shape if hasattr(val, "shape") else ()
      val = np.full((nworld,) + shape, val)
    d_kwargs[f.name] = _create_array(val, f.type, sizes)

  d = types.Data(**d_kwargs)
  d.solver_niter = wp.full((nworld,), mjd.solver_niter[0], dtype=int)

  if is_sparse(mjm):
    d.qM = wp.array(np.full((nworld, 1, mjm.nM), mjd.qM), dtype=float)
    d.qLD = wp.array(np.full((nworld, 1, mjm.nC), mjd.qLD), dtype=float)
  else:
    qM = np.zeros((mjm.nv, mjm.nv))
    mujoco.mj_fullM(mjm, qM, mjd.qM)
    qLD = np.linalg.cholesky(qM) if (mjd.qM != 0.0).any() and (mjd.qLD != 0.0).any() else np.zeros((mjm.nv, mjm.nv))
    padding = sizes["nv_pad"] - mjm.nv
    qM_padded = np.pad(qM, ((0, padding), (0, padding)), mode="constant", constant_values=0.0)
    d.qM = wp.array(np.full((nworld, sizes["nv_pad"], sizes["nv_pad"]), qM_padded), dtype=float)
    d.qLD = wp.array(np.full((nworld, mjm.nv, mjm.nv), qLD), dtype=float)

  if mujoco.mj_isSparse(mjm):
    ten_J = np.zeros((mjm.ntendon, mjm.nv))
    mujoco.mju_sparse2dense(ten_J, mjd.ten_J.reshape(-1), mjd.ten_J_rownnz, mjd.ten_J_rowadr, mjd.ten_J_colind.reshape(-1))
    d.ten_J = wp.array(np.full((nworld, mjm.ntendon, mjm.nv), ten_J), dtype=float)
    flexedge_J = np.zeros((mjm.nflexedge, mjm.nv))
    mujoco.mju_sparse2dense(
      flexedge_J, mjd.flexedge_J.reshape(-1), mjd.flexedge_J_rownnz, mjd.flexedge_J_rowadr, mjd.flexedge_J_colind.reshape(-1)
    )
    d.flexedge_J = wp.array(np.full((nworld, mjm.nflexedge, mjm.nv), flexedge_J), dtype=float)
  else:
    ten_J = mjd.ten_J.reshape((mjm.ntendon, mjm.nv))
    d.ten_J = wp.array(np.full((nworld, mjm.ntendon, mjm.nv), ten_J), dtype=float)
    flexedge_J = mjd.flexedge_J.reshape((mjm.nflexedge, mjm.nv))
    d.flexedge_J = wp.array(np.full((nworld, mjm.nflexedge, mjm.nv), flexedge_J), dtype=float)

  # TODO(taylorhowell): sparse actuator_moment
  actuator_moment = np.zeros((mjm.nu, mjm.nv))
  mujoco.mju_sparse2dense(actuator_moment, mjd.actuator_moment, mjd.moment_rownnz, mjd.moment_rowadr, mjd.moment_colind)
  d.actuator_moment = wp.array(np.full((nworld, mjm.nu, mjm.nv), actuator_moment), dtype=float)

  d.nacon = wp.array([mjd.ncon * nworld], dtype=int)
  d.ne_connect = wp.full(nworld, 3 * np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_CONNECT) & mjd.eq_active), dtype=int)
  d.ne_weld = wp.full(nworld, 6 * np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_WELD) & mjd.eq_active), dtype=int)
  d.ne_jnt = wp.full(nworld, np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_JOINT) & mjd.eq_active), dtype=int)
  d.ne_ten = wp.full(nworld, np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_TENDON) & mjd.eq_active), dtype=int)
  d.ne_flex = wp.full(nworld, np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_FLEX) & mjd.eq_active), dtype=int)
  d.nsolving = wp.array([nworld], dtype=int)

  # + NEW ADD
  ccd_iterations = mjm.opt.ccd_iterations
  epa_vert_size = 5 + ccd_iterations
  epa_face_size = 6 + types.MJ_MAX_EPAFACES * ccd_iterations
  nmaxpolygon = 0
  nmaxmeshdeg = 0

  d.epa_workspace = types.EpaWorkspace(
    epa_vert=wp.empty(shape=(naconmax, epa_vert_size), dtype=wp.vec3),
    epa_vert1=wp.empty(shape=(naconmax, epa_vert_size), dtype=wp.vec3),
    epa_vert2=wp.empty(shape=(naconmax, epa_vert_size), dtype=wp.vec3),
    epa_vert_index1=wp.empty(shape=(naconmax, epa_vert_size), dtype=int),
    epa_vert_index2=wp.empty(shape=(naconmax, epa_vert_size), dtype=int),
    epa_face=wp.empty(shape=(naconmax, epa_face_size), dtype=wp.vec3i),
    epa_pr=wp.empty(shape=(naconmax, epa_face_size), dtype=wp.vec3),
    epa_norm2=wp.empty(shape=(naconmax, epa_face_size), dtype=float),
    epa_index=wp.empty(shape=(naconmax, epa_face_size), dtype=int),
    epa_map=wp.empty(shape=(naconmax, epa_face_size), dtype=int),
    epa_horizon=wp.empty(shape=(naconmax, 2 * types.MJ_MAX_EPAHORIZON), dtype=int),
    multiccd_polygon=wp.empty(shape=(naconmax, max(1, 2 * nmaxpolygon)), dtype=wp.vec3),
    multiccd_clipped=wp.empty(shape=(naconmax, max(1, 2 * nmaxpolygon)), dtype=wp.vec3),
    multiccd_pnormal=wp.empty(shape=(naconmax, max(1, nmaxpolygon)), dtype=wp.vec3),
    multiccd_pdist=wp.empty(shape=(naconmax, max(1, nmaxpolygon)), dtype=float),
    multiccd_idx1=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=int),
    multiccd_idx2=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=int),
    multiccd_n1=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=wp.vec3),
    multiccd_n2=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=wp.vec3),
    multiccd_endvert=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=wp.vec3),
    multiccd_face1=wp.empty(shape=(naconmax, max(1, nmaxpolygon)), dtype=wp.vec3),
    multiccd_face2=wp.empty(shape=(naconmax, max(1, nmaxpolygon)), dtype=wp.vec3),
  )

  return d

types.py

@dataclasses.dataclass
class EpaWorkspace:
  epa_vert: wp.array2d(dtype=wp.vec3)
  epa_vert1: wp.array2d(dtype=wp.vec3)
  epa_vert2: wp.array2d(dtype=wp.vec3)
  epa_vert_index1: wp.array2d(dtype=int)
  epa_vert_index2: wp.array2d(dtype=int)
  epa_face: wp.array2d(dtype=wp.vec3i)
  epa_pr: wp.array2d(dtype=wp.vec3)
  epa_norm2: wp.array2d(dtype=float)
  epa_index: wp.array2d(dtype=int)
  epa_map: wp.array2d(dtype=int)
  epa_horizon: wp.array2d(dtype=int)
  multiccd_polygon: wp.array2d(dtype=wp.vec3)
  multiccd_clipped: wp.array2d(dtype=wp.vec3)
  multiccd_pnormal: wp.array2d(dtype=wp.vec3)
  multiccd_pdist: wp.array2d(dtype=float)
  multiccd_idx1: wp.array2d(dtype=int)
  multiccd_idx2: wp.array2d(dtype=int)
  multiccd_n1: wp.array2d(dtype=wp.vec3)
  multiccd_n2: wp.array2d(dtype=wp.vec3)
  multiccd_endvert: wp.array2d(dtype=wp.vec3)
  multiccd_face1: wp.array2d(dtype=wp.vec3)
  multiccd_face2: wp.array2d(dtype=wp.vec3)

With these changes, the training runs stably without OOM. I wonder if these changes are reasonable? Any feedback on the implementation would be appreciated!

Thanks !

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions