diff --git a/src/main/java/graphql/analysis/NodeVisitorWithTypeTracking.java b/src/main/java/graphql/analysis/NodeVisitorWithTypeTracking.java index 7b65c8e9e..972a6f8e9 100644 --- a/src/main/java/graphql/analysis/NodeVisitorWithTypeTracking.java +++ b/src/main/java/graphql/analysis/NodeVisitorWithTypeTracking.java @@ -3,8 +3,8 @@ import graphql.GraphQLContext; import graphql.Internal; import graphql.execution.CoercedVariables; -import graphql.execution.ConditionalNodes; import graphql.execution.ValuesResolver; +import graphql.execution.conditional.ConditionalNodes; import graphql.introspection.Introspection; import graphql.language.Argument; import graphql.language.Directive; @@ -68,7 +68,9 @@ public TraversalControl visitDirective(Directive node, TraverserContext co @Override public TraversalControl visitInlineFragment(InlineFragment inlineFragment, TraverserContext context) { - if (!conditionalNodes.shouldInclude(variables, inlineFragment.getDirectives())) { + QueryTraversalContext parentEnv = context.getVarFromParents(QueryTraversalContext.class); + GraphQLContext graphQLContext = parentEnv.getGraphQLContext(); + if (!conditionalNodes.shouldInclude(inlineFragment, variables, null, graphQLContext)) { return TraversalControl.ABORT; } @@ -82,7 +84,6 @@ public TraversalControl visitInlineFragment(InlineFragment inlineFragment, Trave preOrderCallback.visitInlineFragment(inlineFragmentEnvironment); // inline fragments are allowed not have type conditions, if so the parent type counts - QueryTraversalContext parentEnv = context.getVarFromParents(QueryTraversalContext.class); GraphQLCompositeType fragmentCondition; if (inlineFragment.getTypeCondition() != null) { @@ -92,17 +93,19 @@ public TraversalControl visitInlineFragment(InlineFragment inlineFragment, Trave fragmentCondition = parentEnv.getUnwrappedOutputType(); } // for unions we only have other fragments inside - context.setVar(QueryTraversalContext.class, new QueryTraversalContext(fragmentCondition, parentEnv.getEnvironment(), inlineFragment)); + context.setVar(QueryTraversalContext.class, new QueryTraversalContext(fragmentCondition, parentEnv.getEnvironment(), inlineFragment, graphQLContext)); return TraversalControl.CONTINUE; } @Override - public TraversalControl visitFragmentDefinition(FragmentDefinition node, TraverserContext context) { - if (!conditionalNodes.shouldInclude(variables, node.getDirectives())) { + public TraversalControl visitFragmentDefinition(FragmentDefinition fragmentDefinition, TraverserContext context) { + QueryTraversalContext parentEnv = context.getVarFromParents(QueryTraversalContext.class); + GraphQLContext graphQLContext = parentEnv.getGraphQLContext(); + if (!conditionalNodes.shouldInclude(fragmentDefinition, variables, null, graphQLContext)) { return TraversalControl.ABORT; } - QueryVisitorFragmentDefinitionEnvironment fragmentEnvironment = new QueryVisitorFragmentDefinitionEnvironmentImpl(node, context, schema); + QueryVisitorFragmentDefinitionEnvironment fragmentEnvironment = new QueryVisitorFragmentDefinitionEnvironmentImpl(fragmentDefinition, context, schema); if (context.getPhase() == LEAVE) { postOrderCallback.visitFragmentDefinition(fragmentEnvironment); @@ -110,20 +113,21 @@ public TraversalControl visitFragmentDefinition(FragmentDefinition node, Travers } preOrderCallback.visitFragmentDefinition(fragmentEnvironment); - QueryTraversalContext parentEnv = context.getVarFromParents(QueryTraversalContext.class); - GraphQLCompositeType typeCondition = (GraphQLCompositeType) schema.getType(node.getTypeCondition().getName()); - context.setVar(QueryTraversalContext.class, new QueryTraversalContext(typeCondition, parentEnv.getEnvironment(), node)); + GraphQLCompositeType typeCondition = (GraphQLCompositeType) schema.getType(fragmentDefinition.getTypeCondition().getName()); + context.setVar(QueryTraversalContext.class, new QueryTraversalContext(typeCondition, parentEnv.getEnvironment(), fragmentDefinition, graphQLContext)); return TraversalControl.CONTINUE; } @Override public TraversalControl visitFragmentSpread(FragmentSpread fragmentSpread, TraverserContext context) { - if (!conditionalNodes.shouldInclude(variables, fragmentSpread.getDirectives())) { + QueryTraversalContext parentEnv = context.getVarFromParents(QueryTraversalContext.class); + GraphQLContext graphQLContext = parentEnv.getGraphQLContext(); + if (!conditionalNodes.shouldInclude(fragmentSpread, variables, null, graphQLContext)) { return TraversalControl.ABORT; } FragmentDefinition fragmentDefinition = fragmentsByName.get(fragmentSpread.getName()); - if (!conditionalNodes.shouldInclude(variables, fragmentDefinition.getDirectives())) { + if (!conditionalNodes.shouldInclude(fragmentDefinition, variables, null, graphQLContext)) { return TraversalControl.ABORT; } @@ -135,19 +139,19 @@ public TraversalControl visitFragmentSpread(FragmentSpread fragmentSpread, Trave preOrderCallback.visitFragmentSpread(fragmentSpreadEnvironment); - QueryTraversalContext parentEnv = context.getVarFromParents(QueryTraversalContext.class); GraphQLCompositeType typeCondition = (GraphQLCompositeType) schema.getType(fragmentDefinition.getTypeCondition().getName()); assertNotNull(typeCondition, () -> format("Invalid type condition '%s' in fragment '%s'", fragmentDefinition.getTypeCondition().getName(), fragmentDefinition.getName())); - context.setVar(QueryTraversalContext.class, new QueryTraversalContext(typeCondition, parentEnv.getEnvironment(), fragmentDefinition)); + context.setVar(QueryTraversalContext.class, new QueryTraversalContext(typeCondition, parentEnv.getEnvironment(), fragmentDefinition, graphQLContext)); return TraversalControl.CONTINUE; } @Override public TraversalControl visitField(Field field, TraverserContext context) { QueryTraversalContext parentEnv = context.getVarFromParents(QueryTraversalContext.class); + GraphQLContext graphQLContext = parentEnv.getGraphQLContext(); GraphQLFieldDefinition fieldDefinition = Introspection.getFieldDef(schema, (GraphQLCompositeType) unwrapAll(parentEnv.getOutputType()), field.getName()); boolean isTypeNameIntrospectionField = fieldDefinition == schema.getIntrospectionTypenameFieldDefinition(); @@ -174,7 +178,7 @@ public TraversalControl visitField(Field field, TraverserContext context) return TraversalControl.CONTINUE; } - if (!conditionalNodes.shouldInclude(variables, field.getDirectives())) { + if (!conditionalNodes.shouldInclude(field, variables, null, graphQLContext)) { return TraversalControl.ABORT; } @@ -182,8 +186,8 @@ public TraversalControl visitField(Field field, TraverserContext context) GraphQLUnmodifiedType unmodifiedType = unwrapAll(fieldDefinition.getType()); QueryTraversalContext fieldEnv = (unmodifiedType instanceof GraphQLCompositeType) - ? new QueryTraversalContext(fieldDefinition.getType(), environment, field) - : new QueryTraversalContext(null, environment, field);// Terminal (scalar) node, EMPTY FRAME + ? new QueryTraversalContext(fieldDefinition.getType(), environment, field, graphQLContext) + : new QueryTraversalContext(null, environment, field, graphQLContext);// Terminal (scalar) node, EMPTY FRAME context.setVar(QueryTraversalContext.class, fieldEnv); diff --git a/src/main/java/graphql/analysis/QueryTransformer.java b/src/main/java/graphql/analysis/QueryTransformer.java index 35c840bb0..9c45902da 100644 --- a/src/main/java/graphql/analysis/QueryTransformer.java +++ b/src/main/java/graphql/analysis/QueryTransformer.java @@ -1,5 +1,6 @@ package graphql.analysis; +import graphql.GraphQLContext; import graphql.PublicApi; import graphql.language.FragmentDefinition; import graphql.language.Node; @@ -67,7 +68,7 @@ public Node transform(QueryVisitor queryVisitor) { NodeVisitorWithTypeTracking nodeVisitor = new NodeVisitorWithTypeTracking(queryVisitor, noOp, variables, schema, fragmentsByName); Map, Object> rootVars = new LinkedHashMap<>(); - rootVars.put(QueryTraversalContext.class, new QueryTraversalContext(rootParentType, null, null)); + rootVars.put(QueryTraversalContext.class, new QueryTraversalContext(rootParentType, null, null, GraphQLContext.getDefault())); TraverserVisitor nodeTraverserVisitor = new TraverserVisitor() { diff --git a/src/main/java/graphql/analysis/QueryTraversalContext.java b/src/main/java/graphql/analysis/QueryTraversalContext.java index de591141c..8fc02fd58 100644 --- a/src/main/java/graphql/analysis/QueryTraversalContext.java +++ b/src/main/java/graphql/analysis/QueryTraversalContext.java @@ -1,5 +1,6 @@ package graphql.analysis; +import graphql.GraphQLContext; import graphql.Internal; import graphql.language.SelectionSetContainer; import graphql.schema.GraphQLCompositeType; @@ -16,14 +17,17 @@ class QueryTraversalContext { // never used for scalars/enums, always a possibly wrapped composite type private final GraphQLOutputType outputType; private final QueryVisitorFieldEnvironment environment; - private final SelectionSetContainer selectionSetContainer; + private final SelectionSetContainer selectionSetContainer; + private final GraphQLContext graphQLContext; QueryTraversalContext(GraphQLOutputType outputType, QueryVisitorFieldEnvironment environment, - SelectionSetContainer selectionSetContainer) { + SelectionSetContainer selectionSetContainer, + GraphQLContext graphQLContext) { this.outputType = outputType; this.environment = environment; this.selectionSetContainer = selectionSetContainer; + this.graphQLContext = graphQLContext; } public GraphQLOutputType getOutputType() { @@ -34,13 +38,15 @@ public GraphQLCompositeType getUnwrappedOutputType() { return (GraphQLCompositeType) GraphQLTypeUtil.unwrapAll(outputType); } - public QueryVisitorFieldEnvironment getEnvironment() { return environment; } - public SelectionSetContainer getSelectionSetContainer() { - + public SelectionSetContainer getSelectionSetContainer() { return selectionSetContainer; } + + public GraphQLContext getGraphQLContext() { + return graphQLContext; + } } diff --git a/src/main/java/graphql/analysis/QueryTraverser.java b/src/main/java/graphql/analysis/QueryTraverser.java index 14d873f59..0ec067595 100644 --- a/src/main/java/graphql/analysis/QueryTraverser.java +++ b/src/main/java/graphql/analysis/QueryTraverser.java @@ -177,7 +177,7 @@ private List childrenOf(Node node) { private Object visitImpl(QueryVisitor visitFieldCallback, Boolean preOrder) { Map, Object> rootVars = new LinkedHashMap<>(); - rootVars.put(QueryTraversalContext.class, new QueryTraversalContext(rootParentType, null, null)); + rootVars.put(QueryTraversalContext.class, new QueryTraversalContext(rootParentType, null, null, GraphQLContext.getDefault())); QueryVisitor preOrderCallback; QueryVisitor postOrderCallback; diff --git a/src/main/java/graphql/execution/ConditionalNodes.java b/src/main/java/graphql/execution/ConditionalNodes.java deleted file mode 100644 index a9e3ca733..000000000 --- a/src/main/java/graphql/execution/ConditionalNodes.java +++ /dev/null @@ -1,43 +0,0 @@ -package graphql.execution; - -import graphql.Assert; -import graphql.GraphQLContext; -import graphql.Internal; -import graphql.language.Directive; -import graphql.language.NodeUtil; - -import java.util.List; -import java.util.Locale; -import java.util.Map; - -import static graphql.Directives.IncludeDirective; -import static graphql.Directives.SkipDirective; - -@Internal -public class ConditionalNodes { - - public boolean shouldInclude(Map variables, List directives) { - // shortcut on no directives - if (directives.isEmpty()) { - return true; - } - boolean skip = getDirectiveResult(variables, directives, SkipDirective.getName(), false); - if (skip) { - return false; - } - - return getDirectiveResult(variables, directives, IncludeDirective.getName(), true); - } - - private boolean getDirectiveResult(Map variables, List directives, String directiveName, boolean defaultValue) { - Directive foundDirective = NodeUtil.findNodeByName(directives, directiveName); - if (foundDirective != null) { - Map argumentValues = ValuesResolver.getArgumentValues(SkipDirective.getArguments(), foundDirective.getArguments(), CoercedVariables.of(variables), GraphQLContext.getDefault(), Locale.getDefault()); - Object flag = argumentValues.get("if"); - Assert.assertTrue(flag instanceof Boolean, () -> String.format("The '%s' directive MUST have a value for the 'if' argument", directiveName)); - return (Boolean) flag; - } - return defaultValue; - } - -} diff --git a/src/main/java/graphql/execution/Execution.java b/src/main/java/graphql/execution/Execution.java index 401258ded..d352e4c81 100644 --- a/src/main/java/graphql/execution/Execution.java +++ b/src/main/java/graphql/execution/Execution.java @@ -134,7 +134,8 @@ private CompletableFuture executeOperation(ExecutionContext exe .schema(executionContext.getGraphQLSchema()) .objectType(operationRootType) .fragments(executionContext.getFragmentsByName()) - .variables(executionContext.getVariables()) + .variables(executionContext.getCoercedVariables().toMap()) + .graphQLContext(graphQLContext) .build(); MergedSelectionSet fields = fieldCollector.collectFields(collectorParameters, operationDefinition.getSelectionSet()); diff --git a/src/main/java/graphql/execution/FieldCollector.java b/src/main/java/graphql/execution/FieldCollector.java index d32218bfc..a6f1310a8 100644 --- a/src/main/java/graphql/execution/FieldCollector.java +++ b/src/main/java/graphql/execution/FieldCollector.java @@ -2,6 +2,7 @@ import graphql.Internal; +import graphql.execution.conditional.ConditionalNodes; import graphql.language.Field; import graphql.language.FragmentDefinition; import graphql.language.FragmentSpread; @@ -76,13 +77,19 @@ private void collectFragmentSpread(FieldCollectorParameters parameters, Set visitedFragments, Map fields, InlineFragment inlineFragment) { - if (!conditionalNodes.shouldInclude(parameters.getVariables(), inlineFragment.getDirectives()) || + if (!conditionalNodes.shouldInclude(inlineFragment, + parameters.getVariables(), + parameters.getGraphQLSchema(), + parameters.getGraphQLContext()) || !doesFragmentConditionMatch(parameters, inlineFragment)) { return; } @@ -100,7 +110,10 @@ private void collectInlineFragment(FieldCollectorParameters parameters, Set fields, Field field) { - if (!conditionalNodes.shouldInclude(parameters.getVariables(), field.getDirectives())) { + if (!conditionalNodes.shouldInclude(field, + parameters.getVariables(), + parameters.getGraphQLSchema(), + parameters.getGraphQLContext())) { return; } String name = field.getResultKey(); diff --git a/src/main/java/graphql/execution/FieldCollectorParameters.java b/src/main/java/graphql/execution/FieldCollectorParameters.java index b1878ff2a..c0a23404a 100644 --- a/src/main/java/graphql/execution/FieldCollectorParameters.java +++ b/src/main/java/graphql/execution/FieldCollectorParameters.java @@ -1,6 +1,7 @@ package graphql.execution; import graphql.Assert; +import graphql.GraphQLContext; import graphql.Internal; import graphql.language.FragmentDefinition; import graphql.schema.GraphQLObjectType; @@ -17,6 +18,7 @@ public class FieldCollectorParameters { private final Map fragmentsByName; private final Map variables; private final GraphQLObjectType objectType; + private final GraphQLContext graphQLContext; public GraphQLSchema getGraphQLSchema() { return graphQLSchema; @@ -34,11 +36,16 @@ public GraphQLObjectType getObjectType() { return objectType; } - private FieldCollectorParameters(GraphQLSchema graphQLSchema, Map variables, Map fragmentsByName, GraphQLObjectType objectType) { - this.fragmentsByName = fragmentsByName; - this.graphQLSchema = graphQLSchema; - this.variables = variables; - this.objectType = objectType; + public GraphQLContext getGraphQLContext() { + return graphQLContext; + } + + private FieldCollectorParameters(Builder builder) { + this.fragmentsByName = builder.fragmentsByName; + this.graphQLSchema = builder.graphQLSchema; + this.variables = builder.variables; + this.objectType = builder.objectType; + this.graphQLContext = builder.graphQLContext; } public static Builder newParameters() { @@ -50,6 +57,7 @@ public static class Builder { private Map fragmentsByName; private Map variables; private GraphQLObjectType objectType; + private GraphQLContext graphQLContext = GraphQLContext.getDefault(); /** * @see FieldCollectorParameters#newParameters() @@ -68,6 +76,11 @@ public Builder objectType(GraphQLObjectType objectType) { return this; } + public Builder graphQLContext(GraphQLContext graphQLContext) { + this.graphQLContext = graphQLContext; + return this; + } + public Builder fragments(Map fragmentsByName) { this.fragmentsByName = fragmentsByName; return this; @@ -80,7 +93,7 @@ public Builder variables(Map variables) { public FieldCollectorParameters build() { Assert.assertNotNull(graphQLSchema, () -> "You must provide a schema"); - return new FieldCollectorParameters(graphQLSchema, variables, fragmentsByName, objectType); + return new FieldCollectorParameters(this); } } diff --git a/src/main/java/graphql/execution/conditional/ConditionalNodeDecision.java b/src/main/java/graphql/execution/conditional/ConditionalNodeDecision.java new file mode 100644 index 000000000..69afc6bbc --- /dev/null +++ b/src/main/java/graphql/execution/conditional/ConditionalNodeDecision.java @@ -0,0 +1,23 @@ +package graphql.execution.conditional; + +import graphql.ExperimentalApi; + +/** + * This callback interface allows custom implementations to decide if a field is included in a query or not. + *

+ * The default `@skip / @include` is built in, but you can create your own implementations to allow you to make + * decisions on whether fields are considered part of a query. + */ +@ExperimentalApi +public interface ConditionalNodeDecision { + + /** + * This is called to decide if a {@link graphql.language.Node} should be included or not + * + * @param decisionEnv ghe environment you can use to make the decision + * + * @return true if the node should be included or false if it should be excluded + */ + boolean shouldInclude(ConditionalNodeDecisionEnvironment decisionEnv); +} + diff --git a/src/main/java/graphql/execution/conditional/ConditionalNodeDecisionEnvironment.java b/src/main/java/graphql/execution/conditional/ConditionalNodeDecisionEnvironment.java new file mode 100644 index 000000000..e0e116a1b --- /dev/null +++ b/src/main/java/graphql/execution/conditional/ConditionalNodeDecisionEnvironment.java @@ -0,0 +1,48 @@ +package graphql.execution.conditional; + +import graphql.GraphQLContext; +import graphql.execution.CoercedVariables; +import graphql.language.Directive; +import graphql.language.DirectivesContainer; +import graphql.schema.GraphQLSchema; +import org.jetbrains.annotations.Nullable; + +import java.util.List; + +/** + * The parameters given to a {@link ConditionalNodeDecision} + */ +public interface ConditionalNodeDecisionEnvironment { + + /** + * This is an AST {@link graphql.language.Node} that has directives on it. + * {@link graphql.language.Field}, @{@link graphql.language.FragmentSpread} and + * {@link graphql.language.InlineFragment} are examples of nodes + * that can be conditionally included. + * + * @return the AST element in question + */ + DirectivesContainer getDirectivesContainer(); + + /** + * @return the list of directives associated with the {@link #getDirectivesContainer()} + */ + default List getDirectives() { + return getDirectivesContainer().getDirectives(); + } + + /** + * @return a map of the current variables + */ + CoercedVariables getVariables(); + + /** + * @return the {@link GraphQLSchema} in question - this can be null for certain call paths + */ + @Nullable GraphQLSchema getGraphQlSchema(); + + /** + * @return a graphql context + */ + GraphQLContext getGraphQLContext(); +} diff --git a/src/main/java/graphql/execution/conditional/ConditionalNodes.java b/src/main/java/graphql/execution/conditional/ConditionalNodes.java new file mode 100644 index 000000000..9c90deead --- /dev/null +++ b/src/main/java/graphql/execution/conditional/ConditionalNodes.java @@ -0,0 +1,102 @@ +package graphql.execution.conditional; + +import graphql.Assert; +import graphql.GraphQLContext; +import graphql.Internal; +import graphql.execution.CoercedVariables; +import graphql.execution.ValuesResolver; +import graphql.language.Directive; +import graphql.language.DirectivesContainer; +import graphql.language.NodeUtil; +import graphql.schema.GraphQLSchema; + +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import static graphql.Directives.IncludeDirective; +import static graphql.Directives.SkipDirective; + +@Internal +public class ConditionalNodes { + + + public boolean shouldInclude(DirectivesContainer element, + Map variables, + GraphQLSchema graphQLSchema, + GraphQLContext graphQLContext + ) { + // + // call the base @include / @skip first + if (!shouldInclude(variables, element.getDirectives())) { + return false; + } + // + // if they have declared a decision callback, then we will use it but we expect this to be mostly + // empty and hence the cost is a map lookup. + if (graphQLContext != null) { + ConditionalNodeDecision conditionalDecision = graphQLContext.get(ConditionalNodeDecision.class); + if (conditionalDecision != null) { + return customShouldInclude(variables, element, graphQLSchema, graphQLContext, conditionalDecision); + } + } + // if no one says otherwise, the node is considered included + return true; + } + + private boolean customShouldInclude(Map variables, + DirectivesContainer element, + GraphQLSchema graphQLSchema, + GraphQLContext graphQLContext, + ConditionalNodeDecision conditionalDecision + ) { + CoercedVariables coercedVariables = CoercedVariables.of(variables); + return conditionalDecision.shouldInclude(new ConditionalNodeDecisionEnvironment() { + @Override + public DirectivesContainer getDirectivesContainer() { + return element; + } + + @Override + public CoercedVariables getVariables() { + return coercedVariables; + } + + @Override + public GraphQLSchema getGraphQlSchema() { + return graphQLSchema; + } + + @Override + public GraphQLContext getGraphQLContext() { + return graphQLContext; + } + }); + } + + + private boolean shouldInclude(Map variables, List directives) { + // shortcut on no directives + if (directives.isEmpty()) { + return true; + } + boolean skip = getDirectiveResult(variables, directives, SkipDirective.getName(), false); + if (skip) { + return false; + } + + return getDirectiveResult(variables, directives, IncludeDirective.getName(), true); + } + + private boolean getDirectiveResult(Map variables, List directives, String directiveName, boolean defaultValue) { + Directive foundDirective = NodeUtil.findNodeByName(directives, directiveName); + if (foundDirective != null) { + Map argumentValues = ValuesResolver.getArgumentValues(SkipDirective.getArguments(), foundDirective.getArguments(), CoercedVariables.of(variables), GraphQLContext.getDefault(), Locale.getDefault()); + Object flag = argumentValues.get("if"); + Assert.assertTrue(flag instanceof Boolean, () -> String.format("The '%s' directive MUST have a value for the 'if' argument", directiveName)); + return (Boolean) flag; + } + return defaultValue; + } + +} diff --git a/src/main/java/graphql/normalized/ExecutableNormalizedOperationFactory.java b/src/main/java/graphql/normalized/ExecutableNormalizedOperationFactory.java index d3d603568..d5e0d4db8 100644 --- a/src/main/java/graphql/normalized/ExecutableNormalizedOperationFactory.java +++ b/src/main/java/graphql/normalized/ExecutableNormalizedOperationFactory.java @@ -8,10 +8,10 @@ import graphql.PublicApi; import graphql.collect.ImmutableKit; import graphql.execution.CoercedVariables; -import graphql.execution.ConditionalNodes; import graphql.execution.MergedField; import graphql.execution.RawVariables; import graphql.execution.ValuesResolver; +import graphql.execution.conditional.ConditionalNodes; import graphql.execution.directives.QueryDirectives; import graphql.execution.directives.QueryDirectivesImpl; import graphql.introspection.Introspection; @@ -520,12 +520,18 @@ private void collectFragmentSpread(FieldCollectorNormalizedQueryParams parameter FragmentSpread fragmentSpread, Set possibleObjects ) { - if (!conditionalNodes.shouldInclude(parameters.getCoercedVariableValues(), fragmentSpread.getDirectives())) { + if (!conditionalNodes.shouldInclude(fragmentSpread, + parameters.getCoercedVariableValues(), + parameters.getGraphQLSchema(), + parameters.getGraphQLContext())) { return; } FragmentDefinition fragmentDefinition = assertNotNull(parameters.getFragmentsByName().get(fragmentSpread.getName())); - if (!conditionalNodes.shouldInclude(parameters.getCoercedVariableValues(), fragmentDefinition.getDirectives())) { + if (!conditionalNodes.shouldInclude(fragmentDefinition, + parameters.getCoercedVariableValues(), + parameters.getGraphQLSchema(), + parameters.getGraphQLContext())) { return; } GraphQLCompositeType newAstTypeCondition = (GraphQLCompositeType) assertNotNull(parameters.getGraphQLSchema().getType(fragmentDefinition.getTypeCondition().getName())); @@ -540,7 +546,7 @@ private void collectInlineFragment(FieldCollectorNormalizedQueryParams parameter Set possibleObjects, GraphQLCompositeType astTypeCondition ) { - if (!conditionalNodes.shouldInclude(parameters.getCoercedVariableValues(), inlineFragment.getDirectives())) { + if (!conditionalNodes.shouldInclude(inlineFragment, parameters.getCoercedVariableValues(), parameters.getGraphQLSchema(), parameters.getGraphQLContext())) { return; } Set newPossibleObjects = possibleObjects; @@ -560,7 +566,10 @@ private void collectField(FieldCollectorNormalizedQueryParams parameters, Set possibleObjectTypes, GraphQLCompositeType astTypeCondition ) { - if (!conditionalNodes.shouldInclude(parameters.getCoercedVariableValues(), field.getDirectives())) { + if (!conditionalNodes.shouldInclude(field, + parameters.getCoercedVariableValues(), + parameters.getGraphQLSchema(), + parameters.getGraphQLContext())) { return; } // this means there is actually no possible type for this field, and we are done diff --git a/src/test/groovy/graphql/execution/ConditionalNodesTest.groovy b/src/test/groovy/graphql/execution/ConditionalNodesTest.groovy index 7c7660072..629f1fde9 100644 --- a/src/test/groovy/graphql/execution/ConditionalNodesTest.groovy +++ b/src/test/groovy/graphql/execution/ConditionalNodesTest.groovy @@ -1,9 +1,18 @@ package graphql.execution - +import graphql.ExecutionInput +import graphql.GraphQLContext +import graphql.TestUtil +import graphql.execution.conditional.ConditionalNodeDecision +import graphql.execution.conditional.ConditionalNodeDecisionEnvironment +import graphql.execution.conditional.ConditionalNodes import graphql.language.Argument import graphql.language.BooleanValue import graphql.language.Directive +import graphql.language.Field +import graphql.language.NodeUtil +import graphql.schema.DataFetcher +import graphql.schema.DataFetchingEnvironment import spock.lang.Specification class ConditionalNodesTest extends Specification { @@ -13,11 +22,43 @@ class ConditionalNodesTest extends Specification { def variables = new LinkedHashMap() ConditionalNodes conditionalNodes = new ConditionalNodes() - def argument = Argument.newArgument("if", new BooleanValue(true)).build() - def directives = [Directive.newDirective().name("skip").arguments([argument]).build()] + def directives = directive("skip", ifArg(true)) + + expect: + !conditionalNodes.shouldInclude(mkField(directives), variables, null, GraphQLContext.getDefault()) + } + + def "should include true for skip = false"() { + given: + def variables = new LinkedHashMap() + ConditionalNodes conditionalNodes = new ConditionalNodes() + + def directives = directive("skip", ifArg(false)) expect: - !conditionalNodes.shouldInclude(variables, directives) + conditionalNodes.shouldInclude(mkField(directives), variables, null, GraphQLContext.getDefault()) + } + + def "should include false for include = false"() { + given: + def variables = new LinkedHashMap() + ConditionalNodes conditionalNodes = new ConditionalNodes() + + def directives = directive("include", ifArg(false)) + + expect: + !conditionalNodes.shouldInclude(mkField(directives), variables, null, GraphQLContext.getDefault()) + } + + def "should include true for include = true"() { + given: + def variables = new LinkedHashMap() + ConditionalNodes conditionalNodes = new ConditionalNodes() + + def directives = directive("include", ifArg(true)) + + expect: + conditionalNodes.shouldInclude(mkField(directives), variables, null, GraphQLContext.getDefault()) } def "no directives means include"() { @@ -26,6 +67,133 @@ class ConditionalNodesTest extends Specification { ConditionalNodes conditionalNodes = new ConditionalNodes() expect: - conditionalNodes.shouldInclude(variables, []) + conditionalNodes.shouldInclude(mkField([]), variables, null, GraphQLContext.getDefault()) + } + + + def "allows a custom implementation to check conditional nodes"() { + given: + def variables = ["x": "y"] + def graphQLSchema = TestUtil.schema("type Query { f : String} ") + ConditionalNodes conditionalNodes = new ConditionalNodes() + + def graphQLContext = GraphQLContext.getDefault() + + def directives = directive("featureFlag", ifArg(true)) + def field = mkField(directives) + + def called = false + ConditionalNodeDecision conditionalDecision = new ConditionalNodeDecision() { + @Override + boolean shouldInclude(ConditionalNodeDecisionEnvironment env) { + called = true + assert env.variables.toMap() == variables + assert env.directivesContainer == field + assert env.graphQlSchema == graphQLSchema + assert env.graphQLContext.get("assert") != null + return false + } + } + graphQLContext.put(ConditionalNodeDecision.class, conditionalDecision) + graphQLContext.put("assert", true) + expect: + + !conditionalNodes.shouldInclude(field, variables, graphQLSchema, graphQLContext) + called == true + } + + def "integration test showing conditional nodes can be custom included"() { + + def sdl = """ + + directive @featureFlag(flagName: String!) repeatable on FIELD + + type Query { + in : String + out : String + } + """ + DataFetcher df = { DataFetchingEnvironment env -> env.getFieldDefinition().name } + def graphQL = TestUtil.graphQL(sdl, [Query: ["in": df, "out": df]]).build() + ConditionalNodeDecision customDecision = new ConditionalNodeDecision() { + @Override + boolean shouldInclude(ConditionalNodeDecisionEnvironment env) { + + Directive foundDirective = NodeUtil.findNodeByName(env.getDirectives(), "featureFlag") + if (foundDirective != null) { + + def arguments = env.getGraphQlSchema().getDirective("featureFlag") + .getArguments() + Map argumentValues = ValuesResolver.getArgumentValues( + arguments, foundDirective.getArguments(), + env.variables, env.graphQLContext, Locale.getDefault()) + Object flagName = argumentValues.get("flagName") + return String.valueOf(flagName) == "ON" + } + return true + } + } + + def contextMap = [:] + contextMap.put(ConditionalNodeDecision.class, customDecision) + + when: + def ei = ExecutionInput.newExecutionInput() + .graphQLContext(contextMap) + .query(""" + query q { + in + out @featureFlag(flagName : "OFF") + } + """ + ).build() + def er = graphQL.execute(ei) + + then: + er["data"] == ["in": "in"] + + when: + ei = ExecutionInput.newExecutionInput() + .graphQLContext(contextMap) + .query(""" + query q { + in + out @featureFlag(flagName : "ON") + } + """ + ).build() + er = graphQL.execute(ei) + + then: + er["data"] == ["in": "in", "out": "out"] + + when: + ei = ExecutionInput.newExecutionInput() + .graphQLContext(contextMap) + .query(''' + query vars_should_work($v : String!) { + in + out @featureFlag(flagName : $v) + } + ''' + ) + .variables([v: "ON"]) + .build() + er = graphQL.execute(ei) + + then: + er["data"] == ["in": "in", "out": "out"] + } + + private ArrayList directive(String name, Argument argument) { + [Directive.newDirective().name(name).arguments([argument]).build()] + } + + private Argument ifArg(Boolean b) { + Argument.newArgument("if", new BooleanValue(b)).build() + } + + Field mkField(List directives) { + return Field.newField("name").directives(directives).build() } }