-
Notifications
You must be signed in to change notification settings - Fork 498
Dynamic Inference Headers with Prediction Trie Integration #1483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamic Inference Headers with Prediction Trie Integration #1483
Conversation
Introduced context management using DynamoPrefixContext to optimize KV cache by setting unique prefix IDs per workflow run. This includes adding lazy imports to avoid circular dependencies and ensuring the context is cleared in `finally` blocks to prevent leaks. Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Introduces comprehensive test cases to ensure `DynamoPrefixContext` is properly set, cleared, and associated with unique workflow run IDs during execution. This includes tests for normal operations, streaming results, error handling, and pre-set workflow run IDs, enhancing reliability and coverage for the Runner class. Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Design for a prediction system that provides Dynamo inference server with expected workload characteristics (remaining calls, inter-arrival time, output length) for each LLM call, enabling smarter routing. Key components: - PredictionTrie: hierarchical structure storing metrics at every path granularity - TrieBuilder: processes profiler traces into trie - Runtime lookup with graceful fallback to less specific matches - Header injection in dynamo_langchain LLM client Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Detailed 10-task TDD implementation plan: 1. Data models (PredictionMetrics, LLMCallPrediction, PredictionTrieNode) 2. Metrics accumulator for computing statistics 3. Trie builder from profiler traces 4. Trie lookup with fallback 5. JSON serialization 6. Runtime call tracker (contextvars) 7. Profiler integration (config + generation) 8. Dynamo header injection 9. LangChain integration 10. End-to-end test Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Add Pydantic models for the prediction trie: - PredictionMetrics: aggregated stats (mean, p50, p90, p95) - LLMCallPrediction: predictions for remaining calls, interarrival time, output tokens - PredictionTrieNode: trie node with children and predictions by call index Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Accumulates sample values and computes aggregated statistics (mean, p50, p90, p95) using linear interpolation for percentiles. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Builds prediction trie from profiler execution traces: - Extracts LLM call contexts (path, call index, remaining, interarrival, output tokens) - Aggregates metrics at every node along the path - Computes stats by call index and aggregated fallback Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Walks the trie to find best matching prediction: - Exact path + exact call_index (most specific) - Partial path + exact call_index - Falls back to aggregated predictions when call_index not found Signed-off-by: Claude <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
JSON serialization with metadata: - version, generated_at, workflow_name - Recursive node serialization/deserialization - Handles predictions_by_call_index int keys Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Context variable-based tracking of LLM call indices per function invocation. Thread/async-safe using contextvars. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Add PredictionTrieConfig to ProfilerConfig with enable flag. ProfilerRunner now builds and saves prediction_trie.json when enabled. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Injects x-nat-remaining-llm-calls, x-nat-interarrival-ms, and x-nat-expected-output-tokens headers for server routing optimization. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Allows specifying a prediction_trie.json file path in workflow config. When set, predictions are looked up and injected as headers. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Validates complete flow: profiler traces -> trie generation -> lookup with different agents and call indices. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Design document for integrating the prediction trie with runtime workflow execution: - Add function_path_stack ContextVar for full ancestry tracking - Increment call tracker in IntermediateStepManager on LLM_START - Dynamic httpx hook for per-request prediction lookup - Fallback chain to root aggregates when no match found Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Detailed TDD implementation plan for integrating prediction trie lookups at runtime: - Task 1: Add function_path_stack ContextVar - Task 2: Track path in push_active_function - Task 3: Increment call tracker in IntermediateStepManager - Task 4: Create dynamic prediction hook - Task 5: Update httpx client creation - Task 6: Load trie in LangChain Dynamo client - Task 7: End-to-end integration test Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Add a new ContextVar to track the full function ancestry path as a list of function names. This will be used by the runtime prediction trie integration to perform prediction lookups using the full path (e.g., ["my_workflow", "react_agent", "tool"]). The implementation follows the existing pattern of active_span_id_stack, using a private ContextVar with None default and a property that lazily initializes to an empty list. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Update the push_active_function context manager to push/pop function names on the function_path_stack ContextVar. This enables tracking the complete ancestry of the currently executing function from root to leaf. Changes: - Push function name onto path stack when entering push_active_function - Pop function name using ContextVar.reset(token) when exiting - Add Context.function_path property that returns a copy of the path stack Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Update IntermediateStepManager.push_intermediate_step() to increment the LLMCallTracker whenever an LLM_START event is pushed. This ensures call indices are tracked for all LLM frameworks (LangChain, LlamaIndex, etc.) since they all push events through this manager. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Create _create_dynamic_prediction_hook function that dynamically looks up predictions from the trie based on current context (function path + call index) and injects headers for Dynamo optimization. The hook: - Reads Context.function_path to get current ancestry - Reads LLMCallTracker.counts to get current call index - Looks up prediction in trie using trie_lookup.find(path, call_index) - Injects headers: x-nat-remaining-llm-calls, x-nat-interarrival-ms, x-nat-expected-output-tokens This is part of the dynamic inference headers feature for KV cache optimization with NVIDIA Dynamo. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Update test expectation to include prediction_trie_path which was added to DynamoModelConfig.get_dynamo_field_names() in a previous task. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Update the function to accept an optional PredictionTrieLookup parameter. When provided, adds the dynamic prediction hook to the list of hooks, enabling runtime header injection based on trie predictions. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Validates the complete flow from function path tracking through header injection: - function_path_stack updates on push_active_function - IntermediateStepManager increments call tracker on LLM_START - Dynamic hook reads context and looks up predictions - Correct headers injected based on call index Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Design for two-phase Dynamo optimization workflow: - Phase 1: Profile with prediction_trie.enable to build trie - Phase 2: Run with prediction_trie_path for dynamic headers Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Add two-phase workflow for Dynamo optimization using prediction trie: Phase 1 (profiling): - Enable prediction_trie.enable in profile_rethinking_full_test.yml - Builds trie from profiled execution data Phase 2 (runtime): - New run_with_prediction_trie.yml config - Loads trie and injects dynamic headers per LLM call - Headers: x-nat-remaining-llm-calls, x-nat-interarrival-ms, x-nat-expected-output-tokens Documentation: - README_PREDICTION_TRIE.md with quick start, how it works, configuration reference, and troubleshooting Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Updated x-prefix-* headers to use categorical values (LOW/MEDIUM/HIGH) derived from prediction metrics. Introduced support for loading and handling prediction trie files for dynamic header overrides, ensuring consistent and contextually accurate LLM request annotations. Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Moved logic for setting Dynamo prefix ID into `DynamoPrefixContext.get` for better reusability and clarity. Removed redundant code from request header injection, ensuring consistent prefix generation and logging behavior. Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Introduced depth-aware prefix ID generation for more granular control of prefix IDs across nested function calls. Replaced the previous context variable approach with a depth mapping mechanism and added support for override prefixes. Updated relevant tests for clarity and alignment with the new depth-based behavior. Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Introduced depth-aware prefix ID generation for more granular control of prefix IDs across nested function calls. Replaced the previous context variable approach with a depth mapping mechanism and added support for override prefixes. Updated relevant tests for clarity and alignment with the new depth-based behavior. Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
Introduced depth-aware prefix ID generation for more granular control of prefix IDs across nested function calls. Replaced the previous context variable approach with a depth mapping mechanism and added support for override prefixes. Updated relevant tests for clarity and alignment with the new depth-based behavior. Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
…dynamic-inference-headers
|
/ok to test a611564 |
Eliminated the setup and cleanup of DynamoPrefixContext from the Runner class as it is no longer required for KV cache optimization. This simplifies the workflow logic and reduces dependencies, ensuring cleaner and more maintainable code. Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 11
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tests/nat/llm/test_dynamo_llm.py (2)
259-387: Remove@pytest.mark.asynciodecorators from async tests.Async tests are automatically detected and run by pytest-asyncio when
asyncio_mode = "auto"is configured (as it is in yourpyproject.toml). The decorator is unnecessary clutter and violates coding guidelines.This applies to all 7 async tests in this range (lines 260, 283, 304, 331, 350, 369, 387).
🧹 Suggested diff (apply to all async tests in this file)
- `@pytest.mark.asyncio` async def test_hook_injects_headers(self):
252-257: Rename fixture to follow the naming convention specified in coding guidelines.Class-based fixtures should use the
fixture_prefix and specifyname=in the decorator. Per the coding guidelines: fixture functions should be named using thefixture_prefix with snake_case, and include anameargument in the decorator.Suggested diff
- `@pytest.fixture`(autouse=True) - def clean_context(self): + `@pytest.fixture`(name="clean_context", autouse=True) + def fixture_clean_context(self):
🤖 Fix all issues with AI agents
In `@examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md`:
- Line 1: Add the required SPDX Apache-2.0 header at the very top of the
README_PREDICTION_TRIE.md file (above the "# Prediction Trie Optimization for
Dynamo" title); insert the standard two-line header "SPDX-License-Identifier:
Apache-2.0" (and optionally the copyright holder line if project policy requires
it) so the markdown begins with the SPDX identifier before any other content.
- Around line 7-11: Replace the possessive forms "Dynamo's KV-aware routing" and
"Dynamo's Thompson Sampling router" in the README_PREDICTION_TRIE.md text with
non-possessive alternatives (e.g., "Dynamo KV-aware routing", "the Dynamo
KV-aware routing", or "the Dynamo Thompson Sampling router") so the doc avoids
using ’s for an inanimate object; update the two phrases in the paragraph
describing dynamic header injection and the router to use the chosen
non-possessive wording consistently.
- Around line 23-26: The fenced code block under the "Output location:" section
currently has no language specified; update that Markdown fence to include a
language token (e.g., change ``` to ```text) so the block becomes ```text and
contains the path
"outputs/dynamo_evals/rethinking_full_test_for_profiling/<job_id>/prediction_trie.json";
edit the README_PREDICTION_TRIE.md file and modify the fenced block shown in the
diff accordingly to satisfy markdownlint MD040.
In `@packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py`:
- Around line 226-238: The exception handler around load_prediction_trie should
log the full traceback instead of just the exception message; modify the except
Exception block in the LLM loading logic (where load_prediction_trie and
PredictionTrieLookup are used and llm_config.prediction_trie_path is referenced)
to call logger.exception(...) with a descriptive message so the stack trace is
captured while still allowing execution to continue.
In `@src/nat/llm/prediction_context.py`:
- Around line 1-8: Replace the abbreviated SPDX header at the top of
prediction_context.py with the full 14-line Apache-2.0 license header template
used across the repo: remove the short SPDX lines and insert the standard Apache
License, Version 2.0 header block (including copyright line, license notice, and
link to http://www.apache.org/licenses/LICENSE-2.0) as a contiguous header at
the very top of the file so the file matches other files in the PR.
In `@src/nat/profiler/prediction_trie/serialization.py`:
- Around line 1-25: This module is missing the required module-level docstring;
add a concise Google-style docstring immediately after the license header in
`serialization.py` that describes the module's purpose
(serializing/deserializing prediction trie structures) and mentions the key
public identifiers in backticks such as `LLMCallPrediction`,
`PredictionMetrics`, and `PredictionTrieNode`; ensure the docstring is a
top-level string literal (not a comment) and follows the license block so
linters and documentation tools will pick it up.
- Around line 53-66: load_prediction_trie currently ignores the saved version;
update it to read and validate the JSON's "version" field (the same version
emitted by save_prediction_trie) before deserializing: retrieve data["version"],
compare it against the module's expected version constant or literal used by
save_prediction_trie, and if missing or mismatched raise a clear error (e.g.,
ValueError) that aborts loading; only call _deserialize_node(data["root"]) after
the version check passes.
In `@src/nat/profiler/prediction_trie/trie_lookup.py`:
- Around line 1-18: Add a Google-style module docstring at the top of
`trie_lookup.py` (above the imports) consisting of a one-line summary ending
with a period and a short description that mentions the public entities (wrap
`LLMCallPrediction` and `PredictionTrieNode` in backticks). Keep it concise and
follow the project's docstring style (first line summary + optional short
paragraph), ensuring the module now has a proper docstring as required for
public modules.
In `@tests/nat/builder/test_function_path_stack.py`:
- Around line 1-6: This file currently only has SPDX identifiers at the top and
is missing the full Apache-2.0 license boilerplate; add the standard full Apache
2.0 header (the same multi-line boilerplate used in the other test files) at the
very top of the source before the existing imports (so it precedes the imports
of Context and ContextState and the module-level code in
test_function_path_stack.py); ensure the header matches the project's canonical
Apache-2.0 template including copyright holder and license text.
In `@tests/nat/llm/test_dynamo_prediction_headers.py`:
- Around line 4-6: The tests import and use
create_httpx_client_with_prediction_headers but currently catch a broad
Exception when the request is expected to fail; narrow the except clauses to
catch httpx.RequestError instead. Update the two failing-requests test blocks
(the except clauses around the client.request calls) to "except
httpx.RequestError as e" and import httpx.RequestError at the top so only
connection/request failures are handled, leaving other exceptions to surface.
In `@tests/nat/profiler/prediction_trie/test_metrics_accumulator.py`:
- Around line 20-29: The p90 assertion in
test_accumulator_add_multiple_samples() is using exact equality which can fail
due to floating-point precision; update the assertion to use pytest.approx for
metrics.p90 (similar to the existing p95 assertion) so that the test compares
metrics.p90 to pytest.approx(9.1) after calling
MetricsAccumulator().add_sample(...) and compute_metrics().
🧹 Nitpick comments (12)
tests/nat/llm/test_prediction_context.py (1)
24-30: Potential test isolation issue withContextVarstate.The
get_call_tracker()function uses a module-levelContextVarthat persists across test runs in the same process. Iftest_tracker_context_variableruns after other tests that callget_call_tracker(), the tracker may already exist with stale state.Consider resetting the
ContextVaror creating a fresh context to ensure test isolation.Proposed fix using contextvars.copy_context
+import contextvars + from nat.llm.prediction_context import LLMCallTracker from nat.llm.prediction_context import get_call_tracker +def test_tracker_context_variable(): + """Test that get_call_tracker returns consistent tracker in a fresh context.""" + def _inner(): + tracker1 = get_call_tracker() + tracker1.increment("func-a") + + tracker2 = get_call_tracker() + # Should be the same tracker in the same context + assert tracker2.increment("func-a") == 2 + + # Run in a fresh context to ensure isolation + ctx = contextvars.copy_context() + ctx.run(_inner) -def test_tracker_context_variable(): - tracker1 = get_call_tracker() - tracker1.increment("func-a") - - tracker2 = get_call_tracker() - # Should be the same tracker in the same context - assert tracker2.increment("func-a") == 2src/nat/builder/intermediate_step_manager.py (1)
100-108: Minor inefficiency: use return value fromincrement()instead of re-fetching.The
increment()method returns the new call index (as shown in the relevant snippets), but the code callstracker.counts.get()again for logging. Consider using the return value directly.♻️ Suggested improvement
if payload.event_type == IntermediateStepType.LLM_START: active_function = self._context_state.active_function.get() if active_function and active_function.function_id != "root": tracker = get_call_tracker() - tracker.increment(active_function.function_id) + call_index = tracker.increment(active_function.function_id) logger.debug("Incremented LLM call tracker for %s to %d", active_function.function_id, - tracker.counts.get(active_function.function_id, 0)) + call_index)tests/nat/profiler/prediction_trie/test_serialization.py (1)
60-61: Add explicit encoding parameter for consistency.The
open()call should specifyencoding="utf-8"for consistency with the coding guidelines and the save function which uses explicit encoding.♻️ Suggested fix
- with open(path) as f: + with open(path, encoding="utf-8") as f: data = json.load(f)src/nat/profiler/prediction_trie/data_models.py (1)
16-64: Add Google‑style docstrings (module + public models).The module and public model docstrings should follow Google style and wrap field names in backticks. This is required for public APIs and docs linting.
As per coding guidelines, ensure Google-style docstrings and backticks for code identifiers.Proposed docstring pattern
+"""Prediction trie data models. + +Attributes: + `PredictionMetrics`: Aggregated statistics for a single metric. + `LLMCallPrediction`: Predictions for an LLM call at a given position. + `PredictionTrieNode`: Trie node holding per-call predictions. +""" @@ -class PredictionMetrics(BaseModel): - """Aggregated statistics for a single metric from profiler data.""" +class PredictionMetrics(BaseModel): + """Aggregated statistics for a single metric from profiler data. + + Attributes: + `sample_count`: Number of samples. + `mean`: Mean value. + `p50`: 50th percentile (median). + `p90`: 90th percentile. + `p95`: 95th percentile. + """src/nat/builder/context.py (1)
119-124: Add a Google‑style docstring forfunction_path_stack.This is a public API property and should include a compliant docstring (with backticks for identifiers).
As per coding guidelines, public APIs need Google-style docstrings.Docstring example
`@property` def function_path_stack(self) -> ContextVar[list[str]]: + """ + Return the ContextVar holding the current function path stack. + + Returns: + ContextVar[list[str]]: Stack of function names for the current execution path. + """ if self._function_path_stack.get() is None: self._function_path_stack.set([]) return typing.cast(ContextVar[list[str]], self._function_path_stack)src/nat/profiler/prediction_trie/__init__.py (1)
16-32: Add a module docstring describing the public export surface.This is a public package initializer and should include a Google‑style module docstring.
As per coding guidelines, public modules require Google‑style docstrings.Example module docstring
+"""Prediction trie public API exports. + +Attributes: + `LLMCallPrediction`: Prediction data for a single LLM call. + `PredictionMetrics`: Aggregated metric statistics. + `PredictionTrieNode`: Trie node for hierarchical predictions. + `PredictionTrieBuilder`: Trie construction utility. + `PredictionTrieLookup`: Prediction lookup utility. + `load_prediction_trie`: Load a trie from JSON. + `save_prediction_trie`: Save a trie to JSON. +"""tests/nat/llm/test_dynamo_llm.py (2)
212-215: Avoid message text in the raisedValueErrorper TRY003.Ruff flags this pattern; consider raising a plain
ValueErrorhere since the message isn’t used in assertions.🧹 Suggested diff
- raise ValueError("Test exception") + raise ValueError
318-329: Prefernext(iter(prefix_ids))for single-element sets.Ruff suggests this pattern as it avoids unnecessary list creation.
♻️ Suggested diff
- assert "-d0" in list(prefix_ids)[0] + assert "-d0" in next(iter(prefix_ids))tests/nat/plugins/langchain/test_dynamo_trie_loading.py (1)
149-179: Consider using pytest'stmp_pathfixture for cleaner temp file handling.The current implementation uses
tempfile.NamedTemporaryFilewith manual cleanup in a try/finally block. Using pytest's built-intmp_pathfixture would be more idiomatic and handle cleanup automatically.♻️ Suggested refactor using tmp_path
`@patch`("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") `@patch`("langchain_openai.ChatOpenAI") -async def test_dynamo_langchain_handles_invalid_trie_file_gracefully(mock_chat, mock_create_client, mock_builder): +async def test_dynamo_langchain_handles_invalid_trie_file_gracefully(mock_chat, mock_create_client, mock_builder, tmp_path): """Test that dynamo_langchain logs warning and continues when trie file is invalid JSON.""" mock_httpx_client = MagicMock() mock_httpx_client.aclose = AsyncMock() mock_create_client.return_value = mock_httpx_client - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: - f.write("not valid json {{{") - invalid_trie_path = f.name + invalid_trie_path = tmp_path / "invalid_trie.json" + invalid_trie_path.write_text("not valid json {{{") - try: - config = DynamoModelConfig( - base_url="http://localhost:8000/v1", - model_name="test-model", - api_key="test-key", - prefix_template="test-{uuid}", - prediction_trie_path=invalid_trie_path, - ) + config = DynamoModelConfig( + base_url="http://localhost:8000/v1", + model_name="test-model", + api_key="test-key", + prefix_template="test-{uuid}", + prediction_trie_path=str(invalid_trie_path), + ) - # Should not raise an exception - async with dynamo_langchain(config, mock_builder): - # Verify httpx client was created with prediction_lookup=None - mock_create_client.assert_called_once() - call_kwargs = mock_create_client.call_args.kwargs - assert call_kwargs["prediction_lookup"] is None + # Should not raise an exception + async with dynamo_langchain(config, mock_builder): + # Verify httpx client was created with prediction_lookup=None + mock_create_client.assert_called_once() + call_kwargs = mock_create_client.call_args.kwargs + assert call_kwargs["prediction_lookup"] is None - mock_httpx_client.aclose.assert_awaited_once() - finally: - Path(invalid_trie_path).unlink(missing_ok=True) + mock_httpx_client.aclose.assert_awaited_once()src/nat/profiler/prediction_trie/trie_builder.py (1)
85-88: Consider optimizing the next LLM_START lookup for large traces.The current implementation iterates through all
llm_startsfor eachllm_endsitem, resulting in O(n*m) complexity. For traces with many LLM calls, this could be slow.Since both lists are sorted by timestamp, you could use binary search or a single-pass approach with a pointer to avoid repeated iteration.
♻️ Suggested optimization using a pointer
+ # Use pointer for efficient next-start lookup + start_idx = 0 + for i, end_step in enumerate(llm_ends): # ... existing code ... # Time to next LLM start (if any) time_to_next_ms: float | None = None current_end_time = end_step.event_timestamp - # Find next LLM_START after this LLM_END - for start_step in llm_starts: - if start_step.event_timestamp > current_end_time: - time_to_next_ms = (start_step.event_timestamp - current_end_time) * 1000.0 - break + # Find next LLM_START after this LLM_END (pointer advances monotonically) + while start_idx < len(llm_starts) and llm_starts[start_idx].event_timestamp <= current_end_time: + start_idx += 1 + if start_idx < len(llm_starts): + time_to_next_ms = (llm_starts[start_idx].event_timestamp - current_end_time) * 1000.0src/nat/llm/dynamo_llm.py (2)
153-160: Broad exception handling is acceptable for graceful fallback.The static analysis flags catching bare
Exceptionat lines 159 and 217. In both cases, this is intentional:
- Line 159: Returns
0as fallback depth when context is unavailable- Line 217: Falls back to a generated UUID when
workflow_run_idis unavailableThese are resilience patterns to ensure the system continues working even when context is not fully initialized. The behavior is acceptable, but consider logging at
debuglevel in_get_current_depthfor observability.♻️ Optional: Add debug logging in _get_current_depth
`@classmethod` def _get_current_depth(cls) -> int: """Get the current function call stack depth from Context.""" try: ctx = Context.get() return len(ctx.function_path) - except Exception: + except Exception as e: + logger.debug("Could not get context for depth calculation: %s", e) return 0Also applies to: 214-222
377-389: Document the deprecation ofprefix_templateparameter.The
prefix_templateparameter is kept for API compatibility but is no longer used (line 389). Consider adding a deprecation warning or updating the docstring to more explicitly indicate this is deprecated.♻️ Suggested docstring update
Args: - prefix_template: Template string with {uuid} placeholder (currently unused, - kept for API compatibility) + prefix_template: Template string with {uuid} placeholder. **Deprecated**: This + parameter is no longer used; prefix IDs are now managed by DynamoPrefixContext + with depth-awareness. Kept for API backward compatibility. total_requests: Expected number of requests for this prefix osl: Output sequence length hint (LOW/MEDIUM/HIGH) iat: Inter-arrival time hint (LOW/MEDIUM/HIGH)
examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md
Show resolved
Hide resolved
examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md
Show resolved
Hide resolved
examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md
Show resolved
Hide resolved
bbednarski9
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dnandakumar-nv approving this so we can merge with my work as it is non-breaking. We will need to bring in the follow-up #1486 shortly and validate end to end, revisiting diffs from this PR and that one as well for a complete implementation review.
…e-headers # Conflicts: # packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py # src/nat/llm/dynamo_llm.py # tests/nat/llm/test_dynamo_llm.py
|
/ok to test 5bec15d |
Updated files to include full Apache 2.0 license text, ensuring clarity on usage and distribution under the license. This change ensures compliance with legal requirements and improves consistency across the repository. Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
|
/ok to test 2cba23c |
This update includes "Trie(s)" in the NAT vocabulary file for Vale. It ensures that the term is recognized as valid during linting. Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
|
/ok to test 7d2c087 |
Updated headers and descriptions in the README to improve terminology clarity. Removed an unnecessary blank line in the test file for better formatting. Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
|
/ok to test 727f564 |
Corrected the formatting of `job_id` to use code style for consistency and clarity. This improves readability and aligns with standard documentation practices. Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
|
/ok to test 74c191d |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md`:
- Around line 130-165: The README incorrectly says prefix_template is required;
update the documentation to mark the prefix_template config key as deprecated
(or remove "required") and clarify current behavior: note that
prediction_trie_path and dynamo LLM type are the primary requirements for
trie-based predictions, state when prefix_template is still honored for backward
compatibility (if any), and adjust the "Headers not being injected" checklist to
remove or qualify the "prefix_template is set (required for Dynamo hooks)" line;
reference the config keys prefix_template, prediction_trie_path, and the dynamo
LLM type in the revised text.
examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md
Show resolved
Hide resolved
Included the full Apache 2.0 license header in two test files for compliance. This ensures proper licensing alignment and clarifies usage terms for these files. Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
|
/ok to test c412dc8 |
Updated imports to fetch `PredictionTrieLookup` directly from the `trie_lookup` submodule for better clarity and modularity. Adjusted `__init__.py` to avoid re-exporting `PredictionTrieLookup` to prevent Sphinx cross-reference warnings. Additionally, reformatted and clarified docstrings and field descriptions for improved readability. Signed-off-by: dnandakumar-nv <dnandakumar@nvidia.com>
|
/ok to test 22327b9 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@src/nat/llm/dynamo_llm.py`:
- Around line 560-562: Replace the current bare warning in the exception handler
that swallows prediction-hook errors with a logger.exception call so the full
traceback is preserved; specifically, in the exception block inside the
prediction header override logic (the except Exception as e: block that
currently calls logger.warning("Failed to override prefix headers from
prediction: %s", e)), change it to log the exception using logger.exception with
a descriptive message so the stack trace and error context are recorded.
🧹 Nitpick comments (2)
src/nat/llm/dynamo_llm.py (1)
154-161: Narrow the Context access exception handling.
CatchingExceptionhere can mask unexpected bugs; prefer the specific exceptions raised byContext.get()/ context vars.Also applies to: 215-219
tests/nat/plugins/langchain/test_dynamo_trie_loading.py (1)
86-109: Prefix unused mock parameter with underscore to indicate intentional non-use.The
mock_chatparameter is required to apply the patch but isn't used in the test body. Static analysis flags this as unused. Prefix with_to signal intent and silence the lint warning.Suggested fix
`@patch`("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") `@patch`("langchain_openai.ChatOpenAI") -async def test_dynamo_langchain_loads_trie_and_passes_to_client(mock_chat, mock_create_client, trie_file, mock_builder): +async def test_dynamo_langchain_loads_trie_and_passes_to_client(_mock_chat, mock_create_client, trie_file, mock_builder):The same applies to the other three async tests at lines 114, 140, and 163.
|
/merge |
Summary
This PR introduces runtime prediction-based header injection for Dynamo LLM integration, enabling intelligent KV cache optimization through profiling-derived predictions.
What This Enables
When running workflows against Dynamo-optimized inference endpoints, NAT can now:
This enables Dynamo's router to make better scheduling and caching decisions.
Key Features
1. Prediction Trie System
New module:
src/nat/profiler/prediction_trie/A data structure for building and querying prediction data:
data_models.pyPredictionMetrics,LLMCallPrediction,PredictionTrieNodemetrics_accumulator.pytrie_builder.pytrie_lookup.pyPredictionTrieLookupfor O(depth) path-based queriesserialization.pyTrie Structure:
2. Depth-Aware Prefix IDs
Modified:
DynamoPrefixContextinsrc/nat/llm/dynamo_llm.pyPrefix IDs are now unique per call stack depth within a workflow run:
Benefits:
3. Unified Header Format
All prediction sources (static config or dynamic trie) use consistent
x-prefix-*headers:x-prefix-id{workflow_id}-d{depth}{workflow_id}-d{depth}x-prefix-total-requestsprefix_total_requestsconfigremaining_calls.meanx-prefix-oslprefix_osl(LOW/MEDIUM/HIGH)output_tokens.p90x-prefix-iatprefix_iat(LOW/MEDIUM/HIGH)interarrival_ms.mean4. Category Conversion Thresholds
Dynamic numeric predictions are converted to categorical hints:
OSL (Output Sequence Length):
IAT (Inter-Arrival Time):
5. Function Path Tracking
Modified:
src/nat/builder/context.pyNew
function_path_stackContextVar tracks the current execution path:This enables the prediction trie lookup to find the right predictions for the current call site.
6. LLM Call Tracker
New:
src/nat/llm/prediction_context.pyTracks LLM call counts per function for accurate trie lookups:
Configuration
Enable Prediction Trie at Runtime
Generate Prediction Trie from Profiling
Architecture
Data Flow
Header Injection Pipeline
Files Changed
New Files
Prediction Trie Module:
src/nat/profiler/prediction_trie/__init__.pysrc/nat/profiler/prediction_trie/data_models.pysrc/nat/profiler/prediction_trie/metrics_accumulator.pysrc/nat/profiler/prediction_trie/trie_builder.pysrc/nat/profiler/prediction_trie/trie_lookup.pysrc/nat/profiler/prediction_trie/serialization.pyRuntime Support:
src/nat/llm/prediction_context.pyTests:
tests/nat/profiler/prediction_trie/__init__.pytests/nat/profiler/prediction_trie/test_data_models.pytests/nat/profiler/prediction_trie/test_metrics_accumulator.pytests/nat/profiler/prediction_trie/test_trie_builder.pytests/nat/profiler/prediction_trie/test_trie_lookup.pytests/nat/profiler/prediction_trie/test_serialization.pytests/nat/profiler/test_prediction_trie_e2e.pytests/nat/profiler/test_prediction_trie_integration.pytests/nat/llm/test_dynamic_prediction_hook.pytests/nat/llm/test_runtime_prediction_e2e.pytests/nat/llm/test_dynamo_prediction_headers.pytests/nat/llm/test_dynamo_prediction_trie.pytests/nat/llm/test_prediction_context.pytests/nat/plugins/langchain/test_dynamo_trie_loading.pytests/nat/builder/test_call_tracker_integration.pytests/nat/builder/test_function_path_stack.pytests/nat/runtime/test_runner_dynamo_prefix.pyModified Files
Core:
src/nat/llm/dynamo_llm.py- Major refactor for depth-aware prefixes and prediction hookssrc/nat/builder/context.py- Addfunction_path_stackContextVarsrc/nat/builder/intermediate_step_manager.py- Increment LLM call tracker on LLM_STARTsrc/nat/profiler/profile_runner.py- Integrate prediction trie generationsrc/nat/runtime/runner.py- Set/clear DynamoPrefixContext for workflowsLangChain Integration:
packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py- Load prediction trie in dynamo_langchainTests:
tests/nat/llm/test_dynamo_llm.py- Updated for new depth-aware behaviorTest Plan
pytest tests/nat/profiler/prediction_trie/- Prediction trie unit testspytest tests/nat/profiler/test_prediction_trie_e2e.py- End-to-end trie testpytest tests/nat/profiler/test_prediction_trie_integration.py- Integration testpytest tests/nat/llm/test_dynamo_llm.py- Dynamo LLM tests (28 tests)pytest tests/nat/llm/test_dynamic_prediction_hook.py- Dynamic hook tests (6 tests)pytest tests/nat/llm/test_runtime_prediction_e2e.py- Runtime E2E testpytest tests/nat/plugins/langchain/test_dynamo_trie_loading.py- Trie loading tests (6 tests)pytest packages/nvidia_nat_langchain/tests/test_llm_langchain.py -k dynamo- LangChain tests (4 tests)pytest tests/nat/builder/test_function_path_stack.py- Function path trackingpytest tests/nat/builder/test_call_tracker_integration.py- Call tracker testspytest tests/nat/runtime/test_runner_dynamo_prefix.py- Runner integration testsAPI Changes
Breaking Changes
DynamoPrefixContext.get()return type changed:str | Nonestr(always returns a value, auto-generates if no override)prefix_templateparameter deprecated:prefix_templateparameter in_create_dynamo_request_hookis kept for API compatibility but no longer usedDynamoPrefixContextwith depth-awarenessNew APIs
Example Usage
Full Workflow with Prediction Trie
Related Documentation
examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.mdexamples/dynamo_integration/react_benchmark_agent/configs/run_with_prediction_trie.ymlBy Submitting this PR I confirm:
Summary by CodeRabbit
New Features
Improvements
Documentation
Tests
✏️ Tip: You can customize this high-level summary in your review settings.