diff --git a/src/main/java/graphql/execution/AsyncExecutionStrategy.java b/src/main/java/graphql/execution/AsyncExecutionStrategy.java index f7734df9fb..7f51149089 100644 --- a/src/main/java/graphql/execution/AsyncExecutionStrategy.java +++ b/src/main/java/graphql/execution/AsyncExecutionStrategy.java @@ -39,7 +39,6 @@ public AsyncExecutionStrategy(DataFetcherExceptionHandler exceptionHandler) { @SuppressWarnings("FutureReturnValueIgnored") public CompletableFuture execute(ExecutionContext executionContext, ExecutionStrategyParameters parameters) throws NonNullableFieldWasNullException { DataLoaderDispatchStrategy dataLoaderDispatcherStrategy = executionContext.getDataLoaderDispatcherStrategy(); - dataLoaderDispatcherStrategy.executionStrategy(executionContext, parameters); Instrumentation instrumentation = executionContext.getInstrumentation(); InstrumentationExecutionStrategyParameters instrumentationParameters = new InstrumentationExecutionStrategyParameters(executionContext, parameters); @@ -54,6 +53,9 @@ public CompletableFuture execute(ExecutionContext executionCont } DeferredExecutionSupport deferredExecutionSupport = createDeferredExecutionSupport(executionContext, parameters); + + dataLoaderDispatcherStrategy.executionStrategy(executionContext, parameters, deferredExecutionSupport.getNonDeferredFieldNames(fieldNames).size()); + Async.CombinedBuilder futures = getAsyncFieldValueInfo(executionContext, parameters, deferredExecutionSupport); CompletableFuture overallResult = new CompletableFuture<>(); @@ -72,14 +74,14 @@ public CompletableFuture execute(ExecutionContext executionCont for (FieldValueInfo completeValueInfo : completeValueInfos) { fieldValuesFutures.addObject(completeValueInfo.getFieldValueObject()); } - dataLoaderDispatcherStrategy.executionStrategyOnFieldValuesInfo(completeValueInfos); + dataLoaderDispatcherStrategy.executionStrategyOnFieldValuesInfo(completeValueInfos, parameters); executionStrategyCtx.onFieldValuesInfo(completeValueInfos); fieldValuesFutures.await().whenComplete(handleResultsConsumer); }).exceptionally((ex) -> { // if there are any issues with combining/handling the field results, // complete the future at all costs and bubble up any thrown exception so // the execution does not hang. - dataLoaderDispatcherStrategy.executionStrategyOnFieldValuesException(ex); + dataLoaderDispatcherStrategy.executionStrategyOnFieldValuesException(ex, parameters); executionStrategyCtx.onFieldValuesException(); overallResult.completeExceptionally(ex); return null; diff --git a/src/main/java/graphql/execution/AsyncSerialExecutionStrategy.java b/src/main/java/graphql/execution/AsyncSerialExecutionStrategy.java index 98c6ce478b..665777731d 100644 --- a/src/main/java/graphql/execution/AsyncSerialExecutionStrategy.java +++ b/src/main/java/graphql/execution/AsyncSerialExecutionStrategy.java @@ -74,13 +74,13 @@ private Object resolveSerialField(ExecutionContext executionContext, if (fieldWithInfo instanceof CompletableFuture) { //noinspection unchecked return ((CompletableFuture) fieldWithInfo).thenCompose(fvi -> { - dataLoaderDispatcherStrategy.executionStrategyOnFieldValuesInfo(List.of(fvi)); + dataLoaderDispatcherStrategy.executionStrategyOnFieldValuesInfo(List.of(fvi), newParameters); CompletableFuture fieldValueFuture = fvi.getFieldValueFuture(); return fieldValueFuture; }); } else { FieldValueInfo fvi = (FieldValueInfo) fieldWithInfo; - dataLoaderDispatcherStrategy.executionStrategyOnFieldValuesInfo(List.of(fvi)); + dataLoaderDispatcherStrategy.executionStrategyOnFieldValuesInfo(List.of(fvi), newParameters); return fvi.getFieldValueObject(); } } diff --git a/src/main/java/graphql/execution/DataLoaderDispatchStrategy.java b/src/main/java/graphql/execution/DataLoaderDispatchStrategy.java index d91bf46814..8799797d1f 100644 --- a/src/main/java/graphql/execution/DataLoaderDispatchStrategy.java +++ b/src/main/java/graphql/execution/DataLoaderDispatchStrategy.java @@ -14,7 +14,7 @@ public interface DataLoaderDispatchStrategy { }; - default void executionStrategy(ExecutionContext executionContext, ExecutionStrategyParameters parameters) { + default void executionStrategy(ExecutionContext executionContext, ExecutionStrategyParameters parameters, int fieldCount) { } @@ -22,16 +22,16 @@ default void executionSerialStrategy(ExecutionContext executionContext, Executio } - default void executionStrategyOnFieldValuesInfo(List fieldValueInfoList) { + default void executionStrategyOnFieldValuesInfo(List fieldValueInfoList, ExecutionStrategyParameters parameters) { } - default void executionStrategyOnFieldValuesException(Throwable t) { + default void executionStrategyOnFieldValuesException(Throwable t, ExecutionStrategyParameters parameters) { } - default void executeObject(ExecutionContext executionContext, ExecutionStrategyParameters executionStrategyParameters) { + default void executeObject(ExecutionContext executionContext, ExecutionStrategyParameters executionStrategyParameters, int fieldCount) { } @@ -39,6 +39,10 @@ default void executeObjectOnFieldValuesInfo(List fieldValueInfoL } + default void deferredOnFieldValue(String resultKey, FieldValueInfo fieldValueInfo, Throwable throwable, ExecutionStrategyParameters parameters) { + + } + default void executeObjectOnFieldValuesException(Throwable t, ExecutionStrategyParameters parameters) { } @@ -55,8 +59,4 @@ default void fieldFetched(ExecutionContext executionContext, default DataFetcher modifyDataFetcher(DataFetcher dataFetcher) { return dataFetcher; } - - default void executeDeferredOnFieldValueInfo(FieldValueInfo fieldValueInfo, ExecutionStrategyParameters executionStrategyParameters) { - - } } diff --git a/src/main/java/graphql/execution/Execution.java b/src/main/java/graphql/execution/Execution.java index f35854a188..60447c9996 100644 --- a/src/main/java/graphql/execution/Execution.java +++ b/src/main/java/graphql/execution/Execution.java @@ -6,7 +6,6 @@ import graphql.ExecutionInput; import graphql.ExecutionResult; import graphql.ExecutionResultImpl; -import graphql.ExperimentalApi; import graphql.GraphQLContext; import graphql.GraphQLError; import graphql.Internal; @@ -16,7 +15,6 @@ import graphql.execution.instrumentation.InstrumentationState; import graphql.execution.instrumentation.dataloader.FallbackDataLoaderDispatchStrategy; import graphql.execution.instrumentation.dataloader.PerLevelDataLoaderDispatchStrategy; -import graphql.execution.instrumentation.dataloader.PerLevelDataLoaderDispatchStrategyWithDeferAlwaysDispatch; import graphql.execution.instrumentation.parameters.InstrumentationExecuteOperationParameters; import graphql.execution.instrumentation.parameters.InstrumentationExecutionParameters; import graphql.extensions.ExtensionsBuilder; @@ -37,7 +35,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.function.Supplier; @@ -258,12 +255,7 @@ private DataLoaderDispatchStrategy createDataLoaderDispatchStrategy(ExecutionCon return DataLoaderDispatchStrategy.NO_OP; } if (!executionContext.isSubscriptionOperation()) { - boolean deferEnabled = executionContext.hasIncrementalSupport(); - - // Dedicated strategy for defer support, for safety purposes. - return deferEnabled ? - new PerLevelDataLoaderDispatchStrategyWithDeferAlwaysDispatch(executionContext) : - new PerLevelDataLoaderDispatchStrategy(executionContext); + return new PerLevelDataLoaderDispatchStrategy(executionContext); } else { return new FallbackDataLoaderDispatchStrategy(executionContext); } diff --git a/src/main/java/graphql/execution/ExecutionStrategy.java b/src/main/java/graphql/execution/ExecutionStrategy.java index 355f13106b..c1323f0b84 100644 --- a/src/main/java/graphql/execution/ExecutionStrategy.java +++ b/src/main/java/graphql/execution/ExecutionStrategy.java @@ -5,7 +5,6 @@ import graphql.EngineRunningState; import graphql.ExecutionResult; import graphql.ExecutionResultImpl; -import graphql.ExperimentalApi; import graphql.GraphQLError; import graphql.Internal; import graphql.PublicSpi; @@ -50,7 +49,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.OptionalInt; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; @@ -197,7 +195,6 @@ public static String mkNameForPath(List currentField) { @DuckTyped(shape = "CompletableFuture> | Map") protected Object executeObject(ExecutionContext executionContext, ExecutionStrategyParameters parameters) throws NonNullableFieldWasNullException { DataLoaderDispatchStrategy dataLoaderDispatcherStrategy = executionContext.getDataLoaderDispatcherStrategy(); - dataLoaderDispatcherStrategy.executeObject(executionContext, parameters); Instrumentation instrumentation = executionContext.getInstrumentation(); InstrumentationExecutionStrategyParameters instrumentationParameters = new InstrumentationExecutionStrategyParameters(executionContext, parameters); @@ -212,6 +209,7 @@ protected Object executeObject(ExecutionContext executionContext, ExecutionStrat CompletableFuture> overallResult = new CompletableFuture<>(); List fieldsExecutedOnInitialResult = deferredExecutionSupport.getNonDeferredFieldNames(fieldNames); + dataLoaderDispatcherStrategy.executeObject(executionContext, parameters, fieldsExecutedOnInitialResult.size()); BiConsumer, Throwable> handleResultsConsumer = buildFieldValueMap(fieldsExecutedOnInitialResult, overallResult, executionContext); resolveObjectCtx.onDispatched(); @@ -300,7 +298,7 @@ DeferredExecutionSupport createDeferredExecutionSupport(ExecutionContext executi ) { MergedSelectionSet fields = parameters.getFields(); - executionContext.getIncrementalCallState().enqueue(deferredExecutionSupport.createCalls(parameters)); + executionContext.getIncrementalCallState().enqueue(deferredExecutionSupport.createCalls()); // Only non-deferred fields should be considered for calculating the expected size of futures. Async.CombinedBuilder futures = Async @@ -400,7 +398,6 @@ private Object fetchField(GraphQLFieldDefinition fieldDef, ExecutionContext exec } MergedField field = parameters.getField(); - String pathString = parameters.getPath().toString(); GraphQLObjectType parentType = (GraphQLObjectType) parameters.getExecutionStepInfo().getUnwrappedNonNullType(); // if the DF (like PropertyDataFetcher) does not use the arguments or execution step info then dont build any @@ -435,6 +432,7 @@ private Object fetchField(GraphQLFieldDefinition fieldDef, ExecutionContext exec .parentType(parentType) .selectionSet(fieldCollector) .queryDirectives(queryDirectives) + .deferredCallContext(parameters.getDeferredCallContext()) .build(); }); diff --git a/src/main/java/graphql/execution/ExecutionStrategyParameters.java b/src/main/java/graphql/execution/ExecutionStrategyParameters.java index 58eb3d1767..87dd7057ae 100644 --- a/src/main/java/graphql/execution/ExecutionStrategyParameters.java +++ b/src/main/java/graphql/execution/ExecutionStrategyParameters.java @@ -94,6 +94,7 @@ public ExecutionStrategyParameters getParent() { * @return the deferred call context or null if we're not in the scope of a deferred call */ @Nullable + @Internal public DeferredCallContext getDeferredCallContext() { return deferredCallContext; } diff --git a/src/main/java/graphql/execution/incremental/DeferredCallContext.java b/src/main/java/graphql/execution/incremental/DeferredCallContext.java index d7d494aace..e7e2ec0658 100644 --- a/src/main/java/graphql/execution/incremental/DeferredCallContext.java +++ b/src/main/java/graphql/execution/incremental/DeferredCallContext.java @@ -2,6 +2,7 @@ import graphql.GraphQLError; import graphql.Internal; +import graphql.VisibleForTesting; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; @@ -18,8 +19,31 @@ @Internal public class DeferredCallContext { + private final int startLevel; + private final int fields; + private final List errors = new CopyOnWriteArrayList<>(); + public DeferredCallContext(int startLevel, int fields) { + this.startLevel = startLevel; + this.fields = fields; + } + + @VisibleForTesting + public DeferredCallContext() { + this.startLevel = 0; + this.fields = 0; + } + + public int getStartLevel() { + return startLevel; + } + + public int getFields() { + return fields; + } + + public void addErrors(List errors) { this.errors.addAll(errors); } @@ -34,4 +58,6 @@ public void addError(GraphQLError graphqlError) { public List getErrors() { return errors; } + + } diff --git a/src/main/java/graphql/execution/incremental/DeferredExecutionSupport.java b/src/main/java/graphql/execution/incremental/DeferredExecutionSupport.java index a347d9b0cd..ade6242d24 100644 --- a/src/main/java/graphql/execution/incremental/DeferredExecutionSupport.java +++ b/src/main/java/graphql/execution/incremental/DeferredExecutionSupport.java @@ -45,7 +45,7 @@ public interface DeferredExecutionSupport { List getNonDeferredFieldNames(List allFieldNames); - Set> createCalls(ExecutionStrategyParameters executionStrategyParameters); + Set> createCalls(); DeferredExecutionSupport NOOP = new DeferredExecutionSupport.NoOp(); @@ -106,23 +106,24 @@ public List getNonDeferredFieldNames(List allFieldNames) { } @Override - public Set> createCalls(ExecutionStrategyParameters executionStrategyParameters) { + public Set> createCalls() { ImmutableSet deferredExecutions = deferredExecutionToFields.keySet(); Set> set = new HashSet<>(deferredExecutions.size()); for (DeferredExecution deferredExecution : deferredExecutions) { - set.add(this.createDeferredFragmentCall(deferredExecution, executionStrategyParameters)); + set.add(this.createDeferredFragmentCall(deferredExecution)); } return set; } - private DeferredFragmentCall createDeferredFragmentCall(DeferredExecution deferredExecution, ExecutionStrategyParameters executionStrategyParameters) { - DeferredCallContext deferredCallContext = new DeferredCallContext(); + private DeferredFragmentCall createDeferredFragmentCall(DeferredExecution deferredExecution) { + int level = parameters.getPath().getLevel() + 1; + DeferredCallContext deferredCallContext = new DeferredCallContext(level, deferredFields.size()); List mergedFields = deferredExecutionToFields.get(deferredExecution); List>> calls = FpKit.arrayListSizedTo(mergedFields); for (MergedField currentField : mergedFields) { - calls.add(this.createResultSupplier(currentField, deferredCallContext, executionStrategyParameters)); + calls.add(this.createResultSupplier(currentField, deferredCallContext)); } return new DeferredFragmentCall( @@ -135,13 +136,12 @@ private DeferredFragmentCall createDeferredFragmentCall(DeferredExecution deferr private Supplier> createResultSupplier( MergedField currentField, - DeferredCallContext deferredCallContext, - ExecutionStrategyParameters executionStrategyParameters + DeferredCallContext deferredCallContext ) { Map fields = new LinkedHashMap<>(); fields.put(currentField.getResultKey(), currentField); - ExecutionStrategyParameters callParameters = parameters.transform(builder -> + ExecutionStrategyParameters executionStrategyParameters = parameters.transform(builder -> { MergedSelectionSet mergedSelectionSet = MergedSelectionSet.newMergedSelectionSet().subFields(fields).build(); ResultPath path = parameters.getPath().segment(currentField.getResultKey()); @@ -158,22 +158,23 @@ private Supplier FpKit.interThreadMemoize(() -> { - CompletableFuture fieldValueResult = resolveFieldWithInfoFn - .apply(executionContext, callParameters); + CompletableFuture fieldValueResult = resolveFieldWithInfoFn.apply(executionContext, executionStrategyParameters); + + fieldValueResult.whenComplete((fieldValueInfo, throwable) -> { + executionContext.getDataLoaderDispatcherStrategy().deferredOnFieldValue(currentField.getResultKey(), fieldValueInfo, throwable, executionStrategyParameters); + }); - CompletableFuture executionResultCF = fieldValueResult - .thenCompose(fvi -> { - executionContext.getDataLoaderDispatcherStrategy().executeDeferredOnFieldValueInfo(fvi, executionStrategyParameters); - return fvi - .getFieldValueFuture() - .thenApply(fv -> ExecutionResultImpl.newExecutionResult().data(fv).build()); - } + CompletableFuture executionResultCF = fieldValueResult + .thenCompose(fvi -> fvi + .getFieldValueFuture() + .thenApply(fv -> ExecutionResultImpl.newExecutionResult().data(fv).build()) ); return executionResultCF @@ -207,7 +208,7 @@ public List getNonDeferredFieldNames(List allFieldNames) { } @Override - public Set> createCalls(ExecutionStrategyParameters executionStrategyParameters) { + public Set> createCalls() { return Collections.emptySet(); } } diff --git a/src/main/java/graphql/execution/incremental/IncrementalCallState.java b/src/main/java/graphql/execution/incremental/IncrementalCallState.java index 2f5c9742be..f2c0b9dbc7 100644 --- a/src/main/java/graphql/execution/incremental/IncrementalCallState.java +++ b/src/main/java/graphql/execution/incremental/IncrementalCallState.java @@ -103,4 +103,5 @@ private Supplier> cre public Publisher startDeferredCalls() { return publisher.get(); } + } diff --git a/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java b/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java index 30ccd838d4..ee80d4aa61 100644 --- a/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java +++ b/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategy.java @@ -7,12 +7,15 @@ import graphql.execution.ExecutionContext; import graphql.execution.ExecutionStrategyParameters; import graphql.execution.FieldValueInfo; +import graphql.execution.incremental.DeferredCallContext; import graphql.schema.DataFetcher; import graphql.schema.DataFetchingEnvironment; import graphql.util.InterThreadMemoizedSupplier; import graphql.util.LockKit; import org.dataloader.DataLoader; import org.dataloader.DataLoaderRegistry; +import org.jspecify.annotations.NullMarked; +import org.jspecify.annotations.Nullable; import java.util.ArrayList; import java.util.Collections; @@ -30,9 +33,10 @@ import java.util.stream.Collectors; @Internal +@NullMarked public class PerLevelDataLoaderDispatchStrategy implements DataLoaderDispatchStrategy { - private final CallStack callStack; + private final CallStack initialCallStack; private final ExecutionContext executionContext; private final long batchWindowNs; private final boolean enableDataLoaderChaining; @@ -44,16 +48,40 @@ public class PerLevelDataLoaderDispatchStrategy implements DataLoaderDispatchStr static final long DEFAULT_BATCH_WINDOW_NANO_SECONDS_DEFAULT = 500_000L; + private final Map deferredCallStackMap = new ConcurrentHashMap<>(); + private static class CallStack { private final LockKit.ReentrantLock lock = new LockKit.ReentrantLock(); /** - * A level is ready when all fields in this level are fetched - * The expected field fetch count is accurate when all execute object calls happened - * The expected execute object count is accurate when all sub selections fetched - * are done in the previous level + * A general overview of teh tracked data: + * There are three aspects tracked per level: + * - number of execute object calls (executeObject) + * - number of fetches + * - number of sub selections finished fetching + *

+ * The level for an execute object call is the level of the field in the query: for + * { a {b {c}}} the level of a is 1, b is 2 and c is not an object + *

+ * For fetches the level is the level of the field fetched + *

+ * For sub selections finished it is the level of the fields inside the sub selection: + * {a1 { b c} a2 } the level of {a1 a2} is 1, the level of {b c} is 2 + *

+ *

+ * A finished subselection means we can predict the number of execute object calls in the same level as the subselection: + * { a {x} b {y} } + * If a is a list of 3 objects and b is a list of 2 objects we expect 3 + 2 = 5 execute object calls on the level 1 to be happening + *

+ * An executed object call again means we can predict the number of fetches in the next level: + * Execute Object a with { a {f1 f2 f3} } means we expect 3 fetches on level 2. + *

+ * This means we know a level is ready to be dispatched if: + * - all subselections done in the parent level + * - all execute objects calls in the parent level are done + * - all expected fetched happened in the current level */ private final LevelMap expectedFetchCountPerLevel = new LevelMap(); @@ -84,10 +112,12 @@ private static class CallStack { private boolean batchWindowOpen; + private final List deferredFragmentRootFieldsFetched = new ArrayList<>(); + public CallStack() { // in the first level there is only one sub selection, // so we only expect one execute object call (which is actually an executionStrategy call) - expectedExecuteObjectCallsPerLevel.set(1, 1); + expectedExecuteObjectCallsPerLevel.set(0, 1); } public void addResultPathWithDataLoader(int level, ResultPathWithDataLoader resultPathWithDataLoader) { @@ -107,6 +137,7 @@ void increaseFetchCount(int level) { fetchCountPerLevel.increment(level, 1); } + void clearFetchCount() { fetchCountPerLevel.clear(); } @@ -139,8 +170,8 @@ boolean allExecuteObjectCallsHappened(int level) { return happenedExecuteObjectCallsPerLevel.get(level) == expectedExecuteObjectCallsPerLevel.get(level); } - boolean allSubSelectionsFetchingHappened(int level) { - return happenedOnFieldValueCallsPerLevel.get(level) == expectedExecuteObjectCallsPerLevel.get(level); + boolean allSubSelectionsFetchingHappened(int subSelectionLevel) { + return happenedOnFieldValueCallsPerLevel.get(subSelectionLevel) == expectedExecuteObjectCallsPerLevel.get(subSelectionLevel - 1); } boolean allFetchesHappened(int level) { @@ -172,7 +203,7 @@ public void setDispatchedLevel(int level) { } public PerLevelDataLoaderDispatchStrategy(ExecutionContext executionContext) { - this.callStack = new CallStack(); + this.initialCallStack = new CallStack(); this.executionContext = executionContext; GraphQLContext graphQLContext = executionContext.getGraphQLContext(); @@ -189,68 +220,111 @@ public PerLevelDataLoaderDispatchStrategy(ExecutionContext executionContext) { this.enableDataLoaderChaining = graphQLContext.getBoolean(DataLoaderDispatchingContextKeys.ENABLE_DATA_LOADER_CHAINING, false); } - @Override - public void executeDeferredOnFieldValueInfo(FieldValueInfo fieldValueInfo, ExecutionStrategyParameters executionStrategyParameters) { - throw new UnsupportedOperationException("Data Loaders cannot be used to resolve deferred fields"); - } @Override - public void executionStrategy(ExecutionContext executionContext, ExecutionStrategyParameters parameters) { + public void executionStrategy(ExecutionContext executionContext, ExecutionStrategyParameters parameters, int fieldCount) { Assert.assertTrue(parameters.getExecutionStepInfo().getPath().isRootPath()); - increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(1, parameters); + increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(0, fieldCount, initialCallStack); } @Override public void executionSerialStrategy(ExecutionContext executionContext, ExecutionStrategyParameters parameters) { - resetCallStack(); - increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(1, 1); + CallStack callStack = getCallStack(parameters); + resetCallStack(callStack); + increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(0, 1, callStack); } @Override - public void executionStrategyOnFieldValuesInfo(List fieldValueInfoList) { - onFieldValuesInfoDispatchIfNeeded(fieldValueInfoList, 1); + public void executionStrategyOnFieldValuesInfo(List fieldValueInfoList, ExecutionStrategyParameters parameters) { + CallStack callStack = getCallStack(parameters); + // the root fields are the root sub selection on level 1 + onFieldValuesInfoDispatchIfNeeded(fieldValueInfoList, 1, callStack); } - public void executionStrategyOnFieldValuesException(Throwable t) { + @Override + public void executionStrategyOnFieldValuesException(Throwable t, ExecutionStrategyParameters parameters) { + CallStack callStack = getCallStack(parameters); callStack.lock.runLocked(() -> callStack.increaseHappenedOnFieldValueCalls(1) ); } + private CallStack getCallStack(ExecutionStrategyParameters parameters) { + return getCallStack(parameters.getDeferredCallContext()); + } + + private CallStack getCallStack(@Nullable DeferredCallContext deferredCallContext) { + if (deferredCallContext == null) { + return this.initialCallStack; + } else { + return deferredCallStackMap.computeIfAbsent(deferredCallContext, k -> { + CallStack callStack = new CallStack(); + int startLevel = deferredCallContext.getStartLevel(); + int fields = deferredCallContext.getFields(); + callStack.lock.runLocked(() -> { + // we make sure that startLevel-1 is considered done + callStack.expectedExecuteObjectCallsPerLevel.set(0, 0); // set to 1 in the constructor of CallStack + callStack.expectedExecuteObjectCallsPerLevel.set(startLevel - 1, 1); + callStack.happenedExecuteObjectCallsPerLevel.set(startLevel - 1, 1); + callStack.highestReadyLevel = startLevel - 1; + callStack.increaseExpectedFetchCount(startLevel, fields); + }); + return callStack; + }); + } + } @Override - public void executeObject(ExecutionContext executionContext, ExecutionStrategyParameters parameters) { - int curLevel = parameters.getExecutionStepInfo().getPath().getLevel() + 1; - increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(curLevel, parameters); + public void executeObject(ExecutionContext executionContext, ExecutionStrategyParameters parameters, int fieldCount) { + CallStack callStack = getCallStack(parameters); + int curLevel = parameters.getPath().getLevel(); + increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(curLevel, fieldCount, callStack); } @Override - public void executeObjectOnFieldValuesInfo(List fieldValueInfoList, ExecutionStrategyParameters parameters) { + public void executeObjectOnFieldValuesInfo + (List fieldValueInfoList, ExecutionStrategyParameters parameters) { + // the level of the sub selection that is fully fetched is one level more than parameters level int curLevel = parameters.getPath().getLevel() + 1; - onFieldValuesInfoDispatchIfNeeded(fieldValueInfoList, curLevel); + CallStack callStack = getCallStack(parameters); + onFieldValuesInfoDispatchIfNeeded(fieldValueInfoList, curLevel, callStack); } + @Override + public void deferredOnFieldValue(String resultKey, FieldValueInfo fieldValueInfo, Throwable + throwable, ExecutionStrategyParameters parameters) { + CallStack callStack = getCallStack(parameters); + boolean ready = callStack.lock.callLocked(() -> { + callStack.deferredFragmentRootFieldsFetched.add(fieldValueInfo); + return callStack.deferredFragmentRootFieldsFetched.size() == parameters.getDeferredCallContext().getFields(); + }); + if (ready) { + int curLevel = parameters.getPath().getLevel(); + onFieldValuesInfoDispatchIfNeeded(callStack.deferredFragmentRootFieldsFetched, curLevel, callStack); + } + } @Override public void executeObjectOnFieldValuesException(Throwable t, ExecutionStrategyParameters parameters) { + CallStack callStack = getCallStack(parameters); + // the level of the sub selection that is errored is one level more than parameters level int curLevel = parameters.getPath().getLevel() + 1; callStack.lock.runLocked(() -> callStack.increaseHappenedOnFieldValueCalls(curLevel) ); } - private void increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(int curLevel, ExecutionStrategyParameters executionStrategyParameters) { - increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(curLevel, executionStrategyParameters.getFields().size()); - } - private void increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(int curLevel, int fieldCount) { + private void increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(int curLevel, + int fieldCount, + CallStack callStack) { callStack.lock.runLocked(() -> { callStack.increaseHappenedExecuteObjectCalls(curLevel); - callStack.increaseExpectedFetchCount(curLevel, fieldCount); + callStack.increaseExpectedFetchCount(curLevel + 1, fieldCount); }); } - private void resetCallStack() { + private void resetCallStack(CallStack callStack) { callStack.lock.runLocked(() -> { callStack.clearDispatchLevels(); callStack.clearExpectedObjectCalls(); @@ -258,7 +332,7 @@ private void resetCallStack() { callStack.clearFetchCount(); callStack.clearHappenedExecuteObjectCalls(); callStack.clearHappenedOnFieldValueCalls(); - callStack.expectedExecuteObjectCallsPerLevel.set(1, 1); + callStack.expectedExecuteObjectCallsPerLevel.set(0, 1); callStack.dispatchingFinishedPerLevel.clear(); callStack.dispatchingStartedPerLevel.clear(); callStack.allResultPathWithDataLoader.clear(); @@ -269,24 +343,27 @@ private void resetCallStack() { }); } - private void onFieldValuesInfoDispatchIfNeeded(List fieldValueInfoList, int curLevel) { + private void onFieldValuesInfoDispatchIfNeeded(List fieldValueInfoList, + int subSelectionLevel, + CallStack callStack) { Integer dispatchLevel = callStack.lock.callLocked(() -> - handleOnFieldValuesInfo(fieldValueInfoList, curLevel) + handleSubSelectionFetched(fieldValueInfoList, subSelectionLevel, callStack) ); // the handle on field values check for the next level if it is ready if (dispatchLevel != null) { - dispatch(dispatchLevel); + dispatch(dispatchLevel, callStack); } } // // thread safety: called with callStack.lock // - private Integer handleOnFieldValuesInfo(List fieldValueInfos, int curLevel) { - callStack.increaseHappenedOnFieldValueCalls(curLevel); + private Integer handleSubSelectionFetched(List fieldValueInfos, int subSelectionLevel, CallStack + callStack) { + callStack.increaseHappenedOnFieldValueCalls(subSelectionLevel); int expectedOnObjectCalls = getObjectCountForList(fieldValueInfos); - // on the next level we expect the following on object calls because we found non null objects - callStack.increaseExpectedExecuteObjectCalls(curLevel + 1, expectedOnObjectCalls); + // we expect on the level of the current sub selection #expectedOnObjectCalls execute object calls + callStack.increaseExpectedExecuteObjectCalls(subSelectionLevel, expectedOnObjectCalls); // maybe the object calls happened already (because the DataFetcher return directly values synchronously) // therefore we check the next levels if they are ready // this means we could skip some level because the higher level is also already ready, @@ -296,7 +373,7 @@ private Integer handleOnFieldValuesInfo(List fieldValueInfos, in // if data loader chaining is disabled (the old algo) the level we dispatch is not really relevant as // we dispatch the whole registry anyway - return getHighestReadyLevel(curLevel + 1); + return getHighestReadyLevel(subSelectionLevel + 1, callStack); } /** @@ -321,13 +398,14 @@ public void fieldFetched(ExecutionContext executionContext, DataFetcher dataFetcher, Object fetchedValue, Supplier dataFetchingEnvironment) { + CallStack callStack = getCallStack(executionStrategyParameters); int level = executionStrategyParameters.getPath().getLevel(); boolean dispatchNeeded = callStack.lock.callLocked(() -> { callStack.increaseFetchCount(level); - return dispatchIfNeeded(level); + return dispatchIfNeeded(level, callStack); }); if (dispatchNeeded) { - dispatch(level); + dispatch(level, callStack); } } @@ -336,8 +414,8 @@ public void fieldFetched(ExecutionContext executionContext, // // thread safety : called with callStack.lock // - private boolean dispatchIfNeeded(int level) { - boolean ready = checkLevelBeingReady(level); + private boolean dispatchIfNeeded(int level, CallStack callStack) { + boolean ready = checkLevelBeingReady(level, callStack); if (ready) { callStack.setDispatchedLevel(level); return true; @@ -348,10 +426,10 @@ private boolean dispatchIfNeeded(int level) { // // thread safety: called with callStack.lock // - private Integer getHighestReadyLevel(int startFrom) { + private Integer getHighestReadyLevel(int startFrom, CallStack callStack) { int curLevel = callStack.highestReadyLevel; while (true) { - if (!checkLevelImpl(curLevel + 1)) { + if (!checkLevelImpl(curLevel + 1, callStack)) { callStack.highestReadyLevel = curLevel; return curLevel >= startFrom ? curLevel : null; } @@ -359,14 +437,14 @@ private Integer getHighestReadyLevel(int startFrom) { } } - private boolean checkLevelBeingReady(int level) { + private boolean checkLevelBeingReady(int level, CallStack callStack) { Assert.assertTrue(level > 0); if (level <= callStack.highestReadyLevel) { return true; } for (int i = callStack.highestReadyLevel + 1; i <= level; i++) { - if (!checkLevelImpl(i)) { + if (!checkLevelImpl(i, callStack)) { return false; } } @@ -374,26 +452,28 @@ private boolean checkLevelBeingReady(int level) { return true; } - private boolean checkLevelImpl(int level) { + private boolean checkLevelImpl(int level, CallStack callStack) { // a level with zero expectations can't be ready if (callStack.expectedFetchCountPerLevel.get(level) == 0) { return false; } - // level 1 is special: there is no previous sub selections - // and the expected execution object calls is always 1 - if (level > 1 && !callStack.allSubSelectionsFetchingHappened(level - 1)) { + + // first we make sure that the expected fetch count is correct + // by verifying that the parent level all execute object + sub selection were fetched + if (!callStack.allExecuteObjectCallsHappened(level - 1)) { return false; } - if (!callStack.allExecuteObjectCallsHappened(level)) { + if (level > 1 && !callStack.allSubSelectionsFetchingHappened(level - 1)) { return false; } + // the main check: all fetches must have happened if (!callStack.allFetchesHappened(level)) { return false; } return true; } - void dispatch(int level) { + void dispatch(int level, CallStack callStack) { if (!enableDataLoaderChaining) { DataLoaderRegistry dataLoaderRegistry = executionContext.getDataLoaderRegistry(); dataLoaderRegistry.dispatchAll(); @@ -409,7 +489,7 @@ void dispatch(int level) { .map(resultPathWithDataLoader -> resultPathWithDataLoader.resultPath) .collect(Collectors.toSet()); }); - dispatchDLCFImpl(resultPathToDispatch, level); + dispatchDLCFImpl(resultPathToDispatch, level, callStack); } else { callStack.lock.runLocked(() -> { callStack.dispatchingStartedPerLevel.add(level); @@ -419,7 +499,7 @@ void dispatch(int level) { } - public void dispatchDLCFImpl(Set resultPathsToDispatch, Integer level) { + public void dispatchDLCFImpl(Set resultPathsToDispatch, Integer level, CallStack callStack) { // filter out all DataLoaderCFS that are matching the fields we want to dispatch List relevantResultPathWithDataLoader = new ArrayList<>(); @@ -444,18 +524,20 @@ public void dispatchDLCFImpl(Set resultPathsToDispatch, Integer level) { } CompletableFuture.allOf(allDispatchedCFs.toArray(new CompletableFuture[0])) .whenComplete((unused, throwable) -> { - dispatchDLCFImpl(resultPathsToDispatch, level); + dispatchDLCFImpl(resultPathsToDispatch, level, callStack); } ); } - public void newDataLoaderLoadCall(String resultPath, int level, DataLoader dataLoader, String dataLoaderName, Object key) { + public void newDataLoaderLoadCall(String resultPath, int level, DataLoader dataLoader, String + dataLoaderName, Object key, @Nullable DeferredCallContext deferredCallContext) { if (!enableDataLoaderChaining) { return; } ResultPathWithDataLoader resultPathWithDataLoader = new ResultPathWithDataLoader(resultPath, level, dataLoader, dataLoaderName, key); + CallStack callStack = getCallStack(deferredCallContext); boolean levelFinished = callStack.lock.callLocked(() -> { boolean finished = callStack.dispatchingFinishedPerLevel.contains(level); callStack.allResultPathWithDataLoader.add(resultPathWithDataLoader); @@ -466,7 +548,7 @@ public void newDataLoaderLoadCall(String resultPath, int level, DataLoader dataL return finished; }); if (levelFinished) { - newDelayedDataLoader(resultPathWithDataLoader); + newDelayedDataLoader(resultPathWithDataLoader, callStack); } @@ -474,6 +556,12 @@ public void newDataLoaderLoadCall(String resultPath, int level, DataLoader dataL class DispatchDelayedDataloader implements Runnable { + private final CallStack callStack; + + public DispatchDelayedDataloader(CallStack callStack) { + this.callStack = callStack; + } + @Override public void run() { AtomicReference> resultPathToDispatch = new AtomicReference<>(); @@ -482,16 +570,16 @@ public void run() { callStack.batchWindowOfDelayedDataLoaderToDispatch.clear(); callStack.batchWindowOpen = false; }); - dispatchDLCFImpl(resultPathToDispatch.get(), null); + dispatchDLCFImpl(resultPathToDispatch.get(), null, callStack); } } - private void newDelayedDataLoader(ResultPathWithDataLoader resultPathWithDataLoader) { + private void newDelayedDataLoader(ResultPathWithDataLoader resultPathWithDataLoader, CallStack callStack) { callStack.lock.runLocked(() -> { callStack.batchWindowOfDelayedDataLoaderToDispatch.add(resultPathWithDataLoader.resultPath); if (!callStack.batchWindowOpen) { callStack.batchWindowOpen = true; - delayedDataLoaderDispatchExecutor.get().schedule(new DispatchDelayedDataloader(), this.batchWindowNs, TimeUnit.NANOSECONDS); + delayedDataLoaderDispatchExecutor.get().schedule(new DispatchDelayedDataloader(callStack), this.batchWindowNs, TimeUnit.NANOSECONDS); } }); diff --git a/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategyWithDeferAlwaysDispatch.java b/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategyWithDeferAlwaysDispatch.java deleted file mode 100644 index 7f996f664b..0000000000 --- a/src/main/java/graphql/execution/instrumentation/dataloader/PerLevelDataLoaderDispatchStrategyWithDeferAlwaysDispatch.java +++ /dev/null @@ -1,279 +0,0 @@ -package graphql.execution.instrumentation.dataloader; - -import graphql.Assert; -import graphql.Internal; -import graphql.execution.DataLoaderDispatchStrategy; -import graphql.execution.ExecutionContext; -import graphql.execution.ExecutionStrategyParameters; -import graphql.execution.FieldValueInfo; -import graphql.execution.MergedField; -import graphql.schema.DataFetcher; -import graphql.schema.DataFetchingEnvironment; -import graphql.util.LockKit; -import org.dataloader.DataLoaderRegistry; - -import java.util.LinkedHashSet; -import java.util.List; -import java.util.Set; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Supplier; - -/** - * The execution of a query can be divided into 2 phases: first, the non-deferred fields are executed and only once - * they are completely resolved, we start to execute the deferred fields. - * The behavior of this Data Loader strategy is quite different during those 2 phases. During the execution of the - * deferred fields the Data Loader will not attempt to dispatch in a optimal way. It will essentially dispatch for - * every field fetched, which is quite ineffective. - * This is the first iteration of the Data Loader strategy with support for @defer, and it will be improved in the - * future. - */ -@Internal -public class PerLevelDataLoaderDispatchStrategyWithDeferAlwaysDispatch implements DataLoaderDispatchStrategy { - - private final CallStack callStack; - private final ExecutionContext executionContext; - - /** - * This flag is used to determine if we have started the deferred execution. - * The value of this flag is set to true as soon as we identified that a deferred field is being executed, and then - * the flag stays on that state for the remainder of the execution. - */ - private final AtomicBoolean startedDeferredExecution = new AtomicBoolean(false); - - - private static class CallStack { - - private final LockKit.ReentrantLock lock = new LockKit.ReentrantLock(); - private final LevelMap expectedFetchCountPerLevel = new LevelMap(); - private final LevelMap fetchCountPerLevel = new LevelMap(); - private final LevelMap expectedStrategyCallsPerLevel = new LevelMap(); - private final LevelMap happenedStrategyCallsPerLevel = new LevelMap(); - private final LevelMap happenedOnFieldValueCallsPerLevel = new LevelMap(); - - private final Set dispatchedLevels = new LinkedHashSet<>(); - - public CallStack() { - expectedStrategyCallsPerLevel.set(1, 1); - } - - void increaseExpectedFetchCount(int level, int count) { - expectedFetchCountPerLevel.increment(level, count); - } - - void increaseFetchCount(int level) { - fetchCountPerLevel.increment(level, 1); - } - - void increaseExpectedStrategyCalls(int level, int count) { - expectedStrategyCallsPerLevel.increment(level, count); - } - - void increaseHappenedStrategyCalls(int level) { - happenedStrategyCallsPerLevel.increment(level, 1); - } - - void increaseHappenedOnFieldValueCalls(int level) { - happenedOnFieldValueCallsPerLevel.increment(level, 1); - } - - boolean allStrategyCallsHappened(int level) { - return happenedStrategyCallsPerLevel.get(level) == expectedStrategyCallsPerLevel.get(level); - } - - boolean allOnFieldCallsHappened(int level) { - return happenedOnFieldValueCallsPerLevel.get(level) == expectedStrategyCallsPerLevel.get(level); - } - - boolean allFetchesHappened(int level) { - return fetchCountPerLevel.get(level) == expectedFetchCountPerLevel.get(level); - } - - @Override - public String toString() { - return "CallStack{" + - "expectedFetchCountPerLevel=" + expectedFetchCountPerLevel + - ", fetchCountPerLevel=" + fetchCountPerLevel + - ", expectedStrategyCallsPerLevel=" + expectedStrategyCallsPerLevel + - ", happenedStrategyCallsPerLevel=" + happenedStrategyCallsPerLevel + - ", happenedOnFieldValueCallsPerLevel=" + happenedOnFieldValueCallsPerLevel + - ", dispatchedLevels" + dispatchedLevels + - '}'; - } - - - public boolean dispatchIfNotDispatchedBefore(int level) { - if (dispatchedLevels.contains(level)) { - Assert.assertShouldNeverHappen("level " + level + " already dispatched"); - return false; - } - dispatchedLevels.add(level); - return true; - } - } - - public PerLevelDataLoaderDispatchStrategyWithDeferAlwaysDispatch(ExecutionContext executionContext) { - this.callStack = new CallStack(); - this.executionContext = executionContext; - } - - @Override - public void executeDeferredOnFieldValueInfo(FieldValueInfo fieldValueInfo, ExecutionStrategyParameters executionStrategyParameters) { - this.startedDeferredExecution.set(true); - } - - @Override - public void executionStrategy(ExecutionContext executionContext, ExecutionStrategyParameters parameters) { - if (this.startedDeferredExecution.get()) { - return; - } - int curLevel = parameters.getExecutionStepInfo().getPath().getLevel() + 1; - increaseCallCounts(curLevel, parameters); - } - - @Override - public void executeObject(ExecutionContext executionContext, ExecutionStrategyParameters parameters) { - if (this.startedDeferredExecution.get()) { - return; - } - int curLevel = parameters.getExecutionStepInfo().getPath().getLevel() + 1; - increaseCallCounts(curLevel, parameters); - } - - @Override - public void executionStrategyOnFieldValuesInfo(List fieldValueInfoList) { - if (this.startedDeferredExecution.get()) { - this.dispatch(); - } - onFieldValuesInfoDispatchIfNeeded(fieldValueInfoList, 1); - } - - @Override - public void executionStrategyOnFieldValuesException(Throwable t) { - callStack.lock.runLocked(() -> - callStack.increaseHappenedOnFieldValueCalls(1) - ); - } - - @Override - public void executeObjectOnFieldValuesInfo(List fieldValueInfoList, ExecutionStrategyParameters parameters) { - if (this.startedDeferredExecution.get()) { - this.dispatch(); - } - int curLevel = parameters.getPath().getLevel() + 1; - onFieldValuesInfoDispatchIfNeeded(fieldValueInfoList, curLevel); - } - - - @Override - public void executeObjectOnFieldValuesException(Throwable t, ExecutionStrategyParameters parameters) { - int curLevel = parameters.getPath().getLevel() + 1; - callStack.lock.runLocked(() -> - callStack.increaseHappenedOnFieldValueCalls(curLevel) - ); - } - - @Override - public void fieldFetched(ExecutionContext executionContext, - ExecutionStrategyParameters parameters, - DataFetcher dataFetcher, - Object fetchedValue, - Supplier dataFetchingEnvironment) { - - final boolean dispatchNeeded; - - if (parameters.getField().isDeferred() || this.startedDeferredExecution.get()) { - this.startedDeferredExecution.set(true); - dispatchNeeded = true; - } else { - int level = parameters.getPath().getLevel(); - dispatchNeeded = callStack.lock.callLocked(() -> { - callStack.increaseFetchCount(level); - return dispatchIfNeeded(level); - }); - } - - if (dispatchNeeded) { - dispatch(); - } - - } - - private void increaseCallCounts(int curLevel, ExecutionStrategyParameters parameters) { - int count = 0; - for (MergedField field : parameters.getFields().getSubFieldsList()) { - if (!field.isDeferred()) { - count++; - } - } - int nonDeferredFieldCount = count; - callStack.lock.runLocked(() -> { - callStack.increaseExpectedFetchCount(curLevel, nonDeferredFieldCount); - callStack.increaseHappenedStrategyCalls(curLevel); - }); - } - - private void onFieldValuesInfoDispatchIfNeeded(List fieldValueInfoList, int curLevel) { - boolean dispatchNeeded = callStack.lock.callLocked(() -> - handleOnFieldValuesInfo(fieldValueInfoList, curLevel) - ); - if (dispatchNeeded) { - dispatch(); - } - } - - // - // thread safety: called with callStack.lock - // - private boolean handleOnFieldValuesInfo(List fieldValueInfos, int curLevel) { - callStack.increaseHappenedOnFieldValueCalls(curLevel); - int expectedStrategyCalls = getCountForList(fieldValueInfos); - callStack.increaseExpectedStrategyCalls(curLevel + 1, expectedStrategyCalls); - return dispatchIfNeeded(curLevel + 1); - } - - private int getCountForList(List fieldValueInfos) { - int result = 0; - for (FieldValueInfo fieldValueInfo : fieldValueInfos) { - if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.OBJECT) { - result += 1; - } else if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.LIST) { - result += getCountForList(fieldValueInfo.getFieldValueInfos()); - } - } - return result; - } - - // - // thread safety : called with callStack.lock - // - private boolean dispatchIfNeeded(int level) { - boolean ready = levelReady(level); - if (ready) { - return callStack.dispatchIfNotDispatchedBefore(level); - } - return false; - } - - // - // thread safety: called with callStack.lock - // - private boolean levelReady(int level) { - if (level == 1) { - // level 1 is special: there is only one strategy call and that's it - return callStack.allFetchesHappened(1); - } - if (levelReady(level - 1) && callStack.allOnFieldCallsHappened(level - 1) - && callStack.allStrategyCallsHappened(level) && callStack.allFetchesHappened(level)) { - - return true; - } - return false; - } - - void dispatch() { - DataLoaderRegistry dataLoaderRegistry = executionContext.getDataLoaderRegistry(); - dataLoaderRegistry.dispatchAll(); - } - -} - diff --git a/src/main/java/graphql/schema/DataFetchingEnvironmentImpl.java b/src/main/java/graphql/schema/DataFetchingEnvironmentImpl.java index 0dd0e30674..dc9c3776cd 100644 --- a/src/main/java/graphql/schema/DataFetchingEnvironmentImpl.java +++ b/src/main/java/graphql/schema/DataFetchingEnvironmentImpl.java @@ -12,6 +12,7 @@ import graphql.execution.ExecutionStepInfo; import graphql.execution.MergedField; import graphql.execution.directives.QueryDirectives; +import graphql.execution.incremental.DeferredCallContext; import graphql.language.Document; import graphql.language.Field; import graphql.language.FragmentDefinition; @@ -78,7 +79,7 @@ private DataFetchingEnvironmentImpl(Builder builder) { this.queryDirectives = builder.queryDirectives; // internal state - this.dfeInternalState = new DFEInternalState(builder.dataLoaderDispatchStrategy); + this.dfeInternalState = new DFEInternalState(builder.dataLoaderDispatchStrategy, builder.deferredCallContext); } /** @@ -106,7 +107,6 @@ public static Builder newDataFetchingEnvironment(ExecutionContext executionConte .variables(executionContext.getCoercedVariables().toMap()) .executionId(executionContext.getExecutionId()) .dataLoaderDispatchStrategy(executionContext.getDataLoaderDispatcherStrategy()); - } @Override @@ -282,6 +282,7 @@ public static class Builder { private ImmutableMapWithNullValues variables; private QueryDirectives queryDirectives; private DataLoaderDispatchStrategy dataLoaderDispatchStrategy; + private DeferredCallContext deferredCallContext; public Builder(DataFetchingEnvironmentImpl env) { this.source = env.source; @@ -306,6 +307,7 @@ public Builder(DataFetchingEnvironmentImpl env) { this.variables = env.variables; this.queryDirectives = env.queryDirectives; this.dataLoaderDispatchStrategy = env.dfeInternalState.dataLoaderDispatchStrategy; + this.deferredCallContext = env.dfeInternalState.deferredCallContext; } public Builder() { @@ -425,6 +427,11 @@ public Builder queryDirectives(QueryDirectives queryDirectives) { return this; } + public Builder deferredCallContext(DeferredCallContext deferredCallContext) { + this.deferredCallContext = deferredCallContext; + return this; + } + public DataFetchingEnvironment build() { return new DataFetchingEnvironmentImpl(this); } @@ -438,13 +445,19 @@ public Builder dataLoaderDispatchStrategy(DataLoaderDispatchStrategy dataLoaderD @Internal public static class DFEInternalState { final DataLoaderDispatchStrategy dataLoaderDispatchStrategy; + final DeferredCallContext deferredCallContext; - public DFEInternalState(DataLoaderDispatchStrategy dataLoaderDispatchStrategy) { + public DFEInternalState(DataLoaderDispatchStrategy dataLoaderDispatchStrategy, DeferredCallContext deferredCallContext) { this.dataLoaderDispatchStrategy = dataLoaderDispatchStrategy; + this.deferredCallContext = deferredCallContext; } public DataLoaderDispatchStrategy getDataLoaderDispatchStrategy() { return dataLoaderDispatchStrategy; } + + public DeferredCallContext getDeferredCallContext() { + return deferredCallContext; + } } } diff --git a/src/main/java/graphql/schema/DataLoaderWithContext.java b/src/main/java/graphql/schema/DataLoaderWithContext.java index a4b56814ca..0c6ae9e1d7 100644 --- a/src/main/java/graphql/schema/DataLoaderWithContext.java +++ b/src/main/java/graphql/schema/DataLoaderWithContext.java @@ -1,6 +1,7 @@ package graphql.schema; import graphql.Internal; +import graphql.execution.incremental.DeferredCallContext; import graphql.execution.instrumentation.dataloader.PerLevelDataLoaderDispatchStrategy; import org.dataloader.DataLoader; import org.dataloader.DelegatingDataLoader; @@ -32,7 +33,8 @@ public CompletableFuture load(@NonNull K key, @Nullable Object keyContext) { String path = dfe.getExecutionStepInfo().getPath().toString(); DataFetchingEnvironmentImpl.DFEInternalState dfeInternalState = (DataFetchingEnvironmentImpl.DFEInternalState) dfeImpl.toInternal(); if (dfeInternalState.getDataLoaderDispatchStrategy() instanceof PerLevelDataLoaderDispatchStrategy) { - ((PerLevelDataLoaderDispatchStrategy) dfeInternalState.dataLoaderDispatchStrategy).newDataLoaderLoadCall(path, level, delegate, dataLoaderName, key); + DeferredCallContext deferredCallContext = dfeInternalState.getDeferredCallContext(); + ((PerLevelDataLoaderDispatchStrategy) dfeInternalState.dataLoaderDispatchStrategy).newDataLoaderLoadCall(path, level, delegate, dataLoaderName, key, deferredCallContext); } return result; } diff --git a/src/test/groovy/graphql/execution/incremental/DeferExecutionSupportIntegrationTest.groovy b/src/test/groovy/graphql/execution/incremental/DeferExecutionSupportIntegrationTest.groovy index 49e51eaf5a..b3b522d90b 100644 --- a/src/test/groovy/graphql/execution/incremental/DeferExecutionSupportIntegrationTest.groovy +++ b/src/test/groovy/graphql/execution/incremental/DeferExecutionSupportIntegrationTest.groovy @@ -18,11 +18,16 @@ import graphql.schema.DataFetchingEnvironment import graphql.schema.TypeResolver import graphql.schema.idl.RuntimeWiring import org.awaitility.Awaitility +import org.dataloader.BatchLoader +import org.dataloader.DataLoader +import org.dataloader.DataLoaderFactory +import org.dataloader.DataLoaderRegistry import org.reactivestreams.Publisher import spock.lang.Specification import spock.lang.Unroll import java.util.concurrent.CompletableFuture +import java.util.concurrent.atomic.AtomicInteger import static graphql.schema.idl.TypeRuntimeWiring.newTypeWiring @@ -62,6 +67,8 @@ class DeferExecutionSupportIntegrationTest extends Specification { typeMismatchError: [String] nonNullableError: String! wordCount: Int + fieldWithDataLoader1: String + fieldWithDataLoader2: String } type Comment { @@ -90,6 +97,13 @@ class DeferExecutionSupportIntegrationTest extends Specification { return resolve(value, sleepMs, false) } + private static DataFetcher fieldWithDataLoader(String key) { + return (dfe) -> { + def dataLoader = dfe.getDataLoader("someDataLoader") + return dataLoader.load(dfe.getSource().id + "-" + key) + }; + } + private static DataFetcher resolve(Object value, Integer sleepMs, boolean allowMultipleCalls) { return new DataFetcher() { boolean executed = false @@ -163,6 +177,8 @@ class DeferExecutionSupportIntegrationTest extends Specification { .dataFetcher("item", resolveItem()) ) .type(newTypeWiring("Post").dataFetcher("summary", resolve("A summary", 10))) + .type(newTypeWiring("Post").dataFetcher("fieldWithDataLoader1", fieldWithDataLoader("fieldWithDataLoader1"))) + .type(newTypeWiring("Post").dataFetcher("fieldWithDataLoader2", fieldWithDataLoader("fieldWithDataLoader2"))) .type(newTypeWiring("Post").dataFetcher("text", resolve("The full text", 100))) .type(newTypeWiring("Post").dataFetcher("wordCount", resolve(45999, 10, true))) .type(newTypeWiring("Post").dataFetcher("latestComment", resolve([title: "Comment title"], 10))) @@ -1674,6 +1690,42 @@ class DeferExecutionSupportIntegrationTest extends Specification { } + def "dataloader used inside defer"() { + given: + def query = ''' + query { + post { + id + ...@defer { + fieldWithDataLoader1 + fieldWithDataLoader2 + } + } + } + ''' + + def batchLoaderCallCount = new AtomicInteger(0) + when: + def initialResult = executeQuery(query, true, [:], batchLoaderCallCount) + + then: + initialResult.toSpecification() == [ + data : [post: [id: "1001"]], + hasNext: true + ] + + when: + def incrementalResults = getIncrementalResults(initialResult) + + then: + batchLoaderCallCount.get() == 1 + incrementalResults.size() == 1 + incrementalResults[0] == [incremental: [[path: ["post"], data: [fieldWithDataLoader1: "1001-fieldWithDataLoader1", fieldWithDataLoader2: "1001-fieldWithDataLoader2"]]], + hasNext : false + ] + + } + private ExecutionResult executeQuery(String query) { return this.executeQuery(query, true, [:]) @@ -1683,12 +1735,22 @@ class DeferExecutionSupportIntegrationTest extends Specification { return this.executeQuery(query, true, variables) } - private ExecutionResult executeQuery(String query, boolean incrementalSupport, Map variables) { + private ExecutionResult executeQuery(String query, boolean incrementalSupport, Map variables, AtomicInteger batchLoaderCallCount = null) { + BatchLoader batchLoader = { keys -> + if (batchLoaderCallCount != null) { + batchLoaderCallCount.incrementAndGet() + } + return CompletableFuture.completedFuture(keys) + } + DataLoader dl = DataLoaderFactory.newDataLoader(batchLoader) + DataLoaderRegistry dataLoaderRegistry = new DataLoaderRegistry(); + dataLoaderRegistry.register("someDataLoader", dl) return graphQL.execute( ExecutionInput.newExecutionInput() .graphQLContext([(ExperimentalApi.ENABLE_INCREMENTAL_SUPPORT): incrementalSupport]) .query(query) .variables(variables) + .dataLoaderRegistry(dataLoaderRegistry) .build() ) } diff --git a/src/test/groovy/graphql/execution/instrumentation/dataloader/DataLoaderPerformanceData.groovy b/src/test/groovy/graphql/execution/instrumentation/dataloader/DataLoaderPerformanceData.groovy index 7366ded562..5e72e5f2ad 100644 --- a/src/test/groovy/graphql/execution/instrumentation/dataloader/DataLoaderPerformanceData.groovy +++ b/src/test/groovy/graphql/execution/instrumentation/dataloader/DataLoaderPerformanceData.groovy @@ -68,14 +68,14 @@ class DataLoaderPerformanceData { static String getQuery(boolean deferDepartments, boolean deferProducts) { return """ query { - shops { - id name + shops { # 1 + id name # 2 ... @defer(if: $deferDepartments) { - departments { - id name + departments { # 2 + id name # 3 ... @defer(if: $deferProducts) { - products { - id name + products { # 3 + id name # 4 } } } diff --git a/src/test/groovy/graphql/execution/instrumentation/dataloader/DeferWithDataLoaderTest.groovy b/src/test/groovy/graphql/execution/instrumentation/dataloader/DeferWithDataLoaderTest.groovy index 6da9489c76..5427f7e504 100644 --- a/src/test/groovy/graphql/execution/instrumentation/dataloader/DeferWithDataLoaderTest.groovy +++ b/src/test/groovy/graphql/execution/instrumentation/dataloader/DeferWithDataLoaderTest.groovy @@ -2,12 +2,19 @@ package graphql.execution.instrumentation.dataloader import graphql.ExecutionInput import graphql.ExecutionResult +import graphql.ExperimentalApi import graphql.GraphQL +import graphql.TestUtil import graphql.incremental.IncrementalExecutionResult +import graphql.schema.DataFetcher +import org.awaitility.Awaitility +import org.dataloader.BatchLoader +import org.dataloader.DataLoaderFactory import org.dataloader.DataLoaderRegistry import spock.lang.Specification -import java.util.stream.Collectors +import java.time.Duration +import java.util.concurrent.CompletableFuture import static graphql.ExperimentalApi.ENABLE_INCREMENTAL_SUPPORT import static graphql.execution.instrumentation.dataloader.DataLoaderPerformanceData.combineExecutionResults @@ -36,7 +43,7 @@ class DeferWithDataLoaderTest extends Specification { * @param results a list of the incremental results from the execution * @param expectedPaths a list of the expected paths in the incremental results. The order of the elements in the list is not important. */ - private static void assertIncrementalResults(List> results, List> expectedPaths) { + private static void assertIncrementalResults(List> results, List> expectedPaths, List expectedData = null) { assert results.size() == expectedPaths.size(), "Expected ${expectedPaths.size()} results, got ${results.size()}" assert results.dropRight(1).every { it.hasNext == true }, "Expected all but the last result to have hasNext=true" @@ -44,8 +51,12 @@ class DeferWithDataLoaderTest extends Specification { assert results.every { it.incremental.size() == 1 }, "Expected every result to have exactly one incremental item" - expectedPaths.each { path -> - assert results.any { it.incremental[0].path == path }, "Expected path $path not found in $results" + expectedPaths.eachWithIndex { path, index -> + def result = results.find { it.incremental[0].path == path } + assert result != null, "Expected path $path not found in $results" + if (expectedData != null) { + assert result.incremental[0].data == expectedData[index], "Expected data $expectedData[index] for path $path, got ${result.incremental[0].data}" + } } } @@ -90,7 +101,7 @@ class DeferWithDataLoaderTest extends Specification { combined.data == expectedData batchCompareDataFetchers.departmentsForShopsBatchLoaderCounter.get() == 3 - batchCompareDataFetchers.productsForDepartmentsBatchLoaderCounter.get() == 9 + batchCompareDataFetchers.productsForDepartmentsBatchLoaderCounter.get() == 3 } def "multiple fields on same defer block"() { @@ -320,10 +331,10 @@ class DeferWithDataLoaderTest extends Specification { [ ["expensiveShops", 0], ["expensiveShops", 1], ["expensiveShops", 2], ["shops", 0], ["shops", 1], ["shops", 2], - ["shops", 0, "departments", 0], ["shops", 0, "departments", 1],["shops", 0, "departments", 2], ["shops", 1, "departments", 0],["shops", 1, "departments", 1], ["shops", 1, "departments", 2], ["shops", 2, "departments", 0],["shops", 2, "departments", 1],["shops", 2, "departments", 2], - ["shops", 0, "expensiveDepartments", 0], ["shops", 0, "expensiveDepartments", 1], ["shops", 0, "expensiveDepartments", 2], ["shops", 1, "expensiveDepartments", 0], ["shops", 1, "expensiveDepartments", 1], ["shops", 1, "expensiveDepartments", 2], ["shops", 2, "expensiveDepartments", 0], ["shops", 2, "expensiveDepartments", 1],["shops", 2, "expensiveDepartments", 2], - ["expensiveShops", 0, "expensiveDepartments", 0], ["expensiveShops", 0, "expensiveDepartments", 1], ["expensiveShops", 0, "expensiveDepartments", 2], ["expensiveShops", 1, "expensiveDepartments", 0], ["expensiveShops", 1, "expensiveDepartments", 1], ["expensiveShops", 1, "expensiveDepartments", 2], ["expensiveShops", 2, "expensiveDepartments", 0], ["expensiveShops", 2, "expensiveDepartments", 1],["expensiveShops", 2, "expensiveDepartments", 2], - ["expensiveShops", 0, "departments", 0], ["expensiveShops", 0, "departments", 1], ["expensiveShops", 0, "departments", 2], ["expensiveShops", 1, "departments", 0], ["expensiveShops", 1, "departments", 1], ["expensiveShops", 1, "departments", 2], ["expensiveShops", 2, "departments", 0], ["expensiveShops", 2, "departments", 1],["expensiveShops", 2, "departments", 2]] + ["shops", 0, "departments", 0], ["shops", 0, "departments", 1], ["shops", 0, "departments", 2], ["shops", 1, "departments", 0], ["shops", 1, "departments", 1], ["shops", 1, "departments", 2], ["shops", 2, "departments", 0], ["shops", 2, "departments", 1], ["shops", 2, "departments", 2], + ["shops", 0, "expensiveDepartments", 0], ["shops", 0, "expensiveDepartments", 1], ["shops", 0, "expensiveDepartments", 2], ["shops", 1, "expensiveDepartments", 0], ["shops", 1, "expensiveDepartments", 1], ["shops", 1, "expensiveDepartments", 2], ["shops", 2, "expensiveDepartments", 0], ["shops", 2, "expensiveDepartments", 1], ["shops", 2, "expensiveDepartments", 2], + ["expensiveShops", 0, "expensiveDepartments", 0], ["expensiveShops", 0, "expensiveDepartments", 1], ["expensiveShops", 0, "expensiveDepartments", 2], ["expensiveShops", 1, "expensiveDepartments", 0], ["expensiveShops", 1, "expensiveDepartments", 1], ["expensiveShops", 1, "expensiveDepartments", 2], ["expensiveShops", 2, "expensiveDepartments", 0], ["expensiveShops", 2, "expensiveDepartments", 1], ["expensiveShops", 2, "expensiveDepartments", 2], + ["expensiveShops", 0, "departments", 0], ["expensiveShops", 0, "departments", 1], ["expensiveShops", 0, "departments", 2], ["expensiveShops", 1, "departments", 0], ["expensiveShops", 1, "departments", 1], ["expensiveShops", 1, "departments", 2], ["expensiveShops", 2, "departments", 0], ["expensiveShops", 2, "departments", 1], ["expensiveShops", 2, "departments", 2]] ) when: @@ -337,4 +348,136 @@ class DeferWithDataLoaderTest extends Specification { batchCompareDataFetchers.productsForDepartmentsBatchLoaderCounter.get() == 1 } + def "dataloader in initial result and chained dataloader inside nested defer block"() { + given: + def sdl = ''' + type Query { + pets: [Pet] + } + + type Pet { + name: String + owner: Owner + } + type Owner { + name: String + address: String + } + + ''' + + def query = ''' + query { + pets { + name + ... @defer { + owner { + name + ... @defer { + address + } + } + } + } + } + ''' + + BatchLoader petNameBatchLoader = { List keys -> + println "petNameBatchLoader called with $keys" + assert keys.size() == 3 + return CompletableFuture.completedFuture(["Pet 1", "Pet 2", "Pet 3"]) + } + BatchLoader addressBatchLoader = { List keys -> + println "addressBatchLoader called with $keys" + assert keys.size() == 3 + return CompletableFuture.completedFuture(keys.collect { it -> + if (it == "owner-1") { + return "Address 1" + } else if (it == "owner-2") { + return "Address 2" + } else if (it == "owner-3") { + return "Address 3" + } + }) + } + + DataLoaderRegistry dataLoaderRegistry = new DataLoaderRegistry() + def petNameDL = DataLoaderFactory.newDataLoader("petName", petNameBatchLoader) + def addressDL = DataLoaderFactory.newDataLoader("address", addressBatchLoader) + dataLoaderRegistry.register("petName", petNameDL) + dataLoaderRegistry.register("address", addressDL) + + DataFetcher petsDF = { env -> + return [ + [id: "pet-1"], + [id: "pet-2"], + [id: "pet-3"] + ] + } + DataFetcher petNameDF = { env -> + env.getDataLoader("petName").load(env.getSource().id) + } + + DataFetcher petOwnerDF = { env -> + String id = env.getSource().id + if (id == "pet-1") { + return [id: "owner-1", name: "Owner 1"] + } else if (id == "pet-2") { + return [id: "owner-2", name: "Owner 2"] + } else if (id == "pet-3") { + return [id: "owner-3", name: "Owner 3"] + } + } + DataFetcher ownerAddressDF = { env -> + return CompletableFuture.supplyAsync { + Thread.sleep(500) + return "foo" + }.thenCompose { + return env.getDataLoader("address").load(env.getSource().id) + } + .thenCompose { + return env.getDataLoader("address").load(env.getSource().id) + } + } + + def schema = TestUtil.schema(sdl, [Query: [pets: petsDF], + Pet : [name: petNameDF, owner: petOwnerDF], + Owner: [address: ownerAddressDF]]) + def graphQL = GraphQL.newGraphQL(schema).build() + def ei = ExecutionInput.newExecutionInput(query).dataLoaderRegistry(dataLoaderRegistry).build() + ei.getGraphQLContext().put(ExperimentalApi.ENABLE_INCREMENTAL_SUPPORT, true) + ei.getGraphQLContext().put(DataLoaderDispatchingContextKeys.ENABLE_DATA_LOADER_CHAINING, true) + ei.getGraphQLContext().put(DataLoaderDispatchingContextKeys.DELAYED_DATA_LOADER_BATCH_WINDOW_SIZE_NANO_SECONDS, Duration.ofSeconds(1).toNanos()) + + when: + CompletableFuture erCF = graphQL.executeAsync(ei) + Awaitility.await().until { erCF.isDone() } + def er = erCF.get() + + then: + er.toSpecification() == [data : [pets: [[name: "Pet 1"], [name: "Pet 2"], [name: "Pet 3"]]], + hasNext: true] + + when: + def incrementalResults = getIncrementalResults(er) + println "incrementalResults: $incrementalResults" + + then: + assertIncrementalResults(incrementalResults, + [ + ["pets", 0], ["pets", 1], ["pets", 2], + ["pets", 0, "owner"], ["pets", 1, "owner"], ["pets", 2, "owner"], + ], + [ + [owner: [name: "Owner 1"]], + [owner: [name: "Owner 2"]], + [owner: [name: "Owner 3"]], + [address: "Address 1"], + [address: "Address 2"], + [address: "Address 3"] + ] + ) + + } + }