diff --git a/.changes/next-release/feature-1c36b965bf2f51d85aa0171be6206ae2029137c4.json b/.changes/next-release/feature-1c36b965bf2f51d85aa0171be6206ae2029137c4.json new file mode 100644 index 00000000000..3153bd1730d --- /dev/null +++ b/.changes/next-release/feature-1c36b965bf2f51d85aa0171be6206ae2029137c4.json @@ -0,0 +1,7 @@ +{ + "type": "feature", + "description": "Implement rules engine ITE fn and S3 tree transform", + "pull_requests": [ + "[#2903](https://github.com/smithy-lang/smithy/pull/2903)" + ] +} diff --git a/.changes/next-release/feature-5e4b03b74b1054ac2632fcae533d727eace1d26e.json b/.changes/next-release/feature-5e4b03b74b1054ac2632fcae533d727eace1d26e.json new file mode 100644 index 00000000000..cd940a23a40 --- /dev/null +++ b/.changes/next-release/feature-5e4b03b74b1054ac2632fcae533d727eace1d26e.json @@ -0,0 +1,7 @@ +{ + "type": "feature", + "description": "Improve BDD sifting (2x speed, more reduction)", + "pull_requests": [ + "[#2890](https://github.com/smithy-lang/smithy/pull/2890)" + ] +} diff --git a/docs/source-2.0/additional-specs/rules-engine/standard-library.rst b/docs/source-2.0/additional-specs/rules-engine/standard-library.rst index 6950b914468..8cb5f7d51d8 100644 --- a/docs/source-2.0/additional-specs/rules-engine/standard-library.rst +++ b/docs/source-2.0/additional-specs/rules-engine/standard-library.rst @@ -208,6 +208,103 @@ The following example uses ``isValidHostLabel`` to check if the value of the } +.. _rules-engine-standard-library-ite: + +``ite`` function +================ + +Summary + An if-then-else function that returns one of two values based on a boolean condition. +Argument types + * condition: ``bool`` + * trueValue: ``T`` or ``option`` + * falseValue: ``T`` or ``option`` +Return type + * ``ite(bool, T, T)`` → ``T`` (both non-optional, result is non-optional) + * ``ite(bool, T, option)`` → ``option`` (any optional makes result optional) + * ``ite(bool, option, T)`` → ``option`` (any optional makes result optional) + * ``ite(bool, option, option)`` → ``option`` (both optional, result is optional) +Since + 1.1 + +The ``ite`` (if-then-else) function evaluates a boolean condition and returns one of two values based on +the result. If the condition is ``true``, it returns ``trueValue``; if ``false``, it returns ``falseValue``. +This function is particularly useful for computing conditional values without branching in the rule tree, resulting +in fewer result nodes, and enabling better BDD optimizations as a result of reduced fragmentation. + +.. important:: + Both ``trueValue`` and ``falseValue`` must have the same base type ``T``. The result type follows + the "least upper bound" rule: if either branch is optional, the result is optional. + +The following example uses ``ite`` to compute a URL suffix based on whether FIPS is enabled: + +.. code-block:: json + + { + "fn": "ite", + "argv": [ + {"ref": "UseFIPS"}, + "-fips", + "" + ], + "assign": "fipsSuffix" + } + +The following example uses ``ite`` with ``coalesce`` to handle an optional boolean parameter: + +.. code-block:: json + + { + "fn": "ite", + "argv": [ + { + "fn": "coalesce", + "argv": [ + {"ref": "DisableFeature"}, + false + ] + }, + "disabled", + "enabled" + ], + "assign": "featureState" + } + + +.. _rules-engine-standard-library-ite-examples: + +-------- +Examples +-------- + +The following table shows various inputs and their corresponding outputs for the ``ite`` function: + +.. list-table:: + :header-rows: 1 + :widths: 20 25 25 30 + + * - Condition + - True Value + - False Value + - Output + * - ``true`` + - ``"-fips"`` + - ``""`` + - ``"-fips"`` + * - ``false`` + - ``"-fips"`` + - ``""`` + - ``""`` + * - ``true`` + - ``"sigv4"`` + - ``"sigv4-s3express"`` + - ``"sigv4"`` + * - ``false`` + - ``"sigv4"`` + - ``"sigv4-s3express"`` + - ``"sigv4-s3express"`` + + .. _rules-engine-standard-library-not: ``not`` function diff --git a/settings.gradle.kts b/settings.gradle.kts index f3c9eba093a..bffb67b89a7 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -6,6 +6,8 @@ pluginManagement { } } + + rootProject.name = "smithy" include(":smithy-aws-iam-traits") diff --git a/smithy-aws-endpoints/build.gradle.kts b/smithy-aws-endpoints/build.gradle.kts index c142cd4ee32..559731a51a3 100644 --- a/smithy-aws-endpoints/build.gradle.kts +++ b/smithy-aws-endpoints/build.gradle.kts @@ -11,10 +11,64 @@ description = "AWS specific components for managing endpoints in Smithy" extra["displayName"] = "Smithy :: AWS Endpoints Components" extra["moduleName"] = "software.amazon.smithy.aws.endpoints" +// Custom configuration for S3 model - kept separate from test classpath to avoid +// polluting other tests with S3 model discovery +val s3Model: Configuration by configurations.creating + dependencies { api(project(":smithy-aws-traits")) api(project(":smithy-diff")) api(project(":smithy-rules-engine")) api(project(":smithy-model")) api(project(":smithy-utils")) + + s3Model("software.amazon.api.models:s3:1.0.11") +} + +// Integration test source set for tests that require the S3 model +// These tests require JDK 21+ due to the S3 model dependency +sourceSets { + create("it") { + compileClasspath += sourceSets["main"].output + sourceSets["test"].output + runtimeClasspath += sourceSets["main"].output + sourceSets["test"].output + } +} + +configurations["itImplementation"].extendsFrom(configurations["testImplementation"]) +configurations["itRuntimeOnly"].extendsFrom(configurations["testRuntimeOnly"]) +configurations["itImplementation"].extendsFrom(s3Model) + +// Configure IT source set to compile with JDK 21 +tasks.named("compileItJava") { + javaCompiler.set( + javaToolchains.compilerFor { + languageVersion.set(JavaLanguageVersion.of(21)) + }, + ) + sourceCompatibility = "21" + targetCompatibility = "21" +} + +val integrationTest by tasks.registering(Test::class) { + description = "Runs integration tests that require external models like S3" + group = "verification" + testClassesDirs = sourceSets["it"].output.classesDirs + classpath = sourceSets["it"].runtimeClasspath + dependsOn(tasks.jar) + shouldRunAfter(tasks.test) + + // Run with JDK 21 + javaLauncher.set( + javaToolchains.launcherFor { + languageVersion.set(JavaLanguageVersion.of(21)) + }, + ) +} + +tasks.test { + finalizedBy(integrationTest) +} + +tasks.named("check") { + dependsOn(integrationTest) } diff --git a/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java new file mode 100644 index 00000000000..dd5e88140a7 --- /dev/null +++ b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java @@ -0,0 +1,52 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.language.functions; + +import static org.junit.jupiter.api.Assertions.assertFalse; + +import java.util.List; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.evaluation.TestEvaluator; +import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; +import software.amazon.smithy.rulesengine.traits.EndpointTestCase; +import software.amazon.smithy.rulesengine.traits.EndpointTestsTrait; + +/** + * Runs the endpoint test cases against the transformed S3 model. We're fixed to a specific version for this test, + * but could periodically bump the version if needed. + */ +class S3TreeRewriterTest { + private static final ShapeId S3_SERVICE_ID = ShapeId.from("com.amazonaws.s3#AmazonS3"); + + private static EndpointRuleSet originalRules; + private static List testCases; + + @BeforeAll + static void loadS3Model() { + Model model = Model.assembler() + .discoverModels() + .assemble() + .unwrap(); + + ServiceShape s3Service = model.expectShape(S3_SERVICE_ID, ServiceShape.class); + originalRules = s3Service.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet(); + testCases = s3Service.expectTrait(EndpointTestsTrait.class).getTestCases(); + } + + @Test + void transformPreservesEndpointTestSemantics() { + assertFalse(testCases.isEmpty(), "S3 model should have endpoint test cases"); + + EndpointRuleSet transformed = S3TreeRewriter.transform(originalRules); + for (EndpointTestCase testCase : testCases) { + TestEvaluator.evaluate(transformed, testCase); + } + } +} diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/AwsConditionProbability.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/AwsConditionProbability.java index 42b2344e8ef..02dbe32fa4e 100644 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/AwsConditionProbability.java +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/AwsConditionProbability.java @@ -26,12 +26,21 @@ public double applyAsDouble(Condition condition) { // Region is almost always provided if (s.contains("isSet(Region)")) { - return 0.95; + return 0.96; } // Endpoint override is rare if (s.contains("isSet(Endpoint)")) { - return 0.1; + return 0.2; + } + + // S3 Express is rare (includes ITE variables from S3TreeRewriter) + if (s.contains("S3Express") || s.contains("--x-s3") + || s.contains("--xa-s3") + || s.contains("s3e_fips") + || s.contains("s3e_ds") + || s.contains("s3e_auth")) { + return 0.001; } // Most isSet checks on optional params succeed moderately @@ -48,11 +57,6 @@ public double applyAsDouble(Condition condition) { return 0.05; } - // S3 Express is relatively rare - if (s.contains("S3Express") || s.contains("--x-s3") || s.contains("--xa-s3")) { - return 0.1; - } - // ARN-based buckets are uncommon if (s.contains("parseArn") || s.contains("arn:")) { return 0.15; diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java new file mode 100644 index 00000000000..f748a289f56 --- /dev/null +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java @@ -0,0 +1,633 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.language.functions; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.logging.Logger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import software.amazon.smithy.model.node.StringNode; +import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Split; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Substring; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.utils.SmithyInternalApi; + +/** + * Rewrites S3 endpoint rules to use canonical, position-independent expressions. + * + *

This is a BDD pre-processing transform that makes the rules tree larger but enables dramatically better + * BDD compilation. It solves the "SSA Trap" problem where semantically identical operations appear as syntactically + * different expressions, preventing the BDD compiler from recognizing sharing opportunities. + * + *

Internal use only

+ *

