-
Notifications
You must be signed in to change notification settings - Fork 109
Open
Labels
Description
Hi team,
I've been using mujoco_warp for RL training involved with MyoSim musculoskeletal models.
Module mujoco_warp._src.solver 3a2dfc2 load on device 'cuda:0' took 3.03 ms (cached)
Step 0: JAX mem used: 0.20 GiB, peak: 0.57 GiB, limit: 18.95 GiB
Step 0: Warp mempool used: 0.24 GiB, high-water: 5.00 GiB
Step 4751360: JAX mem used: 0.21 GiB, peak: 2.70 GiB, limit: 18.95 GiB
Step 4751360: Warp mempool used: 0.24 GiB, high-water: 5.00 GiB
...
Step 114032640: JAX mem used: 0.23 GiB, peak: 2.70 GiB, limit: 18.95 GiB
Step 114032640: Warp mempool used: 0.24 GiB, high-water: 5.00 GiB
Warp CUDA error 2: out of memory (in function wp_alloc_device_async, .../warp/native/warp.cu:769)
Traceback (most recent call last):
File ".../site-packages/warp/_src/jax_experimental/ffi.py", line 708, in ffi_callback
self.func(*arg_list)
File ".../lib/rl/envs/mjx/mjwarp_env.py", line 141, in mjwarp_step_kernel
mjwarp.step(m, d)
File ".../mujoco_warp/_src/warp_util.py", line 99, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File ".../mujoco_warp/_src/forward.py", line 939, in step
forward(m, d)
File ".../mujoco_warp/_src/warp_util.py", line 99, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File ".../mujoco_warp/_src/forward.py", line 911, in forward
fwd_position(m, d, factorize=False)
File ".../mujoco_warp/_src/warp_util.py", line 99, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File ".../mujoco_warp/_src/forward.py", line 519, in fwd_position
collision_driver.collision(m, d)
File ".../mujoco_warp/_src/warp_util.py", line 99, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File ".../mujoco_warp/_src/collision_driver.py", line 729, in collision
_narrowphase(m, d)
File ".../mujoco_warp/_src/collision_driver.py", line 693, in _narrowphase
convex_narrowphase(m, d)
File ".../mujoco_warp/_src/warp_util.py", line 99, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File ".../mujoco_warp/_src/collision_convex.py", line 998, in convex_narrowphase
epa_face = wp.empty(shape=(d.naconmax, 6 + MJ_MAX_EPAFACES * epa_iterations), dtype=wp.vec3i)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../site-packages/warp/_src/context.py", line 5995, in empty
return warp.array(shape=shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../site-packages/warp/_src/types.py", line 2537, in __init__
self._init_new(dtype, shape, strides, device, pinned)
File ".../site-packages/warp/_src/types.py", line 2932, in _init_new
ptr = allocator.alloc(capacity)
^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../site-packages/warp/_src/context.py", line 2904, in alloc
raise RuntimeError(f"Failed to allocate {size_in_bytes} bytes on device '{self.device}'")
RuntimeError: Failed to allocate 176160768 bytes on device 'cuda:0'
E1216 11:55:02.916320 2884160 pjrt_stream_executor_client.cc:2085] Execution of replica 0 failed: UNKNOWN: FFI callback error: RuntimeError: Failed to allocate 176160768 bytes on device 'cuda:0'
Traceback (most recent call last):
File ".../lib/rl/train/mjx/./mjwarp_train.py", line 364, in <module>
make_inference_fn, params, _ = train_fn(
^^^^^^^^^
File ".../site-packages/brax/training/agents/ppo/train.py", line 707, in train
training_epoch_with_timing(training_state, env_state, epoch_keys)
File ".../site-packages/brax/training/agents/ppo/train.py", line 580, in training_epoch_with_timing
result = training_epoch(training_state, env_state, key)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../site-packages/jax/_src/pmap.py", line 72, in wrapped
outs = jitted_f(*flat_global_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: UNKNOWN: FFI callback error: RuntimeError: Failed to allocate 176160768 bytes on device 'cuda:0'
Checking nvtop showed that there was still ~20GB of VRAM available. I managed to solve this issue by moving the temporary workspace allocation out of the loop and persisting it within the Data class .
Here are the changes I made ( The comment "# + NEW ADD" in the code. ):
collision_convex.py
@event_scope
def convex_narrowphase(m: Model, d: Data):
"""Runs narrowphase collision detection for convex geom pairs.
This function handles collision detection for pairs of convex geometries that were
identified during the broadphase. It uses the Gilbert-Johnson-Keerthi (GJK) algorithm to
determine the distance between shapes and the Expanding Polytope Algorithm (EPA) to find
the penetration depth and contact normal for colliding pairs.
The convex geom types handled by this function are `SPHERE`, `CAPSULE`, `ELLIPSOID`, `CYLINDER`,
`BOX`, `MESH`, `HFIELD`.
To optimize performance, this function dynamically builds and launches a specialized
kernel for each type of convex collision pair present in the model, avoiding unnecessary
computations for non-existent pair types.
"""
def _pair_count(p1: int, p2: int) -> int:
return m.geom_pair_type_count[upper_trid_index(len(GeomType), p1, p2)]
# no convex collisions, early return
if not any(_pair_count(g[0].value, g[1].value) for g in _CONVEX_COLLISION_PAIRS):
return
epa_iterations = m.opt.ccd_iterations
# set to true to enable multiccd
use_multiccd = False
# + NEW ADD
epa_ws = d.epa_workspace
epa_vert = epa_ws.epa_vert
epa_vert1 = epa_ws.epa_vert1
epa_vert2 = epa_ws.epa_vert2
epa_vert_index1 = epa_ws.epa_vert_index1
epa_vert_index2 = epa_ws.epa_vert_index2
epa_face = epa_ws.epa_face
epa_pr = epa_ws.epa_pr
epa_norm2 = epa_ws.epa_norm2
epa_index = epa_ws.epa_index
epa_map = epa_ws.epa_map
epa_horizon = epa_ws.epa_horizon
multiccd_polygon = epa_ws.multiccd_polygon
multiccd_clipped = epa_ws.multiccd_clipped
multiccd_pnormal = epa_ws.multiccd_pnormal
multiccd_pdist = epa_ws.multiccd_pdist
multiccd_idx1 = epa_ws.multiccd_idx1
multiccd_idx2 = epa_ws.multiccd_idx2
multiccd_n1 = epa_ws.multiccd_n1
multiccd_n2 = epa_ws.multiccd_n2
multiccd_endvert = epa_ws.multiccd_endvert
multiccd_face1 = epa_ws.multiccd_face1
multiccd_face2 = epa_ws.multiccd_face2
for geom_pair in _CONVEX_COLLISION_PAIRS:
g1 = geom_pair[0].value
g2 = geom_pair[1].value
if _pair_count(g1, g2):
wp.launch(
ccd_kernel_builder(g1, g2, m.opt.ccd_iterations, epa_iterations, use_multiccd),
dim=d.naconmax,
inputs=[
m.opt.ccd_tolerance,
m.geom_type,
m.geom_condim,
m.geom_dataid,
m.geom_priority,
m.geom_solmix,
m.geom_solref,
m.geom_solimp,
m.geom_size,
m.geom_aabb,
m.geom_rbound,
m.geom_friction,
m.geom_margin,
m.geom_gap,
m.hfield_adr,
m.hfield_nrow,
m.hfield_ncol,
m.hfield_size,
m.hfield_data,
m.mesh_vertadr,
m.mesh_vertnum,
m.mesh_vert,
m.mesh_graphadr,
m.mesh_graph,
m.mesh_polynum,
m.mesh_polyadr,
m.mesh_polynormal,
m.mesh_polyvertadr,
m.mesh_polyvertnum,
m.mesh_polyvert,
m.mesh_polymapadr,
m.mesh_polymapnum,
m.mesh_polymap,
m.pair_dim,
m.pair_solref,
m.pair_solreffriction,
m.pair_solimp,
m.pair_margin,
m.pair_gap,
m.pair_friction,
d.naconmax,
d.geom_xpos,
d.geom_xmat,
d.collision_pair,
d.collision_pairid,
d.collision_worldid,
d.ncollision,
epa_vert,
epa_vert1,
epa_vert2,
epa_vert_index1,
epa_vert_index2,
epa_face,
epa_pr,
epa_norm2,
epa_index,
epa_map,
epa_horizon,
multiccd_polygon,
multiccd_clipped,
multiccd_pnormal,
multiccd_pdist,
multiccd_idx1,
multiccd_idx2,
multiccd_n1,
multiccd_n2,
multiccd_endvert,
multiccd_face1,
multiccd_face2,
],
outputs=[
d.nacon,
d.contact.dist,
d.contact.pos,
d.contact.frame,
d.contact.includemargin,
d.contact.friction,
d.contact.solref,
d.contact.solreffriction,
d.contact.solimp,
d.contact.dim,
d.contact.geom,
d.contact.worldid,
d.contact.type,
d.contact.geomcollisionid,
],
)io.py
def make_data(
mjm: mujoco.MjModel,
nworld: int = 1,
nconmax: Optional[int] = None,
njmax: Optional[int] = None,
naconmax: Optional[int] = None,
) -> types.Data:
"""Creates a data object on device.
Args:
mjm: The model containing kinematic and dynamic information (host).
nworld: Number of worlds.
nconmax: Number of contacts to allocate per world. Contacts exist in large
heterogeneous arrays: one world may have more than nconmax contacts.
njmax: Number of constraints to allocate per world. Constraint arrays are
batched by world: no world may have more than njmax constraints.
naconmax: Number of contacts to allocate for all worlds. Overrides nconmax.
Returns:
The data object containing the current state and output arrays (device).
"""
# TODO(team): move nconmax, njmax to Model?
# TODO(team): improve heuristic for nconmax and njmax
nconmax = nconmax or 20
njmax = njmax or nconmax * 6
if nworld < 1:
raise ValueError(f"nworld must be >= 1")
if naconmax is None:
if nconmax < 0:
raise ValueError("nconmax must be >= 0")
naconmax = max(512, nworld * nconmax)
elif naconmax < 0:
raise ValueError("naconmax must be >= 0")
if njmax < 0:
raise ValueError("njmax must be >= 0")
sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int})
sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max()
sizes["nmaxpyramid"] = np.maximum(1, 2 * (sizes["nmaxcondim"] - 1))
tile_size = types.TILE_SIZE_JTDAJ_SPARSE if is_sparse(mjm) else types.TILE_SIZE_JTDAJ_DENSE
sizes["njmax_pad"], sizes["nv_pad"] = _get_padded_sizes(mjm.nv, njmax, is_sparse(mjm), tile_size)
sizes["nworld"] = nworld
sizes["naconmax"] = naconmax
sizes["njmax"] = njmax
contact = types.Contact(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Contact)})
efc = types.Constraint(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Constraint)})
# world body and static geom (attached to the world) poses are precomputed
# this speeds up scenes with many static geoms (e.g. terrains)
# TODO(team): remove this when we introduce dof islands + sleeping
mjd = mujoco.MjData(mjm)
mujoco.mj_kinematics(mjm, mjd)
# mocap
mocap_body = np.nonzero(mjm.body_mocapid >= 0)[0]
mocap_id = mjm.body_mocapid[mocap_body]
d_kwargs = {
"qpos": wp.array(np.tile(mjm.qpos0, nworld), shape=(nworld, mjm.nq), dtype=float),
"contact": contact,
"efc": efc,
"nworld": nworld,
"naconmax": naconmax,
"njmax": njmax,
"qM": None,
"qLD": None,
"epa_workspace": None, # +NEW ADD
# world body
"xquat": wp.array(np.tile(mjd.xquat, (nworld, 1)), shape=(nworld, mjm.nbody), dtype=wp.quat),
"xmat": wp.array(np.tile(mjd.xmat, (nworld, 1)), shape=(nworld, mjm.nbody), dtype=wp.mat33),
"ximat": wp.array(np.tile(mjd.ximat, (nworld, 1)), shape=(nworld, mjm.nbody), dtype=wp.mat33),
# static geoms
"geom_xpos": wp.array(np.tile(mjd.geom_xpos, (nworld, 1)), shape=(nworld, mjm.ngeom), dtype=wp.vec3),
"geom_xmat": wp.array(np.tile(mjd.geom_xmat, (nworld, 1)), shape=(nworld, mjm.ngeom), dtype=wp.mat33),
# mocap
"mocap_pos": wp.array(np.tile(mjm.body_pos[mocap_body[mocap_id]], (nworld, 1)), shape=(nworld, mjm.nmocap), dtype=wp.vec3),
"mocap_quat": wp.array(
np.tile(mjm.body_quat[mocap_body[mocap_id]], (nworld, 1)), shape=(nworld, mjm.nmocap), dtype=wp.quat
),
# equality constraints
"eq_active": wp.array(np.tile(mjm.eq_active0.astype(bool), (nworld, 1)), shape=(nworld, mjm.neq), dtype=bool),
}
for f in dataclasses.fields(types.Data):
if f.name in d_kwargs:
continue
d_kwargs[f.name] = _create_array(None, f.type, sizes)
d = types.Data(**d_kwargs)
if is_sparse(mjm):
d.qM = wp.zeros((nworld, 1, mjm.nM), dtype=float)
d.qLD = wp.zeros((nworld, 1, mjm.nC), dtype=float)
else:
d.qM = wp.zeros((nworld, sizes["nv_pad"], sizes["nv_pad"]), dtype=float)
d.qLD = wp.zeros((nworld, mjm.nv, mjm.nv), dtype=float)
# +NEW ADD
ccd_iterations = mjm.opt.ccd_iterations
epa_vert_size = 5 + ccd_iterations
epa_face_size = 6 + types.MJ_MAX_EPAFACES * ccd_iterations
nmaxpolygon = 0
nmaxmeshdeg = 0
d.epa_workspace = types.EpaWorkspace(
epa_vert=wp.empty(shape=(naconmax, epa_vert_size), dtype=wp.vec3),
epa_vert1=wp.empty(shape=(naconmax, epa_vert_size), dtype=wp.vec3),
epa_vert2=wp.empty(shape=(naconmax, epa_vert_size), dtype=wp.vec3),
epa_vert_index1=wp.empty(shape=(naconmax, epa_vert_size), dtype=int),
epa_vert_index2=wp.empty(shape=(naconmax, epa_vert_size), dtype=int),
epa_face=wp.empty(shape=(naconmax, epa_face_size), dtype=wp.vec3i),
epa_pr=wp.empty(shape=(naconmax, epa_face_size), dtype=wp.vec3),
epa_norm2=wp.empty(shape=(naconmax, epa_face_size), dtype=float),
epa_index=wp.empty(shape=(naconmax, epa_face_size), dtype=int),
epa_map=wp.empty(shape=(naconmax, epa_face_size), dtype=int),
epa_horizon=wp.empty(shape=(naconmax, 2 * types.MJ_MAX_EPAHORIZON), dtype=int),
multiccd_polygon=wp.empty(shape=(naconmax, max(1, 2 * nmaxpolygon)), dtype=wp.vec3),
multiccd_clipped=wp.empty(shape=(naconmax, max(1, 2 * nmaxpolygon)), dtype=wp.vec3),
multiccd_pnormal=wp.empty(shape=(naconmax, max(1, nmaxpolygon)), dtype=wp.vec3),
multiccd_pdist=wp.empty(shape=(naconmax, max(1, nmaxpolygon)), dtype=float),
multiccd_idx1=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=int),
multiccd_idx2=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=int),
multiccd_n1=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=wp.vec3),
multiccd_n2=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=wp.vec3),
multiccd_endvert=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=wp.vec3),
multiccd_face1=wp.empty(shape=(naconmax, max(1, nmaxpolygon)), dtype=wp.vec3),
multiccd_face2=wp.empty(shape=(naconmax, max(1, nmaxpolygon)), dtype=wp.vec3),
)
return d
def put_data(
mjm: mujoco.MjModel,
mjd: mujoco.MjData,
nworld: int = 1,
nconmax: Optional[int] = None,
njmax: Optional[int] = None,
naconmax: Optional[int] = None,
) -> types.Data:
"""Moves data from host to a device.
Args:
mjm: The model containing kinematic and dynamic information (host).
mjd: The data object containing current state and output arrays (host).
nworld: The number of worlds.
nconmax: Number of contacts to allocate per world. Contacts exist in large
heterogenous arrays: one world may have more than nconmax contacts.
njmax: Number of constraints to allocate per world. Constraint arrays are
batched by world: no world may have more than njmax constraints.
naconmax: Number of contacts to allocate for all worlds. Overrides nconmax.
Returns:
The data object containing the current state and output arrays (device).
"""
# TODO(team): move nconmax and njmax to Model?
# TODO(team): decide what to do about uninitialized warp-only fields created by put_data
# we need to ensure these are only workspace fields and don't carry state
# TODO(team): better heuristic for nconmax and njmax
nconmax = nconmax or max(5, 4 * mjd.ncon)
njmax = njmax or max(5, 4 * mjd.nefc)
if nworld < 1:
raise ValueError(f"nworld must be >= 1")
if naconmax is None:
if nconmax < 0:
raise ValueError("nconmax must be >= 0")
if mjd.ncon > nconmax:
raise ValueError(f"nconmax overflow (nconmax must be >= {mjd.ncon})")
naconmax = max(512, nworld * nconmax)
elif naconmax < mjd.ncon * nworld:
raise ValueError(f"naconmax overflow (naconmax must be >= {mjd.ncon * nworld})")
if njmax < 0:
raise ValueError("njmax must be >= 0")
if mjd.nefc > njmax:
raise ValueError(f"njmax overflow (njmax must be >= {mjd.nefc})")
sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int})
sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max()
sizes["nmaxpyramid"] = np.maximum(1, 2 * (sizes["nmaxcondim"] - 1))
tile_size = types.TILE_SIZE_JTDAJ_SPARSE if is_sparse(mjm) else types.TILE_SIZE_JTDAJ_DENSE
sizes["njmax_pad"], sizes["nv_pad"] = _get_padded_sizes(mjm.nv, njmax, is_sparse(mjm), tile_size)
sizes["nworld"] = nworld
sizes["naconmax"] = naconmax
sizes["njmax"] = njmax
# ensure static geom positions are computed
# TODO: remove once MjData creation semantics are fixed
mujoco.mj_kinematics(mjm, mjd)
# create contact
contact_kwargs = {"efc_address": None, "worldid": None, "type": None, "geomcollisionid": None}
for f in dataclasses.fields(types.Contact):
if f.name in contact_kwargs:
continue
val = getattr(mjd.contact, f.name)
val = np.repeat(val, nworld, axis=0)
width = ((0, naconmax - val.shape[0]),) + ((0, 0),) * (val.ndim - 1)
val = np.pad(val, width)
contact_kwargs[f.name] = _create_array(val, f.type, sizes)
contact = types.Contact(**contact_kwargs)
contact.efc_address = np.zeros((naconmax, sizes["nmaxpyramid"]), dtype=int)
for i in range(mjd.ncon):
efc_address = mjd.contact.efc_address[i]
if efc_address == -1:
continue
condim = mjd.contact.dim[i]
ndim = max(1, 2 * (condim - 1)) if mjm.opt.cone == mujoco.mjtCone.mjCONE_PYRAMIDAL else condim
for j in range(nworld):
contact.efc_address[j * mjd.ncon + i, :ndim] = efc_address + np.arange(ndim)
contact.efc_address = wp.array(contact.efc_address, dtype=int)
contact.worldid = np.pad(np.repeat(np.arange(nworld), mjd.ncon), (0, naconmax - nworld * mjd.ncon))
contact.worldid = wp.array(contact.worldid, dtype=int)
contact.type = wp.ones((naconmax,), dtype=int) # TODO(team): set values
contact.geomcollisionid = wp.empty((naconmax,), dtype=int) # TODO(team): set values
# create efc
efc_kwargs = {"J": None}
for f in dataclasses.fields(types.Constraint):
if f.name in efc_kwargs:
continue
shape = tuple(sizes[dim] if isinstance(dim, str) else dim for dim in f.type.shape)
val = np.zeros(shape, dtype=f.type.dtype)
if f.name in ("type", "id", "pos", "margin", "D", "vel", "aref", "frictionloss", "force"):
val[:, : mjd.nefc] = np.tile(getattr(mjd, "efc_" + f.name), (nworld, 1))
efc_kwargs[f.name] = wp.array(val, dtype=f.type.dtype)
efc = types.Constraint(**efc_kwargs)
if mujoco.mj_isSparse(mjm):
efc_j = np.zeros((mjd.nefc, mjm.nv))
mujoco.mju_sparse2dense(efc_j, mjd.efc_J, mjd.efc_J_rownnz, mjd.efc_J_rowadr, mjd.efc_J_colind)
else:
efc_j = mjd.efc_J.reshape((mjd.nefc, mjm.nv))
efc.J = np.zeros((nworld, sizes["njmax_pad"], sizes["nv_pad"]), dtype=f.type.dtype)
efc.J[:, : mjd.nefc, : mjm.nv] = np.tile(efc_j, (nworld, 1, 1))
efc.J = wp.array(efc.J, dtype=float)
# create data
d_kwargs = {
"contact": contact,
"efc": efc,
"nworld": nworld,
"naconmax": naconmax,
"njmax": njmax,
# fields set after initialization:
"solver_niter": None,
"qM": None,
"qLD": None,
"ten_J": None,
"actuator_moment": None,
"flexedge_J": None,
"nacon": None,
"ne_connect": None,
"ne_weld": None,
"ne_jnt": None,
"ne_ten": None,
"ne_flex": None,
"nsolving": None,
"epa_workspace": None, # + NEW ADD
}
for f in dataclasses.fields(types.Data):
if f.name in d_kwargs:
continue
val = getattr(mjd, f.name, None)
if val is not None:
shape = val.shape if hasattr(val, "shape") else ()
val = np.full((nworld,) + shape, val)
d_kwargs[f.name] = _create_array(val, f.type, sizes)
d = types.Data(**d_kwargs)
d.solver_niter = wp.full((nworld,), mjd.solver_niter[0], dtype=int)
if is_sparse(mjm):
d.qM = wp.array(np.full((nworld, 1, mjm.nM), mjd.qM), dtype=float)
d.qLD = wp.array(np.full((nworld, 1, mjm.nC), mjd.qLD), dtype=float)
else:
qM = np.zeros((mjm.nv, mjm.nv))
mujoco.mj_fullM(mjm, qM, mjd.qM)
qLD = np.linalg.cholesky(qM) if (mjd.qM != 0.0).any() and (mjd.qLD != 0.0).any() else np.zeros((mjm.nv, mjm.nv))
padding = sizes["nv_pad"] - mjm.nv
qM_padded = np.pad(qM, ((0, padding), (0, padding)), mode="constant", constant_values=0.0)
d.qM = wp.array(np.full((nworld, sizes["nv_pad"], sizes["nv_pad"]), qM_padded), dtype=float)
d.qLD = wp.array(np.full((nworld, mjm.nv, mjm.nv), qLD), dtype=float)
if mujoco.mj_isSparse(mjm):
ten_J = np.zeros((mjm.ntendon, mjm.nv))
mujoco.mju_sparse2dense(ten_J, mjd.ten_J.reshape(-1), mjd.ten_J_rownnz, mjd.ten_J_rowadr, mjd.ten_J_colind.reshape(-1))
d.ten_J = wp.array(np.full((nworld, mjm.ntendon, mjm.nv), ten_J), dtype=float)
flexedge_J = np.zeros((mjm.nflexedge, mjm.nv))
mujoco.mju_sparse2dense(
flexedge_J, mjd.flexedge_J.reshape(-1), mjd.flexedge_J_rownnz, mjd.flexedge_J_rowadr, mjd.flexedge_J_colind.reshape(-1)
)
d.flexedge_J = wp.array(np.full((nworld, mjm.nflexedge, mjm.nv), flexedge_J), dtype=float)
else:
ten_J = mjd.ten_J.reshape((mjm.ntendon, mjm.nv))
d.ten_J = wp.array(np.full((nworld, mjm.ntendon, mjm.nv), ten_J), dtype=float)
flexedge_J = mjd.flexedge_J.reshape((mjm.nflexedge, mjm.nv))
d.flexedge_J = wp.array(np.full((nworld, mjm.nflexedge, mjm.nv), flexedge_J), dtype=float)
# TODO(taylorhowell): sparse actuator_moment
actuator_moment = np.zeros((mjm.nu, mjm.nv))
mujoco.mju_sparse2dense(actuator_moment, mjd.actuator_moment, mjd.moment_rownnz, mjd.moment_rowadr, mjd.moment_colind)
d.actuator_moment = wp.array(np.full((nworld, mjm.nu, mjm.nv), actuator_moment), dtype=float)
d.nacon = wp.array([mjd.ncon * nworld], dtype=int)
d.ne_connect = wp.full(nworld, 3 * np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_CONNECT) & mjd.eq_active), dtype=int)
d.ne_weld = wp.full(nworld, 6 * np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_WELD) & mjd.eq_active), dtype=int)
d.ne_jnt = wp.full(nworld, np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_JOINT) & mjd.eq_active), dtype=int)
d.ne_ten = wp.full(nworld, np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_TENDON) & mjd.eq_active), dtype=int)
d.ne_flex = wp.full(nworld, np.sum((mjm.eq_type == mujoco.mjtEq.mjEQ_FLEX) & mjd.eq_active), dtype=int)
d.nsolving = wp.array([nworld], dtype=int)
# + NEW ADD
ccd_iterations = mjm.opt.ccd_iterations
epa_vert_size = 5 + ccd_iterations
epa_face_size = 6 + types.MJ_MAX_EPAFACES * ccd_iterations
nmaxpolygon = 0
nmaxmeshdeg = 0
d.epa_workspace = types.EpaWorkspace(
epa_vert=wp.empty(shape=(naconmax, epa_vert_size), dtype=wp.vec3),
epa_vert1=wp.empty(shape=(naconmax, epa_vert_size), dtype=wp.vec3),
epa_vert2=wp.empty(shape=(naconmax, epa_vert_size), dtype=wp.vec3),
epa_vert_index1=wp.empty(shape=(naconmax, epa_vert_size), dtype=int),
epa_vert_index2=wp.empty(shape=(naconmax, epa_vert_size), dtype=int),
epa_face=wp.empty(shape=(naconmax, epa_face_size), dtype=wp.vec3i),
epa_pr=wp.empty(shape=(naconmax, epa_face_size), dtype=wp.vec3),
epa_norm2=wp.empty(shape=(naconmax, epa_face_size), dtype=float),
epa_index=wp.empty(shape=(naconmax, epa_face_size), dtype=int),
epa_map=wp.empty(shape=(naconmax, epa_face_size), dtype=int),
epa_horizon=wp.empty(shape=(naconmax, 2 * types.MJ_MAX_EPAHORIZON), dtype=int),
multiccd_polygon=wp.empty(shape=(naconmax, max(1, 2 * nmaxpolygon)), dtype=wp.vec3),
multiccd_clipped=wp.empty(shape=(naconmax, max(1, 2 * nmaxpolygon)), dtype=wp.vec3),
multiccd_pnormal=wp.empty(shape=(naconmax, max(1, nmaxpolygon)), dtype=wp.vec3),
multiccd_pdist=wp.empty(shape=(naconmax, max(1, nmaxpolygon)), dtype=float),
multiccd_idx1=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=int),
multiccd_idx2=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=int),
multiccd_n1=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=wp.vec3),
multiccd_n2=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=wp.vec3),
multiccd_endvert=wp.empty(shape=(naconmax, max(1, nmaxmeshdeg)), dtype=wp.vec3),
multiccd_face1=wp.empty(shape=(naconmax, max(1, nmaxpolygon)), dtype=wp.vec3),
multiccd_face2=wp.empty(shape=(naconmax, max(1, nmaxpolygon)), dtype=wp.vec3),
)
return dtypes.py
@dataclasses.dataclass
class EpaWorkspace:
epa_vert: wp.array2d(dtype=wp.vec3)
epa_vert1: wp.array2d(dtype=wp.vec3)
epa_vert2: wp.array2d(dtype=wp.vec3)
epa_vert_index1: wp.array2d(dtype=int)
epa_vert_index2: wp.array2d(dtype=int)
epa_face: wp.array2d(dtype=wp.vec3i)
epa_pr: wp.array2d(dtype=wp.vec3)
epa_norm2: wp.array2d(dtype=float)
epa_index: wp.array2d(dtype=int)
epa_map: wp.array2d(dtype=int)
epa_horizon: wp.array2d(dtype=int)
multiccd_polygon: wp.array2d(dtype=wp.vec3)
multiccd_clipped: wp.array2d(dtype=wp.vec3)
multiccd_pnormal: wp.array2d(dtype=wp.vec3)
multiccd_pdist: wp.array2d(dtype=float)
multiccd_idx1: wp.array2d(dtype=int)
multiccd_idx2: wp.array2d(dtype=int)
multiccd_n1: wp.array2d(dtype=wp.vec3)
multiccd_n2: wp.array2d(dtype=wp.vec3)
multiccd_endvert: wp.array2d(dtype=wp.vec3)
multiccd_face1: wp.array2d(dtype=wp.vec3)
multiccd_face2: wp.array2d(dtype=wp.vec3)With these changes, the training runs stably without OOM. I wonder if these changes are reasonable? Any feedback on the implementation would be appreciated!
Thanks !