Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,55 +44,81 @@ private static class CallStack {
private final LockKit.ReentrantLock lock = new LockKit.ReentrantLock();
private final LevelMap expectedFetchCountPerLevel = new LevelMap();
private final LevelMap fetchCountPerLevel = new LevelMap();
private final LevelMap expectedStrategyCallsPerLevel = new LevelMap();
private final LevelMap happenedStrategyCallsPerLevel = new LevelMap();

private final LevelMap expectedExecuteObjectCallsPerLevel = new LevelMap();
private final LevelMap happenedExecuteObjectCallsPerLevel = new LevelMap();

private final LevelMap happenedOnFieldValueCallsPerLevel = new LevelMap();

private final Set<Integer> dispatchedLevels = new LinkedHashSet<>();

public CallStack() {
expectedStrategyCallsPerLevel.set(1, 1);
expectedExecuteObjectCallsPerLevel.set(1, 1);
}

void increaseExpectedFetchCount(int level, int count) {
expectedFetchCountPerLevel.increment(level, count);
}

void clearExpectedFetchCount() {
expectedFetchCountPerLevel.clear();
}

void increaseFetchCount(int level) {
fetchCountPerLevel.increment(level, 1);
}

void increaseExpectedStrategyCalls(int level, int count) {
expectedStrategyCallsPerLevel.increment(level, count);
void clearFetchCount() {
fetchCountPerLevel.clear();
}

void increaseExpectedExecuteObjectCalls(int level, int count) {
expectedExecuteObjectCallsPerLevel.increment(level, count);
}

void increaseHappenedStrategyCalls(int level) {
happenedStrategyCallsPerLevel.increment(level, 1);
void clearExpectedObjectCalls() {
expectedExecuteObjectCallsPerLevel.clear();
}

void increaseHappenedExecuteObjectCalls(int level) {
happenedExecuteObjectCallsPerLevel.increment(level, 1);
}

void clearHappenedExecuteObjectCalls() {
happenedExecuteObjectCallsPerLevel.clear();
}

void increaseHappenedOnFieldValueCalls(int level) {
happenedOnFieldValueCallsPerLevel.increment(level, 1);
}

boolean allStrategyCallsHappened(int level) {
return happenedStrategyCallsPerLevel.get(level) == expectedStrategyCallsPerLevel.get(level);
void clearHappenedOnFieldValueCalls() {
happenedOnFieldValueCallsPerLevel.clear();
}

boolean allExecuteObjectCallsHappened(int level) {
return happenedExecuteObjectCallsPerLevel.get(level) == expectedExecuteObjectCallsPerLevel.get(level);
}

boolean allOnFieldCallsHappened(int level) {
return happenedOnFieldValueCallsPerLevel.get(level) == expectedStrategyCallsPerLevel.get(level);
return happenedOnFieldValueCallsPerLevel.get(level) == expectedExecuteObjectCallsPerLevel.get(level);
}

boolean allFetchesHappened(int level) {
return fetchCountPerLevel.get(level) == expectedFetchCountPerLevel.get(level);
}

void clearDispatchLevels() {
dispatchedLevels.clear();
}

@Override
public String toString() {
return "CallStack{" +
"expectedFetchCountPerLevel=" + expectedFetchCountPerLevel +
", fetchCountPerLevel=" + fetchCountPerLevel +
", expectedStrategyCallsPerLevel=" + expectedStrategyCallsPerLevel +
", happenedStrategyCallsPerLevel=" + happenedStrategyCallsPerLevel +
", expectedExecuteObjectCallsPerLevel=" + expectedExecuteObjectCallsPerLevel +
", happenedExecuteObjectCallsPerLevel=" + happenedExecuteObjectCallsPerLevel +
", happenedOnFieldValueCallsPerLevel=" + happenedOnFieldValueCallsPerLevel +
", dispatchedLevels" + dispatchedLevels +
'}';
Expand Down Expand Up @@ -125,16 +151,14 @@ public void executionStrategy(ExecutionContext executionContext, ExecutionStrate
return;
}
int curLevel = parameters.getExecutionStepInfo().getPath().getLevel() + 1;
increaseCallCounts(curLevel, parameters);
increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(curLevel, parameters);

}

@Override
public void executeObject(ExecutionContext executionContext, ExecutionStrategyParameters parameters) {
if (this.startedDeferredExecution.get()) {
return;
}
int curLevel = parameters.getExecutionStepInfo().getPath().getLevel() + 1;
increaseCallCounts(curLevel, parameters);
public void executionSerialStrategy(ExecutionContext executionContext, ExecutionStrategyParameters parameters) {
resetCallStack();
increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(1, 1);
}

@Override
Expand All @@ -145,13 +169,24 @@ public void executionStrategyOnFieldValuesInfo(List<FieldValueInfo> fieldValueIn
onFieldValuesInfoDispatchIfNeeded(fieldValueInfoList, 1);
}

@Override
public void executionStrategyOnFieldValuesException(Throwable t) {
callStack.lock.runLocked(() ->
callStack.increaseHappenedOnFieldValueCalls(1)
);
}


@Override
public void executeObject(ExecutionContext executionContext, ExecutionStrategyParameters parameters) {
if (this.startedDeferredExecution.get()) {
return;
}
int curLevel = parameters.getExecutionStepInfo().getPath().getLevel() + 1;
increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(curLevel, parameters);
}



@Override
public void executeObjectOnFieldValuesInfo(List<FieldValueInfo> fieldValueInfoList, ExecutionStrategyParameters parameters) {
if (this.startedDeferredExecution.get()) {
Expand All @@ -170,45 +205,34 @@ public void executeObjectOnFieldValuesException(Throwable t, ExecutionStrategyPa
);
}

@Override
public void fieldFetched(ExecutionContext executionContext,
ExecutionStrategyParameters parameters,
DataFetcher<?> dataFetcher,
Object fetchedValue) {

final boolean dispatchNeeded;

if (parameters.getField().isDeferred() || this.startedDeferredExecution.get()) {
this.startedDeferredExecution.set(true);
dispatchNeeded = true;
} else {
int level = parameters.getPath().getLevel();
dispatchNeeded = callStack.lock.callLocked(() -> {
callStack.increaseFetchCount(level);
return dispatchIfNeeded(level);
});
}

if (dispatchNeeded) {
dispatch();
}

}

private void increaseCallCounts(int curLevel, ExecutionStrategyParameters parameters) {
int count = 0;
private void increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(int curLevel, ExecutionStrategyParameters parameters) {
int nonDeferredFields = 0;
for (MergedField field : parameters.getFields().getSubFieldsList()) {
if (!field.isDeferred()) {
count++;
nonDeferredFields++;
}
}
int nonDeferredFieldCount = count;
increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(curLevel, nonDeferredFields);
}

private void increaseHappenedExecuteObjectAndIncreaseExpectedFetchCount(int curLevel, int fieldCount) {
callStack.lock.runLocked(() -> {
callStack.increaseExpectedFetchCount(curLevel, nonDeferredFieldCount);
callStack.increaseHappenedStrategyCalls(curLevel);
callStack.increaseHappenedExecuteObjectCalls(curLevel);
callStack.increaseExpectedFetchCount(curLevel, fieldCount);
});
}

private void resetCallStack() {
callStack.lock.runLocked(() -> {
callStack.clearDispatchLevels();
callStack.clearExpectedObjectCalls();
callStack.clearExpectedFetchCount();
callStack.clearFetchCount();
callStack.clearHappenedExecuteObjectCalls();
callStack.clearHappenedOnFieldValueCalls();
callStack.expectedExecuteObjectCallsPerLevel.set(1, 1);
});
}
private void onFieldValuesInfoDispatchIfNeeded(List<FieldValueInfo> fieldValueInfoList, int curLevel) {
boolean dispatchNeeded = callStack.lock.callLocked(() ->
handleOnFieldValuesInfo(fieldValueInfoList, curLevel)
Expand All @@ -223,23 +247,53 @@ private void onFieldValuesInfoDispatchIfNeeded(List<FieldValueInfo> fieldValueIn
//
private boolean handleOnFieldValuesInfo(List<FieldValueInfo> fieldValueInfos, int curLevel) {
callStack.increaseHappenedOnFieldValueCalls(curLevel);
int expectedStrategyCalls = getCountForList(fieldValueInfos);
callStack.increaseExpectedStrategyCalls(curLevel + 1, expectedStrategyCalls);
int expectedStrategyCalls = getObjectCountForList(fieldValueInfos);
callStack.increaseExpectedExecuteObjectCalls(curLevel + 1, expectedStrategyCalls);
return dispatchIfNeeded(curLevel + 1);
}

private int getCountForList(List<FieldValueInfo> fieldValueInfos) {
/**
* the amount of (non nullable) objects that will require an execute object call
*/
private int getObjectCountForList(List<FieldValueInfo> fieldValueInfos) {
int result = 0;
for (FieldValueInfo fieldValueInfo : fieldValueInfos) {
if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.OBJECT) {
result += 1;
} else if (fieldValueInfo.getCompleteValueType() == FieldValueInfo.CompleteValueType.LIST) {
result += getCountForList(fieldValueInfo.getFieldValueInfos());
result += getObjectCountForList(fieldValueInfo.getFieldValueInfos());
}
}
return result;
}


@Override
public void fieldFetched(ExecutionContext executionContext,
ExecutionStrategyParameters executionStrategyParameters,
DataFetcher<?> dataFetcher,
Object fetchedValue) {

final boolean dispatchNeeded;

if (executionStrategyParameters.getField().isDeferred() || this.startedDeferredExecution.get()) {
this.startedDeferredExecution.set(true);
dispatchNeeded = true;
} else {
int level = executionStrategyParameters.getPath().getLevel();
dispatchNeeded = callStack.lock.callLocked(() -> {
callStack.increaseFetchCount(level);
return dispatchIfNeeded(level);
});
}

if (dispatchNeeded) {
dispatch();
}

}


//
// thread safety : called with callStack.lock
//
Expand All @@ -260,7 +314,7 @@ private boolean levelReady(int level) {
return callStack.allFetchesHappened(1);
}
if (levelReady(level - 1) && callStack.allOnFieldCallsHappened(level - 1)
&& callStack.allStrategyCallsHappened(level) && callStack.allFetchesHappened(level)) {
&& callStack.allExecuteObjectCallsHappened(level) && callStack.allFetchesHappened(level)) {

return true;
}
Expand Down
20 changes: 16 additions & 4 deletions src/test/groovy/graphql/MutationTest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -141,23 +141,26 @@ class MutationTest extends Specification {
]])

def graphQL = GraphQL.newGraphQL(schema).build()

when:
def er = graphQL.execute("""
def ei = ExecutionInput.newExecutionInput("""
mutation m {
plus1(arg:10)
plus2(arg:10)
plus3(arg:10)
}
""")
""").build()
ei.getGraphQLContext().put(ExperimentalApi.ENABLE_INCREMENTAL_SUPPORT, defeEnabled)

when:
def er = graphQL.execute(ei)
then:
er.errors.isEmpty()
er.data == [
plus1: 11,
plus2: 12,
plus3: 13,
]
where:
defeEnabled << [true, false]
}

def "simple async mutation with DataLoader"() {
Expand Down Expand Up @@ -213,6 +216,7 @@ class MutationTest extends Specification {
plus3(arg:10)
}
""").dataLoaderRegistry(dlReg).build()
ei.getGraphQLContext().put(ExperimentalApi.ENABLE_INCREMENTAL_SUPPORT, defeEnabled)
when:
def er = graphQL.execute(ei)

Expand All @@ -223,12 +227,16 @@ class MutationTest extends Specification {
plus2: 12,
plus3: 13,
]

where:
defeEnabled << [true, false]
}

/*
This test shows a dataloader being called at the mutation field level, in serial via AsyncSerialExecutionStrategy, and then
again at the sub field level, in parallel, via AsyncExecutionStrategy.
*/

def "more complex async mutation with DataLoader"() {
def sdl = """
type Query {
Expand Down Expand Up @@ -436,6 +444,7 @@ class MutationTest extends Specification {
}
}
""").dataLoaderRegistry(dlReg).build()
ei.getGraphQLContext().put(ExperimentalApi.ENABLE_INCREMENTAL_SUPPORT, defeEnabled)
when:
def cf = graphQL.executeAsync(ei)

Expand All @@ -459,5 +468,8 @@ class MutationTest extends Specification {
topLevelF3: expectedMap,
topLevelF4: expectedMap,
]

where:
defeEnabled << [true, false]
}
}
Loading