Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,7 @@ def cond(
Note that this constraint is violated when
creating variables or submodules in only one branch.
Because initializing variables in just one branch
causes the paramater structure to be different.
causes the parameter structure to be different.

Example::

Expand Down
2 changes: 1 addition & 1 deletion flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,7 @@ def lazy_init(
) -> Callable[..., Any]:
"""Functionalizes a `Scope` function for lazy initialization.

Similair to ``init`` except that the init function now accepts
Similar to ``init`` except that the init function now accepts
``jax.ShapeDtypeStruct`` instances for arguments that do not
affect the variable initialization (typically this is all the input data).

Expand Down
2 changes: 1 addition & 1 deletion flax/linen/experimental/layers_with_named_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _normalize(
):
""" "Normalizes the input of a normalization layer and optionally applies a learned scale and bias.

A seperate bias and scale is learned for each feature as specified by
A separate bias and scale is learned for each feature as specified by
feature_axes.
"""
reduction_axes = _canonicalize_axes(x.ndim, reduction_axes)
Expand Down
4 changes: 2 additions & 2 deletions flax/linen/kw_only_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

For earlier Python versions, when constructing a dataclass, any fields that have
been marked as keyword-only (including inherited fields) will be moved to the
end of the constuctor's argument list. This makes it possible to have a base
end of the constructor's argument list. This makes it possible to have a base
class that defines a field with a default, and a subclass that defines a field
without a default. E.g.:

Expand Down Expand Up @@ -77,7 +77,7 @@ def __repr__(self):


def field(*, metadata=None, kw_only=dataclasses.MISSING, **kwargs):
"""Wrapper for dataclassess.field that adds support for kw_only fields.
"""Wrapper for dataclasses.field that adds support for kw_only fields.

Args:
metadata: A mapping or None, containing metadata for the field.
Expand Down
2 changes: 1 addition & 1 deletion flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def nowrap(fun: _CallableT) -> _CallableT:

This is needed in several concrete instances:
- if you're subclassing a method like Module.param and don't want this
overriden core function decorated with the state management wrapper.
overridden core function decorated with the state management wrapper.
- If you want a method to be callable from an unbound Module (e.g.: a
function of construction of arguments that doesn't depend on params/RNGs).
If you want to learn more about how Flax Modules manage their state read the
Expand Down
4 changes: 2 additions & 2 deletions flax/linen/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ class RNN(Module):

By default RNN expect the time dimension after the batch dimension (``(*batch,
time, *features)``),
if you set ``time_major=True`` RNN will instead expect the time dimesion to be
if you set ``time_major=True`` RNN will instead expect the time dimension to be
at the beginning
(``(time, *batch, *features)``)::

Expand Down Expand Up @@ -955,7 +955,7 @@ class RNN(Module):
RNN also accepts some of the arguments of :func:`flax.linen.scan`, by default
they are set to
work with cells like :class:`LSTMCell` and :class:`GRUCell` but they can be
overriden as needed.
overridden as needed.
Overriding default values to scan looks like this::

>>> lstm = nn.RNN(
Expand Down
2 changes: 1 addition & 1 deletion flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2066,7 +2066,7 @@ def custom_vjp(
Args:
fn: The function to define a custom_vjp for.
forward_fn: A function with the same arguments as ``fn`` returning an tuple
with the original output and the residuals that will be passsed to
with the original output and the residuals that will be passed to
``backward_fn``.
backward_fn: arguments are passed as
``(*nondiff_args, residuals, tangents)`` The function should return a
Expand Down
6 changes: 3 additions & 3 deletions flax/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,8 +669,8 @@ class Conv(Module):
``'CIRCULAR'`` (periodic boundary conditions), the string `'REFLECT'`
(reflection across the padding boundary), or a sequence of ``n``
``(low, high)`` integer pairs that give the padding to apply before and after each
spatial dimension. A single int is interpeted as applying the same padding
in all dims and passign a single int in a sequence causes the same padding
spatial dimension. A single int is interpreted as applying the same padding
in all dims and passing a single int in a sequence causes the same padding
to be used on both sides. ``'CAUSAL'`` padding for a 1D convolution will
left-pad the convolution axis, resulting in same-sized output.
input_dilation: an integer or a sequence of ``n`` integers, giving the
Expand Down Expand Up @@ -951,7 +951,7 @@ class ConvTranspose(Module):
inter-window strides (default: 1).
padding: either a string indicating a specialized padding mode,
or a sequence of ``n`` ``(low, high)`` integer pairs that give the padding to apply before and after each
spatial dimension. A single int is interpeted as applying the same padding
spatial dimension. A single int is interpreted as applying the same padding
in all dims and a single int in a sequence causes the same padding
to be used on both sides.

Expand Down
4 changes: 2 additions & 2 deletions flax/nnx/transforms/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,10 +738,10 @@ def custom_vjp(

``nnx.custom_vjp`` accepts Modules and other Flax NNX objects as arguments. The main difference
with the JAX version is that, because Modules follow reference semantics, they propagate the State
updates for the inputs as auxiliary outputs. This means that the incomming gradients in the ``bwd`` function
updates for the inputs as auxiliary outputs. This means that the incoming gradients in the ``bwd`` function
will have the form ``(input_updates_g, out_g)`` where ``input_updates_g`` is the gradient updated state of
the inputs w.r.t. to the inputs. All Module terms on the inputs will an associated ``State`` term in
``input_updates_g``, while all non-Module terms will appear as None. The shape of the tanget will be
``input_updates_g``, while all non-Module terms will appear as None. The shape of the tangent will be
expected to have the same shape as the input, with ``State`` terms in place of the corresponding Module terms.

Example::
Expand Down
2 changes: 1 addition & 1 deletion flax/training/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Common utilty functions used in data-parallel Flax examples.
"""Common utility functions used in data-parallel Flax examples.

This module is a historical grab-bag of utility functions primarily concerned
with helping write pmap-based data-parallel training loops.
Expand Down
6 changes: 3 additions & 3 deletions tests/linen/linen_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def test_multihead_self_attention_w_dropout_disabled(self):
np.testing.assert_allclose(y7, y8)

def test_causal_mask_1d(self):
"""Tests autoregresive masking for 1d attention."""
"""Tests autoregressive masking for 1d attention."""
x = jnp.ones((3, 16)) # (bs1, length)
mask_1d = nn.attention.make_causal_mask(x)
ts = np.arange(16)
Expand Down Expand Up @@ -257,8 +257,8 @@ def body_fn(state, x):

np.testing.assert_allclose(y_ref, y, atol=1e-5)

def test_autoregresive_receptive_field_1d(self):
"""Tests the autoregresive self-attention receptive field."""
def test_autoregressive_receptive_field_1d(self):
"""Tests the autoregressive self-attention receptive field."""
rng = random.key(0)
rng1, rng2 = random.split(rng, num=2)

Expand Down
Loading