Skip to content

List vs. nested structure for weights, gradients, and updates (dualized grads) #11

@EIFY

Description

@EIFY

What's the rationale for using a flat list for weights / gradients / updates instead of a nested structure? I believe the latter is more standard for JAX (see Working with pytrees) and avoids the need of counting and slicing like this:

modula/modula/abstract.py

Lines 107 to 109 in ede2ba7

m0, m1 = self.children
w0 = w[:m0.atoms]
w1 = w[m0.atoms:]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions