diff --git a/mypyc/irbuild/format_str_tokenizer.py b/mypyc/irbuild/format_str_tokenizer.py index 2596e8e050a6..d34ecfe00801 100644 --- a/mypyc/irbuild/format_str_tokenizer.py +++ b/mypyc/irbuild/format_str_tokenizer.py @@ -26,7 +26,7 @@ from mypyc.irbuild.constant_fold import constant_fold_expr from mypyc.primitives.bytes_ops import bytes_build_op from mypyc.primitives.int_ops import int_to_str_op -from mypyc.primitives.str_ops import str_build_op, str_op +from mypyc.primitives.str_ops import ascii_op, str_build_op, str_op @unique @@ -42,6 +42,7 @@ class FormatOp(Enum): STR = "s" INT = "d" + ASCII = "a" BYTES = "b" @@ -53,14 +54,25 @@ def generate_format_ops(specifiers: list[ConversionSpecifier]) -> list[FormatOp] format_ops = [] for spec in specifiers: # TODO: Match specifiers instead of using whole_seq - if spec.whole_seq == "%s" or spec.whole_seq == "{:{}}": + # Conversion flags for str.format/f-strings (e.g. {!a}); only if no format spec. + if spec.conversion and not spec.format_spec: + if spec.conversion == "!a": + format_op = FormatOp.ASCII + else: + return None + # printf-style tokens and special f-string lowering patterns. + elif spec.whole_seq == "%s" or spec.whole_seq == "{:{}}": format_op = FormatOp.STR elif spec.whole_seq == "%d": format_op = FormatOp.INT + elif spec.whole_seq == "%a": + format_op = FormatOp.ASCII elif spec.whole_seq == "%b": format_op = FormatOp.BYTES + # Any other non-empty spec means we can't optimize; fall back to runtime formatting. elif spec.whole_seq: return None + # Empty spec ("{}") defaults to str(). else: format_op = FormatOp.STR format_ops.append(format_op) @@ -152,6 +164,11 @@ def convert_format_expr_to_str( var_str = builder.primitive_op(int_to_str_op, [builder.accept(x)], line) else: var_str = builder.primitive_op(str_op, [builder.accept(x)], line) + elif format_op == FormatOp.ASCII: + if (folded := constant_fold_expr(builder, x)) is not None: + var_str = builder.load_literal_value(ascii(folded)) + else: + var_str = builder.primitive_op(ascii_op, [builder.accept(x)], line) elif format_op == FormatOp.INT: if isinstance(folded := constant_fold_expr(builder, x), int): var_str = builder.load_literal_value(str(folded)) diff --git a/mypyc/primitives/str_ops.py b/mypyc/primitives/str_ops.py index f6d3f722dd7b..6a178accca8a 100644 --- a/mypyc/primitives/str_ops.py +++ b/mypyc/primitives/str_ops.py @@ -51,6 +51,15 @@ error_kind=ERR_MAGIC, ) +# ascii(obj) +ascii_op = function_op( + name="builtins.ascii", + arg_types=[object_rprimitive], + return_type=str_rprimitive, + c_function_name="PyObject_ASCII", + error_kind=ERR_MAGIC, +) + # translate isinstance(obj, str) isinstance_str = function_op( name="builtins.isinstance", diff --git a/mypyc/test-data/irbuild-constant-fold.test b/mypyc/test-data/irbuild-constant-fold.test index 009a4e68f678..ba6cc3b73ae1 100644 --- a/mypyc/test-data/irbuild-constant-fold.test +++ b/mypyc/test-data/irbuild-constant-fold.test @@ -192,9 +192,14 @@ L0: from typing import Any, Final FMT: Final = "{} {}" +FMT_A: Final = "{!a}" def f() -> str: return FMT.format(400 + 20, "roll" + "up") +def g() -> str: + return FMT_A.format("\u00e9") +def g2() -> str: + return FMT_A.format("\u2603") [out] def f(): r0, r1, r2, r3 :: str @@ -204,6 +209,18 @@ L0: r2 = ' ' r3 = CPyStr_Build(3, r0, r2, r1) return r3 +def g(): + r0, r1 :: str +L0: + r0 = "'\\xe9'" + r1 = CPyStr_Build(1, r0) + return r1 +def g2(): + r0, r1 :: str +L0: + r0 = "'\\u2603'" + r1 = CPyStr_Build(1, r0) + return r1 [case testIntConstantFoldingFinal] from typing import Final diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index ee618bb34f65..16aed004e02f 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -291,6 +291,7 @@ def f(var: Union[str, NewStr], num: int) -> None: s2 = "I am %d years old." % num s3 = "Hi! I'm %s. I am %d years old." % (var, num) s4 = "Float: %f" % num + s5 = "Ascii: %a" % var [typing fixtures/typing-full.pyi] [out] def f(var, num): @@ -298,7 +299,7 @@ def f(var, num): num :: int r0, r1, r2, s1, r3, r4, r5, r6, s2, r7, r8, r9, r10, r11, s3, r12 :: str r13, r14 :: object - r15, s4 :: str + r15, s4, r16, r17, r18, s5 :: str L0: r0 = "Hi! I'm " r1 = '.' @@ -320,6 +321,10 @@ L0: r14 = PyNumber_Remainder(r12, r13) r15 = cast(str, r14) s4 = r15 + r16 = PyObject_ASCII(var) + r17 = 'Ascii: ' + r18 = CPyStr_Build(2, r17, r16) + s5 = r18 return 1 [case testDecode]