-
Notifications
You must be signed in to change notification settings - Fork 1
feat[cartesian] Literal precision #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
romanc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Thanks for adding all the docstrings. Some nitpicks about the types.
Co-authored-by: Roman Cattaneo <romanc@users.noreply.github.com>
Co-authored-by: Roman Cattaneo <romanc@users.noreply.github.com>
Co-authored-by: Roman Cattaneo <romanc@users.noreply.github.com>
stubbiali
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @twicki, LGTM! I left some minor comments, but the PR could be merged as is.
| literal_precision = attribute(of=int, default=64) | ||
| "Specify the literal precision for automatic casts. Defaults to 64-bit" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make it immediately clear that literal_precision can either be 32 or 64, one might define the following type descriptor in gt4py.cartesian.utils.attrib:
def _make_literal_validator(args):
if not isinstance(args, tuple):
args = tuple([args])
def _is_literal_validator(instance, attribute, value):
if value not in args:
raise ValueError(
f"Invalid literal value {value} for attribute `{attribute.name}`; "
f"expected one of {args}."
)
return _is_literal_validator
class _LiteralDescriptor:
def __getitem__(self, values):
return _TypeDescriptor(
"Literal",
args=values,
make_validator_func=_make_literal_validator,
type_hint=typing.Literal,
)
Literal = _LiteralDescriptor()so that here we could have:
| literal_precision = attribute(of=int, default=64) | |
| "Specify the literal precision for automatic casts. Defaults to 64-bit" | |
| literal_precision = attribute(of=Literal[32, 64], default=64) | |
| "Specify the literal precision for automatic casts. Defaults to 64-bit" |
Just thinking out loud, this comment should not be blocking.
| "f32": DataType.FLOAT32, | ||
| "f64": DataType.FLOAT64, | ||
| } | ||
| def frontend_type_to_native_type(literal_precision: int = 64) -> dict[str, DataType]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def frontend_type_to_native_type(literal_precision: int = 64) -> dict[str, DataType]: | |
| def frontend_type_to_native_type(literal_precision: Literal[32, 64] = 64) -> dict[str, DataType]: |
| if not isinstance(literal_precision, int) and literal_precision not in (32, 64): | ||
| raise ValueError(f"Invalid 'literal_precision' ('{literal_precision}')") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check would be redundant with the proposed custom type descriptor for literals.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some tests of this suite are failing for reasons that go beyond this PR. I'm pretty sure you're well aware of this, but would like to double-check anyway.
Co-authored-by: Stefano Ubbiali <subbiali@phys.ethz.ch>
Co-authored-by: Stefano Ubbiali <subbiali@phys.ethz.ch>
Co-authored-by: Stefano Ubbiali <subbiali@phys.ethz.ch>
Co-authored-by: Stefano Ubbiali <subbiali@phys.ethz.ch>
Description
As we realized there is more mixed-precision work in fv3, this PR allows for generic casts.
Allow for generic casts
intandfloatto a literal precision that is given to any given stencil. This allows for better re-use instead of having to rely on the hard-codedi32andi64between stencils.