Ideally this transform is deleted one day, and the rules that source it adopt these techniques (hopefully we + * don't look back on this comment and laugh in 5 years). If/when that happens, this class will be deleted, whether + * it breaks a consumer that uses it or not. + * + *

Trade-off: Larger Rules, Smaller BDD

+ *

This transform would be counterproductive for rule tree interpretation, but is highly beneficial when a + * BDD compiler processes the output. It adds ITE (if-then-else) conditions to compute URL segments and auth scheme + * names, increasing rule tree size by ~30%. However, this enables the BDD compiler to deduplicate endpoints that + * were previously considered distinct, as of writing, reducing BDD results and node counts both by ~43%. + * + *

The key insight is that the BDD deduplicates by endpoint identity (URL template + properties). By making + * URL templates identical through variable substitution, endpoints that differed only in FIPS/DualStack/auth variants + * collapse into a single BDD result. + * + *

Transformations performed:

+ * + *

AZ Extraction Canonicalization

+ * + *

The original rules extract the availability zone ID using position-dependent substring operations. + * Different bucket name lengths result in different extraction positions, creating 10+ SSA variants that can't + * be shared in the BDD. + * + *

Before: Position-dependent substring extraction + *

{@code
+ * {
+ *   "conditions": [
+ *     {
+ *       "fn": "substring",
+ *       "argv": [{"ref": "Bucket"}, 6, 14, true],
+ *       "assign": "s3expressAvailabilityZoneId"
+ *     }
+ *   ],
+ *   "rules": [...]
+ * }
+ * // Another branch with different positions:
+ * {
+ *   "conditions": [
+ *     {
+ *       "fn": "substring",
+ *       "argv": [{"ref": "Bucket"}, 6, 20, true],
+ *       "assign": "s3expressAvailabilityZoneId"
+ *     }
+ *   ],
+ *   "rules": [...]
+ * }
+ * }
+ * + *

After: Position-independent split-based extraction + *

{@code
+ * {
+ *   "conditions": [
+ *     {
+ *       "fn": "getAttr",
+ *       "argv": [
+ *         {"fn": "split", "argv": [{"ref": "Bucket"}, "--", 0]},
+ *         "[1]"
+ *       ],
+ *       "assign": "s3expressAvailabilityZoneId"
+ *     }
+ *   ],
+ *   "rules": [...]
+ * }
+ * }
+ * + *

All branches now use the identical expression {@code split(Bucket, "--")[1]}, enabling + * the BDD compiler to share nodes across all S3Express bucket handling paths. Because the expression only interacts + * with Bucket, a constant value, there's no SSA transform performed on these expressions. + * + *

URL Canonicalization

+ * + *

S3Express endpoints (currently) have 4 URL variants based on UseFIPS and UseDualStack flags. This creates + * duplicate endpoints that differ only in URL structure. + * + *

Before: Separate endpoints for each FIPS/DualStack combination + *

{@code
+ * // Branch 1: FIPS + DualStack
+ * {
+ *   "conditions": [
+ *     {"fn": "booleanEquals", "argv": [{"ref": "UseFIPS"}, true]},
+ *     {"fn": "booleanEquals", "argv": [{"ref": "UseDualStack"}, true]}
+ *   ],
+ *   "endpoint": {
+ *     "url": "https://{Bucket}.s3express-fips-{s3expressAvailabilityZoneId}.dualstack.{Region}.amazonaws.com"
+ *   }
+ * }
+ * // Branch 2: FIPS only
+ * {
+ *   "conditions": [
+ *     {"fn": "booleanEquals", "argv": [{"ref": "UseFIPS"}, true]}
+ *   ],
+ *   "endpoint": {
+ *     "url": "https://{Bucket}.s3express-fips-{s3expressAvailabilityZoneId}.{Region}.amazonaws.com"
+ *   }
+ * }
+ * // Branch 3: DualStack only
+ * // Branch 4: Neither
+ * }
+ * + *

After: Single endpoint with ITE-computed URL segments + *

{@code
+ * {
+ *   "conditions": [
+ *     {"fn": "ite", "argv": [{"ref": "UseFIPS"}, "-fips", ""], "assign": "_s3e_fips"},
+ *     {"fn": "ite", "argv": [{"ref": "UseDualStack"}, ".dualstack", ""], "assign": "_s3e_ds"}
+ *   ],
+ *   "endpoint": {
+ *     "url": "https://{Bucket}.s3express{_s3e_fips}-{s3expressAvailabilityZoneId}{_s3e_ds}.{Region}.amazonaws.com"
+ *   }
+ * }
+ * }
+ * + *

The ITE conditions compute values branchlessly. The BDD sifting optimization naturally places these rare + * S3Express-specific conditions late in the decision tree. + * + *

Auth Scheme Canonicalization

+ * + *

S3Express endpoints use different auth schemes based on DisableS3ExpressSessionAuth. + * This creates duplicate endpoints differing only in auth scheme name. + * + *

Before: Separate auth scheme names + *

{@code
+ * // When DisableS3ExpressSessionAuth is true:
+ * "authSchemes": [{"name": "sigv4", "signingName": "s3express", ...}]
+ *
+ * // When DisableS3ExpressSessionAuth is false/unset:
+ * "authSchemes": [{"name": "sigv4-s3express", "signingName": "s3express", ...}]
+ * }
+ * + *

After: ITE-computed auth scheme name + *

{@code
+ * {
+ *   "conditions": [
+ *     {
+ *       "fn": "ite",
+ *       "argv": [
+ *         {"fn": "coalesce", "argv": [{"ref": "DisableS3ExpressSessionAuth"}, false]},
+ *         "sigv4",
+ *         "sigv4-s3express"
+ *       ],
+ *       "assign": "_s3e_auth"
+ *     }
+ *   ],
+ *   "endpoint": {
+ *     "properties": {
+ *       "authSchemes": [{"name": "{_s3e_auth}", "signingName": "s3express", ...}]
+ *     }
+ *   }
+ * }
+ * }
+ */ +@SmithyInternalApi +public final class S3TreeRewriter { + private static final Logger LOGGER = Logger.getLogger(S3TreeRewriter.class.getName()); + + // Variable names for the computed suffixes + private static final String VAR_FIPS = "_s3e_fips"; + private static final String VAR_DS = "_s3e_ds"; + private static final String VAR_AUTH = "_s3e_auth"; + + // Suffix values used in the URI templates + private static final String FIPS_SUFFIX = "-fips"; + private static final String DS_SUFFIX = ".dualstack"; + private static final String EMPTY_SUFFIX = ""; + + // Auth scheme values used with s3-express + private static final String AUTH_SIGV4 = "sigv4"; + private static final String AUTH_SIGV4_S3EXPRESS = "sigv4-s3express"; + + // Property and parameter identifiers + private static final Identifier ID_AUTH_SCHEMES = Identifier.of("authSchemes"); + private static final Identifier ID_NAME = Identifier.of("name"); + private static final Identifier ID_BACKEND = Identifier.of("backend"); + private static final Identifier ID_BUCKET = Identifier.of("Bucket"); + private static final Identifier ID_AZ_ID = Identifier.of("s3expressAvailabilityZoneId"); + private static final Identifier ID_USE_FIPS = Identifier.of("UseFIPS"); + private static final Identifier ID_USE_DUAL_STACK = Identifier.of("UseDualStack"); + private static final Identifier ID_DISABLE_S3EXPRESS_SESSION_AUTH = Identifier.of("DisableS3ExpressSessionAuth"); + + // Auth scheme name literal shared across all rewritten endpoints + private static final Literal AUTH_NAME_LITERAL = Literal.stringLiteral(Template.fromString("{" + VAR_AUTH + "}")); + + // Patterns to match S3Express bucket endpoint URLs (with AZ) + // Format: https://{Bucket}.s3express[-fips]-{AZ}[.dualstack].{Region}.amazonaws.com + // (negative lookahead (?!dualstack) prevents matching dualstack variants in non-DS patterns) + private static final Pattern S3EXPRESS_FIPS_DS = Pattern.compile("(s3express)-fips-([^.]+)\\.dualstack\\.(.+)$"); + private static final Pattern S3EXPRESS_FIPS = Pattern.compile("(s3express)-fips-([^.]+)\\.(?!dualstack)(.+)$"); + private static final Pattern S3EXPRESS_DS = Pattern.compile("(s3express)-([^.]+)\\.dualstack\\.(.+)$"); + private static final Pattern S3EXPRESS_PLAIN = Pattern.compile("(s3express)-([^.]+)\\.(?!dualstack)(.+)$"); + + // Patterns to match S3Express control plane URLs (no AZ) + // Format: https://s3express-control[-fips][.dualstack].{Region}.amazonaws.com + private static final Pattern S3EXPRESS_CONTROL_FIPS_DS = Pattern.compile( + "(s3express-control)-fips\\.dualstack\\.(.+)$"); + private static final Pattern S3EXPRESS_CONTROL_FIPS = Pattern.compile( + "(s3express-control)-fips\\.(?!dualstack)(.+)$"); + private static final Pattern S3EXPRESS_CONTROL_DS = Pattern.compile( + "(s3express-control)\\.dualstack\\.(.+)$"); + private static final Pattern S3EXPRESS_CONTROL_PLAIN = Pattern.compile( + "(s3express-control)\\.(?!dualstack)(.+)$"); + + // Cached canonical expression for AZ extraction: split(Bucket, "--", 0) + private static final Split BUCKET_SPLIT = Split.ofExpressions( + Expression.getReference(ID_BUCKET), + Expression.of("--"), + Expression.of(0)); + + private int rewrittenCount = 0; + private int totalS3ExpressCount = 0; + + private S3TreeRewriter() {} + + /** + * Transforms the given endpoint rule set using canonical expressions. + * + * @param ruleSet the rule set to transform + * @return the transformed rule set + */ + public static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + return new S3TreeRewriter().run(ruleSet); + } + + private EndpointRuleSet run(EndpointRuleSet ruleSet) { + List transformedRules = new ArrayList<>(); + for (Rule rule : ruleSet.getRules()) { + transformedRules.add(transformRule(rule)); + } + + LOGGER.info(() -> String.format( + "S3 tree rewriter: %s/%s S3Express endpoints rewritten", + rewrittenCount, + totalS3ExpressCount)); + + return EndpointRuleSet.builder() + .sourceLocation(ruleSet.getSourceLocation()) + .parameters(ruleSet.getParameters()) + .rules(transformedRules) + .version(ruleSet.getVersion()) + .build(); + } + + private Rule transformRule(Rule rule) { + if (rule instanceof TreeRule) { + TreeRule tr = (TreeRule) rule; + // Transform conditions + List transformedConditions = transformConditions(tr.getConditions()); + List transformedChildren = new ArrayList<>(); + for (Rule child : tr.getRules()) { + transformedChildren.add(transformRule(child)); + } + return Rule.builder().conditions(transformedConditions).treeRule(transformedChildren); + } else if (rule instanceof EndpointRule) { + return rewriteEndpoint((EndpointRule) rule); + } else { + // Error rules pass through unchanged + return rule; + } + } + + private List transformConditions(List conditions) { + List result = new ArrayList<>(conditions.size()); + for (Condition cond : conditions) { + result.add(transformCondition(cond)); + } + return result; + } + + /** + * Transforms a single condition. + * + *

Handles: + *

+     * AZ extraction: substring(Bucket, N, M) -> split(Bucket, "--")[1]
+     * 
+ * + *

Note: Delimiter checks (s3expressAvailabilityZoneDelim) are not currently transformed because they're part + * of a complex fallback structure, and changing them breaks control flow. Possibly something we can improve, or + * wait until the upstream rules are optimized. + */ + private Condition transformCondition(Condition cond) { + // Is this a condition fishing for delimiters? + if (cond.getResult().isPresent() + && ID_AZ_ID.equals(cond.getResult().get()) + && cond.getFunction() instanceof Substring + && isSubstringOnBucket((Substring) cond.getFunction())) { + // Replace with split-based extraction: split(Bucket, "--")[1] + GetAttr azExpr = GetAttr.ofExpressions(BUCKET_SPLIT, "[1]"); + return cond.toBuilder().fn(azExpr).build(); + } + + return cond; + } + + private boolean isSubstringOnBucket(Substring substring) { + List args = substring.getArguments(); + if (args.isEmpty()) { + return false; + } + + Expression target = args.get(0); + return target instanceof Reference && ID_BUCKET.equals(((Reference) target).getName()); + } + + // Creates ITE conditions for branchless S3Express variable computation. + private List createIteConditions() { + List conditions = new ArrayList<>(); + conditions.add(createIteAssignment(VAR_FIPS, Expression.getReference(ID_USE_FIPS), FIPS_SUFFIX, EMPTY_SUFFIX)); + conditions.add(createIteAssignment( + VAR_DS, + Expression.getReference(ID_USE_DUAL_STACK), + DS_SUFFIX, + EMPTY_SUFFIX)); + // Auth scheme: sigv4 when session auth disabled, sigv4-s3express otherwise + Expression sessionAuthDisabled = Coalesce.ofExpressions( + Expression.getReference(ID_DISABLE_S3EXPRESS_SESSION_AUTH), + Expression.of(false)); + conditions.add(createIteAssignment(VAR_AUTH, sessionAuthDisabled, AUTH_SIGV4, AUTH_SIGV4_S3EXPRESS)); + return conditions; + } + + // Creates an ITE-based assignment condition. + private Condition createIteAssignment(String varName, Expression condition, String trueValue, String falseValue) { + return Condition.builder() + .fn(Ite.ofStrings(condition, trueValue, falseValue)) + .result(varName) + .build(); + } + + // Rewrites an endpoint rule to use canonical S3Express URLs and auth schemes. + private Rule rewriteEndpoint(EndpointRule rule) { + Endpoint endpoint = rule.getEndpoint(); + Expression urlExpr = endpoint.getUrl(); + + // Extract the raw URL string from the expression (IFF it's a static string, rarely is anything else). + String urlStr = extractUrlString(urlExpr); + if (urlStr == null) { + return rule; + } + + // Check if this is an S3Express endpoint by URL or backend property. + // Note: while `contains("s3express")` is broad and could theoretically match path/query components, + // the subsequent matchUrl() call validates the hostname pattern before any rewriting occurs. + boolean isS3ExpressUrl = urlStr.contains("s3express"); + boolean isS3ExpressBackend = isS3ExpressBackend(endpoint); + + if (!isS3ExpressUrl && !isS3ExpressBackend) { + return rule; + } + + totalS3ExpressCount++; + + // For URL override endpoints (backend=S3Express but URL doesn't match s3express hostname), + // just canonicalize the auth scheme - no URL rewriting needed + if (isS3ExpressBackend && !isS3ExpressUrl) { + // Canonicalize auth scheme to use {_s3e_auth} + Map newProperties = canonicalizeAuthScheme(endpoint.getProperties()); + + if (newProperties == endpoint.getProperties()) { + // No changes needed + return rule; + } + + rewrittenCount++; + + Endpoint newEndpoint = Endpoint.builder() + .url(urlExpr) + .headers(endpoint.getHeaders()) + .properties(newProperties) + .sourceLocation(endpoint.getSourceLocation()) + .build(); + + // Add auth ITE condition for URL override endpoints + List allConditions = new ArrayList<>(rule.getConditions()); + allConditions.add(createAuthIteCondition()); + + return Rule.builder() + .conditions(allConditions) + .endpoint(newEndpoint); + } + + // Standard S3Express URL - match and rewrite + UrlMatchResult match = matchUrl(urlStr); + if (match == null) { + return rule; + } + + rewrittenCount++; + + // Rewrite the URL to use the ITE-assigned variables + String newUrl = match.rewriteUrl(); + + // Canonicalize auth scheme for bucket endpoints (not control plane) + // Control plane always uses sigv4, bucket endpoints vary based on DisableS3ExpressSessionAuth + Map newProperties = endpoint.getProperties(); + if (match instanceof BucketUrlMatchResult) { + newProperties = canonicalizeAuthScheme(endpoint.getProperties()); + } + + // Build the new endpoint with canonicalized URL and properties + Endpoint newEndpoint = Endpoint.builder() + .url(Expression.of(newUrl)) + .headers(endpoint.getHeaders()) + .properties(newProperties) + .sourceLocation(endpoint.getSourceLocation()) + .build(); + + // Add ITE conditions: original conditions first, then ITE conditions at the end. + List allConditions = new ArrayList<>(rule.getConditions()); + allConditions.addAll(createIteConditions()); + + return Rule.builder() + .conditions(allConditions) + .endpoint(newEndpoint); + } + + // Checks if the endpoint has `backend` property set to "S3Express". + private boolean isS3ExpressBackend(Endpoint endpoint) { + Literal backend = endpoint.getProperties().get(ID_BACKEND); + if (backend == null) { + return false; + } + + return backend.asStringLiteral() + .filter(Template::isStatic) + .map(t -> "S3Express".equalsIgnoreCase(t.expectLiteral())) + .orElse(false); + } + + // Creates just the auth ITE condition for URL override endpoints. + private Condition createAuthIteCondition() { + // `DisableS3ExpressSessionAuth` is nullable, so we need to coalesce it to have a false default. Fix upstream? + Expression isSessionAuthDisabled = Coalesce.ofExpressions( + Expression.getReference(ID_DISABLE_S3EXPRESS_SESSION_AUTH), + Expression.of(false)); + return createIteAssignment(VAR_AUTH, isSessionAuthDisabled, AUTH_SIGV4, AUTH_SIGV4_S3EXPRESS); + } + + // Canonicalizes the authScheme name in endpoint properties to use the ITE variable. + private Map canonicalizeAuthScheme(Map properties) { + Literal authSchemes = properties.get(ID_AUTH_SCHEMES); + if (authSchemes == null) { + return properties; + } + + List schemes = authSchemes.asTupleLiteral().orElse(null); + if (schemes == null || schemes.isEmpty()) { + return properties; + } + + // Rewrite each auth scheme's name field + List newSchemes = new ArrayList<>(); + for (Literal scheme : schemes) { + Map record = scheme.asRecordLiteral().orElse(null); + if (record == null) { + // Auth is always a record, but maybe that changes in the future, so pass it through. + newSchemes.add(scheme); + continue; + } + + Literal nameLiteral = record.get(ID_NAME); + if (nameLiteral == null) { + // "name" should always be set, but pass through if not. + newSchemes.add(scheme); + continue; + } + + // Only transform string literals we recognize. + String name = nameLiteral.asStringLiteral() + .filter(Template::isStatic) + .map(Template::expectLiteral) + .orElse(null); + + // Only rewrite if it's one of the S3Express auth schemes + if (AUTH_SIGV4.equals(name) || AUTH_SIGV4_S3EXPRESS.equals(name)) { + Map newRecord = new LinkedHashMap<>(record); + newRecord.put(ID_NAME, AUTH_NAME_LITERAL); + newSchemes.add(Literal.recordLiteral(newRecord)); + } else { + newSchemes.add(scheme); + } + } + + Map newProperties = new LinkedHashMap<>(properties); + newProperties.put(ID_AUTH_SCHEMES, Literal.tupleLiteral(newSchemes)); + return newProperties; + } + + // Extracts the raw URL string from a URL expression. + private String extractUrlString(Expression urlExpr) { + return urlExpr.toNode().asStringNode().map(StringNode::getValue).orElse(null); + } + + // Matches an S3Express URL and returns the pattern match info. Tries to match in most specific order. + private UrlMatchResult matchUrl(String url) { + Matcher m; + + // First try control plane patterns (no AZ) since these are more specific + m = S3EXPRESS_CONTROL_FIPS_DS.matcher(url); + if (m.find()) { + return new ControlPlaneUrlMatchResult(url, m); + } + + m = S3EXPRESS_CONTROL_FIPS.matcher(url); + if (m.find()) { + return new ControlPlaneUrlMatchResult(url, m); + } + + m = S3EXPRESS_CONTROL_DS.matcher(url); + if (m.find()) { + return new ControlPlaneUrlMatchResult(url, m); + } + + m = S3EXPRESS_CONTROL_PLAIN.matcher(url); + if (m.find()) { + return new ControlPlaneUrlMatchResult(url, m); + } + + // Next, try bucket endpoint patterns (with AZ) + m = S3EXPRESS_FIPS_DS.matcher(url); + if (m.find()) { + return new BucketUrlMatchResult(url, m); + } + + m = S3EXPRESS_FIPS.matcher(url); + if (m.find()) { + return new BucketUrlMatchResult(url, m); + } + + m = S3EXPRESS_DS.matcher(url); + if (m.find()) { + return new BucketUrlMatchResult(url, m); + } + + m = S3EXPRESS_PLAIN.matcher(url); + if (m.find()) { + return new BucketUrlMatchResult(url, m); + } + + return null; + } + + /** + * Result of matching an S3Express URL pattern. + */ + private abstract static class UrlMatchResult { + protected final String prefix; + + UrlMatchResult(String prefix) { + this.prefix = prefix; + } + + abstract String rewriteUrl(); + } + + /** + * Match result for bucket endpoints (with AZ): {prefix}s3express{fips}-{AZ}{ds}.{region} + */ + private static final class BucketUrlMatchResult extends UrlMatchResult { + private final String s3express; + private final String az; + private final String regionSuffix; + + BucketUrlMatchResult(String url, Matcher m) { + super(url.substring(0, m.start())); + this.s3express = m.group(1); + this.az = m.group(2); + this.regionSuffix = m.group(3); + } + + @Override + String rewriteUrl() { + return String.format("%s%s{%s}-%s{%s}.%s", prefix, s3express, VAR_FIPS, az, VAR_DS, regionSuffix); + } + } + + /** + * Match result for control plane endpoints (no AZ): {prefix}s3express-control{fips}{ds}.{region} + */ + private static final class ControlPlaneUrlMatchResult extends UrlMatchResult { + private final String s3expressControl; + private final String regionSuffix; + + ControlPlaneUrlMatchResult(String url, Matcher m) { + super(url.substring(0, m.start())); + this.s3expressControl = m.group(1); + this.regionSuffix = m.group(2); + } + + @Override + String rewriteUrl() { + return String.format("%s%s{%s}{%s}.%s", prefix, s3expressControl, VAR_FIPS, VAR_DS, regionSuffix); + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java index 9e0dfb0fd39..2dda13db308 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java @@ -11,6 +11,7 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsValidHostLabel; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.ParseUrl; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Split; @@ -43,6 +44,7 @@ public List getLibraryFunctions() { Split.getDefinition(), StringEquals.getDefinition(), Substring.getDefinition(), + Ite.getDefinition(), UriEncode.getDefinition()); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java index 6e8d70a8771..efc82026a30 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java @@ -215,6 +215,12 @@ public Value visitStringEquals(Expression left, Expression right) { .equals(right.accept(this).expectStringValue())); } + @Override + public Value visitIte(Expression condition, Expression trueValue, Expression falseValue) { + boolean cond = condition.accept(this).expectBooleanValue().getValue(); + return cond ? trueValue.accept(this) : falseValue.accept(this); + } + @Override public Value visitGetAttr(GetAttr getAttr) { return getAttr.evaluate(getAttr.getTarget().accept(this)); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java index 1557b529b52..b4bbc93868f 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java @@ -4,10 +4,12 @@ */ package software.amazon.smithy.rulesengine.language.syntax.expressions; +import java.util.Arrays; import java.util.List; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionDefinition; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; import software.amazon.smithy.utils.SmithyUnstableApi; @@ -86,6 +88,18 @@ default R visitCoalesce(List expressions) { */ R visitStringEquals(Expression left, Expression right); + /** + * Visits an if-then-else (ITE) function. + * + * @param condition the boolean condition expression. + * @param trueValue the value if condition is true. + * @param falseValue the value if condition is false. + * @return the value from the visitor. + */ + default R visitIte(Expression condition, Expression trueValue, Expression falseValue) { + return visitLibraryFunction(Ite.getDefinition(), Arrays.asList(condition, trueValue, falseValue)); + } + /** * Visits a library function. * @@ -138,6 +152,11 @@ public R visitStringEquals(Expression left, Expression right) { return getDefault(); } + @Override + public R visitIte(Expression condition, Expression trueValue, Expression falseValue) { + return getDefault(); + } + @Override public R visitLibraryFunction(FunctionDefinition fn, List args) { return getDefault(); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java index 931e5d9f9dd..039d461350e 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Optional; import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.error.InnerParseError; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; @@ -81,10 +82,10 @@ public R accept(ExpressionVisitor visitor) { } @Override - public Type typeCheck(Scope scope) { + protected Type typeCheckLocal(Scope scope) throws InnerParseError { List args = getArguments(); if (args.size() < 2) { - throw new IllegalArgumentException("Coalesce requires at least 2 arguments, got " + args.size()); + throw new InnerParseError("Coalesce requires at least 2 arguments, got " + args.size()); } // Get the first argument's type as the baseline @@ -98,7 +99,7 @@ public Type typeCheck(Scope scope) { Type innerType = getInnerType(argType); if (!innerType.equals(baseInnerType)) { - throw new IllegalArgumentException(String.format( + throw new InnerParseError(String.format( "Type mismatch in coalesce at argument %d: expected %s but got %s", i + 1, baseInnerType, diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java new file mode 100644 index 00000000000..90acc71da01 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java @@ -0,0 +1,175 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language.syntax.expressions.functions; + +import java.util.Arrays; +import java.util.List; +import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.error.InnerParseError; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.evaluation.value.Value; +import software.amazon.smithy.rulesengine.language.syntax.ToExpression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.ExpressionVisitor; +import software.amazon.smithy.utils.SmithyUnstableApi; + +/** + * An if-then-else (ITE) function that returns one of two values based on a boolean condition. + * + *

This function is critical for avoiding SSA (Static Single Assignment) fragmentation in BDD compilation. + * By computing conditional values atomically without branching, it prevents the graph explosion that occurs when + * boolean flags like UseFips or UseDualStack create divergent paths with distinct variable identities. + * + *

Semantics: {@code ite(condition, trueValue, falseValue)} + *

    + *
  • If condition is true, returns trueValue
  • + *
  • If condition is false, returns falseValue
  • + *
  • The condition must be a non-optional boolean (use coalesce to provide a default if needed)
  • + *
+ * + *

Type checking rules (least upper bound of nullability): + *

    + *
  • {@code ite(Boolean, T, T) => T} - both non-optional, result is non-optional
  • + *
  • {@code ite(Boolean, T, Optional) => Optional} - any optional makes result optional
  • + *
  • {@code ite(Boolean, Optional, T) => Optional} - any optional makes result optional
  • + *
  • {@code ite(Boolean, Optional, Optional) => Optional} - both optional, result is optional
  • + *
+ * + *

Available since: rules engine 1.1. + */ +@SmithyUnstableApi +public final class Ite extends LibraryFunction { + public static final String ID = "ite"; + private static final Definition DEFINITION = new Definition(); + + private Ite(FunctionNode functionNode) { + super(DEFINITION, functionNode); + } + + /** + * Gets the {@link FunctionDefinition} implementation. + * + * @return the function definition. + */ + public static Definition getDefinition() { + return DEFINITION; + } + + /** + * Creates a {@link Ite} function from the given expressions. + * + * @param condition the boolean condition to evaluate + * @param trueValue the value to return if condition is true + * @param falseValue the value to return if condition is false + * @return The resulting {@link Ite} function. + */ + public static Ite ofExpressions(ToExpression condition, ToExpression trueValue, ToExpression falseValue) { + return DEFINITION.createFunction(FunctionNode.ofExpressions(ID, condition, trueValue, falseValue)); + } + + /** + * Creates a {@link Ite} function with a reference condition and string values. + * + * @param conditionRef the reference to a boolean parameter + * @param trueValue the string value if condition is true + * @param falseValue the string value if condition is false + * @return The resulting {@link Ite} function. + */ + public static Ite ofStrings(ToExpression conditionRef, String trueValue, String falseValue) { + return ofExpressions(conditionRef, Expression.of(trueValue), Expression.of(falseValue)); + } + + @Override + public RulesVersion availableSince() { + return RulesVersion.V1_1; + } + + @Override + public R accept(ExpressionVisitor visitor) { + return visitor.visitIte(getArguments().get(0), getArguments().get(1), getArguments().get(2)); + } + + @Override + protected Type typeCheckLocal(Scope scope) throws InnerParseError { + List args = getArguments(); + if (args.size() != 3) { + throw new InnerParseError("ITE requires exactly 3 arguments, got " + args.size()); + } + + // Check condition is a boolean (non-optional) + Type conditionType = args.get(0).typeCheck(scope); + if (!conditionType.equals(Type.booleanType())) { + throw new InnerParseError(String.format( + "ITE condition must be a non-optional Boolean, got %s. " + + "Use coalesce to provide a default for optional booleans.", + conditionType)); + } + + // Get trueValue and falseValue types + Type trueType = args.get(1).typeCheck(scope); + Type falseType = args.get(2).typeCheck(scope); + + // Extract base types (unwrap Optional if present) + Type trueBaseType = getInnerType(trueType); + Type falseBaseType = getInnerType(falseType); + + // Base types must match + if (!trueBaseType.equals(falseBaseType)) { + throw new InnerParseError(String.format( + "ITE branches must have the same base type: true branch is %s, false branch is %s", + trueBaseType, + falseBaseType)); + } + + // Result is optional if EITHER branch is optional (least upper bound) + boolean resultIsOptional = (trueType instanceof OptionalType) || (falseType instanceof OptionalType); + return resultIsOptional ? Type.optionalType(trueBaseType) : trueBaseType; + } + + private static Type getInnerType(Type t) { + return (t instanceof OptionalType) ? ((OptionalType) t).inner() : t; + } + + /** + * A {@link FunctionDefinition} for the {@link Ite} function. + */ + public static final class Definition implements FunctionDefinition { + private Definition() {} + + @Override + public String getId() { + return ID; + } + + @Override + public List getArguments() { + // Actual type checking is done in typeCheck override + return Arrays.asList(Type.booleanType(), Type.anyType(), Type.anyType()); + } + + @Override + public Type getReturnType() { + // Actual return type is computed in typeCheck override + return Type.anyType(); + } + + @Override + public Value evaluate(List arguments) { + throw new UnsupportedOperationException("ITE evaluation is handled by ExpressionVisitor"); + } + + @Override + public Ite createFunction(FunctionNode functionNode) { + return new Ite(functionNode); + } + + @Override + public int getCost() { + return 10; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java index d7c76f7feec..be58bb2415d 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java @@ -28,7 +28,7 @@ public T accept(RuleValueVisitor visitor) { @Override protected Type typecheckValue(Scope scope) { - throw new UnsupportedOperationException("NO_MATCH is a sentinel"); + return Type.anyType(); } @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CostOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CostOptimization.java index 11dbcc8b119..ea80fb56fbb 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CostOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CostOptimization.java @@ -293,7 +293,7 @@ public static final class Builder implements SmithyBuilder { private Cfg cfg; private ConditionCostModel costModel; private ToDoubleFunction trueProbability; - private double maxAllowedGrowth = 0.1; + private double maxAllowedGrowth = 0.08; private int maxRounds = 30; private int topK = 50; @@ -333,7 +333,7 @@ public Builder trueProbability(ToDoubleFunction trueProbability) { } /** - * Sets the maximum allowed node growth as a fraction (default 0.1 or 10%). + * Sets the maximum allowed node growth as a fraction (default 0.08 or 8%). * * @param maxAllowedGrowth maximum growth (0.0 = no growth, 0.1 = 10% growth) * @return the builder diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java index 47666837125..b779eceb4f3 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java @@ -6,16 +6,14 @@ import java.util.ArrayList; import java.util.Arrays; -import java.util.Comparator; -import java.util.IdentityHashMap; import java.util.List; -import java.util.Map; import java.util.function.Function; import java.util.logging.Logger; import java.util.stream.Collectors; import java.util.stream.IntStream; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.ConditionCostModel; import software.amazon.smithy.rulesengine.logic.cfg.Cfg; import software.amazon.smithy.rulesengine.logic.cfg.ConditionDependencyGraph; import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; @@ -34,14 +32,21 @@ public final class SiftingOptimization implements Function { private static final Logger LOGGER = Logger.getLogger(SiftingOptimization.class.getName()); - // When to use a parallel stream private static final int PARALLEL_THRESHOLD = 7; + // Early termination: number of passes to track for plateau detection + private static final int PLATEAU_HISTORY_SIZE = 3; + private static final double PLATEAU_THRESHOLD = 0.5; + // Thread-local BDD builders to avoid allocation overhead private final ThreadLocal threadBuilder = ThreadLocal.withInitial(BddBuilder::new); private final Cfg cfg; private final ConditionDependencyGraph dependencyGraph; + private final ConditionCostModel costModel = ConditionCostModel.createDefault();; + + // Reusable cost estimator, created once per optimization run + private BddCostEstimator costEstimator; // Tiered optimization settings private final int coarseMinNodes; @@ -54,7 +59,7 @@ public final class SiftingOptimization implements Function= HIGH_THRESHOLD) { + sampleRate = Math.max(1, sampleRate - 1); + maxPositions = Math.min(base.maxPositions * 2, maxPositions + 5); + nearbyRadius = Math.min(base.nearbyRadius + 6, nearbyRadius + 2); + bonusPasses = Math.min(bonusPasses + 2, 6); + return true; + } else if (reductionPercent < LOW_THRESHOLD) { + sampleRate = Math.min(base.sampleRate * 2, sampleRate + 2); + maxPositions = Math.max(base.maxPositions / 2, maxPositions - 3); + nearbyRadius = Math.max(0, nearbyRadius - 2); + bonusPasses = Math.max(0, bonusPasses - 2); + } + return false; + } + } + private SiftingOptimization(Builder builder) { this.cfg = SmithyBuilder.requiredState("cfg", builder.cfg); this.coarseMinNodes = builder.coarseMinNodes; @@ -108,382 +151,477 @@ public EndpointBddTrait apply(EndpointBddTrait trait) { private EndpointBddTrait doApply(EndpointBddTrait trait) { LOGGER.info("Starting BDD sifting optimization"); long startTime = System.currentTimeMillis(); - OptimizationState state = initializeOptimization(trait); + State state = initializeOptimization(trait); LOGGER.info(String.format("Initial size: %d nodes", state.initialSize)); - state = runOptimizationStage("Coarse", state, OptimizationEffort.COARSE, coarseMinNodes, coarseMaxPasses, 4.0); - state = runOptimizationStage("Medium", state, OptimizationEffort.MEDIUM, mediumMinNodes, mediumMaxPasses, 1.5); + // Create cost estimator once for the entire optimization run + this.costEstimator = new BddCostEstimator(state.orderView, costModel, null); + + runOptimizationStage("Coarse", state, OptimizationEffort.COARSE, coarseMinNodes, coarseMaxPasses, 4.0); + runOptimizationStage("Medium", state, OptimizationEffort.MEDIUM, mediumMinNodes, mediumMaxPasses, 1.5); if (state.currentSize <= granularMaxNodes) { - state = runOptimizationStage("Granular", state, OptimizationEffort.GRANULAR, 0, granularMaxPasses, 0.0); - } else { - LOGGER.info("Skipping granular stage - too large"); + runOptimizationStage("Granular", state, OptimizationEffort.GRANULAR, 0, granularMaxPasses, 0.0); } - state = runAdjacentSwaps(state); + runBlockMoves(state); + runAdjacentSwaps(state); double totalTimeInSeconds = (System.currentTimeMillis() - startTime) / 1000.0; - if (state.bestSize >= state.initialSize) { + if (state.currentSize >= state.initialSize) { LOGGER.info(String.format("No improvements found in %fs", totalTimeInSeconds)); return trait; } LOGGER.info(String.format("Optimization complete: %d -> %d nodes (%.1f%% total reduction) in %fs", state.initialSize, - state.bestSize, - (1.0 - (double) state.bestSize / state.initialSize) * 100, + state.currentSize, + (1.0 - (double) state.currentSize / state.initialSize) * 100, totalTimeInSeconds)); return trait.toBuilder().conditions(state.orderView).results(state.results).bdd(state.bestBdd).build(); } - private OptimizationState initializeOptimization(EndpointBddTrait trait) { - // Use the trait's existing ordering as the starting point + private State initializeOptimization(EndpointBddTrait trait) { List initialOrder = new ArrayList<>(trait.getConditions()); Condition[] order = initialOrder.toArray(new Condition[0]); List orderView = Arrays.asList(order); Bdd bdd = trait.getBdd(); int initialSize = bdd.getNodeCount() - 1; - return new OptimizationState(order, orderView, bdd, initialSize, initialSize, trait.getResults()); + return new State(order, orderView, bdd, initialSize, trait.getResults()); } - private OptimizationState runOptimizationStage( + private void runOptimizationStage( String stageName, - OptimizationState state, + State state, OptimizationEffort effort, - int targetNodeCount, + int targetNodes, int maxPasses, - double minReductionPercent + double minReduction ) { - if (targetNodeCount > 0 && state.currentSize <= targetNodeCount) { - return state; + if (targetNodes > 0 && state.currentSize <= targetNodes) { + return; } - LOGGER.info(String.format("Stage: %s optimization (%d nodes%s)", - stageName, - state.currentSize, - targetNodeCount > 0 ? String.format(", target < %d", targetNodeCount) : "")); + LOGGER.info(String.format("Stage: %s (%d nodes)", stageName, state.currentSize)); + + AdaptiveEffort ae = new AdaptiveEffort(effort); + double[] history = new double[PLATEAU_HISTORY_SIZE]; + int historyIdx = 0, consecutiveLow = 0; + + for (int pass = 1; pass <= maxPasses + ae.bonusPasses; pass++) { + if (targetNodes > 0 && state.currentSize <= targetNodes) { + break; + } - OptimizationState currentState = state; - for (int pass = 1; pass <= maxPasses; pass++) { - if (targetNodeCount > 0 && currentState.currentSize <= targetNodeCount) { + int startSize = state.currentSize; + PassContext result = runPass(state, ae); + if (result.improvements == 0) { break; } - int passStartSize = currentState.currentSize; - OptimizationResult result = runPass(currentState, effort); - if (result.improved) { - currentState = currentState.withResult(result.bdd, result.size, result.results); - double reduction = (1.0 - (double) result.size / passStartSize) * 100; - LOGGER.fine(String.format("%s pass %d: %d -> %d nodes (%.1f%% reduction)", - stageName, - pass, - passStartSize, - result.size, - reduction)); - if (minReductionPercent > 0 && reduction < minReductionPercent) { - LOGGER.fine(String.format("%s optimization yielding diminishing returns", stageName)); + state.update(result.bestBdd, result.bestSize, result.bestResults); + double reduction = (1.0 - (double) result.bestSize / startSize) * 100; + + history[historyIdx++ % PLATEAU_HISTORY_SIZE] = reduction; + if (historyIdx >= PLATEAU_HISTORY_SIZE) { + boolean plateau = true; + for (double r : history) { + if (r >= PLATEAU_THRESHOLD) { + plateau = false; + break; + } + } + if (plateau) { break; } - } else { - LOGGER.fine(String.format("%s pass %d found no improvements", stageName, pass)); + } + + consecutiveLow = ae.adapt(reduction) ? 0 : (reduction < 2.0 ? consecutiveLow + 1 : 0); + if (consecutiveLow >= 2 || (minReduction > 0 && reduction < minReduction)) { break; } } - - return currentState; } - private OptimizationState runAdjacentSwaps(OptimizationState state) { + private void runBlockMoves(State state) { if (state.currentSize > granularMaxNodes) { - return state; + return; } + LOGGER.info("Running block moves"); - LOGGER.info("Running adjacent swaps optimization"); - OptimizationState currentState = state; - - // Run multiple sweeps until no improvement - for (int sweep = 1; sweep <= 3; sweep++) { - OptimizationContext context = new OptimizationContext(currentState, dependencyGraph); - int startSize = currentState.currentSize; + List> blocks = new ArrayList<>(); + for (List b : findDependencyBlocks(state.orderView)) { + if (b.size() >= 2 && b.size() <= 5) { + blocks.add(b); + } + } - for (int i = 0; i < currentState.order.length - 1; i++) { - // Adjacent swap requires both elements to be able to occupy each other's positions - if (context.constraints.canMove(i, i + 1) && context.constraints.canMove(i + 1, i)) { - BddCompilerSupport.move(currentState.order, i, i + 1); - BddCompilerSupport.BddCompilationResult compilationResult = - BddCompilerSupport.compile(cfg, currentState.orderView, threadBuilder.get()); - int swappedSize = compilationResult.bdd.getNodeCount() - 1; - if (swappedSize < context.bestSize) { - context = context.withImprovement( - new PositionResult(i + 1, - swappedSize, - compilationResult.bdd, - compilationResult.results)); - } else { - BddCompilerSupport.move(currentState.order, i + 1, i); // Swap back - } - } + for (List block : blocks) { + PassContext ctx = new PassContext(state, dependencyGraph); + Result r = tryBlockMove(block, ctx); + if (r != null && r.size < ctx.bestSize) { + state.update(r.bdd, r.size, r.results); } + } + } + + private List> findDependencyBlocks(List ordering) { + List> blocks = new ArrayList<>(); + if (ordering.isEmpty()) { + return blocks; + } - if (context.improvements > 0) { - currentState = currentState.withResult(context.bestBdd, context.bestSize, context.bestResults); - LOGGER.fine(String.format("Adjacent swaps sweep %d: %d -> %d nodes", - sweep, - startSize, - context.bestSize)); + List curr = new ArrayList<>(); + curr.add(0); + for (int i = 1; i < ordering.size(); i++) { + if (dependencyGraph.getDependencies(ordering.get(i)).contains(ordering.get(i - 1))) { + curr.add(i); } else { - break; + if (curr.size() >= 2) { + blocks.add(curr); + } + curr = new ArrayList<>(); + curr.add(i); } } - return currentState; + if (curr.size() >= 2) { + blocks.add(curr); + } + + return blocks; } - private OptimizationResult runPass(OptimizationState state, OptimizationEffort effort) { - OptimizationContext context = new OptimizationContext(state, dependencyGraph); + private Result tryBlockMove(List block, PassContext ctx) { + int blockStart = block.get(0), blockEnd = block.get(block.size() - 1), blockSize = block.size(); - List selectedConditions = IntStream.range(0, state.orderView.size()) - .filter(i -> i % effort.sampleRate == 0) - .mapToObj(state.orderView::get) - .collect(Collectors.toList()); + // Compute valid range considering all block members' constraints + int minPos = 0, maxPos = ctx.order.length - blockSize; + for (int idx : block) { + int offset = idx - blockStart; + minPos = Math.max(minPos, ctx.constraints.getMinValidPosition(idx) - offset); + maxPos = Math.min(maxPos, ctx.constraints.getMaxValidPosition(idx) - offset); + } + + if (minPos >= maxPos) { + return null; + } - for (Condition condition : selectedConditions) { - Integer varIdx = context.liveIndex.get(condition); - if (varIdx == null) { + // Try a few strategic positions: min, max, mid + int[] targets = {minPos, maxPos, minPos + (maxPos - minPos) / 2}; + Result best = null; + + for (int target : targets) { + if (target == blockStart) { continue; } - List positions = getStrategicPositions(varIdx, context.constraints, effort); - if (positions.isEmpty()) { + Condition[] candidate = ctx.order.clone(); + moveBlock(candidate, blockStart, blockEnd, target); + List candidateList = Arrays.asList(candidate); + + // Validate constraints + ConditionDependencyGraph.OrderConstraints nc = dependencyGraph.createOrderConstraints(candidateList); + boolean valid = true; + for (int j = 0; j < candidate.length; j++) { + if (nc.getMinValidPosition(j) > j || nc.getMaxValidPosition(j) < j) { + valid = false; + break; + } + } + + if (!valid) { continue; } - context = tryImprovePosition(context, varIdx, positions); + BddCompilerSupport.BddCompilationResult cr = + BddCompilerSupport.compile(cfg, candidateList, threadBuilder.get()); + int size = cr.bdd.getNodeCount() - 1; + double cost = computeCost(cr.bdd, candidateList); + if (best == null || size < best.size || (size == best.size && cost < best.cost)) { + best = new Result(target, size, cost, cr.bdd, cr.results); + } } + return best; + } + + /** + * Moves a contiguous block of elements from [start, end] to begin at targetStart. + */ + private static void moveBlock(Condition[] order, int start, int end, int targetStart) { + if (targetStart == start) { + return; + } + + int blockSize = end - start + 1; + Condition[] block = new Condition[blockSize]; + System.arraycopy(order, start, block, 0, blockSize); - return context.toResult(); + if (targetStart < start) { + // Move block earlier: shift elements [targetStart, start) to the right + System.arraycopy(order, targetStart, order, targetStart + blockSize, start - targetStart); + System.arraycopy(block, 0, order, targetStart, blockSize); + } else { + // Move block later: shift elements (end, targetStart + blockSize) to the left + int shiftStart = end + 1; + int shiftEnd = targetStart + blockSize; + if (shiftEnd > order.length) { + shiftEnd = order.length; + } + System.arraycopy(order, shiftStart, order, start, shiftEnd - shiftStart); + System.arraycopy(block, 0, order, targetStart, blockSize); + } } - private OptimizationContext tryImprovePosition(OptimizationContext context, int varIdx, List positions) { - PositionResult best = findBestPosition(positions, context, varIdx); - if (best != null && best.count <= context.bestSize) { // Accept ties - BddCompilerSupport.move(context.order, varIdx, best.position); - return context.withImprovement(best); + private void runAdjacentSwaps(State state) { + if (state.currentSize > granularMaxNodes) { + return; } - return context; + for (int sweep = 0; sweep < 3; sweep++) { + PassContext ctx = new PassContext(state, dependencyGraph); + for (int i = 0; i < state.order.length - 1; i++) { + // Adjacent swap requires both elements to be able to occupy each other's positions + if (ctx.constraints.canMove(i, i + 1) && ctx.constraints.canMove(i + 1, i)) { + BddCompilerSupport.move(state.order, i, i + 1); + BddCompilerSupport.BddCompilationResult cr = BddCompilerSupport.compile( + cfg, + state.orderView, + threadBuilder.get()); + int size = cr.bdd.getNodeCount() - 1; + if (size < ctx.bestSize) { + ctx.recordImprovement(new Result(i + 1, size, cr.bdd, cr.results, null)); + } else { + BddCompilerSupport.move(state.order, i + 1, i); + } + } + } + if (ctx.improvements == 0) { + break; + } + state.update(ctx.bestBdd, ctx.bestSize, ctx.bestResults); + } + } + + private PassContext runPass(State state, AdaptiveEffort effort) { + PassContext ctx = new PassContext(state, dependencyGraph); + int[] nodeCounts = computeNodeCountsPerVariable(state.bestBdd); + int[] selectedIndices = selectConditionsByPriority(state.orderView.size(), nodeCounts, effort.sampleRate); + + for (int varIdx : selectedIndices) { + List positions = getStrategicPositions(varIdx, ctx.constraints, effort, state.orderView.size()); + if (positions.isEmpty()) { + continue; + } + Result best = findBestPosition(positions, ctx, varIdx); + if (best != null && best.size <= ctx.bestSize) { + BddCompilerSupport.move(ctx.order, varIdx, best.position); + ctx.recordImprovement(best); + } + } + return ctx; + } + + /** + * Computes the number of BDD nodes testing each variable. + */ + private static int[] computeNodeCountsPerVariable(Bdd bdd) { + int[] counts = new int[bdd.getConditionCount()]; + for (int i = 0; i < bdd.getNodeCount(); i++) { + int v = bdd.getVariable(i); + if (v >= 0 && v < counts.length) { + counts[v]++; + } + } + return counts; + } + + private static int[] selectConditionsByPriority(int n, int[] nodeCounts, int sampleRate) { + int[] indices = IntStream.range(0, n) + .boxed() + .sorted((a, b) -> Integer.compare(nodeCounts[b], nodeCounts[a])) + .mapToInt(i -> i) + .toArray(); + return sampleRate <= 1 ? indices : Arrays.copyOf(indices, Math.max(1, n / sampleRate)); } - private PositionResult findBestPosition(List positions, OptimizationContext ctx, int varIdx) { - return (positions.size() > PARALLEL_THRESHOLD ? positions.parallelStream() : positions.stream()) + /** Two-pass position finder: compile candidates, then cost-break ties among min-size. */ + private Result findBestPosition(List positions, PassContext ctx, int varIdx) { + // First pass: compile all candidates + List candidates = (positions.size() > PARALLEL_THRESHOLD + ? positions.parallelStream() + : positions.stream()) .map(pos -> { Condition[] order = ctx.order.clone(); BddCompilerSupport.move(order, varIdx, pos); + List orderList = Arrays.asList(order); BddCompilerSupport.BddCompilationResult cr = - BddCompilerSupport.compile(cfg, Arrays.asList(order), threadBuilder.get()); - return new PositionResult(pos, cr.bdd.getNodeCount() - 1, cr.bdd, cr.results); + BddCompilerSupport.compile(cfg, orderList, threadBuilder.get()); + return new Result(pos, cr.bdd.getNodeCount() - 1, cr.bdd, cr.results, orderList); }) - .filter(pr -> pr.count <= ctx.bestSize) - .min(Comparator.comparingInt((PositionResult pr) -> pr.count).thenComparingInt(pr -> pr.position)) - .orElse(null); + .filter(c -> c.size <= ctx.bestSize) + .collect(Collectors.toList()); + + if (candidates.isEmpty()) { + return null; + } + + // Second pass: among min-size candidates, pick lowest cost + int minSize = Integer.MAX_VALUE; + for (Result c : candidates) { + if (c.size < minSize) { + minSize = c.size; + } + } + + Result best = null; + for (Result c : candidates) { + if (c.size == minSize) { + double cost = computeCost(c.bdd, c.orderList); + if (best == null || cost < best.cost || (cost == best.cost && c.position < best.position)) { + best = new Result(c.position, c.size, cost, c.bdd, c.results); + } + } + } + return best; + } + + private double computeCost(Bdd bdd, List ordering) { + return costEstimator.expectedCost(bdd, ordering); } private static List getStrategicPositions( int varIdx, - ConditionDependencyGraph.OrderConstraints constraints, - OptimizationEffort effort + ConditionDependencyGraph.OrderConstraints c, + AdaptiveEffort ae, + int orderSize ) { - int min = constraints.getMinValidPosition(varIdx); - int max = constraints.getMaxValidPosition(varIdx); + int min = c.getMinValidPosition(varIdx); + int max = c.getMaxValidPosition(varIdx); int range = max - min; - if (range <= effort.exhaustiveThreshold) { - List positions = new ArrayList<>(range); + // Exhaustive for small ranges + if (range <= ae.base.exhaustiveThreshold) { + List pos = new ArrayList<>(range); for (int p = min; p < max; p++) { - if (p != varIdx && constraints.canMove(varIdx, p)) { - positions.add(p); + if (p != varIdx && c.canMove(varIdx, p)) { + pos.add(p); } } - return positions; + return pos; } - List positions = new ArrayList<>(effort.maxPositions); + List pos = new ArrayList<>(ae.maxPositions); + boolean[] seen = new boolean[orderSize]; - // Test extremes first since they often yield the best improvements - if (min != varIdx && constraints.canMove(varIdx, min)) { - positions.add(min); - } - if (positions.size() >= effort.maxPositions) { - return positions; + // Extremes + if (min != varIdx && c.canMove(varIdx, min)) { + pos.add(min); + seen[min] = true; } - if (max - 1 != varIdx && constraints.canMove(varIdx, max - 1)) { - positions.add(max - 1); - } - if (positions.size() >= effort.maxPositions) { - return positions; + if (max - 1 != varIdx && c.canMove(varIdx, max - 1)) { + pos.add(max - 1); + seen[max - 1] = true; } - // Test local moves that preserve relative ordering with neighbors - for (int offset = -effort.nearbyRadius; offset <= effort.nearbyRadius; offset++) { - if (offset != 0) { - if (positions.size() >= effort.maxPositions) { - return positions; - } - int p = varIdx + offset; - if (p >= min && p < max && !positions.contains(p) && constraints.canMove(varIdx, p)) { - positions.add(p); - } + // Global sampling + int step = Math.max(1, range / Math.min(15, ae.maxPositions / 2)); + for (int p = min + step; p < max - step && pos.size() < ae.maxPositions; p += step) { + if (p != varIdx && !seen[p] && c.canMove(varIdx, p)) { + pos.add(p); + seen[p] = true; } } - // Sample intermediate positions to find global improvements - if (positions.size() >= effort.maxPositions) { - return positions; - } - - int maxSamples = Math.min(15, effort.maxPositions / 2); - int samples = Math.min(maxSamples, Math.max(2, range / 4)); - int step = Math.max(1, range / samples); - - for (int p = min + step; p < max - step && positions.size() < effort.maxPositions; p += step) { - if (p != varIdx && !positions.contains(p) && constraints.canMove(varIdx, p)) { - positions.add(p); + // Local neighborhood + for (int off = -ae.nearbyRadius; off <= ae.nearbyRadius && pos.size() < ae.maxPositions; off++) { + int p = varIdx + off; + if (off != 0 && p >= min && p < max && !seen[p] && c.canMove(varIdx, p)) { + pos.add(p); + seen[p] = true; } } - return positions; - } - - private static Map rebuildIndex(List orderView) { - Map index = new IdentityHashMap<>(); - for (int i = 0; i < orderView.size(); i++) { - index.put(orderView.get(i), i); - } - return index; + return pos; } - // Helper class to track optimization context within a pass - private static final class OptimizationContext { + /** Mutable context for tracking optimization progress within a pass. */ + private static final class PassContext { final Condition[] order; final List orderView; final ConditionDependencyGraph dependencyGraph; - final ConditionDependencyGraph.OrderConstraints constraints; - final Map liveIndex; - final Bdd bestBdd; - final int bestSize; - final List bestResults; - final int improvements; - - OptimizationContext(OptimizationState state, ConditionDependencyGraph dependencyGraph) { + ConditionDependencyGraph.OrderConstraints constraints; + Bdd bestBdd; + int bestSize; + List bestResults; + int improvements; + + PassContext(State state, ConditionDependencyGraph dependencyGraph) { this.order = state.order; this.orderView = state.orderView; - this.dependencyGraph = dependencyGraph; - this.constraints = dependencyGraph.createOrderConstraints(orderView); - this.liveIndex = rebuildIndex(orderView); - this.bestBdd = null; this.bestSize = state.currentSize; - this.bestResults = null; - this.improvements = 0; - } - - private OptimizationContext( - Condition[] order, - List orderView, - ConditionDependencyGraph dependencyGraph, - ConditionDependencyGraph.OrderConstraints constraints, - Map liveIndex, - Bdd bestBdd, - int bestSize, - List bestResults, - int improvements - ) { - this.order = order; - this.orderView = orderView; this.dependencyGraph = dependencyGraph; - this.constraints = constraints; - this.liveIndex = liveIndex; - this.bestBdd = bestBdd; - this.bestSize = bestSize; - this.bestResults = bestResults; - this.improvements = improvements; - } - - OptimizationContext withImprovement(PositionResult result) { - ConditionDependencyGraph.OrderConstraints newConstraints = - dependencyGraph.createOrderConstraints(orderView); - Map newIndex = rebuildIndex(orderView); - return new OptimizationContext(order, - orderView, - dependencyGraph, - newConstraints, - newIndex, - result.bdd, - result.count, - result.results, - improvements + 1); - } - - OptimizationResult toResult() { - return new OptimizationResult(bestBdd, bestSize, improvements > 0, bestResults); + this.constraints = dependencyGraph.createOrderConstraints(orderView); + } + + void recordImprovement(Result result) { + this.bestBdd = result.bdd; + this.bestSize = result.size; + this.bestResults = result.results; + this.constraints = dependencyGraph.createOrderConstraints(orderView); + this.improvements++; } } - private static final class PositionResult { + /** Result holder for BDD compilation with optional position/cost metadata. */ + private static final class Result { final int position; - final int count; + final int size; + final double cost; final Bdd bdd; final List results; + final List orderList; // For deferred cost computation - PositionResult(int position, int count, Bdd bdd, List results) { - this.position = position; - this.count = count; - this.bdd = bdd; - this.results = results; + Result(int position, int size, Bdd bdd, List results, List orderList) { + this(position, size, Double.MAX_VALUE, bdd, results, orderList); } - } - private static final class OptimizationResult { - final Bdd bdd; - final int size; - final boolean improved; - final List results; + Result(int position, int size, double cost, Bdd bdd, List results) { + this(position, size, cost, bdd, results, null); + } - OptimizationResult(Bdd bdd, int size, boolean improved, List results) { - this.bdd = bdd; + Result(int position, int size, double cost, Bdd bdd, List results, List orderList) { + this.position = position; this.size = size; - this.improved = improved; + this.cost = cost; + this.bdd = bdd; this.results = results; + this.orderList = orderList; } } - private static final class OptimizationState { + /** Tracks overall optimization state across stages. */ + private static final class State { final Condition[] order; final List orderView; - final Bdd bestBdd; - final int currentSize; - final int bestSize; final int initialSize; - final List results; + Bdd bestBdd; + int currentSize; + List results; - OptimizationState( - Condition[] order, - List orderView, - Bdd bestBdd, - int currentSize, - int initialSize, - List results - ) { + State(Condition[] order, List orderView, Bdd bdd, int size, List results) { this.order = order; this.orderView = orderView; - this.bestBdd = bestBdd; - this.currentSize = currentSize; - this.bestSize = currentSize; - this.initialSize = initialSize; + this.bestBdd = bdd; + this.currentSize = size; + this.initialSize = size; this.results = results; } - OptimizationState withResult(Bdd newBdd, int newSize, List newResults) { - return new OptimizationState(order, orderView, newBdd, newSize, initialSize, newResults); + void update(Bdd bdd, int size, List results) { + this.bestBdd = bdd; + this.currentSize = size; + this.results = results; } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java index b0846e786b3..dd23a537cc1 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/cfg/VariableConsolidationTransform.java @@ -142,8 +142,15 @@ private void discoverBindingsInRule( String globalVar = globalExpressionToVar.get(canonical); if (globalVar != null && !globalVar.equals(varName)) { // Same expression elsewhere with different name - // Check if consolidation would cause shadowing - if (!wouldCauseShadowing(globalVar, path, ancestorVars)) { + // Only consolidate if both variables follow SSA naming (same base, different suffix) + // This prevents consolidating semantically different variables that happen to have the same value + if (!hasSameBaseName(varName, globalVar)) { + LOGGER.fine( + String.format("Skipping consolidation '%s' -> '%s' (different base names) for: %s", + varName, + globalVar, + canonical)); + } else if (!wouldCauseShadowing(globalVar, path, ancestorVars)) { variableRenameMap.put(varName, globalVar); consolidatedCount++; LOGGER.info(String.format("Consolidating '%s' -> '%s' for: %s", @@ -177,6 +184,42 @@ private void discoverBindingsInRule( } } + /** + * Checks if two variable names have the same base name. + * For SSA-style variables like "foo_1" and "foo_2", the base name is "foo". + * Variables without SSA suffix (like "s3e_fips" and "s3e_ds") are considered + * to have their full name as the base. + */ + private boolean hasSameBaseName(String var1, String var2) { + String base1 = getSsaBaseName(var1); + String base2 = getSsaBaseName(var2); + return base1.equals(base2); + } + + /** + * Extracts the SSA base name from a variable. + * If the variable ends with _N (where N is a number), strips the suffix. + * Otherwise returns the full name. + */ + private String getSsaBaseName(String varName) { + int lastUnderscore = varName.lastIndexOf('_'); + if (lastUnderscore > 0 && lastUnderscore < varName.length() - 1) { + String suffix = varName.substring(lastUnderscore + 1); + // Check if suffix is all digits + boolean allDigits = true; + for (int i = 0; i < suffix.length(); i++) { + if (!Character.isDigit(suffix.charAt(i))) { + allDigits = false; + break; + } + } + if (allDigits) { + return varName.substring(0, lastUnderscore); + } + } + return varName; + } + private boolean wouldCauseShadowing(String varName, String currentPath, Set ancestorVars) { // Check if using this variable name would shadow an ancestor variable if (ancestorVars.contains(varName)) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java index 9cd9d627c36..52ffd2b4f36 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java @@ -23,6 +23,8 @@ import software.amazon.smithy.model.traits.AbstractTraitBuilder; import software.amazon.smithy.model.traits.Trait; import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; @@ -206,6 +208,16 @@ public static EndpointBddTrait fromNode(Node node) { results.add(NoMatchRule.INSTANCE); // Always add no-match at index 0 results.addAll(serializedResults); + // Validate that results have no conditions (all conditions are hoisted into the BDD) + for (int i = 1; i < results.size(); i++) { + Rule rule = results.get(i); + if (!rule.getConditions().isEmpty()) { + throw new IllegalArgumentException( + "BDD result at index " + i + " has conditions, but BDD results must not have conditions. " + + "All conditions should be hoisted into the BDD decision structure."); + } + } + String nodesBase64 = obj.expectStringMember("nodes").getValue(); int nodeCount = obj.expectNumberMember("nodeCount").getValue().intValue(); int rootRef = obj.expectNumberMember("root").getValue().intValue(); @@ -350,7 +362,21 @@ public Builder bdd(Bdd bdd) { @Override public EndpointBddTrait build() { - return new EndpointBddTrait(this); + EndpointBddTrait trait = new EndpointBddTrait(this); + + // Type-check conditions and results so expression.type() works. Note that using a shared scope across + // each check is ok, because BDD evaluation always runs conditions in a fixed order and could in theory + // try every condition for a single path to a result. + Scope scope = new Scope<>(); + trait.getParameters().writeToScope(scope); + for (Condition condition : trait.getConditions()) { + condition.typeCheck(scope); + } + for (Rule result : trait.getResults()) { + result.typeCheck(scope); + } + + return trait; } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java index c79ab5fbeb8..62e06b43742 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java @@ -124,7 +124,7 @@ private String validateAuthSchemeName( FromSourceLocation sourceLocation ) { Literal nameLiteral = authScheme.get(NAME); - if (nameLiteral == null) { + if (nameLiteral == null || nameLiteral.asStringLiteral().isEmpty()) { events.add(error(service, sourceLocation, String.format( @@ -133,13 +133,14 @@ private String validateAuthSchemeName( return null; } - String name = nameLiteral.asStringLiteral().map(s -> s.expectLiteral()).orElse(null); + // Try to get the name as a literal string. If the template contains variables + // (e.g., from branchless transforms like "{s3e_auth}"), we can't statically validate. + String name = nameLiteral.asStringLiteral() + .filter(t -> t.isStatic()) + .map(t -> t.expectLiteral()) + .orElse(null); if (name == null) { - events.add(error(service, - sourceLocation, - String.format( - "Expected `authSchemes` to have a `name` key with a string value but it did not: `%s`", - authScheme))); + // String literal with template variables - skip static validation return null; } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java index bf6ac4bb9da..7dd1d118883 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java @@ -10,6 +10,7 @@ import java.util.Arrays; import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.error.RuleError; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.syntax.Identifier; @@ -135,7 +136,7 @@ void testCoalesceWithIncompatibleTypes() { Scope scope = new Scope<>(); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> coalesce.typeCheck(scope)); + RuleError ex = assertThrows(RuleError.class, () -> coalesce.typeCheck(scope)); assertTrue(ex.getMessage().contains("Type mismatch in coalesce")); assertTrue(ex.getMessage().contains("argument 2")); } @@ -151,7 +152,7 @@ void testCoalesceWithIncompatibleTypesInMiddle() { Scope scope = new Scope<>(); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> coalesce.typeCheck(scope)); + RuleError ex = assertThrows(RuleError.class, () -> coalesce.typeCheck(scope)); assertTrue(ex.getMessage().contains("Type mismatch in coalesce")); assertTrue(ex.getMessage().contains("argument 3")); } @@ -160,8 +161,7 @@ void testCoalesceWithIncompatibleTypesInMiddle() { void testCoalesceWithLessThanTwoArguments() { Expression single = Literal.of("only"); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, - () -> Coalesce.ofExpressions(single).typeCheck(new Scope<>())); + RuleError ex = assertThrows(RuleError.class, () -> Coalesce.ofExpressions(single).typeCheck(new Scope<>())); assertTrue(ex.getMessage().contains("at least 2 arguments")); } @@ -215,4 +215,24 @@ void testCoalesceWithBooleanTypes() { assertEquals(Type.booleanType(), resultType); } + + @Test + void testTypeMethodReturnsInferredTypeAfterTypeCheck() { + // Verify that type() returns the correct inferred type after typeCheck() + Expression optional1 = Expression.getReference(Identifier.of("maybeValue1")); + Expression optional2 = Expression.getReference(Identifier.of("maybeValue2")); + Expression definite = Literal.of("default"); + Coalesce coalesce = Coalesce.ofExpressions(optional1, optional2, definite); + + Scope scope = new Scope<>(); + scope.insert("maybeValue1", Type.optionalType(Type.stringType())); + scope.insert("maybeValue2", Type.optionalType(Type.stringType())); + + // Call typeCheck to cache the type + coalesce.typeCheck(scope); + + // Now type() should return the inferred type (non-optional since last arg is definite) + Type cachedType = coalesce.type(); + assertEquals(Type.stringType(), cachedType); + } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java new file mode 100644 index 00000000000..5c57ed7dc6a --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java @@ -0,0 +1,278 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language.syntax.functions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.error.RuleError; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; + +public class IteTest { + + @Test + void testIteBothBranchesNonOptionalString() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Literal.of("-fips"); + Expression falseValue = Literal.of(""); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + + Type resultType = ite.typeCheck(scope); + + // Both non-optional String => non-optional String + assertEquals(Type.stringType(), resultType); + } + + @Test + void testIteBothBranchesNonOptionalInteger() { + Expression condition = Expression.getReference(Identifier.of("useNewValue")); + Expression trueValue = Literal.of(100); + Expression falseValue = Literal.of(0); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("useNewValue", Type.booleanType()); + + Type resultType = ite.typeCheck(scope); + + assertEquals(Type.integerType(), resultType); + } + + @Test + void testIteTrueBranchOptional() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybeValue")); + Expression falseValue = Literal.of("default"); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeValue", Type.optionalType(Type.stringType())); + + Type resultType = ite.typeCheck(scope); + + // True branch optional => result is optional + assertEquals(Type.optionalType(Type.stringType()), resultType); + } + + @Test + void testIteFalseBranchOptional() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Literal.of("value"); + Expression falseValue = Expression.getReference(Identifier.of("maybeDefault")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeDefault", Type.optionalType(Type.stringType())); + + Type resultType = ite.typeCheck(scope); + + // False branch optional => result is optional + assertEquals(Type.optionalType(Type.stringType()), resultType); + } + + @Test + void testIteBothBranchesOptional() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybe1")); + Expression falseValue = Expression.getReference(Identifier.of("maybe2")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybe1", Type.optionalType(Type.stringType())); + scope.insert("maybe2", Type.optionalType(Type.stringType())); + + Type resultType = ite.typeCheck(scope); + + // Both optional => result is optional + assertEquals(Type.optionalType(Type.stringType()), resultType); + } + + @Test + void testIteWithOfStringsHelper() { + Expression condition = Expression.getReference(Identifier.of("UseFIPS")); + Ite ite = Ite.ofStrings(condition, "-fips", ""); + + Scope scope = new Scope<>(); + scope.insert("UseFIPS", Type.booleanType()); + + Type resultType = ite.typeCheck(scope); + + // Both literal strings => non-optional String + assertEquals(Type.stringType(), resultType); + } + + @Test + void testIteTypeMismatchBetweenBranches() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Literal.of("string"); + Expression falseValue = Literal.of(42); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + + RuleError ex = assertThrows(RuleError.class, () -> ite.typeCheck(scope)); + assertTrue(ex.getMessage().contains("same base type")); + assertTrue(ex.getMessage().contains("true branch")); + assertTrue(ex.getMessage().contains("false branch")); + } + + @Test + void testIteConditionMustBeBoolean() { + Expression condition = Literal.of("not a boolean"); + Expression trueValue = Literal.of("yes"); + Expression falseValue = Literal.of("no"); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + + RuleError ex = assertThrows(RuleError.class, () -> ite.typeCheck(scope)); + assertTrue(ex.getMessage().contains("non-optional Boolean")); + } + + @Test + void testIteConditionCannotBeOptionalBoolean() { + Expression condition = Expression.getReference(Identifier.of("maybeFlag")); + Expression trueValue = Literal.of("yes"); + Expression falseValue = Literal.of("no"); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("maybeFlag", Type.optionalType(Type.booleanType())); + + RuleError ex = assertThrows(RuleError.class, () -> ite.typeCheck(scope)); + assertTrue(ex.getMessage().contains("non-optional Boolean")); + assertTrue(ex.getMessage().contains("coalesce")); + } + + @Test + void testIteWithArrayTypes() { + Expression condition = Expression.getReference(Identifier.of("useFirst")); + Expression trueValue = Expression.getReference(Identifier.of("array1")); + Expression falseValue = Expression.getReference(Identifier.of("array2")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("useFirst", Type.booleanType()); + scope.insert("array1", Type.arrayType(Type.stringType())); + scope.insert("array2", Type.arrayType(Type.stringType())); + + Type resultType = ite.typeCheck(scope); + + assertEquals(Type.arrayType(Type.stringType()), resultType); + } + + @Test + void testIteWithOptionalArrayType() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybeArray")); + Expression falseValue = Expression.getReference(Identifier.of("definiteArray")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeArray", Type.optionalType(Type.arrayType(Type.integerType()))); + scope.insert("definiteArray", Type.arrayType(Type.integerType())); + + Type resultType = ite.typeCheck(scope); + + // One optional array => result is optional array + assertEquals(Type.optionalType(Type.arrayType(Type.integerType())), resultType); + } + + @Test + void testIteWithBooleanValues() { + Expression condition = Expression.getReference(Identifier.of("invertFlag")); + Expression trueValue = Literal.of(false); + Expression falseValue = Literal.of(true); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("invertFlag", Type.booleanType()); + + Type resultType = ite.typeCheck(scope); + + assertEquals(Type.booleanType(), resultType); + } + + @Test + void testIteTypeMismatchWithOptionalUnwrapping() { + // Even with optional wrapping, base types must match + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybeString")); + Expression falseValue = Expression.getReference(Identifier.of("maybeInt")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeString", Type.optionalType(Type.stringType())); + scope.insert("maybeInt", Type.optionalType(Type.integerType())); + + RuleError ex = assertThrows(RuleError.class, () -> ite.typeCheck(scope)); + assertTrue(ex.getMessage().contains("same base type")); + } + + @Test + void testIteReturnsCorrectId() { + assertEquals("ite", Ite.ID); + assertEquals("ite", Ite.getDefinition().getId()); + } + + @Test + void testTypeMethodReturnsInferredTypeAfterTypeCheck() { + // Verify that type() returns the correct inferred type after typeCheck() + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybeValue")); + Expression falseValue = Literal.of("default"); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeValue", Type.optionalType(Type.stringType())); + + // Call typeCheck to cache the type + ite.typeCheck(scope); + + // Now type() should return the inferred type + Type cachedType = ite.type(); + assertEquals(Type.optionalType(Type.stringType()), cachedType); + } + + @Test + void testNestedIteTypeInference() { + // Test that nested Ite expressions have correct type inference + Expression outerCondition = Expression.getReference(Identifier.of("outer")); + Expression innerCondition = Expression.getReference(Identifier.of("inner")); + + // Inner ITE: ite(inner, "a", "b") => String + Ite innerIte = Ite.ofExpressions(innerCondition, Literal.of("a"), Literal.of("b")); + + // Outer ITE: ite(outer, innerIte, "c") => String + Ite outerIte = Ite.ofExpressions(outerCondition, innerIte, Literal.of("c")); + + Scope scope = new Scope<>(); + scope.insert("outer", Type.booleanType()); + scope.insert("inner", Type.booleanType()); + + outerIte.typeCheck(scope); + + // Both inner and outer should have String type + assertEquals(Type.stringType(), innerIte.type()); + assertEquals(Type.stringType(), outerIte.type()); + } +} diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java index 88c2aab778e..a4defa53c43 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java @@ -12,6 +12,13 @@ import java.util.List; import org.junit.jupiter.api.Test; import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; @@ -25,7 +32,11 @@ public class BddTraitTest { @Test void testBddTraitSerialization() { // Create a BddTrait with full context - Parameters params = Parameters.builder().build(); + Parameter regionParam = Parameter.builder() + .name("Region") + .type(ParameterType.STRING) + .build(); + Parameters params = Parameters.builder().addParameter(regionParam).build(); Condition cond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); Rule endpoint = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); @@ -99,4 +110,39 @@ void testEmptyBddTrait() { assertEquals(1, trait.getResults().size()); assertEquals(-1, trait.getBdd().getRootRef()); // FALSE terminal } + + @Test + void testBuildTypeChecksExpressionsForCodegen() { + // Verify that after building an EndpointBddTrait, expression.type() works + // This is important for codegen to infer types without a scope + Parameter regionParam = Parameter.builder() + .name("Region") + .type(ParameterType.STRING) + .build(); + Parameters params = Parameters.builder().addParameter(regionParam).build(); + + // Create a condition with a coalesce that infers to String + Expression regionRef = Expression.getReference(Identifier.of("Region")); + Expression fallback = Literal.of("us-east-1"); + Coalesce coalesce = Coalesce.ofExpressions(regionRef, fallback); + Condition cond = Condition.builder().fn(coalesce).result(Identifier.of("resolvedRegion")).build(); + + Rule endpoint = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + + List results = new ArrayList<>(); + results.add(NoMatchRule.INSTANCE); + results.add(endpoint); + + EndpointBddTrait trait = EndpointBddTrait.builder() + .parameters(params) + .conditions(ListUtils.of(cond)) + .results(results) + .bdd(createSimpleBdd()) + .build(); + + // After build(), type() should work on the coalesce expression + // Region is Optional, fallback is String, so result is String (non-optional) + Coalesce builtCoalesce = (Coalesce) trait.getConditions().get(0).getFunction(); + assertEquals(Type.stringType(), builtCoalesce.type()); + } } diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.errors new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.errors @@ -0,0 +1 @@ + diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.smithy new file mode 100644 index 00000000000..75a4d4d050c --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.smithy @@ -0,0 +1,80 @@ +$version: "2.0" + +namespace example + +use smithy.rules#clientContextParams +use smithy.rules#endpointRuleSet +use smithy.rules#endpointTests + +@clientContextParams( + useFips: {type: "boolean", documentation: "Use FIPS endpoints"} +) +@endpointRuleSet({ + version: "1.1", + parameters: { + useFips: { + type: "boolean", + documentation: "Use FIPS endpoints", + default: false, + required: true + } + }, + rules: [ + { + "documentation": "Use ite to select endpoint suffix" + "conditions": [ + { + "fn": "ite" + "argv": [{"ref": "useFips"}, "-fips", ""] + "assign": "suffix" + } + ] + "endpoint": { + "url": "https://example{suffix}.com" + } + "type": "endpoint" + } + ] +}) +@endpointTests({ + "version": "1.0", + "testCases": [ + { + "documentation": "When useFips is true, returns trueValue" + "params": { + "useFips": true + } + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "endpoint": { + "url": "https://example-fips.com" + } + } + } + { + "documentation": "When useFips is false, returns falseValue" + "params": { + "useFips": false + } + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "endpoint": { + "url": "https://example.com" + } + } + } + ] +}) +@suppress(["UnstableTrait.smithy"]) +service FizzBuzz { + version: "2022-01-01", + operations: [GetThing] +} + +operation GetThing { + input := {} +}