Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions kgraphql/api/kgraphql.api
Original file line number Diff line number Diff line change
Expand Up @@ -868,10 +868,9 @@ public final class com/apurebase/kgraphql/schema/execution/ParallelRequestExecut
}

public final class com/apurebase/kgraphql/schema/execution/ParallelRequestExecutor$ExecutionContext {
public fun <init> (Lkotlinx/coroutines/CoroutineScope;Lcom/apurebase/kgraphql/request/Variables;Lcom/apurebase/kgraphql/Context;Ljava/util/Map;)V
public fun <init> (Lcom/apurebase/kgraphql/request/Variables;Lcom/apurebase/kgraphql/Context;Ljava/util/Map;)V
public final fun getLoaders ()Ljava/util/Map;
public final fun getRequestContext ()Lcom/apurebase/kgraphql/Context;
public final fun getScope ()Lkotlinx/coroutines/CoroutineScope;
public final fun getVariables ()Lcom/apurebase/kgraphql/request/Variables;
}

Expand Down
15 changes: 3 additions & 12 deletions kgraphql/src/main/kotlin/com/apurebase/kgraphql/Extensions.kt
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package com.apurebase.kgraphql

import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.withContext
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.full.isSubclassOf
Expand All @@ -30,12 +27,6 @@ internal fun KType.getIterableElementType(): KType {
return arguments.firstOrNull()?.type ?: throw NoSuchElementException("KType $this has no type arguments")
}

internal suspend fun <T, R> Iterable<T>.mapIndexedParallel(
dispatcher: CoroutineDispatcher = Dispatchers.Default,
block: suspend (Int, T) -> R
): List<R> =
withContext(dispatcher) {
coroutineScope {
this@mapIndexedParallel.mapIndexed { index, i -> async { block(index, i) } }.awaitAll()
}
}
internal suspend fun <T, R> Iterable<T>.mapIndexedParallel(block: suspend (Int, T) -> R): List<R> = coroutineScope {
this@mapIndexedParallel.mapIndexed { index, i -> async { block(index, i) } }.awaitAll()
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import com.apurebase.kgraphql.schema.structure.LookupSchema
import com.apurebase.kgraphql.schema.structure.RequestInterpreter
import com.apurebase.kgraphql.schema.structure.SchemaModel
import com.apurebase.kgraphql.schema.structure.Type
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.withContext
import kotlin.reflect.KClass

class DefaultSchema(
Expand All @@ -42,7 +42,7 @@ class DefaultSchema(
variables: String?,
context: Context,
operationName: String?,
): String = coroutineScope {
): String = withContext(configuration.coroutineDispatcher) {
if (!configuration.introspection && Introspection.isIntrospection(request)) {
throw ValidationException("GraphQL introspection is not allowed")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ import com.fasterxml.jackson.databind.node.ArrayNode
import com.fasterxml.jackson.databind.node.NullNode
import com.fasterxml.jackson.databind.node.ObjectNode
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.job
import nidomiro.kdataloader.DataLoader
Expand All @@ -31,7 +31,6 @@ import kotlin.reflect.KProperty1
class ParallelRequestExecutor(val schema: DefaultSchema) : RequestExecutor {

class ExecutionContext(
val scope: CoroutineScope,
val variables: Variables,
val requestContext: Context,
val loaders: Map<Field.DataLoader<*, *, *>, DataLoader<Any?, *>>
Expand Down Expand Up @@ -63,8 +62,6 @@ class ParallelRequestExecutor(val schema: DefaultSchema) : RequestExecutor {

private val jsonNodeFactory = schema.configuration.objectMapper.nodeFactory

private val dispatcher = schema.configuration.coroutineDispatcher

private val objectWriter = schema.configuration.objectMapper.writer().let {
if (schema.configuration.useDefaultPrettyPrinter) {
it.withDefaultPrettyPrinter()
Expand All @@ -79,19 +76,17 @@ class ParallelRequestExecutor(val schema: DefaultSchema) : RequestExecutor {
val data = root.putObject("data")
val loaders = plan.constructLoaders()

val resultMap = plan.mapIndexedParallel(dispatcher) { _, operation ->
coroutineScope {
val ctx = ExecutionContext(this, Variables(variables, operation.variables), context, loaders)
if (shouldInclude(ctx, operation)) {
operation to writeOperation(
isSubscription = plan.isSubscription,
ctx = ctx,
node = operation,
operation = operation.field as Field.Function<*, *>
)
} else {
operation to null
}
val resultMap = plan.mapIndexedParallel { _, operation ->
val ctx = ExecutionContext(Variables(variables, operation.variables), context, loaders)
if (shouldInclude(ctx, operation)) {
operation to writeOperation(
isSubscription = plan.isSubscription,
ctx = ctx,
node = operation,
operation = operation.field as Field.Function<*, *>
)
} else {
operation to null
}
}.toMap()

Expand Down Expand Up @@ -162,27 +157,27 @@ class ParallelRequestExecutor(val schema: DefaultSchema) : RequestExecutor {
value: T?,
node: Execution.Node,
returnType: Type
): Deferred<JsonNode> {
): Deferred<JsonNode> = coroutineScope {
if (value == null || value is NullNode) {
return CompletableDeferred(createNullNode(node, returnType))
return@coroutineScope CompletableDeferred(createNullNode(node, returnType))
}

val unboxed = schema.configuration.genericTypeResolver.unbox(value)
if (unboxed !== value) {
return createNode(ctx, unboxed, node, returnType)
return@coroutineScope createNode(ctx, unboxed, node, returnType)
}

return when {
return@coroutineScope when {
// Check value, not returnType, because this method can be invoked with element value
value is Collection<*> || value is Array<*> || value is ArrayNode -> ctx.scope.async {
value is Collection<*> || value is Array<*> || value is ArrayNode -> async {
val values: Collection<*> = when (value) {
is Array<*> -> value.toList()
is ArrayNode -> value.toList()
else -> value as Collection<*>
}
if (returnType.isList()) {
val unwrappedReturnType = returnType.unwrapList()
val valuesMap = values.mapIndexedParallel(dispatcher) { i, value ->
val valuesMap = values.mapIndexedParallel { i, value ->
value to createNode(ctx, value, node.withIndex(i), unwrappedReturnType)
}.toMap()
values.fold(jsonNodeFactory.arrayNode(values.size)) { array, v ->
Expand Down Expand Up @@ -238,31 +233,33 @@ class ParallelRequestExecutor(val schema: DefaultSchema) : RequestExecutor {
}
}

private fun <T> createObjectNode(
private suspend fun <T> createObjectNode(
ctx: ExecutionContext,
value: T,
node: Execution.Node,
type: Type
): Deferred<ObjectNode> = ctx.scope.async {
val objectNode = jsonNodeFactory.objectNode()
val deferreds = mutableListOf<Deferred<Map<String, Deferred<JsonNode?>?>>>()
for (child in node.children) {
when (child) {
is Execution.Fragment -> deferreds.add(ctx.scope.async {
handleFragment(ctx, value, child.withParent(node))
})

else -> deferreds.add(ctx.scope.async {
handleProperty(ctx, value, child.withParent(node), type)?.let { mapOf(it) } ?: emptyMap()
})
): Deferred<ObjectNode> = coroutineScope {
async {
val objectNode = jsonNodeFactory.objectNode()
val deferreds = mutableListOf<Deferred<Map<String, Deferred<JsonNode?>?>>>()
for (child in node.children) {
when (child) {
is Execution.Fragment -> deferreds.add(async {
handleFragment(ctx, value, child.withParent(node))
})

else -> deferreds.add(async {
handleProperty(ctx, value, child.withParent(node), type)?.let { mapOf(it) } ?: emptyMap()
})
}
}
}
deferreds.forEach {
it.await().forEach { (key, value) ->
objectNode.merge(key, value?.await())
deferreds.awaitAll().forEach {
it.forEach { (key, value) ->
objectNode.merge(key, value?.await())
}
}
objectNode
}
objectNode
}

private suspend fun <T> handleProperty(
Expand Down Expand Up @@ -445,7 +442,7 @@ class ParallelRequestExecutor(val schema: DefaultSchema) : RequestExecutor {
}?.reduce { acc, b -> acc && b } ?: true
}

internal fun <T> FunctionWrapper<T>.invoke(
internal suspend fun <T> FunctionWrapper<T>.invoke(
isSubscription: Boolean = false,
children: Collection<Execution> = emptyList(),
funName: String,
Expand All @@ -454,7 +451,7 @@ class ParallelRequestExecutor(val schema: DefaultSchema) : RequestExecutor {
args: ArgumentNodes?,
executionNode: Execution,
ctx: ExecutionContext
): Deferred<T?> {
): Deferred<T?> = coroutineScope {
val transformedArgs = argumentsHandler.transformArguments(
funName,
receiver,
Expand All @@ -463,11 +460,10 @@ class ParallelRequestExecutor(val schema: DefaultSchema) : RequestExecutor {
ctx.variables,
executionNode,
ctx.requestContext,
this
) ?: return CompletableDeferred(value = null)
this@invoke
) ?: return@coroutineScope CompletableDeferred(value = null)

// exceptions are not caught on purpose to pass up business logic errors
return ctx.scope.async {
async {
try {
when {
hasReceiver -> invoke(receiver, *transformedArgs.toTypedArray())
Expand Down
Loading