diff --git a/chex/_src/dataclass.py b/chex/_src/dataclass.py index ea11125..e2fdf29 100644 --- a/chex/_src/dataclass.py +++ b/chex/_src/dataclass.py @@ -100,7 +100,7 @@ def dataclass( order=False, unsafe_hash=False, frozen=False, - kw_only: bool = False, + kw_only: bool = True, mappable_dataclass=True, # pylint: disable=redefined-outer-name ): """JAX-friendly wrapper for :py:func:`dataclasses.dataclass`. @@ -154,7 +154,7 @@ def __init__( order=False, unsafe_hash=False, frozen=False, - kw_only=False, + kw_only=True, mappable_dataclass=True, # pylint: disable=redefined-outer-name ): self.init = init