diff --git a/kgraphql/api/kgraphql.api b/kgraphql/api/kgraphql.api index 51825e9f..73ca2f45 100644 --- a/kgraphql/api/kgraphql.api +++ b/kgraphql/api/kgraphql.api @@ -868,10 +868,9 @@ public final class com/apurebase/kgraphql/schema/execution/ParallelRequestExecut } public final class com/apurebase/kgraphql/schema/execution/ParallelRequestExecutor$ExecutionContext { - public fun (Lkotlinx/coroutines/CoroutineScope;Lcom/apurebase/kgraphql/request/Variables;Lcom/apurebase/kgraphql/Context;Ljava/util/Map;)V + public fun (Lcom/apurebase/kgraphql/request/Variables;Lcom/apurebase/kgraphql/Context;Ljava/util/Map;)V public final fun getLoaders ()Ljava/util/Map; public final fun getRequestContext ()Lcom/apurebase/kgraphql/Context; - public final fun getScope ()Lkotlinx/coroutines/CoroutineScope; public final fun getVariables ()Lcom/apurebase/kgraphql/request/Variables; } diff --git a/kgraphql/src/main/kotlin/com/apurebase/kgraphql/Extensions.kt b/kgraphql/src/main/kotlin/com/apurebase/kgraphql/Extensions.kt index b658aa86..abdfb684 100644 --- a/kgraphql/src/main/kotlin/com/apurebase/kgraphql/Extensions.kt +++ b/kgraphql/src/main/kotlin/com/apurebase/kgraphql/Extensions.kt @@ -1,11 +1,8 @@ package com.apurebase.kgraphql -import kotlinx.coroutines.CoroutineDispatcher -import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll import kotlinx.coroutines.coroutineScope -import kotlinx.coroutines.withContext import kotlin.reflect.KClass import kotlin.reflect.KType import kotlin.reflect.full.isSubclassOf @@ -30,12 +27,6 @@ internal fun KType.getIterableElementType(): KType { return arguments.firstOrNull()?.type ?: throw NoSuchElementException("KType $this has no type arguments") } -internal suspend fun Iterable.mapIndexedParallel( - dispatcher: CoroutineDispatcher = Dispatchers.Default, - block: suspend (Int, T) -> R -): List = - withContext(dispatcher) { - coroutineScope { - this@mapIndexedParallel.mapIndexed { index, i -> async { block(index, i) } }.awaitAll() - } - } +internal suspend fun Iterable.mapIndexedParallel(block: suspend (Int, T) -> R): List = coroutineScope { + this@mapIndexedParallel.mapIndexed { index, i -> async { block(index, i) } }.awaitAll() +} diff --git a/kgraphql/src/main/kotlin/com/apurebase/kgraphql/schema/DefaultSchema.kt b/kgraphql/src/main/kotlin/com/apurebase/kgraphql/schema/DefaultSchema.kt index c31372da..92516eb5 100644 --- a/kgraphql/src/main/kotlin/com/apurebase/kgraphql/schema/DefaultSchema.kt +++ b/kgraphql/src/main/kotlin/com/apurebase/kgraphql/schema/DefaultSchema.kt @@ -16,7 +16,7 @@ import com.apurebase.kgraphql.schema.structure.LookupSchema import com.apurebase.kgraphql.schema.structure.RequestInterpreter import com.apurebase.kgraphql.schema.structure.SchemaModel import com.apurebase.kgraphql.schema.structure.Type -import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.withContext import kotlin.reflect.KClass class DefaultSchema( @@ -42,7 +42,7 @@ class DefaultSchema( variables: String?, context: Context, operationName: String?, - ): String = coroutineScope { + ): String = withContext(configuration.coroutineDispatcher) { if (!configuration.introspection && Introspection.isIntrospection(request)) { throw ValidationException("GraphQL introspection is not allowed") } diff --git a/kgraphql/src/main/kotlin/com/apurebase/kgraphql/schema/execution/ParallelRequestExecutor.kt b/kgraphql/src/main/kotlin/com/apurebase/kgraphql/schema/execution/ParallelRequestExecutor.kt index 5997d19a..18ac1111 100644 --- a/kgraphql/src/main/kotlin/com/apurebase/kgraphql/schema/execution/ParallelRequestExecutor.kt +++ b/kgraphql/src/main/kotlin/com/apurebase/kgraphql/schema/execution/ParallelRequestExecutor.kt @@ -19,9 +19,9 @@ import com.fasterxml.jackson.databind.node.ArrayNode import com.fasterxml.jackson.databind.node.NullNode import com.fasterxml.jackson.databind.node.ObjectNode import kotlinx.coroutines.CompletableDeferred -import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Deferred import kotlinx.coroutines.async +import kotlinx.coroutines.awaitAll import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.job import nidomiro.kdataloader.DataLoader @@ -31,7 +31,6 @@ import kotlin.reflect.KProperty1 class ParallelRequestExecutor(val schema: DefaultSchema) : RequestExecutor { class ExecutionContext( - val scope: CoroutineScope, val variables: Variables, val requestContext: Context, val loaders: Map, DataLoader> @@ -63,8 +62,6 @@ class ParallelRequestExecutor(val schema: DefaultSchema) : RequestExecutor { private val jsonNodeFactory = schema.configuration.objectMapper.nodeFactory - private val dispatcher = schema.configuration.coroutineDispatcher - private val objectWriter = schema.configuration.objectMapper.writer().let { if (schema.configuration.useDefaultPrettyPrinter) { it.withDefaultPrettyPrinter() @@ -79,19 +76,17 @@ class ParallelRequestExecutor(val schema: DefaultSchema) : RequestExecutor { val data = root.putObject("data") val loaders = plan.constructLoaders() - val resultMap = plan.mapIndexedParallel(dispatcher) { _, operation -> - coroutineScope { - val ctx = ExecutionContext(this, Variables(variables, operation.variables), context, loaders) - if (shouldInclude(ctx, operation)) { - operation to writeOperation( - isSubscription = plan.isSubscription, - ctx = ctx, - node = operation, - operation = operation.field as Field.Function<*, *> - ) - } else { - operation to null - } + val resultMap = plan.mapIndexedParallel { _, operation -> + val ctx = ExecutionContext(Variables(variables, operation.variables), context, loaders) + if (shouldInclude(ctx, operation)) { + operation to writeOperation( + isSubscription = plan.isSubscription, + ctx = ctx, + node = operation, + operation = operation.field as Field.Function<*, *> + ) + } else { + operation to null } }.toMap() @@ -162,19 +157,19 @@ class ParallelRequestExecutor(val schema: DefaultSchema) : RequestExecutor { value: T?, node: Execution.Node, returnType: Type - ): Deferred { + ): Deferred = coroutineScope { if (value == null || value is NullNode) { - return CompletableDeferred(createNullNode(node, returnType)) + return@coroutineScope CompletableDeferred(createNullNode(node, returnType)) } val unboxed = schema.configuration.genericTypeResolver.unbox(value) if (unboxed !== value) { - return createNode(ctx, unboxed, node, returnType) + return@coroutineScope createNode(ctx, unboxed, node, returnType) } - return when { + return@coroutineScope when { // Check value, not returnType, because this method can be invoked with element value - value is Collection<*> || value is Array<*> || value is ArrayNode -> ctx.scope.async { + value is Collection<*> || value is Array<*> || value is ArrayNode -> async { val values: Collection<*> = when (value) { is Array<*> -> value.toList() is ArrayNode -> value.toList() @@ -182,7 +177,7 @@ class ParallelRequestExecutor(val schema: DefaultSchema) : RequestExecutor { } if (returnType.isList()) { val unwrappedReturnType = returnType.unwrapList() - val valuesMap = values.mapIndexedParallel(dispatcher) { i, value -> + val valuesMap = values.mapIndexedParallel { i, value -> value to createNode(ctx, value, node.withIndex(i), unwrappedReturnType) }.toMap() values.fold(jsonNodeFactory.arrayNode(values.size)) { array, v -> @@ -238,31 +233,33 @@ class ParallelRequestExecutor(val schema: DefaultSchema) : RequestExecutor { } } - private fun createObjectNode( + private suspend fun createObjectNode( ctx: ExecutionContext, value: T, node: Execution.Node, type: Type - ): Deferred = ctx.scope.async { - val objectNode = jsonNodeFactory.objectNode() - val deferreds = mutableListOf?>>>() - for (child in node.children) { - when (child) { - is Execution.Fragment -> deferreds.add(ctx.scope.async { - handleFragment(ctx, value, child.withParent(node)) - }) - - else -> deferreds.add(ctx.scope.async { - handleProperty(ctx, value, child.withParent(node), type)?.let { mapOf(it) } ?: emptyMap() - }) + ): Deferred = coroutineScope { + async { + val objectNode = jsonNodeFactory.objectNode() + val deferreds = mutableListOf?>>>() + for (child in node.children) { + when (child) { + is Execution.Fragment -> deferreds.add(async { + handleFragment(ctx, value, child.withParent(node)) + }) + + else -> deferreds.add(async { + handleProperty(ctx, value, child.withParent(node), type)?.let { mapOf(it) } ?: emptyMap() + }) + } } - } - deferreds.forEach { - it.await().forEach { (key, value) -> - objectNode.merge(key, value?.await()) + deferreds.awaitAll().forEach { + it.forEach { (key, value) -> + objectNode.merge(key, value?.await()) + } } + objectNode } - objectNode } private suspend fun handleProperty( @@ -445,7 +442,7 @@ class ParallelRequestExecutor(val schema: DefaultSchema) : RequestExecutor { }?.reduce { acc, b -> acc && b } ?: true } - internal fun FunctionWrapper.invoke( + internal suspend fun FunctionWrapper.invoke( isSubscription: Boolean = false, children: Collection = emptyList(), funName: String, @@ -454,7 +451,7 @@ class ParallelRequestExecutor(val schema: DefaultSchema) : RequestExecutor { args: ArgumentNodes?, executionNode: Execution, ctx: ExecutionContext - ): Deferred { + ): Deferred = coroutineScope { val transformedArgs = argumentsHandler.transformArguments( funName, receiver, @@ -463,11 +460,10 @@ class ParallelRequestExecutor(val schema: DefaultSchema) : RequestExecutor { ctx.variables, executionNode, ctx.requestContext, - this - ) ?: return CompletableDeferred(value = null) + this@invoke + ) ?: return@coroutineScope CompletableDeferred(value = null) - // exceptions are not caught on purpose to pass up business logic errors - return ctx.scope.async { + async { try { when { hasReceiver -> invoke(receiver, *transformedArgs.toTypedArray())