diff --git a/flax/core/lift.py b/flax/core/lift.py index 7500c65bd..678110029 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -1171,7 +1171,7 @@ def cond( Note that this constraint is violated when creating variables or submodules in only one branch. Because initializing variables in just one branch - causes the paramater structure to be different. + causes the parameter structure to be different. Example:: diff --git a/flax/core/scope.py b/flax/core/scope.py index 3e207d3e5..e9b2cf9c5 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -1126,7 +1126,7 @@ def lazy_init( ) -> Callable[..., Any]: """Functionalizes a `Scope` function for lazy initialization. - Similair to ``init`` except that the init function now accepts + Similar to ``init`` except that the init function now accepts ``jax.ShapeDtypeStruct`` instances for arguments that do not affect the variable initialization (typically this is all the input data). diff --git a/flax/linen/experimental/layers_with_named_axes.py b/flax/linen/experimental/layers_with_named_axes.py index 795aa845f..7b2e2314c 100644 --- a/flax/linen/experimental/layers_with_named_axes.py +++ b/flax/linen/experimental/layers_with_named_axes.py @@ -248,7 +248,7 @@ def _normalize( ): """ "Normalizes the input of a normalization layer and optionally applies a learned scale and bias. - A seperate bias and scale is learned for each feature as specified by + A separate bias and scale is learned for each feature as specified by feature_axes. """ reduction_axes = _canonicalize_axes(x.ndim, reduction_axes) diff --git a/flax/linen/kw_only_dataclasses.py b/flax/linen/kw_only_dataclasses.py index 95542fafe..40575e613 100644 --- a/flax/linen/kw_only_dataclasses.py +++ b/flax/linen/kw_only_dataclasses.py @@ -23,7 +23,7 @@ For earlier Python versions, when constructing a dataclass, any fields that have been marked as keyword-only (including inherited fields) will be moved to the -end of the constuctor's argument list. This makes it possible to have a base +end of the constructor's argument list. This makes it possible to have a base class that defines a field with a default, and a subclass that defines a field without a default. E.g.: @@ -77,7 +77,7 @@ def __repr__(self): def field(*, metadata=None, kw_only=dataclasses.MISSING, **kwargs): - """Wrapper for dataclassess.field that adds support for kw_only fields. + """Wrapper for dataclasses.field that adds support for kw_only fields. Args: metadata: A mapping or None, containing metadata for the field. diff --git a/flax/linen/module.py b/flax/linen/module.py index f4dc1bd14..21b3548bd 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -510,7 +510,7 @@ def nowrap(fun: _CallableT) -> _CallableT: This is needed in several concrete instances: - if you're subclassing a method like Module.param and don't want this - overriden core function decorated with the state management wrapper. + overridden core function decorated with the state management wrapper. - If you want a method to be callable from an unbound Module (e.g.: a function of construction of arguments that doesn't depend on params/RNGs). If you want to learn more about how Flax Modules manage their state read the diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index 274546fc5..b55b39ba7 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -914,7 +914,7 @@ class RNN(Module): By default RNN expect the time dimension after the batch dimension (``(*batch, time, *features)``), - if you set ``time_major=True`` RNN will instead expect the time dimesion to be + if you set ``time_major=True`` RNN will instead expect the time dimension to be at the beginning (``(time, *batch, *features)``):: @@ -955,7 +955,7 @@ class RNN(Module): RNN also accepts some of the arguments of :func:`flax.linen.scan`, by default they are set to work with cells like :class:`LSTMCell` and :class:`GRUCell` but they can be - overriden as needed. + overridden as needed. Overriding default values to scan looks like this:: >>> lstm = nn.RNN( diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 48007200a..eb695bac8 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -2066,7 +2066,7 @@ def custom_vjp( Args: fn: The function to define a custom_vjp for. forward_fn: A function with the same arguments as ``fn`` returning an tuple - with the original output and the residuals that will be passsed to + with the original output and the residuals that will be passed to ``backward_fn``. backward_fn: arguments are passed as ``(*nondiff_args, residuals, tangents)`` The function should return a diff --git a/flax/nnx/nn/linear.py b/flax/nnx/nn/linear.py index f25433c96..1fecb8316 100644 --- a/flax/nnx/nn/linear.py +++ b/flax/nnx/nn/linear.py @@ -669,8 +669,8 @@ class Conv(Module): ``'CIRCULAR'`` (periodic boundary conditions), the string `'REFLECT'` (reflection across the padding boundary), or a sequence of ``n`` ``(low, high)`` integer pairs that give the padding to apply before and after each - spatial dimension. A single int is interpeted as applying the same padding - in all dims and passign a single int in a sequence causes the same padding + spatial dimension. A single int is interpreted as applying the same padding + in all dims and passing a single int in a sequence causes the same padding to be used on both sides. ``'CAUSAL'`` padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output. input_dilation: an integer or a sequence of ``n`` integers, giving the @@ -951,7 +951,7 @@ class ConvTranspose(Module): inter-window strides (default: 1). padding: either a string indicating a specialized padding mode, or a sequence of ``n`` ``(low, high)`` integer pairs that give the padding to apply before and after each - spatial dimension. A single int is interpeted as applying the same padding + spatial dimension. A single int is interpreted as applying the same padding in all dims and a single int in a sequence causes the same padding to be used on both sides. diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index a4a1afdf9..a9556e9a7 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -738,10 +738,10 @@ def custom_vjp( ``nnx.custom_vjp`` accepts Modules and other Flax NNX objects as arguments. The main difference with the JAX version is that, because Modules follow reference semantics, they propagate the State - updates for the inputs as auxiliary outputs. This means that the incomming gradients in the ``bwd`` function + updates for the inputs as auxiliary outputs. This means that the incoming gradients in the ``bwd`` function will have the form ``(input_updates_g, out_g)`` where ``input_updates_g`` is the gradient updated state of the inputs w.r.t. to the inputs. All Module terms on the inputs will an associated ``State`` term in - ``input_updates_g``, while all non-Module terms will appear as None. The shape of the tanget will be + ``input_updates_g``, while all non-Module terms will appear as None. The shape of the tangent will be expected to have the same shape as the input, with ``State`` terms in place of the corresponding Module terms. Example:: diff --git a/flax/training/common_utils.py b/flax/training/common_utils.py index 1000b28ae..a8c08cd32 100644 --- a/flax/training/common_utils.py +++ b/flax/training/common_utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Common utilty functions used in data-parallel Flax examples. +"""Common utility functions used in data-parallel Flax examples. This module is a historical grab-bag of utility functions primarily concerned with helping write pmap-based data-parallel training loops. diff --git a/tests/linen/linen_attention_test.py b/tests/linen/linen_attention_test.py index e8ce47991..d66428a4c 100644 --- a/tests/linen/linen_attention_test.py +++ b/tests/linen/linen_attention_test.py @@ -206,7 +206,7 @@ def test_multihead_self_attention_w_dropout_disabled(self): np.testing.assert_allclose(y7, y8) def test_causal_mask_1d(self): - """Tests autoregresive masking for 1d attention.""" + """Tests autoregressive masking for 1d attention.""" x = jnp.ones((3, 16)) # (bs1, length) mask_1d = nn.attention.make_causal_mask(x) ts = np.arange(16) @@ -257,8 +257,8 @@ def body_fn(state, x): np.testing.assert_allclose(y_ref, y, atol=1e-5) - def test_autoregresive_receptive_field_1d(self): - """Tests the autoregresive self-attention receptive field.""" + def test_autoregressive_receptive_field_1d(self): + """Tests the autoregressive self-attention receptive field.""" rng = random.key(0) rng1, rng2 = random.split(rng, num=2)