From 9f6b430f7833e8a3b0c892c9e6b972e72771164e Mon Sep 17 00:00:00 2001 From: Brad Baker Date: Thu, 17 Mar 2022 08:17:14 +1100 Subject: [PATCH 1/3] Investigation of NonNullableValueCoercedAsNullException is being thrown in MaxQueryDepthInstrumentation --- .../MaxQueryDepthInstrumentationTest.groovy | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/test/groovy/graphql/analysis/MaxQueryDepthInstrumentationTest.groovy b/src/test/groovy/graphql/analysis/MaxQueryDepthInstrumentationTest.groovy index 5a386786e5..4f2301f470 100644 --- a/src/test/groovy/graphql/analysis/MaxQueryDepthInstrumentationTest.groovy +++ b/src/test/groovy/graphql/analysis/MaxQueryDepthInstrumentationTest.groovy @@ -1,6 +1,8 @@ package graphql.analysis import graphql.ExecutionInput +import graphql.ExecutionResult +import graphql.GraphQL import graphql.TestUtil import graphql.execution.AbortExecutionException import graphql.execution.instrumentation.InstrumentationContext @@ -161,4 +163,30 @@ class MaxQueryDepthInstrumentationTest extends Specification { test == true notThrown(Exception) } + + def "coercing null variables that are marked as non nullable"() { + + given: + def schema = TestUtil.schema(""" + type Query { + field(arg : String!) : String + } + """) + + MaxQueryDepthInstrumentation maximumQueryDepthInstrumentation = new MaxQueryDepthInstrumentation(6) + def graphQL = GraphQL.newGraphQL(schema).instrumentation(maximumQueryDepthInstrumentation).build() + + when: + def query = ''' + query x($var : String!) { + field(arg : $var) + } + ''' + def executionInput = ExecutionInput.newExecutionInput(query).variables(["var": null]).build() + def er = graphQL.execute(executionInput) + + then: + ! er.errors.isEmpty() + + } } From 1ed54d983009b4c743fd422e242e7dacb467443d Mon Sep 17 00:00:00 2001 From: Brad Baker Date: Sun, 27 Mar 2022 10:51:41 +1100 Subject: [PATCH 2/3] MaxQueryDepthInstrumentation and MaxQueryComplexityInstrumentation now are called at execution time not after validation --- .../MaxQueryComplexityInstrumentation.java | 93 ++++++++------ .../MaxQueryDepthInstrumentation.java | 52 ++++---- .../graphql/analysis/QueryComplexityInfo.java | 30 ++++- ...xQueryComplexityInstrumentationTest.groovy | 113 ++++++------------ .../MaxQueryDepthInstrumentationTest.groovy | 102 ++++------------ 5 files changed, 167 insertions(+), 223 deletions(-) diff --git a/src/main/java/graphql/analysis/MaxQueryComplexityInstrumentation.java b/src/main/java/graphql/analysis/MaxQueryComplexityInstrumentation.java index 762d13aa62..cba57bea28 100644 --- a/src/main/java/graphql/analysis/MaxQueryComplexityInstrumentation.java +++ b/src/main/java/graphql/analysis/MaxQueryComplexityInstrumentation.java @@ -1,9 +1,14 @@ package graphql.analysis; +import graphql.ExecutionResult; import graphql.PublicApi; import graphql.execution.AbortExecutionException; +import graphql.execution.ExecutionContext; import graphql.execution.instrumentation.InstrumentationContext; +import graphql.execution.instrumentation.InstrumentationState; import graphql.execution.instrumentation.SimpleInstrumentation; +import graphql.execution.instrumentation.parameters.InstrumentationCreateStateParameters; +import graphql.execution.instrumentation.parameters.InstrumentationExecuteOperationParameters; import graphql.execution.instrumentation.parameters.InstrumentationValidationParameters; import graphql.validation.ValidationError; import org.slf4j.Logger; @@ -15,7 +20,7 @@ import java.util.function.Function; import static graphql.Assert.assertNotNull; -import static graphql.execution.instrumentation.SimpleInstrumentationContext.whenCompleted; +import static graphql.execution.instrumentation.SimpleInstrumentationContext.noOp; import static java.util.Optional.ofNullable; /** @@ -76,41 +81,52 @@ public MaxQueryComplexityInstrumentation(int maxComplexity, FieldComplexityCalcu this.maxQueryComplexityExceededFunction = maxQueryComplexityExceededFunction; } + @Override + public InstrumentationState createState(InstrumentationCreateStateParameters parameters) { + return new State(); + } + @Override public InstrumentationContext> beginValidation(InstrumentationValidationParameters parameters) { - return whenCompleted((errors, throwable) -> { - if ((errors != null && errors.size() > 0) || throwable != null) { - return; - } - QueryTraverser queryTraverser = newQueryTraverser(parameters); - - Map valuesByParent = new LinkedHashMap<>(); - queryTraverser.visitPostOrder(new QueryVisitorStub() { - @Override - public void visitField(QueryVisitorFieldEnvironment env) { - int childsComplexity = valuesByParent.getOrDefault(env, 0); - int value = calculateComplexity(env, childsComplexity); - - valuesByParent.compute(env.getParentEnvironment(), (key, oldValue) -> - ofNullable(oldValue).orElse(0) + value - ); - } - }); - int totalComplexity = valuesByParent.getOrDefault(null, 0); - if (log.isDebugEnabled()) { - log.debug("Query complexity: {}", totalComplexity); - } - if (totalComplexity > maxComplexity) { - QueryComplexityInfo queryComplexityInfo = QueryComplexityInfo.newQueryComplexityInfo() - .complexity(totalComplexity) - .instrumentationValidationParameters(parameters) - .build(); - boolean throwAbortException = maxQueryComplexityExceededFunction.apply(queryComplexityInfo); - if (throwAbortException) { - throw mkAbortException(totalComplexity, maxComplexity); - } + State state = parameters.getInstrumentationState(); + // for API backwards compatibility reasons we capture the validation parameters, so we can put them into QueryComplexityInfo + state.instrumentationValidationParameters = parameters; + return noOp(); + } + + @Override + public InstrumentationContext beginExecuteOperation(InstrumentationExecuteOperationParameters instrumentationExecuteOperationParameters) { + State state = instrumentationExecuteOperationParameters.getInstrumentationState(); + QueryTraverser queryTraverser = newQueryTraverser(instrumentationExecuteOperationParameters.getExecutionContext()); + + Map valuesByParent = new LinkedHashMap<>(); + queryTraverser.visitPostOrder(new QueryVisitorStub() { + @Override + public void visitField(QueryVisitorFieldEnvironment env) { + int childsComplexity = valuesByParent.getOrDefault(env, 0); + int value = calculateComplexity(env, childsComplexity); + + valuesByParent.compute(env.getParentEnvironment(), (key, oldValue) -> + ofNullable(oldValue).orElse(0) + value + ); } }); + int totalComplexity = valuesByParent.getOrDefault(null, 0); + if (log.isDebugEnabled()) { + log.debug("Query complexity: {}", totalComplexity); + } + if (totalComplexity > maxComplexity) { + QueryComplexityInfo queryComplexityInfo = QueryComplexityInfo.newQueryComplexityInfo() + .complexity(totalComplexity) + .instrumentationValidationParameters(state.instrumentationValidationParameters) + .instrumentationExecuteOperationParameters(instrumentationExecuteOperationParameters) + .build(); + boolean throwAbortException = maxQueryComplexityExceededFunction.apply(queryComplexityInfo); + if (throwAbortException) { + throw mkAbortException(totalComplexity, maxComplexity); + } + } + return noOp(); } /** @@ -125,12 +141,12 @@ protected AbortExecutionException mkAbortException(int totalComplexity, int maxC return new AbortExecutionException("maximum query complexity exceeded " + totalComplexity + " > " + maxComplexity); } - QueryTraverser newQueryTraverser(InstrumentationValidationParameters parameters) { + QueryTraverser newQueryTraverser(ExecutionContext executionContext) { return QueryTraverser.newQueryTraverser() - .schema(parameters.getSchema()) - .document(parameters.getDocument()) - .operationName(parameters.getOperation()) - .variables(parameters.getVariables()) + .schema(executionContext.getGraphQLSchema()) + .document(executionContext.getDocument()) + .operationName(executionContext.getExecutionInput().getOperationName()) + .variables(executionContext.getVariables()) .build(); } @@ -156,5 +172,8 @@ private FieldComplexityEnvironment convertEnv(QueryVisitorFieldEnvironment query ); } + private static class State implements InstrumentationState { + InstrumentationValidationParameters instrumentationValidationParameters; + } } diff --git a/src/main/java/graphql/analysis/MaxQueryDepthInstrumentation.java b/src/main/java/graphql/analysis/MaxQueryDepthInstrumentation.java index 785fd5d14f..991a830a64 100644 --- a/src/main/java/graphql/analysis/MaxQueryDepthInstrumentation.java +++ b/src/main/java/graphql/analysis/MaxQueryDepthInstrumentation.java @@ -1,18 +1,18 @@ package graphql.analysis; +import graphql.ExecutionResult; import graphql.PublicApi; import graphql.execution.AbortExecutionException; +import graphql.execution.ExecutionContext; import graphql.execution.instrumentation.InstrumentationContext; import graphql.execution.instrumentation.SimpleInstrumentation; -import graphql.execution.instrumentation.parameters.InstrumentationValidationParameters; -import graphql.validation.ValidationError; +import graphql.execution.instrumentation.parameters.InstrumentationExecuteOperationParameters; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.util.List; import java.util.function.Function; -import static graphql.execution.instrumentation.SimpleInstrumentationContext.whenCompleted; +import static graphql.execution.instrumentation.SimpleInstrumentationContext.noOp; /** * Prevents execution if the query depth is greater than the specified maxDepth. @@ -49,26 +49,22 @@ public MaxQueryDepthInstrumentation(int maxDepth, Function> beginValidation(InstrumentationValidationParameters parameters) { - return whenCompleted((errors, throwable) -> { - if ((errors != null && errors.size() > 0) || throwable != null) { - return; - } - QueryTraverser queryTraverser = newQueryTraverser(parameters); - int depth = queryTraverser.reducePreOrder((env, acc) -> Math.max(getPathLength(env.getParentEnvironment()), acc), 0); - if (log.isDebugEnabled()) { - log.debug("Query depth info: {}", depth); - } - if (depth > maxDepth) { - QueryDepthInfo queryDepthInfo = QueryDepthInfo.newQueryDepthInfo() - .depth(depth) - .build(); - boolean throwAbortException = maxQueryDepthExceededFunction.apply(queryDepthInfo); - if (throwAbortException) { - throw mkAbortException(depth, maxDepth); - } + public InstrumentationContext beginExecuteOperation(InstrumentationExecuteOperationParameters parameters) { + QueryTraverser queryTraverser = newQueryTraverser(parameters.getExecutionContext()); + int depth = queryTraverser.reducePreOrder((env, acc) -> Math.max(getPathLength(env.getParentEnvironment()), acc), 0); + if (log.isDebugEnabled()) { + log.debug("Query depth info: {}", depth); + } + if (depth > maxDepth) { + QueryDepthInfo queryDepthInfo = QueryDepthInfo.newQueryDepthInfo() + .depth(depth) + .build(); + boolean throwAbortException = maxQueryDepthExceededFunction.apply(queryDepthInfo); + if (throwAbortException) { + throw mkAbortException(depth, maxDepth); } - }); + } + return noOp(); } /** @@ -83,12 +79,12 @@ protected AbortExecutionException mkAbortException(int depth, int maxDepth) { return new AbortExecutionException("maximum query depth exceeded " + depth + " > " + maxDepth); } - QueryTraverser newQueryTraverser(InstrumentationValidationParameters parameters) { + QueryTraverser newQueryTraverser(ExecutionContext executionContext) { return QueryTraverser.newQueryTraverser() - .schema(parameters.getSchema()) - .document(parameters.getDocument()) - .operationName(parameters.getOperation()) - .variables(parameters.getVariables()) + .schema(executionContext.getGraphQLSchema()) + .document(executionContext.getDocument()) + .operationName(executionContext.getExecutionInput().getOperationName()) + .variables(executionContext.getVariables()) .build(); } diff --git a/src/main/java/graphql/analysis/QueryComplexityInfo.java b/src/main/java/graphql/analysis/QueryComplexityInfo.java index 0d4e213bd8..f4e86d0be1 100644 --- a/src/main/java/graphql/analysis/QueryComplexityInfo.java +++ b/src/main/java/graphql/analysis/QueryComplexityInfo.java @@ -1,6 +1,7 @@ package graphql.analysis; import graphql.PublicApi; +import graphql.execution.instrumentation.parameters.InstrumentationExecuteOperationParameters; import graphql.execution.instrumentation.parameters.InstrumentationValidationParameters; /** @@ -10,11 +11,13 @@ public class QueryComplexityInfo { private final int complexity; - private InstrumentationValidationParameters instrumentationValidationParameters; + private final InstrumentationValidationParameters instrumentationValidationParameters; + private final InstrumentationExecuteOperationParameters instrumentationExecuteOperationParameters; - private QueryComplexityInfo(int complexity, InstrumentationValidationParameters parameters) { - this.complexity = complexity; - this.instrumentationValidationParameters = parameters; + private QueryComplexityInfo(Builder builder) { + this.complexity = builder.complexity; + this.instrumentationValidationParameters = builder.instrumentationValidationParameters; + this.instrumentationExecuteOperationParameters = builder.instrumentationExecuteOperationParameters; } /** @@ -35,6 +38,15 @@ public InstrumentationValidationParameters getInstrumentationValidationParameter return instrumentationValidationParameters; } + /** + * This returns the instrumentation execute operation parameters. + * + * @return the instrumentation execute operation parameters. + */ + public InstrumentationExecuteOperationParameters getInstrumentationExecuteOperationParameters() { + return instrumentationExecuteOperationParameters; + } + @Override public String toString() { return "QueryComplexityInfo{" + @@ -54,6 +66,7 @@ public static class Builder { private int complexity; private InstrumentationValidationParameters instrumentationValidationParameters; + private InstrumentationExecuteOperationParameters instrumentationExecuteOperationParameters; private Builder() { } @@ -62,6 +75,7 @@ private Builder() { * The query complexity. * * @param complexity the query complexity + * * @return this builder */ public Builder complexity(int complexity) { @@ -73,6 +87,7 @@ public Builder complexity(int complexity) { * The instrumentation validation parameters. * * @param parameters the instrumentation validation parameters. + * * @return this builder */ public Builder instrumentationValidationParameters(InstrumentationValidationParameters parameters) { @@ -80,11 +95,16 @@ public Builder instrumentationValidationParameters(InstrumentationValidationPara return this; } + public Builder instrumentationExecuteOperationParameters(InstrumentationExecuteOperationParameters instrumentationExecuteOperationParameters) { + this.instrumentationExecuteOperationParameters = instrumentationExecuteOperationParameters; + return this; + } + /** * @return a built {@link QueryComplexityInfo} object */ public QueryComplexityInfo build() { - return new QueryComplexityInfo(complexity, instrumentationValidationParameters); + return new QueryComplexityInfo(this); } } } diff --git a/src/test/groovy/graphql/analysis/MaxQueryComplexityInstrumentationTest.groovy b/src/test/groovy/graphql/analysis/MaxQueryComplexityInstrumentationTest.groovy index 704a1c9c6d..3670b4fe1d 100644 --- a/src/test/groovy/graphql/analysis/MaxQueryComplexityInstrumentationTest.groovy +++ b/src/test/groovy/graphql/analysis/MaxQueryComplexityInstrumentationTest.groovy @@ -3,12 +3,14 @@ package graphql.analysis import graphql.ExecutionInput import graphql.TestUtil import graphql.execution.AbortExecutionException -import graphql.execution.instrumentation.InstrumentationContext +import graphql.execution.ExecutionContext +import graphql.execution.ExecutionContextBuilder +import graphql.execution.ExecutionId +import graphql.execution.instrumentation.parameters.InstrumentationExecuteOperationParameters import graphql.execution.instrumentation.parameters.InstrumentationValidationParameters import graphql.language.Document import graphql.parser.Parser -import graphql.validation.ValidationError -import graphql.validation.ValidationErrorType +import graphql.schema.GraphQLSchema import spock.lang.Specification import java.util.function.Function @@ -20,61 +22,6 @@ class MaxQueryComplexityInstrumentationTest extends Specification { parser.parseDocument(query) } - def "doesn't do anything if validation errors occur"() { - given: - def schema = TestUtil.schema(""" - type Query{ - bar: String - } - """) - def query = createQuery(""" - { bar { thisIsWrong } } - """) - def queryTraversal = Mock(QueryTraverser) - MaxQueryComplexityInstrumentation maxQueryComplexityInstrumentation = new MaxQueryComplexityInstrumentation(6) { - - @Override - QueryTraverser newQueryTraverser(InstrumentationValidationParameters parameters) { - return queryTraversal - } - } - ExecutionInput executionInput = Mock(ExecutionInput) - InstrumentationValidationParameters validationParameters = new InstrumentationValidationParameters(executionInput, query, schema, null) - InstrumentationContext instrumentationContext = maxQueryComplexityInstrumentation.beginValidation(validationParameters) - when: - instrumentationContext.onCompleted([new ValidationError(ValidationErrorType.SubSelectionNotAllowed)], null) - then: - 0 * queryTraversal._(_) - - } - - def "doesn't do anything if exception was thrown"() { - given: - def schema = TestUtil.schema(""" - type Query{ - bar: String - } - """) - def query = createQuery(""" - { bar { thisIsWrong } } - """) - def queryTraversal = Mock(QueryTraverser) - MaxQueryComplexityInstrumentation maxQueryComplexityInstrumentation = new MaxQueryComplexityInstrumentation(6) { - - @Override - QueryTraverser newQueryTraverser(InstrumentationValidationParameters parameters) { - return queryTraversal - } - } - ExecutionInput executionInput = Mock(ExecutionInput) - InstrumentationValidationParameters validationParameters = new InstrumentationValidationParameters(executionInput, query, schema, null) - InstrumentationContext instrumentationContext = maxQueryComplexityInstrumentation.beginValidation(validationParameters) - when: - instrumentationContext.onCompleted(null, new RuntimeException()) - then: - 0 * queryTraversal._(_) - - } def "default complexity calculator"() { given: @@ -93,16 +40,16 @@ class MaxQueryComplexityInstrumentationTest extends Specification { """) MaxQueryComplexityInstrumentation queryComplexityInstrumentation = new MaxQueryComplexityInstrumentation(10) ExecutionInput executionInput = Mock(ExecutionInput) - InstrumentationValidationParameters validationParameters = new InstrumentationValidationParameters(executionInput, query, schema, null) - InstrumentationContext instrumentationContext = queryComplexityInstrumentation.beginValidation(validationParameters) + InstrumentationExecuteOperationParameters executeOperationParameters = createExecuteOperationParameters(queryComplexityInstrumentation, executionInput, query, schema) when: - instrumentationContext.onCompleted(null, null) + queryComplexityInstrumentation.beginExecuteOperation(executeOperationParameters) then: def e = thrown(AbortExecutionException) e.message == "maximum query complexity exceeded 11 > 10" } + def "complexity calculator works with __typename field with score 0"() { given: def schema = TestUtil.schema(""" @@ -115,10 +62,9 @@ class MaxQueryComplexityInstrumentationTest extends Specification { """) MaxQueryComplexityInstrumentation queryComplexityInstrumentation = new MaxQueryComplexityInstrumentation(1) ExecutionInput executionInput = Mock(ExecutionInput) - InstrumentationValidationParameters validationParameters = new InstrumentationValidationParameters(executionInput, query, schema, null) - InstrumentationContext instrumentationContext = queryComplexityInstrumentation.beginValidation(validationParameters) + InstrumentationExecuteOperationParameters executeOperationParameters = createExecuteOperationParameters(queryComplexityInstrumentation, executionInput, query, schema) when: - instrumentationContext.onCompleted(null, null) + queryComplexityInstrumentation.beginExecuteOperation(executeOperationParameters) then: def e = thrown(AbortExecutionException) e.message == "maximum query complexity exceeded 2 > 1" @@ -143,10 +89,9 @@ class MaxQueryComplexityInstrumentationTest extends Specification { def calculator = Mock(FieldComplexityCalculator) MaxQueryComplexityInstrumentation queryComplexityInstrumentation = new MaxQueryComplexityInstrumentation(5, calculator) ExecutionInput executionInput = Mock(ExecutionInput) - InstrumentationValidationParameters validationParameters = new InstrumentationValidationParameters(executionInput, query, schema, null) - InstrumentationContext instrumentationContext = queryComplexityInstrumentation.beginValidation(validationParameters) + InstrumentationExecuteOperationParameters executeOperationParameters = createExecuteOperationParameters(queryComplexityInstrumentation, executionInput, query, schema) when: - instrumentationContext.onCompleted(null, null) + queryComplexityInstrumentation.beginExecuteOperation(executeOperationParameters) then: 1 * calculator.calculate({ FieldComplexityEnvironment env -> env.field.name == "scalar" }, 0) >> 10 @@ -171,22 +116,23 @@ class MaxQueryComplexityInstrumentationTest extends Specification { def query = createQuery(""" {f2: foo {scalar foo{scalar}} f1: foo { foo {foo {foo {foo{foo{scalar}}}}}} } """) - Boolean test = false + Boolean customFunctionCalled = false Function maxQueryComplexityExceededFunction = new Function() { @Override Boolean apply(final QueryComplexityInfo queryComplexityInfo) { - test = true + assert queryComplexityInfo.instrumentationExecuteOperationParameters != null + assert queryComplexityInfo.instrumentationValidationParameters != null + customFunctionCalled = true return false } } MaxQueryComplexityInstrumentation queryComplexityInstrumentation = new MaxQueryComplexityInstrumentation(10, maxQueryComplexityExceededFunction) ExecutionInput executionInput = Mock(ExecutionInput) - InstrumentationValidationParameters validationParameters = new InstrumentationValidationParameters(executionInput, query, schema, null) - InstrumentationContext instrumentationContext = queryComplexityInstrumentation.beginValidation(validationParameters) + InstrumentationExecuteOperationParameters executeOperationParameters = createExecuteOperationParameters(queryComplexityInstrumentation, executionInput, query, schema) when: - instrumentationContext.onCompleted(null, null) + queryComplexityInstrumentation.beginExecuteOperation(executeOperationParameters) then: - test == true + customFunctionCalled notThrown(Exception) } @@ -205,14 +151,29 @@ class MaxQueryComplexityInstrumentationTest extends Specification { MaxQueryComplexityInstrumentation queryComplexityInstrumentation = new MaxQueryComplexityInstrumentation(0) ExecutionInput executionInput = Mock(ExecutionInput) - InstrumentationValidationParameters validationParameters = new InstrumentationValidationParameters(executionInput, query, schema, null) - InstrumentationContext instrumentationContext = queryComplexityInstrumentation.beginValidation(validationParameters) + InstrumentationExecuteOperationParameters executeOperationParameters = createExecuteOperationParameters(queryComplexityInstrumentation, executionInput, query, schema) when: - instrumentationContext.onCompleted(null, null) + queryComplexityInstrumentation.beginExecuteOperation(executeOperationParameters) then: def e = thrown(AbortExecutionException) e.message == "maximum query complexity exceeded 1 > 0" } + + private InstrumentationExecuteOperationParameters createExecuteOperationParameters(MaxQueryComplexityInstrumentation queryComplexityInstrumentation, ExecutionInput executionInput, Document query, GraphQLSchema schema) { + // we need to run N steps to create instrumentation state + def instrumentationState = queryComplexityInstrumentation.createState(null) + def validationParameters = new InstrumentationValidationParameters(executionInput, query, schema, instrumentationState) + queryComplexityInstrumentation.beginValidation(validationParameters) + def executionContext = executionCtx(executionInput, query, schema) + def executeOperationParameters = new InstrumentationExecuteOperationParameters(executionContext).withNewState(instrumentationState) + executeOperationParameters + } + + private ExecutionContext executionCtx(ExecutionInput executionInput, Document query, GraphQLSchema schema) { + ExecutionContextBuilder.newExecutionContextBuilder() + .executionInput(executionInput).document(query).graphQLSchema(schema).executionId(ExecutionId.generate()) + .build() + } } diff --git a/src/test/groovy/graphql/analysis/MaxQueryDepthInstrumentationTest.groovy b/src/test/groovy/graphql/analysis/MaxQueryDepthInstrumentationTest.groovy index 4f2301f470..11ac79bc7f 100644 --- a/src/test/groovy/graphql/analysis/MaxQueryDepthInstrumentationTest.groovy +++ b/src/test/groovy/graphql/analysis/MaxQueryDepthInstrumentationTest.groovy @@ -1,16 +1,16 @@ package graphql.analysis import graphql.ExecutionInput -import graphql.ExecutionResult import graphql.GraphQL import graphql.TestUtil import graphql.execution.AbortExecutionException -import graphql.execution.instrumentation.InstrumentationContext -import graphql.execution.instrumentation.parameters.InstrumentationValidationParameters +import graphql.execution.ExecutionContext +import graphql.execution.ExecutionContextBuilder +import graphql.execution.ExecutionId +import graphql.execution.instrumentation.parameters.InstrumentationExecuteOperationParameters import graphql.language.Document import graphql.parser.Parser -import graphql.validation.ValidationError -import graphql.validation.ValidationErrorType +import graphql.schema.GraphQLSchema import spock.lang.Specification import java.util.function.Function @@ -23,63 +23,7 @@ class MaxQueryDepthInstrumentationTest extends Specification { } - def "doesn't do anything if validation errors occur"() { - given: - def schema = TestUtil.schema(""" - type Query{ - bar: String - } - """) - def query = createQuery(""" - { bar { thisIsWrong } } - """) - def queryTraversal = Mock(QueryTraverser) - MaxQueryDepthInstrumentation maximumQueryDepthInstrumentation = new MaxQueryDepthInstrumentation(6) { - - @Override - QueryTraverser newQueryTraverser(InstrumentationValidationParameters parameters) { - return queryTraversal - } - } - ExecutionInput executionInput = Mock(ExecutionInput) - InstrumentationValidationParameters validationParameters = new InstrumentationValidationParameters(executionInput, query, schema, null) - InstrumentationContext instrumentationContext = maximumQueryDepthInstrumentation.beginValidation(validationParameters) - when: - instrumentationContext.onCompleted([new ValidationError(ValidationErrorType.SubSelectionNotAllowed)], null) - then: - 0 * queryTraversal._(_) - - } - - def "doesn't do anything if exception was thrown"() { - given: - def schema = TestUtil.schema(""" - type Query{ - bar: String - } - """) - def query = createQuery(""" - { bar { thisIsWrong } } - """) - def queryTraversal = Mock(QueryTraverser) - MaxQueryDepthInstrumentation maximumQueryDepthInstrumentation = new MaxQueryDepthInstrumentation(6) { - - @Override - QueryTraverser newQueryTraverser(InstrumentationValidationParameters parameters) { - return queryTraversal - } - } - ExecutionInput executionInput = Mock(ExecutionInput) - InstrumentationValidationParameters validationParameters = new InstrumentationValidationParameters(executionInput, query, schema, null) - InstrumentationContext instrumentationContext = maximumQueryDepthInstrumentation.beginValidation(validationParameters) - when: - instrumentationContext.onCompleted(null, new RuntimeException()) - then: - 0 * queryTraversal._(_) - - } - - def "throws exception"() { + def "throws exception if too deep"() { given: def schema = TestUtil.schema(""" type Query{ @@ -96,16 +40,16 @@ class MaxQueryDepthInstrumentationTest extends Specification { """) MaxQueryDepthInstrumentation maximumQueryDepthInstrumentation = new MaxQueryDepthInstrumentation(6) ExecutionInput executionInput = Mock(ExecutionInput) - InstrumentationValidationParameters validationParameters = new InstrumentationValidationParameters(executionInput, query, schema, null) - InstrumentationContext instrumentationContext = maximumQueryDepthInstrumentation.beginValidation(validationParameters) + def executionContext = executionCtx(executionInput, query, schema) + def executeOperationParameters = new InstrumentationExecuteOperationParameters(executionContext) when: - instrumentationContext.onCompleted(null, null) + maximumQueryDepthInstrumentation.beginExecuteOperation(executeOperationParameters) then: def e = thrown(AbortExecutionException) e.message.contains("maximum query depth exceeded 7 > 6") } - def "doesn't throw exception"() { + def "doesn't throw exception if not deep enough"() { given: def schema = TestUtil.schema(""" type Query{ @@ -122,10 +66,10 @@ class MaxQueryDepthInstrumentationTest extends Specification { """) MaxQueryDepthInstrumentation maximumQueryDepthInstrumentation = new MaxQueryDepthInstrumentation(7) ExecutionInput executionInput = Mock(ExecutionInput) - InstrumentationValidationParameters validationParameters = new InstrumentationValidationParameters(executionInput, query, schema, null) - InstrumentationContext instrumentationContext = maximumQueryDepthInstrumentation.beginValidation(validationParameters) + def executionContext = executionCtx(executionInput, query, schema) + def executeOperationParameters = new InstrumentationExecuteOperationParameters(executionContext) when: - instrumentationContext.onCompleted(null, null) + maximumQueryDepthInstrumentation.beginExecuteOperation(executeOperationParameters) then: notThrown(Exception) } @@ -145,26 +89,26 @@ class MaxQueryDepthInstrumentationTest extends Specification { def query = createQuery(""" {f1: foo {foo {foo {scalar}}} f2: foo { foo {foo {foo {foo{foo{scalar}}}}}} } """) - Boolean test = false + Boolean calledFunction = false Function maxQueryDepthExceededFunction = new Function() { @Override Boolean apply(final QueryDepthInfo queryDepthInfo) { - test = true + calledFunction = true return false } } MaxQueryDepthInstrumentation maximumQueryDepthInstrumentation = new MaxQueryDepthInstrumentation(6, maxQueryDepthExceededFunction) ExecutionInput executionInput = Mock(ExecutionInput) - InstrumentationValidationParameters validationParameters = new InstrumentationValidationParameters(executionInput, query, schema, null) - InstrumentationContext instrumentationContext = maximumQueryDepthInstrumentation.beginValidation(validationParameters) + def executionContext = executionCtx(executionInput, query, schema) + def executeOperationParameters = new InstrumentationExecuteOperationParameters(executionContext) when: - instrumentationContext.onCompleted(null, null) + maximumQueryDepthInstrumentation.beginExecuteOperation(executeOperationParameters) then: - test == true + calledFunction notThrown(Exception) } - def "coercing null variables that are marked as non nullable"() { + def "coercing null variables that are marked as non nullable wont blow up early"() { given: def schema = TestUtil.schema(""" @@ -186,7 +130,11 @@ class MaxQueryDepthInstrumentationTest extends Specification { def er = graphQL.execute(executionInput) then: - ! er.errors.isEmpty() + !er.errors.isEmpty() + } + private ExecutionContext executionCtx(ExecutionInput executionInput, Document query, GraphQLSchema schema) { + ExecutionContextBuilder.newExecutionContextBuilder() + .executionInput(executionInput).document(query).graphQLSchema(schema).executionId(ExecutionId.generate()).build() } } From 46d4e53addbee9be66415b74c458d42380bdbefd Mon Sep 17 00:00:00 2001 From: Brad Baker Date: Fri, 29 Apr 2022 11:10:21 +1000 Subject: [PATCH 3/3] use atomic ref --- .../analysis/MaxQueryComplexityInstrumentation.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/main/java/graphql/analysis/MaxQueryComplexityInstrumentation.java b/src/main/java/graphql/analysis/MaxQueryComplexityInstrumentation.java index cba57bea28..66523d70e8 100644 --- a/src/main/java/graphql/analysis/MaxQueryComplexityInstrumentation.java +++ b/src/main/java/graphql/analysis/MaxQueryComplexityInstrumentation.java @@ -17,6 +17,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import static graphql.Assert.assertNotNull; @@ -90,7 +91,7 @@ public InstrumentationState createState(InstrumentationCreateStateParameters par public InstrumentationContext> beginValidation(InstrumentationValidationParameters parameters) { State state = parameters.getInstrumentationState(); // for API backwards compatibility reasons we capture the validation parameters, so we can put them into QueryComplexityInfo - state.instrumentationValidationParameters = parameters; + state.instrumentationValidationParameters.set(parameters); return noOp(); } @@ -118,7 +119,7 @@ public void visitField(QueryVisitorFieldEnvironment env) { if (totalComplexity > maxComplexity) { QueryComplexityInfo queryComplexityInfo = QueryComplexityInfo.newQueryComplexityInfo() .complexity(totalComplexity) - .instrumentationValidationParameters(state.instrumentationValidationParameters) + .instrumentationValidationParameters(state.instrumentationValidationParameters.get()) .instrumentationExecuteOperationParameters(instrumentationExecuteOperationParameters) .build(); boolean throwAbortException = maxQueryComplexityExceededFunction.apply(queryComplexityInfo); @@ -173,7 +174,7 @@ private FieldComplexityEnvironment convertEnv(QueryVisitorFieldEnvironment query } private static class State implements InstrumentationState { - InstrumentationValidationParameters instrumentationValidationParameters; + AtomicReference instrumentationValidationParameters = new AtomicReference<>(); } }