# Copyright 2025 The Newton Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import dataclasses
import warnings
from typing import Any, Optional, Sequence
import mujoco
import numpy as np
import warp as wp
from mujoco_warp._src import bvh
from mujoco_warp._src import math as mjmath
from mujoco_warp._src import render_util
from mujoco_warp._src import smooth
from mujoco_warp._src import types
from mujoco_warp._src import warp_util
from mujoco_warp._src.types import MJ_MINVAL
from mujoco_warp._src.types import BiasType
from mujoco_warp._src.types import TrnType
from mujoco_warp._src.types import vec10
from mujoco_warp._src.util_pkg import check_version
def _is_array_spec(typ) -> bool:
"""Check if a type annotation is an array spec (wp.array instance or bracket annotation)."""
return isinstance(typ, wp.array) or type(typ).__name__ == "_ArrayAnnotation"
def _create_array(data: Any, spec, sizes: dict[str, int]) -> wp.array | None:
"""Creates a warp array and populates it with data.
The array shape is determined by a field spec referencing MjModel / MjData array sizes.
"""
spec_shape = getattr(spec, "shape", (0,))
shape = None
if spec_shape != (0,):
shape = tuple(sizes[dim] if isinstance(dim, str) else dim for dim in spec_shape)
if data is None and shape is None:
return None # nothing to do
elif data is None:
array = wp.zeros(shape, dtype=spec.dtype)
else:
array = wp.array(np.array(data), dtype=spec.dtype, shape=shape)
if spec_shape and spec_shape[0] == "*":
# add private attribute for JAX to determine which fields are batched
array._is_batched = True
# also set stride 0 to 0 which is expected legacy behavior (but is deprecated)
array.strides = (0,) + array.strides[1:]
return array
def is_sparse(mjm: mujoco.MjModel) -> bool:
if mjm.opt.jacobian == mujoco.mjtJacobian.mjJAC_AUTO:
if mjm.nv > 32:
return True
else:
return False
else:
return bool(mujoco.mj_isSparse(mjm))
[docs]
def put_model(mjm: mujoco.MjModel) -> types.Model:
"""Creates a model on device.
Args:
mjm: The model containing kinematic and dynamic information (host).
Returns:
The model containing kinematic and dynamic information (device).
"""
# check for compatible cuda toolkit and driver versions
warp_util.check_toolkit_driver()
# model: check supported features in array types
for field, field_type, mj_type in (
(mjm.actuator_trntype, types.TrnType, mujoco.mjtTrn),
(mjm.actuator_dyntype, types.DynType, mujoco.mjtDyn),
(mjm.actuator_gaintype, types.GainType, mujoco.mjtGain),
(mjm.actuator_biastype, types.BiasType, mujoco.mjtBias),
(mjm.eq_type, types.EqType, mujoco.mjtEq),
(mjm.geom_type, types.GeomType, mujoco.mjtGeom),
(mjm.sensor_type, types.SensorType, mujoco.mjtSensor),
(mjm.wrap_type, types.WrapType, mujoco.mjtWrap),
):
missing = ~np.isin(field, field_type)
if missing.any():
names = [mj_type(v).name for v in field[missing]]
raise NotImplementedError(f"{names} not supported.")
# opt: check supported features in scalar types
for field, field_type, mj_type in (
(mjm.opt.integrator, types.IntegratorType, mujoco.mjtIntegrator),
(mjm.opt.cone, types.ConeType, mujoco.mjtCone),
(mjm.opt.solver, types.SolverType, mujoco.mjtSolver),
):
if field not in set(field_type):
raise NotImplementedError(f"{mj_type(field).name} is unsupported.")
# opt: check supported features in scalar flag types
for field, field_type, mj_type in (
(mjm.opt.disableflags, types.DisableBit, mujoco.mjtDisableBit),
(mjm.opt.enableflags, types.EnableBit, mujoco.mjtEnableBit),
):
unsupported = field & ~np.bitwise_or.reduce(field_type)
if unsupported:
raise NotImplementedError(f"{mj_type(unsupported).name} is unsupported.")
if mjm.opt.noslip_iterations > 0:
raise NotImplementedError(f"noslip solver not implemented.")
if (mjm.opt.viscosity > 0 or mjm.opt.density > 0) and mjm.opt.integrator in (
mujoco.mjtIntegrator.mjINT_IMPLICITFAST,
mujoco.mjtIntegrator.mjINT_IMPLICIT,
):
raise NotImplementedError(f"Implicit integrators and fluid model not implemented.")
if (mjm.body_plugin != -1).any():
raise NotImplementedError("Body plugins not supported.")
if (mjm.actuator_plugin != -1).any():
raise NotImplementedError("Actuator plugins not supported.")
if (mjm.sensor_plugin != -1).any():
raise NotImplementedError("Sensor plugins not supported.")
# TODO(team): remove after _update_gradient for Newton uses tile operations for islands
nv_max = 60
if mjm.nv > nv_max and mjm.opt.jacobian == mujoco.mjtJacobian.mjJAC_DENSE:
raise ValueError(f"Dense is unsupported for nv > {nv_max} (nv = {mjm.nv}).")
collision_sensors = (mujoco.mjtSensor.mjSENS_GEOMDIST, mujoco.mjtSensor.mjSENS_GEOMNORMAL, mujoco.mjtSensor.mjSENS_GEOMFROMTO)
is_collision_sensor = np.isin(mjm.sensor_type, collision_sensors)
def not_implemented(objtype, objid, geomtype):
if objtype == mujoco.mjtObj.mjOBJ_BODY:
geomnum = mjm.body_geomnum[objid]
geomadr = mjm.body_geomadr[objid]
for geomid in range(geomadr, geomadr + geomnum):
if mjm.geom_type[geomid] == geomtype:
return True
elif objtype == mujoco.mjtObj.mjOBJ_GEOM:
if mjm.geom_type[objid] == geomtype:
return True
return False
def _check_friction(name: str, id_: int, condim: int, friction, checks):
for min_condim, indices in checks:
if condim >= min_condim:
for idx in indices:
if friction[idx] < types.MJ_MINMU:
warnings.warn(
f"{name} {id_}: friction[{idx}] ({friction[idx]}) < MJ_MINMU ({types.MJ_MINMU}) with condim={condim} may cause NaN"
)
for geomid in range(mjm.ngeom):
_check_friction("geom", geomid, mjm.geom_condim[geomid], mjm.geom_friction[geomid], [(3, [0]), (4, [1]), (6, [2])])
for pairid in range(mjm.npair):
_check_friction("pair", pairid, mjm.pair_dim[pairid], mjm.pair_friction[pairid], [(3, [0]), (4, [1, 2]), (6, [3, 4])])
# create opt
opt_kwargs = {f.name: getattr(mjm.opt, f.name, None) for f in dataclasses.fields(types.Option)}
if hasattr(mjm.opt, "impratio"):
opt_kwargs["impratio_invsqrt"] = 1.0 / np.sqrt(np.maximum(mjm.opt.impratio, mujoco.mjMINVAL))
opt = types.Option(**opt_kwargs)
# C MuJoCo tolerance was chosen for float64 architecture, but we default to float32 on GPU
# adjust the tolerance for lower precision, to avoid the solver spending iterations needlessly
# bouncing around the optimal solution
opt.tolerance = max(opt.tolerance, 1e-6)
# warp only fields
ls_parallel_id = mujoco.mj_name2id(mjm, mujoco.mjtObj.mjOBJ_NUMERIC, "ls_parallel")
opt.ls_parallel = (ls_parallel_id > -1) and (mjm.numeric_data[mjm.numeric_adr[ls_parallel_id]] == 1)
opt.ls_parallel_min_step = 1.0e-6 # TODO(team): determine good default setting
opt.broadphase = types.BroadphaseType.NXN
opt.broadphase_filter = types.BroadphaseFilter.PLANE | types.BroadphaseFilter.SPHERE | types.BroadphaseFilter.OBB
opt.graph_conditional = True
opt.run_collision_detection = True
contact_sensor_maxmatch_id = mujoco.mj_name2id(mjm, mujoco.mjtObj.mjOBJ_NUMERIC, "contact_sensor_maxmatch")
if contact_sensor_maxmatch_id > -1:
opt.contact_sensor_maxmatch = mjm.numeric_data[mjm.numeric_adr[contact_sensor_maxmatch_id]]
else:
opt.contact_sensor_maxmatch = 64
# place opt on device
for f in dataclasses.fields(types.Option):
if _is_array_spec(f.type):
setattr(opt, f.name, _create_array(getattr(opt, f.name), f.type, {"*": 1}))
else:
setattr(opt, f.name, f.type(getattr(opt, f.name)))
# create stat
stat = types.Statistic(meaninertia=_create_array([mjm.stat.meaninertia], types.array("*", float), {"*": 1}))
# create model
m = types.Model(**{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model)})
m.opt = opt
m.stat = stat
m.callback = types.Callback()
m.nv_pad = _get_padded_sizes(
mjm.nv, 0, is_sparse(mjm), types.TILE_SIZE_JTDAJ_SPARSE if is_sparse(mjm) else types.TILE_SIZE_JTDAJ_DENSE
)[1]
m.nacttrnbody = (mjm.actuator_trntype == mujoco.mjtTrn.mjTRN_BODY).sum()
m.nsensortaxel = mjm.mesh_vertnum[mjm.sensor_objid[mjm.sensor_type == mujoco.mjtSensor.mjSENS_TACTILE]].sum()
m.nsensorcontact = (mjm.sensor_type == mujoco.mjtSensor.mjSENS_CONTACT).sum()
m.nrangefinder = (mjm.sensor_type == mujoco.mjtSensor.mjSENS_RANGEFINDER).sum()
m.nmaxcondim = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max()
m.nmaxpyramid = np.maximum(1, 2 * (m.nmaxcondim - 1))
m.has_sdf_geom = (mjm.geom_type == mujoco.mjtGeom.mjGEOM_SDF).any()
m.block_dim = types.BlockDim()
m.is_sparse = is_sparse(mjm)
m.has_fluid = mjm.opt.wind.any() or mjm.opt.density > 0 or mjm.opt.viscosity > 0
m.max_ten_J_rownnz = int(mjm.ten_J_rownnz.max()) if mjm.ntendon else 0
# body ids grouped by tree level (depth-based traversal)
bodies, body_depth = {}, np.zeros(mjm.nbody, dtype=int) - 1
for i in range(mjm.nbody):
body_depth[i] = body_depth[mjm.body_parentid[i]] + 1
bodies.setdefault(body_depth[i], []).append(i)
m.body_tree = tuple(wp.array(bodies[i], dtype=int) for i in sorted(bodies))
# branch-based traversal data
children_count = np.bincount(mjm.body_parentid[1:], minlength=mjm.nbody)
ancestor_chain = lambda b: ancestor_chain(mjm.body_parentid[b]) + [b] if b else []
branches = [ancestor_chain(l) for l in np.where(children_count[1:] == 0)[0] + 1]
m.nbranch = len(branches)
body_branches = []
body_branch_start = []
offset = 0
for branch in branches:
body_branches.extend(branch)
body_branch_start.append(offset)
offset += len(branch)
body_branch_start.append(offset)
m.body_branches = np.array(body_branches, dtype=int)
m.body_branch_start = np.array(body_branch_start, dtype=int)
m.mocap_bodyid = np.arange(mjm.nbody)[mjm.body_mocapid >= 0]
m.mocap_bodyid = m.mocap_bodyid[mjm.body_mocapid[mjm.body_mocapid >= 0].argsort()]
m.body_fluid_ellipsoid = np.zeros(mjm.nbody, dtype=bool)
m.body_fluid_ellipsoid[mjm.geom_bodyid[mjm.geom_fluid.reshape(mjm.ngeom, mujoco.mjNFLUID)[:, 0] > 0]] = True
jnt_limited_slide_hinge = mjm.jnt_limited & np.isin(mjm.jnt_type, (mujoco.mjtJoint.mjJNT_SLIDE, mujoco.mjtJoint.mjJNT_HINGE))
m.jnt_limited_slide_hinge_adr = np.nonzero(jnt_limited_slide_hinge)[0]
m.jnt_limited_ball_adr = np.nonzero(mjm.jnt_limited & (mjm.jnt_type == mujoco.mjtJoint.mjJNT_BALL))[0]
m.dof_tri_row, m.dof_tri_col = np.tril_indices(mjm.nv)
# precalculated geom pairs
filterparent = not (mjm.opt.disableflags & types.DisableBit.FILTERPARENT)
geom1, geom2 = np.triu_indices(mjm.ngeom, k=1)
m.nxn_geom_pair = np.stack((geom1, geom2), axis=1)
bodyid1 = mjm.geom_bodyid[geom1]
bodyid2 = mjm.geom_bodyid[geom2]
contype1 = mjm.geom_contype[geom1]
contype2 = mjm.geom_contype[geom2]
conaffinity1 = mjm.geom_conaffinity[geom1]
conaffinity2 = mjm.geom_conaffinity[geom2]
weldid1 = mjm.body_weldid[bodyid1]
weldid2 = mjm.body_weldid[bodyid2]
weld_parentid1 = mjm.body_weldid[mjm.body_parentid[weldid1]]
weld_parentid2 = mjm.body_weldid[mjm.body_parentid[weldid2]]
self_collision = weldid1 == weldid2
parent_child_collision = (
filterparent & (weldid1 != 0) & (weldid2 != 0) & ((weldid1 == weld_parentid2) | (weldid2 == weld_parentid1))
)
mask = np.array((contype1 & conaffinity2) | (contype2 & conaffinity1), dtype=bool)
exclude = np.isin((bodyid1 << 16) + bodyid2, mjm.exclude_signature)
nxn_pairid_contact = -1 * np.ones(len(geom1), dtype=int)
nxn_pairid_contact[~(mask & ~self_collision & ~parent_child_collision & ~exclude)] = -2
# contact pairs
def upper_tri_index(n, i, j):
i, j = (j, i) if j < i else (i, j)
return (i * (2 * n - i - 3)) // 2 + j - 1
for i in range(mjm.npair):
nxn_pairid_contact[upper_tri_index(mjm.ngeom, mjm.pair_geom1[i], mjm.pair_geom2[i])] = i
sensor_collision_adr = np.nonzero(is_collision_sensor)[0]
collision_sensor_adr = np.full(mjm.nsensor, -1)
collision_sensor_adr[sensor_collision_adr] = np.arange(len(sensor_collision_adr))
nxn_pairid_collision = -1 * np.ones(len(geom1), dtype=int)
pairids = []
sensor_collision_start_adr = []
for i in range(sensor_collision_adr.size):
sensorid = sensor_collision_adr[i]
objtype = mjm.sensor_objtype[sensorid]
objid = mjm.sensor_objid[sensorid]
reftype = mjm.sensor_reftype[sensorid]
refid = mjm.sensor_refid[sensorid]
# get lists of geoms to collide
if objtype == types.ObjType.BODY:
n1 = mjm.body_geomnum[objid]
id1 = mjm.body_geomadr[objid]
else:
n1 = 1
id1 = objid
if reftype == types.ObjType.BODY:
n2 = mjm.body_geomnum[refid]
id2 = mjm.body_geomadr[refid]
else:
n2 = 1
id2 = refid
# collide all pairs
for geom1id in range(id1, id1 + n1):
for geom2id in range(id2, id2 + n2):
pairid = upper_tri_index(mjm.ngeom, geom1id, geom2id)
if pairid in pairids:
sensor_collision_start_adr.append(nxn_pairid_collision[pairid])
else:
npairids = len(pairids)
nxn_pairid_collision[pairid] = npairids
sensor_collision_start_adr.append(npairids)
pairids.append(pairid)
m.nsensorcollision = (nxn_pairid_collision >= 0).sum()
m.sensor_collision_start_adr = np.array(sensor_collision_start_adr)
nxn_include = (nxn_pairid_contact > -2) | (nxn_pairid_collision >= 0)
if nxn_include.sum() < 250_000:
opt.broadphase = types.BroadphaseType.NXN
elif mjm.ngeom < 1000:
opt.broadphase = types.BroadphaseType.SAP_TILE
else:
opt.broadphase = types.BroadphaseType.SAP_SEGMENTED
m.nxn_geom_pair_filtered = m.nxn_geom_pair[nxn_include]
m.nxn_pairid = np.hstack([nxn_pairid_contact.reshape((-1, 1)), nxn_pairid_collision.reshape((-1, 1))])
m.nxn_pairid_filtered = m.nxn_pairid[nxn_include]
# count contact pair types
def geom_trid_index(i, j):
i, j = (j, i) if j < i else (i, j)
return (i * (2 * len(types.GeomType) - i - 1)) // 2 + j
m.geom_pair_type_count = tuple(
np.bincount(
[geom_trid_index(mjm.geom_type[geom1[i]], mjm.geom_type[geom2[i]]) for i in np.arange(len(geom1)) if nxn_include[i]],
minlength=len(types.GeomType) * (len(types.GeomType) + 1) // 2,
)
)
# check for unsupported margin + multicontact / box-box CCD combinations
use_multiccd = (mjm.opt.disableflags & types.DisableBit.MULTICCD) == 0
nativeccd_disabled = mjm.opt.disableflags & types.DisableBit.NATIVECCD
BOX = int(mujoco.mjtGeom.mjGEOM_BOX)
MESH = int(mujoco.mjtGeom.mjGEOM_MESH)
has_boxbox = m.geom_pair_type_count[geom_trid_index(BOX, BOX)] > 0
has_multiccd_pairs = has_boxbox or (
use_multiccd
and (m.geom_pair_type_count[geom_trid_index(BOX, MESH)] > 0 or m.geom_pair_type_count[geom_trid_index(MESH, MESH)] > 0)
)
if has_multiccd_pairs:
def _check_margin(name, t1, t2, margin):
if use_multiccd:
raise NotImplementedError(
f"{name} has non-zero margin ({margin}) with MULTICCD enabled. Set margin to 0 or disable MULTICCD."
)
if t1 == BOX and t2 == BOX and not nativeccd_disabled:
raise NotImplementedError(
f"{name} has non-zero margin ({margin}) with NATIVECCD enabled. Set margin to 0 or disable NATIVECCD."
)
geom_name = lambda g: mujoco.mj_id2name(mjm, mujoco.mjtObj.mjOBJ_GEOM, g) or str(g)
for idx in np.nonzero(nxn_include & (nxn_pairid_contact == -1))[0]:
g1, g2 = int(geom1[idx]), int(geom2[idx])
t1, t2 = int(mjm.geom_type[g1]), int(mjm.geom_type[g2])
m1, m2 = float(mjm.geom_margin[g1]), float(mjm.geom_margin[g2])
if (m1 or m2) and t1 in (BOX, MESH) and t2 in (BOX, MESH):
_check_margin(f"geom pair ({geom_name(g1)}, {geom_name(g2)})", t1, t2, (m1, m2))
for pid in range(mjm.npair):
g1, g2 = int(mjm.pair_geom1[pid]), int(mjm.pair_geom2[pid])
t1, t2 = int(mjm.geom_type[g1]), int(mjm.geom_type[g2])
pm = float(mjm.pair_margin[pid])
if pm and t1 in (BOX, MESH) and t2 in (BOX, MESH):
_check_margin(f"pair {pid} ({geom_name(g1)}, {geom_name(g2)})", t1, t2, pm)
m.nmaxpolygon = np.append(mjm.mesh_polyvertnum, 0).max()
m.nmaxmeshdeg = np.append(mjm.mesh_polymapnum, 0).max()
# filter plugins for only geom plugins, drop the rest
m.plugin, m.plugin_attr = [], []
m.geom_plugin_index = np.full_like(mjm.geom_type, -1)
for i in range(len(mjm.geom_plugin)):
if mjm.geom_plugin[i] == -1:
continue
p = mjm.geom_plugin[i]
m.geom_plugin_index[i] = len(m.plugin)
m.plugin.append(mjm.plugin[p])
start = mjm.plugin_attradr[p]
end = mjm.plugin_attradr[p + 1] if p + 1 < mjm.nplugin else len(mjm.plugin_attr)
values = mjm.plugin_attr[start:end]
attr_values = []
current = []
for v in values:
if v == 0:
if current:
s = "".join(chr(int(x)) for x in current)
attr_values.append(float(s))
current = []
else:
current.append(v)
if len(attr_values) > types._NPLUGINATTR:
raise ValueError(f"Plugin has {len(attr_values)} attributes, which exceeds the maximum of {types._NPLUGINATTR}. ")
# pad with zeros to _NPLUGINATTR
attr_values += [0.0] * (types._NPLUGINATTR - len(attr_values))
m.plugin_attr.append(attr_values[: types._NPLUGINATTR])
# equality constraint addresses
m.eq_connect_adr = np.nonzero(mjm.eq_type == types.EqType.CONNECT)[0]
m.eq_wld_adr = np.nonzero(mjm.eq_type == types.EqType.WELD)[0]
m.eq_jnt_adr = np.nonzero(mjm.eq_type == types.EqType.JOINT)[0]
m.eq_ten_adr = np.nonzero(mjm.eq_type == types.EqType.TENDON)[0]
m.eq_flex_adr = np.nonzero(mjm.eq_type == types.EqType.FLEX)[0]
# fixed tendon
m.tendon_jnt_adr, m.wrap_jnt_adr = [], []
for i in range(mjm.ntendon):
adr = mjm.tendon_adr[i]
if mjm.wrap_type[adr] == mujoco.mjtWrap.mjWRAP_JOINT:
tendon_num = mjm.tendon_num[i]
for j in range(tendon_num):
m.tendon_jnt_adr.append(i)
m.wrap_jnt_adr.append(adr + j)
# spatial tendon
m.tendon_site_pair_adr, m.tendon_geom_adr = [], []
m.ten_wrapadr_site, m.ten_wrapnum_site = [0], []
for i, tendon_num in enumerate(mjm.tendon_num):
adr = mjm.tendon_adr[i]
# sites
if (mjm.wrap_type[adr : adr + tendon_num] == mujoco.mjtWrap.mjWRAP_SITE).all():
if i < mjm.ntendon:
m.ten_wrapadr_site.append(m.ten_wrapadr_site[-1] + tendon_num)
m.ten_wrapnum_site.append(tendon_num)
else:
if i < mjm.ntendon:
m.ten_wrapadr_site.append(m.ten_wrapadr_site[-1])
m.ten_wrapnum_site.append(0)
# geoms
for j in range(tendon_num):
wrap_type = mjm.wrap_type[adr + j]
if j < tendon_num - 1:
next_wrap_type = mjm.wrap_type[adr + j + 1]
if wrap_type == mujoco.mjtWrap.mjWRAP_SITE and next_wrap_type == mujoco.mjtWrap.mjWRAP_SITE:
m.tendon_site_pair_adr.append(i)
if wrap_type == mujoco.mjtWrap.mjWRAP_SPHERE or wrap_type == mujoco.mjtWrap.mjWRAP_CYLINDER:
m.tendon_geom_adr.append(i)
m.tendon_limited_adr = np.nonzero(mjm.tendon_limited)[0]
m.wrap_site_adr = np.nonzero(mjm.wrap_type == mujoco.mjtWrap.mjWRAP_SITE)[0]
m.wrap_site_pair_adr = np.setdiff1d(m.wrap_site_adr[np.nonzero(np.diff(m.wrap_site_adr) == 1)[0]], mjm.tendon_adr[1:] - 1)
m.wrap_geom_adr = np.nonzero(np.isin(mjm.wrap_type, [mujoco.mjtWrap.mjWRAP_SPHERE, mujoco.mjtWrap.mjWRAP_CYLINDER]))[0]
# pulley scaling
m.wrap_pulley_scale = np.ones(mjm.nwrap, dtype=float)
pulley_adr = np.nonzero(mjm.wrap_type == mujoco.mjtWrap.mjWRAP_PULLEY)[0]
for tadr, tnum in zip(mjm.tendon_adr, mjm.tendon_num):
for padr in pulley_adr:
if tadr <= padr < tadr + tnum:
m.wrap_pulley_scale[padr : tadr + tnum] = 1.0 / mjm.wrap_prm[padr]
m.actuator_trntype_body_adr = np.nonzero(mjm.actuator_trntype == mujoco.mjtTrn.mjTRN_BODY)[0]
# sensor addresses
m.sensor_pos_adr = np.nonzero(
(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_POS)
& (mjm.sensor_type != mujoco.mjtSensor.mjSENS_JOINTLIMITPOS)
& (mjm.sensor_type != mujoco.mjtSensor.mjSENS_TENDONLIMITPOS)
)[0]
m.sensor_limitpos_adr = np.nonzero(
(mjm.sensor_type == mujoco.mjtSensor.mjSENS_JOINTLIMITPOS) | (mjm.sensor_type == mujoco.mjtSensor.mjSENS_TENDONLIMITPOS)
)[0]
m.sensor_vel_adr = np.nonzero(
(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_VEL)
& (mjm.sensor_type != mujoco.mjtSensor.mjSENS_JOINTLIMITVEL)
& (mjm.sensor_type != mujoco.mjtSensor.mjSENS_TENDONLIMITVEL)
)[0]
m.sensor_limitvel_adr = np.nonzero(
(mjm.sensor_type == mujoco.mjtSensor.mjSENS_JOINTLIMITVEL) | (mjm.sensor_type == mujoco.mjtSensor.mjSENS_TENDONLIMITVEL)
)[0]
m.sensor_acc_adr = np.nonzero(
(mjm.sensor_needstage == mujoco.mjtStage.mjSTAGE_ACC)
& (
(mjm.sensor_type != mujoco.mjtSensor.mjSENS_TOUCH)
| (mjm.sensor_type != mujoco.mjtSensor.mjSENS_JOINTLIMITFRC)
| (mjm.sensor_type != mujoco.mjtSensor.mjSENS_TENDONLIMITFRC)
| (mjm.sensor_type != mujoco.mjtSensor.mjSENS_TENDONACTFRC)
)
)[0]
m.sensor_rangefinder_adr = np.nonzero(mjm.sensor_type == mujoco.mjtSensor.mjSENS_RANGEFINDER)[0]
m.rangefinder_sensor_adr = np.full(mjm.nsensor, -1)
m.rangefinder_sensor_adr[m.sensor_rangefinder_adr] = np.arange(len(m.sensor_rangefinder_adr))
m.collision_sensor_adr = np.full(mjm.nsensor, -1)
m.collision_sensor_adr[sensor_collision_adr] = np.arange(len(sensor_collision_adr))
m.sensor_touch_adr = np.nonzero(mjm.sensor_type == mujoco.mjtSensor.mjSENS_TOUCH)[0]
limitfrc_sensors = (mujoco.mjtSensor.mjSENS_JOINTLIMITFRC, mujoco.mjtSensor.mjSENS_TENDONLIMITFRC)
m.sensor_limitfrc_adr = np.nonzero(np.isin(mjm.sensor_type, limitfrc_sensors))[0]
m.sensor_e_potential = (mjm.sensor_type == mujoco.mjtSensor.mjSENS_E_POTENTIAL).any()
m.sensor_e_kinetic = (mjm.sensor_type == mujoco.mjtSensor.mjSENS_E_KINETIC).any()
m.sensor_tendonactfrc_adr = np.nonzero(mjm.sensor_type == mujoco.mjtSensor.mjSENS_TENDONACTFRC)[0]
subtreevel_sensors = (mujoco.mjtSensor.mjSENS_SUBTREELINVEL, mujoco.mjtSensor.mjSENS_SUBTREEANGMOM)
m.sensor_subtree_vel = np.isin(mjm.sensor_type, subtreevel_sensors).any()
m.sensor_contact_adr = np.nonzero(mjm.sensor_type == mujoco.mjtSensor.mjSENS_CONTACT)[0]
m.sensor_adr_to_contact_adr = np.clip(np.cumsum(mjm.sensor_type == mujoco.mjtSensor.mjSENS_CONTACT) - 1, a_min=0, a_max=None)
m.sensor_rne_postconstraint = np.isin(
mjm.sensor_type,
[
mujoco.mjtSensor.mjSENS_ACCELEROMETER,
mujoco.mjtSensor.mjSENS_FORCE,
mujoco.mjtSensor.mjSENS_TORQUE,
mujoco.mjtSensor.mjSENS_FRAMELINACC,
mujoco.mjtSensor.mjSENS_FRAMEANGACC,
],
).any()
m.sensor_rangefinder_bodyid = mjm.site_bodyid[mjm.sensor_objid[mjm.sensor_type == mujoco.mjtSensor.mjSENS_RANGEFINDER]]
m.taxel_vertadr = [
j + mjm.mesh_vertadr[mjm.sensor_objid[i]]
for i in range(mjm.nsensor)
if mjm.sensor_type[i] == mujoco.mjtSensor.mjSENS_TACTILE
for j in range(mjm.mesh_vertnum[mjm.sensor_objid[i]])
]
m.taxel_sensorid = [
i
for i in range(mjm.nsensor)
if mjm.sensor_type[i] == mujoco.mjtSensor.mjSENS_TACTILE
for j in range(mjm.mesh_vertnum[mjm.sensor_objid[i]])
]
# qM_tiles records the block diagonal structure of qM
tile_corners = [i for i in range(mjm.nv) if mjm.dof_parentid[i] == -1]
tiles = {}
for i in range(len(tile_corners)):
tile_beg = tile_corners[i]
tile_end = mjm.nv if i == len(tile_corners) - 1 else tile_corners[i + 1]
tiles.setdefault(tile_end - tile_beg, []).append(tile_beg)
m.qM_tiles = tuple(types.TileSet(adr=wp.array(tiles[sz], dtype=int), size=sz) for sz in sorted(tiles.keys()))
# qLD_updates has dof tree ordering of qLD updates for sparse factor m
qLD_updates, dof_depth = {}, np.zeros(mjm.nv, dtype=int) - 1
for k in range(mjm.nv):
# skip diagonal rows
if mjm.M_rownnz[k] == 1:
continue
dof_depth[k] = dof_depth[mjm.dof_parentid[k]] + 1
i = mjm.dof_parentid[k]
diag_k = mjm.M_rowadr[k] + mjm.M_rownnz[k] - 1
Madr_ki = diag_k - 1
while i > -1:
qLD_updates.setdefault(dof_depth[i], []).append((i, k, Madr_ki))
i = mjm.dof_parentid[i]
Madr_ki -= 1
m.qLD_updates = tuple(wp.array(qLD_updates[i], dtype=wp.vec3i) for i in sorted(qLD_updates))
# Build concatenated updates for fused kernel
all_updates_flat = []
level_offsets = [0]
for level in sorted(qLD_updates):
all_updates_flat.extend(qLD_updates[level])
level_offsets.append(len(all_updates_flat))
m.qLD_all_updates = all_updates_flat if all_updates_flat else [(0, 0, 0)]
m.qLD_level_offsets = level_offsets
# indices for sparse qM_fullm (used in solver)
m.qM_fullm_i, m.qM_fullm_j = [], []
for i in range(mjm.nv):
j = i
while j > -1:
m.qM_fullm_i.append(i)
m.qM_fullm_j.append(j)
j = mjm.dof_parentid[j]
# Gather-based sparse mul_m: for each row, all (col, madr) including diagonal
row_elements = [[] for _ in range(mjm.nv)]
# Add diagonal
for i in range(mjm.nv):
row_elements[i].append((i, mjm.dof_Madr[i]))
# Add off-diagonals: ancestors (lower) and descendants (upper)
for i in range(mjm.nv):
madr_ij, j = mjm.dof_Madr[i], i
while True:
madr_ij, j = madr_ij + 1, mjm.dof_parentid[j]
if j == -1:
break
row_elements[i].append((j, madr_ij)) # row i gathers M[i,j] * vec[j]
row_elements[j].append((i, madr_ij)) # row j gathers M[j,i] * vec[i]
# Flatten into CSR-like arrays
m.qM_mulm_rowadr = [0]
m.qM_mulm_col = []
m.qM_mulm_madr = []
for i in range(mjm.nv):
for col, madr in row_elements[i]:
m.qM_mulm_col.append(col)
m.qM_mulm_madr.append(madr)
m.qM_mulm_rowadr.append(len(m.qM_mulm_col))
m.flexedge_J_rownnz = mjm.flexedge_J_rownnz
m.flexedge_J_rowadr = mjm.flexedge_J_rowadr
m.flexedge_J_colind = mjm.flexedge_J_colind.reshape(-1)
# place m on device
sizes = dict({"*": 1}, **{f.name: getattr(m, f.name) for f in dataclasses.fields(types.Model) if f.type is int})
for f in dataclasses.fields(types.Model):
if _is_array_spec(f.type):
setattr(m, f.name, _create_array(getattr(m, f.name), f.type, sizes))
return m
def _get_padded_sizes(nv: int, njmax: int, is_sparse: bool, tile_size: int):
# if dense - we just pad to the next multiple of 4 for nv, to get the fast load path.
# we pad to the next multiple of tile_size for njmax to avoid out of bounds accesses.
# if sparse - we pad to the next multiple of tile_size for njmax, and nv.
def round_up(x, multiple):
return ((x + multiple - 1) // multiple) * multiple
njmax_padded = round_up(njmax, tile_size)
nv_padded = round_up(nv, tile_size) if (is_sparse or nv > 32) else round_up(nv, 4)
return njmax_padded, nv_padded
def _default_nconmax(mjm: mujoco.MjModel, mjd: Optional[mujoco.MjData] = None) -> int:
"""Returns a default guess for an ideal nconmax given a Model and optional Data.
This guess is based off a very simple heuristic, and may need to be manually raised if MJWarp
reports ncon overflow, or lowered in order to get the very best performance.
"""
valid_sizes = (2 + (np.arange(19) % 2)) * (2 ** (np.arange(19) // 2 + 3)) # 16, 24, 32, 48, ... 8192
has_sdf = (mjm.geom_type == mujoco.mjtGeom.mjGEOM_SDF).any()
has_flex = mjm.nflex > 0
nconmax = max(mjm.nv * 0.35 * (mjm.nhfield > 0) * 10 + 45, 256 * has_flex, 64 * has_sdf, mjd.ncon if mjd else 0)
return int(valid_sizes[np.searchsorted(valid_sizes, nconmax)])
def _default_njmax(mjm: mujoco.MjModel, mjd: Optional[mujoco.MjData] = None) -> int:
"""Returns a default guess for an ideal njmax given a Model and optional Data.
This guess is based off a very simple heuristic, and may need to be manually raised if MJWarp
reports ncon overflow, or lowered in order to get the very best performance.
"""
valid_sizes = (2 + (np.arange(19) % 2)) * (2 ** (np.arange(19) // 2 + 3)) # 16, 24, 32, 48, ... 8192
has_sdf = (mjm.geom_type == mujoco.mjtGeom.mjGEOM_SDF).any()
has_flex = mjm.nflex > 0
njmax = max(mjm.nv * 2.26 * (mjm.nhfield > 0) * 18 + 53, 512 * has_flex, 256 * has_sdf, mjd.nefc if mjd else 0)
return int(valid_sizes[np.searchsorted(valid_sizes, njmax)])
def _body_pair_nnz(mjm: mujoco.MjModel, body1: int, body2: int) -> int:
"""Returns the number of unique DOFs in the kinematic tree union of two bodies."""
body1 = mjm.body_weldid[body1]
body2 = mjm.body_weldid[body2]
da1 = mjm.body_dofadr[body1] + mjm.body_dofnum[body1] - 1
da2 = mjm.body_dofadr[body2] + mjm.body_dofnum[body2] - 1
nnz = 0
while da1 >= 0 or da2 >= 0:
da = max(da1, da2)
if da1 == da:
da1 = mjm.dof_parentid[da1]
if da2 == da:
da2 = mjm.dof_parentid[da2]
nnz += 1
return nnz
def _default_njmax_nnz(mjm: mujoco.MjModel, nconmax: int, njmax: int) -> int:
"""Returns a heuristic estimate for the number of non-zeros in the sparse constraint Jacobian.
Assumes all equality, friction, and limit constraints are active and computes
their non-zeros. For contacts, assumes njmax contact rows at the maximum
body-pair non-zeros from all enabled collision pairs.
Args:
mjm: The model containing kinematic and dynamic information (host).
nconmax: Maximum number of contacts per world.
njmax: Maximum number of constraint rows per world.
Returns:
Estimated number of non-zeros in the constraint Jacobian.
"""
total_nnz = 0
def _eq_bodies(i):
"""Returns body pair for equality constraint i."""
obj1id, obj2id = mjm.eq_obj1id[i], mjm.eq_obj2id[i]
if mjm.eq_objtype[i] == mujoco.mjtObj.mjOBJ_SITE:
return mjm.site_bodyid[obj1id], mjm.site_bodyid[obj2id]
return obj1id, obj2id
# equality constraints (assume all active)
for i in range(mjm.neq):
eq_type = mjm.eq_type[i]
if eq_type == mujoco.mjtEq.mjEQ_CONNECT:
total_nnz += 3 * _body_pair_nnz(mjm, *_eq_bodies(i))
elif eq_type == mujoco.mjtEq.mjEQ_WELD:
total_nnz += 6 * _body_pair_nnz(mjm, *_eq_bodies(i))
elif eq_type == mujoco.mjtEq.mjEQ_JOINT:
total_nnz += 2 if mjm.eq_obj2id[i] >= 0 else 1
elif eq_type == mujoco.mjtEq.mjEQ_TENDON:
obj1id = mjm.eq_obj1id[i]
obj2id = mjm.eq_obj2id[i]
rownnz1 = mjm.ten_J_rownnz[obj1id] if obj1id < mjm.ntendon else 0
if obj2id >= 0 and obj2id < mjm.ntendon:
rowadr1 = mjm.ten_J_rowadr[obj1id]
rowadr2 = mjm.ten_J_rowadr[obj2id]
rownnz2 = mjm.ten_J_rownnz[obj2id]
cols = set()
for j in range(rownnz1):
cols.add(mjm.ten_J_colind[rowadr1 + j])
for j in range(rownnz2):
cols.add(mjm.ten_J_colind[rowadr2 + j])
total_nnz += len(cols)
else:
total_nnz += rownnz1
elif eq_type == mujoco.mjtEq.mjEQ_FLEX:
obj1id = mjm.eq_obj1id[i]
if obj1id < mjm.nflex:
edge_start = mjm.flex_edgeadr[obj1id]
edge_count = mjm.flex_edgenum[obj1id]
for e in range(edge_count):
total_nnz += mjm.flexedge_J_rownnz[edge_start + e]
# friction constraints
total_nnz += (mjm.dof_frictionloss > 0).sum()
for i in range(mjm.ntendon):
if mjm.tendon_frictionloss[i] > 0:
total_nnz += mjm.ten_J_rownnz[i]
# limit constraints (assume all active)
for i in range(mjm.njnt):
if mjm.jnt_limited[i]:
jnt_type = mjm.jnt_type[i]
if jnt_type == mujoco.mjtJoint.mjJNT_BALL:
total_nnz += 3
elif jnt_type in (mujoco.mjtJoint.mjJNT_SLIDE, mujoco.mjtJoint.mjJNT_HINGE):
total_nnz += 1
for i in range(mjm.ntendon):
if mjm.tendon_limited[i]:
total_nnz += mjm.ten_J_rownnz[i]
# contact constraints: njmax rows at max body-pair non-zeros
max_contact_nnz = 0
# contact pairs
for i in range(mjm.npair):
g1, g2 = mjm.pair_geom1[i], mjm.pair_geom2[i]
b1, b2 = mjm.geom_bodyid[g1], mjm.geom_bodyid[g2]
max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, b1, b2))
# filter geom-geom pairs (unique body pairs, filtered)
body_pair_seen = set()
for i in range(mjm.ngeom):
bi = mjm.geom_bodyid[i]
cti, cai = mjm.geom_contype[i], mjm.geom_conaffinity[i]
for j in range(i + 1, mjm.ngeom):
bj = mjm.geom_bodyid[j]
if bi == bj:
continue
if mjm.body_weldid[bi] == 0 and mjm.body_weldid[bj] == 0:
continue
bp = (min(bi, bj), max(bi, bj))
if bp in body_pair_seen:
continue
ctj, caj = mjm.geom_contype[j], mjm.geom_conaffinity[j]
if not ((cti & caj) or (ctj & cai)):
continue
body_pair_seen.add(bp)
max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, bi, bj))
# flex vertex contacts
for fi in range(mjm.nflex):
fct = mjm.flex_contype[fi]
fca = mjm.flex_conaffinity[fi]
vert_start = mjm.flex_vertadr[fi]
vert_count = mjm.flex_vertnum[fi]
flex_bodies = {mjm.flex_vertbodyid[vert_start + v] for v in range(vert_count)}
geom_bodies = set()
for g in range(mjm.ngeom):
ct, ca = mjm.geom_contype[g], mjm.geom_conaffinity[g]
if (fct & ca) or (ct & fca):
geom_bodies.add(mjm.geom_bodyid[g])
for fb in flex_bodies:
for gb in geom_bodies:
if fb != gb:
max_contact_nnz = max(max_contact_nnz, _body_pair_nnz(mjm, fb, gb))
# flex self-collision
if mjm.flex_selfcollide[fi]:
flex_body_list = sorted(flex_bodies)
for idx1 in range(len(flex_body_list)):
for idx2 in range(idx1 + 1, len(flex_body_list)):
max_contact_nnz = max(
max_contact_nnz,
_body_pair_nnz(mjm, flex_body_list[idx1], flex_body_list[idx2]),
)
total_nnz += njmax * max_contact_nnz
return int(min(max(total_nnz, 1), njmax * mjm.nv))
def _resolve_batch_size(na: int | None, n: int | None, nworld: int, default: int) -> int:
if na is not None:
return na
if n is not None:
return n * nworld
return default
[docs]
def make_data(
mjm: mujoco.MjModel,
nworld: int = 1,
nconmax: Optional[int] = None,
nccdmax: Optional[int] = None,
njmax: Optional[int] = None,
njmax_nnz: Optional[int] = None,
naconmax: Optional[int] = None,
naccdmax: Optional[int] = None,
) -> types.Data:
"""Creates a data object on device.
Args:
mjm: The model containing kinematic and dynamic information (host).
nworld: Number of worlds.
nconmax: Number of contacts to allocate per world. Contacts exist in large
heterogeneous arrays: one world may have more than nconmax contacts.
nccdmax: Number of CCD contacts to allocate per world. Same semantics as nconmax.
njmax: Number of constraints to allocate per world. Constraint arrays are
batched by world: no world may have more than njmax constraints.
njmax_nnz: Number of non-zeros in constraint Jacobian (sparse). Defaults to njmax * nv.
naconmax: Number of contacts to allocate for all worlds. Overrides nconmax.
naccdmax: Maximum number of CCD contacts. Defaults to naconmax.
Returns:
The data object containing the current state and output arrays (device).
"""
# TODO(team): move nconmax, njmax to Model?
if nconmax is None:
nconmax = _default_nconmax(mjm)
if njmax is None:
njmax = _default_njmax(mjm)
if nconmax < 0:
raise ValueError("nconmax must be >= 0")
if njmax < 0:
raise ValueError("njmax must be >= 0")
if nworld < 1:
raise ValueError(f"nworld must be >= 1")
naconmax = _resolve_batch_size(naconmax, nconmax, nworld, 0)
if naconmax < 0:
raise ValueError("naconmax must be >= 0")
naccdmax = _resolve_batch_size(naccdmax, nccdmax, nworld, naconmax)
if naccdmax < 0:
raise ValueError("naccdmax must be >= 0")
elif naccdmax > naconmax:
raise ValueError(f"naccdmax ({naccdmax}) must be <= naconmax ({naconmax})")
if nccdmax is None:
nccdmax = nconmax
else:
if nccdmax < 0:
raise ValueError("nccdmax must be >= 0")
elif nccdmax > nconmax:
raise ValueError(f"nccdmax ({nccdmax}) must be <= nconmax ({nconmax})")
sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int})
sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max()
sizes["nmaxpyramid"] = np.maximum(1, 2 * (sizes["nmaxcondim"] - 1))
tile_size = types.TILE_SIZE_JTDAJ_SPARSE if is_sparse(mjm) else types.TILE_SIZE_JTDAJ_DENSE
sizes["njmax_pad"], sizes["nv_pad"] = _get_padded_sizes(mjm.nv, njmax, is_sparse(mjm), tile_size)
sizes["nworld"] = nworld
sizes["naconmax"] = naconmax
sizes["njmax"] = njmax
if njmax_nnz is None:
if is_sparse(mjm):
njmax_nnz = _default_njmax_nnz(mjm, nconmax, njmax)
else:
njmax_nnz = njmax * mjm.nv
contact = types.Contact(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Contact)})
contact.efc_address = wp.array(np.full((naconmax, sizes["nmaxpyramid"]), -1, dtype=int), dtype=int)
efc = types.Constraint(**{f.name: _create_array(None, f.type, sizes) for f in dataclasses.fields(types.Constraint)})
if is_sparse(mjm):
efc.J_rownnz = wp.zeros((nworld, njmax), dtype=int)
efc.J_rowadr = wp.zeros((nworld, njmax), dtype=int)
efc.J_colind = wp.zeros((nworld, 1, njmax_nnz), dtype=int)
efc.J = wp.zeros((nworld, 1, njmax_nnz), dtype=float)
else:
efc.J_rownnz = wp.zeros((nworld, 0), dtype=int)
efc.J_rowadr = wp.zeros((nworld, 0), dtype=int)
efc.J_colind = wp.zeros((nworld, 0, 0), dtype=int)
efc.J = wp.zeros((nworld, sizes["njmax_pad"], sizes["nv_pad"]), dtype=float)
contact_kwargs = {}
for f in dataclasses.fields(types.Contact):
contact_kwargs[f.name] = _create_array(None, f.type, sizes)
contact = types.Contact(**contact_kwargs)
# world body and static geom (attached to the world) poses are precomputed
# this speeds up scenes with many static geoms (e.g. terrains)
# TODO(team): remove this when we introduce dof islands + sleeping
mjd = mujoco.MjData(mjm)
mujoco.mj_kinematics(mjm, mjd)
# mocap
mocap_body = np.nonzero(mjm.body_mocapid >= 0)[0]
mocap_id = mjm.body_mocapid[mocap_body]
d_kwargs = {
"qpos": wp.array(np.tile(mjm.qpos0, nworld), shape=(nworld, mjm.nq), dtype=float),
"contact": contact,
"efc": efc,
"nworld": nworld,
"naconmax": naconmax,
"naccdmax": naccdmax,
"njmax": njmax,
"njmax_pad": sizes["njmax_pad"],
"njmax_nnz": njmax_nnz,
"qM": None,
"qLD": None,
# world body
"xquat": wp.array(np.tile(mjd.xquat, (nworld, 1)), shape=(nworld, mjm.nbody), dtype=wp.quat),
"xmat": wp.array(np.tile(mjd.xmat, (nworld, 1)), shape=(nworld, mjm.nbody), dtype=wp.mat33),
"ximat": wp.array(np.tile(mjd.ximat, (nworld, 1)), shape=(nworld, mjm.nbody), dtype=wp.mat33),
# static geoms
"geom_xpos": wp.array(np.tile(mjd.geom_xpos, (nworld, 1)), shape=(nworld, mjm.ngeom), dtype=wp.vec3),
"geom_xmat": wp.array(np.tile(mjd.geom_xmat, (nworld, 1)), shape=(nworld, mjm.ngeom), dtype=wp.mat33),
# mocap
"mocap_pos": wp.array(np.tile(mjm.body_pos[mocap_body[mocap_id]], (nworld, 1)), shape=(nworld, mjm.nmocap), dtype=wp.vec3),
"mocap_quat": wp.array(
np.tile(mjm.body_quat[mocap_body[mocap_id]], (nworld, 1)), shape=(nworld, mjm.nmocap), dtype=wp.quat
),
# equality constraints
"eq_active": wp.array(np.tile(mjm.eq_active0.astype(bool), (nworld, 1)), shape=(nworld, mjm.neq), dtype=bool),
# island arrays
"nisland": None,
"tree_island": None,
}
for f in dataclasses.fields(types.Data):
if f.name in d_kwargs:
continue
d_kwargs[f.name] = _create_array(None, f.type, sizes)
d = types.Data(**d_kwargs)
if is_sparse(mjm):
d.qM = wp.zeros((nworld, 1, mjm.nM), dtype=float)
d.qLD = wp.zeros((nworld, 1, mjm.nC), dtype=float)
else:
d.qM = wp.zeros((nworld, sizes["nv_pad"], sizes["nv_pad"]), dtype=float)
d.qLD = wp.zeros((nworld, mjm.nv, mjm.nv), dtype=float)
# island discovery arrays
d.nisland = wp.zeros((nworld,), dtype=int)
d.tree_island = wp.zeros((nworld, mjm.ntree), dtype=int)
return d
[docs]
def put_data(
mjm: mujoco.MjModel,
mjd: mujoco.MjData,
nworld: int = 1,
nconmax: Optional[int] = None,
nccdmax: Optional[int] = None,
njmax: Optional[int] = None,
njmax_nnz: Optional[int] = None,
naconmax: Optional[int] = None,
naccdmax: Optional[int] = None,
) -> types.Data:
"""Moves data from host to a device.
Args:
mjm: The model containing kinematic and dynamic information (host).
mjd: The data object containing current state and output arrays (host).
nworld: The number of worlds.
nconmax: Number of contacts to allocate per world. Contacts exist in large
heterogenous arrays: one world may have more than nconmax contacts.
nccdmax: Number of CCD contacts to allocate per world. Same semantics as nconmax.
njmax: Number of constraints to allocate per world. Constraint arrays are
batched by world: no world may have more than njmax constraints.
njmax_nnz: Number of non-zeros in constraint Jacobian (sparse). Defaults to njmax * nv.
naconmax: Number of contacts to allocate for all worlds. Overrides nconmax.
naccdmax: Maximum number of CCD contacts. Defaults to naconmax.
Returns:
The data object containing the current state and output arrays (device).
"""
# TODO(team): move nconmax and njmax to Model?
# TODO(team): decide what to do about uninitialized warp-only fields created by put_data
# we need to ensure these are only workspace fields and don't carry state
if nconmax is None:
nconmax = _default_nconmax(mjm, mjd)
if njmax is None:
njmax = _default_njmax(mjm, mjd)
if nconmax < 0:
raise ValueError("nconmax must be >= 0")
if njmax < 0:
raise ValueError("njmax must be >= 0")
if nworld < 1:
raise ValueError(f"nworld must be >= 1")
naconmax_is_input = naconmax is not None
naconmax = _resolve_batch_size(naconmax, nconmax, nworld, 0)
if naconmax < 0:
raise ValueError("naconmax must be >= 0")
if not naconmax_is_input and mjd.ncon > nconmax:
raise ValueError(f"nconmax overflow (nconmax must be >= {mjd.ncon})")
elif naconmax < mjd.ncon * nworld:
raise ValueError(f"naconmax overflow (naconmax must be >= {mjd.ncon * nworld})")
naccdmax = _resolve_batch_size(naccdmax, nccdmax, nworld, naconmax)
if naccdmax < 0:
raise ValueError("naccdmax must be >= 0")
elif naccdmax > naconmax:
raise ValueError(f"naccdmax ({naccdmax}) must be <= naconmax ({naconmax})")
if nccdmax is None:
nccdmax = nconmax
else:
if nccdmax < 0:
raise ValueError("nccdmax must be >= 0")
elif nccdmax > nconmax:
raise ValueError(f"nccdmax ({nccdmax}) must be <= nconmax ({nconmax})")
if mjd.nefc > njmax:
raise ValueError(f"njmax overflow (njmax must be >= {mjd.nefc})")
sizes = dict({"*": 1}, **{f.name: getattr(mjm, f.name, None) for f in dataclasses.fields(types.Model) if f.type is int})
sizes["nmaxcondim"] = np.concatenate(([0], mjm.geom_condim, mjm.pair_dim)).max()
sizes["nmaxpyramid"] = np.maximum(1, 2 * (sizes["nmaxcondim"] - 1))
tile_size = types.TILE_SIZE_JTDAJ_SPARSE if is_sparse(mjm) else types.TILE_SIZE_JTDAJ_DENSE
sizes["njmax_pad"], sizes["nv_pad"] = _get_padded_sizes(mjm.nv, njmax, is_sparse(mjm), tile_size)
sizes["nworld"] = nworld
sizes["naconmax"] = naconmax
sizes["njmax"] = njmax
if njmax_nnz is None:
if is_sparse(mjm):
njmax_nnz = _default_njmax_nnz(mjm, nconmax, njmax)
else:
njmax_nnz = njmax * mjm.nv
# ensure static geom positions are computed
# TODO: remove once MjData creation semantics are fixed
mujoco.mj_kinematics(mjm, mjd)
# create contact
contact_kwargs = {"efc_address": None, "worldid": None, "type": None, "geomcollisionid": None}
for f in dataclasses.fields(types.Contact):
if f.name in contact_kwargs:
continue
val = getattr(mjd.contact, f.name)
val = np.repeat(val, nworld, axis=0)
width = ((0, naconmax - val.shape[0]),) + ((0, 0),) * (val.ndim - 1)
val = np.pad(val, width)
contact_kwargs[f.name] = _create_array(val, f.type, sizes)
contact = types.Contact(**contact_kwargs)
contact.efc_address = np.full((naconmax, sizes["nmaxpyramid"]), -1, dtype=int)
for i in range(mjd.ncon):
efc_address = mjd.contact.efc_address[i]
if efc_address == -1:
continue
condim = mjd.contact.dim[i]
ndim = max(1, 2 * (condim - 1)) if mjm.opt.cone == mujoco.mjtCone.mjCONE_PYRAMIDAL else condim
for j in range(nworld):
contact.efc_address[j * mjd.ncon + i, :ndim] = efc_address + np.arange(ndim)
contact.efc_address = wp.array(contact.efc_address, dtype=int)
contact.worldid = np.pad(np.repeat(np.arange(nworld), mjd.ncon), (0, naconmax - nworld * mjd.ncon))
contact.worldid = wp.array(contact.worldid, dtype=int)
contact.type = wp.ones((naconmax,), dtype=int) # TODO(team): set values
contact.geomcollisionid = wp.empty((naconmax,), dtype=int) # TODO(team): set values
# create efc
efc_kwargs = {"J_rownnz": None, "J_rowadr": None, "J_colind": None, "J": None}
for f in dataclasses.fields(types.Constraint):
if f.name in efc_kwargs:
continue
shape = tuple(sizes[dim] if isinstance(dim, str) else dim for dim in f.type.shape)
val = np.zeros(shape, dtype=f.type.dtype)
if f.name in ("type", "id", "pos", "margin", "D", "vel", "aref", "frictionloss", "force"):
val[:, : mjd.nefc] = np.tile(getattr(mjd, "efc_" + f.name), (nworld, 1))
efc_kwargs[f.name] = wp.array(val, dtype=f.type.dtype)
efc = types.Constraint(**efc_kwargs)
if is_sparse(mjm):
J_rownnz = np.zeros(njmax, dtype=np.int32)
J_rowadr = np.zeros(njmax, dtype=np.int32)
J_colind = np.zeros(njmax_nnz, dtype=np.int32)
J = np.zeros(njmax_nnz, dtype=np.float64)
if mjd.nefc:
if mujoco.mj_isSparse(mjm):
J_rownnz[: mjd.nefc] = mjd.efc_J_rownnz[: mjd.nefc]
J_rowadr[: mjd.nefc] = mjd.efc_J_rowadr[: mjd.nefc]
nnz = int(mjd.efc_J_rownnz[: mjd.nefc].sum())
J_colind[:nnz] = mjd.efc_J_colind[:nnz]
J[:nnz] = mjd.efc_J[:nnz]
else:
dense_J = mjd.efc_J.reshape((-1, mjm.nv))[: mjd.nefc]
mujoco.mju_dense2sparse(
J[: mjd.nefc * mjm.nv], dense_J, J_rownnz[: mjd.nefc], J_rowadr[: mjd.nefc], J_colind[: mjd.nefc * mjm.nv]
)
efc.J_rownnz = wp.array(np.tile(J_rownnz, (nworld, 1)), dtype=int)
efc.J_rowadr = wp.array(np.tile(J_rowadr, (nworld, 1)), dtype=int)
efc.J_colind = wp.array(np.tile(J_colind, (nworld, 1)).reshape((nworld, 1, -1)), dtype=int)
efc.J = wp.array(np.tile(J, (nworld, 1)).reshape((nworld, 1, -1)), dtype=float)
else:
efc.J_rownnz = wp.zeros((nworld, 0), dtype=int)
efc.J_rowadr = wp.zeros((nworld, 0), dtype=int)
efc.J_colind = wp.zeros((nworld, 0, 0), dtype=int)
mj_efc_J = np.zeros((mjd.nefc, mjm.nv))
if mjd.nefc:
if mujoco.mj_isSparse(mjm):
mujoco.mju_sparse2dense(mj_efc_J, mjd.efc_J, mjd.efc_J_rownnz, mjd.efc_J_rowadr, mjd.efc_J_colind)
else:
mj_efc_J = mjd.efc_J.reshape((mjd.nefc, mjm.nv))
efc_J = np.zeros((nworld, sizes["njmax_pad"], sizes["nv_pad"]), dtype=float)
efc_J[:, : mjd.nefc, : mjm.nv] = np.tile(mj_efc_J, (nworld, 1, 1))
efc.J = wp.array(efc_J, dtype=float)
# create data
d_kwargs = {
"contact": contact,
"efc": efc,
"nworld": nworld,
"naconmax": naconmax,
"naccdmax": naccdmax,
"njmax": njmax,
"njmax_pad": sizes["njmax_pad"],
"njmax_nnz": njmax_nnz,
# fields set after initialization:
"solver_niter": None,
"qM": None,
"qLD": None,
"nacon": None,
# island arrays
"nisland": None,
"tree_island": None,
}
for f in dataclasses.fields(types.Data):
if f.name in d_kwargs:
continue
val = getattr(mjd, f.name, None)
if val is not None:
shape = val.shape if hasattr(val, "shape") else ()
val = np.full((nworld,) + shape, val)
d_kwargs[f.name] = _create_array(val, f.type, sizes)
d = types.Data(**d_kwargs)
d.solver_niter = wp.full((nworld,), mjd.solver_niter[0], dtype=int)
if is_sparse(mjm):
d.qM = wp.array(np.full((nworld, 1, mjm.nM), mjd.qM), dtype=float)
d.qLD = wp.array(np.full((nworld, 1, mjm.nC), mjd.qLD), dtype=float)
else:
qM = np.zeros((mjm.nv, mjm.nv))
mujoco.mj_fullM(mjm, qM, mjd.qM)
qLD = np.linalg.cholesky(qM) if (mjd.qM != 0.0).any() and (mjd.qLD != 0.0).any() else np.zeros((mjm.nv, mjm.nv))
padding = sizes["nv_pad"] - mjm.nv
qM_padded = np.pad(qM, ((0, padding), (0, padding)), mode="constant", constant_values=0.0)
d.qM = wp.array(np.full((nworld, sizes["nv_pad"], sizes["nv_pad"]), qM_padded), dtype=float)
d.qLD = wp.array(np.full((nworld, mjm.nv, mjm.nv), qLD), dtype=float)
# island arrays
d.nisland = wp.array(np.full(nworld, mjd.nisland), dtype=int)
d.tree_island = wp.array(np.tile(mjd.tree_island, (nworld, 1)), dtype=int)
d.nacon = wp.array([mjd.ncon * nworld], dtype=int)
return d
[docs]
def get_data_into(
result: mujoco.MjData,
mjm: mujoco.MjModel,
d: types.Data,
world_id: int = 0,
):
"""Gets data from a device into an existing mujoco.MjData.
Args:
result: The data object containing the current state and output arrays (host).
mjm: The model containing kinematic and dynamic information (host).
d: The data object containing the current state and output arrays (device).
world_id: The id of the world to get the data from.
"""
# nacon and nefc can overflow. in that case, only pull up to the max contacts and constraints
nacon = min(d.nacon.numpy()[0], d.naconmax)
nefc = min(d.nefc.numpy()[world_id], d.njmax)
ncon_filter = np.zeros_like(d.contact.worldid.numpy(), dtype=bool)
ncon_filter[:nacon] = d.contact.worldid.numpy()[:nacon] == world_id
ncon = ncon_filter.sum()
if ncon != result.ncon or nefc != result.nefc:
# TODO(team): if sparse, set nJ based on sparse efc_J
mujoco._functions._realloc_con_efc(result, ncon=ncon, nefc=nefc, nJ=nefc * mjm.nv)
ne = d.ne.numpy()[world_id]
nf = d.nf.numpy()[world_id]
nl = d.nl.numpy()[world_id]
# efc indexing
# mujoco expects contiguous efc ordering for contacts
# this ordering is not guaranteed with mujoco warp, we enforce order here
if ncon > 0:
efc_idx_efl = np.arange(ne + nf + nl)
contact_dim = d.contact.dim.numpy()[ncon_filter]
contact_efc_address = d.contact.efc_address.numpy()[ncon_filter]
efc_idx_c = []
contact_efc_address_ordered = [ne + nf + nl]
for i in range(ncon):
dim = contact_dim[i]
if mjm.opt.cone == mujoco.mjtCone.mjCONE_PYRAMIDAL:
ndim = np.maximum(1, 2 * (dim - 1))
else:
ndim = dim
efc_idx_c.append(contact_efc_address[i, :ndim])
if i < ncon - 1:
contact_efc_address_ordered.append(contact_efc_address_ordered[-1] + ndim)
efc_idx = np.concatenate((efc_idx_efl, *efc_idx_c))
contact_efc_address_ordered = np.array(contact_efc_address_ordered)
else:
efc_idx = np.array(np.arange(nefc))
contact_efc_address_ordered = np.empty(0)
efc_idx = efc_idx[:nefc] # dont emit indices for overflow constraints
result.solver_niter[0] = d.solver_niter.numpy()[world_id]
result.ncon = ncon
result.ne = ne
result.nf = nf
result.nl = nl
result.time = d.time.numpy()[world_id]
result.energy[:] = d.energy.numpy()[world_id]
result.qpos[:] = d.qpos.numpy()[world_id]
result.qvel[:] = d.qvel.numpy()[world_id]
result.act[:] = d.act.numpy()[world_id]
result.qacc_warmstart[:] = d.qacc_warmstart.numpy()[world_id]
result.ctrl[:] = d.ctrl.numpy()[world_id]
result.qfrc_applied[:] = d.qfrc_applied.numpy()[world_id]
result.xfrc_applied[:] = d.xfrc_applied.numpy()[world_id]
result.eq_active[:] = d.eq_active.numpy()[world_id]
result.mocap_pos[:] = d.mocap_pos.numpy()[world_id]
result.mocap_quat[:] = d.mocap_quat.numpy()[world_id]
result.qacc[:] = d.qacc.numpy()[world_id]
result.act_dot[:] = d.act_dot.numpy()[world_id]
result.xpos[:] = d.xpos.numpy()[world_id]
result.xquat[:] = d.xquat.numpy()[world_id]
result.xmat[:] = d.xmat.numpy()[world_id].reshape((-1, 9))
result.xipos[:] = d.xipos.numpy()[world_id]
result.ximat[:] = d.ximat.numpy()[world_id].reshape((-1, 9))
result.xanchor[:] = d.xanchor.numpy()[world_id]
result.xaxis[:] = d.xaxis.numpy()[world_id]
result.geom_xpos[:] = d.geom_xpos.numpy()[world_id]
result.geom_xmat[:] = d.geom_xmat.numpy()[world_id].reshape((-1, 9))
result.site_xpos[:] = d.site_xpos.numpy()[world_id]
result.site_xmat[:] = d.site_xmat.numpy()[world_id].reshape((-1, 9))
result.cam_xpos[:] = d.cam_xpos.numpy()[world_id]
result.cam_xmat[:] = d.cam_xmat.numpy()[world_id].reshape((-1, 9))
result.light_xpos[:] = d.light_xpos.numpy()[world_id]
result.light_xdir[:] = d.light_xdir.numpy()[world_id]
result.subtree_com[:] = d.subtree_com.numpy()[world_id]
result.cdof[:] = d.cdof.numpy()[world_id]
result.cinert[:] = d.cinert.numpy()[world_id]
result.flexvert_xpos[:] = d.flexvert_xpos.numpy()[world_id]
if mjm.nflexedge > 0:
result.flexedge_J[:] = d.flexedge_J.numpy()[world_id].reshape(-1)
result.flexedge_length[:] = d.flexedge_length.numpy()[world_id]
result.flexedge_velocity[:] = d.flexedge_velocity.numpy()[world_id]
result.actuator_length[:] = d.actuator_length.numpy()[world_id]
result.moment_rownnz[:] = d.moment_rownnz.numpy()[world_id]
result.moment_rowadr[:] = d.moment_rowadr.numpy()[world_id]
if mjm.nu:
result.moment_colind[:] = d.moment_colind.numpy()[world_id]
result.actuator_moment[:] = d.actuator_moment.numpy()[world_id]
result.crb[:] = d.crb.numpy()[world_id]
result.qLDiagInv[:] = d.qLDiagInv.numpy()[world_id]
result.ten_velocity[:] = d.ten_velocity.numpy()[world_id]
result.actuator_velocity[:] = d.actuator_velocity.numpy()[world_id]
result.cvel[:] = d.cvel.numpy()[world_id]
result.cdof_dot[:] = d.cdof_dot.numpy()[world_id]
result.qfrc_bias[:] = d.qfrc_bias.numpy()[world_id]
result.qfrc_spring[:] = d.qfrc_spring.numpy()[world_id]
result.qfrc_damper[:] = d.qfrc_damper.numpy()[world_id]
result.qfrc_gravcomp[:] = d.qfrc_gravcomp.numpy()[world_id]
result.qfrc_fluid[:] = d.qfrc_fluid.numpy()[world_id]
result.qfrc_passive[:] = d.qfrc_passive.numpy()[world_id]
result.subtree_linvel[:] = d.subtree_linvel.numpy()[world_id]
result.subtree_angmom[:] = d.subtree_angmom.numpy()[world_id]
result.actuator_force[:] = d.actuator_force.numpy()[world_id]
result.qfrc_actuator[:] = d.qfrc_actuator.numpy()[world_id]
result.qfrc_smooth[:] = d.qfrc_smooth.numpy()[world_id]
result.qacc_smooth[:] = d.qacc_smooth.numpy()[world_id]
result.qfrc_constraint[:] = d.qfrc_constraint.numpy()[world_id]
result.qfrc_inverse[:] = d.qfrc_inverse.numpy()[world_id]
# contact
result.contact.dist[:ncon] = d.contact.dist.numpy()[ncon_filter]
result.contact.pos[:ncon] = d.contact.pos.numpy()[ncon_filter]
result.contact.frame[:ncon] = d.contact.frame.numpy()[ncon_filter].reshape((-1, 9))
result.contact.includemargin[:ncon] = d.contact.includemargin.numpy()[ncon_filter]
result.contact.friction[:ncon] = d.contact.friction.numpy()[ncon_filter]
result.contact.solref[:ncon] = d.contact.solref.numpy()[ncon_filter]
result.contact.solreffriction[:ncon] = d.contact.solreffriction.numpy()[ncon_filter]
result.contact.solimp[:ncon] = d.contact.solimp.numpy()[ncon_filter]
result.contact.dim[:ncon] = d.contact.dim.numpy()[ncon_filter]
result.contact.geom[:ncon] = d.contact.geom.numpy()[ncon_filter]
result.contact.efc_address[:ncon] = contact_efc_address_ordered[:ncon]
if is_sparse(mjm):
result.qM[:] = d.qM.numpy()[world_id, 0]
result.qLD[:] = d.qLD.numpy()[world_id, 0]
else:
qM = d.qM.numpy()[world_id]
adr = 0
for i in range(mjm.nv):
j = i
while j >= 0:
result.qM[adr] = qM[i, j]
j = mjm.dof_parentid[j]
adr += 1
mujoco.mj_factorM(mjm, result)
if nefc > 0:
if is_sparse(mjm):
efc_J = np.zeros((nefc, mjm.nv))
mujoco.mju_sparse2dense(
efc_J,
d.efc.J.numpy()[world_id, 0],
d.efc.J_rownnz.numpy()[world_id, :nefc],
d.efc.J_rowadr.numpy()[world_id, :nefc],
d.efc.J_colind.numpy()[world_id, 0],
)
else:
efc_J = d.efc.J.numpy()[world_id, :nefc, : mjm.nv]
# write to mujoco result (format depends on mj_isSparse)
if mujoco.mj_isSparse(mjm):
mujoco.mju_dense2sparse(
result.efc_J,
efc_J[efc_idx],
result.efc_J_rownnz,
result.efc_J_rowadr,
result.efc_J_colind,
)
else:
result.efc_J[: nefc * mjm.nv] = efc_J[efc_idx].flatten()
# efc
result.efc_type[:] = d.efc.type.numpy()[world_id, efc_idx]
result.efc_id[:] = d.efc.id.numpy()[world_id, efc_idx]
result.efc_pos[:] = d.efc.pos.numpy()[world_id, efc_idx]
result.efc_margin[:] = d.efc.margin.numpy()[world_id, efc_idx]
result.efc_D[:] = d.efc.D.numpy()[world_id, efc_idx]
result.efc_vel[:] = d.efc.vel.numpy()[world_id, efc_idx]
result.efc_aref[:] = d.efc.aref.numpy()[world_id, efc_idx]
result.efc_frictionloss[:] = d.efc.frictionloss.numpy()[world_id, efc_idx]
result.efc_state[:] = d.efc.state.numpy()[world_id, efc_idx]
result.efc_force[:] = d.efc.force.numpy()[world_id, efc_idx]
# rne_postconstraint
result.cacc[:] = d.cacc.numpy()[world_id]
result.cfrc_int[:] = d.cfrc_int.numpy()[world_id]
result.cfrc_ext[:] = d.cfrc_ext.numpy()[world_id]
# tendon
result.ten_length[:] = d.ten_length.numpy()[world_id]
if mjm.ntendon > 0:
result.ten_J[:] = d.ten_J.numpy()[world_id]
result.ten_wrapadr[:] = d.ten_wrapadr.numpy()[world_id]
result.ten_wrapnum[:] = d.ten_wrapnum.numpy()[world_id]
result.wrap_obj[:] = d.wrap_obj.numpy()[world_id]
result.wrap_xpos[:] = d.wrap_xpos.numpy()[world_id]
# sensors
result.sensordata[:] = d.sensordata.numpy()[world_id]
# islands
nisland = d.nisland.numpy()[world_id]
result.nisland = nisland
if nisland:
result.tree_island[:] = d.tree_island.numpy()[world_id]
[docs]
def reset_data(m: types.Model, d: types.Data, reset: Optional[wp.array] = None):
"""Clear data, set defaults; optionally by world.
Args:
m: The model containing kinematic and dynamic information (device).
d: The data object containing the current state and output arrays (device).
reset: Per-world bitmask. Reset if True.
"""
@wp.kernel(module="unique", enable_backward=False)
def reset_xfrc_applied(reset_in: wp.array[bool], xfrc_applied_out: wp.array2d[wp.spatial_vector]):
worldid, bodyid, elemid = wp.tid()
if wp.static(reset is not None):
if not reset_in[worldid]:
return
xfrc_applied_out[worldid, bodyid][elemid] = 0.0
@wp.kernel(module="unique", enable_backward=False)
def reset_qM(reset_in: wp.array[bool], qM_out: wp.array3d[float]):
worldid, elemid1, elemid2 = wp.tid()
if wp.static(reset is not None):
if not reset_in[worldid]:
return
qM_out[worldid, elemid1, elemid2] = 0.0
@wp.kernel(module="unique", enable_backward=False)
def reset_nworld(
# Model:
nq: int,
nv: int,
nu: int,
na: int,
neq: int,
nsensordata: int,
qpos0: wp.array2d[float],
eq_active0: wp.array[bool],
# Data in:
nworld_in: int,
# In:
reset_in: wp.array[bool],
# Data out:
solver_niter_out: wp.array[int],
ne_out: wp.array[int],
nf_out: wp.array[int],
nl_out: wp.array[int],
nefc_out: wp.array[int],
time_out: wp.array[float],
energy_out: wp.array[wp.vec2],
qpos_out: wp.array2d[float],
qvel_out: wp.array2d[float],
act_out: wp.array2d[float],
qacc_warmstart_out: wp.array2d[float],
ctrl_out: wp.array2d[float],
qfrc_applied_out: wp.array2d[float],
eq_active_out: wp.array2d[bool],
qacc_out: wp.array2d[float],
act_dot_out: wp.array2d[float],
sensordata_out: wp.array2d[float],
nacon_out: wp.array[int],
):
worldid = wp.tid()
if wp.static(reset is not None):
if not reset_in[worldid]:
return
solver_niter_out[worldid] = 0
if worldid == 0:
nacon_out[0] = 0
ne_out[worldid] = 0
nf_out[worldid] = 0
nl_out[worldid] = 0
nefc_out[worldid] = 0
time_out[worldid] = 0.0
energy_out[worldid] = wp.vec2(0.0, 0.0)
qpos0_id = worldid % qpos0.shape[0]
for i in range(nq):
qpos_out[worldid, i] = qpos0[qpos0_id, i]
if i < nv:
qvel_out[worldid, i] = 0.0
qacc_warmstart_out[worldid, i] = 0.0
qfrc_applied_out[worldid, i] = 0.0
qacc_out[worldid, i] = 0.0
for i in range(nu):
ctrl_out[worldid, i] = 0.0
if i < na:
act_out[worldid, i] = 0.0
act_dot_out[worldid, i] = 0.0
for i in range(neq):
eq_active_out[worldid, i] = eq_active0[i]
for i in range(nsensordata):
sensordata_out[worldid, i] = 0.0
@wp.kernel(module="unique", enable_backward=False)
def reset_mocap(
# Model:
body_mocapid: wp.array[int],
body_pos: wp.array2d[wp.vec3],
body_quat: wp.array2d[wp.quat],
# In:
reset_in: wp.array[bool],
# Data out:
mocap_pos_out: wp.array2d[wp.vec3],
mocap_quat_out: wp.array2d[wp.quat],
):
worldid, bodyid = wp.tid()
if wp.static(reset is not None):
if not reset_in[worldid]:
return
mocapid = body_mocapid[bodyid]
if mocapid >= 0:
mocap_pos_out[worldid, mocapid] = body_pos[worldid % body_pos.shape[0], bodyid]
mocap_quat_out[worldid, mocapid] = body_quat[worldid % body_quat.shape[0], bodyid]
@wp.kernel(module="unique", enable_backward=False)
def reset_contact(
# Data in:
nacon_in: wp.array[int],
# In:
reset_in: wp.array[bool],
nefcaddress: int,
# Data out:
contact_dist_out: wp.array[float],
contact_pos_out: wp.array[wp.vec3],
contact_frame_out: wp.array[wp.mat33],
contact_includemargin_out: wp.array[float],
contact_friction_out: wp.array[types.vec5],
contact_solref_out: wp.array[wp.vec2],
contact_solreffriction_out: wp.array[wp.vec2],
contact_solimp_out: wp.array[types.vec5],
contact_dim_out: wp.array[int],
contact_geom_out: wp.array[wp.vec2i],
contact_flex_out: wp.array[wp.vec2i],
contact_vert_out: wp.array[wp.vec2i],
contact_efc_address_out: wp.array2d[int],
contact_worldid_out: wp.array[int],
contact_type_out: wp.array[int],
contact_geomcollisionid_out: wp.array[int],
):
conid = wp.tid()
if conid >= nacon_in[0]:
return
worldid = contact_worldid_out[conid]
if wp.static(reset is not None):
if worldid >= 0:
if not reset_in[worldid]:
return
contact_dist_out[conid] = 0.0
contact_pos_out[conid] = wp.vec3(0.0)
contact_frame_out[conid] = wp.mat33(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
contact_includemargin_out[conid] = 0.0
contact_friction_out[conid] = types.vec5(0.0, 0.0, 0.0, 0.0, 0.0)
contact_solref_out[conid] = wp.vec2(0.0, 0.0)
contact_solreffriction_out[conid] = wp.vec2(0.0, 0.0)
contact_solimp_out[conid] = types.vec5(0.0, 0.0, 0.0, 0.0, 0.0)
contact_dim_out[conid] = 0
contact_geom_out[conid] = wp.vec2i(0, 0)
contact_flex_out[conid] = wp.vec2i(0, 0)
contact_vert_out[conid] = wp.vec2i(0, 0)
for i in range(nefcaddress):
contact_efc_address_out[conid, i] = -1
contact_worldid_out[conid] = 0
contact_type_out[conid] = 0
contact_geomcollisionid_out[conid] = 0
reset_input = reset or wp.ones(d.nworld, dtype=bool)
wp.launch(reset_xfrc_applied, dim=(d.nworld, m.nbody, 6), inputs=[reset_input], outputs=[d.xfrc_applied])
wp.launch(
reset_qM,
dim=(d.nworld, d.qM.shape[1], d.qM.shape[2]),
inputs=[reset_input],
outputs=[d.qM],
)
# set mocap_pos/quat = body_pos/quat for mocap bodies
wp.launch(
reset_mocap,
dim=(d.nworld, m.nbody),
inputs=[m.body_mocapid, m.body_pos, m.body_quat, reset_input],
outputs=[d.mocap_pos, d.mocap_quat],
)
# clear contacts
wp.launch(
reset_contact,
dim=d.naconmax,
inputs=[d.nacon, reset_input, d.contact.efc_address.shape[1]],
outputs=[
d.contact.dist,
d.contact.pos,
d.contact.frame,
d.contact.includemargin,
d.contact.friction,
d.contact.solref,
d.contact.solreffriction,
d.contact.solimp,
d.contact.dim,
d.contact.geom,
d.contact.flex,
d.contact.vert,
d.contact.efc_address,
d.contact.worldid,
d.contact.type,
d.contact.geomcollisionid,
],
)
wp.launch(
reset_nworld,
dim=d.nworld,
inputs=[m.nq, m.nv, m.nu, m.na, m.neq, m.nsensordata, m.qpos0, m.eq_active0, d.nworld, reset_input],
outputs=[
d.solver_niter,
d.ne,
d.nf,
d.nl,
d.nefc,
d.time,
d.energy,
d.qpos,
d.qvel,
d.act,
d.qacc_warmstart,
d.ctrl,
d.qfrc_applied,
d.eq_active,
d.qacc,
d.act_dot,
d.sensordata,
d.nacon,
],
)
# kernel_analyzer: off
@wp.kernel
def _init_subtreemass(
body_mass_in: wp.array2d[float],
body_subtreemass_out: wp.array2d[float],
):
worldid, bodyid = wp.tid()
body_mass_id = worldid % body_mass_in.shape[0]
body_subtreemass_id = worldid % body_subtreemass_out.shape[0]
body_subtreemass_out[body_subtreemass_id, bodyid] = body_mass_in[body_mass_id, bodyid]
@wp.kernel
def _accumulate_subtreemass(
body_parentid: wp.array[int],
body_subtreemass_io: wp.array2d[float],
body_tree_: wp.array[int],
):
worldid, nodeid = wp.tid()
body_subtreemass_id = worldid % body_subtreemass_io.shape[0]
bodyid = body_tree_[nodeid]
parentid = body_parentid[bodyid]
if bodyid != 0:
wp.atomic_add(body_subtreemass_io, body_subtreemass_id, parentid, body_subtreemass_io[body_subtreemass_id, bodyid])
@wp.kernel
def _copy_qpos0_to_qpos(
qpos0: wp.array2d[float],
qpos_out: wp.array2d[float],
):
worldid, i = wp.tid()
qpos0_id = worldid % qpos0.shape[0]
qpos_out[worldid, i] = qpos0[qpos0_id, i]
@wp.kernel
def _copy_tendon_length0(
ten_length_in: wp.array2d[float],
tendon_length0_out: wp.array2d[float],
):
worldid, tenid = wp.tid()
tendon_length0_id = worldid % tendon_length0_out.shape[0]
tendon_length0_out[tendon_length0_id, tenid] = ten_length_in[worldid, tenid]
@wp.kernel
def _compute_meaninertia(
nv: int,
is_sparse: bool,
dof_Madr_in: wp.array[int],
qM_in: wp.array3d[float],
meaninertia_out: wp.array[float],
):
"""Compute mean diagonal inertia from qM at qpos0."""
worldid = wp.tid()
if nv == 0:
meaninertia_out[worldid % meaninertia_out.shape[0]] = 1.0 # Default from MuJoCo
return
total = float(0.0)
for i in range(nv):
if is_sparse:
# Sparse: qM is flattened lower triangular, diagonal at dof_Madr[i]
madr = dof_Madr_in[i]
total += qM_in[worldid, 0, madr]
else:
# Dense: qM is 2D matrix, diagonal at [i,i]
total += qM_in[worldid, i, i]
meaninertia_out[worldid % meaninertia_out.shape[0]] = total / float(nv)
@wp.kernel
def _set_unit_vector(
dofid_target: int,
unit_vec_out: wp.array2d[float],
):
worldid = wp.tid()
nv = unit_vec_out.shape[1]
for i in range(nv):
if i == dofid_target:
unit_vec_out[worldid, i] = 1.0
else:
unit_vec_out[worldid, i] = 0.0
@wp.kernel
def _extract_dof_A_diag(
dofid: int,
result_vec_in: wp.array2d[float],
dof_A_diag_out: wp.array2d[float],
):
worldid = wp.tid()
dof_A_diag_id = worldid % dof_A_diag_out.shape[0]
dof_A_diag_out[dof_A_diag_id, dofid] = result_vec_in[worldid, dofid]
@wp.kernel
def _finalize_dof_invweight0(
dof_jntid: wp.array[int],
jnt_type: wp.array[int],
jnt_dofadr: wp.array[int],
dof_A_diag_in: wp.array2d[float],
dof_invweight0_out: wp.array2d[float],
):
worldid, dofid = wp.tid()
dof_invweight0_id = worldid % dof_invweight0_out.shape[0]
dof_A_diag_id = worldid % dof_A_diag_in.shape[0]
jntid = dof_jntid[dofid]
jtype = jnt_type[jntid]
dofadr = jnt_dofadr[jntid]
if jtype == int(types.JointType.FREE.value):
# FREE joint: 6 DOFs, average first 3 (trans) and last 3 (rot) separately
if dofid < dofadr + 3:
avg = wp.static(1.0 / 3.0) * (
dof_A_diag_in[dof_A_diag_id, dofadr + 0]
+ dof_A_diag_in[dof_A_diag_id, dofadr + 1]
+ dof_A_diag_in[dof_A_diag_id, dofadr + 2]
)
else:
avg = wp.static(1.0 / 3.0) * (
dof_A_diag_in[dof_A_diag_id, dofadr + 3]
+ dof_A_diag_in[dof_A_diag_id, dofadr + 4]
+ dof_A_diag_in[dof_A_diag_id, dofadr + 5]
)
dof_invweight0_out[dof_invweight0_id, dofid] = avg
elif jtype == int(types.JointType.BALL.value):
# BALL joint: 3 DOFs, average all
avg = wp.static(1.0 / 3.0) * (
dof_A_diag_in[dof_A_diag_id, dofadr + 0]
+ dof_A_diag_in[dof_A_diag_id, dofadr + 1]
+ dof_A_diag_in[dof_A_diag_id, dofadr + 2]
)
dof_invweight0_out[dof_invweight0_id, dofid] = avg
else:
# HINGE/SLIDE: 1 DOF, no averaging
dof_invweight0_out[dof_invweight0_id, dofid] = dof_A_diag_in[dof_A_diag_id, dofid]
@wp.kernel
def _compute_body_jac_row(
nv: int,
bodyid_target: int,
row_idx: int,
body_parentid: wp.array[int],
body_rootid: wp.array[int],
body_dofadr: wp.array[int],
body_dofnum: wp.array[int],
dof_parentid: wp.array[int],
subtree_com_in: wp.array2d[wp.vec3],
xipos_in: wp.array2d[wp.vec3],
cdof_in: wp.array2d[wp.spatial_vector],
body_jac_row_out: wp.array2d[float],
):
worldid = wp.tid()
for i in range(nv):
body_jac_row_out[worldid, i] = 0.0
bodyid = bodyid_target
while bodyid > 0 and body_dofnum[bodyid] == 0:
bodyid = body_parentid[bodyid]
if bodyid == 0:
return
# Compute offset from point (xipos) to subtree_com of root body
point = xipos_in[worldid, bodyid_target]
offset = point - subtree_com_in[worldid, body_rootid[bodyid_target]]
# Get last dof that affects this body
dofid = body_dofadr[bodyid] + body_dofnum[bodyid] - 1
# Backward pass over dof ancestor chain
while dofid >= 0:
cdof = cdof_in[worldid, dofid]
cdof_ang = wp.spatial_top(cdof)
cdof_lin = wp.spatial_bottom(cdof)
if row_idx < 3:
tmp = wp.cross(cdof_ang, offset)
if row_idx == 0:
body_jac_row_out[worldid, dofid] = cdof_lin[0] + tmp[0]
elif row_idx == 1:
body_jac_row_out[worldid, dofid] = cdof_lin[1] + tmp[1]
else:
body_jac_row_out[worldid, dofid] = cdof_lin[2] + tmp[2]
else:
if row_idx == 3:
body_jac_row_out[worldid, dofid] = cdof_ang[0]
elif row_idx == 4:
body_jac_row_out[worldid, dofid] = cdof_ang[1]
else:
body_jac_row_out[worldid, dofid] = cdof_ang[2]
dofid = dof_parentid[dofid]
@wp.kernel
def _compute_body_A_diag_entry(
nv: int,
bodyid_target: int,
row_idx: int,
body_jac_row_in: wp.array2d[float],
result_vec_in: wp.array2d[float],
body_A_diag_out: wp.array3d[float],
):
worldid = wp.tid()
body_A_diag_id = worldid % body_A_diag_out.shape[0]
# A[row,row] = J[row] · inv(M) · J[row]' = J[row] · result_vec
dot_prod = float(0.0)
for i in range(nv):
dot_prod += body_jac_row_in[worldid, i] * result_vec_in[worldid, i]
body_A_diag_out[body_A_diag_id, bodyid_target, row_idx] = dot_prod
@wp.kernel
def _finalize_body_invweight0(
body_weldid: wp.array[int],
body_A_diag_in: wp.array3d[float],
body_invweight0_out: wp.array2d[wp.vec2],
):
worldid, bodyid = wp.tid()
body_invweight0_id = worldid % body_invweight0_out.shape[0]
body_A_diag_id = worldid % body_A_diag_in.shape[0]
# World body and static bodies have zero invweight
if bodyid == 0 or body_weldid[bodyid] == 0:
body_invweight0_out[body_invweight0_id, bodyid] = wp.vec2(0.0, 0.0)
return
# Average diagonal: trans = (A[0,0]+A[1,1]+A[2,2])/3, rot = (A[3,3]+A[4,4]+A[5,5])/3
inv_trans = wp.static(1.0 / 3.0) * (
body_A_diag_in[body_A_diag_id, bodyid, 0]
+ body_A_diag_in[body_A_diag_id, bodyid, 1]
+ body_A_diag_in[body_A_diag_id, bodyid, 2]
)
inv_rot = wp.static(1.0 / 3.0) * (
body_A_diag_in[body_A_diag_id, bodyid, 3]
+ body_A_diag_in[body_A_diag_id, bodyid, 4]
+ body_A_diag_in[body_A_diag_id, bodyid, 5]
)
# Prevent degenerate constraints: if one component is near zero, use the other as fallback
if inv_trans < mujoco.mjMINVAL and inv_rot > mujoco.mjMINVAL:
inv_trans = inv_rot # use rotation as fallback for translation
elif inv_rot < mujoco.mjMINVAL and inv_trans > mujoco.mjMINVAL:
inv_rot = inv_trans # use translation as fallback for rotation
body_invweight0_out[body_invweight0_id, bodyid] = wp.vec2(inv_trans, inv_rot)
@wp.kernel
def _copy_tendon_jacobian(
tenid_target: int,
ten_J_rownnz: wp.array[int],
ten_J_rowadr: wp.array[int],
ten_J_colind: wp.array[int],
ten_J_in: wp.array2d[float],
ten_J_vec_out: wp.array2d[float],
):
worldid = wp.tid()
nv = ten_J_in.shape[2]
rownnz = ten_J_rownnz[tenid_target]
rowadr = ten_J_rowadr[tenid_target]
for i in range(rownnz):
colind = ten_J_colind[rowadr + i]
ten_J_vec_out[worldid, colind] = ten_J_in[worldid, rowadr + i]
@wp.kernel
def _compute_tendon_dot_product(
# Model:
ten_J_rownnz: wp.array[int],
ten_J_rowadr: wp.array[int],
ten_J_colind: wp.array[int],
# In:
tenid_target: int,
ten_J_in: wp.array2d[float],
result_vec_in: wp.array2d[float],
# Out:
tendon_invweight0_out: wp.array2d[float],
):
worldid = wp.tid()
tendon_invweight0_id = worldid % tendon_invweight0_out.shape[0]
dot_prod = float(0.0)
rownnz = ten_J_rownnz[tenid_target]
rowadr = ten_J_rowadr[tenid_target]
for i in range(rownnz):
sparseid = rowadr + i
colind = ten_J_colind[sparseid]
dot_prod += ten_J_in[worldid, sparseid] * result_vec_in[worldid, colind]
tendon_invweight0_out[tendon_invweight0_id, tenid_target] = dot_prod
@wp.kernel
def _compute_cam_pos0(
cam_bodyid: wp.array[int],
cam_targetbodyid: wp.array[int],
cam_xpos_in: wp.array2d[wp.vec3],
cam_xmat_in: wp.array2d[wp.mat33],
xpos_in: wp.array2d[wp.vec3],
subtree_com_in: wp.array2d[wp.vec3],
cam_pos0_out: wp.array2d[wp.vec3],
cam_poscom0_out: wp.array2d[wp.vec3],
cam_mat0_out: wp.array2d[wp.mat33],
):
worldid, camid = wp.tid()
cam_pos0_id = worldid % cam_pos0_out.shape[0]
bodyid = cam_bodyid[camid]
targetid = cam_targetbodyid[camid]
cam_xpos = cam_xpos_in[worldid, camid]
cam_pos0_out[cam_pos0_id, camid] = cam_xpos - xpos_in[worldid, bodyid]
if targetid >= 0:
cam_poscom0_out[cam_pos0_id, camid] = cam_xpos - subtree_com_in[worldid, targetid]
else:
cam_poscom0_out[cam_pos0_id, camid] = cam_xpos - subtree_com_in[worldid, bodyid]
cam_mat0_out[cam_pos0_id, camid] = cam_xmat_in[worldid, camid]
@wp.kernel
def _compute_light_pos0(
light_bodyid: wp.array[int],
light_targetbodyid: wp.array[int],
light_xpos_in: wp.array2d[wp.vec3],
light_xdir_in: wp.array2d[wp.vec3],
xpos_in: wp.array2d[wp.vec3],
subtree_com_in: wp.array2d[wp.vec3],
light_pos0_out: wp.array2d[wp.vec3],
light_poscom0_out: wp.array2d[wp.vec3],
light_dir0_out: wp.array2d[wp.vec3],
):
worldid, lightid = wp.tid()
light_pos0_id = worldid % light_pos0_out.shape[0]
bodyid = light_bodyid[lightid]
targetid = light_targetbodyid[lightid]
light_xpos = light_xpos_in[worldid, lightid]
light_pos0_out[light_pos0_id, lightid] = light_xpos - xpos_in[worldid, bodyid]
if targetid >= 0:
light_poscom0_out[light_pos0_id, lightid] = light_xpos - subtree_com_in[worldid, targetid]
else:
light_poscom0_out[light_pos0_id, lightid] = light_xpos - subtree_com_in[worldid, bodyid]
light_dir0_out[light_pos0_id, lightid] = light_xdir_in[worldid, lightid]
@wp.kernel
def _copy_actuator_moment(
actid_target: int,
moment_rownnz_in: wp.array2d[int],
moment_rowadr_in: wp.array2d[int],
moment_colind_in: wp.array2d[int],
actuator_moment_in: wp.array2d[float],
act_moment_vec_out: wp.array2d[float],
):
worldid = wp.tid()
nv = act_moment_vec_out.shape[1]
for i in range(nv):
act_moment_vec_out[worldid, i] = 0.0
rownnz = moment_rownnz_in[worldid, actid_target]
rowadr = moment_rowadr_in[worldid, actid_target]
for i in range(rownnz):
sparseid = rowadr + i
col = moment_colind_in[worldid, sparseid]
act_moment_vec_out[worldid, col] = actuator_moment_in[worldid, sparseid]
@wp.kernel
def _compute_actuator_acc0(
actid_target: int,
nv: int,
result_vec_in: wp.array2d[float],
actuator_acc0_out: wp.array2d[float],
):
worldid = wp.tid()
norm_sq = float(0.0)
for i in range(nv):
norm_sq += result_vec_in[worldid, i] * result_vec_in[worldid, i]
actuator_acc0_out[worldid, actid_target] = wp.sqrt(norm_sq)
@wp.kernel
def _compute_dof_M0(
dof_bodyid: wp.array[int],
dof_armature: wp.array2d[float],
cdof_in: wp.array2d[wp.spatial_vector],
crb_in: wp.array2d[vec10],
dof_M0_out: wp.array2d[float],
):
worldid, dofid = wp.tid()
bodyid = dof_bodyid[dofid]
armature = dof_armature[worldid % dof_armature.shape[0], dofid]
buf = mjmath.inert_vec(crb_in[worldid, bodyid], cdof_in[worldid, dofid])
dof_M0_out[worldid, dofid] = armature + wp.dot(cdof_in[worldid, dofid], buf)
@wp.kernel
def _resolve_dampratio(
actuator_biastype: wp.array[int],
actuator_gainprm: wp.array2d[types.vec10f],
moment_rownnz_in: wp.array2d[int],
moment_rowadr_in: wp.array2d[int],
moment_colind_in: wp.array2d[int],
actuator_moment_in: wp.array2d[float],
dof_M0_in: wp.array2d[float],
nv: int,
actuator_biasprm: wp.array2d[types.vec10f],
):
worldid, actid = wp.tid()
biastype = actuator_biastype[actid]
# only affine bias (position actuators)
if biastype != BiasType.AFFINE:
return
gainprm_id = worldid % actuator_gainprm.shape[0]
biasprm_id = worldid % actuator_biasprm.shape[0]
kp = actuator_gainprm[gainprm_id, actid][0]
biasprm = actuator_biasprm[biasprm_id, actid]
# dampratio condition: gainprm[0] == -biasprm[1] and biasprm[2] > 0
if wp.abs(kp + biasprm[1]) > MJ_MINVAL:
return
if biasprm[2] <= 0.0:
return
dampratio = biasprm[2]
# compute reflected mass: sum(dof_M0[j] / moment[i,j]^2) for active DOFs
mass = float(0.0)
rownnz = moment_rownnz_in[worldid, actid]
rowadr = moment_rowadr_in[worldid, actid]
for k in range(rownnz):
sparseid = rowadr + k
j = moment_colind_in[worldid, sparseid]
moment = actuator_moment_in[worldid, sparseid]
if wp.abs(moment) > MJ_MINVAL:
mass += dof_M0_in[worldid, j] / (moment * moment)
damping = dampratio * 2.0 * wp.sqrt(kp * mass)
# write -damping to biasprm[2]
new_biasprm = biasprm
new_biasprm[2] = -damping
actuator_biasprm[biasprm_id, actid] = new_biasprm
@wp.kernel
def _set_length_range(
actuator_trntype: wp.array[int],
actuator_trnid: wp.array[wp.vec2i],
actuator_gear: wp.array2d[wp.spatial_vector],
jnt_limited: wp.array[int],
jnt_range: wp.array2d[wp.vec2],
tendon_limited: wp.array[int],
tendon_range: wp.array2d[wp.vec2],
ntendon: int,
actuator_lengthrange_out: wp.array2d[wp.vec2],
):
worldid, actid = wp.tid()
trntype = actuator_trntype[actid]
id0 = actuator_trnid[actid][0]
gear0 = actuator_gear[worldid % actuator_gear.shape[0], actid][0]
lr = wp.vec2(0.0, 0.0)
if trntype == TrnType.JOINT or trntype == TrnType.JOINTINPARENT:
if jnt_limited[id0]:
rng = jnt_range[worldid % jnt_range.shape[0], id0]
if gear0 > 0.0:
lr = wp.vec2(rng[0] * gear0, rng[1] * gear0)
else:
lr = wp.vec2(rng[1] * gear0, rng[0] * gear0)
elif trntype == TrnType.TENDON:
if ntendon > 0 and tendon_limited[id0]:
rng = tendon_range[worldid % tendon_range.shape[0], id0]
if gear0 > 0.0:
lr = wp.vec2(rng[0] * gear0, rng[1] * gear0)
else:
lr = wp.vec2(rng[1] * gear0, rng[0] * gear0)
actuator_lengthrange_out[worldid, actid] = lr
# kernel_analyzer: on
[docs]
def set_const_fixed(m: types.Model, d: types.Data):
"""Compute fixed quantities (independent of qpos0).
Computes:
- body_subtreemass: mass of body and all descendants (depends on body_mass)
- ngravcomp: count of bodies with gravity compensation (depends on body_gravcomp)
Args:
m: The model containing kinematic and dynamic information (device).
d: The data object containing the current state and output arrays (device).
"""
wp.launch(_init_subtreemass, dim=(d.nworld, m.nbody), inputs=[m.body_mass], outputs=[m.body_subtreemass])
for i in reversed(range(len(m.body_tree))):
body_tree = m.body_tree[i]
wp.launch(
_accumulate_subtreemass,
dim=(d.nworld, body_tree.size),
inputs=[m.body_parentid, m.body_subtreemass, body_tree],
)
# TODO(team): refactor for graph capture compatibility
body_gravcomp_np = m.body_gravcomp.numpy()
m.ngravcomp = int((body_gravcomp_np > 0.0).any(axis=0).sum())
[docs]
def set_const_0(m: types.Model, d: types.Data):
"""Compute quantities that depend on qpos0.
Computes:
- tendon_length0: tendon resting lengths
- dof_invweight0: inverse inertia for DOFs
- body_invweight0: inverse spatial inertia for bodies
- tendon_invweight0: inverse weight for tendons
- cam_pos0, cam_poscom0, cam_mat0: camera references
- light_pos0, light_poscom0, light_dir0: light references
- actuator_acc0: acceleration from unit actuator force
- actuator_biasprm[2] (dampratio resolution): for position actuators where
gainprm[0] == -biasprm[1] and biasprm[2] > 0, converts dampratio to
damping via biasprm[2] = -dampratio * 2 * sqrt(kp * reflected_mass)
Args:
m: The model containing kinematic and dynamic information (device).
d: The data object containing the current state and output arrays (device).
"""
qpos_saved = wp.clone(d.qpos)
wp.launch(_copy_qpos0_to_qpos, dim=(d.nworld, m.nq), inputs=[m.qpos0], outputs=[d.qpos])
smooth.kinematics(m, d)
smooth.com_pos(m, d)
smooth.camlight(m, d)
smooth.flex(m, d)
smooth.tendon(m, d)
smooth.crb(m, d)
smooth.tendon_armature(m, d)
smooth.factor_m(m, d)
smooth.transmission(m, d)
# Compute meaninertia from qM diagonal at qpos0
wp.launch(
_compute_meaninertia,
dim=d.nworld,
inputs=[m.nv, m.is_sparse, m.dof_Madr, d.qM],
outputs=[m.stat.meaninertia],
)
wp.launch(_copy_tendon_length0, dim=(d.nworld, m.ntendon), inputs=[d.ten_length], outputs=[m.tendon_length0])
# dof_invweight0: computed per joint with averaging for multi-DOF joints
# FREE: 6 DOFs, trans gets mean(A[0:3]), rot gets mean(A[3:6])
# BALL: 3 DOFs, all get mean(A[0:3])
# HINGE/SLIDE: 1 DOF, gets A[0,0]
if m.nv > 0:
unit_vec = wp.zeros((d.nworld, m.nv), dtype=float)
result_vec = wp.zeros((d.nworld, m.nv), dtype=float)
dof_A_diag = wp.zeros((d.nworld, m.nv), dtype=float)
# TODO(team): more efficient approach instead of looping over nv?
for dofid in range(m.nv):
wp.launch(_set_unit_vector, dim=d.nworld, inputs=[dofid], outputs=[unit_vec])
smooth.solve_m(m, d, result_vec, unit_vec)
wp.launch(_extract_dof_A_diag, dim=d.nworld, inputs=[dofid, result_vec], outputs=[dof_A_diag])
wp.launch(
_finalize_dof_invweight0,
dim=(d.nworld, m.nv),
inputs=[m.dof_jntid, m.jnt_type, m.jnt_dofadr, dof_A_diag],
outputs=[m.dof_invweight0],
)
# body_invweight0: computed as mean diagonal of J * inv(M) * J'
# where J is the 6xnv body Jacobian (3 rows translation, 3 rows rotation)
if m.nv > 0:
body_jac_row = wp.zeros((d.nworld, m.nv), dtype=float)
body_result_vec = wp.zeros((d.nworld, m.nv), dtype=float)
body_A_diag = wp.zeros((d.nworld, m.nbody, 6), dtype=float)
# TODO(team): more efficient approach instead of nested iterations?
for bodyid in range(1, m.nbody):
for row_idx in range(6):
wp.launch(
_compute_body_jac_row,
dim=d.nworld,
inputs=[
m.nv,
bodyid,
row_idx,
m.body_parentid,
m.body_rootid,
m.body_dofadr,
m.body_dofnum,
m.dof_parentid,
d.subtree_com,
d.xipos,
d.cdof,
],
outputs=[body_jac_row],
)
smooth.solve_m(m, d, body_result_vec, body_jac_row)
wp.launch(
_compute_body_A_diag_entry,
dim=d.nworld,
inputs=[m.nv, bodyid, row_idx, body_jac_row, body_result_vec],
outputs=[body_A_diag],
)
wp.launch(
_finalize_body_invweight0,
dim=(d.nworld, m.nbody),
inputs=[m.body_weldid, body_A_diag],
outputs=[m.body_invweight0],
)
else:
m.body_invweight0.zero_()
# tendon_invweight0[t] = J_t * inv(M) * J_t'
if m.ntendon > 0:
ten_J_vec = wp.empty((d.nworld, m.nv), dtype=float)
ten_result_vec = wp.empty((d.nworld, m.nv), dtype=float)
for tenid in range(m.ntendon):
ten_J_vec.zero_()
wp.launch(
_copy_tendon_jacobian,
dim=d.nworld,
inputs=[tenid, m.ten_J_rownnz, m.ten_J_rowadr, m.ten_J_colind, d.ten_J],
outputs=[ten_J_vec],
)
smooth.solve_m(m, d, ten_result_vec, ten_J_vec)
wp.launch(
_compute_tendon_dot_product,
dim=d.nworld,
inputs=[m.ten_J_rownnz, m.ten_J_rowadr, m.ten_J_colind, tenid, d.ten_J, ten_result_vec],
outputs=[m.tendon_invweight0],
)
wp.launch(
_compute_cam_pos0,
dim=(d.nworld, m.ncam),
inputs=[m.cam_bodyid, m.cam_targetbodyid, d.cam_xpos, d.cam_xmat, d.xpos, d.subtree_com],
outputs=[m.cam_pos0, m.cam_poscom0, m.cam_mat0],
)
wp.launch(
_compute_light_pos0,
dim=(d.nworld, m.nlight),
inputs=[m.light_bodyid, m.light_targetbodyid, d.light_xpos, d.light_xdir, d.xpos, d.subtree_com],
outputs=[m.light_pos0, m.light_poscom0, m.light_dir0],
)
# actuator_acc0[i] = ||inv(M) * actuator_moment[i]|| - acceleration from unit actuator force
if m.nu > 0 and m.nv > 0:
act_moment_vec = wp.zeros((d.nworld, m.nv), dtype=float)
act_result_vec = wp.zeros((d.nworld, m.nv), dtype=float)
for actid in range(m.nu):
wp.launch(
_copy_actuator_moment,
dim=d.nworld,
inputs=[actid, d.moment_rownnz, d.moment_rowadr, d.moment_colind, d.actuator_moment],
outputs=[act_moment_vec],
)
smooth.solve_m(m, d, act_result_vec, act_moment_vec)
wp.launch(_compute_actuator_acc0, dim=d.nworld, inputs=[actid, m.nv, act_result_vec], outputs=[m.actuator_acc0])
# resolve dampratio: compute dof_M0, then convert dampratio to damping
if m.nu > 0 and m.nv > 0:
dof_M0 = wp.zeros((d.nworld, m.nv), dtype=float)
wp.launch(
_compute_dof_M0,
dim=(d.nworld, m.nv),
inputs=[m.dof_bodyid, m.dof_armature, d.cdof, d.crb],
outputs=[dof_M0],
)
wp.launch(
_resolve_dampratio,
dim=(d.nworld, m.nu),
inputs=[
m.actuator_biastype,
m.actuator_gainprm,
d.moment_rownnz,
d.moment_rowadr,
d.moment_colind,
d.actuator_moment,
dof_M0,
m.nv,
],
outputs=[m.actuator_biasprm],
)
wp.copy(d.qpos, qpos_saved)
[docs]
def set_const(m: types.Model, d: types.Data):
"""Recomputes qpos0-dependent constant model fields.
This function propagates changes from some model fields to derived fields,
allowing modifications that would otherwise be unsafe. It should be called
after modifying model parameters at runtime.
Model fields that can be modified safely with set_const:
Field | Notes
---------------------------------|----------------------------------------------
qpos0, qpos_spring |
body_mass, body_inertia, | Mass and inertia are usually scaled together
body_ipos, body_iquat | since inertia is sum(m * r^2).
body_pos, body_quat | Unsafe for static bodies (invalidates BVH).
body_gravcomp | If changing from 0 to >0 bodies, required.
dof_armature |
eq_data | For connect/weld, offsets computed if not set.
hfield_size |
tendon_stiffness, tendon_damping | Only if changing from/to zero.
actuator_gainprm, actuator_biasprm | For position actuators with dampratio.
For selective updates, use the sub-functions directly based on what changed:
Modified Field | Call
----------------|------------------
body_mass | set_const
body_gravcomp | set_const_fixed
body_inertia | set_const_0
qpos0 | set_const_0
Computes:
- Fixed quantities (via set_const_fixed):
- body_subtreemass: mass of body and all descendants
- ngravcomp: count of bodies with gravity compensation
- qpos0-dependent quantities (via set_const_0):
- tendon_length0: tendon resting lengths
- dof_invweight0: inverse inertia for DOFs
- body_invweight0: inverse spatial inertia for bodies
- tendon_invweight0: inverse weight for tendons
- cam_pos0, cam_poscom0, cam_mat0: camera references
- light_pos0, light_poscom0, light_dir0: light references
- actuator_acc0: acceleration from unit actuator force
- actuator_biasprm[2] (dampratio resolution)
Skips: actuator_length0 (not in mjwarp).
Args:
m: The model containing kinematic and dynamic information (device).
d: The data object containing the current state and output arrays (device).
"""
set_const_fixed(m, d)
set_const_0(m, d)
[docs]
def set_length_range(m: types.Model, d: types.Data, index: int = -1):
"""Compute feasible actuator length ranges from joint/tendon limits.
For joint and tendon transmissions with limits, copies the range directly
from jnt_range or tendon_range scaled by gear. Actuators without limits
keep (0, 0). This covers the common robotics use case; simulation-based
computation for general transmissions is not yet implemented.
Args:
m: The model containing kinematic and dynamic information (device).
d: The data object (unused, kept for API compatibility with MuJoCo C).
index: Actuator index to compute for, or -1 for all actuators.
"""
if m.nu == 0:
return
wp.launch(
_set_length_range,
dim=(d.nworld, m.nu),
inputs=[
m.actuator_trntype,
m.actuator_trnid,
m.actuator_gear,
m.jnt_limited,
m.jnt_range,
m.tendon_limited,
m.tendon_range,
m.ntendon,
],
outputs=[m.actuator_lengthrange],
)
def override_model(model: types.Model | mujoco.MjModel, overrides: dict[str, Any] | Sequence[str]):
"""Overrides model parameters.
Overrides are of the format:
opt.iterations = 1
opt.ls_parallel = True
opt.cone = pyramidal
opt.disableflags = contact | spring
"""
enum_fields = {
"opt.broadphase": types.BroadphaseType,
"opt.broadphase_filter": types.BroadphaseFilter,
"opt.cone": types.ConeType,
"opt.disableflags": types.DisableBit,
"opt.enableflags": types.EnableBit,
"opt.integrator": types.IntegratorType,
"opt.solver": types.SolverType,
}
# MuJoCo pybind11 enums don't support iteration, so we provide explicit mappings
mj_enum_fields = {
"opt.jacobian": {
"DENSE": mujoco.mjtJacobian.mjJAC_DENSE,
"SPARSE": mujoco.mjtJacobian.mjJAC_SPARSE,
"AUTO": mujoco.mjtJacobian.mjJAC_AUTO,
},
}
mjw_only_fields = {
"opt.broadphase",
"opt.broadphase_filter",
"opt.ls_parallel",
"opt.graph_conditional",
"opt.contact_sensor_maxmatch",
}
mj_only_fields = {"opt.jacobian"}
if not isinstance(overrides, dict):
overrides_dict = {}
for override in overrides:
if "=" not in override:
raise ValueError(f"Invalid override format: {override}")
k, v = override.split("=", 1)
overrides_dict[k.strip()] = v.strip()
overrides = overrides_dict
for key, val in overrides.items():
# skip overrides on MjModel for properties that are only on mjw.Model
if key in mjw_only_fields and isinstance(model, mujoco.MjModel):
continue
if key in mj_only_fields and isinstance(model, types.Model):
continue
obj, attrs = model, key.split(".")
for i, attr in enumerate(attrs):
if not hasattr(obj, attr):
raise ValueError(f"Unrecognized model field: {key}")
if i < len(attrs) - 1:
obj = getattr(obj, attr)
continue
typ = type(getattr(obj, attr))
if key in mj_enum_fields and isinstance(val, str):
enum_member = val.strip().upper()
if enum_member not in mj_enum_fields[key]:
raise ValueError(f"Unrecognized enum value for {key}: {enum_member}")
val = mj_enum_fields[key][enum_member]
elif key in enum_fields and isinstance(val, str):
# special case: enum value
enum_members = val.split("|")
val = 0
for enum_member in enum_members:
enum_member = enum_member.strip().upper()
if enum_member not in enum_fields[key].__members__:
raise ValueError(f"Unrecognized enum value for {enum_fields[key].__name__}: {enum_member}")
val |= int(enum_fields[key][enum_member])
elif typ is bool and isinstance(val, str):
# special case: "true", "TRUE", "false", "FALSE" etc.
if val.upper() not in ("TRUE", "FALSE"):
raise ValueError(f"Unrecognized value for field: {key}")
val = val.upper() == "TRUE"
elif typ is wp.array and isinstance(val, str):
arr = getattr(obj, attr)
floats = [float(p) for p in val.strip("[]").split()]
val = wp.array([arr.dtype(*floats)], dtype=arr.dtype)
elif typ is np.ndarray and isinstance(val, str):
arr = getattr(obj, attr)
val = np.array([float(p) for p in val.strip("[]").split()], dtype=arr.dtype)
else:
val = typ(val)
setattr(obj, attr, val)
def find_keys(model: mujoco.MjModel, keyname_prefix: str) -> list[int]:
"""Finds keyframes that start with keyname_prefix."""
keys = []
for keyid in range(model.nkey):
name = mujoco.mj_id2name(model, mujoco.mjtObj.mjOBJ_KEY, keyid)
if name.startswith(keyname_prefix):
keys.append(keyid)
return keys
def make_trajectory(model: mujoco.MjModel, keys: list[int]) -> np.ndarray:
"""Make a ctrl trajectory with linear interpolation."""
ctrls = []
prev_ctrl_key = np.zeros(model.nu, dtype=np.float64)
prev_time, time = 0.0, 0.0
for key in keys:
ctrl_key, ctrl_time = model.key_ctrl[key], model.key_time[key]
if not ctrls and ctrl_time != 0.0:
raise ValueError("first keyframe must have time 0.0")
elif ctrls and ctrl_time <= prev_time:
raise ValueError("keyframes must be in time order")
while time < ctrl_time:
frac = (time - prev_time) / (ctrl_time - prev_time)
ctrls.append(prev_ctrl_key * (1 - frac) + ctrl_key * frac)
time += model.opt.timestep
ctrls.append(ctrl_key)
time += model.opt.timestep
prev_ctrl_key = ctrl_key
prev_time = time
return np.array(ctrls)
@wp.kernel
def _build_rays(
# In:
offset: int,
img_w: int,
img_h: int,
projection: int,
fovy: float,
sensorsize: wp.vec2,
intrinsic: wp.vec4,
znear: float,
# Out:
ray_out: wp.array[wp.vec3],
):
xid, yid = wp.tid()
ray_out[offset + xid + yid * img_w] = render_util.compute_ray(
projection, fovy, sensorsize, intrinsic, img_w, img_h, xid, yid, znear
)
[docs]
def create_render_context(
mjm: mujoco.MjModel,
nworld: int = 1,
cam_res: list[tuple[int, int]] | tuple[int, int] | None = None,
render_rgb: list[bool] | bool | None = None,
render_depth: list[bool] | bool | None = None,
render_seg: list[bool] | bool | None = None,
use_textures: bool = True,
use_shadows: bool = False,
enabled_geom_groups: list[int] = [0, 1, 2],
cam_active: list[bool] | None = None,
flex_render_smooth: bool = True,
use_precomputed_rays: bool = True,
) -> types.RenderContext:
"""Creates a render context on device.
Args:
mjm: The model containing kinematic and dynamic information on host.
nworld: The number of worlds.
cam_res: The width and height to render each camera image. If None, uses the
MuJoCo model values.
render_rgb: Whether to render RGB images. If None, uses the MuJoCo model values.
render_depth: Whether to render depth images. If None, uses the MuJoCo model values.
render_seg: Whether to render segmentation (per-pixel geom IDs). If None,
uses the MuJoCo model values.
use_textures: Whether to use textures.
use_shadows: Whether to use shadows.
enabled_geom_groups: The geom groups to render.
cam_active: List of booleans indicating which cameras to include in rendering.
If None, all cameras are included.
flex_render_smooth: Whether to render flex meshes smoothly.
use_precomputed_rays: Use precomputed rays instead of computing during rendering.
When using domain randomization for camera intrinsics, set to False.
Returns:
The render context containing rendering fields and output arrays on device.
"""
mjd = mujoco.MjData(mjm)
mujoco.mj_forward(mjm, mjd)
constructor = "sah"
if check_version("warp>=1.13.0.dev20260325"):
# TODO: The cubql constructor and is_cubql_available exist only in
# recent Warp 1.13+ builds, modify this after warp is updated to 1.13+.
_cubql_avail = getattr(wp, "is_cubql_available", None)
if callable(_cubql_avail) and _cubql_avail():
constructor = "cubql"
# Mesh BVHs – build for all meshes so per-world variants are available
nmesh = mjm.nmesh
geom_enabled_mask = np.isin(mjm.geom_group, list(enabled_geom_groups))
geom_enabled_idx = np.nonzero(geom_enabled_mask)[0]
mesh_registry = {}
mesh_bvh_id = [wp.uint64(0) for _ in range(nmesh)]
mesh_bounds_size = [wp.vec3(0.0, 0.0, 0.0) for _ in range(nmesh)]
for mid in range(nmesh):
mesh, half = bvh.build_mesh_bvh(mjm, mid, constructor=constructor)
mesh_registry[mesh.id] = mesh
mesh_bvh_id[mid] = mesh.id
mesh_bounds_size[mid] = half
mesh_bvh_id_arr = wp.array(mesh_bvh_id, dtype=wp.uint64)
mesh_bounds_size_arr = wp.array(mesh_bounds_size, dtype=wp.vec3)
# HField BVHs
nhfield = mjm.nhfield
hfield_geom_mask = geom_enabled_mask & (mjm.geom_type == types.GeomType.HFIELD) & (mjm.geom_dataid >= 0)
used_hfield_id = set(mjm.geom_dataid[hfield_geom_mask].astype(int))
hfield_registry = {}
hfield_bvh_id = [wp.uint64(0) for _ in range(nhfield)]
hfield_bounds_size = [wp.vec3(0.0, 0.0, 0.0) for _ in range(nhfield)]
for hid in used_hfield_id:
hmesh, hhalf = bvh.build_hfield_bvh(mjm, hid, constructor=constructor)
hfield_registry[hmesh.id] = hmesh
hfield_bvh_id[hid] = hmesh.id
hfield_bounds_size[hid] = hhalf
hfield_bvh_id_arr = wp.array(hfield_bvh_id, dtype=wp.uint64)
hfield_bounds_size_arr = wp.array(hfield_bounds_size, dtype=wp.vec3)
# Flex BVHs
nflex = mjm.nflex
flex_registry = {}
# Scene BVH flex primitives: 1D → one capsule per edge, 2D/3D → one box per flex
flex_geom_flexid = []
flex_geom_edgeid = []
flex_bvh_id = np.full(nflex, 0, dtype=wp.uint64)
flex_group_root = np.zeros((nflex, nworld), dtype=int)
for f in range(nflex):
if mjm.flex_dim[f] == 1:
edge_adr = mjm.flex_edgeadr[f]
flex_geom_flexid.extend([f] * mjm.flex_edgenum[f])
flex_geom_edgeid.extend([edge_adr + e for e in range(mjm.flex_edgenum[f])])
flex_group_root[f] = np.zeros(nworld, dtype=int)
else:
flex_geom_flexid.append(f)
flex_geom_edgeid.append(-1)
fmesh, group_root = bvh.build_flex_bvh(mjm, mjd, nworld, f)
flex_registry[f] = fmesh
flex_bvh_id[f] = fmesh.id
flex_group_root[f] = group_root.numpy()
textures_registry = []
for i in range(mjm.ntex):
textures_registry.append(render_util.create_warp_texture(mjm, i))
textures = wp.array(textures_registry, dtype=wp.Texture2D)
# Filter active cameras
if cam_active is not None:
assert len(cam_active) == mjm.ncam, f"cam_active must have length {mjm.ncam} (got {len(cam_active)})"
active_cam_indices = np.nonzero(cam_active)[0]
else:
active_cam_indices = list(range(mjm.ncam))
ncam = len(active_cam_indices)
if cam_res is not None:
if isinstance(cam_res, tuple):
cam_res = [cam_res] * ncam
assert len(cam_res) == ncam, (
f"Camera resolutions must be provided for all active cameras (got {len(cam_res)}, expected {ncam})"
)
active_cam_res = cam_res
else:
active_cam_res = mjm.cam_resolution[active_cam_indices]
cam_res_arr = wp.array(active_cam_res, dtype=wp.vec2i)
if render_rgb is None:
render_rgb = [mjm.cam_output[i] & mujoco.mjtCamOutBit.mjCAMOUT_RGB for i in active_cam_indices]
elif isinstance(render_rgb, bool):
render_rgb = [render_rgb] * ncam
if render_depth is None:
render_depth = [mjm.cam_output[i] & mujoco.mjtCamOutBit.mjCAMOUT_DEPTH for i in active_cam_indices]
if isinstance(render_depth, bool):
render_depth = [render_depth] * ncam
if render_seg is None:
render_seg = [mjm.cam_output[i] & mujoco.mjtCamOutBit.mjCAMOUT_SEG for i in active_cam_indices]
elif isinstance(render_seg, bool):
render_seg = [render_seg] * ncam
assert len(render_rgb) == ncam and len(render_depth) == ncam and len(render_seg) == ncam, (
f"render_rgb, render_depth, and render_seg must be a bool or a list of bools with length {ncam}"
)
rgb_adr = -1 * np.ones(ncam, dtype=int)
depth_adr = -1 * np.ones(ncam, dtype=int)
seg_adr = -1 * np.ones(ncam, dtype=int)
cam_res_np = cam_res_arr.numpy()
ri = 0
di = 0
si = 0
total = 0
for idx in range(ncam):
if render_rgb[idx]:
rgb_adr[idx] = ri
ri += cam_res_np[idx][0] * cam_res_np[idx][1]
if render_depth[idx]:
depth_adr[idx] = di
di += cam_res_np[idx][0] * cam_res_np[idx][1]
if render_seg[idx]:
seg_adr[idx] = si
si += cam_res_np[idx][0] * cam_res_np[idx][1]
total += cam_res_np[idx][0] * cam_res_np[idx][1]
znear = mjm.vis.map.znear * mjm.stat.extent
ray = wp.zeros(int(total), dtype=wp.vec3)
cam_projection = mjm.cam_projection
offset = 0
for idx, cam_id in enumerate(active_cam_indices):
img_w = cam_res_np[idx][0]
img_h = cam_res_np[idx][1]
wp.launch(
kernel=_build_rays,
dim=(img_w, img_h),
inputs=[
offset,
img_w,
img_h,
int(cam_projection[cam_id]),
float(mjm.cam_fovy[cam_id]),
wp.vec2(mjm.cam_sensorsize[cam_id]),
wp.vec4(mjm.cam_intrinsic[cam_id]),
znear,
],
outputs=[ray],
)
offset += img_w * img_h
bvh_ngeom = len(geom_enabled_idx)
rc = types.RenderContext(
nrender=ncam,
cam_res=cam_res_arr,
cam_id_map=wp.array(active_cam_indices, dtype=int),
use_textures=use_textures,
use_shadows=use_shadows,
background_color=render_util.pack_rgba_to_uint32(0.1 * 255.0, 0.1 * 255.0, 0.2 * 255.0, 1.0 * 255.0),
use_precomputed_rays=use_precomputed_rays,
bvh_ngeom=bvh_ngeom,
enabled_geom_ids=wp.array(geom_enabled_idx, dtype=int),
mesh_registry=mesh_registry,
mesh_bvh_id=mesh_bvh_id_arr,
mesh_bounds_size=mesh_bounds_size_arr,
mesh_texcoord=wp.array(mjm.mesh_texcoord, dtype=wp.vec2),
mesh_texcoord_offsets=wp.array(mjm.mesh_texcoordadr, dtype=int),
mesh_facetexcoord=wp.array(mjm.mesh_facetexcoord, dtype=wp.vec3i),
textures=textures,
textures_registry=textures_registry,
hfield_registry=hfield_registry,
hfield_bvh_id=hfield_bvh_id_arr,
hfield_bounds_size=hfield_bounds_size_arr,
flex_mesh_registry=flex_registry,
flex_rgba=wp.array(mjm.flex_rgba, dtype=wp.vec4),
flex_bvh_id=wp.array(flex_bvh_id, dtype=wp.uint64),
flex_group_root=wp.array(flex_group_root, dtype=int),
flex_render_smooth=flex_render_smooth,
bvh_nflexgeom=len(flex_geom_flexid),
flex_dim_np=mjm.flex_dim,
flex_geom_flexid=wp.array(flex_geom_flexid, dtype=int),
flex_geom_edgeid=wp.array(flex_geom_edgeid, dtype=int),
bvh=None,
bvh_id=None,
lower=wp.zeros(nworld * (bvh_ngeom + len(flex_geom_flexid)), dtype=wp.vec3),
upper=wp.zeros(nworld * (bvh_ngeom + len(flex_geom_flexid)), dtype=wp.vec3),
group=wp.zeros(nworld * (bvh_ngeom + len(flex_geom_flexid)), dtype=int),
group_root=wp.zeros(nworld, dtype=int),
ray=ray,
rgb_data=wp.zeros((nworld, ri), dtype=wp.uint32),
rgb_adr=wp.array(rgb_adr, dtype=int),
depth_data=wp.zeros((nworld, di), dtype=wp.float32),
depth_adr=wp.array(depth_adr, dtype=int),
render_rgb=wp.array(render_rgb, dtype=bool),
render_depth=wp.array(render_depth, dtype=bool),
seg_data=wp.zeros((nworld, max(si, 1)), dtype=int),
seg_adr=wp.array(seg_adr, dtype=int),
render_seg=wp.array(render_seg, dtype=bool),
znear=znear,
total_rays=int(total),
)
bvh.build_scene_bvh(mjm, mjd, rc, nworld)
return rc