[ML] Add per allocation and per deployment memory metadata fields to …#6
[ML] Add per allocation and per deployment memory metadata fields to …#6MitchLewis930 wants to merge 1 commit intopr_016_beforefrom
Conversation
…the trained models config (elastic#98139) To improve the required memory estimation of NLP models, this PR introduces two new metadata fields: per_deployment_memory_bytes and per_allocation_memory_bytes. per_deployment_memory_bytes is the memory required to load the model in the deployment per_allocation_memory_bytes is the temporary additional memory used during the inference for every allocation. This PR extends the memory usage estimation logic while ensuring backward compatibility. In a follow-up PR, I will adjust the assignment planner to use the refined memory usage information.
📝 WalkthroughWalkthroughThis pull request introduces memory accounting capabilities for trained model deployments. It adds a new transport version (V_8_500_064), extends the TaskParams configuration to track per-deployment and per-allocation memory bytes, wires these values from model metadata through deployment flows, and updates memory estimation logic to incorporate allocation counts alongside the new memory fields. Changes
Sequence DiagramsequenceDiagram
participant User as User/Client
participant API as StartTrainedModelDeploymentAction
participant Config as TrainedModelConfig
participant TaskParams as TaskParams
participant Memory as Memory Estimator
participant Deployment as Deployment Task
User->>API: Start deployment request
API->>Config: Fetch trained model config
Config-->>API: Return config with memory metadata
API->>TaskParams: Create TaskParams with<br/>perDeploymentMemoryBytes,<br/>perAllocationMemoryBytes
TaskParams-->>API: TaskParams instantiated
API->>Memory: estimateMemoryUsageBytes(<br/>modelId, totalLength,<br/>perDeploymentMem,<br/>perAllocationMem,<br/>numberOfAllocations)
Memory-->>API: Computed memory requirement
API->>Deployment: Initialize with memory-aware<br/>TaskParams
Deployment-->>User: Deployment started with<br/>memory tracking
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Comment |
|
@CodeRabbit full review |
✅ Actions performedFull review triggered. |
|
@coderabbitai full review |
✅ Actions performedFull review triggered. |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In
`@x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java`:
- Around line 127-135: The code currently sums all AssignmentStats into a single
numberOfAllocations and passes that same value into modelSizeStats for every
model, overestimating memory when multiple models/deployments exist; change the
logic to compute allocations per model (e.g., build a Map<String,int> from
deploymentStats.getStats().results() keyed by modelId or deployment id using
AssignmentStats::getNumberOfAllocations aggregated per model) and then iterate
over responseBuilder.getExpandedModelIdsWithAliases(), calling
modelSizeStats(...) once per model with its corresponding per-model allocation
count (still passing request.isAllowNoResources(), parentTaskId,
modelSizeStatsListener as before).
🧹 Nitpick comments (3)
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java (1)
98-101: Consider adding assertions for the new memory fields.The TaskParams construction correctly includes the new memory fields. However, the
testUpdateNumberOfAllocationstest verifies that various parameters are preserved after callingupdateNumberOfAllocations()(modelId, deploymentId, modelBytes, threadsPerAllocation, cacheSize, priority), but it does not verify thatperDeploymentMemoryBytesandperAllocationMemoryBytesare preserved.Consider adding assertions to ensure these new fields are correctly preserved:
💡 Suggested addition for completeness
assertThat(updatedParams.getCacheSize(), equalTo(initialParams.getCacheSize())); assertThat(updatedParams.getPriority(), equalTo(initialParams.getPriority())); + assertThat(updatedParams.getPerDeploymentMemoryBytes(), equalTo(initialParams.getPerDeploymentMemoryBytes())); + assertThat(updatedParams.getPerAllocationMemoryBytes(), equalTo(initialParams.getPerAllocationMemoryBytes())); }x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java (1)
76-76: Consider validating per‑deployment/per‑allocation memory values before use.These values come from model metadata; if they can be negative, memory accounting could under‑estimate usage. A small guard (or clamping) would make the deployment flow more robust.
🔧 Suggested guard
- perDeploymentMemoryBytes.set(trainedModelConfig.getPerDeploymentMemoryBytes()); - perAllocationMemoryBytes.set(trainedModelConfig.getPerAllocationMemoryBytes()); + long perDeployment = trainedModelConfig.getPerDeploymentMemoryBytes(); + long perAllocation = trainedModelConfig.getPerAllocationMemoryBytes(); + if (perDeployment < 0 || perAllocation < 0) { + listener.onFailure( + ExceptionsHelper.badRequestException( + "model [{}] has negative memory metadata (per_deployment: [{}], per_allocation: [{}])", + trainedModelConfig.getModelId(), + perDeployment, + perAllocation + ) + ); + return; + } + perDeploymentMemoryBytes.set(perDeployment); + perAllocationMemoryBytes.set(perAllocation);Also applies to: 175-176, 197-209, 244-245
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java (1)
193-223: Allow metadata emission when either memory field is set.Right now metadata is included only when both values are > 0; this prevents testing a non‑zero per‑deployment value with a zero per‑allocation value (or vice‑versa). If the API allows those independently, consider
||instead.🔧 Suggested tweak
- if (perDeploymentMemoryBytes > 0 && perAllocationMemoryBytes > 0) { + if (perDeploymentMemoryBytes > 0 || perAllocationMemoryBytes > 0) { metadata = Strings.format(""" "metadata": { "per_deployment_memory_bytes": %d, "per_allocation_memory_bytes": %d },""", perDeploymentMemoryBytes, perAllocationMemoryBytes); } else {
📜 Review details
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (23)
server/src/main/java/org/elasticsearch/TransportVersion.javax-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.javax-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.javax-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.javax-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.javax-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.javax-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.javax-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.javax-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.javax-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.javax-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.javax-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.javax-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.javax-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlMemoryAutoscalingDeciderTests.javax-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlProcessorAutoscalingDeciderTests.javax-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.javax-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.javax-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.javax-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.javax-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducerTests.javax-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.javax-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilderTests.javax-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java
🧰 Additional context used
🧬 Code graph analysis (2)
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java (1)
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java (1)
StartTrainedModelDeploymentAction(46-740)
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java (1)
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java (1)
TrainedModelConfig(65-1142)
🔇 Additional comments (30)
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java (1)
133-136: LGTM!The TaskParams constructor is correctly updated with the two new memory metadata fields (
perDeploymentMemoryBytesandperAllocationMemoryBytes). Using0Lfor both values is appropriate here since this test focuses on node load detection behavior rather than memory estimation.server/src/main/java/org/elasticsearch/TransportVersion.java (2)
180-181: LGTM!The new transport version
V_8_500_064is correctly added following the established incrementing pattern with a unique UUID. TheCURRENTholder is properly updated to reference this new version.
204-204: LGTM!The
CurrentHolder.CURRENTis correctly updated to useV_8_500_064as the fallback version.x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java (1)
50-53: LGTM!The
createRandom()method correctly usesrandomNonNegativeLong()for both new memory fields, providing good coverage for serialization testing ofperDeploymentMemoryBytesandperAllocationMemoryBytes.x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java (1)
85-88: LGTM!The
updateNumberOfAllocationsmethod correctly preserves the new memory metadata fields (perDeploymentMemoryBytesandperAllocationMemoryBytes) when creating a newTaskParamsinstance. This ensures no data loss during allocation count updates.x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java (1)
299-302: LGTM!The
randomTaskParamshelper method correctly includesrandomNonNegativeLong()for both new memory fields, ensuring comprehensive test coverage across all test methods that use this helper.x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java (1)
1541-1544: LGTM!The
newParamshelper method correctly adds0Lfor both new memory fields. Using zero values is appropriate here since these tests focus on cluster service assignment logic rather than memory estimation behavior.x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducerTests.java (1)
181-184: LGTM!The
createAssignmenthelper method correctly addsrandomNonNegativeLong()for both new memory fields, providing good test coverage while testing allocation reduction logic.x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java (1)
62-65: LGTM!The TaskParams construction correctly includes
randomNonNegativeLong()for both new memory fields.x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilderTests.java (1)
51-55: LGTM: TaskParams invocations updated for new memory fields.Also applies to: 73-77, 94-98
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java (1)
107-119: LGTM: random TaskParams now cover the new memory fields.x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlProcessorAutoscalingDeciderTests.java (1)
79-81: LGTM: test fixtures aligned with expanded TaskParams signature.Also applies to: 96-98, 153-155, 170-172, 227-229, 244-246, 301-303, 318-320, 386-388, 403-405, 459-461, 476-478, 536-538, 553-555
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java (1)
631-643: LGTM: helper now includes the new memory fields.x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java (1)
1121-1132: LGTM: TaskParams helper methods updated for new fields.Also applies to: 1151-1162
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java (1)
309-312: LGTM: handy helper for fetching trained model configs.x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlMemoryAutoscalingDeciderTests.java (1)
1058-1069: LGTM: TaskParams updates align with new memory fields.Also applies to: 1072-1083, 1094-1105, 1108-1119, 1130-1141, 1144-1155
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java (1)
365-376: LGTM: memory metadata is propagated into TaskParams.x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java (1)
160-172: LGTM: test helper updated for new TaskParams signature.x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java (1)
468-480: LGTM: allocation updates retain memory metadata.x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java (1)
243-306: LGTM: coverage added for required memory estimation.x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java (3)
103-108: LGTM: new parse fields and transport version constant added.
411-421: LGTM: metadata-backed accessors are straightforward.
590-591: LGTM: builder extended for allocation memory fields.x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java (7)
381-433: LGTM: TaskParams parsing updated for allocation memory fields.
449-504: LGTM: new TaskParams fields and constructors are wired correctly.
527-535: LGTM: transport-version gated deserialization looks correct.
546-564: LGTM: TaskParams memory estimation uses new metadata.
571-608: LGTM: serialization/XContent include allocation memory fields.
612-685: LGTM: equality and accessors cover new fields.
707-739: LGTM: updated estimate logic reflects allocation-aware memory.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
|
|
||
| int numberOfAllocations = deploymentStats.getStats().results().stream().mapToInt(AssignmentStats::getNumberOfAllocations).sum(); | ||
| modelSizeStats( | ||
| responseBuilder.getExpandedModelIdsWithAliases(), | ||
| request.isAllowNoResources(), | ||
| parentTaskId, | ||
| modelSizeStatsListener | ||
| modelSizeStatsListener, | ||
| numberOfAllocations | ||
| ); |
There was a problem hiding this comment.
Compute allocations per model, not a global sum.
The global sum feeds the same allocation count into every model’s memory estimate, which overestimates memory when multiple deployments/models are returned. Aggregate per model (or per deployment) instead.
🔧 Suggested fix (aggregate by model)
- int numberOfAllocations = deploymentStats.getStats().results().stream().mapToInt(AssignmentStats::getNumberOfAllocations).sum();
+ Map<String, Integer> allocationsByModelId = deploymentStats.getStats().results().stream()
+ .collect(Collectors.groupingBy(AssignmentStats::getModelId, Collectors.summingInt(AssignmentStats::getNumberOfAllocations)));
- modelSizeStats(
- responseBuilder.getExpandedModelIdsWithAliases(),
- request.isAllowNoResources(),
- parentTaskId,
- modelSizeStatsListener,
- numberOfAllocations
- );
+ modelSizeStats(
+ responseBuilder.getExpandedModelIdsWithAliases(),
+ request.isAllowNoResources(),
+ parentTaskId,
+ modelSizeStatsListener,
+ allocationsByModelId
+ );- private void modelSizeStats(..., ActionListener<Map<String, TrainedModelSizeStats>> listener, int numberOfAllocations) {
+ private void modelSizeStats(..., ActionListener<Map<String, TrainedModelSizeStats>> listener,
+ Map<String, Integer> allocationsByModelId) {
for (TrainedModelConfig model : models) {
+ int numberOfAllocations = allocationsByModelId.getOrDefault(model.getModelId(), 0);
...
}🤖 Prompt for AI Agents
In
`@x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java`
around lines 127 - 135, The code currently sums all AssignmentStats into a
single numberOfAllocations and passes that same value into modelSizeStats for
every model, overestimating memory when multiple models/deployments exist;
change the logic to compute allocations per model (e.g., build a Map<String,int>
from deploymentStats.getStats().results() keyed by modelId or deployment id
using AssignmentStats::getNumberOfAllocations aggregated per model) and then
iterate over responseBuilder.getExpandedModelIdsWithAliases(), calling
modelSizeStats(...) once per model with its corresponding per-model allocation
count (still passing request.isAllowNoResources(), parentTaskId,
modelSizeStatsListener as before).
PR_016
Summary by CodeRabbit
New Features
Tests