Skip to content

loop_fn in search.py becomes slow as the tree depth increases #107

@HeavyCrab

Description

@HeavyCrab

I am using mctx to implement MuZero, where actions are selected at each node via a forward pass of a large neural network.
Initially, the main bottleneck in terms of time cost is the model forward pass, which is expected and similar to cases where C++-based MCTS is used.
However, after a few steps of training, as the model starts to learn certain biases, the search tree becomes deeper. At this point, the bottleneck shifts to the loop_fn used in the selection and backpropagation phases.

Here's the trace of one search. You can see that most of the time is spent in search.py:291 and search.py:181, which correspond to (search.py was slightly adjusted to fit my codebase, so the line number mismatches):

  • tree, _, _ = jax.lax.while_loop(cond_fun, body_fun, loop_state) in backward
  • end_state = jax.lax.while_loop(cond_fun, body_fun, initial_state) in simulate
Image Image

My question is: is there any way to improve the time efficiency of this part?
In practice, even with a tree depth of only a few tens, the performance is heavily affected by the huge gap in efficiency between jax.lax.fori_loop on GPU and CPU (with GPU being thousands of times slower).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions