Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion server/src/main/java/org/elasticsearch/TransportVersion.java
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ private static TransportVersion registerTransportVersion(int id, String uniqueId
public static final TransportVersion V_8_500_061 = registerTransportVersion(8_500_061, "4e07f830-8be4-448c-851e-62b3d2f0bf0a");
public static final TransportVersion V_8_500_062 = registerTransportVersion(8_500_062, "09CD9C9B-3207-4B40-8756-B7A12001A885");
public static final TransportVersion V_8_500_063 = registerTransportVersion(8_500_063, "31dedced-0055-4f34-b952-2f6919be7488");
public static final TransportVersion V_8_500_064 = registerTransportVersion(8_500_064, "3a795175-5e6f-40ff-90fe-5571ea8ab04e");

/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand All @@ -199,7 +201,7 @@ private static TransportVersion registerTransportVersion(int id, String uniqueId
*/

private static class CurrentHolder {
private static final TransportVersion CURRENT = findCurrent(V_8_500_063);
private static final TransportVersion CURRENT = findCurrent(V_8_500_064);

// finds the pluggable current version, or uses the given fallback
private static TransportVersion findCurrent(TransportVersion fallback) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ public static boolean mayAssignToNode(DiscoveryNode node) {
}

public static final MlConfigVersion VERSION_INTRODUCED = MlConfigVersion.V_8_0_0;

private static final ParseField MODEL_BYTES = new ParseField("model_bytes");
public static final ParseField NUMBER_OF_ALLOCATIONS = new ParseField("number_of_allocations");
public static final ParseField THREADS_PER_ALLOCATION = new ParseField("threads_per_allocation");
Expand All @@ -389,6 +390,8 @@ public static boolean mayAssignToNode(DiscoveryNode node) {
public static final ParseField QUEUE_CAPACITY = new ParseField("queue_capacity");
public static final ParseField CACHE_SIZE = new ParseField("cache_size");
public static final ParseField PRIORITY = new ParseField("priority");
public static final ParseField PER_DEPLOYMENT_MEMORY_BYTES = new ParseField("per_deployment_memory_bytes");
public static final ParseField PER_ALLOCATION_MEMORY_BYTES = new ParseField("per_allocation_memory_bytes");

private static final ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
"trained_model_deployment_params",
Expand All @@ -403,7 +406,9 @@ public static boolean mayAssignToNode(DiscoveryNode node) {
(ByteSizeValue) a[6],
(Integer) a[7],
(Integer) a[8],
a[9] == null ? null : Priority.fromString((String) a[9])
a[9] == null ? null : Priority.fromString((String) a[9]),
(Long) a[10],
(Long) a[11]
)
);

Expand All @@ -423,6 +428,8 @@ public static boolean mayAssignToNode(DiscoveryNode node) {
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_MODEL_THREADS);
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), LEGACY_INFERENCE_THREADS);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PRIORITY);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), PER_DEPLOYMENT_MEMORY_BYTES);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), PER_ALLOCATION_MEMORY_BYTES);
}

