1: # Copyright 2023 DeepMind Technologies Limited
2: #
3: # Licensed under the Apache License, Version 2.0 (the "License");
4: # you may not use this file except in compliance with the License.
5: # You may obtain a copy of the License at
6: #
7: # http://www.apache.org/licenses/LICENSE-2.0
8: #
9: # Unless required by applicable law or agreed to in writing, software
10: # distributed under the License is distributed on an "AS IS" BASIS,
11: # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12: # See the License for the specific language governing permissions and
13: # limitations under the License.
14: # ==============================================================================
15: """Scan across data ordered by body joint types and kinematic tree order."""
16:
17: from typing import Any, Callable, TypeVar
18:
19: import jax
20: from jax import numpy as jp
21: # pylint: disable=g-importing-member
22: from mujoco.mjx._src.types import JointType
23: from mujoco.mjx._src.types import Model
24: from mujoco.mjx._src.types import TrnType
25: # pylint: enable=g-importing-member
26: import numpy as np
27:
28:
29: Y = TypeVar('Y')
30:
31:
32: # TODO(erikfrey): re-check if this really helps perf
33: def _take(obj: Y, idx: np.ndarray) -> Y:
34: """Takes idxs on any pytree given to it.
35:
36: XLA executes x[jp.array([1, 2, 3])] slower than x[1:4], so we detect when
37: take indices are contiguous, and convert them to slices.
38:
39: Args:
40: obj: an input pytree
41: idx: indices to take
42:
43: Returns:
44: obj pytree with leaves taken by idxs
45: """
46:
47: if isinstance(obj, np.ndarray):
48: return obj[idx]
49:
50: def take(x):
51: # TODO(erikfrey): if this helps perf, add support for striding too
52: if not x.shape[0]:
53: return x
54: elif (
55: len(idx.shape) == 1
56: and idx.size > 0
57: and (idx == np.arange(idx[0], idx[0] + idx.size)).all()
58: and (idx > 0).all()
59: ):
60: x = x[idx[0] : idx[-1] + 1]
61: else:
62: x = x.take(jp.array(idx), axis=0, mode='wrap')
63: return x
64:
65: return jax.tree_util.tree_map(take, obj)
66:
67:
68: def _q_bodyid(m: Model) -> np.ndarray:
69: """Returns the bodyid for each qpos adress."""
70: q_bodyids = [np.array([], dtype=np.int32)]
71: for jnt_type, jnt_bodyid in zip(m.jnt_type, m.jnt_bodyid):
72: width = {JointType.FREE: 7, JointType.BALL: 4}.get(jnt_type, 1)
73: q_bodyids.append(np.repeat(jnt_bodyid, width))
74: return np.concatenate(q_bodyids)
75:
76:
77: def _q_jointid(m: Model) -> np.ndarray:
78: """Returns the jointid for each qpos adress."""
79: q_jointid = [np.array([], dtype=np.int32)]
80: for i, jnt_type in enumerate(m.jnt_type):
81: width = {JointType.FREE: 7, JointType.BALL: 4}.get(jnt_type, 1)
82: q_jointid.append(np.repeat(i, width))
83: return np.concatenate(q_jointid)
84:
85:
86: def _index(haystack: np.ndarray, needle: np.ndarray) -> np.ndarray:
87: """Returns indexes in haystack for elements in needle."""
88: idx = np.argsort(haystack)
89: sorted_haystack = haystack[idx]
90: sorted_idx = np.searchsorted(sorted_haystack, needle)
91: idx = np.take(idx, sorted_idx, mode='clip')
92: idx[haystack[idx] != needle] = -1
93:
94: return idx
95:
96:
97: def _nvmap(f: Callable[..., Y], *args) -> Y:
98: """A vmap that accepts numpy arrays.
99:
100: Numpy arrays are statically vmapped, and the elements are passed to f as
101: static arguments. The implication is that all the elements of numpy array
102: arguments must be the same.
103:
104: Args:
105: f: function to be mapped over
106: *args: args to be mapped along, passed to f
107:
108: Returns:
109: the result of vmapping f over args
110:
111: Raises:
112: RuntimeError: if numpy arg elements do not match
113: """
114: for arg in args:
115: if isinstance(arg, np.ndarray) and not np.all(arg == arg[0]):
116: raise RuntimeError(f'numpy arg elements do not match: {arg}')
117:
118: # split out numpy and jax args
119: np_args = [a[0] if isinstance(a, np.ndarray) else None for a in args]
120: args = [a if n is None else None for n, a in zip(np_args, args)]
121:
122: # remove empty args that we should not vmap over
123: args = jax.tree_util.tree_map(lambda a: a if a.shape[0] else None, args)
124: in_axes = [None if a is None else 0 for a in args]
125:
126: def outer_f(*args, np_args=np_args):
127: args = [a if n is None else n for n, a in zip(args, np_args)]
128: return f(*args)
129:
130: return jax.vmap(outer_f, in_axes=in_axes)(*args)
131:
132:
133: def _check_input(m: Model, args: Any, in_types: str) -> None:
134: """Checks that scan input has the right shape."""
135: size = {
136: 'b': m.nbody,
137: 'j': m.njnt,
138: 'q': m.nq,
139: 'v': m.nv,
140: 'u': m.nu,
141: 'a': m.na,
142: 's': m.nsite,
143: 'c': m.ncam,
144: }
145: for idx, (arg, typ) in enumerate(zip(args, in_types)):
146: if len(arg) != size[typ]:
147: raise IndexError((
148: f'f argument "{idx}" with type "{typ}" has length "{len(arg)}"'
149: f' which does not match the in_types[{idx}] expected length of '
150: f'"{size[typ]}".'
151: ))
152:
153:
154: def _check_output(
155: y: jax.Array, take_ids: np.ndarray, typ: str, idx: int
156: ) -> None:
157: """Checks that scan output has the right shape."""
158: if y.shape[0] != take_ids.shape[0]:
159: raise IndexError((
160: f'f output "{idx}" with type "{typ}" has shape "{y.shape[0]}" '
161: f'which does not match the out_types[{idx}] expected size of'
162: f' "{take_ids.shape[0]}".'
163: ))
164:
165:
166: def flat(
167: m: Model,
168: f: Callable[..., Y],
169: in_types: str,
170: out_types: str,
171: *args,
172: group_by: str = 'j',
173: ) -> Y:
174: r"""Scan a function across bodies or actuators.
175:
176: Scan group data according to type and batch shape then calls vmap(f) on it.\
177:
178: Args:
179: m: an mjx model
180: f: a function to be scanned with the following type signature:
181: def f(key, *args) -> y
182: where
183: ``key`` gives grouping key for this function instance
184: ``*args`` are input arguments with types matching ``in_types``
185: ``y`` is an output arguments with types matching ``out_type``
186: in_types: string specifying the type of each input arg:
187: 'b': split according to bodies
188: 'j': split according to joint types
189: 'q': split according to generalized coordinates (len(qpos))
190: 'v': split according to degrees of freedom (len(qvel))
191: 'u': split according to actuators
192: 'a': split according to actuator activations
193: 'c': split according to camera
194: out_types: string specifying the types the output dimension matches
195: *args: the input arguments corresponding to ``in_types``
196: group_by: the type to group by, either joints or actuators
197:
198: Returns:
199: The stacked outputs of ``f`` matching the model's order.
200:
201: Raises:
202: IndexError: if function output shape does not match out_types shape
203: """
204: _check_input(m, args, in_types)
205:
206: if group_by not in {'j', 'u', 'c'}:
207: raise NotImplementedError(f'group by type "{group_by}" not implemented.')
208:
209: def key_j(type_ids):
210: if any(t in 'jqv' for t in in_types + out_types):
211: return tuple(m.jnt_type[type_ids['j']])
212: return ()
213:
214: def type_ids_j(m, i):
215: return {
216: 'b': i,
217: 'j': np.nonzero(m.jnt_bodyid == i)[0],
218: 'v': np.nonzero(m.dof_bodyid == i)[0],
219: 'q': np.nonzero(_q_bodyid(m) == i)[0],
220: }
221:
222: def key_u(type_ids):
223: ids_u, ids_j = type_ids['u'], type_ids['j']
224: return (
225: m.actuator_biastype[ids_u],
226: m.actuator_gaintype[ids_u],
227: m.actuator_dyntype[ids_u],
228: m.actuator_trntype[ids_u],
229: m.jnt_type[ids_j],
230: m.actuator_trnid[ids_u, 1] == -1, # key by refsite being present
231: )
232:
233: def type_ids_u(m, i):
234: typ_ids = {
235: 'u': i,
236: 'a': m.actuator_actadr[i],
237: 'j': (
238: m.actuator_trnid[i, 0]
239: if m.actuator_trntype[i] in (TrnType.JOINT, TrnType.JOINTINPARENT)
240: else -1
241: ),
242: 's': (
243: m.actuator_trnid[i]
244: if m.actuator_trntype[i] == TrnType.SITE
245: else np.array([-1, -1])
246: ),
247: }
248: v, q = np.array([-1]), np.array([-1])
249: if m.actuator_trntype[i] in (TrnType.JOINT, TrnType.JOINTINPARENT):
250: # v/q are associated with the joint transmissions only
251: v = np.nonzero(m.dof_jntid == typ_ids['j'])[0]
252: q = np.nonzero(_q_jointid(m) == typ_ids['j'])[0]
253:
254: typ_ids.update({'v': v, 'q': q})
255:
256: return typ_ids
257:
258: def key_c(type_ids):
259: return m.cam_mode[type_ids['c']], m.cam_targetbodyid[type_ids['c']] >= 0
260:
261: def type_ids_c(unused_m, i):
262: return {
263: 'c': i,
264: }
265:
266: type_ids_fn = {'j': type_ids_j, 'u': type_ids_u, 'c': type_ids_c}[group_by]
267: key_fn = {'j': key_j, 'u': key_u, 'c': key_c}[group_by]
268:
269: # build up a grouping of type take-ids in body/actuator order
270: key_typ_ids, order = {}, []
271: all_types = set(in_types + out_types)
272: n_items = {'j': m.nbody, 'u': m.nu, 'c': m.ncam}[group_by]
273: for i in np.arange(n_items, dtype=np.int32):
274: typ_ids = type_ids_fn(m, i)
275:
276: # create grouping key
277: key = key_fn(typ_ids)
278: order.append((key, typ_ids))
279:
280: # add ids per type to the corresponding group
281: for t in all_types:
282: out = key_typ_ids.setdefault(key, {})
283: val = np.expand_dims(typ_ids[t], axis=0)
284: out[t] = np.concatenate((out[t], val)) if t in out else val
285:
286: key_typ_ids = list(sorted(key_typ_ids.items()))
287:
288: # use this grouping to take the right data subsets and call vmap(f)
289: ys = []
290: for _, typ_ids in key_typ_ids:
291: # only execute f if we would actually take something from the result
292: if any(typ_ids[v].size > 0 for v in out_types):
293: f_args = [_take(arg, typ_ids[typ]) for arg, typ in zip(args, in_types)]
294: y = _nvmap(f, *f_args)
295: ys.append(y)
296: else:
297: ys.append(None)
298:
299: # remove None results from the final output
300: key_typ_ids = [v for y, v in zip(ys, key_typ_ids) if y is not None]
301: ys = [y for y in ys if y is not None]
302: ys_keys = set([k for k, *_ in key_typ_ids])
303: order = [o for k, o in order if k in ys_keys]
304:
305: # get the original input order
306: order = [[o[t] for o in order] for t in all_types]
307: order = [
308: np.concatenate(o) if isinstance(o[0], np.ndarray) else np.array(o)
309: for o in order
310: ]
311: order = dict(zip(all_types, order))
312:
313: # concatenate back to a single tree and drop the grouping dimension
314: f_ret_is_seq = isinstance(ys[0], (list, tuple))
315: ys = ys if f_ret_is_seq else [[y] for y in ys]
316: flat_ = {'j': 'b', 'u': 'uaj', 'c': 'c'}[group_by]
317: ys = [
318: [v if typ in flat_ else jp.concatenate(v) for v, typ in zip(y, out_types)]
319: for y in ys
320: ]
321: ys = jax.tree_util.tree_map(lambda *x: jp.concatenate(x), *ys)
322:
323: # put concatenated results back in order
324: reordered_ys = []
325: for i, (y, typ) in enumerate(zip(ys, out_types)):
326: _check_output(y, order[typ], typ, i)
327: ids = np.concatenate([np.hstack(v[typ]) for _, v in key_typ_ids])
328: input_order = order[typ][np.where(order[typ] != -1)]
329: reordered_ys.append(_take(y, _index(ids, input_order)))
330: y = reordered_ys if f_ret_is_seq else reordered_ys[0]
331:
332: return y
333:
334:
335: def body_tree(
336: m: Model,
337: f: Callable[..., Y],
338: in_types: str,
339: out_types: str,
340: *args,
341: reverse: bool = False,
342: ) -> Y:
343: r"""Scan ``f`` across bodies in tree order, carrying results up/down the tree.
344:
345: This function groups bodies according to level and attached joints, then calls
346: vmap(f) on them.\
347:
348: Args:
349: m: an mjx mjmodel
350: f: a function to be scanned with the following type signature:
351: def f(y, *args) -> y
352: where
353: ``y`` is the carry value and return value
354: ``*args`` are input arguments with types matching ``in_types``
355: in_types: string specifying the type of each input arg:
356: 'b': split according to bodies
357: 'j': split according to joint types
358: 'q': split according to generalized coordinates (len(qpos))
359: 'v': split according to degrees of freedom (len(qvel))
360: out_types: string specifying the types the output dimension matches
361: *args: the input arguments corresponding to ``in_types``
362: reverse: if True, scans up the body tree from leaves to root, otherwise
363: root to leaves
364:
365: Returns:
366: The stacked outputs of ``f`` matching the model's body order.
367:
368: Raises:
369: IndexError: if function output shape does not match out_types shape
370: """
371: _check_input(m, args, in_types)
372:
373: # group together bodies that will be processed together. grouping key:
374: # 1) the tree depth: parent bodies are processed first, so that they are
375: # available as carry input to child bodies (or reverse if reverse=True)
376: # 2) the types of arguments passed to f, both carry and *args:
377: # * for 'b' arguments, there is no extra grouping
378: # * for 'j' arguments, we group by joint type
379: # * for 'q' arguments, we group by q width
380: # * for 'v' arguments, we group by dof width
381: depths = np.zeros(m.nbody, dtype=np.int32)
382:
383: # map key => body id
384: key_body_ids = {}
385: for body_id in range(m.nbody):
386: parent_id = -1
387: if body_id > 0:
388: parent_id = m.body_parentid[body_id]
389: depths[body_id] = 1 + depths[parent_id]
390:
391: # create grouping key: depth, carry, args
392: key = (depths[body_id],)
393:
394: for i, t in enumerate(out_types + in_types):
395: id_ = parent_id if i < len(out_types) else body_id
396: if t == 'b':
397: continue
398: elif t == 'j':
399: key += tuple(m.jnt_type[np.nonzero(m.jnt_bodyid == id_)[0]])
400: elif t == 'v':
401: key += (len(np.nonzero(m.dof_bodyid == id_)[0]),)
402: elif t == 'q':
403: key += (len(np.nonzero(_q_bodyid(m) == id_)[0]),)
404:
405: body_ids = key_body_ids.get(key, np.array([], dtype=np.int32))
406: key_body_ids[key] = np.append(body_ids, body_id)
407:
408: # find parent keys of each key. a key may have multiple parents if the
409: # carry output keys of distinct parents are the same. e.g.:
410: # - depth 0 body 1 (slide joint)
411: # -- depth 1 body 1 (hinge joint)
412: # - depth 0 body 2 (ball joint)
413: # -- depth 1 body 2 (hinge joint)
414: # given a scan with 'j' in the in_types, we would group depth 0 bodies
415: # separately but we may group depth 1 bodies together
416: key_parents = {}
417:
418: for key, body_ids in key_body_ids.items():
419: body_ids = body_ids[body_ids != 0] # ignore worldbody, has no parent
420: if body_ids.size == 0:
421: continue
422: # find any key which has a body id that is a parent of these body_ids
423: pids = m.body_parentid[body_ids]
424: parents = {k for k, v in key_body_ids.items() if np.isin(v, pids).any()}
425: key_parents[key] = list(sorted(parents))
426:
427: # key => take indices
428: key_in_take, key_y_take = {}, {}
429: for key, body_ids in key_body_ids.items():
430: for i, typ in enumerate(in_types + out_types):
431: if typ == 'b':
432: ids = body_ids
433: elif typ == 'j':
434: ids = np.stack([np.nonzero(m.jnt_bodyid == b)[0] for b in body_ids])
435: elif typ == 'v':
436: ids = np.stack([np.nonzero(m.dof_bodyid == b)[0] for b in body_ids])
437: elif typ == 'q':
438: ids = np.stack([np.nonzero(_q_bodyid(m) == b)[0] for b in body_ids])
439: else:
440: raise ValueError(f'Unknown in_type: {typ}')
441: if i < len(in_types):
442: key_in_take.setdefault(key, []).append(ids)
443: else:
444: key_y_take.setdefault(key, []).append(np.hstack(ids))
445:
446: # use this grouping to take the right data subsets and call vmap(f)
447: keys = sorted(key_body_ids, reverse=reverse)
448: key_y = {}
449: for key in keys:
450: carry = None
451:
452: if reverse:
453: child_keys = [k for k, v in key_parents.items() if key in v]
454:
455: for child_key in child_keys:
456: y = key_y[child_key]
457: body_ids = key_body_ids[key]
458: parent_ids = m.body_parentid[key_body_ids[child_key]]
459: id_map = _index(body_ids, parent_ids)
460:
461: def index_sum(x, i=id_map, s=body_ids.size):
462: return jax.ops.segment_sum(x, i, s)
463:
464: y = jax.tree_util.tree_map(index_sum, y)
465: carry = y if carry is None else jax.tree_util.tree_map(jp.add, carry, y)
466: elif key in key_parents:
467: ys = [key_y[p] for p in key_parents[key]]
468: y = jax.tree_util.tree_map(lambda *x: jp.concatenate(x), *ys)
469: body_ids = np.concatenate([key_body_ids[p] for p in key_parents[key]])
470: parent_ids = m.body_parentid[key_body_ids[key]]
471: take_fn = lambda x, i=_index(body_ids, parent_ids): _take(x, i)
472: carry = jax.tree_util.tree_map(take_fn, y)
473:
474: f_args = [_take(arg, ids) for arg, ids in zip(args, key_in_take[key])]
475: key_y[key] = _nvmap(f, carry, *f_args)
476:
477: # slice None results from the final output
478: keys = [k for k in keys if key_y[k] is not None]
479:
480: # concatenate ys, drop grouping dimensions, put back in order
481: y = []
482: for i, typ in enumerate(out_types):
483: y_typ = [key_y[key] for key in keys]
484: if len(out_types) > 1:
485: y_typ = [y_[i] for y_ in y_typ]
486: if typ != 'b':
487: y_typ = jax.tree_util.tree_map(jp.concatenate, y_typ)
488: y_typ = jax.tree_util.tree_map(lambda *x: jp.concatenate(x), *y_typ)
489: y_take = np.argsort(np.concatenate([key_y_take[key][i] for key in keys]))
490: _check_output(y_typ, y_take, typ, i)
491: y.append(_take(y_typ, y_take))
492:
493: y = y[0] if len(out_types) == 1 else y
494:
495: return y
496: