diff --git a/src/main/java/graphql/GraphQL.java b/src/main/java/graphql/GraphQL.java index 5b4961725..2ccb54b95 100644 --- a/src/main/java/graphql/GraphQL.java +++ b/src/main/java/graphql/GraphQL.java @@ -1,6 +1,7 @@ package graphql; import graphql.execution.AbortExecutionException; +import graphql.execution.Async; import graphql.execution.AsyncExecutionStrategy; import graphql.execution.AsyncSerialExecutionStrategy; import graphql.execution.DataFetcherExceptionHandler; @@ -421,31 +422,33 @@ public CompletableFuture executeAsync(ExecutionInput executionI if (logNotSafe.isDebugEnabled()) { logNotSafe.debug("Executing request. operation name: '{}'. query: '{}'. variables '{}'", executionInput.getOperationName(), executionInput.getQuery(), executionInput.getVariables()); } - executionInput = ensureInputHasId(executionInput); + ExecutionInput executionInputWithId = ensureInputHasId(executionInput); - InstrumentationState instrumentationState = instrumentation.createState(new InstrumentationCreateStateParameters(this.graphQLSchema, executionInput)); - try { - InstrumentationExecutionParameters inputInstrumentationParameters = new InstrumentationExecutionParameters(executionInput, this.graphQLSchema, instrumentationState); - executionInput = instrumentation.instrumentExecutionInput(executionInput, inputInstrumentationParameters, instrumentationState); - - CompletableFuture beginExecutionCF = new CompletableFuture<>(); - InstrumentationExecutionParameters instrumentationParameters = new InstrumentationExecutionParameters(executionInput, this.graphQLSchema, instrumentationState); - InstrumentationContext executionInstrumentation = nonNullCtx(instrumentation.beginExecution(instrumentationParameters, instrumentationState)); - executionInstrumentation.onDispatched(beginExecutionCF); - - GraphQLSchema graphQLSchema = instrumentation.instrumentSchema(this.graphQLSchema, instrumentationParameters, instrumentationState); - - CompletableFuture executionResult = parseValidateAndExecute(executionInput, graphQLSchema, instrumentationState); - // - // finish up instrumentation - executionResult = executionResult.whenComplete(completeInstrumentationCtxCF(executionInstrumentation, beginExecutionCF)); - // - // allow instrumentation to tweak the result - executionResult = executionResult.thenCompose(result -> instrumentation.instrumentExecutionResult(result, instrumentationParameters, instrumentationState)); - return executionResult; - } catch (AbortExecutionException abortException) { - return handleAbortException(executionInput, instrumentationState, abortException); - } + CompletableFuture instrumentationStateCF = instrumentation.createStateAsync(new InstrumentationCreateStateParameters(this.graphQLSchema, executionInput)); + return Async.orNullCompletedFuture(instrumentationStateCF).thenCompose(instrumentationState -> { + try { + InstrumentationExecutionParameters inputInstrumentationParameters = new InstrumentationExecutionParameters(executionInputWithId, this.graphQLSchema, instrumentationState); + ExecutionInput instrumentedExecutionInput = instrumentation.instrumentExecutionInput(executionInputWithId, inputInstrumentationParameters, instrumentationState); + + CompletableFuture beginExecutionCF = new CompletableFuture<>(); + InstrumentationExecutionParameters instrumentationParameters = new InstrumentationExecutionParameters(instrumentedExecutionInput, this.graphQLSchema, instrumentationState); + InstrumentationContext executionInstrumentation = nonNullCtx(instrumentation.beginExecution(instrumentationParameters, instrumentationState)); + executionInstrumentation.onDispatched(beginExecutionCF); + + GraphQLSchema graphQLSchema = instrumentation.instrumentSchema(this.graphQLSchema, instrumentationParameters, instrumentationState); + + CompletableFuture executionResult = parseValidateAndExecute(instrumentedExecutionInput, graphQLSchema, instrumentationState); + // + // finish up instrumentation + executionResult = executionResult.whenComplete(completeInstrumentationCtxCF(executionInstrumentation, beginExecutionCF)); + // + // allow instrumentation to tweak the result + executionResult = executionResult.thenCompose(result -> instrumentation.instrumentExecutionResult(result, instrumentationParameters, instrumentationState)); + return executionResult; + } catch (AbortExecutionException abortException) { + return handleAbortException(executionInput, instrumentationState, abortException); + } + }); } private CompletableFuture handleAbortException(ExecutionInput executionInput, InstrumentationState instrumentationState, AbortExecutionException abortException) { diff --git a/src/main/java/graphql/execution/Async.java b/src/main/java/graphql/execution/Async.java index ec71e2bdc..56f7a2f9b 100644 --- a/src/main/java/graphql/execution/Async.java +++ b/src/main/java/graphql/execution/Async.java @@ -2,6 +2,8 @@ import graphql.Assert; import graphql.Internal; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; import java.util.ArrayList; import java.util.Collection; @@ -207,4 +209,15 @@ public static CompletableFuture exceptionallyCompletedFuture(Throwable ex return result; } + /** + * If the passed in CompletableFuture is null then it creates a CompletableFuture that resolves to null + * + * @param completableFuture the CF to use + * @param for two + * + * @return the completableFuture if it's not null or one that always resoles to null + */ + public static @NotNull CompletableFuture orNullCompletedFuture(@Nullable CompletableFuture completableFuture) { + return completableFuture != null ? completableFuture : CompletableFuture.completedFuture(null); + } } diff --git a/src/main/java/graphql/execution/instrumentation/ChainedInstrumentation.java b/src/main/java/graphql/execution/instrumentation/ChainedInstrumentation.java index bfafc49ed..70e7bd063 100644 --- a/src/main/java/graphql/execution/instrumentation/ChainedInstrumentation.java +++ b/src/main/java/graphql/execution/instrumentation/ChainedInstrumentation.java @@ -22,6 +22,7 @@ import graphql.schema.GraphQLSchema; import graphql.validation.ValidationError; import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; import java.util.Arrays; import java.util.List; @@ -80,10 +81,19 @@ private InstrumentationContext chainedCtx(Function(mapAndDropNulls(instrumentations, mapper)); } + @Override + public InstrumentationState createState() { + return Assert.assertShouldNeverHappen("createStateAsync should only ever be used"); + } + + @Override + public @Nullable InstrumentationState createState(InstrumentationCreateStateParameters parameters) { + return Assert.assertShouldNeverHappen("createStateAsync should only ever be used"); + } @Override - public InstrumentationState createState(InstrumentationCreateStateParameters parameters) { - return new ChainedInstrumentationState(instrumentations, parameters); + public @NotNull CompletableFuture createStateAsync(InstrumentationCreateStateParameters parameters) { + return ChainedInstrumentationState.combineAll(instrumentations, parameters); } @Override @@ -349,18 +359,31 @@ public CompletableFuture instrumentExecutionResult(ExecutionRes } static class ChainedInstrumentationState implements InstrumentationState { - private final Map instrumentationStates; + private final Map instrumentationToStates; - private ChainedInstrumentationState(List instrumentations, InstrumentationCreateStateParameters parameters) { - instrumentationStates = Maps.newLinkedHashMapWithExpectedSize(instrumentations.size()); - instrumentations.forEach(i -> instrumentationStates.put(i, i.createState(parameters))); + private ChainedInstrumentationState(List instrumentations, List instrumentationStates) { + instrumentationToStates = Maps.newLinkedHashMapWithExpectedSize(instrumentations.size()); + for (int i = 0; i < instrumentations.size(); i++) { + Instrumentation instrumentation = instrumentations.get(i); + InstrumentationState instrumentationState = instrumentationStates.get(i); + instrumentationToStates.put(instrumentation, instrumentationState); + } } private InstrumentationState getState(Instrumentation instrumentation) { - return instrumentationStates.get(instrumentation); + return instrumentationToStates.get(instrumentation); } + private static CompletableFuture combineAll(List instrumentations, InstrumentationCreateStateParameters parameters) { + Async.CombinedBuilder builder = Async.ofExpectedSize(instrumentations.size()); + for (Instrumentation instrumentation : instrumentations) { + // state can be null including the CF so handle that + CompletableFuture stateCF = Async.orNullCompletedFuture(instrumentation.createStateAsync(parameters)); + builder.add(stateCF); + } + return builder.await().thenApply(instrumentationStates -> new ChainedInstrumentationState(instrumentations, instrumentationStates)); + } } private static class ChainedInstrumentationContext implements InstrumentationContext { diff --git a/src/main/java/graphql/execution/instrumentation/Instrumentation.java b/src/main/java/graphql/execution/instrumentation/Instrumentation.java index 989364c48..77c4c6bd8 100644 --- a/src/main/java/graphql/execution/instrumentation/Instrumentation.java +++ b/src/main/java/graphql/execution/instrumentation/Instrumentation.java @@ -63,11 +63,26 @@ default InstrumentationState createState() { * * @return a state object that is passed to each method */ + @Deprecated + @DeprecatedAt("2023-08-25") @Nullable default InstrumentationState createState(InstrumentationCreateStateParameters parameters) { return createState(); } + /** + * This will be called just before execution to create an object, in an asynchronous manner, that is given back to all instrumentation methods + * to allow them to have per execution request state + * + * @param parameters the parameters to this step + * + * @return a state object that is passed to each method + */ + @Nullable + default CompletableFuture createStateAsync(InstrumentationCreateStateParameters parameters) { + return CompletableFuture.completedFuture(createState(parameters)); + } + /** * This is called right at the start of query execution, and it's the first step in the instrumentation chain. * diff --git a/src/main/java/graphql/execution/instrumentation/SimplePerformantInstrumentation.java b/src/main/java/graphql/execution/instrumentation/SimplePerformantInstrumentation.java index b7835ffdd..8ad5f8eef 100644 --- a/src/main/java/graphql/execution/instrumentation/SimplePerformantInstrumentation.java +++ b/src/main/java/graphql/execution/instrumentation/SimplePerformantInstrumentation.java @@ -56,6 +56,12 @@ public InstrumentationState createState() { return null; } + @Override + public @Nullable CompletableFuture createStateAsync(InstrumentationCreateStateParameters parameters) { + InstrumentationState state = createState(parameters); + return state == null ? null : CompletableFuture.completedFuture(state); + } + @Override public @NotNull InstrumentationContext beginExecution(InstrumentationExecutionParameters parameters) { return assertShouldNeverHappen("The deprecated " + "beginExecution" + " was called"); diff --git a/src/test/groovy/graphql/execution/instrumentation/ChainedInstrumentationStateTest.groovy b/src/test/groovy/graphql/execution/instrumentation/ChainedInstrumentationStateTest.groovy index c2c1ffb8c..f1812e395 100644 --- a/src/test/groovy/graphql/execution/instrumentation/ChainedInstrumentationStateTest.groovy +++ b/src/test/groovy/graphql/execution/instrumentation/ChainedInstrumentationStateTest.groovy @@ -1,17 +1,13 @@ package graphql.execution.instrumentation +import graphql.ExecutionInput import graphql.ExecutionResult import graphql.GraphQL import graphql.StarWarsSchema import graphql.execution.AsyncExecutionStrategy -import graphql.execution.instrumentation.parameters.InstrumentationExecuteOperationParameters +import graphql.execution.instrumentation.parameters.InstrumentationCreateStateParameters import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters -import graphql.execution.instrumentation.parameters.InstrumentationExecutionStrategyParameters -import graphql.execution.instrumentation.parameters.InstrumentationFieldFetchParameters -import graphql.execution.instrumentation.parameters.InstrumentationFieldParameters import graphql.execution.instrumentation.parameters.InstrumentationValidationParameters -import graphql.language.Document -import graphql.schema.DataFetcher import graphql.validation.ValidationError import spock.lang.Specification @@ -279,6 +275,75 @@ class ChainedInstrumentationStateTest extends Specification { } + + class StringInstrumentationState implements InstrumentationState { + StringInstrumentationState(String value) { + this.value = value + } + + String value + } + + def "can have an multiple async createState() calls in play"() { + + + given: + + def query = '''query Q($var: String!) { + human(id: $var) { + id + name + } + } + ''' + + + def instrumentation1 = new SimplePerformantInstrumentation() { + @Override + CompletableFuture createStateAsync(InstrumentationCreateStateParameters parameters) { + return CompletableFuture.supplyAsync { + return new StringInstrumentationState("I1") + } as CompletableFuture + } + + @Override + CompletableFuture instrumentExecutionResult(ExecutionResult executionResult, InstrumentationExecutionParameters parameters, InstrumentationState state) { + return CompletableFuture.completedFuture( + executionResult.transform { it.addExtension("i1", ((StringInstrumentationState) state).value) } + ) + } + } + def instrumentation2 = new SimplePerformantInstrumentation() { + @Override + CompletableFuture createStateAsync(InstrumentationCreateStateParameters parameters) { + return CompletableFuture.supplyAsync { + return new StringInstrumentationState("I2") + } as CompletableFuture + } + + @Override + CompletableFuture instrumentExecutionResult(ExecutionResult executionResult, InstrumentationExecutionParameters parameters, InstrumentationState state) { + return CompletableFuture.completedFuture( + executionResult.transform { it.addExtension("i2", ((StringInstrumentationState) state).value) } + ) + } + + } + + def graphQL = GraphQL + .newGraphQL(StarWarsSchema.starWarsSchema) + .instrumentation(new ChainedInstrumentation([instrumentation1, instrumentation2])) + .doNotAddDefaultInstrumentations() // important, otherwise a chained one wil be used + .build() + + when: + def variables = [var: "1001"] + def er = graphQL.execute(ExecutionInput.newExecutionInput().query(query).variables(variables)) // Luke + + then: + er.extensions == [i1: "I1", i2: "I2"] + } + private void assertCalls(NamedInstrumentation instrumentation) { assert instrumentation.dfInvocations[0].getFieldDefinition().name == 'hero' assert instrumentation.dfInvocations[0].getExecutionStepInfo().getPath().toList() == ['hero'] diff --git a/src/test/groovy/graphql/execution/instrumentation/InstrumentationTest.groovy b/src/test/groovy/graphql/execution/instrumentation/InstrumentationTest.groovy index 694e100c0..85b274089 100644 --- a/src/test/groovy/graphql/execution/instrumentation/InstrumentationTest.groovy +++ b/src/test/groovy/graphql/execution/instrumentation/InstrumentationTest.groovy @@ -5,6 +5,7 @@ import graphql.ExecutionResult import graphql.GraphQL import graphql.StarWarsSchema import graphql.execution.AsyncExecutionStrategy +import graphql.execution.instrumentation.parameters.InstrumentationCreateStateParameters import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters import graphql.execution.instrumentation.parameters.InstrumentationExecutionStrategyParameters import graphql.execution.instrumentation.parameters.InstrumentationFieldFetchParameters @@ -404,4 +405,56 @@ class InstrumentationTest extends Specification { instrumentation.executionList == expected } + + class StringInstrumentationState implements InstrumentationState { + StringInstrumentationState(String value) { + this.value = value + } + + String value + } + + def "can have an single async createState() in play"() { + + + given: + + def query = '''query Q($var: String!) { + human(id: $var) { + id + name + } + } + ''' + + + def instrumentation1 = new SimplePerformantInstrumentation() { + @Override + CompletableFuture createStateAsync(InstrumentationCreateStateParameters parameters) { + return CompletableFuture.supplyAsync { + return new StringInstrumentationState("I1") + } as CompletableFuture + } + + @Override + CompletableFuture instrumentExecutionResult(ExecutionResult executionResult, InstrumentationExecutionParameters parameters, InstrumentationState state) { + return CompletableFuture.completedFuture( + executionResult.transform { it.addExtension("i1", ((StringInstrumentationState) state).value) } + ) + } + } + + def graphQL = GraphQL + .newGraphQL(StarWarsSchema.starWarsSchema) + .instrumentation(instrumentation1) + .doNotAddDefaultInstrumentations() // important, otherwise a chained one wil be used + .build() + + when: + def variables = [var: "1001"] + def er = graphQL.execute(ExecutionInput.newExecutionInput().query(query).variables(variables)) // Luke + + then: + er.extensions == [i1: "I1"] + } }