diff --git a/src/julax/layers.py b/src/julax/layers.py index c1c6d58..0cdb0fb 100644 --- a/src/julax/layers.py +++ b/src/julax/layers.py @@ -38,6 +38,12 @@ def forward(self, x: PyTree, p: Param, s: State) -> tuple[PyTree, State]: class Repeated(LayerBase): + """Applies the same layer n times sequentially. + + Each repetition has independent, separately initialized parameters. + This is typical for transformer blocks where each layer has its own weights. + """ + n: int layer: LayerLike