diff --git a/src/main/java/graphql/EngineRunningState.java b/src/main/java/graphql/EngineRunningState.java index 43b584805..0806f5880 100644 --- a/src/main/java/graphql/EngineRunningState.java +++ b/src/main/java/graphql/EngineRunningState.java @@ -1,7 +1,9 @@ package graphql; +import graphql.execution.AbortExecutionException; import graphql.execution.EngineRunningObserver; import graphql.execution.ExecutionId; +import org.jspecify.annotations.NullMarked; import org.jspecify.annotations.Nullable; import java.util.concurrent.CompletableFuture; @@ -13,19 +15,20 @@ import java.util.function.Supplier; import static graphql.Assert.assertTrue; +import static graphql.execution.EngineRunningObserver.RunningState.CANCELLED; import static graphql.execution.EngineRunningObserver.RunningState.NOT_RUNNING; import static graphql.execution.EngineRunningObserver.RunningState.NOT_RUNNING_FINISH; import static graphql.execution.EngineRunningObserver.RunningState.RUNNING; import static graphql.execution.EngineRunningObserver.RunningState.RUNNING_START; @Internal +@NullMarked public class EngineRunningState { @Nullable private final EngineRunningObserver engineRunningObserver; - @Nullable + private volatile ExecutionInput executionInput; private final GraphQLContext graphQLContext; - @Nullable private volatile ExecutionId executionId; // if true the last decrementRunning() call will be ignored @@ -33,24 +36,11 @@ public class EngineRunningState { private final AtomicInteger isRunning = new AtomicInteger(0); - @VisibleForTesting - public EngineRunningState() { - this.engineRunningObserver = null; - this.graphQLContext = null; - this.executionId = null; - } - public EngineRunningState(ExecutionInput executionInput) { - EngineRunningObserver engineRunningObserver = executionInput.getGraphQLContext().get(EngineRunningObserver.ENGINE_RUNNING_OBSERVER_KEY); - if (engineRunningObserver != null) { - this.engineRunningObserver = engineRunningObserver; - this.graphQLContext = executionInput.getGraphQLContext(); - this.executionId = executionInput.getExecutionId(); - } else { - this.engineRunningObserver = null; - this.graphQLContext = null; - this.executionId = null; - } + this.executionInput = executionInput; + this.graphQLContext = executionInput.getGraphQLContext(); + this.executionId = executionInput.getExecutionId(); + this.engineRunningObserver = executionInput.getGraphQLContext().get(EngineRunningObserver.ENGINE_RUNNING_OBSERVER_KEY); } public CompletableFuture handle(CompletableFuture src, BiFunction fn) { @@ -64,6 +54,7 @@ public CompletableFuture handle(CompletableFuture src, BiFunction engineRun(Supplier getExtensions() { return extensions; } + + /** + * The graphql engine will check this frequently and if that is true, it will + * throw a {@link graphql.execution.AbortExecutionException} to cancel the execution. + *

+ * This is a cooperative cancellation. Some asynchronous data fetching code may still continue to + * run but there will be no more efforts run future field fetches say. + * + * @return true if the execution should be cancelled + */ + public boolean isCancelled() { + return cancelled.get(); + } + + /** + * This can be called to cancel the graphql execution. Remember this is a cooperative cancellation + * and the graphql engine needs to be running on a thread to allow is to respect this flag. + */ + public void cancel() { + cancelled.set(true); + } + /** * This helps you transform the current ExecutionInput object into another one by starting a builder with all * the current values and allows you to transform it how you want. diff --git a/src/main/java/graphql/GraphQL.java b/src/main/java/graphql/GraphQL.java index d7207d15e..bee767ae4 100644 --- a/src/main/java/graphql/GraphQL.java +++ b/src/main/java/graphql/GraphQL.java @@ -481,7 +481,7 @@ public CompletableFuture executeAsync(ExecutionInput executionI EngineRunningState engineRunningState = new EngineRunningState(executionInput); return engineRunningState.engineRun(() -> { ExecutionInput executionInputWithId = ensureInputHasId(executionInput); - engineRunningState.updateExecutionId(executionInputWithId.getExecutionId()); + engineRunningState.updateExecutionInput(executionInputWithId); CompletableFuture instrumentationStateCF = instrumentation.createStateAsync(new InstrumentationCreateStateParameters(this.graphQLSchema, executionInputWithId)); instrumentationStateCF = Async.orNullCompletedFuture(instrumentationStateCF); diff --git a/src/main/java/graphql/execution/AbstractAsyncExecutionStrategy.java b/src/main/java/graphql/execution/AbstractAsyncExecutionStrategy.java index d04ba001c..25f2036cb 100644 --- a/src/main/java/graphql/execution/AbstractAsyncExecutionStrategy.java +++ b/src/main/java/graphql/execution/AbstractAsyncExecutionStrategy.java @@ -22,6 +22,8 @@ public AbstractAsyncExecutionStrategy(DataFetcherExceptionHandler dataFetcherExc protected BiConsumer, Throwable> handleResults(ExecutionContext executionContext, List fieldNames, CompletableFuture overallResult) { return (List results, Throwable exception) -> { + exception = executionContext.possibleCancellation(exception); + if (exception != null) { handleNonNullException(executionContext, overallResult, exception); return; diff --git a/src/main/java/graphql/execution/AsyncExecutionStrategy.java b/src/main/java/graphql/execution/AsyncExecutionStrategy.java index 7f5114908..d355bfd24 100644 --- a/src/main/java/graphql/execution/AsyncExecutionStrategy.java +++ b/src/main/java/graphql/execution/AsyncExecutionStrategy.java @@ -65,6 +65,8 @@ public CompletableFuture execute(ExecutionContext executionCont List fieldsExecutedOnInitialResult = deferredExecutionSupport.getNonDeferredFieldNames(fieldNames); BiConsumer, Throwable> handleResultsConsumer = handleResults(executionContext, fieldsExecutedOnInitialResult, overallResult); + throwable = executionContext.possibleCancellation(throwable); + if (throwable != null) { handleResultsConsumer.accept(null, throwable.getCause()); return; diff --git a/src/main/java/graphql/execution/EngineRunningObserver.java b/src/main/java/graphql/execution/EngineRunningObserver.java index c75f47706..a13fe7701 100644 --- a/src/main/java/graphql/execution/EngineRunningObserver.java +++ b/src/main/java/graphql/execution/EngineRunningObserver.java @@ -1,5 +1,6 @@ package graphql.execution; +import graphql.ExecutionInput; import graphql.ExperimentalApi; import graphql.GraphQLContext; import org.jspecify.annotations.NullMarked; @@ -8,6 +9,8 @@ * This class lets you observe the running state of the graphql-java engine. As it processes and dispatches graphql fields, * the engine moves in and out of a running and not running state. As it does this, the callback is called with information telling you the current * state. + *

+ * If the engine is cancelled via {@link ExecutionInput#cancel()} then the observer will also be called to indicate that. */ @ExperimentalApi @NullMarked @@ -29,7 +32,11 @@ enum RunningState { /** * Represents that the engine is finished */ - NOT_RUNNING_FINISH + NOT_RUNNING_FINISH, + /** + * Represents that the engine code has been cancelled via {@link ExecutionInput#cancel()} + */ + CANCELLED } diff --git a/src/main/java/graphql/execution/Execution.java b/src/main/java/graphql/execution/Execution.java index c1220779c..4c8cb7da3 100644 --- a/src/main/java/graphql/execution/Execution.java +++ b/src/main/java/graphql/execution/Execution.java @@ -86,6 +86,12 @@ public CompletableFuture execute(Document document, GraphQLSche throw rte; } + // before we get started - did they ask us to cancel? + AbortExecutionException abortExecutionException = engineRunningState.ifCancelledMakeException(); + if (abortExecutionException != null) { + return completedFuture(abortExecutionException.toExecutionResult()); + } + boolean propagateErrorsOnNonNullContractFailure = propagateErrorsOnNonNullContractFailure(getOperationResult.operationDefinition.getDirectives()); ResponseMapFactory responseMapFactory = GraphQL.unusualConfiguration(executionInput.getGraphQLContext()) diff --git a/src/main/java/graphql/execution/ExecutionContext.java b/src/main/java/graphql/execution/ExecutionContext.java index 92a375af6..5bdc076d3 100644 --- a/src/main/java/graphql/execution/ExecutionContext.java +++ b/src/main/java/graphql/execution/ExecutionContext.java @@ -23,6 +23,7 @@ import graphql.util.FpKit; import graphql.util.LockKit; import org.dataloader.DataLoaderRegistry; +import org.jspecify.annotations.Nullable; import java.util.HashSet; import java.util.List; @@ -377,4 +378,15 @@ public boolean hasIncrementalSupport() { public EngineRunningState getEngineRunningState() { return engineRunningState; } -} + + @Internal + @Nullable + Throwable possibleCancellation(@Nullable Throwable currentThrowable) { + return engineRunningState.possibleCancellation(currentThrowable); + } + + @Internal + void throwIfCancelled() throws AbortExecutionException { + engineRunningState.throwIfCancelled(); + } +} \ No newline at end of file diff --git a/src/main/java/graphql/execution/ExecutionStrategy.java b/src/main/java/graphql/execution/ExecutionStrategy.java index c1323f0b8..563d4d527 100644 --- a/src/main/java/graphql/execution/ExecutionStrategy.java +++ b/src/main/java/graphql/execution/ExecutionStrategy.java @@ -194,6 +194,8 @@ public static String mkNameForPath(List currentField) { @SuppressWarnings("unchecked") @DuckTyped(shape = "CompletableFuture> | Map") protected Object executeObject(ExecutionContext executionContext, ExecutionStrategyParameters parameters) throws NonNullableFieldWasNullException { + executionContext.throwIfCancelled(); + DataLoaderDispatchStrategy dataLoaderDispatcherStrategy = executionContext.getDataLoaderDispatcherStrategy(); Instrumentation instrumentation = executionContext.getInstrumentation(); InstrumentationExecutionStrategyParameters instrumentationParameters = new InstrumentationExecutionStrategyParameters(executionContext, parameters); @@ -218,6 +220,8 @@ protected Object executeObject(ExecutionContext executionContext, ExecutionStrat if (fieldValueInfosResult instanceof CompletableFuture) { CompletableFuture> fieldValueInfos = (CompletableFuture>) fieldValueInfosResult; fieldValueInfos.whenComplete((completeValueInfos, throwable) -> { + throwable = executionContext.possibleCancellation(throwable); + if (throwable != null) { handleResultsConsumer.accept(null, throwable); return; @@ -269,6 +273,8 @@ protected Object executeObject(ExecutionContext executionContext, ExecutionStrat private BiConsumer, Throwable> buildFieldValueMap(List fieldNames, CompletableFuture> overallResult, ExecutionContext executionContext) { return (List results, Throwable exception) -> { + exception = executionContext.possibleCancellation(exception); + if (exception != null) { handleValueException(overallResult, exception, executionContext); return; @@ -296,6 +302,8 @@ DeferredExecutionSupport createDeferredExecutionSupport(ExecutionContext executi ExecutionStrategyParameters parameters, DeferredExecutionSupport deferredExecutionSupport ) { + executionContext.throwIfCancelled(); + MergedSelectionSet fields = parameters.getFields(); executionContext.getIncrementalCallState().enqueue(deferredExecutionSupport.createCalls()); @@ -305,6 +313,8 @@ DeferredExecutionSupport createDeferredExecutionSupport(ExecutionContext executi .ofExpectedSize(fields.size() - deferredExecutionSupport.deferredFieldsCount()); for (String fieldName : fields.getKeys()) { + executionContext.throwIfCancelled(); + MergedField currentField = fields.getSubField(fieldName); ResultPath fieldPath = parameters.getPath().segment(mkNameForPath(currentField)); @@ -392,6 +402,7 @@ protected Object fetchField(ExecutionContext executionContext, ExecutionStrategy @DuckTyped(shape = "CompletableFuture | FetchedValue") private Object fetchField(GraphQLFieldDefinition fieldDef, ExecutionContext executionContext, ExecutionStrategyParameters parameters) { + executionContext.throwIfCancelled(); if (incrementAndCheckMaxNodesExceeded(executionContext)) { return new FetchedValue(null, Collections.emptyList(), null); @@ -465,9 +476,10 @@ private Object fetchField(GraphQLFieldDefinition fieldDef, ExecutionContext exec CompletableFuture> handleCF = engineRunningState.handle(fetchedValue, (result, exception) -> { // because we added an artificial CF, we need to unwrap the exception fetchCtx.onCompleted(result, exception); + exception = engineRunningState.possibleCancellation(exception); + if (exception != null) { - CompletableFuture handleFetchingExceptionResult = handleFetchingException(dataFetchingEnvironment.get(), parameters, exception); - return handleFetchingExceptionResult; + return handleFetchingException(dataFetchingEnvironment.get(), parameters, exception); } else { // we can simply return the fetched value CF and avoid a allocation return fetchedValue; @@ -588,6 +600,8 @@ private CompletableFuture asyncHandleException(DataFetcherExceptionHandle * if a nonnull field resolves to a null value */ protected FieldValueInfo completeField(ExecutionContext executionContext, ExecutionStrategyParameters parameters, FetchedValue fetchedValue) { + executionContext.throwIfCancelled(); + Field field = parameters.getField().getSingleField(); GraphQLObjectType parentType = (GraphQLObjectType) parameters.getExecutionStepInfo().getUnwrappedNonNullType(); GraphQLFieldDefinition fieldDef = getFieldDef(executionContext.getGraphQLSchema(), parentType, field); @@ -784,6 +798,8 @@ protected FieldValueInfo completeValueForList(ExecutionContext executionContext, overallResult.whenComplete(completeListCtx::onCompleted); resultsFuture.whenComplete((results, exception) -> { + exception = executionContext.possibleCancellation(exception); + if (exception != null) { handleValueException(overallResult, exception, executionContext); return; diff --git a/src/test/groovy/graphql/ExecutionInputTest.groovy b/src/test/groovy/graphql/ExecutionInputTest.groovy index d2cb3bc93..54ff68d73 100644 --- a/src/test/groovy/graphql/ExecutionInputTest.groovy +++ b/src/test/groovy/graphql/ExecutionInputTest.groovy @@ -1,9 +1,15 @@ package graphql import graphql.execution.ExecutionId +import graphql.execution.instrumentation.ExecutionStrategyInstrumentationContext +import graphql.execution.instrumentation.Instrumentation +import graphql.execution.instrumentation.InstrumentationContext +import graphql.execution.instrumentation.InstrumentationState +import graphql.execution.instrumentation.parameters.InstrumentationExecutionStrategyParameters +import graphql.execution.instrumentation.parameters.InstrumentationFieldCompleteParameters +import graphql.execution.instrumentation.parameters.InstrumentationFieldParameters import graphql.schema.DataFetcher import graphql.schema.DataFetchingEnvironment -import org.awaitility.Awaitility import org.dataloader.DataLoaderRegistry import spock.lang.Specification @@ -11,6 +17,8 @@ import java.time.Duration import java.util.concurrent.CompletableFuture import java.util.concurrent.CountDownLatch +import static org.awaitility.Awaitility.* + class ExecutionInputTest extends Specification { def query = "query { hello }" @@ -167,4 +175,248 @@ class ExecutionInputTest extends Specification { er.errors.isEmpty() er.data["fetch"] == "{locale=German, executionId=ID123, graphqlContext=b}" } + + def "can cancel the execution"() { + def sdl = ''' + type Query { + fetch1 : Inner + fetch2 : Inner + } + + type Inner { + f : String + } + + ''' + + CountDownLatch fieldLatch = new CountDownLatch(1) + + DataFetcher df1Sec = { DataFetchingEnvironment env -> + println("Entering DF1") + return CompletableFuture.supplyAsync { + println("DF1 async run") + fieldLatch.await() + Thread.sleep(1000) + return [f: "x"] + } + } + DataFetcher df10Sec = { DataFetchingEnvironment env -> + println("Entering DF10") + return CompletableFuture.supplyAsync { + println("DF10 async run") + fieldLatch.await() + Thread.sleep(10000) + return "x" + } + } + + def fetcherMap = ["Query": ["fetch1": df1Sec, "fetch2": df1Sec], + "Inner": ["f": df10Sec] + ] + def schema = TestUtil.schema(sdl, fetcherMap) + def graphQL = GraphQL.newGraphQL(schema).build() + + when: + ExecutionInput executionInput = ExecutionInput.newExecutionInput() + .query("query q { fetch1 { f } fetch2 { f } }") + .build() + + def cf = graphQL.executeAsync(executionInput) + + Thread.sleep(250) // let it get into the field fetching say + + // lets cancel it + println("cancelling") + executionInput.cancel() + + // let the DFs run + println("make the fields run") + fieldLatch.countDown() + + println("and await for the overall CF to complete") + await().atMost(Duration.ofSeconds(60)).until({ -> cf.isDone() }) + + def er = cf.join() + + then: + !er.errors.isEmpty() + er.errors[0]["message"] == "Execution has been asked to be cancelled" + } + + def "can cancel request at random times (#testName)"() { + def sdl = ''' + type Query { + fetch1 : Inner + fetch2 : Inner + } + + type Inner { + inner : Inner + f : String + } + + ''' + + when: + + CountDownLatch fetcherLatch = new CountDownLatch(1) + + DataFetcher df = { DataFetchingEnvironment env -> + return CompletableFuture.supplyAsync { + fetcherLatch.countDown() + def delay = plusOrMinus(dfDelay) + println("DF ${env.getExecutionStepInfo().getPath()} sleeping for $delay") + Thread.sleep(delay) + return [inner: [f: "x"], f: "x"] + } + } + + def fetcherMap = ["Query": ["fetch1": df, "fetch2": df], + "Inner": ["inner": df] + ] + def schema = TestUtil.schema(sdl, fetcherMap) + def graphQL = GraphQL.newGraphQL(schema).build() + + ExecutionInput executionInput = ExecutionInput.newExecutionInput() + .query("query q { fetch1 { inner { inner { inner { f }}}} fetch2 { inner { inner { inner { f }}}} }") + .build() + + def cf = graphQL.executeAsync(executionInput) + + // wait for at least one fetcher to run + fetcherLatch.await() + + // using a random number MAY make this test flaky - but so be it. We want ot make sure that + // if we cancel then we dont lock up. We have deterministic tests for that so lets hav + // some random ones + // + def randomCancel = plusOrMinus(dfDelay) + Thread.sleep(randomCancel) + + // now make it cancel + println("Cancelling after $randomCancel") + executionInput.cancel() + + await().atMost(Duration.ofSeconds(10)).until({ -> cf.isDone() }) + + def er = cf.join() + + then: + !er.errors.isEmpty() + er.errors[0]["message"] == "Execution has been asked to be cancelled" + + where: + testName | dfDelay + "50 ms" | plusOrMinus(50) + "100 ms" | plusOrMinus(100) + "200 ms" | plusOrMinus(200) + "500 ms" | plusOrMinus(500) + "1000 ms" | plusOrMinus(1000) + } + + def "can cancel at specific places"() { + def sdl = ''' + type Query { + fetch1 : Inner + fetch2 : Inner + } + + type Inner { + inner : Inner + f : String + } + + ''' + + when: + + DataFetcher df = { DataFetchingEnvironment env -> + return CompletableFuture.supplyAsync { + return [inner: [f: "x"], f: "x"] + } + } + + def fetcherMap = ["Query": ["fetch1": df, "fetch2": df], + "Inner": ["inner": df] + ] + + + def queryText = "query q { fetch1 { inner { inner { inner { f }}}} fetch2 { inner { inner { inner { f }}}} }" + ExecutionInput executionInput = ExecutionInput.newExecutionInput() + .query(queryText) + .build() + + Instrumentation instrumentation = new Instrumentation() { + @Override + ExecutionStrategyInstrumentationContext beginExecutionStrategy(InstrumentationExecutionStrategyParameters parameters, InstrumentationState state) { + executionInput.cancel() + return null + } + } + def schema = TestUtil.schema(sdl, fetcherMap) + def graphQL = GraphQL.newGraphQL(schema).instrumentation(instrumentation).build() + + + def er = awaitAsync(graphQL, executionInput) + + then: + !er.errors.isEmpty() + er.errors[0]["message"] == "Execution has been asked to be cancelled" + + when: + executionInput = ExecutionInput.newExecutionInput() + .query(queryText) + .build() + + instrumentation = new Instrumentation() { + @Override + InstrumentationContext beginFieldExecution(InstrumentationFieldParameters parameters, InstrumentationState state) { + executionInput.cancel() + return null + } + } + schema = TestUtil.schema(sdl, fetcherMap) + graphQL = GraphQL.newGraphQL(schema).instrumentation(instrumentation).build() + + er = awaitAsync(graphQL, executionInput) + + then: + !er.errors.isEmpty() + er.errors[0]["message"] == "Execution has been asked to be cancelled" + + when: + executionInput = ExecutionInput.newExecutionInput() + .query(queryText) + .build() + + instrumentation = new Instrumentation() { + + @Override + InstrumentationContext beginFieldCompletion(InstrumentationFieldCompleteParameters parameters, InstrumentationState state) { + executionInput.cancel() + return null + } + } + schema = TestUtil.schema(sdl, fetcherMap) + graphQL = GraphQL.newGraphQL(schema).instrumentation(instrumentation).build() + + er = awaitAsync(graphQL, executionInput) + + then: + !er.errors.isEmpty() + er.errors[0]["message"] == "Execution has been asked to be cancelled" + + } + + private static ExecutionResult awaitAsync(GraphQL graphQL, ExecutionInput executionInput) { + def cf = graphQL.executeAsync(executionInput) + await().atMost(Duration.ofSeconds(10)).until({ -> cf.isDone() }) + return cf.join() + } + + private static int plusOrMinus(int integer) { + int half = (int) (integer / 2) + def intVal = TestUtil.rand((integer - half), (integer + half)) + return intVal + } } diff --git a/src/test/groovy/graphql/execution/AsyncExecutionStrategyTest.groovy b/src/test/groovy/graphql/execution/AsyncExecutionStrategyTest.groovy index 9d99fbbfb..18c35d260 100644 --- a/src/test/groovy/graphql/execution/AsyncExecutionStrategyTest.groovy +++ b/src/test/groovy/graphql/execution/AsyncExecutionStrategyTest.groovy @@ -102,6 +102,7 @@ abstract class AsyncExecutionStrategyTest extends Specification { .type(schema.getQueryType()) .build() + def ei = ExecutionInput.newExecutionInput("{}").build() ExecutionContext executionContext = new ExecutionContextBuilder() .graphQLSchema(schema) .executionId(ExecutionId.generate()) @@ -109,9 +110,9 @@ abstract class AsyncExecutionStrategyTest extends Specification { .instrumentation(SimplePerformantInstrumentation.INSTANCE) .valueUnboxer(ValueUnboxer.DEFAULT) .graphQLContext(graphqlContextMock) - .executionInput(ExecutionInput.newExecutionInput("{}").build()) + .executionInput(ei) .locale(Locale.getDefault()) - .engineRunningState(new EngineRunningState()) + .engineRunningState(new EngineRunningState(ei)) .build() ExecutionStrategyParameters executionStrategyParameters = ExecutionStrategyParameters .newParameters() @@ -146,6 +147,7 @@ abstract class AsyncExecutionStrategyTest extends Specification { .type(schema.getQueryType()) .build() + def ei = ExecutionInput.newExecutionInput("{}").build() ExecutionContext executionContext = new ExecutionContextBuilder() .graphQLSchema(schema) .executionId(ExecutionId.generate()) @@ -154,8 +156,8 @@ abstract class AsyncExecutionStrategyTest extends Specification { .instrumentation(SimplePerformantInstrumentation.INSTANCE) .locale(Locale.getDefault()) .graphQLContext(graphqlContextMock) - .executionInput(ExecutionInput.newExecutionInput("{}").build()) - .engineRunningState(new EngineRunningState()) + .executionInput(ei) + .engineRunningState(new EngineRunningState(ei)) .build() ExecutionStrategyParameters executionStrategyParameters = ExecutionStrategyParameters .newParameters() @@ -192,6 +194,7 @@ abstract class AsyncExecutionStrategyTest extends Specification { .type(schema.getQueryType()) .build() + def ei = ExecutionInput.newExecutionInput("{}").build() ExecutionContext executionContext = new ExecutionContextBuilder() .graphQLSchema(schema) .executionId(ExecutionId.generate()) @@ -199,8 +202,8 @@ abstract class AsyncExecutionStrategyTest extends Specification { .valueUnboxer(ValueUnboxer.DEFAULT) .instrumentation(SimplePerformantInstrumentation.INSTANCE) .graphQLContext(graphqlContextMock) - .executionInput(ExecutionInput.newExecutionInput("{}").build()) - .engineRunningState(new EngineRunningState()) + .executionInput(ei) + .engineRunningState(new EngineRunningState(ei)) .locale(Locale.getDefault()) .build() ExecutionStrategyParameters executionStrategyParameters = ExecutionStrategyParameters @@ -237,6 +240,7 @@ abstract class AsyncExecutionStrategyTest extends Specification { .type(schema.getQueryType()) .build() + def ei = ExecutionInput.newExecutionInput("{}").build() ExecutionContext executionContext = new ExecutionContextBuilder() .graphQLSchema(schema) .executionId(ExecutionId.generate()) @@ -245,8 +249,8 @@ abstract class AsyncExecutionStrategyTest extends Specification { .valueUnboxer(ValueUnboxer.DEFAULT) .locale(Locale.getDefault()) .graphQLContext(graphqlContextMock) - .executionInput(ExecutionInput.newExecutionInput("{}").build()) - .engineRunningState(new EngineRunningState()) + .executionInput(ei) + .engineRunningState(new EngineRunningState(ei)) .build() ExecutionStrategyParameters executionStrategyParameters = ExecutionStrategyParameters .newParameters() @@ -281,15 +285,16 @@ abstract class AsyncExecutionStrategyTest extends Specification { .type(schema.getQueryType()) .build() + def ei = ExecutionInput.newExecutionInput("{}").build() ExecutionContext executionContext = new ExecutionContextBuilder() .graphQLSchema(schema) .executionId(ExecutionId.generate()) .operationDefinition(operation) .valueUnboxer(ValueUnboxer.DEFAULT) .graphQLContext(graphqlContextMock) - .executionInput(ExecutionInput.newExecutionInput("{}").build()) + .executionInput(ei) .locale(Locale.getDefault()) - .engineRunningState(new EngineRunningState()) + .engineRunningState(new EngineRunningState(ei)) .instrumentation(new SimplePerformantInstrumentation() { @Override diff --git a/src/test/groovy/graphql/execution/AsyncSerialExecutionStrategyTest.groovy b/src/test/groovy/graphql/execution/AsyncSerialExecutionStrategyTest.groovy index 937c99c70..ba1f83e18 100644 --- a/src/test/groovy/graphql/execution/AsyncSerialExecutionStrategyTest.groovy +++ b/src/test/groovy/graphql/execution/AsyncSerialExecutionStrategyTest.groovy @@ -100,6 +100,7 @@ class AsyncSerialExecutionStrategyTest extends Specification { .type(schema.getQueryType()) .build() + def ei = ExecutionInput.newExecutionInput("{}").build() ExecutionContext executionContext = new ExecutionContextBuilder() .graphQLSchema(schema) .executionId(ExecutionId.generate()) @@ -108,8 +109,8 @@ class AsyncSerialExecutionStrategyTest extends Specification { .valueUnboxer(ValueUnboxer.DEFAULT) .locale(Locale.getDefault()) .graphQLContext(GraphQLContext.getDefault()) - .executionInput(ExecutionInput.newExecutionInput("{}").build()) - .engineRunningState(new EngineRunningState()) + .executionInput(ei) + .engineRunningState(new EngineRunningState(ei)) .build() ExecutionStrategyParameters executionStrategyParameters = ExecutionStrategyParameters .newParameters() @@ -149,6 +150,7 @@ class AsyncSerialExecutionStrategyTest extends Specification { .type(schema.getQueryType()) .build() + def ei = ExecutionInput.newExecutionInput("{}").build() ExecutionContext executionContext = new ExecutionContextBuilder() .graphQLSchema(schema) .executionId(ExecutionId.generate()) @@ -157,8 +159,8 @@ class AsyncSerialExecutionStrategyTest extends Specification { .valueUnboxer(ValueUnboxer.DEFAULT) .locale(Locale.getDefault()) .graphQLContext(GraphQLContext.getDefault()) - .executionInput(ExecutionInput.newExecutionInput("{}").build()) - .engineRunningState(new EngineRunningState()) + .executionInput(ei) + .engineRunningState(new EngineRunningState(ei)) .build() ExecutionStrategyParameters executionStrategyParameters = ExecutionStrategyParameters .newParameters() diff --git a/src/test/groovy/graphql/execution/ExecutionStrategyTest.groovy b/src/test/groovy/graphql/execution/ExecutionStrategyTest.groovy index a8de454c0..bb6db0bd0 100644 --- a/src/test/groovy/graphql/execution/ExecutionStrategyTest.groovy +++ b/src/test/groovy/graphql/execution/ExecutionStrategyTest.groovy @@ -70,6 +70,7 @@ class ExecutionStrategyTest extends Specification { def buildContext(GraphQLSchema schema = null) { ExecutionId executionId = ExecutionId.from("executionId123") + ExecutionInput ei = ExecutionInput.newExecutionInput("{}").build() def variables = [arg1: "value1"] def builder = ExecutionContextBuilder.newExecutionContextBuilder() .instrumentation(SimplePerformantInstrumentation.INSTANCE) @@ -80,12 +81,12 @@ class ExecutionStrategyTest extends Specification { .subscriptionStrategy(executionStrategy) .coercedVariables(CoercedVariables.of(variables)) .graphQLContext(GraphQLContext.newContext().of("key", "context").build()) - .executionInput(ExecutionInput.newExecutionInput("{}").build()) + .executionInput(ei) .root("root") .dataLoaderRegistry(new DataLoaderRegistry()) .locale(Locale.getDefault()) .valueUnboxer(ValueUnboxer.DEFAULT) - .engineRunningState(new EngineRunningState()) + .engineRunningState(new EngineRunningState(ei)) new ExecutionContext(builder) } diff --git a/src/test/groovy/graphql/execution/SubscriptionExecutionStrategyTest.groovy b/src/test/groovy/graphql/execution/SubscriptionExecutionStrategyTest.groovy index a92a669fd..2b59d3325 100644 --- a/src/test/groovy/graphql/execution/SubscriptionExecutionStrategyTest.groovy +++ b/src/test/groovy/graphql/execution/SubscriptionExecutionStrategyTest.groovy @@ -31,6 +31,7 @@ import spock.lang.Unroll import java.util.concurrent.CompletableFuture import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.CopyOnWriteArrayList import static graphql.schema.idl.TypeRuntimeWiring.newTypeWiring @@ -714,6 +715,49 @@ class SubscriptionExecutionStrategyTest extends Specification { } } + def "we can cancel the operation and the upstream publisher is told"() { + List promises = new CopyOnWriteArrayList<>() + RxJavaMessagePublisher publisher = new RxJavaMessagePublisher(10) + + DataFetcher newMessageDF = { env -> return publisher } + DataFetcher senderDF = dfThatDoesNotComplete("sender", promises) + DataFetcher textDF = PropertyDataFetcher.fetching("text") + + GraphQL graphQL = buildSubscriptionQL(newMessageDF, senderDF, textDF) + + def executionInput = ExecutionInput.newExecutionInput().query(""" + subscription NewMessages { + newMessage(roomId: 123) { + sender + text + } + } + """).graphQLContext([(SubscriptionExecutionStrategy.KEEP_SUBSCRIPTION_EVENTS_ORDERED): true]).build() + + def executionResult = graphQL.execute(executionInput) + + when: + Publisher msgStream = executionResult.getData() + def capturingSubscriber = new CapturingSubscriber(1) + msgStream.subscribe(capturingSubscriber) + + // now cancel the operation + executionInput.cancel() + + // make things over the subscription + promises.forEach {it.run()} + + + then: + Awaitility.await().untilTrue(capturingSubscriber.isDone()) + + def messages = capturingSubscriber.events + messages.size() == 1 + def error = messages[0].errors[0] + assert error.message.contains("Execution has been asked to be cancelled") + publisher.counter == 2 + } + private static DataFetcher dfThatDoesNotComplete(String propertyName, List promises) { { env -> def df = PropertyDataFetcher.fetching(propertyName) diff --git a/src/test/groovy/graphql/execution/instrumentation/fieldvalidation/FieldValidationTest.groovy b/src/test/groovy/graphql/execution/instrumentation/fieldvalidation/FieldValidationTest.groovy index 508fc56d1..2852d262b 100644 --- a/src/test/groovy/graphql/execution/instrumentation/fieldvalidation/FieldValidationTest.groovy +++ b/src/test/groovy/graphql/execution/instrumentation/fieldvalidation/FieldValidationTest.groovy @@ -310,7 +310,7 @@ class FieldValidationTest extends Specification { def execution = new Execution(strategy, strategy, strategy, instrumentation, ValueUnboxer.DEFAULT, false) def executionInput = ExecutionInput.newExecutionInput().query(query).variables(variables).build() - execution.execute(document, schema, ExecutionId.generate(), executionInput, null, new EngineRunningState()) + execution.execute(document, schema, ExecutionId.generate(), executionInput, null, new EngineRunningState(executionInput)) } def "test graphql from end to end with chained instrumentation"() {