Skip to content

Conversation

@twicki
Copy link
Collaborator

@twicki twicki commented Mar 20, 2025

Description

As we realized there is more mixed-precision work in fv3, this PR allows for generic casts.

Allow for generic casts int and float to a literal precision that is given to any given stencil. This allows for better re-use instead of having to rely on the hard-coded i32 and i64 between stencils.

@twicki twicki marked this pull request as draft March 20, 2025 14:05
@twicki twicki marked this pull request as ready for review March 20, 2025 17:27
Copy link
Collaborator

@romanc romanc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Thanks for adding all the docstrings. Some nitpicks about the types.

twicki and others added 4 commits March 21, 2025 09:52
Co-authored-by: Roman Cattaneo <romanc@users.noreply.github.com>
Co-authored-by: Roman Cattaneo <romanc@users.noreply.github.com>
Co-authored-by: Roman Cattaneo <romanc@users.noreply.github.com>
Copy link
Owner

@stubbiali stubbiali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @twicki, LGTM! I left some minor comments, but the PR could be merged as is.

Comment on lines +104 to +105
literal_precision = attribute(of=int, default=64)
"Specify the literal precision for automatic casts. Defaults to 64-bit"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make it immediately clear that literal_precision can either be 32 or 64, one might define the following type descriptor in gt4py.cartesian.utils.attrib:

def _make_literal_validator(args):
    if not isinstance(args, tuple):
        args = tuple([args])

    def _is_literal_validator(instance, attribute, value):
        if value not in args:
            raise ValueError(
                f"Invalid literal value {value} for attribute `{attribute.name}`; "
                f"expected one of {args}."
            )

    return _is_literal_validator


class _LiteralDescriptor:
    def __getitem__(self, values):
        return _TypeDescriptor(
            "Literal",
            args=values,
            make_validator_func=_make_literal_validator,
            type_hint=typing.Literal,
        )


Literal = _LiteralDescriptor()

so that here we could have:

Suggested change
literal_precision = attribute(of=int, default=64)
"Specify the literal precision for automatic casts. Defaults to 64-bit"
literal_precision = attribute(of=Literal[32, 64], default=64)
"Specify the literal precision for automatic casts. Defaults to 64-bit"

Just thinking out loud, this comment should not be blocking.

"f32": DataType.FLOAT32,
"f64": DataType.FLOAT64,
}
def frontend_type_to_native_type(literal_precision: int = 64) -> dict[str, DataType]:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def frontend_type_to_native_type(literal_precision: int = 64) -> dict[str, DataType]:
def frontend_type_to_native_type(literal_precision: Literal[32, 64] = 64) -> dict[str, DataType]:

Comment on lines +289 to +290
if not isinstance(literal_precision, int) and literal_precision not in (32, 64):
raise ValueError(f"Invalid 'literal_precision' ('{literal_precision}')")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check would be redundant with the proposed custom type descriptor for literals.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some tests of this suite are failing for reasons that go beyond this PR. I'm pretty sure you're well aware of this, but would like to double-check anyway.

twicki and others added 4 commits March 28, 2025 10:26
Co-authored-by: Stefano Ubbiali <subbiali@phys.ethz.ch>
Co-authored-by: Stefano Ubbiali <subbiali@phys.ethz.ch>
Co-authored-by: Stefano Ubbiali <subbiali@phys.ethz.ch>
Co-authored-by: Stefano Ubbiali <subbiali@phys.ethz.ch>
@twicki twicki merged commit 606b310 into stubbiali:physics Apr 9, 2025
0 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants