google-deepmind/mujoco/mjx/mujoco/mjx/_src/scan.py
1: # Copyright 2023 DeepMind Technologies Limited
2: #
3: # Licensed under the Apache License, Version 2.0 (the "License");
4: # you may not use this file except in compliance with the License.
5: # You may obtain a copy of the License at
6: #
7: #     http://www.apache.org/licenses/LICENSE-2.0
8: #
9: # Unless required by applicable law or agreed to in writing, software
10: # distributed under the License is distributed on an "AS IS" BASIS,
11: # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12: # See the License for the specific language governing permissions and
13: # limitations under the License.
14: # ==============================================================================
15: """Scan across data ordered by body joint types and kinematic tree order."""
16: 
17: from typing import Any, Callable, TypeVar
18: 
19: import jax
20: from jax import numpy as jp
21: # pylint: disable=g-importing-member
22: from mujoco.mjx._src.types import JointType
23: from mujoco.mjx._src.types import Model
24: from mujoco.mjx._src.types import TrnType
25: # pylint: enable=g-importing-member
26: import numpy as np
27: 
28: 
29: Y = TypeVar('Y')
30: 
31: 
32: # TODO(erikfrey): re-check if this really helps perf
33: def _take(obj: Y, idx: np.ndarray) -> Y:
34:   """Takes idxs on any pytree given to it.
35: 
36:   XLA executes x[jp.array([1, 2, 3])] slower than x[1:4], so we detect when
37:   take indices are contiguous, and convert them to slices.
38: 
39:   Args:
40:     obj: an input pytree
41:     idx: indices to take
42: 
43:   Returns:
44:     obj pytree with leaves taken by idxs
45:   """
46: 
47:   if isinstance(obj, np.ndarray):
48:     return obj[idx]
49: 
50:   def take(x):
51:     # TODO(erikfrey): if this helps perf, add support for striding too
52:     if not x.shape[0]:
53:       return x
54:     elif (
55:         len(idx.shape) == 1
56:         and idx.size > 0
57:         and (idx == np.arange(idx[0], idx[0] + idx.size)).all()
58:         and (idx > 0).all()
59:     ):
60:       x = x[idx[0] : idx[-1] + 1]
61:     else:
62:       x = x.take(jp.array(idx), axis=0, mode='wrap')
63:     return x
64: 
65:   return jax.tree_util.tree_map(take, obj)
66: 
67: 
68: def _q_bodyid(m: Model) -> np.ndarray:
69:   """Returns the bodyid for each qpos adress."""
70:   q_bodyids = [np.array([], dtype=np.int32)]
71:   for jnt_type, jnt_bodyid in zip(m.jnt_type, m.jnt_bodyid):
72:     width = {JointType.FREE: 7, JointType.BALL: 4}.get(jnt_type, 1)
73:     q_bodyids.append(np.repeat(jnt_bodyid, width))
74:   return np.concatenate(q_bodyids)
75: 
76: 
77: def _q_jointid(m: Model) -> np.ndarray:
78:   """Returns the jointid for each qpos adress."""
79:   q_jointid = [np.array([], dtype=np.int32)]
80:   for i, jnt_type in enumerate(m.jnt_type):
81:     width = {JointType.FREE: 7, JointType.BALL: 4}.get(jnt_type, 1)
82:     q_jointid.append(np.repeat(i, width))
83:   return np.concatenate(q_jointid)
84: 
85: 
86: def _index(haystack: np.ndarray, needle: np.ndarray) -> np.ndarray:
87:   """Returns indexes in haystack for elements in needle."""
88:   idx = np.argsort(haystack)
89:   sorted_haystack = haystack[idx]
90:   sorted_idx = np.searchsorted(sorted_haystack, needle)
91:   idx = np.take(idx, sorted_idx, mode='clip')
92:   idx[haystack[idx] != needle] = -1
93: 
94:   return idx
95: 
96: 
97: def _nvmap(f: Callable[..., Y], *args) -> Y:
98:   """A vmap that accepts numpy arrays.
99: 
100:   Numpy arrays are statically vmapped, and the elements are passed to f as
101:   static arguments.  The implication is that all the elements of numpy array
102:   arguments must be the same.
103: 
104:   Args:
105:     f: function to be mapped over
106:     *args: args to be mapped along, passed to f
107: 
108:   Returns:
109:     the result of vmapping f over args
110: 
111:   Raises:
112:     RuntimeError: if numpy arg elements do not match
113:   """
114:   for arg in args:
115:     if isinstance(arg, np.ndarray) and not np.all(arg == arg[0]):
116:       raise RuntimeError(f'numpy arg elements do not match: {arg}')
117: 
118:   # split out numpy and jax args
119:   np_args = [a[0] if isinstance(a, np.ndarray) else None for a in args]
120:   args = [a if n is None else None for n, a in zip(np_args, args)]
121: 
122:   # remove empty args that we should not vmap over
123:   args = jax.tree_util.tree_map(lambda a: a if a.shape[0] else None, args)
124:   in_axes = [None if a is None else 0 for a in args]
125: 
126:   def outer_f(*args, np_args=np_args):
127:     args = [a if n is None else n for n, a in zip(args, np_args)]
128:     return f(*args)
129: 
130:   return jax.vmap(outer_f, in_axes=in_axes)(*args)
131: 
132: 
133: def _check_input(m: Model, args: Any, in_types: str) -> None:
134:   """Checks that scan input has the right shape."""
135:   size = {
136:       'b': m.nbody,
137:       'j': m.njnt,
138:       'q': m.nq,
139:       'v': m.nv,
140:       'u': m.nu,
141:       'a': m.na,
142:       's': m.nsite,
143:       'c': m.ncam,
144:   }
145:   for idx, (arg, typ) in enumerate(zip(args, in_types)):
146:     if len(arg) != size[typ]:
147:       raise IndexError((
148:           f'f argument "{idx}" with type "{typ}" has length "{len(arg)}"'
149:           f' which does not match the in_types[{idx}] expected length of '
150:           f'"{size[typ]}".'
151:       ))
152: 
153: 
154: def _check_output(
155:     y: jax.Array, take_ids: np.ndarray, typ: str, idx: int
156: ) -> None:
157:   """Checks that scan output has the right shape."""
158:   if y.shape[0] != take_ids.shape[0]:
159:     raise IndexError((
160:         f'f output "{idx}" with type "{typ}" has shape "{y.shape[0]}" '
161:         f'which does not match the out_types[{idx}] expected size of'
162:         f' "{take_ids.shape[0]}".'
163:     ))
164: 
165: 
166: def flat(
167:     m: Model,
168:     f: Callable[..., Y],
169:     in_types: str,
170:     out_types: str,
171:     *args,
172:     group_by: str = 'j',
173: ) -> Y:
174:   r"""Scan a function across bodies or actuators.
175: 
176:   Scan group data according to type and batch shape then calls vmap(f) on it.\
177: 
178:   Args:
179:     m: an mjx model
180:     f: a function to be scanned with the following type signature:
181:         def f(key, *args) -> y
182:       where
183:         ``key`` gives grouping key for this function instance
184:         ``*args`` are input arguments with types matching ``in_types``
185:         ``y`` is an output arguments with types matching ``out_type``
186:     in_types: string specifying the type of each input arg:
187:       'b': split according to bodies
188:       'j': split according to joint types
189:       'q': split according to generalized coordinates (len(qpos))
190:       'v': split according to degrees of freedom (len(qvel))
191:       'u': split according to actuators
192:       'a': split according to actuator activations
193:       'c': split according to camera
194:     out_types: string specifying the types the output dimension matches
195:     *args: the input arguments corresponding to ``in_types``
196:     group_by: the type to group by, either joints or actuators
197: 
198:   Returns:
199:     The stacked outputs of ``f`` matching the model's order.
200: 
201:   Raises:
202:       IndexError: if function output shape does not match out_types shape
203:   """
204:   _check_input(m, args, in_types)
205: 
206:   if group_by not in {'j', 'u', 'c'}:
207:     raise NotImplementedError(f'group by type "{group_by}" not implemented.')
208: 
209:   def key_j(type_ids):
210:     if any(t in 'jqv' for t in in_types + out_types):
211:       return tuple(m.jnt_type[type_ids['j']])
212:     return ()
213: 
214:   def type_ids_j(m, i):
215:     return {
216:         'b': i,
217:         'j': np.nonzero(m.jnt_bodyid == i)[0],
218:         'v': np.nonzero(m.dof_bodyid == i)[0],
219:         'q': np.nonzero(_q_bodyid(m) == i)[0],
220:     }
221: 
222:   def key_u(type_ids):
223:     ids_u, ids_j = type_ids['u'], type_ids['j']
224:     return (
225:         m.actuator_biastype[ids_u],
226:         m.actuator_gaintype[ids_u],
227:         m.actuator_dyntype[ids_u],
228:         m.actuator_trntype[ids_u],
229:         m.jnt_type[ids_j],
230:         m.actuator_trnid[ids_u, 1] == -1,  # key by refsite being present
231:     )
232: 
233:   def type_ids_u(m, i):
234:     typ_ids = {
235:         'u': i,
236:         'a': m.actuator_actadr[i],
237:         'j': (
238:             m.actuator_trnid[i, 0]
239:             if m.actuator_trntype[i] in (TrnType.JOINT, TrnType.JOINTINPARENT)
240:             else -1
241:         ),
242:         's': (
243:             m.actuator_trnid[i]
244:             if m.actuator_trntype[i] == TrnType.SITE
245:             else np.array([-1, -1])
246:         ),
247:     }
248:     v, q = np.array([-1]), np.array([-1])
249:     if m.actuator_trntype[i] in (TrnType.JOINT, TrnType.JOINTINPARENT):
250:       # v/q are associated with the joint transmissions only
251:       v = np.nonzero(m.dof_jntid == typ_ids['j'])[0]
252:       q = np.nonzero(_q_jointid(m) == typ_ids['j'])[0]
253: 
254:     typ_ids.update({'v': v, 'q': q})
255: 
256:     return typ_ids
257: 
258:   def key_c(type_ids):
259:     return m.cam_mode[type_ids['c']], m.cam_targetbodyid[type_ids['c']] >= 0
260: 
261:   def type_ids_c(unused_m, i):
262:     return {
263:         'c': i,
264:     }
265: 
266:   type_ids_fn = {'j': type_ids_j, 'u': type_ids_u, 'c': type_ids_c}[group_by]
267:   key_fn = {'j': key_j, 'u': key_u, 'c': key_c}[group_by]
268: 
269:   # build up a grouping of type take-ids in body/actuator order
270:   key_typ_ids, order = {}, []
271:   all_types = set(in_types + out_types)
272:   n_items = {'j': m.nbody, 'u': m.nu, 'c': m.ncam}[group_by]
273:   for i in np.arange(n_items, dtype=np.int32):
274:     typ_ids = type_ids_fn(m, i)
275: 
276:     # create grouping key
277:     key = key_fn(typ_ids)
278:     order.append((key, typ_ids))
279: 
280:     # add ids per type to the corresponding group
281:     for t in all_types:
282:       out = key_typ_ids.setdefault(key, {})
283:       val = np.expand_dims(typ_ids[t], axis=0)
284:       out[t] = np.concatenate((out[t], val)) if t in out else val
285: 
286:   key_typ_ids = list(sorted(key_typ_ids.items()))
287: 
288:   # use this grouping to take the right data subsets and call vmap(f)
289:   ys = []
290:   for _, typ_ids in key_typ_ids:
291:     # only execute f if we would actually take something from the result
292:     if any(typ_ids[v].size > 0 for v in out_types):
293:       f_args = [_take(arg, typ_ids[typ]) for arg, typ in zip(args, in_types)]
294:       y = _nvmap(f, *f_args)
295:       ys.append(y)
296:     else:
297:       ys.append(None)
298: 
299:   # remove None results from the final output
300:   key_typ_ids = [v for y, v in zip(ys, key_typ_ids) if y is not None]
301:   ys = [y for y in ys if y is not None]
302:   ys_keys = set([k for k, *_ in key_typ_ids])
303:   order = [o for k, o in order if k in ys_keys]
304: 
305:   # get the original input order
306:   order = [[o[t] for o in order] for t in all_types]
307:   order = [
308:       np.concatenate(o) if isinstance(o[0], np.ndarray) else np.array(o)
309:       for o in order
310:   ]
311:   order = dict(zip(all_types, order))
312: 
313:   # concatenate back to a single tree and drop the grouping dimension
314:   f_ret_is_seq = isinstance(ys[0], (list, tuple))
315:   ys = ys if f_ret_is_seq else [[y] for y in ys]
316:   flat_ = {'j': 'b', 'u': 'uaj', 'c': 'c'}[group_by]
317:   ys = [
318:       [v if typ in flat_ else jp.concatenate(v) for v, typ in zip(y, out_types)]
319:       for y in ys
320:   ]
321:   ys = jax.tree_util.tree_map(lambda *x: jp.concatenate(x), *ys)
322: 
323:   # put concatenated results back in order
324:   reordered_ys = []
325:   for i, (y, typ) in enumerate(zip(ys, out_types)):
326:     _check_output(y, order[typ], typ, i)
327:     ids = np.concatenate([np.hstack(v[typ]) for _, v in key_typ_ids])
328:     input_order = order[typ][np.where(order[typ] != -1)]
329:     reordered_ys.append(_take(y, _index(ids, input_order)))
330:   y = reordered_ys if f_ret_is_seq else reordered_ys[0]
331: 
332:   return y
333: 
334: 
335: def body_tree(
336:     m: Model,
337:     f: Callable[..., Y],
338:     in_types: str,
339:     out_types: str,
340:     *args,
341:     reverse: bool = False,
342: ) -> Y:
343:   r"""Scan ``f`` across bodies in tree order, carrying results up/down the tree.
344: 
345:   This function groups bodies according to level and attached joints, then calls
346:   vmap(f) on them.\
347: 
348:   Args:
349:     m: an mjx mjmodel
350:     f: a function to be scanned with the following type signature:
351:         def f(y, *args) -> y
352:       where
353:         ``y`` is the carry value and return value
354:         ``*args`` are input arguments with types matching ``in_types``
355:     in_types: string specifying the type of each input arg:
356:       'b': split according to bodies
357:       'j': split according to joint types
358:       'q': split according to generalized coordinates (len(qpos))
359:       'v': split according to degrees of freedom (len(qvel))
360:     out_types: string specifying the types the output dimension matches
361:     *args: the input arguments corresponding to ``in_types``
362:     reverse: if True, scans up the body tree from leaves to root, otherwise
363:       root to leaves
364: 
365:   Returns:
366:     The stacked outputs of ``f`` matching the model's body order.
367: 
368:   Raises:
369:       IndexError: if function output shape does not match out_types shape
370:   """
371:   _check_input(m, args, in_types)
372: 
373:   # group together bodies that will be processed together.  grouping key:
374:   #   1) the tree depth: parent bodies are processed first, so that they are
375:   #      available as carry input to child bodies (or reverse if reverse=True)
376:   #   2) the types of arguments passed to f, both carry and *args:
377:   #      * for 'b' arguments, there is no extra grouping
378:   #      * for 'j' arguments, we group by joint type
379:   #      * for 'q' arguments, we group by q width
380:   #      * for 'v' arguments, we group by dof width
381:   depths = np.zeros(m.nbody, dtype=np.int32)
382: 
383:   # map key => body id
384:   key_body_ids = {}
385:   for body_id in range(m.nbody):
386:     parent_id = -1
387:     if body_id > 0:
388:       parent_id = m.body_parentid[body_id]
389:       depths[body_id] = 1 + depths[parent_id]
390: 
391:     # create grouping key: depth, carry, args
392:     key = (depths[body_id],)
393: 
394:     for i, t in enumerate(out_types + in_types):
395:       id_ = parent_id if i < len(out_types) else body_id
396:       if t == 'b':
397:         continue
398:       elif t == 'j':
399:         key += tuple(m.jnt_type[np.nonzero(m.jnt_bodyid == id_)[0]])
400:       elif t == 'v':
401:         key += (len(np.nonzero(m.dof_bodyid == id_)[0]),)
402:       elif t == 'q':
403:         key += (len(np.nonzero(_q_bodyid(m) == id_)[0]),)
404: 
405:     body_ids = key_body_ids.get(key, np.array([], dtype=np.int32))
406:     key_body_ids[key] = np.append(body_ids, body_id)
407: 
408:   # find parent keys of each key.  a key may have multiple parents if the
409:   # carry output keys of distinct parents are the same. e.g.:
410:   # - depth 0 body 1 (slide joint)
411:   # -- depth 1 body 1 (hinge joint)
412:   # - depth 0 body 2 (ball joint)
413:   # -- depth 1 body 2 (hinge joint)
414:   # given a scan with 'j' in the in_types, we would group depth 0 bodies
415:   # separately but we may group depth 1 bodies together
416:   key_parents = {}
417: 
418:   for key, body_ids in key_body_ids.items():
419:     body_ids = body_ids[body_ids != 0]  # ignore worldbody, has no parent
420:     if body_ids.size == 0:
421:       continue
422:     # find any key which has a body id that is a parent of these body_ids
423:     pids = m.body_parentid[body_ids]
424:     parents = {k for k, v in key_body_ids.items() if np.isin(v, pids).any()}
425:     key_parents[key] = list(sorted(parents))
426: 
427:   # key => take indices
428:   key_in_take, key_y_take = {}, {}
429:   for key, body_ids in key_body_ids.items():
430:     for i, typ in enumerate(in_types + out_types):
431:       if typ == 'b':
432:         ids = body_ids
433:       elif typ == 'j':
434:         ids = np.stack([np.nonzero(m.jnt_bodyid == b)[0] for b in body_ids])
435:       elif typ == 'v':
436:         ids = np.stack([np.nonzero(m.dof_bodyid == b)[0] for b in body_ids])
437:       elif typ == 'q':
438:         ids = np.stack([np.nonzero(_q_bodyid(m) == b)[0] for b in body_ids])
439:       else:
440:         raise ValueError(f'Unknown in_type: {typ}')
441:       if i < len(in_types):
442:         key_in_take.setdefault(key, []).append(ids)
443:       else:
444:         key_y_take.setdefault(key, []).append(np.hstack(ids))
445: 
446:   # use this grouping to take the right data subsets and call vmap(f)
447:   keys = sorted(key_body_ids, reverse=reverse)
448:   key_y = {}
449:   for key in keys:
450:     carry = None
451: 
452:     if reverse:
453:       child_keys = [k for k, v in key_parents.items() if key in v]
454: 
455:       for child_key in child_keys:
456:         y = key_y[child_key]
457:         body_ids = key_body_ids[key]
458:         parent_ids = m.body_parentid[key_body_ids[child_key]]
459:         id_map = _index(body_ids, parent_ids)
460: 
461:         def index_sum(x, i=id_map, s=body_ids.size):
462:           return jax.ops.segment_sum(x, i, s)
463: 
464:         y = jax.tree_util.tree_map(index_sum, y)
465:         carry = y if carry is None else jax.tree_util.tree_map(jp.add, carry, y)
466:     elif key in key_parents:
467:       ys = [key_y[p] for p in key_parents[key]]
468:       y = jax.tree_util.tree_map(lambda *x: jp.concatenate(x), *ys)
469:       body_ids = np.concatenate([key_body_ids[p] for p in key_parents[key]])
470:       parent_ids = m.body_parentid[key_body_ids[key]]
471:       take_fn = lambda x, i=_index(body_ids, parent_ids): _take(x, i)
472:       carry = jax.tree_util.tree_map(take_fn, y)
473: 
474:     f_args = [_take(arg, ids) for arg, ids in zip(args, key_in_take[key])]
475:     key_y[key] = _nvmap(f, carry, *f_args)
476: 
477:   # slice None results from the final output
478:   keys = [k for k in keys if key_y[k] is not None]
479: 
480:   # concatenate ys, drop grouping dimensions, put back in order
481:   y = []
482:   for i, typ in enumerate(out_types):
483:     y_typ = [key_y[key] for key in keys]
484:     if len(out_types) > 1:
485:       y_typ = [y_[i] for y_ in y_typ]
486:     if typ != 'b':
487:       y_typ = jax.tree_util.tree_map(jp.concatenate, y_typ)
488:     y_typ = jax.tree_util.tree_map(lambda *x: jp.concatenate(x), *y_typ)
489:     y_take = np.argsort(np.concatenate([key_y_take[key][i] for key in keys]))
490:     _check_output(y_typ, y_take, typ, i)
491:     y.append(_take(y_typ, y_take))
492: 
493:   y = y[0] if len(out_types) == 1 else y
494: 
495:   return y
496: