diff --git a/server/src/main/java/org/elasticsearch/TransportVersion.java b/server/src/main/java/org/elasticsearch/TransportVersion.java index c8ca78388ff14..e569f50d9f84f 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersion.java +++ b/server/src/main/java/org/elasticsearch/TransportVersion.java @@ -177,6 +177,8 @@ private static TransportVersion registerTransportVersion(int id, String uniqueId public static final TransportVersion V_8_500_061 = registerTransportVersion(8_500_061, "4e07f830-8be4-448c-851e-62b3d2f0bf0a"); public static final TransportVersion V_8_500_062 = registerTransportVersion(8_500_062, "09CD9C9B-3207-4B40-8756-B7A12001A885"); public static final TransportVersion V_8_500_063 = registerTransportVersion(8_500_063, "31dedced-0055-4f34-b952-2f6919be7488"); + public static final TransportVersion V_8_500_064 = registerTransportVersion(8_500_064, "3a795175-5e6f-40ff-90fe-5571ea8ab04e"); + /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ @@ -199,7 +201,7 @@ private static TransportVersion registerTransportVersion(int id, String uniqueId */ private static class CurrentHolder { - private static final TransportVersion CURRENT = findCurrent(V_8_500_063); + private static final TransportVersion CURRENT = findCurrent(V_8_500_064); // finds the pluggable current version, or uses the given fallback private static TransportVersion findCurrent(TransportVersion fallback) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java index b8085abbd9536..fa219e233cfb2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java @@ -379,6 +379,7 @@ public static boolean mayAssignToNode(DiscoveryNode node) { } public static final MlConfigVersion VERSION_INTRODUCED = MlConfigVersion.V_8_0_0; + private static final ParseField MODEL_BYTES = new ParseField("model_bytes"); public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations"); public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation"); @@ -389,6 +390,8 @@ public static boolean mayAssignToNode(DiscoveryNode node) { public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity"); public static final ParseField CACHE_SIZE = new ParseField("cache_size"); public static final ParseField PRIORITY = new ParseField("priority"); + public static final ParseField PER_DEPLOYMENT_MEMORY_BYTES = new ParseField("per_deployment_memory_bytes"); + public static final ParseField PER_ALLOCATION_MEMORY_BYTES = new ParseField("per_allocation_memory_bytes"); private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "trained_model_deployment_params", @@ -403,7 +406,9 @@ public static boolean mayAssignToNode(DiscoveryNode node) { (ByteSizeValue) a[6], (Integer) a[7], (Integer) a[8], - a[9] == null ? null : Priority.fromString((String) a[9]) + a[9] == null ? null : Priority.fromString((String) a[9]), + (Long) a[10], + (Long) a[11] ) ); @@ -423,6 +428,8 @@ public static boolean mayAssignToNode(DiscoveryNode node) { PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_MODEL_THREADS); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_INFERENCE_THREADS); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PRIORITY); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), PER_DEPLOYMENT_MEMORY_BYTES); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), PER_ALLOCATION_MEMORY_BYTES); } public static TaskParams fromXContent(XContentParser parser) { @@ -439,6 +446,8 @@ public static TaskParams fromXContent(XContentParser parser) { private final int numberOfAllocations; private final int queueCapacity; private final Priority priority; + private final long perDeploymentMemoryBytes; + private final long perAllocationMemoryBytes; private TaskParams( String modelId, @@ -450,7 +459,9 @@ private TaskParams( ByteSizeValue cacheSizeValue, Integer legacyModelThreads, Integer legacyInferenceThreads, - Priority priority + Priority priority, + Long perDeploymentMemoryBytes, + Long perAllocationMemoryBytes ) { this( modelId, @@ -462,7 +473,9 @@ private TaskParams( threadsPerAllocation == null ? legacyInferenceThreads : threadsPerAllocation, queueCapacity, cacheSizeValue, - priority == null ? Priority.NORMAL : priority + priority == null ? Priority.NORMAL : priority, + perDeploymentMemoryBytes == null ? 0 : perDeploymentMemoryBytes, + perAllocationMemoryBytes == null ? 0 : perAllocationMemoryBytes ); } @@ -474,7 +487,9 @@ public TaskParams( int threadsPerAllocation, int queueCapacity, @Nullable ByteSizeValue cacheSize, - Priority priority + Priority priority, + long perDeploymentMemoryBytes, + long perAllocationMemoryBytes ) { this.modelId = Objects.requireNonNull(modelId); this.deploymentId = Objects.requireNonNull(deploymentId); @@ -484,6 +499,8 @@ public TaskParams( this.queueCapacity = queueCapacity; this.cacheSize = cacheSize; this.priority = Objects.requireNonNull(priority); + this.perDeploymentMemoryBytes = perDeploymentMemoryBytes; + this.perAllocationMemoryBytes = perAllocationMemoryBytes; } public TaskParams(StreamInput in) throws IOException { @@ -507,6 +524,15 @@ public TaskParams(StreamInput in) throws IOException { } else { this.deploymentId = modelId; } + + if (in.getTransportVersion().onOrAfter(TrainedModelConfig.VERSION_ALLOCATION_MEMORY_ADDED)) { + // We store additional model usage per allocation in the task params. + this.perDeploymentMemoryBytes = in.readLong(); + this.perAllocationMemoryBytes = in.readLong(); + } else { + this.perDeploymentMemoryBytes = 0L; + this.perAllocationMemoryBytes = 0L; + } } public String getModelId() { @@ -521,10 +547,21 @@ public long estimateMemoryUsageBytes() { // We already take into account 2x the model bytes. If the cache size is larger than the model bytes, then // we need to take it into account when returning the estimate. if (cacheSize != null && cacheSize.getBytes() > modelBytes) { - return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelId, modelBytes) + (cacheSize.getBytes() - - modelBytes); + return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( + modelId, + modelBytes, + perDeploymentMemoryBytes, + perAllocationMemoryBytes, + numberOfAllocations + ) + (cacheSize.getBytes() - modelBytes); } - return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelId, modelBytes); + return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( + modelId, + modelBytes, + perDeploymentMemoryBytes, + perAllocationMemoryBytes, + numberOfAllocations + ); } public MlConfigVersion getMinimalSupportedVersion() { @@ -547,6 +584,10 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_8_0)) { out.writeString(deploymentId); } + if (out.getTransportVersion().onOrAfter(TrainedModelConfig.VERSION_ALLOCATION_MEMORY_ADDED)) { + out.writeLong(perDeploymentMemoryBytes); + out.writeLong(perAllocationMemoryBytes); + } } @Override @@ -562,6 +603,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(CACHE_SIZE.getPreferredName(), cacheSize.getStringRep()); } builder.field(PRIORITY.getPreferredName(), priority); + builder.field(PER_DEPLOYMENT_MEMORY_BYTES.getPreferredName(), perDeploymentMemoryBytes); + builder.field(PER_ALLOCATION_MEMORY_BYTES.getPreferredName(), perAllocationMemoryBytes); builder.endObject(); return builder; } @@ -576,7 +619,9 @@ public int hashCode() { numberOfAllocations, queueCapacity, cacheSize, - priority + priority, + perDeploymentMemoryBytes, + perAllocationMemoryBytes ); } @@ -593,7 +638,9 @@ public boolean equals(Object o) { && numberOfAllocations == other.numberOfAllocations && Objects.equals(cacheSize, other.cacheSize) && queueCapacity == other.queueCapacity - && priority == other.priority; + && priority == other.priority + && perDeploymentMemoryBytes == other.perDeploymentMemoryBytes + && perAllocationMemoryBytes == other.perAllocationMemoryBytes; } @Override @@ -629,6 +676,14 @@ public Priority getPriority() { return priority; } + public long getPerAllocationMemoryBytes() { + return perAllocationMemoryBytes; + } + + public long getPerDeploymentMemoryBytes() { + return perDeploymentMemoryBytes; + } + @Override public String toString() { return Strings.toString(this); @@ -649,12 +704,37 @@ static boolean match(Task task, String expectedId) { } } - public static long estimateMemoryUsageBytes(String modelId, long totalDefinitionLength) { + public static long estimateMemoryUsageBytes( + String modelId, + long totalDefinitionLength, + long perDeploymentMemoryBytes, + long perAllocationMemoryBytes, + int numberOfAllocations + ) { // While loading the model in the process we need twice the model size. - return isElserModel(modelId) ? ELSER_1_MEMORY_USAGE.getBytes() : MEMORY_OVERHEAD.getBytes() + 2 * totalDefinitionLength; + + // 1. If ELSER v1 then 2004MB + // 2. If static memory and dynamic memory are not set then 240MB + 2 * model size + // 3. Else static memory + dynamic memory * allocations + model size + + // The model size is still added in option 3 to account for the temporary requirement to hold the zip file in memory + // inĀ `pytorch_inference`. + if (isElserV1Model(modelId)) { + return ELSER_1_MEMORY_USAGE.getBytes(); + } else { + long baseSize = MEMORY_OVERHEAD.getBytes() + 2 * totalDefinitionLength; + if (perDeploymentMemoryBytes == 0 && perAllocationMemoryBytes == 0) { + return baseSize; + } else { + return Math.max( + baseSize, + perDeploymentMemoryBytes + perAllocationMemoryBytes * numberOfAllocations + totalDefinitionLength + ); + } + } } - private static boolean isElserModel(String modelId) { + private static boolean isElserV1Model(String modelId) { return modelId.startsWith(".elser_model_1"); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 545b9e6c260b3..77f3959b4a758 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -100,7 +100,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { public static final ParseField LOCATION = new ParseField("location"); public static final ParseField MODEL_PACKAGE = new ParseField("model_package"); + public static final ParseField PER_DEPLOYMENT_MEMORY_BYTES = new ParseField("per_deployment_memory_bytes"); + public static final ParseField PER_ALLOCATION_MEMORY_BYTES = new ParseField("per_allocation_memory_bytes"); + public static final TransportVersion VERSION_3RD_PARTY_CONFIG_ADDED = TransportVersion.V_8_0_0; + public static final TransportVersion VERSION_ALLOCATION_MEMORY_ADDED = TransportVersion.V_8_500_064; // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly public static final ObjectParser LENIENT_PARSER = createParser(true); @@ -163,6 +167,7 @@ private static ObjectParser createParser(boole (p, c) -> ignoreUnknownFields ? ModelPackageConfig.fromXContentLenient(p) : ModelPackageConfig.fromXContentStrict(p), MODEL_PACKAGE ); + return parser; } @@ -403,6 +408,18 @@ public void setFullDefinition(boolean fullDefinition) { this.fullDefinition = fullDefinition; } + public long getPerDeploymentMemoryBytes() { + return metadata != null && metadata.containsKey(PER_DEPLOYMENT_MEMORY_BYTES.getPreferredName()) + ? ((Number) metadata.get(PER_DEPLOYMENT_MEMORY_BYTES.getPreferredName())).longValue() + : 0L; + } + + public long getPerAllocationMemoryBytes() { + return metadata != null && metadata.containsKey(PER_ALLOCATION_MEMORY_BYTES.getPreferredName()) + ? ((Number) metadata.get(PER_ALLOCATION_MEMORY_BYTES.getPreferredName())).longValue() + : 0L; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); @@ -570,6 +587,8 @@ public static class Builder { private InferenceConfig inferenceConfig; private TrainedModelLocation location; private ModelPackageConfig modelPackageConfig; + private Long perDeploymentMemoryBytes; + private Long perAllocationMemoryBytes; public Builder() {} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java index 79e7004a49960..14cdc5639a0ea 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java @@ -474,7 +474,9 @@ public Builder setNumberOfAllocations(int numberOfAllocations) { taskParams.getThreadsPerAllocation(), taskParams.getQueueCapacity(), taskParams.getCacheSize().orElse(null), - taskParams.getPriority() + taskParams.getPriority(), + taskParams.getPerDeploymentMemoryBytes(), + taskParams.getPerAllocationMemoryBytes() ); return this; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java index d2072c44aa6ed..50cf59489c900 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentTaskParamsTests.java @@ -47,7 +47,9 @@ public static StartTrainedModelDeploymentAction.TaskParams createRandom() { randomIntBetween(1, 8), randomIntBetween(1, 10000), randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong()), - randomFrom(Priority.values()) + randomFrom(Priority.values()), + randomNonNegativeLong(), + randomNonNegativeLong() ); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java index 096942cc004a5..c85729b5a6311 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java @@ -296,7 +296,9 @@ private static StartTrainedModelDeploymentAction.TaskParams randomTaskParams(int randomIntBetween(1, 8), randomIntBetween(1, 10000), randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(0, modelSize + 1)), - randomFrom(Priority.values()) + randomFrom(Priority.values()), + randomNonNegativeLong(), + randomNonNegativeLong() ); } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java index d8e357d320cd3..f125b274830ae 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java @@ -240,6 +240,71 @@ public void testDeploymentStats() throws IOException { assertAtLeast.accept(modelStarted, AllocationStatus.State.FULLY_ALLOCATED); } + @SuppressWarnings("unchecked") + public void testRequiredMemoryEstimation() throws IOException { + String modelWithMetadata = "model_with_metadata"; + createPassThroughModel(modelWithMetadata, randomLongBetween(0, 10000000), randomLongBetween(0, 10000000)); + putVocabulary(List.of("once", "twice"), modelWithMetadata); + putModelDefinition(modelWithMetadata); + String modelNoMetadata = "model_no_metadata"; + createPassThroughModel(modelNoMetadata); + putVocabulary(List.of("once", "twice"), modelNoMetadata); + putModelDefinition(modelNoMetadata); + + CheckedBiConsumer assertAtLeast = (modelId, state) -> { + startDeployment(modelId, state); + Response response = getTrainedModelStats(modelId); + var responseMap = entityAsMap(response); + List> stats = (List>) responseMap.get("trained_model_stats"); + assertThat(stats, hasSize(1)); + String statusState = (String) XContentMapValues.extractValue("deployment_stats.allocation_status.state", stats.get(0)); + assertThat(responseMap.toString(), statusState, is(not(nullValue()))); + assertThat(AllocationStatus.State.fromString(statusState), greaterThanOrEqualTo(state)); + assertThat(XContentMapValues.extractValue("inference_stats", stats.get(0)), is(not(nullValue()))); + Integer numberOfAllocations = (Integer) XContentMapValues.extractValue("deployment_stats.number_of_allocations", stats.get(0)); + assertThat(numberOfAllocations, greaterThanOrEqualTo(0)); + + Integer byteSize = (Integer) XContentMapValues.extractValue("model_size_stats.model_size_bytes", stats.get(0)); + assertThat(responseMap.toString(), byteSize, is(not(nullValue()))); + assertThat(byteSize, equalTo((int) RAW_MODEL_SIZE)); + + Integer requiredNativeMemory = (Integer) XContentMapValues.extractValue( + "model_size_stats.required_native_memory_bytes", + stats.get(0) + ); + assertThat(responseMap.toString(), requiredNativeMemory, is(not(nullValue()))); + + Response trainedModelConfigResponse = getTrainedModelConfigs(modelId); + List> configs = (List>) entityAsMap(trainedModelConfigResponse).get( + "trained_model_configs" + ); + assertThat(configs, hasSize(1)); + Map metadata = (Map) configs.get(0).get("metadata"); + Integer canonicalRequiredMemory = (int) (ByteSizeValue.ofMb(240).getBytes() + 2 * RAW_MODEL_SIZE); + if (metadata != null) { + // test required memory estimation for a model with metadata memory requirements + assertThat(metadata, is(not(nullValue()))); + assertThat(metadata.containsKey("per_deployment_memory_bytes"), is(true)); + long perDeploymentMemoryBytes = ((Number) metadata.get("per_deployment_memory_bytes")).longValue(); + assertThat(metadata.containsKey("per_allocation_memory_bytes"), is(true)); + long perAllocationMemoryBytes = ((Number) metadata.get("per_allocation_memory_bytes")).longValue(); + Integer expectedRequiredMemory = Math.max( + canonicalRequiredMemory, + (int) (perDeploymentMemoryBytes + perAllocationMemoryBytes * numberOfAllocations + RAW_MODEL_SIZE) + ); + assertThat(requiredNativeMemory, equalTo(expectedRequiredMemory)); + } else { + // test required memory estimation for a model without metadata memory requirements + assertThat(requiredNativeMemory, equalTo(canonicalRequiredMemory)); + } + + stopDeployment(modelId); + }; + + assertAtLeast.accept(modelWithMetadata, AllocationStatus.State.STARTING); + assertAtLeast.accept(modelNoMetadata, AllocationStatus.State.STARTING); + } + @SuppressWarnings("unchecked") public void testLiveDeploymentStats() throws IOException { String modelId = "live_deployment_stats"; diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java index d895608ca0d6d..cbd90c26df3b2 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java @@ -191,11 +191,26 @@ protected void putVocabulary(List vocabulary, String modelId) throws IOE } protected void createPassThroughModel(String modelId) throws IOException { + createPassThroughModel(modelId, 0, 0); + } + + protected void createPassThroughModel(String modelId, long perDeploymentMemoryBytes, long perAllocationMemoryBytes) throws IOException { Request request = new Request("PUT", "/_ml/trained_models/" + modelId); - request.setJsonEntity(""" + String metadata; + if (perDeploymentMemoryBytes > 0 && perAllocationMemoryBytes > 0) { + metadata = Strings.format(""" + "metadata": { + "per_deployment_memory_bytes": %d, + "per_allocation_memory_bytes": %d + },""", perDeploymentMemoryBytes, perAllocationMemoryBytes); + } else { + metadata = ""; + } + request.setJsonEntity(Strings.format(""" { "description": "simple model for testing", "model_type": "pytorch", + %s "inference_config": { "pass_through": { "tokenization": { @@ -205,7 +220,7 @@ protected void createPassThroughModel(String modelId) throws IOException { } } } - }"""); + }""", metadata)); client().performRequest(request); } @@ -291,6 +306,11 @@ protected Response getTrainedModelStats(String modelId) throws IOException { return client().performRequest(request); } + protected Response getTrainedModelConfigs(String modelId) throws IOException { + Request request = new Request("GET", "/_ml/trained_models/" + modelId); + return client().performRequest(request); + } + protected Response infer(String input, String modelId, TimeValue timeout) throws IOException { Request request = new Request("POST", "/_ml/trained_models/" + modelId + "/_infer?timeout=" + timeout.toString()); request.setJsonEntity(Strings.format(""" diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java index 3d25302e6e5b4..266ec75007a2c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java @@ -124,11 +124,14 @@ protected void doExecute( .stream() .collect(Collectors.toMap(AssignmentStats::getDeploymentId, Function.identity())) ); + + int numberOfAllocations = deploymentStats.getStats().results().stream().mapToInt(AssignmentStats::getNumberOfAllocations).sum(); modelSizeStats( responseBuilder.getExpandedModelIdsWithAliases(), request.isAllowNoResources(), parentTaskId, - modelSizeStatsListener + modelSizeStatsListener, + numberOfAllocations ); })); @@ -273,7 +276,8 @@ private void modelSizeStats( Map> expandedIdsWithAliases, boolean allowNoResources, TaskId parentTaskId, - ActionListener> listener + ActionListener> listener, + int numberOfAllocations ) { ActionListener> modelsListener = ActionListener.wrap(models -> { final List pytorchModelIds = models.stream() @@ -285,12 +289,27 @@ private void modelSizeStats( for (TrainedModelConfig model : models) { if (model.getModelType() == TrainedModelType.PYTORCH) { long totalDefinitionLength = pytorchTotalDefinitionLengthsByModelId.getOrDefault(model.getModelId(), 0L); + long estimatedMemoryUsageBytes = totalDefinitionLength > 0L + ? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( + model.getModelId(), + totalDefinitionLength, + model.getPerDeploymentMemoryBytes(), + model.getPerAllocationMemoryBytes(), + numberOfAllocations + ) + : 0L; modelSizeStatsByModelId.put( model.getModelId(), new TrainedModelSizeStats( totalDefinitionLength, totalDefinitionLength > 0L - ? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(model.getModelId(), totalDefinitionLength) + ? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes( + model.getModelId(), + totalDefinitionLength, + model.getPerDeploymentMemoryBytes(), + model.getPerAllocationMemoryBytes(), + numberOfAllocations + ) : 0L ) ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java index e59a8a29612e4..795b67fb19aca 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java @@ -73,6 +73,7 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.Predicate; import java.util.function.Supplier; @@ -171,6 +172,9 @@ protected void masterOperation( return; } + AtomicLong perDeploymentMemoryBytes = new AtomicLong(); + AtomicLong perAllocationMemoryBytes = new AtomicLong(); + ActionListener waitForDeploymentToStart = ActionListener.wrap( modelAssignment -> waitForDeploymentState(request.getDeploymentId(), request.getTimeout(), request.getWaitForState(), listener), e -> { @@ -199,7 +203,9 @@ protected void masterOperation( request.getThreadsPerAllocation(), request.getQueueCapacity(), Optional.ofNullable(request.getCacheSize()).orElse(ByteSizeValue.ofBytes(modelIdAndSizeInBytes.v2())), - request.getPriority() + request.getPriority(), + perDeploymentMemoryBytes.get(), + perAllocationMemoryBytes.get() ); PersistentTasksCustomMetadata persistentTasks = clusterService.state().getMetadata().custom(PersistentTasksCustomMetadata.TYPE); memoryTracker.refresh( @@ -235,6 +241,9 @@ protected void masterOperation( return; } + perDeploymentMemoryBytes.set(trainedModelConfig.getPerDeploymentMemoryBytes()); + perAllocationMemoryBytes.set(trainedModelConfig.getPerAllocationMemoryBytes()); + if (trainedModelConfig.getLocation() == null) { listener.onFailure(ExceptionsHelper.serverError("model [{}] does not have location", trainedModelConfig.getModelId())); return; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java index c39b4662b6568..8554e4120775a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java @@ -370,7 +370,9 @@ public void clusterChanged(ClusterChangedEvent event) { trainedModelAssignment.getTaskParams().getThreadsPerAllocation(), trainedModelAssignment.getTaskParams().getQueueCapacity(), trainedModelAssignment.getTaskParams().getCacheSize().orElse(null), - trainedModelAssignment.getTaskParams().getPriority() + trainedModelAssignment.getTaskParams().getPriority(), + trainedModelAssignment.getTaskParams().getPerDeploymentMemoryBytes(), + trainedModelAssignment.getTaskParams().getPerAllocationMemoryBytes() ) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java index cabc3dee0e5a7..85e0feb71b704 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java @@ -82,7 +82,9 @@ public void updateNumberOfAllocations(int numberOfAllocations) { params.getThreadsPerAllocation(), params.getQueueCapacity(), params.getCacheSize().orElse(null), - params.getPriority() + params.getPriority(), + params.getPerDeploymentMemoryBytes(), + params.getPerAllocationMemoryBytes() ); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java index 6e1d41f968431..b8dd3559253ee 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java @@ -158,7 +158,18 @@ private DiscoveryNodes buildNodes(String... nodeIds) throws UnknownHostException private static TrainedModelAssignment createAssignment(String modelId) { return TrainedModelAssignment.Builder.empty( - new StartTrainedModelDeploymentAction.TaskParams(modelId, modelId, 1024, 1, 1, 1, ByteSizeValue.ofBytes(1024), Priority.NORMAL) + new StartTrainedModelDeploymentAction.TaskParams( + modelId, + modelId, + 1024, + 1, + 1, + 1, + ByteSizeValue.ofBytes(1024), + Priority.NORMAL, + 0L, + 0L + ) ).build(); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlMemoryAutoscalingDeciderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlMemoryAutoscalingDeciderTests.java index afad2b76f3d62..cf986c3cc5709 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlMemoryAutoscalingDeciderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlMemoryAutoscalingDeciderTests.java @@ -1063,7 +1063,9 @@ public void testCpuModelAssignmentRequirements() { 3, 100, null, - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ).build(), TrainedModelAssignment.Builder.empty( @@ -1075,7 +1077,9 @@ public void testCpuModelAssignmentRequirements() { 1, 100, null, - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ).build() ), @@ -1095,7 +1099,9 @@ public void testCpuModelAssignmentRequirements() { 3, 100, null, - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ).build(), TrainedModelAssignment.Builder.empty( @@ -1107,7 +1113,9 @@ public void testCpuModelAssignmentRequirements() { 1, 100, null, - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ).build() ), @@ -1127,7 +1135,9 @@ public void testCpuModelAssignmentRequirements() { 3, 100, null, - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ).build(), TrainedModelAssignment.Builder.empty( @@ -1139,7 +1149,9 @@ public void testCpuModelAssignmentRequirements() { 1, 100, null, - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ).build() ), diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlProcessorAutoscalingDeciderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlProcessorAutoscalingDeciderTests.java index 7aa3714c6ff2f..a56ad515690cf 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlProcessorAutoscalingDeciderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlProcessorAutoscalingDeciderTests.java @@ -76,7 +76,9 @@ public void testScale_GivenCurrentCapacityIsUsedExactly() { 3, 1024, ByteSizeValue.ONE, - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ).addRoutingEntry(mlNodeId1, new RoutingInfo(2, 2, RoutingState.STARTED, "")) ) @@ -91,7 +93,9 @@ public void testScale_GivenCurrentCapacityIsUsedExactly() { 1, 1024, ByteSizeValue.ONE, - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ) .addRoutingEntry(mlNodeId1, new RoutingInfo(2, 2, RoutingState.STARTED, "")) @@ -146,7 +150,9 @@ public void testScale_GivenUnsatisfiedDeployments() { 8, 1024, ByteSizeValue.ONE, - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ) ) @@ -161,7 +167,9 @@ public void testScale_GivenUnsatisfiedDeployments() { 4, 1024, ByteSizeValue.ONE, - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ) .addRoutingEntry(mlNodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) @@ -216,7 +224,9 @@ public void testScale_GivenUnsatisfiedDeploymentIsLowPriority_ShouldNotScaleUp() 1, 1024, ByteSizeValue.ONE, - Priority.LOW + Priority.LOW, + 0L, + 0L ) ) ) @@ -231,7 +241,9 @@ public void testScale_GivenUnsatisfiedDeploymentIsLowPriority_ShouldNotScaleUp() 4, 1024, ByteSizeValue.ONE, - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ) .addRoutingEntry(mlNodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) @@ -286,7 +298,9 @@ public void testScale_GivenMoreThanHalfProcessorsAreUsed() { 2, 1024, ByteSizeValue.ONE, - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ).addRoutingEntry(mlNodeId1, new RoutingInfo(2, 2, RoutingState.STARTED, "")) ) @@ -301,7 +315,9 @@ public void testScale_GivenMoreThanHalfProcessorsAreUsed() { 1, 1024, ByteSizeValue.ONE, - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ).addRoutingEntry(mlNodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) @@ -367,7 +383,9 @@ public void testScale_GivenDownScalePossible_DelayNotSatisfied() { 2, 1024, ByteSizeValue.ONE, - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ).addRoutingEntry(mlNodeId1, new RoutingInfo(2, 2, RoutingState.STARTED, "")) ) @@ -382,7 +400,9 @@ public void testScale_GivenDownScalePossible_DelayNotSatisfied() { 1, 1024, ByteSizeValue.ONE, - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ).addRoutingEntry(mlNodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) @@ -436,7 +456,9 @@ public void testScale_GivenDownScalePossible_DelaySatisfied() { 2, 1024, ByteSizeValue.ONE, - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ).addRoutingEntry(mlNodeId1, new RoutingInfo(2, 2, RoutingState.STARTED, "")) ) @@ -451,7 +473,9 @@ public void testScale_GivenDownScalePossible_DelaySatisfied() { 1, 1024, ByteSizeValue.ONE, - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ).addRoutingEntry(mlNodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) @@ -509,7 +533,9 @@ public void testScale_GivenLowPriorityDeploymentsOnly() { 1, 1024, ByteSizeValue.ONE, - Priority.LOW + Priority.LOW, + 0L, + 0L ) ).addRoutingEntry(mlNodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) @@ -524,7 +550,9 @@ public void testScale_GivenLowPriorityDeploymentsOnly() { 1, 1024, ByteSizeValue.ONE, - Priority.LOW + Priority.LOW, + 0L, + 0L ) ).addRoutingEntry(mlNodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java index f32c7970001f3..3471dc6a91958 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java @@ -1538,7 +1538,9 @@ private static StartTrainedModelDeploymentAction.TaskParams newParams( threadsPerAllocation, 1024, ByteSizeValue.ofBytes(modelSize), - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java index 00e4f48447ef7..3057da83d11e9 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java @@ -113,7 +113,9 @@ private static StartTrainedModelDeploymentAction.TaskParams randomParams(String randomIntBetween(1, 8), randomIntBetween(1, 10000), randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong()), - randomFrom(Priority.values()) + randomFrom(Priority.values()), + randomNonNegativeLong(), + randomNonNegativeLong() ); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java index cfdf0e87e7f78..06f486d5ab259 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java @@ -637,7 +637,9 @@ private static StartTrainedModelDeploymentAction.TaskParams newParams(String dep 1, 1024, randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong()), - randomFrom(Priority.values()) + randomFrom(Priority.values()), + randomNonNegativeLong(), + randomNonNegativeLong() ); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java index a563f4b9a5a66..8ccf8839cfc08 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java @@ -1126,7 +1126,9 @@ private static StartTrainedModelDeploymentAction.TaskParams lowPriorityParams(St 1, 1024, ByteSizeValue.ofBytes(modelSize), - Priority.LOW + Priority.LOW, + 0, + 0 ); } @@ -1154,7 +1156,9 @@ private static StartTrainedModelDeploymentAction.TaskParams normalPriorityParams threadsPerAllocation, 1024, ByteSizeValue.ofBytes(modelSize), - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducerTests.java index cce59d9a989ba..85fc83f775670 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducerTests.java @@ -178,7 +178,9 @@ private static TrainedModelAssignment createAssignment( randomIntBetween(1, 16), 1024, null, - Priority.NORMAL + Priority.NORMAL, + randomNonNegativeLong(), + randomNonNegativeLong() ) ); allocationsByNode.entrySet() diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java index d6f237de35d3d..44dc44971bf38 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTaskTests.java @@ -59,7 +59,9 @@ void assertTrackingComplete(Consumer method, String randomInt(5), randomInt(5), randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, Long.MAX_VALUE)), - Priority.NORMAL + Priority.NORMAL, + randomNonNegativeLong(), + randomNonNegativeLong() ), nodeService, licenseState, @@ -93,7 +95,9 @@ public void testUpdateNumberOfAllocations() { randomIntBetween(1, 32), randomInt(5), randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, Long.MAX_VALUE)), - randomFrom(Priority.values()) + randomFrom(Priority.values()), + randomNonNegativeLong(), + randomNonNegativeLong() ); TrainedModelDeploymentTask task = new TrainedModelDeploymentTask( diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilderTests.java index c3223ed7ad67c..b14b2fe71ab17 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchBuilderTests.java @@ -51,7 +51,7 @@ public void testBuild() throws IOException, InterruptedException { new PyTorchBuilder( nativeController, processPipes, - new TaskParams("my_model", "my_deployment", 42L, 4, 2, 1024, ByteSizeValue.ofBytes(12), Priority.NORMAL) + new TaskParams("my_model", "my_deployment", 42L, 4, 2, 1024, ByteSizeValue.ofBytes(12), Priority.NORMAL, 0L, 0L) ).build(); verify(nativeController).startProcess(commandCaptor.capture()); @@ -73,7 +73,7 @@ public void testBuildWithNoCache() throws IOException, InterruptedException { new PyTorchBuilder( nativeController, processPipes, - new TaskParams("my_model", "my_deployment", 42L, 4, 2, 1024, ByteSizeValue.ZERO, Priority.NORMAL) + new TaskParams("my_model", "my_deployment", 42L, 4, 2, 1024, ByteSizeValue.ZERO, Priority.NORMAL, 0L, 0L) ).build(); verify(nativeController).startProcess(commandCaptor.capture()); @@ -94,7 +94,7 @@ public void testBuildWithLowPriority() throws IOException, InterruptedException new PyTorchBuilder( nativeController, processPipes, - new TaskParams("my_model", "my_deployment", 42L, 1, 1, 1024, ByteSizeValue.ofBytes(42), Priority.LOW) + new TaskParams("my_model", "my_deployment", 42L, 1, 1, 1024, ByteSizeValue.ofBytes(42), Priority.LOW, 0L, 0L) ).build(); verify(nativeController).startProcess(commandCaptor.capture()); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java index 063f6fc860e20..ccc7f14d2264e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java @@ -130,7 +130,9 @@ public void testNodeLoadDetection() { 1, 1024, ByteSizeValue.ofBytes(MODEL_MEMORY_REQUIREMENT), - Priority.NORMAL + Priority.NORMAL, + 0L, + 0L ) ) .addRoutingEntry("_node_id4", new RoutingInfo(1, 1, RoutingState.STARTING, ""))