From 8a23dbf6757be529a1bafad8d7ed1a695ad17040 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Sat, 13 Dec 2025 00:15:05 -0600 Subject: [PATCH 1/5] Improve BDD sifting (2x speed, more reduction) - Add AdaptiveEffort to dynamically adjust optimization parameters based on observed improvement rates (increase effort when making progress, decrease when plateauing) - Add block moves optimization that moves groups of dependent conditions together to escape local minima - Add cost-based tie-breaking using BddCostEstimator when multiple positions have the same node count --- ...b03b74b1054ac2632fcae533d727eace1d26e.json | 7 + .../logic/bdd/SiftingOptimization.java | 629 +++++++++++------- 2 files changed, 386 insertions(+), 250 deletions(-) create mode 100644 .changes/next-release/feature-5e4b03b74b1054ac2632fcae533d727eace1d26e.json diff --git a/.changes/next-release/feature-5e4b03b74b1054ac2632fcae533d727eace1d26e.json b/.changes/next-release/feature-5e4b03b74b1054ac2632fcae533d727eace1d26e.json new file mode 100644 index 00000000000..cd940a23a40 --- /dev/null +++ b/.changes/next-release/feature-5e4b03b74b1054ac2632fcae533d727eace1d26e.json @@ -0,0 +1,7 @@ +{ + "type": "feature", + "description": "Improve BDD sifting (2x speed, more reduction)", + "pull_requests": [ + "[#2890](https://github.com/smithy-lang/smithy/pull/2890)" + ] +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java index 47666837125..8f0d6686b9c 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java @@ -6,16 +6,14 @@ import java.util.ArrayList; import java.util.Arrays; -import java.util.Comparator; -import java.util.IdentityHashMap; import java.util.List; -import java.util.Map; import java.util.function.Function; import java.util.logging.Logger; import java.util.stream.Collectors; import java.util.stream.IntStream; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.logic.ConditionCostModel; import software.amazon.smithy.rulesengine.logic.cfg.Cfg; import software.amazon.smithy.rulesengine.logic.cfg.ConditionDependencyGraph; import software.amazon.smithy.rulesengine.traits.EndpointBddTrait; @@ -34,14 +32,21 @@ public final class SiftingOptimization implements Function { private static final Logger LOGGER = Logger.getLogger(SiftingOptimization.class.getName()); - // When to use a parallel stream private static final int PARALLEL_THRESHOLD = 7; + // Early termination: number of passes to track for plateau detection + private static final int PLATEAU_HISTORY_SIZE = 3; + private static final double PLATEAU_THRESHOLD = 0.5; + // Thread-local BDD builders to avoid allocation overhead private final ThreadLocal threadBuilder = ThreadLocal.withInitial(BddBuilder::new); private final Cfg cfg; private final ConditionDependencyGraph dependencyGraph; + private final ConditionCostModel costModel = ConditionCostModel.createDefault();; + + // Reusable cost estimator, created once per optimization run + private BddCostEstimator costEstimator; // Tiered optimization settings private final int coarseMinNodes; @@ -81,6 +86,44 @@ private enum OptimizationEffort { } } + /** + * Mutable effort tracker that adapts parameters based on observed improvement. + */ + private static final class AdaptiveEffort { + static final double HIGH_THRESHOLD = 10.0; + static final double LOW_THRESHOLD = 2.0; + + final OptimizationEffort base; + int sampleRate; + int maxPositions; + int nearbyRadius; + int bonusPasses; + + AdaptiveEffort(OptimizationEffort effort) { + this.base = effort; + this.sampleRate = effort.sampleRate; + this.maxPositions = effort.maxPositions; + this.nearbyRadius = effort.nearbyRadius; + } + + /** Adapts effort based on improvement. Returns true if effort increased. */ + boolean adapt(double reductionPercent) { + if (reductionPercent >= HIGH_THRESHOLD) { + sampleRate = Math.max(1, sampleRate - 1); + maxPositions = Math.min(base.maxPositions * 2, maxPositions + 5); + nearbyRadius = Math.min(base.nearbyRadius + 6, nearbyRadius + 2); + bonusPasses = Math.min(bonusPasses + 2, 6); + return true; + } else if (reductionPercent < LOW_THRESHOLD) { + sampleRate = Math.min(base.sampleRate * 2, sampleRate + 2); + maxPositions = Math.max(base.maxPositions / 2, maxPositions - 3); + nearbyRadius = Math.max(0, nearbyRadius - 2); + bonusPasses = Math.max(0, bonusPasses - 2); + } + return false; + } + } + private SiftingOptimization(Builder builder) { this.cfg = SmithyBuilder.requiredState("cfg", builder.cfg); this.coarseMinNodes = builder.coarseMinNodes; @@ -108,382 +151,468 @@ public EndpointBddTrait apply(EndpointBddTrait trait) { private EndpointBddTrait doApply(EndpointBddTrait trait) { LOGGER.info("Starting BDD sifting optimization"); long startTime = System.currentTimeMillis(); - OptimizationState state = initializeOptimization(trait); + State state = initializeOptimization(trait); LOGGER.info(String.format("Initial size: %d nodes", state.initialSize)); - state = runOptimizationStage("Coarse", state, OptimizationEffort.COARSE, coarseMinNodes, coarseMaxPasses, 4.0); - state = runOptimizationStage("Medium", state, OptimizationEffort.MEDIUM, mediumMinNodes, mediumMaxPasses, 1.5); + // Create cost estimator once for the entire optimization run + this.costEstimator = new BddCostEstimator(state.orderView, costModel, null); + + runOptimizationStage("Coarse", state, OptimizationEffort.COARSE, coarseMinNodes, coarseMaxPasses, 4.0); + runOptimizationStage("Medium", state, OptimizationEffort.MEDIUM, mediumMinNodes, mediumMaxPasses, 1.5); if (state.currentSize <= granularMaxNodes) { - state = runOptimizationStage("Granular", state, OptimizationEffort.GRANULAR, 0, granularMaxPasses, 0.0); - } else { - LOGGER.info("Skipping granular stage - too large"); + runOptimizationStage("Granular", state, OptimizationEffort.GRANULAR, 0, granularMaxPasses, 0.0); } - state = runAdjacentSwaps(state); + runBlockMoves(state); + runAdjacentSwaps(state); double totalTimeInSeconds = (System.currentTimeMillis() - startTime) / 1000.0; - if (state.bestSize >= state.initialSize) { + if (state.currentSize >= state.initialSize) { LOGGER.info(String.format("No improvements found in %fs", totalTimeInSeconds)); return trait; } LOGGER.info(String.format("Optimization complete: %d -> %d nodes (%.1f%% total reduction) in %fs", state.initialSize, - state.bestSize, - (1.0 - (double) state.bestSize / state.initialSize) * 100, + state.currentSize, + (1.0 - (double) state.currentSize / state.initialSize) * 100, totalTimeInSeconds)); return trait.toBuilder().conditions(state.orderView).results(state.results).bdd(state.bestBdd).build(); } - private OptimizationState initializeOptimization(EndpointBddTrait trait) { - // Use the trait's existing ordering as the starting point + private State initializeOptimization(EndpointBddTrait trait) { List initialOrder = new ArrayList<>(trait.getConditions()); Condition[] order = initialOrder.toArray(new Condition[0]); List orderView = Arrays.asList(order); Bdd bdd = trait.getBdd(); int initialSize = bdd.getNodeCount() - 1; - return new OptimizationState(order, orderView, bdd, initialSize, initialSize, trait.getResults()); + return new State(order, orderView, bdd, initialSize, trait.getResults()); } - private OptimizationState runOptimizationStage( + private void runOptimizationStage( String stageName, - OptimizationState state, + State state, OptimizationEffort effort, - int targetNodeCount, + int targetNodes, int maxPasses, - double minReductionPercent + double minReduction ) { - if (targetNodeCount > 0 && state.currentSize <= targetNodeCount) { - return state; + if (targetNodes > 0 && state.currentSize <= targetNodes) { + return; } - LOGGER.info(String.format("Stage: %s optimization (%d nodes%s)", - stageName, - state.currentSize, - targetNodeCount > 0 ? String.format(", target < %d", targetNodeCount) : "")); + LOGGER.info(String.format("Stage: %s (%d nodes)", stageName, state.currentSize)); + + AdaptiveEffort ae = new AdaptiveEffort(effort); + double[] history = new double[PLATEAU_HISTORY_SIZE]; + int historyIdx = 0, consecutiveLow = 0; + + for (int pass = 1; pass <= maxPasses + ae.bonusPasses; pass++) { + if (targetNodes > 0 && state.currentSize <= targetNodes) { + break; + } - OptimizationState currentState = state; - for (int pass = 1; pass <= maxPasses; pass++) { - if (targetNodeCount > 0 && currentState.currentSize <= targetNodeCount) { + int startSize = state.currentSize; + PassContext result = runPass(state, ae); + if (result.improvements == 0) { break; } - int passStartSize = currentState.currentSize; - OptimizationResult result = runPass(currentState, effort); - if (result.improved) { - currentState = currentState.withResult(result.bdd, result.size, result.results); - double reduction = (1.0 - (double) result.size / passStartSize) * 100; - LOGGER.fine(String.format("%s pass %d: %d -> %d nodes (%.1f%% reduction)", - stageName, - pass, - passStartSize, - result.size, - reduction)); - if (minReductionPercent > 0 && reduction < minReductionPercent) { - LOGGER.fine(String.format("%s optimization yielding diminishing returns", stageName)); + state.update(result.bestBdd, result.bestSize, result.bestResults); + double reduction = (1.0 - (double) result.bestSize / startSize) * 100; + + history[historyIdx++ % PLATEAU_HISTORY_SIZE] = reduction; + if (historyIdx >= PLATEAU_HISTORY_SIZE) { + boolean plateau = true; + for (double r : history) { + if (r >= PLATEAU_THRESHOLD) { + plateau = false; + break; + } + } + if (plateau) { break; } - } else { - LOGGER.fine(String.format("%s pass %d found no improvements", stageName, pass)); + } + + consecutiveLow = ae.adapt(reduction) ? 0 : (reduction < 2.0 ? consecutiveLow + 1 : 0); + if (consecutiveLow >= 2 || (minReduction > 0 && reduction < minReduction)) { break; } } - - return currentState; } - private OptimizationState runAdjacentSwaps(OptimizationState state) { + private void runBlockMoves(State state) { if (state.currentSize > granularMaxNodes) { - return state; + return; } + LOGGER.info("Running block moves"); - LOGGER.info("Running adjacent swaps optimization"); - OptimizationState currentState = state; - - // Run multiple sweeps until no improvement - for (int sweep = 1; sweep <= 3; sweep++) { - OptimizationContext context = new OptimizationContext(currentState, dependencyGraph); - int startSize = currentState.currentSize; + List> blocks = findDependencyBlocks(state.orderView).stream() + .filter(b -> b.size() >= 2 && b.size() <= 5) + .collect(Collectors.toList()); - for (int i = 0; i < currentState.order.length - 1; i++) { - // Adjacent swap requires both elements to be able to occupy each other's positions - if (context.constraints.canMove(i, i + 1) && context.constraints.canMove(i + 1, i)) { - BddCompilerSupport.move(currentState.order, i, i + 1); - BddCompilerSupport.BddCompilationResult compilationResult = - BddCompilerSupport.compile(cfg, currentState.orderView, threadBuilder.get()); - int swappedSize = compilationResult.bdd.getNodeCount() - 1; - if (swappedSize < context.bestSize) { - context = context.withImprovement( - new PositionResult(i + 1, - swappedSize, - compilationResult.bdd, - compilationResult.results)); - } else { - BddCompilerSupport.move(currentState.order, i + 1, i); // Swap back - } - } + for (List block : blocks) { + PassContext ctx = new PassContext(state, dependencyGraph); + Result r = tryBlockMove(block, ctx); + if (r != null && r.size < ctx.bestSize) { + state.update(r.bdd, r.size, r.results); } + } + } + + private List> findDependencyBlocks(List ordering) { + List> blocks = new ArrayList<>(); + if (ordering.isEmpty()) { + return blocks; + } - if (context.improvements > 0) { - currentState = currentState.withResult(context.bestBdd, context.bestSize, context.bestResults); - LOGGER.fine(String.format("Adjacent swaps sweep %d: %d -> %d nodes", - sweep, - startSize, - context.bestSize)); + List curr = new ArrayList<>(); + curr.add(0); + for (int i = 1; i < ordering.size(); i++) { + if (dependencyGraph.getDependencies(ordering.get(i)).contains(ordering.get(i - 1))) { + curr.add(i); } else { - break; + if (curr.size() >= 2) { + blocks.add(curr); + } + curr = new ArrayList<>(); + curr.add(i); } } - return currentState; + if (curr.size() >= 2) { + blocks.add(curr); + } + + return blocks; } - private OptimizationResult runPass(OptimizationState state, OptimizationEffort effort) { - OptimizationContext context = new OptimizationContext(state, dependencyGraph); + private Result tryBlockMove(List block, PassContext ctx) { + int blockStart = block.get(0), blockEnd = block.get(block.size() - 1), blockSize = block.size(); - List selectedConditions = IntStream.range(0, state.orderView.size()) - .filter(i -> i % effort.sampleRate == 0) - .mapToObj(state.orderView::get) - .collect(Collectors.toList()); + // Compute valid range considering all block members' constraints + int minPos = 0, maxPos = ctx.order.length - blockSize; + for (int idx : block) { + int offset = idx - blockStart; + minPos = Math.max(minPos, ctx.constraints.getMinValidPosition(idx) - offset); + maxPos = Math.min(maxPos, ctx.constraints.getMaxValidPosition(idx) - offset); + } + + if (minPos >= maxPos) { + return null; + } - for (Condition condition : selectedConditions) { - Integer varIdx = context.liveIndex.get(condition); - if (varIdx == null) { + // Try a few strategic positions: min, max, mid + int[] targets = {minPos, maxPos, minPos + (maxPos - minPos) / 2}; + Result best = null; + + for (int target : targets) { + if (target == blockStart) { continue; } - List positions = getStrategicPositions(varIdx, context.constraints, effort); - if (positions.isEmpty()) { + Condition[] candidate = ctx.order.clone(); + moveBlock(candidate, blockStart, blockEnd, target); + List candidateList = Arrays.asList(candidate); + + // Validate constraints + ConditionDependencyGraph.OrderConstraints nc = dependencyGraph.createOrderConstraints(candidateList); + boolean valid = true; + for (int j = 0; j < candidate.length; j++) { + if (nc.getMinValidPosition(j) > j || nc.getMaxValidPosition(j) < j) { + valid = false; + break; + } + } + + if (!valid) { continue; } - context = tryImprovePosition(context, varIdx, positions); + BddCompilerSupport.BddCompilationResult cr = + BddCompilerSupport.compile(cfg, candidateList, threadBuilder.get()); + int size = cr.bdd.getNodeCount() - 1; + double cost = computeCost(cr.bdd, candidateList); + if (best == null || size < best.size || (size == best.size && cost < best.cost)) { + best = new Result(target, size, cost, cr.bdd, cr.results); + } + } + return best; + } + + /** + * Moves a contiguous block of elements from [start, end] to begin at targetStart. + */ + private static void moveBlock(Condition[] order, int start, int end, int targetStart) { + if (targetStart == start) { + return; } - return context.toResult(); + int blockSize = end - start + 1; + Condition[] block = new Condition[blockSize]; + System.arraycopy(order, start, block, 0, blockSize); + + if (targetStart < start) { + // Move block earlier: shift elements [targetStart, start) to the right + System.arraycopy(order, targetStart, order, targetStart + blockSize, start - targetStart); + System.arraycopy(block, 0, order, targetStart, blockSize); + } else { + // Move block later: shift elements (end, targetStart + blockSize) to the left + int shiftStart = end + 1; + int shiftEnd = targetStart + blockSize; + if (shiftEnd > order.length) { + shiftEnd = order.length; + } + System.arraycopy(order, shiftStart, order, start, shiftEnd - shiftStart); + System.arraycopy(block, 0, order, targetStart, blockSize); + } } - private OptimizationContext tryImprovePosition(OptimizationContext context, int varIdx, List positions) { - PositionResult best = findBestPosition(positions, context, varIdx); - if (best != null && best.count <= context.bestSize) { // Accept ties - BddCompilerSupport.move(context.order, varIdx, best.position); - return context.withImprovement(best); + private void runAdjacentSwaps(State state) { + if (state.currentSize > granularMaxNodes) { + return; } - return context; + for (int sweep = 0; sweep < 3; sweep++) { + PassContext ctx = new PassContext(state, dependencyGraph); + for (int i = 0; i < state.order.length - 1; i++) { + // Adjacent swap requires both elements to be able to occupy each other's positions + if (ctx.constraints.canMove(i, i + 1) && ctx.constraints.canMove(i + 1, i)) { + BddCompilerSupport.move(state.order, i, i + 1); + BddCompilerSupport.BddCompilationResult cr = BddCompilerSupport.compile( + cfg, + state.orderView, + threadBuilder.get()); + int size = cr.bdd.getNodeCount() - 1; + if (size < ctx.bestSize) { + ctx.recordImprovement(new Result(i + 1, size, cr.bdd, cr.results, null)); + } else { + BddCompilerSupport.move(state.order, i + 1, i); + } + } + } + if (ctx.improvements == 0) { + break; + } + state.update(ctx.bestBdd, ctx.bestSize, ctx.bestResults); + } + } + + private PassContext runPass(State state, AdaptiveEffort effort) { + PassContext ctx = new PassContext(state, dependencyGraph); + int[] nodeCounts = computeNodeCountsPerVariable(state.bestBdd); + int[] selectedIndices = selectConditionsByPriority(state.orderView.size(), nodeCounts, effort.sampleRate); + + for (int varIdx : selectedIndices) { + List positions = getStrategicPositions(varIdx, ctx.constraints, effort, state.orderView.size()); + if (positions.isEmpty()) { + continue; + } + Result best = findBestPosition(positions, ctx, varIdx); + if (best != null && best.size <= ctx.bestSize) { + BddCompilerSupport.move(ctx.order, varIdx, best.position); + ctx.recordImprovement(best); + } + } + return ctx; + } + + /** + * Computes the number of BDD nodes testing each variable. + */ + private static int[] computeNodeCountsPerVariable(Bdd bdd) { + int[] counts = new int[bdd.getConditionCount()]; + for (int i = 0; i < bdd.getNodeCount(); i++) { + int v = bdd.getVariable(i); + if (v >= 0 && v < counts.length) { + counts[v]++; + } + } + return counts; + } + + private static int[] selectConditionsByPriority(int n, int[] nodeCounts, int sampleRate) { + int[] indices = IntStream.range(0, n) + .boxed() + .sorted((a, b) -> Integer.compare(nodeCounts[b], nodeCounts[a])) + .mapToInt(i -> i) + .toArray(); + return sampleRate <= 1 ? indices : Arrays.copyOf(indices, Math.max(1, n / sampleRate)); } - private PositionResult findBestPosition(List positions, OptimizationContext ctx, int varIdx) { - return (positions.size() > PARALLEL_THRESHOLD ? positions.parallelStream() : positions.stream()) + /** Two-pass position finder: compile candidates, then cost-break ties among min-size. */ + private Result findBestPosition(List positions, PassContext ctx, int varIdx) { + // First pass: compile all candidates + List candidates = (positions.size() > PARALLEL_THRESHOLD + ? positions.parallelStream() + : positions.stream()) .map(pos -> { Condition[] order = ctx.order.clone(); BddCompilerSupport.move(order, varIdx, pos); + List orderList = Arrays.asList(order); BddCompilerSupport.BddCompilationResult cr = - BddCompilerSupport.compile(cfg, Arrays.asList(order), threadBuilder.get()); - return new PositionResult(pos, cr.bdd.getNodeCount() - 1, cr.bdd, cr.results); + BddCompilerSupport.compile(cfg, orderList, threadBuilder.get()); + return new Result(pos, cr.bdd.getNodeCount() - 1, cr.bdd, cr.results, orderList); }) - .filter(pr -> pr.count <= ctx.bestSize) - .min(Comparator.comparingInt((PositionResult pr) -> pr.count).thenComparingInt(pr -> pr.position)) - .orElse(null); + .filter(c -> c.size <= ctx.bestSize) + .collect(Collectors.toList()); + + if (candidates.isEmpty()) { + return null; + } + + // Second pass: among min-size candidates, pick lowest cost + int minSize = candidates.stream().mapToInt(c -> c.size).min().orElse(Integer.MAX_VALUE); + Result best = null; + for (Result c : candidates) { + if (c.size == minSize) { + double cost = computeCost(c.bdd, c.orderList); + if (best == null || cost < best.cost || (cost == best.cost && c.position < best.position)) { + best = new Result(c.position, c.size, cost, c.bdd, c.results); + } + } + } + return best; + } + + private double computeCost(Bdd bdd, List ordering) { + return costEstimator.expectedCost(bdd, ordering); } private static List getStrategicPositions( int varIdx, - ConditionDependencyGraph.OrderConstraints constraints, - OptimizationEffort effort + ConditionDependencyGraph.OrderConstraints c, + AdaptiveEffort ae, + int orderSize ) { - int min = constraints.getMinValidPosition(varIdx); - int max = constraints.getMaxValidPosition(varIdx); + int min = c.getMinValidPosition(varIdx); + int max = c.getMaxValidPosition(varIdx); int range = max - min; - if (range <= effort.exhaustiveThreshold) { - List positions = new ArrayList<>(range); + // Exhaustive for small ranges + if (range <= ae.base.exhaustiveThreshold) { + List pos = new ArrayList<>(range); for (int p = min; p < max; p++) { - if (p != varIdx && constraints.canMove(varIdx, p)) { - positions.add(p); + if (p != varIdx && c.canMove(varIdx, p)) { + pos.add(p); } } - return positions; + return pos; } - List positions = new ArrayList<>(effort.maxPositions); + List pos = new ArrayList<>(ae.maxPositions); + boolean[] seen = new boolean[orderSize]; - // Test extremes first since they often yield the best improvements - if (min != varIdx && constraints.canMove(varIdx, min)) { - positions.add(min); - } - if (positions.size() >= effort.maxPositions) { - return positions; + // Extremes + if (min != varIdx && c.canMove(varIdx, min)) { + pos.add(min); + seen[min] = true; } - if (max - 1 != varIdx && constraints.canMove(varIdx, max - 1)) { - positions.add(max - 1); - } - if (positions.size() >= effort.maxPositions) { - return positions; + if (max - 1 != varIdx && c.canMove(varIdx, max - 1)) { + pos.add(max - 1); + seen[max - 1] = true; } - // Test local moves that preserve relative ordering with neighbors - for (int offset = -effort.nearbyRadius; offset <= effort.nearbyRadius; offset++) { - if (offset != 0) { - if (positions.size() >= effort.maxPositions) { - return positions; - } - int p = varIdx + offset; - if (p >= min && p < max && !positions.contains(p) && constraints.canMove(varIdx, p)) { - positions.add(p); - } + // Global sampling + int step = Math.max(1, range / Math.min(15, ae.maxPositions / 2)); + for (int p = min + step; p < max - step && pos.size() < ae.maxPositions; p += step) { + if (p != varIdx && !seen[p] && c.canMove(varIdx, p)) { + pos.add(p); + seen[p] = true; } } - // Sample intermediate positions to find global improvements - if (positions.size() >= effort.maxPositions) { - return positions; - } - - int maxSamples = Math.min(15, effort.maxPositions / 2); - int samples = Math.min(maxSamples, Math.max(2, range / 4)); - int step = Math.max(1, range / samples); - - for (int p = min + step; p < max - step && positions.size() < effort.maxPositions; p += step) { - if (p != varIdx && !positions.contains(p) && constraints.canMove(varIdx, p)) { - positions.add(p); + // Local neighborhood + for (int off = -ae.nearbyRadius; off <= ae.nearbyRadius && pos.size() < ae.maxPositions; off++) { + int p = varIdx + off; + if (off != 0 && p >= min && p < max && !seen[p] && c.canMove(varIdx, p)) { + pos.add(p); + seen[p] = true; } } - return positions; - } - - private static Map rebuildIndex(List orderView) { - Map index = new IdentityHashMap<>(); - for (int i = 0; i < orderView.size(); i++) { - index.put(orderView.get(i), i); - } - return index; + return pos; } - // Helper class to track optimization context within a pass - private static final class OptimizationContext { + /** Mutable context for tracking optimization progress within a pass. */ + private static final class PassContext { final Condition[] order; final List orderView; final ConditionDependencyGraph dependencyGraph; - final ConditionDependencyGraph.OrderConstraints constraints; - final Map liveIndex; - final Bdd bestBdd; - final int bestSize; - final List bestResults; - final int improvements; - - OptimizationContext(OptimizationState state, ConditionDependencyGraph dependencyGraph) { + ConditionDependencyGraph.OrderConstraints constraints; + Bdd bestBdd; + int bestSize; + List bestResults; + int improvements; + + PassContext(State state, ConditionDependencyGraph dependencyGraph) { this.order = state.order; this.orderView = state.orderView; - this.dependencyGraph = dependencyGraph; - this.constraints = dependencyGraph.createOrderConstraints(orderView); - this.liveIndex = rebuildIndex(orderView); - this.bestBdd = null; this.bestSize = state.currentSize; - this.bestResults = null; - this.improvements = 0; - } - - private OptimizationContext( - Condition[] order, - List orderView, - ConditionDependencyGraph dependencyGraph, - ConditionDependencyGraph.OrderConstraints constraints, - Map liveIndex, - Bdd bestBdd, - int bestSize, - List bestResults, - int improvements - ) { - this.order = order; - this.orderView = orderView; this.dependencyGraph = dependencyGraph; - this.constraints = constraints; - this.liveIndex = liveIndex; - this.bestBdd = bestBdd; - this.bestSize = bestSize; - this.bestResults = bestResults; - this.improvements = improvements; - } - - OptimizationContext withImprovement(PositionResult result) { - ConditionDependencyGraph.OrderConstraints newConstraints = - dependencyGraph.createOrderConstraints(orderView); - Map newIndex = rebuildIndex(orderView); - return new OptimizationContext(order, - orderView, - dependencyGraph, - newConstraints, - newIndex, - result.bdd, - result.count, - result.results, - improvements + 1); - } - - OptimizationResult toResult() { - return new OptimizationResult(bestBdd, bestSize, improvements > 0, bestResults); + this.constraints = dependencyGraph.createOrderConstraints(orderView); + } + + void recordImprovement(Result result) { + this.bestBdd = result.bdd; + this.bestSize = result.size; + this.bestResults = result.results; + this.constraints = dependencyGraph.createOrderConstraints(orderView); + this.improvements++; } } - private static final class PositionResult { + /** Result holder for BDD compilation with optional position/cost metadata. */ + private static final class Result { final int position; - final int count; + final int size; + final double cost; final Bdd bdd; final List results; + final List orderList; // For deferred cost computation - PositionResult(int position, int count, Bdd bdd, List results) { - this.position = position; - this.count = count; - this.bdd = bdd; - this.results = results; + Result(int position, int size, Bdd bdd, List results, List orderList) { + this(position, size, Double.MAX_VALUE, bdd, results, orderList); } - } - private static final class OptimizationResult { - final Bdd bdd; - final int size; - final boolean improved; - final List results; + Result(int position, int size, double cost, Bdd bdd, List results) { + this(position, size, cost, bdd, results, null); + } - OptimizationResult(Bdd bdd, int size, boolean improved, List results) { - this.bdd = bdd; + Result(int position, int size, double cost, Bdd bdd, List results, List orderList) { + this.position = position; this.size = size; - this.improved = improved; + this.cost = cost; + this.bdd = bdd; this.results = results; + this.orderList = orderList; } } - private static final class OptimizationState { + /** Tracks overall optimization state across stages. */ + private static final class State { final Condition[] order; final List orderView; - final Bdd bestBdd; - final int currentSize; - final int bestSize; final int initialSize; - final List results; + Bdd bestBdd; + int currentSize; + List results; - OptimizationState( - Condition[] order, - List orderView, - Bdd bestBdd, - int currentSize, - int initialSize, - List results - ) { + State(Condition[] order, List orderView, Bdd bdd, int size, List results) { this.order = order; this.orderView = orderView; - this.bestBdd = bestBdd; - this.currentSize = currentSize; - this.bestSize = currentSize; - this.initialSize = initialSize; + this.bestBdd = bdd; + this.currentSize = size; + this.initialSize = size; this.results = results; } - OptimizationState withResult(Bdd newBdd, int newSize, List newResults) { - return new OptimizationState(order, orderView, newBdd, newSize, initialSize, newResults); + void update(Bdd bdd, int size, List results) { + this.bestBdd = bdd; + this.currentSize = size; + this.results = results; } } From 0aaea908f13798436aba568af078c7639362186a Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Wed, 17 Dec 2025 11:27:08 -0600 Subject: [PATCH 2/5] Implement rules engine ITE fn and S3 tree transform This commit adds a new function to the rules engine, ite, that performs an if-then-else check on a boolean expression without branching. By not needing to branch in the decision tree, we avoid SSA transforms on divergent branches which would create syntactically different but semantically identical expressions that the BDD cannot deduplicate. This commit also adds an S3-specific decision tree transform that canonicalizes S3Express rules for better BDD compilation: 1. AZ extraction: Rewrites position-dependent substring operations to use a single split(Bucket, "--")[1] expression across all branches 2. URL canonicalization: Uses ITE to compute FIPS/DualStack URL segments, collapsing 4 URL variants into a single template with {_s3e_fips} and {_s3e_ds} placeholders 3. Auth scheme canonicalization: Uses ITE to select sigv4 vs sigv4-s3express based on DisableS3ExpressSessionAuth The transform makes the rules tree ~30% larger but enables dramatic BDD compression by making URL templates identical across FIPS/DualStack/auth variants. Endpoints that previously appeared distinct now collapse into single BDD results, reducing nodes and results by ~43%. --- ...6b965bf2f51d85aa0171be6206ae2029137c4.json | 7 + .../rules-engine/standard-library.rst | 97 +++ smithy-aws-endpoints/build.gradle.kts | 54 ++ .../functions/S3TreeRewriterTest.java | 52 ++ .../aws/AwsConditionProbability.java | 18 +- .../language/functions/S3TreeRewriter.java | 633 ++++++++++++++++++ .../rulesengine/language/CoreExtension.java | 2 + .../language/evaluation/RuleEvaluator.java | 6 + .../syntax/expressions/ExpressionVisitor.java | 19 + .../syntax/expressions/functions/Ite.java | 174 +++++ .../logic/bdd/CostOptimization.java | 4 +- .../logic/bdd/SiftingOptimization.java | 2 +- .../cfg/VariableConsolidationTransform.java | 47 +- .../rulesengine/traits/EndpointBddTrait.java | 10 + .../RuleSetAuthSchemesValidator.java | 15 +- .../language/syntax/functions/IteTest.java | 234 +++++++ .../errorfiles/valid/ite-basic.errors | 1 + .../errorfiles/valid/ite-basic.smithy | 80 +++ 18 files changed, 1436 insertions(+), 19 deletions(-) create mode 100644 .changes/next-release/feature-1c36b965bf2f51d85aa0171be6206ae2029137c4.json create mode 100644 smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java create mode 100644 smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java create mode 100644 smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java create mode 100644 smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.errors create mode 100644 smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.smithy diff --git a/.changes/next-release/feature-1c36b965bf2f51d85aa0171be6206ae2029137c4.json b/.changes/next-release/feature-1c36b965bf2f51d85aa0171be6206ae2029137c4.json new file mode 100644 index 00000000000..3153bd1730d --- /dev/null +++ b/.changes/next-release/feature-1c36b965bf2f51d85aa0171be6206ae2029137c4.json @@ -0,0 +1,7 @@ +{ + "type": "feature", + "description": "Implement rules engine ITE fn and S3 tree transform", + "pull_requests": [ + "[#2903](https://github.com/smithy-lang/smithy/pull/2903)" + ] +} diff --git a/docs/source-2.0/additional-specs/rules-engine/standard-library.rst b/docs/source-2.0/additional-specs/rules-engine/standard-library.rst index 6950b914468..8cb5f7d51d8 100644 --- a/docs/source-2.0/additional-specs/rules-engine/standard-library.rst +++ b/docs/source-2.0/additional-specs/rules-engine/standard-library.rst @@ -208,6 +208,103 @@ The following example uses ``isValidHostLabel`` to check if the value of the } +.. _rules-engine-standard-library-ite: + +``ite`` function +================ + +Summary + An if-then-else function that returns one of two values based on a boolean condition. +Argument types + * condition: ``bool`` + * trueValue: ``T`` or ``option`` + * falseValue: ``T`` or ``option`` +Return type + * ``ite(bool, T, T)`` → ``T`` (both non-optional, result is non-optional) + * ``ite(bool, T, option)`` → ``option`` (any optional makes result optional) + * ``ite(bool, option, T)`` → ``option`` (any optional makes result optional) + * ``ite(bool, option, option)`` → ``option`` (both optional, result is optional) +Since + 1.1 + +The ``ite`` (if-then-else) function evaluates a boolean condition and returns one of two values based on +the result. If the condition is ``true``, it returns ``trueValue``; if ``false``, it returns ``falseValue``. +This function is particularly useful for computing conditional values without branching in the rule tree, resulting +in fewer result nodes, and enabling better BDD optimizations as a result of reduced fragmentation. + +.. important:: + Both ``trueValue`` and ``falseValue`` must have the same base type ``T``. The result type follows + the "least upper bound" rule: if either branch is optional, the result is optional. + +The following example uses ``ite`` to compute a URL suffix based on whether FIPS is enabled: + +.. code-block:: json + + { + "fn": "ite", + "argv": [ + {"ref": "UseFIPS"}, + "-fips", + "" + ], + "assign": "fipsSuffix" + } + +The following example uses ``ite`` with ``coalesce`` to handle an optional boolean parameter: + +.. code-block:: json + + { + "fn": "ite", + "argv": [ + { + "fn": "coalesce", + "argv": [ + {"ref": "DisableFeature"}, + false + ] + }, + "disabled", + "enabled" + ], + "assign": "featureState" + } + + +.. _rules-engine-standard-library-ite-examples: + +-------- +Examples +-------- + +The following table shows various inputs and their corresponding outputs for the ``ite`` function: + +.. list-table:: + :header-rows: 1 + :widths: 20 25 25 30 + + * - Condition + - True Value + - False Value + - Output + * - ``true`` + - ``"-fips"`` + - ``""`` + - ``"-fips"`` + * - ``false`` + - ``"-fips"`` + - ``""`` + - ``""`` + * - ``true`` + - ``"sigv4"`` + - ``"sigv4-s3express"`` + - ``"sigv4"`` + * - ``false`` + - ``"sigv4"`` + - ``"sigv4-s3express"`` + - ``"sigv4-s3express"`` + + .. _rules-engine-standard-library-not: ``not`` function diff --git a/smithy-aws-endpoints/build.gradle.kts b/smithy-aws-endpoints/build.gradle.kts index c142cd4ee32..35c213e1e70 100644 --- a/smithy-aws-endpoints/build.gradle.kts +++ b/smithy-aws-endpoints/build.gradle.kts @@ -11,10 +11,64 @@ description = "AWS specific components for managing endpoints in Smithy" extra["displayName"] = "Smithy :: AWS Endpoints Components" extra["moduleName"] = "software.amazon.smithy.aws.endpoints" +// Custom configuration for S3 model - kept separate from test classpath to avoid +// polluting other tests with S3 model discovery +val s3Model: Configuration by configurations.creating + dependencies { api(project(":smithy-aws-traits")) api(project(":smithy-diff")) api(project(":smithy-rules-engine")) api(project(":smithy-model")) api(project(":smithy-utils")) + + s3Model("software.amazon.api.models:s3:1.0.11") +} + +// Integration test source set for tests that require the S3 model +// These tests require JDK 17+ due to the S3 model dependency +sourceSets { + create("it") { + compileClasspath += sourceSets["main"].output + sourceSets["test"].output + runtimeClasspath += sourceSets["main"].output + sourceSets["test"].output + } +} + +configurations["itImplementation"].extendsFrom(configurations["testImplementation"]) +configurations["itRuntimeOnly"].extendsFrom(configurations["testRuntimeOnly"]) +configurations["itImplementation"].extendsFrom(s3Model) + +// Configure IT source set to compile with JDK 17 +tasks.named("compileItJava") { + javaCompiler.set( + javaToolchains.compilerFor { + languageVersion.set(JavaLanguageVersion.of(17)) + }, + ) + sourceCompatibility = "17" + targetCompatibility = "17" +} + +val integrationTest by tasks.registering(Test::class) { + description = "Runs integration tests that require external models like S3" + group = "verification" + testClassesDirs = sourceSets["it"].output.classesDirs + classpath = sourceSets["it"].runtimeClasspath + dependsOn(tasks.jar) + shouldRunAfter(tasks.test) + + // Run with JDK 17 + javaLauncher.set( + javaToolchains.launcherFor { + languageVersion.set(JavaLanguageVersion.of(17)) + }, + ) +} + +tasks.test { + finalizedBy(integrationTest) +} + +tasks.named("check") { + dependsOn(integrationTest) } diff --git a/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java new file mode 100644 index 00000000000..dd5e88140a7 --- /dev/null +++ b/smithy-aws-endpoints/src/it/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriterTest.java @@ -0,0 +1,52 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.language.functions; + +import static org.junit.jupiter.api.Assertions.assertFalse; + +import java.util.List; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import software.amazon.smithy.model.Model; +import software.amazon.smithy.model.shapes.ServiceShape; +import software.amazon.smithy.model.shapes.ShapeId; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.evaluation.TestEvaluator; +import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait; +import software.amazon.smithy.rulesengine.traits.EndpointTestCase; +import software.amazon.smithy.rulesengine.traits.EndpointTestsTrait; + +/** + * Runs the endpoint test cases against the transformed S3 model. We're fixed to a specific version for this test, + * but could periodically bump the version if needed. + */ +class S3TreeRewriterTest { + private static final ShapeId S3_SERVICE_ID = ShapeId.from("com.amazonaws.s3#AmazonS3"); + + private static EndpointRuleSet originalRules; + private static List testCases; + + @BeforeAll + static void loadS3Model() { + Model model = Model.assembler() + .discoverModels() + .assemble() + .unwrap(); + + ServiceShape s3Service = model.expectShape(S3_SERVICE_ID, ServiceShape.class); + originalRules = s3Service.expectTrait(EndpointRuleSetTrait.class).getEndpointRuleSet(); + testCases = s3Service.expectTrait(EndpointTestsTrait.class).getTestCases(); + } + + @Test + void transformPreservesEndpointTestSemantics() { + assertFalse(testCases.isEmpty(), "S3 model should have endpoint test cases"); + + EndpointRuleSet transformed = S3TreeRewriter.transform(originalRules); + for (EndpointTestCase testCase : testCases) { + TestEvaluator.evaluate(transformed, testCase); + } + } +} diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/AwsConditionProbability.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/AwsConditionProbability.java index 42b2344e8ef..02dbe32fa4e 100644 --- a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/AwsConditionProbability.java +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/AwsConditionProbability.java @@ -26,12 +26,21 @@ public double applyAsDouble(Condition condition) { // Region is almost always provided if (s.contains("isSet(Region)")) { - return 0.95; + return 0.96; } // Endpoint override is rare if (s.contains("isSet(Endpoint)")) { - return 0.1; + return 0.2; + } + + // S3 Express is rare (includes ITE variables from S3TreeRewriter) + if (s.contains("S3Express") || s.contains("--x-s3") + || s.contains("--xa-s3") + || s.contains("s3e_fips") + || s.contains("s3e_ds") + || s.contains("s3e_auth")) { + return 0.001; } // Most isSet checks on optional params succeed moderately @@ -48,11 +57,6 @@ public double applyAsDouble(Condition condition) { return 0.05; } - // S3 Express is relatively rare - if (s.contains("S3Express") || s.contains("--x-s3") || s.contains("--xa-s3")) { - return 0.1; - } - // ARN-based buckets are uncommon if (s.contains("parseArn") || s.contains("arn:")) { return 0.15; diff --git a/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java new file mode 100644 index 00000000000..f748a289f56 --- /dev/null +++ b/smithy-aws-endpoints/src/main/java/software/amazon/smithy/rulesengine/aws/language/functions/S3TreeRewriter.java @@ -0,0 +1,633 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.aws.language.functions; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.logging.Logger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import software.amazon.smithy.model.node.StringNode; +import software.amazon.smithy.rulesengine.language.Endpoint; +import software.amazon.smithy.rulesengine.language.EndpointRuleSet; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Reference; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Template; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Split; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Substring; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; +import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; +import software.amazon.smithy.rulesengine.language.syntax.rule.Rule; +import software.amazon.smithy.rulesengine.language.syntax.rule.TreeRule; +import software.amazon.smithy.utils.SmithyInternalApi; + +/** + * Rewrites S3 endpoint rules to use canonical, position-independent expressions. + * + *

This is a BDD pre-processing transform that makes the rules tree larger but enables dramatically better + * BDD compilation. It solves the "SSA Trap" problem where semantically identical operations appear as syntactically + * different expressions, preventing the BDD compiler from recognizing sharing opportunities. + * + *

Internal use only

+ *

Ideally this transform is deleted one day, and the rules that source it adopt these techniques (hopefully we + * don't look back on this comment and laugh in 5 years). If/when that happens, this class will be deleted, whether + * it breaks a consumer that uses it or not. + * + *

Trade-off: Larger Rules, Smaller BDD

+ *

This transform would be counterproductive for rule tree interpretation, but is highly beneficial when a + * BDD compiler processes the output. It adds ITE (if-then-else) conditions to compute URL segments and auth scheme + * names, increasing rule tree size by ~30%. However, this enables the BDD compiler to deduplicate endpoints that + * were previously considered distinct, as of writing, reducing BDD results and node counts both by ~43%. + * + *

The key insight is that the BDD deduplicates by endpoint identity (URL template + properties). By making + * URL templates identical through variable substitution, endpoints that differed only in FIPS/DualStack/auth variants + * collapse into a single BDD result. + * + *

Transformations performed:

+ * + *

AZ Extraction Canonicalization

+ * + *

The original rules extract the availability zone ID using position-dependent substring operations. + * Different bucket name lengths result in different extraction positions, creating 10+ SSA variants that can't + * be shared in the BDD. + * + *

Before: Position-dependent substring extraction + *

{@code
+ * {
+ *   "conditions": [
+ *     {
+ *       "fn": "substring",
+ *       "argv": [{"ref": "Bucket"}, 6, 14, true],
+ *       "assign": "s3expressAvailabilityZoneId"
+ *     }
+ *   ],
+ *   "rules": [...]
+ * }
+ * // Another branch with different positions:
+ * {
+ *   "conditions": [
+ *     {
+ *       "fn": "substring",
+ *       "argv": [{"ref": "Bucket"}, 6, 20, true],
+ *       "assign": "s3expressAvailabilityZoneId"
+ *     }
+ *   ],
+ *   "rules": [...]
+ * }
+ * }
+ * + *

After: Position-independent split-based extraction + *

{@code
+ * {
+ *   "conditions": [
+ *     {
+ *       "fn": "getAttr",
+ *       "argv": [
+ *         {"fn": "split", "argv": [{"ref": "Bucket"}, "--", 0]},
+ *         "[1]"
+ *       ],
+ *       "assign": "s3expressAvailabilityZoneId"
+ *     }
+ *   ],
+ *   "rules": [...]
+ * }
+ * }
+ * + *

All branches now use the identical expression {@code split(Bucket, "--")[1]}, enabling + * the BDD compiler to share nodes across all S3Express bucket handling paths. Because the expression only interacts + * with Bucket, a constant value, there's no SSA transform performed on these expressions. + * + *

URL Canonicalization

+ * + *

S3Express endpoints (currently) have 4 URL variants based on UseFIPS and UseDualStack flags. This creates + * duplicate endpoints that differ only in URL structure. + * + *

Before: Separate endpoints for each FIPS/DualStack combination + *

{@code
+ * // Branch 1: FIPS + DualStack
+ * {
+ *   "conditions": [
+ *     {"fn": "booleanEquals", "argv": [{"ref": "UseFIPS"}, true]},
+ *     {"fn": "booleanEquals", "argv": [{"ref": "UseDualStack"}, true]}
+ *   ],
+ *   "endpoint": {
+ *     "url": "https://{Bucket}.s3express-fips-{s3expressAvailabilityZoneId}.dualstack.{Region}.amazonaws.com"
+ *   }
+ * }
+ * // Branch 2: FIPS only
+ * {
+ *   "conditions": [
+ *     {"fn": "booleanEquals", "argv": [{"ref": "UseFIPS"}, true]}
+ *   ],
+ *   "endpoint": {
+ *     "url": "https://{Bucket}.s3express-fips-{s3expressAvailabilityZoneId}.{Region}.amazonaws.com"
+ *   }
+ * }
+ * // Branch 3: DualStack only
+ * // Branch 4: Neither
+ * }
+ * + *

After: Single endpoint with ITE-computed URL segments + *

{@code
+ * {
+ *   "conditions": [
+ *     {"fn": "ite", "argv": [{"ref": "UseFIPS"}, "-fips", ""], "assign": "_s3e_fips"},
+ *     {"fn": "ite", "argv": [{"ref": "UseDualStack"}, ".dualstack", ""], "assign": "_s3e_ds"}
+ *   ],
+ *   "endpoint": {
+ *     "url": "https://{Bucket}.s3express{_s3e_fips}-{s3expressAvailabilityZoneId}{_s3e_ds}.{Region}.amazonaws.com"
+ *   }
+ * }
+ * }
+ * + *

The ITE conditions compute values branchlessly. The BDD sifting optimization naturally places these rare + * S3Express-specific conditions late in the decision tree. + * + *

Auth Scheme Canonicalization

+ * + *

S3Express endpoints use different auth schemes based on DisableS3ExpressSessionAuth. + * This creates duplicate endpoints differing only in auth scheme name. + * + *

Before: Separate auth scheme names + *

{@code
+ * // When DisableS3ExpressSessionAuth is true:
+ * "authSchemes": [{"name": "sigv4", "signingName": "s3express", ...}]
+ *
+ * // When DisableS3ExpressSessionAuth is false/unset:
+ * "authSchemes": [{"name": "sigv4-s3express", "signingName": "s3express", ...}]
+ * }
+ * + *

After: ITE-computed auth scheme name + *

{@code
+ * {
+ *   "conditions": [
+ *     {
+ *       "fn": "ite",
+ *       "argv": [
+ *         {"fn": "coalesce", "argv": [{"ref": "DisableS3ExpressSessionAuth"}, false]},
+ *         "sigv4",
+ *         "sigv4-s3express"
+ *       ],
+ *       "assign": "_s3e_auth"
+ *     }
+ *   ],
+ *   "endpoint": {
+ *     "properties": {
+ *       "authSchemes": [{"name": "{_s3e_auth}", "signingName": "s3express", ...}]
+ *     }
+ *   }
+ * }
+ * }
+ */ +@SmithyInternalApi +public final class S3TreeRewriter { + private static final Logger LOGGER = Logger.getLogger(S3TreeRewriter.class.getName()); + + // Variable names for the computed suffixes + private static final String VAR_FIPS = "_s3e_fips"; + private static final String VAR_DS = "_s3e_ds"; + private static final String VAR_AUTH = "_s3e_auth"; + + // Suffix values used in the URI templates + private static final String FIPS_SUFFIX = "-fips"; + private static final String DS_SUFFIX = ".dualstack"; + private static final String EMPTY_SUFFIX = ""; + + // Auth scheme values used with s3-express + private static final String AUTH_SIGV4 = "sigv4"; + private static final String AUTH_SIGV4_S3EXPRESS = "sigv4-s3express"; + + // Property and parameter identifiers + private static final Identifier ID_AUTH_SCHEMES = Identifier.of("authSchemes"); + private static final Identifier ID_NAME = Identifier.of("name"); + private static final Identifier ID_BACKEND = Identifier.of("backend"); + private static final Identifier ID_BUCKET = Identifier.of("Bucket"); + private static final Identifier ID_AZ_ID = Identifier.of("s3expressAvailabilityZoneId"); + private static final Identifier ID_USE_FIPS = Identifier.of("UseFIPS"); + private static final Identifier ID_USE_DUAL_STACK = Identifier.of("UseDualStack"); + private static final Identifier ID_DISABLE_S3EXPRESS_SESSION_AUTH = Identifier.of("DisableS3ExpressSessionAuth"); + + // Auth scheme name literal shared across all rewritten endpoints + private static final Literal AUTH_NAME_LITERAL = Literal.stringLiteral(Template.fromString("{" + VAR_AUTH + "}")); + + // Patterns to match S3Express bucket endpoint URLs (with AZ) + // Format: https://{Bucket}.s3express[-fips]-{AZ}[.dualstack].{Region}.amazonaws.com + // (negative lookahead (?!dualstack) prevents matching dualstack variants in non-DS patterns) + private static final Pattern S3EXPRESS_FIPS_DS = Pattern.compile("(s3express)-fips-([^.]+)\\.dualstack\\.(.+)$"); + private static final Pattern S3EXPRESS_FIPS = Pattern.compile("(s3express)-fips-([^.]+)\\.(?!dualstack)(.+)$"); + private static final Pattern S3EXPRESS_DS = Pattern.compile("(s3express)-([^.]+)\\.dualstack\\.(.+)$"); + private static final Pattern S3EXPRESS_PLAIN = Pattern.compile("(s3express)-([^.]+)\\.(?!dualstack)(.+)$"); + + // Patterns to match S3Express control plane URLs (no AZ) + // Format: https://s3express-control[-fips][.dualstack].{Region}.amazonaws.com + private static final Pattern S3EXPRESS_CONTROL_FIPS_DS = Pattern.compile( + "(s3express-control)-fips\\.dualstack\\.(.+)$"); + private static final Pattern S3EXPRESS_CONTROL_FIPS = Pattern.compile( + "(s3express-control)-fips\\.(?!dualstack)(.+)$"); + private static final Pattern S3EXPRESS_CONTROL_DS = Pattern.compile( + "(s3express-control)\\.dualstack\\.(.+)$"); + private static final Pattern S3EXPRESS_CONTROL_PLAIN = Pattern.compile( + "(s3express-control)\\.(?!dualstack)(.+)$"); + + // Cached canonical expression for AZ extraction: split(Bucket, "--", 0) + private static final Split BUCKET_SPLIT = Split.ofExpressions( + Expression.getReference(ID_BUCKET), + Expression.of("--"), + Expression.of(0)); + + private int rewrittenCount = 0; + private int totalS3ExpressCount = 0; + + private S3TreeRewriter() {} + + /** + * Transforms the given endpoint rule set using canonical expressions. + * + * @param ruleSet the rule set to transform + * @return the transformed rule set + */ + public static EndpointRuleSet transform(EndpointRuleSet ruleSet) { + return new S3TreeRewriter().run(ruleSet); + } + + private EndpointRuleSet run(EndpointRuleSet ruleSet) { + List transformedRules = new ArrayList<>(); + for (Rule rule : ruleSet.getRules()) { + transformedRules.add(transformRule(rule)); + } + + LOGGER.info(() -> String.format( + "S3 tree rewriter: %s/%s S3Express endpoints rewritten", + rewrittenCount, + totalS3ExpressCount)); + + return EndpointRuleSet.builder() + .sourceLocation(ruleSet.getSourceLocation()) + .parameters(ruleSet.getParameters()) + .rules(transformedRules) + .version(ruleSet.getVersion()) + .build(); + } + + private Rule transformRule(Rule rule) { + if (rule instanceof TreeRule) { + TreeRule tr = (TreeRule) rule; + // Transform conditions + List transformedConditions = transformConditions(tr.getConditions()); + List transformedChildren = new ArrayList<>(); + for (Rule child : tr.getRules()) { + transformedChildren.add(transformRule(child)); + } + return Rule.builder().conditions(transformedConditions).treeRule(transformedChildren); + } else if (rule instanceof EndpointRule) { + return rewriteEndpoint((EndpointRule) rule); + } else { + // Error rules pass through unchanged + return rule; + } + } + + private List transformConditions(List conditions) { + List result = new ArrayList<>(conditions.size()); + for (Condition cond : conditions) { + result.add(transformCondition(cond)); + } + return result; + } + + /** + * Transforms a single condition. + * + *

Handles: + *

+     * AZ extraction: substring(Bucket, N, M) -> split(Bucket, "--")[1]
+     * 
+ * + *

Note: Delimiter checks (s3expressAvailabilityZoneDelim) are not currently transformed because they're part + * of a complex fallback structure, and changing them breaks control flow. Possibly something we can improve, or + * wait until the upstream rules are optimized. + */ + private Condition transformCondition(Condition cond) { + // Is this a condition fishing for delimiters? + if (cond.getResult().isPresent() + && ID_AZ_ID.equals(cond.getResult().get()) + && cond.getFunction() instanceof Substring + && isSubstringOnBucket((Substring) cond.getFunction())) { + // Replace with split-based extraction: split(Bucket, "--")[1] + GetAttr azExpr = GetAttr.ofExpressions(BUCKET_SPLIT, "[1]"); + return cond.toBuilder().fn(azExpr).build(); + } + + return cond; + } + + private boolean isSubstringOnBucket(Substring substring) { + List args = substring.getArguments(); + if (args.isEmpty()) { + return false; + } + + Expression target = args.get(0); + return target instanceof Reference && ID_BUCKET.equals(((Reference) target).getName()); + } + + // Creates ITE conditions for branchless S3Express variable computation. + private List createIteConditions() { + List conditions = new ArrayList<>(); + conditions.add(createIteAssignment(VAR_FIPS, Expression.getReference(ID_USE_FIPS), FIPS_SUFFIX, EMPTY_SUFFIX)); + conditions.add(createIteAssignment( + VAR_DS, + Expression.getReference(ID_USE_DUAL_STACK), + DS_SUFFIX, + EMPTY_SUFFIX)); + // Auth scheme: sigv4 when session auth disabled, sigv4-s3express otherwise + Expression sessionAuthDisabled = Coalesce.ofExpressions( + Expression.getReference(ID_DISABLE_S3EXPRESS_SESSION_AUTH), + Expression.of(false)); + conditions.add(createIteAssignment(VAR_AUTH, sessionAuthDisabled, AUTH_SIGV4, AUTH_SIGV4_S3EXPRESS)); + return conditions; + } + + // Creates an ITE-based assignment condition. + private Condition createIteAssignment(String varName, Expression condition, String trueValue, String falseValue) { + return Condition.builder() + .fn(Ite.ofStrings(condition, trueValue, falseValue)) + .result(varName) + .build(); + } + + // Rewrites an endpoint rule to use canonical S3Express URLs and auth schemes. + private Rule rewriteEndpoint(EndpointRule rule) { + Endpoint endpoint = rule.getEndpoint(); + Expression urlExpr = endpoint.getUrl(); + + // Extract the raw URL string from the expression (IFF it's a static string, rarely is anything else). + String urlStr = extractUrlString(urlExpr); + if (urlStr == null) { + return rule; + } + + // Check if this is an S3Express endpoint by URL or backend property. + // Note: while `contains("s3express")` is broad and could theoretically match path/query components, + // the subsequent matchUrl() call validates the hostname pattern before any rewriting occurs. + boolean isS3ExpressUrl = urlStr.contains("s3express"); + boolean isS3ExpressBackend = isS3ExpressBackend(endpoint); + + if (!isS3ExpressUrl && !isS3ExpressBackend) { + return rule; + } + + totalS3ExpressCount++; + + // For URL override endpoints (backend=S3Express but URL doesn't match s3express hostname), + // just canonicalize the auth scheme - no URL rewriting needed + if (isS3ExpressBackend && !isS3ExpressUrl) { + // Canonicalize auth scheme to use {_s3e_auth} + Map newProperties = canonicalizeAuthScheme(endpoint.getProperties()); + + if (newProperties == endpoint.getProperties()) { + // No changes needed + return rule; + } + + rewrittenCount++; + + Endpoint newEndpoint = Endpoint.builder() + .url(urlExpr) + .headers(endpoint.getHeaders()) + .properties(newProperties) + .sourceLocation(endpoint.getSourceLocation()) + .build(); + + // Add auth ITE condition for URL override endpoints + List allConditions = new ArrayList<>(rule.getConditions()); + allConditions.add(createAuthIteCondition()); + + return Rule.builder() + .conditions(allConditions) + .endpoint(newEndpoint); + } + + // Standard S3Express URL - match and rewrite + UrlMatchResult match = matchUrl(urlStr); + if (match == null) { + return rule; + } + + rewrittenCount++; + + // Rewrite the URL to use the ITE-assigned variables + String newUrl = match.rewriteUrl(); + + // Canonicalize auth scheme for bucket endpoints (not control plane) + // Control plane always uses sigv4, bucket endpoints vary based on DisableS3ExpressSessionAuth + Map newProperties = endpoint.getProperties(); + if (match instanceof BucketUrlMatchResult) { + newProperties = canonicalizeAuthScheme(endpoint.getProperties()); + } + + // Build the new endpoint with canonicalized URL and properties + Endpoint newEndpoint = Endpoint.builder() + .url(Expression.of(newUrl)) + .headers(endpoint.getHeaders()) + .properties(newProperties) + .sourceLocation(endpoint.getSourceLocation()) + .build(); + + // Add ITE conditions: original conditions first, then ITE conditions at the end. + List allConditions = new ArrayList<>(rule.getConditions()); + allConditions.addAll(createIteConditions()); + + return Rule.builder() + .conditions(allConditions) + .endpoint(newEndpoint); + } + + // Checks if the endpoint has `backend` property set to "S3Express". + private boolean isS3ExpressBackend(Endpoint endpoint) { + Literal backend = endpoint.getProperties().get(ID_BACKEND); + if (backend == null) { + return false; + } + + return backend.asStringLiteral() + .filter(Template::isStatic) + .map(t -> "S3Express".equalsIgnoreCase(t.expectLiteral())) + .orElse(false); + } + + // Creates just the auth ITE condition for URL override endpoints. + private Condition createAuthIteCondition() { + // `DisableS3ExpressSessionAuth` is nullable, so we need to coalesce it to have a false default. Fix upstream? + Expression isSessionAuthDisabled = Coalesce.ofExpressions( + Expression.getReference(ID_DISABLE_S3EXPRESS_SESSION_AUTH), + Expression.of(false)); + return createIteAssignment(VAR_AUTH, isSessionAuthDisabled, AUTH_SIGV4, AUTH_SIGV4_S3EXPRESS); + } + + // Canonicalizes the authScheme name in endpoint properties to use the ITE variable. + private Map canonicalizeAuthScheme(Map properties) { + Literal authSchemes = properties.get(ID_AUTH_SCHEMES); + if (authSchemes == null) { + return properties; + } + + List schemes = authSchemes.asTupleLiteral().orElse(null); + if (schemes == null || schemes.isEmpty()) { + return properties; + } + + // Rewrite each auth scheme's name field + List newSchemes = new ArrayList<>(); + for (Literal scheme : schemes) { + Map record = scheme.asRecordLiteral().orElse(null); + if (record == null) { + // Auth is always a record, but maybe that changes in the future, so pass it through. + newSchemes.add(scheme); + continue; + } + + Literal nameLiteral = record.get(ID_NAME); + if (nameLiteral == null) { + // "name" should always be set, but pass through if not. + newSchemes.add(scheme); + continue; + } + + // Only transform string literals we recognize. + String name = nameLiteral.asStringLiteral() + .filter(Template::isStatic) + .map(Template::expectLiteral) + .orElse(null); + + // Only rewrite if it's one of the S3Express auth schemes + if (AUTH_SIGV4.equals(name) || AUTH_SIGV4_S3EXPRESS.equals(name)) { + Map newRecord = new LinkedHashMap<>(record); + newRecord.put(ID_NAME, AUTH_NAME_LITERAL); + newSchemes.add(Literal.recordLiteral(newRecord)); + } else { + newSchemes.add(scheme); + } + } + + Map newProperties = new LinkedHashMap<>(properties); + newProperties.put(ID_AUTH_SCHEMES, Literal.tupleLiteral(newSchemes)); + return newProperties; + } + + // Extracts the raw URL string from a URL expression. + private String extractUrlString(Expression urlExpr) { + return urlExpr.toNode().asStringNode().map(StringNode::getValue).orElse(null); + } + + // Matches an S3Express URL and returns the pattern match info. Tries to match in most specific order. + private UrlMatchResult matchUrl(String url) { + Matcher m; + + // First try control plane patterns (no AZ) since these are more specific + m = S3EXPRESS_CONTROL_FIPS_DS.matcher(url); + if (m.find()) { + return new ControlPlaneUrlMatchResult(url, m); + } + + m = S3EXPRESS_CONTROL_FIPS.matcher(url); + if (m.find()) { + return new ControlPlaneUrlMatchResult(url, m); + } + + m = S3EXPRESS_CONTROL_DS.matcher(url); + if (m.find()) { + return new ControlPlaneUrlMatchResult(url, m); + } + + m = S3EXPRESS_CONTROL_PLAIN.matcher(url); + if (m.find()) { + return new ControlPlaneUrlMatchResult(url, m); + } + + // Next, try bucket endpoint patterns (with AZ) + m = S3EXPRESS_FIPS_DS.matcher(url); + if (m.find()) { + return new BucketUrlMatchResult(url, m); + } + + m = S3EXPRESS_FIPS.matcher(url); + if (m.find()) { + return new BucketUrlMatchResult(url, m); + } + + m = S3EXPRESS_DS.matcher(url); + if (m.find()) { + return new BucketUrlMatchResult(url, m); + } + + m = S3EXPRESS_PLAIN.matcher(url); + if (m.find()) { + return new BucketUrlMatchResult(url, m); + } + + return null; + } + + /** + * Result of matching an S3Express URL pattern. + */ + private abstract static class UrlMatchResult { + protected final String prefix; + + UrlMatchResult(String prefix) { + this.prefix = prefix; + } + + abstract String rewriteUrl(); + } + + /** + * Match result for bucket endpoints (with AZ): {prefix}s3express{fips}-{AZ}{ds}.{region} + */ + private static final class BucketUrlMatchResult extends UrlMatchResult { + private final String s3express; + private final String az; + private final String regionSuffix; + + BucketUrlMatchResult(String url, Matcher m) { + super(url.substring(0, m.start())); + this.s3express = m.group(1); + this.az = m.group(2); + this.regionSuffix = m.group(3); + } + + @Override + String rewriteUrl() { + return String.format("%s%s{%s}-%s{%s}.%s", prefix, s3express, VAR_FIPS, az, VAR_DS, regionSuffix); + } + } + + /** + * Match result for control plane endpoints (no AZ): {prefix}s3express-control{fips}{ds}.{region} + */ + private static final class ControlPlaneUrlMatchResult extends UrlMatchResult { + private final String s3expressControl; + private final String regionSuffix; + + ControlPlaneUrlMatchResult(String url, Matcher m) { + super(url.substring(0, m.start())); + this.s3expressControl = m.group(1); + this.regionSuffix = m.group(2); + } + + @Override + String rewriteUrl() { + return String.format("%s%s{%s}{%s}.%s", prefix, s3expressControl, VAR_FIPS, VAR_DS, regionSuffix); + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java index 9e0dfb0fd39..2dda13db308 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/CoreExtension.java @@ -11,6 +11,7 @@ import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsSet; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.IsValidHostLabel; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Not; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.ParseUrl; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Split; @@ -43,6 +44,7 @@ public List getLibraryFunctions() { Split.getDefinition(), StringEquals.getDefinition(), Substring.getDefinition(), + Ite.getDefinition(), UriEncode.getDefinition()); } } diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java index 6e8d70a8771..efc82026a30 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/evaluation/RuleEvaluator.java @@ -215,6 +215,12 @@ public Value visitStringEquals(Expression left, Expression right) { .equals(right.accept(this).expectStringValue())); } + @Override + public Value visitIte(Expression condition, Expression trueValue, Expression falseValue) { + boolean cond = condition.accept(this).expectBooleanValue().getValue(); + return cond ? trueValue.accept(this) : falseValue.accept(this); + } + @Override public Value visitGetAttr(GetAttr getAttr) { return getAttr.evaluate(getAttr.getTarget().accept(this)); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java index 1557b529b52..b4bbc93868f 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/ExpressionVisitor.java @@ -4,10 +4,12 @@ */ package software.amazon.smithy.rulesengine.language.syntax.expressions; +import java.util.Arrays; import java.util.List; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.FunctionDefinition; import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.GetAttr; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; import software.amazon.smithy.utils.SmithyUnstableApi; @@ -86,6 +88,18 @@ default R visitCoalesce(List expressions) { */ R visitStringEquals(Expression left, Expression right); + /** + * Visits an if-then-else (ITE) function. + * + * @param condition the boolean condition expression. + * @param trueValue the value if condition is true. + * @param falseValue the value if condition is false. + * @return the value from the visitor. + */ + default R visitIte(Expression condition, Expression trueValue, Expression falseValue) { + return visitLibraryFunction(Ite.getDefinition(), Arrays.asList(condition, trueValue, falseValue)); + } + /** * Visits a library function. * @@ -138,6 +152,11 @@ public R visitStringEquals(Expression left, Expression right) { return getDefault(); } + @Override + public R visitIte(Expression condition, Expression trueValue, Expression falseValue) { + return getDefault(); + } + @Override public R visitLibraryFunction(FunctionDefinition fn, List args) { return getDefault(); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java new file mode 100644 index 00000000000..30d383cbcd9 --- /dev/null +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java @@ -0,0 +1,174 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language.syntax.expressions.functions; + +import java.util.Arrays; +import java.util.List; +import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.evaluation.value.Value; +import software.amazon.smithy.rulesengine.language.syntax.ToExpression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.ExpressionVisitor; +import software.amazon.smithy.utils.SmithyUnstableApi; + +/** + * An if-then-else (ITE) function that returns one of two values based on a boolean condition. + * + *

This function is critical for avoiding SSA (Static Single Assignment) fragmentation in BDD compilation. + * By computing conditional values atomically without branching, it prevents the graph explosion that occurs when + * boolean flags like UseFips or UseDualStack create divergent paths with distinct variable identities. + * + *

Semantics: {@code ite(condition, trueValue, falseValue)} + *

    + *
  • If condition is true, returns trueValue
  • + *
  • If condition is false, returns falseValue
  • + *
  • The condition must be a non-optional boolean (use coalesce to provide a default if needed)
  • + *
+ * + *

Type checking rules (least upper bound of nullability): + *

    + *
  • {@code ite(Boolean, T, T) => T} - both non-optional, result is non-optional
  • + *
  • {@code ite(Boolean, T, Optional) => Optional} - any optional makes result optional
  • + *
  • {@code ite(Boolean, Optional, T) => Optional} - any optional makes result optional
  • + *
  • {@code ite(Boolean, Optional, Optional) => Optional} - both optional, result is optional
  • + *
+ * + *

Available since: rules engine 1.1. + */ +@SmithyUnstableApi +public final class Ite extends LibraryFunction { + public static final String ID = "ite"; + private static final Definition DEFINITION = new Definition(); + + private Ite(FunctionNode functionNode) { + super(DEFINITION, functionNode); + } + + /** + * Gets the {@link FunctionDefinition} implementation. + * + * @return the function definition. + */ + public static Definition getDefinition() { + return DEFINITION; + } + + /** + * Creates a {@link Ite} function from the given expressions. + * + * @param condition the boolean condition to evaluate + * @param trueValue the value to return if condition is true + * @param falseValue the value to return if condition is false + * @return The resulting {@link Ite} function. + */ + public static Ite ofExpressions(ToExpression condition, ToExpression trueValue, ToExpression falseValue) { + return DEFINITION.createFunction(FunctionNode.ofExpressions(ID, condition, trueValue, falseValue)); + } + + /** + * Creates a {@link Ite} function with a reference condition and string values. + * + * @param conditionRef the reference to a boolean parameter + * @param trueValue the string value if condition is true + * @param falseValue the string value if condition is false + * @return The resulting {@link Ite} function. + */ + public static Ite ofStrings(ToExpression conditionRef, String trueValue, String falseValue) { + return ofExpressions(conditionRef, Expression.of(trueValue), Expression.of(falseValue)); + } + + @Override + public RulesVersion availableSince() { + return RulesVersion.V1_1; + } + + @Override + public R accept(ExpressionVisitor visitor) { + return visitor.visitIte(getArguments().get(0), getArguments().get(1), getArguments().get(2)); + } + + @Override + public Type typeCheck(Scope scope) { + List args = getArguments(); + if (args.size() != 3) { + throw new IllegalArgumentException("ITE requires exactly 3 arguments, got " + args.size()); + } + + // Check condition is a boolean (non-optional) + Type conditionType = args.get(0).typeCheck(scope); + if (!conditionType.equals(Type.booleanType())) { + throw new IllegalArgumentException(String.format( + "ITE condition must be a non-optional Boolean, got %s. " + + "Use coalesce to provide a default for optional booleans.", + conditionType)); + } + + // Get trueValue and falseValue types + Type trueType = args.get(1).typeCheck(scope); + Type falseType = args.get(2).typeCheck(scope); + + // Extract base types (unwrap Optional if present) + Type trueBaseType = getInnerType(trueType); + Type falseBaseType = getInnerType(falseType); + + // Base types must match + if (!trueBaseType.equals(falseBaseType)) { + throw new IllegalArgumentException(String.format( + "ITE branches must have the same base type: true branch is %s, false branch is %s", + trueBaseType, + falseBaseType)); + } + + // Result is optional if EITHER branch is optional (least upper bound) + boolean resultIsOptional = (trueType instanceof OptionalType) || (falseType instanceof OptionalType); + return resultIsOptional ? Type.optionalType(trueBaseType) : trueBaseType; + } + + private static Type getInnerType(Type t) { + return (t instanceof OptionalType) ? ((OptionalType) t).inner() : t; + } + + /** + * A {@link FunctionDefinition} for the {@link Ite} function. + */ + public static final class Definition implements FunctionDefinition { + private Definition() {} + + @Override + public String getId() { + return ID; + } + + @Override + public List getArguments() { + // Actual type checking is done in typeCheck override + return Arrays.asList(Type.booleanType(), Type.anyType(), Type.anyType()); + } + + @Override + public Type getReturnType() { + // Actual return type is computed in typeCheck override + return Type.anyType(); + } + + @Override + public Value evaluate(List arguments) { + throw new UnsupportedOperationException("ITE evaluation is handled by ExpressionVisitor"); + } + + @Override + public Ite createFunction(FunctionNode functionNode) { + return new Ite(functionNode); + } + + @Override + public int getCost() { + return 10; + } + } +} diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CostOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CostOptimization.java index 11dbcc8b119..ea80fb56fbb 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CostOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/CostOptimization.java @@ -293,7 +293,7 @@ public static final class Builder implements SmithyBuilder { private Cfg cfg; private ConditionCostModel costModel; private ToDoubleFunction trueProbability; - private double maxAllowedGrowth = 0.1; + private double maxAllowedGrowth = 0.08; private int maxRounds = 30; private int topK = 50; @@ -333,7 +333,7 @@ public Builder trueProbability(ToDoubleFunction trueProbability) { } /** - * Sets the maximum allowed node growth as a fraction (default 0.1 or 10%). + * Sets the maximum allowed node growth as a fraction (default 0.08 or 8%). * * @param maxAllowedGrowth maximum growth (0.0 = no growth, 0.1 = 10% growth) * @return the builder diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java index 8f0d6686b9c..eda9157d1d0 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java @@ -59,7 +59,7 @@ public final class SiftingOptimization implements Function '%s' (different base names) for: %s", + varName, + globalVar, + canonical)); + } else if (!wouldCauseShadowing(globalVar, path, ancestorVars)) { variableRenameMap.put(varName, globalVar); consolidatedCount++; LOGGER.info(String.format("Consolidating '%s' -> '%s' for: %s", @@ -177,6 +184,42 @@ private void discoverBindingsInRule( } } + /** + * Checks if two variable names have the same base name. + * For SSA-style variables like "foo_1" and "foo_2", the base name is "foo". + * Variables without SSA suffix (like "s3e_fips" and "s3e_ds") are considered + * to have their full name as the base. + */ + private boolean hasSameBaseName(String var1, String var2) { + String base1 = getSsaBaseName(var1); + String base2 = getSsaBaseName(var2); + return base1.equals(base2); + } + + /** + * Extracts the SSA base name from a variable. + * If the variable ends with _N (where N is a number), strips the suffix. + * Otherwise returns the full name. + */ + private String getSsaBaseName(String varName) { + int lastUnderscore = varName.lastIndexOf('_'); + if (lastUnderscore > 0 && lastUnderscore < varName.length() - 1) { + String suffix = varName.substring(lastUnderscore + 1); + // Check if suffix is all digits + boolean allDigits = true; + for (int i = 0; i < suffix.length(); i++) { + if (!Character.isDigit(suffix.charAt(i))) { + allDigits = false; + break; + } + } + if (allDigits) { + return varName.substring(0, lastUnderscore); + } + } + return varName; + } + private boolean wouldCauseShadowing(String varName, String currentPath, Set ancestorVars) { // Check if using this variable name would shadow an ancestor variable if (ancestorVars.contains(varName)) { diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java index 9cd9d627c36..619ca5994ee 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java @@ -206,6 +206,16 @@ public static EndpointBddTrait fromNode(Node node) { results.add(NoMatchRule.INSTANCE); // Always add no-match at index 0 results.addAll(serializedResults); + // Validate that results have no conditions (all conditions are hoisted into the BDD) + for (int i = 1; i < results.size(); i++) { + Rule rule = results.get(i); + if (!rule.getConditions().isEmpty()) { + throw new IllegalArgumentException( + "BDD result at index " + i + " has conditions, but BDD results must not have conditions. " + + "All conditions should be hoisted into the BDD decision structure."); + } + } + String nodesBase64 = obj.expectStringMember("nodes").getValue(); int nodeCount = obj.expectNumberMember("nodeCount").getValue().intValue(); int rootRef = obj.expectNumberMember("root").getValue().intValue(); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java index c79ab5fbeb8..62e06b43742 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/validators/RuleSetAuthSchemesValidator.java @@ -124,7 +124,7 @@ private String validateAuthSchemeName( FromSourceLocation sourceLocation ) { Literal nameLiteral = authScheme.get(NAME); - if (nameLiteral == null) { + if (nameLiteral == null || nameLiteral.asStringLiteral().isEmpty()) { events.add(error(service, sourceLocation, String.format( @@ -133,13 +133,14 @@ private String validateAuthSchemeName( return null; } - String name = nameLiteral.asStringLiteral().map(s -> s.expectLiteral()).orElse(null); + // Try to get the name as a literal string. If the template contains variables + // (e.g., from branchless transforms like "{s3e_auth}"), we can't statically validate. + String name = nameLiteral.asStringLiteral() + .filter(t -> t.isStatic()) + .map(t -> t.expectLiteral()) + .orElse(null); if (name == null) { - events.add(error(service, - sourceLocation, - String.format( - "Expected `authSchemes` to have a `name` key with a string value but it did not: `%s`", - authScheme))); + // String literal with template variables - skip static validation return null; } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java new file mode 100644 index 00000000000..580fa7168e6 --- /dev/null +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java @@ -0,0 +1,234 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package software.amazon.smithy.rulesengine.language.syntax.functions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Ite; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; + +public class IteTest { + + @Test + void testIteBothBranchesNonOptionalString() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Literal.of("-fips"); + Expression falseValue = Literal.of(""); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + + Type resultType = ite.typeCheck(scope); + + // Both non-optional String => non-optional String + assertEquals(Type.stringType(), resultType); + } + + @Test + void testIteBothBranchesNonOptionalInteger() { + Expression condition = Expression.getReference(Identifier.of("useNewValue")); + Expression trueValue = Literal.of(100); + Expression falseValue = Literal.of(0); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("useNewValue", Type.booleanType()); + + Type resultType = ite.typeCheck(scope); + + assertEquals(Type.integerType(), resultType); + } + + @Test + void testIteTrueBranchOptional() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybeValue")); + Expression falseValue = Literal.of("default"); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeValue", Type.optionalType(Type.stringType())); + + Type resultType = ite.typeCheck(scope); + + // True branch optional => result is optional + assertEquals(Type.optionalType(Type.stringType()), resultType); + } + + @Test + void testIteFalseBranchOptional() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Literal.of("value"); + Expression falseValue = Expression.getReference(Identifier.of("maybeDefault")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeDefault", Type.optionalType(Type.stringType())); + + Type resultType = ite.typeCheck(scope); + + // False branch optional => result is optional + assertEquals(Type.optionalType(Type.stringType()), resultType); + } + + @Test + void testIteBothBranchesOptional() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybe1")); + Expression falseValue = Expression.getReference(Identifier.of("maybe2")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybe1", Type.optionalType(Type.stringType())); + scope.insert("maybe2", Type.optionalType(Type.stringType())); + + Type resultType = ite.typeCheck(scope); + + // Both optional => result is optional + assertEquals(Type.optionalType(Type.stringType()), resultType); + } + + @Test + void testIteWithOfStringsHelper() { + Expression condition = Expression.getReference(Identifier.of("UseFIPS")); + Ite ite = Ite.ofStrings(condition, "-fips", ""); + + Scope scope = new Scope<>(); + scope.insert("UseFIPS", Type.booleanType()); + + Type resultType = ite.typeCheck(scope); + + // Both literal strings => non-optional String + assertEquals(Type.stringType(), resultType); + } + + @Test + void testIteTypeMismatchBetweenBranches() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Literal.of("string"); + Expression falseValue = Literal.of(42); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> ite.typeCheck(scope)); + assertTrue(ex.getMessage().contains("same base type")); + assertTrue(ex.getMessage().contains("true branch")); + assertTrue(ex.getMessage().contains("false branch")); + } + + @Test + void testIteConditionMustBeBoolean() { + Expression condition = Literal.of("not a boolean"); + Expression trueValue = Literal.of("yes"); + Expression falseValue = Literal.of("no"); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> ite.typeCheck(scope)); + assertTrue(ex.getMessage().contains("non-optional Boolean")); + } + + @Test + void testIteConditionCannotBeOptionalBoolean() { + Expression condition = Expression.getReference(Identifier.of("maybeFlag")); + Expression trueValue = Literal.of("yes"); + Expression falseValue = Literal.of("no"); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("maybeFlag", Type.optionalType(Type.booleanType())); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> ite.typeCheck(scope)); + assertTrue(ex.getMessage().contains("non-optional Boolean")); + assertTrue(ex.getMessage().contains("coalesce")); + } + + @Test + void testIteWithArrayTypes() { + Expression condition = Expression.getReference(Identifier.of("useFirst")); + Expression trueValue = Expression.getReference(Identifier.of("array1")); + Expression falseValue = Expression.getReference(Identifier.of("array2")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("useFirst", Type.booleanType()); + scope.insert("array1", Type.arrayType(Type.stringType())); + scope.insert("array2", Type.arrayType(Type.stringType())); + + Type resultType = ite.typeCheck(scope); + + assertEquals(Type.arrayType(Type.stringType()), resultType); + } + + @Test + void testIteWithOptionalArrayType() { + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybeArray")); + Expression falseValue = Expression.getReference(Identifier.of("definiteArray")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeArray", Type.optionalType(Type.arrayType(Type.integerType()))); + scope.insert("definiteArray", Type.arrayType(Type.integerType())); + + Type resultType = ite.typeCheck(scope); + + // One optional array => result is optional array + assertEquals(Type.optionalType(Type.arrayType(Type.integerType())), resultType); + } + + @Test + void testIteWithBooleanValues() { + Expression condition = Expression.getReference(Identifier.of("invertFlag")); + Expression trueValue = Literal.of(false); + Expression falseValue = Literal.of(true); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("invertFlag", Type.booleanType()); + + Type resultType = ite.typeCheck(scope); + + assertEquals(Type.booleanType(), resultType); + } + + @Test + void testIteTypeMismatchWithOptionalUnwrapping() { + // Even with optional wrapping, base types must match + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybeString")); + Expression falseValue = Expression.getReference(Identifier.of("maybeInt")); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeString", Type.optionalType(Type.stringType())); + scope.insert("maybeInt", Type.optionalType(Type.integerType())); + + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> ite.typeCheck(scope)); + assertTrue(ex.getMessage().contains("same base type")); + } + + @Test + void testIteReturnsCorrectId() { + assertEquals("ite", Ite.ID); + assertEquals("ite", Ite.getDefinition().getId()); + } +} diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.errors b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.errors new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.errors @@ -0,0 +1 @@ + diff --git a/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.smithy b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.smithy new file mode 100644 index 00000000000..75a4d4d050c --- /dev/null +++ b/smithy-rules-engine/src/test/resources/software/amazon/smithy/rulesengine/language/errorfiles/valid/ite-basic.smithy @@ -0,0 +1,80 @@ +$version: "2.0" + +namespace example + +use smithy.rules#clientContextParams +use smithy.rules#endpointRuleSet +use smithy.rules#endpointTests + +@clientContextParams( + useFips: {type: "boolean", documentation: "Use FIPS endpoints"} +) +@endpointRuleSet({ + version: "1.1", + parameters: { + useFips: { + type: "boolean", + documentation: "Use FIPS endpoints", + default: false, + required: true + } + }, + rules: [ + { + "documentation": "Use ite to select endpoint suffix" + "conditions": [ + { + "fn": "ite" + "argv": [{"ref": "useFips"}, "-fips", ""] + "assign": "suffix" + } + ] + "endpoint": { + "url": "https://example{suffix}.com" + } + "type": "endpoint" + } + ] +}) +@endpointTests({ + "version": "1.0", + "testCases": [ + { + "documentation": "When useFips is true, returns trueValue" + "params": { + "useFips": true + } + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "endpoint": { + "url": "https://example-fips.com" + } + } + } + { + "documentation": "When useFips is false, returns falseValue" + "params": { + "useFips": false + } + "operationInputs": [{ + "operationName": "GetThing" + }], + "expect": { + "endpoint": { + "url": "https://example.com" + } + } + } + ] +}) +@suppress(["UnstableTrait.smithy"]) +service FizzBuzz { + version: "2022-01-01", + operations: [GetThing] +} + +operation GetThing { + input := {} +} From cffa45f14ef8b6a1b4c99da96869e9aae3a74e64 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Thu, 18 Dec 2025 10:27:47 -0600 Subject: [PATCH 3/5] Run type checking on BDD so type() works --- .../expressions/functions/Coalesce.java | 7 +-- .../syntax/expressions/functions/Ite.java | 9 ++-- .../language/syntax/rule/NoMatchRule.java | 2 +- .../rulesengine/traits/EndpointBddTrait.java | 18 ++++++- .../syntax/functions/CoalesceTest.java | 28 ++++++++-- .../language/syntax/functions/IteTest.java | 52 +++++++++++++++++-- .../rulesengine/traits/BddTraitTest.java | 48 ++++++++++++++++- 7 files changed, 146 insertions(+), 18 deletions(-) diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java index 931e5d9f9dd..039d461350e 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Coalesce.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Optional; import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.error.InnerParseError; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; @@ -81,10 +82,10 @@ public R accept(ExpressionVisitor visitor) { } @Override - public Type typeCheck(Scope scope) { + protected Type typeCheckLocal(Scope scope) throws InnerParseError { List args = getArguments(); if (args.size() < 2) { - throw new IllegalArgumentException("Coalesce requires at least 2 arguments, got " + args.size()); + throw new InnerParseError("Coalesce requires at least 2 arguments, got " + args.size()); } // Get the first argument's type as the baseline @@ -98,7 +99,7 @@ public Type typeCheck(Scope scope) { Type innerType = getInnerType(argType); if (!innerType.equals(baseInnerType)) { - throw new IllegalArgumentException(String.format( + throw new InnerParseError(String.format( "Type mismatch in coalesce at argument %d: expected %s but got %s", i + 1, baseInnerType, diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java index 30d383cbcd9..90acc71da01 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/expressions/functions/Ite.java @@ -7,6 +7,7 @@ import java.util.Arrays; import java.util.List; import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.error.InnerParseError; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.OptionalType; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; @@ -93,16 +94,16 @@ public R accept(ExpressionVisitor visitor) { } @Override - public Type typeCheck(Scope scope) { + protected Type typeCheckLocal(Scope scope) throws InnerParseError { List args = getArguments(); if (args.size() != 3) { - throw new IllegalArgumentException("ITE requires exactly 3 arguments, got " + args.size()); + throw new InnerParseError("ITE requires exactly 3 arguments, got " + args.size()); } // Check condition is a boolean (non-optional) Type conditionType = args.get(0).typeCheck(scope); if (!conditionType.equals(Type.booleanType())) { - throw new IllegalArgumentException(String.format( + throw new InnerParseError(String.format( "ITE condition must be a non-optional Boolean, got %s. " + "Use coalesce to provide a default for optional booleans.", conditionType)); @@ -118,7 +119,7 @@ public Type typeCheck(Scope scope) { // Base types must match if (!trueBaseType.equals(falseBaseType)) { - throw new IllegalArgumentException(String.format( + throw new InnerParseError(String.format( "ITE branches must have the same base type: true branch is %s, false branch is %s", trueBaseType, falseBaseType)); diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java index d7c76f7feec..be58bb2415d 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/language/syntax/rule/NoMatchRule.java @@ -28,7 +28,7 @@ public T accept(RuleValueVisitor visitor) { @Override protected Type typecheckValue(Scope scope) { - throw new UnsupportedOperationException("NO_MATCH is a sentinel"); + return Type.anyType(); } @Override diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java index 619ca5994ee..52ffd2b4f36 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/traits/EndpointBddTrait.java @@ -23,6 +23,8 @@ import software.amazon.smithy.model.traits.AbstractTraitBuilder; import software.amazon.smithy.model.traits.Trait; import software.amazon.smithy.rulesengine.language.RulesVersion; +import software.amazon.smithy.rulesengine.language.evaluation.Scope; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.NoMatchRule; @@ -360,7 +362,21 @@ public Builder bdd(Bdd bdd) { @Override public EndpointBddTrait build() { - return new EndpointBddTrait(this); + EndpointBddTrait trait = new EndpointBddTrait(this); + + // Type-check conditions and results so expression.type() works. Note that using a shared scope across + // each check is ok, because BDD evaluation always runs conditions in a fixed order and could in theory + // try every condition for a single path to a result. + Scope scope = new Scope<>(); + trait.getParameters().writeToScope(scope); + for (Condition condition : trait.getConditions()) { + condition.typeCheck(scope); + } + for (Rule result : trait.getResults()) { + result.typeCheck(scope); + } + + return trait; } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java index bf6ac4bb9da..7dd1d118883 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/CoalesceTest.java @@ -10,6 +10,7 @@ import java.util.Arrays; import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.error.RuleError; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.syntax.Identifier; @@ -135,7 +136,7 @@ void testCoalesceWithIncompatibleTypes() { Scope scope = new Scope<>(); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> coalesce.typeCheck(scope)); + RuleError ex = assertThrows(RuleError.class, () -> coalesce.typeCheck(scope)); assertTrue(ex.getMessage().contains("Type mismatch in coalesce")); assertTrue(ex.getMessage().contains("argument 2")); } @@ -151,7 +152,7 @@ void testCoalesceWithIncompatibleTypesInMiddle() { Scope scope = new Scope<>(); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> coalesce.typeCheck(scope)); + RuleError ex = assertThrows(RuleError.class, () -> coalesce.typeCheck(scope)); assertTrue(ex.getMessage().contains("Type mismatch in coalesce")); assertTrue(ex.getMessage().contains("argument 3")); } @@ -160,8 +161,7 @@ void testCoalesceWithIncompatibleTypesInMiddle() { void testCoalesceWithLessThanTwoArguments() { Expression single = Literal.of("only"); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, - () -> Coalesce.ofExpressions(single).typeCheck(new Scope<>())); + RuleError ex = assertThrows(RuleError.class, () -> Coalesce.ofExpressions(single).typeCheck(new Scope<>())); assertTrue(ex.getMessage().contains("at least 2 arguments")); } @@ -215,4 +215,24 @@ void testCoalesceWithBooleanTypes() { assertEquals(Type.booleanType(), resultType); } + + @Test + void testTypeMethodReturnsInferredTypeAfterTypeCheck() { + // Verify that type() returns the correct inferred type after typeCheck() + Expression optional1 = Expression.getReference(Identifier.of("maybeValue1")); + Expression optional2 = Expression.getReference(Identifier.of("maybeValue2")); + Expression definite = Literal.of("default"); + Coalesce coalesce = Coalesce.ofExpressions(optional1, optional2, definite); + + Scope scope = new Scope<>(); + scope.insert("maybeValue1", Type.optionalType(Type.stringType())); + scope.insert("maybeValue2", Type.optionalType(Type.stringType())); + + // Call typeCheck to cache the type + coalesce.typeCheck(scope); + + // Now type() should return the inferred type (non-optional since last arg is definite) + Type cachedType = coalesce.type(); + assertEquals(Type.stringType(), cachedType); + } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java index 580fa7168e6..5c57ed7dc6a 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/language/syntax/functions/IteTest.java @@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import org.junit.jupiter.api.Test; +import software.amazon.smithy.rulesengine.language.error.RuleError; import software.amazon.smithy.rulesengine.language.evaluation.Scope; import software.amazon.smithy.rulesengine.language.evaluation.type.Type; import software.amazon.smithy.rulesengine.language.syntax.Identifier; @@ -125,7 +126,7 @@ void testIteTypeMismatchBetweenBranches() { Scope scope = new Scope<>(); scope.insert("flag", Type.booleanType()); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> ite.typeCheck(scope)); + RuleError ex = assertThrows(RuleError.class, () -> ite.typeCheck(scope)); assertTrue(ex.getMessage().contains("same base type")); assertTrue(ex.getMessage().contains("true branch")); assertTrue(ex.getMessage().contains("false branch")); @@ -140,7 +141,7 @@ void testIteConditionMustBeBoolean() { Scope scope = new Scope<>(); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> ite.typeCheck(scope)); + RuleError ex = assertThrows(RuleError.class, () -> ite.typeCheck(scope)); assertTrue(ex.getMessage().contains("non-optional Boolean")); } @@ -154,7 +155,7 @@ void testIteConditionCannotBeOptionalBoolean() { Scope scope = new Scope<>(); scope.insert("maybeFlag", Type.optionalType(Type.booleanType())); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> ite.typeCheck(scope)); + RuleError ex = assertThrows(RuleError.class, () -> ite.typeCheck(scope)); assertTrue(ex.getMessage().contains("non-optional Boolean")); assertTrue(ex.getMessage().contains("coalesce")); } @@ -222,7 +223,7 @@ void testIteTypeMismatchWithOptionalUnwrapping() { scope.insert("maybeString", Type.optionalType(Type.stringType())); scope.insert("maybeInt", Type.optionalType(Type.integerType())); - IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, () -> ite.typeCheck(scope)); + RuleError ex = assertThrows(RuleError.class, () -> ite.typeCheck(scope)); assertTrue(ex.getMessage().contains("same base type")); } @@ -231,4 +232,47 @@ void testIteReturnsCorrectId() { assertEquals("ite", Ite.ID); assertEquals("ite", Ite.getDefinition().getId()); } + + @Test + void testTypeMethodReturnsInferredTypeAfterTypeCheck() { + // Verify that type() returns the correct inferred type after typeCheck() + Expression condition = Expression.getReference(Identifier.of("flag")); + Expression trueValue = Expression.getReference(Identifier.of("maybeValue")); + Expression falseValue = Literal.of("default"); + Ite ite = Ite.ofExpressions(condition, trueValue, falseValue); + + Scope scope = new Scope<>(); + scope.insert("flag", Type.booleanType()); + scope.insert("maybeValue", Type.optionalType(Type.stringType())); + + // Call typeCheck to cache the type + ite.typeCheck(scope); + + // Now type() should return the inferred type + Type cachedType = ite.type(); + assertEquals(Type.optionalType(Type.stringType()), cachedType); + } + + @Test + void testNestedIteTypeInference() { + // Test that nested Ite expressions have correct type inference + Expression outerCondition = Expression.getReference(Identifier.of("outer")); + Expression innerCondition = Expression.getReference(Identifier.of("inner")); + + // Inner ITE: ite(inner, "a", "b") => String + Ite innerIte = Ite.ofExpressions(innerCondition, Literal.of("a"), Literal.of("b")); + + // Outer ITE: ite(outer, innerIte, "c") => String + Ite outerIte = Ite.ofExpressions(outerCondition, innerIte, Literal.of("c")); + + Scope scope = new Scope<>(); + scope.insert("outer", Type.booleanType()); + scope.insert("inner", Type.booleanType()); + + outerIte.typeCheck(scope); + + // Both inner and outer should have String type + assertEquals(Type.stringType(), innerIte.type()); + assertEquals(Type.stringType(), outerIte.type()); + } } diff --git a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java index 88c2aab778e..a4defa53c43 100644 --- a/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java +++ b/smithy-rules-engine/src/test/java/software/amazon/smithy/rulesengine/traits/BddTraitTest.java @@ -12,6 +12,13 @@ import java.util.List; import org.junit.jupiter.api.Test; import software.amazon.smithy.model.node.Node; +import software.amazon.smithy.rulesengine.language.evaluation.type.Type; +import software.amazon.smithy.rulesengine.language.syntax.Identifier; +import software.amazon.smithy.rulesengine.language.syntax.expressions.Expression; +import software.amazon.smithy.rulesengine.language.syntax.expressions.functions.Coalesce; +import software.amazon.smithy.rulesengine.language.syntax.expressions.literal.Literal; +import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter; +import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType; import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameters; import software.amazon.smithy.rulesengine.language.syntax.rule.Condition; import software.amazon.smithy.rulesengine.language.syntax.rule.EndpointRule; @@ -25,7 +32,11 @@ public class BddTraitTest { @Test void testBddTraitSerialization() { // Create a BddTrait with full context - Parameters params = Parameters.builder().build(); + Parameter regionParam = Parameter.builder() + .name("Region") + .type(ParameterType.STRING) + .build(); + Parameters params = Parameters.builder().addParameter(regionParam).build(); Condition cond = Condition.builder().fn(TestHelpers.isSet("Region")).build(); Rule endpoint = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); @@ -99,4 +110,39 @@ void testEmptyBddTrait() { assertEquals(1, trait.getResults().size()); assertEquals(-1, trait.getBdd().getRootRef()); // FALSE terminal } + + @Test + void testBuildTypeChecksExpressionsForCodegen() { + // Verify that after building an EndpointBddTrait, expression.type() works + // This is important for codegen to infer types without a scope + Parameter regionParam = Parameter.builder() + .name("Region") + .type(ParameterType.STRING) + .build(); + Parameters params = Parameters.builder().addParameter(regionParam).build(); + + // Create a condition with a coalesce that infers to String + Expression regionRef = Expression.getReference(Identifier.of("Region")); + Expression fallback = Literal.of("us-east-1"); + Coalesce coalesce = Coalesce.ofExpressions(regionRef, fallback); + Condition cond = Condition.builder().fn(coalesce).result(Identifier.of("resolvedRegion")).build(); + + Rule endpoint = EndpointRule.builder().endpoint(TestHelpers.endpoint("https://example.com")); + + List results = new ArrayList<>(); + results.add(NoMatchRule.INSTANCE); + results.add(endpoint); + + EndpointBddTrait trait = EndpointBddTrait.builder() + .parameters(params) + .conditions(ListUtils.of(cond)) + .results(results) + .bdd(createSimpleBdd()) + .build(); + + // After build(), type() should work on the coalesce expression + // Region is Optional, fallback is String, so result is String (non-optional) + Coalesce builtCoalesce = (Coalesce) trait.getConditions().get(0).getFunction(); + assertEquals(Type.stringType(), builtCoalesce.type()); + } } From 3f2c363a46b72c767cfb3d1e2df5f5b5e8b1bb30 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Fri, 19 Dec 2025 10:31:03 -0600 Subject: [PATCH 4/5] Improve some loops --- .../logic/bdd/SiftingOptimization.java | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java index eda9157d1d0..b779eceb4f3 100644 --- a/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java +++ b/smithy-rules-engine/src/main/java/software/amazon/smithy/rulesengine/logic/bdd/SiftingOptimization.java @@ -249,9 +249,12 @@ private void runBlockMoves(State state) { } LOGGER.info("Running block moves"); - List> blocks = findDependencyBlocks(state.orderView).stream() - .filter(b -> b.size() >= 2 && b.size() <= 5) - .collect(Collectors.toList()); + List> blocks = new ArrayList<>(); + for (List b : findDependencyBlocks(state.orderView)) { + if (b.size() >= 2 && b.size() <= 5) { + blocks.add(b); + } + } for (List block : blocks) { PassContext ctx = new PassContext(state, dependencyGraph); @@ -464,7 +467,13 @@ private Result findBestPosition(List positions, PassContext ctx, int va } // Second pass: among min-size candidates, pick lowest cost - int minSize = candidates.stream().mapToInt(c -> c.size).min().orElse(Integer.MAX_VALUE); + int minSize = Integer.MAX_VALUE; + for (Result c : candidates) { + if (c.size < minSize) { + minSize = c.size; + } + } + Result best = null; for (Result c : candidates) { if (c.size == minSize) { From 789c130c8cbb287e8eb4b0900f6a2a739b651747 Mon Sep 17 00:00:00 2001 From: Michael Dowling Date: Fri, 19 Dec 2025 10:56:39 -0600 Subject: [PATCH 5/5] Attempt to fix the windows build --- settings.gradle.kts | 2 ++ smithy-aws-endpoints/build.gradle.kts | 14 +++++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/settings.gradle.kts b/settings.gradle.kts index f3c9eba093a..bffb67b89a7 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -6,6 +6,8 @@ pluginManagement { } } + + rootProject.name = "smithy" include(":smithy-aws-iam-traits") diff --git a/smithy-aws-endpoints/build.gradle.kts b/smithy-aws-endpoints/build.gradle.kts index 35c213e1e70..559731a51a3 100644 --- a/smithy-aws-endpoints/build.gradle.kts +++ b/smithy-aws-endpoints/build.gradle.kts @@ -26,7 +26,7 @@ dependencies { } // Integration test source set for tests that require the S3 model -// These tests require JDK 17+ due to the S3 model dependency +// These tests require JDK 21+ due to the S3 model dependency sourceSets { create("it") { compileClasspath += sourceSets["main"].output + sourceSets["test"].output @@ -38,15 +38,15 @@ configurations["itImplementation"].extendsFrom(configurations["testImplementatio configurations["itRuntimeOnly"].extendsFrom(configurations["testRuntimeOnly"]) configurations["itImplementation"].extendsFrom(s3Model) -// Configure IT source set to compile with JDK 17 +// Configure IT source set to compile with JDK 21 tasks.named("compileItJava") { javaCompiler.set( javaToolchains.compilerFor { - languageVersion.set(JavaLanguageVersion.of(17)) + languageVersion.set(JavaLanguageVersion.of(21)) }, ) - sourceCompatibility = "17" - targetCompatibility = "17" + sourceCompatibility = "21" + targetCompatibility = "21" } val integrationTest by tasks.registering(Test::class) { @@ -57,10 +57,10 @@ val integrationTest by tasks.registering(Test::class) { dependsOn(tasks.jar) shouldRunAfter(tasks.test) - // Run with JDK 17 + // Run with JDK 21 javaLauncher.set( javaToolchains.launcherFor { - languageVersion.set(JavaLanguageVersion.of(17)) + languageVersion.set(JavaLanguageVersion.of(21)) }, ) }