What's the rationale for using a flat list for weights / gradients / updates instead of a nested structure? I believe the latter is more standard for JAX (see Working with pytrees) and avoids the need of counting and slicing like this:
|
m0, m1 = self.children |
|
w0 = w[:m0.atoms] |
|
w1 = w[m0.atoms:] |