public static TaskParams fromXContent(XContentParser parser) {
Expand All @@ -439,6 +446,8 @@ public static TaskParams fromXContent(XContentParser parser) {
private final int numberOfAllocations;
private final int queueCapacity;
private final Priority priority;
private final long perDeploymentMemoryBytes;
private final long perAllocationMemoryBytes;

private TaskParams(
String modelId,
Expand All @@ -450,7 +459,9 @@ private TaskParams(
ByteSizeValue cacheSizeValue,
Integer legacyModelThreads,
Integer legacyInferenceThreads,
Priority priority
Priority priority,
Long perDeploymentMemoryBytes,
Long perAllocationMemoryBytes
) {
this(
modelId,
Expand All @@ -462,7 +473,9 @@ private TaskParams(
threadsPerAllocation == null ? legacyInferenceThreads : threadsPerAllocation,
queueCapacity,
cacheSizeValue,
priority == null ? Priority.NORMAL : priority
priority == null ? Priority.NORMAL : priority,
perDeploymentMemoryBytes == null ? 0 : perDeploymentMemoryBytes,
perAllocationMemoryBytes == null ? 0 : perAllocationMemoryBytes
);
}

Expand All @@ -474,7 +487,9 @@ public TaskParams(
int threadsPerAllocation,
int queueCapacity,
@Nullable ByteSizeValue cacheSize,
Priority priority
Priority priority,
long perDeploymentMemoryBytes,
long perAllocationMemoryBytes
) {
this.modelId = Objects.requireNonNull(modelId);
this.deploymentId = Objects.requireNonNull(deploymentId);
Expand All @@ -484,6 +499,8 @@ public TaskParams(
this.queueCapacity = queueCapacity;
this.cacheSize = cacheSize;
this.priority = Objects.requireNonNull(priority);
this.perDeploymentMemoryBytes = perDeploymentMemoryBytes;
this.perAllocationMemoryBytes = perAllocationMemoryBytes;
}

public TaskParams(StreamInput in) throws IOException {
Expand All @@ -507,6 +524,15 @@ public TaskParams(StreamInput in) throws IOException {
} else {
this.deploymentId = modelId;
}

if (in.getTransportVersion().onOrAfter(TrainedModelConfig.VERSION_ALLOCATION_MEMORY_ADDED)) {
// We store additional model usage per allocation in the task params.
this.perDeploymentMemoryBytes = in.readLong();
this.perAllocationMemoryBytes = in.readLong();
} else {
this.perDeploymentMemoryBytes = 0L;
this.perAllocationMemoryBytes = 0L;
}
}

public String getModelId() {
Expand All @@ -521,10 +547,21 @@ public long estimateMemoryUsageBytes() {
// We already take into account 2x the model bytes. If the cache size is larger than the model bytes, then
// we need to take it into account when returning the estimate.
if (cacheSize != null && cacheSize.getBytes() > modelBytes) {
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelId, modelBytes) + (cacheSize.getBytes()
- modelBytes);
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
modelId,
modelBytes,
perDeploymentMemoryBytes,
perAllocationMemoryBytes,
numberOfAllocations
) + (cacheSize.getBytes() - modelBytes);
}
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelId, modelBytes);
return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(
modelId,
modelBytes,
perDeploymentMemoryBytes,
perAllocationMemoryBytes,
numberOfAllocations
);
}

public MlConfigVersion getMinimalSupportedVersion() {
Expand All @@ -547,6 +584,10 @@ public void writeTo(StreamOutput out) throws IOException {
if (out.getTransportVersion().onOrAfter(TransportVersion.V_8_8_0)) {
out.writeString(deploymentId);
}
if (out.getTransportVersion().onOrAfter(TrainedModelConfig.VERSION_ALLOCATION_MEMORY_ADDED)) {
out.writeLong(perDeploymentMemoryBytes);
out.writeLong(perAllocationMemoryBytes);
}
}

@Override
Expand All @@ -562,6 +603,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(CACHE_SIZE.getPreferredName(), cacheSize.getStringRep());
}
builder.field(PRIORITY.getPreferredName(), priority);
builder.field(PER_DEPLOYMENT_MEMORY_BYTES.getPreferredName(), perDeploymentMemoryBytes);
builder.field(PER_ALLOCATION_MEMORY_BYTES.getPreferredName(), perAllocationMemoryBytes);
builder.endObject();
return builder;
}
Expand All @@ -576,7 +619,9 @@ public int hashCode() {
numberOfAllocations,
queueCapacity,
cacheSize,
priority
priority,
perDeploymentMemoryBytes,
perAllocationMemoryBytes
);
}

Expand All @@ -593,7 +638,9 @@ public boolean equals(Object o) {
&& numberOfAllocations == other.numberOfAllocations
&& Objects.equals(cacheSize, other.cacheSize)
&& queueCapacity == other.queueCapacity
&& priority == other.priority;
&& priority == other.priority
&& perDeploymentMemoryBytes == other.perDeploymentMemoryBytes
&& perAllocationMemoryBytes == other.perAllocationMemoryBytes;
}

@Override
Expand Down Expand Up @@ -629,6 +676,14 @@ public Priority getPriority() {
return priority;
}

public long getPerAllocationMemoryBytes() {
return perAllocationMemoryBytes;
}

public long getPerDeploymentMemoryBytes() {
return perDeploymentMemoryBytes;
}

@Override
public String toString() {
return Strings.toString(this);
Expand All @@ -649,12 +704,37 @@ static boolean match(Task task, String expectedId) {
}
}

public static long estimateMemoryUsageBytes(String modelId, long totalDefinitionLength) {
public static long estimateMemoryUsageBytes(
String modelId,
long totalDefinitionLength,
long perDeploymentMemoryBytes,
long perAllocationMemoryBytes,
int numberOfAllocations
) {
// While loading the model in the process we need twice the model size.
return isElserModel(modelId) ? ELSER_1_MEMORY_USAGE.getBytes() : MEMORY_OVERHEAD.getBytes() + 2 * totalDefinitionLength;

// 1. If ELSER v1 then 2004MB
// 2. If static memory and dynamic memory are not set then 240MB + 2 * model size
// 3. Else static memory + dynamic memory * allocations + model size

// The model size is still added in option 3 to account for the temporary requirement to hold the zip file in memory
// in `pytorch_inference`.
if (isElserV1Model(modelId)) {
return ELSER_1_MEMORY_USAGE.getBytes();
} else {
long baseSize = MEMORY_OVERHEAD.getBytes() + 2 * totalDefinitionLength;
if (perDeploymentMemoryBytes == 0 && perAllocationMemoryBytes == 0) {
return baseSize;
} else {
return Math.max(
baseSize,
perDeploymentMemoryBytes + perAllocationMemoryBytes * numberOfAllocations + totalDefinitionLength
);
}
}
}

private static boolean isElserModel(String modelId) {
private static boolean isElserV1Model(String modelId) {
return modelId.startsWith(".elser_model_1");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final ParseField LOCATION = new ParseField("location");
public static final ParseField MODEL_PACKAGE = new ParseField("model_package");

public static final ParseField PER_DEPLOYMENT_MEMORY_BYTES = new ParseField("per_deployment_memory_bytes");
public static final ParseField PER_ALLOCATION_MEMORY_BYTES = new ParseField("per_allocation_memory_bytes");

public static final TransportVersion VERSION_3RD_PARTY_CONFIG_ADDED = TransportVersion.V_8_0_0;
public static final TransportVersion VERSION_ALLOCATION_MEMORY_ADDED = TransportVersion.V_8_500_064;

// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ObjectParser<TrainedModelConfig.Builder, Void> LENIENT_PARSER = createParser(true);
Expand Down Expand Up @@ -163,6 +167,7 @@ private static ObjectParser<TrainedModelConfig.Builder, Void> createParser(boole
(p, c) -> ignoreUnknownFields ? ModelPackageConfig.fromXContentLenient(p) : ModelPackageConfig.fromXContentStrict(p),
MODEL_PACKAGE
);

return parser;
}

Expand Down Expand Up @@ -403,6 +408,18 @@ public void setFullDefinition(boolean fullDefinition) {
this.fullDefinition = fullDefinition;
}

public long getPerDeploymentMemoryBytes() {
return metadata != null && metadata.containsKey(PER_DEPLOYMENT_MEMORY_BYTES.getPreferredName())
? ((Number) metadata.get(PER_DEPLOYMENT_MEMORY_BYTES.getPreferredName())).longValue()
: 0L;
}

public long getPerAllocationMemoryBytes() {
return metadata != null && metadata.containsKey(PER_ALLOCATION_MEMORY_BYTES.getPreferredName())
? ((Number) metadata.get(PER_ALLOCATION_MEMORY_BYTES.getPreferredName())).longValue()
: 0L;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
Expand Down Expand Up @@ -570,6 +587,8 @@ public static class Builder {
private InferenceConfig inferenceConfig;
private TrainedModelLocation location;
private ModelPackageConfig modelPackageConfig;
private Long perDeploymentMemoryBytes;
private Long perAllocationMemoryBytes;

public Builder() {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,9 @@ public Builder setNumberOfAllocations(int numberOfAllocations) {
taskParams.getThreadsPerAllocation(),
taskParams.getQueueCapacity(),
taskParams.getCacheSize().orElse(null),
taskParams.getPriority()
taskParams.getPriority(),
taskParams.getPerDeploymentMemoryBytes(),
taskParams.getPerAllocationMemoryBytes()
);
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ public static StartTrainedModelDeploymentAction.TaskParams createRandom() {
randomIntBetween(1, 8),
randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomNonNegativeLong()),
randomFrom(Priority.values())
randomFrom(Priority.values()),
randomNonNegativeLong(),
randomNonNegativeLong()
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,9 @@ private static StartTrainedModelDeploymentAction.TaskParams randomTaskParams(int
randomIntBetween(1, 8),
randomIntBetween(1, 10000),
randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(0, modelSize + 1)),
randomFrom(Priority.values())
randomFrom(Priority.values()),
randomNonNegativeLong(),
randomNonNegativeLong()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,71 @@ public void testDeploymentStats() throws IOException {
assertAtLeast.accept(modelStarted, AllocationStatus.State.FULLY_ALLOCATED);
}

@SuppressWarnings("unchecked")
public void testRequiredMemoryEstimation() throws IOException {
String modelWithMetadata = "model_with_metadata";
createPassThroughModel(modelWithMetadata, randomLongBetween(0, 10000000), randomLongBetween(0, 10000000));
putVocabulary(List.of("once", "twice"), modelWithMetadata);
putModelDefinition(modelWithMetadata);
String modelNoMetadata = "model_no_metadata";
createPassThroughModel(modelNoMetadata);
putVocabulary(List.of("once", "twice"), modelNoMetadata);
putModelDefinition(modelNoMetadata);

CheckedBiConsumer<String, AllocationStatus.State, IOException> assertAtLeast = (modelId, state) -> {
startDeployment(modelId, state);
Response response = getTrainedModelStats(modelId);
var responseMap = entityAsMap(response);
List<Map<String, Object>> stats = (List<Map<String, Object>>) responseMap.get("trained_model_stats");
assertThat(stats, hasSize(1));
String statusState = (String) XContentMapValues.extractValue("deployment_stats.allocation_status.state", stats.get(0));
assertThat(responseMap.toString(), statusState, is(not(nullValue())));
assertThat(AllocationStatus.State.fromString(statusState), greaterThanOrEqualTo(state));
assertThat(XContentMapValues.extractValue("inference_stats", stats.get(0)), is(not(nullValue())));
Integer numberOfAllocations = (Integer) XContentMapValues.extractValue("deployment_stats.number_of_allocations", stats.get(0));
assertThat(numberOfAllocations, greaterThanOrEqualTo(0));

Integer byteSize = (Integer) XContentMapValues.extractValue("model_size_stats.model_size_bytes", stats.get(0));
assertThat(responseMap.toString(), byteSize, is(not(nullValue())));
assertThat(byteSize, equalTo((int) RAW_MODEL_SIZE));

Integer requiredNativeMemory = (Integer) XContentMapValues.extractValue(
"model_size_stats.required_native_memory_bytes",
stats.get(0)
);
assertThat(responseMap.toString(), requiredNativeMemory, is(not(nullValue())));

Response trainedModelConfigResponse = getTrainedModelConfigs(modelId);
List<Map<String, Object>> configs = (List<Map<String, Object>>) entityAsMap(trainedModelConfigResponse).get(
"trained_model_configs"
);
assertThat(configs, hasSize(1));
Map<String, Object> metadata = (Map<String, Object>) configs.get(0).get("metadata");
Integer canonicalRequiredMemory = (int) (ByteSizeValue.ofMb(240).getBytes() + 2 * RAW_MODEL_SIZE);
if (metadata != null) {
// test required memory estimation for a model with metadata memory requirements
assertThat(metadata, is(not(nullValue())));
assertThat(metadata.containsKey("per_deployment_memory_bytes"), is(true));
long perDeploymentMemoryBytes = ((Number) metadata.get("per_deployment_memory_bytes")).longValue();
assertThat(metadata.containsKey("per_allocation_memory_bytes"), is(true));
long perAllocationMemoryBytes = ((Number) metadata.get("per_allocation_memory_bytes")).longValue();
Integer expectedRequiredMemory = Math.max(
canonicalRequiredMemory,
(int) (perDeploymentMemoryBytes + perAllocationMemoryBytes * numberOfAllocations + RAW_MODEL_SIZE)
);
assertThat(requiredNativeMemory, equalTo(expectedRequiredMemory));
} else {
// test required memory estimation for a model without metadata memory requirements
assertThat(requiredNativeMemory, equalTo(canonicalRequiredMemory));
}

stopDeployment(modelId);
};

assertAtLeast.accept(modelWithMetadata, AllocationStatus.State.STARTING);
assertAtLeast.accept(modelNoMetadata, AllocationStatus.State.STARTING);
}

@SuppressWarnings("unchecked")
public void testLiveDeploymentStats() throws IOException {
String modelId = "live_deployment_stats";
Expand Down
Loading