-
Notifications
You must be signed in to change notification settings - Fork 209
Description
I am using mctx to implement MuZero, where actions are selected at each node via a forward pass of a large neural network.
Initially, the main bottleneck in terms of time cost is the model forward pass, which is expected and similar to cases where C++-based MCTS is used.
However, after a few steps of training, as the model starts to learn certain biases, the search tree becomes deeper. At this point, the bottleneck shifts to the loop_fn used in the selection and backpropagation phases.
Here's the trace of one search. You can see that most of the time is spent in search.py:291 and search.py:181, which correspond to (search.py was slightly adjusted to fit my codebase, so the line number mismatches):
tree, _, _ = jax.lax.while_loop(cond_fun, body_fun, loop_state)inbackwardend_state = jax.lax.while_loop(cond_fun, body_fun, initial_state)insimulate
My question is: is there any way to improve the time efficiency of this part?
In practice, even with a tree depth of only a few tens, the performance is heavily affected by the huge gap in efficiency between jax.lax.fori_loop on GPU and CPU (with GPU being thousands of times slower).