diff --git a/src/main/java/graphql/execution/ValuesResolver.java b/src/main/java/graphql/execution/ValuesResolver.java index cc59de6f8..27487ca55 100644 --- a/src/main/java/graphql/execution/ValuesResolver.java +++ b/src/main/java/graphql/execution/ValuesResolver.java @@ -27,6 +27,7 @@ import graphql.schema.GraphQLScalarType; import graphql.schema.GraphQLSchema; import graphql.schema.GraphQLType; +import graphql.schema.GraphQLTypeUtil; import graphql.schema.InputValueWithState; import graphql.schema.visibility.GraphqlFieldVisibility; import org.jetbrains.annotations.NotNull; @@ -376,8 +377,9 @@ private static Map getArgumentValuesImpl( coercedValues.put(argumentName, value); } // @oneOf input must be checked now that all variables and literals have been converted - if (argumentType instanceof GraphQLInputObjectType) { - GraphQLInputObjectType inputObjectType = (GraphQLInputObjectType) argumentType; + GraphQLType unwrappedType = GraphQLTypeUtil.unwrapNonNull(argumentType); + if (unwrappedType instanceof GraphQLInputObjectType) { + GraphQLInputObjectType inputObjectType = (GraphQLInputObjectType) unwrappedType; if (inputObjectType.isOneOf() && ! ValuesResolverConversion.isNullValue(value)) { validateOneOfInputTypes(inputObjectType, argumentValue, argumentName, value, locale); } diff --git a/src/test/groovy/graphql/execution/ValuesResolverTest.groovy b/src/test/groovy/graphql/execution/ValuesResolverTest.groovy index dea17a2c4..0e53791a3 100644 --- a/src/test/groovy/graphql/execution/ValuesResolverTest.groovy +++ b/src/test/groovy/graphql/execution/ValuesResolverTest.groovy @@ -23,6 +23,7 @@ import graphql.language.Value import graphql.language.VariableDefinition import graphql.language.VariableReference import graphql.schema.CoercingParseValueException +import graphql.schema.GraphQLNonNull import spock.lang.Specification import spock.lang.Unroll @@ -360,16 +361,26 @@ class ValuesResolverTest extends Specification { .type(GraphQLInt) .build()) .build() - def fieldArgument = newArgument().name("arg").type(inputObjectType).build() - when: def argument = new Argument("arg", inputValue) + + when: + def fieldArgument = newArgument().name("arg").type(inputObjectType).build() ValuesResolver.getArgumentValues([fieldArgument], [argument], variables, graphQLContext, locale) then: def e = thrown(OneOfTooManyKeysException) e.message == "Exactly one key must be specified for OneOf type 'oneOfInputObject'." + when: "input type is wrapped in non-null" + def nonNullInputObjectType = GraphQLNonNull.nonNull(inputObjectType) + def fieldArgumentNonNull = newArgument().name("arg").type(nonNullInputObjectType).build() + ValuesResolver.getArgumentValues([fieldArgumentNonNull], [argument], variables, graphQLContext, locale) + + then: + def eNonNull = thrown(OneOfTooManyKeysException) + eNonNull.message == "Exactly one key must be specified for OneOf type 'oneOfInputObject'." + where: // from https://github.com/graphql/graphql-spec/pull/825/files#diff-30a69c5a5eded8e1aea52e53dad1181e6ec8f549ca2c50570b035153e2de1c43R1692 testCase | inputValue | variables @@ -502,6 +513,44 @@ class ValuesResolverTest extends Specification { } + def "getArgumentValues: invalid oneOf input no values where passed - #testCase"() { + given: "schema defining input object" + def inputObjectType = newInputObject() + .name("oneOfInputObject") + .withAppliedDirective(Directives.OneOfDirective.toAppliedDirective()) + .field(newInputObjectField() + .name("a") + .type(GraphQLString) + .build()) + .field(newInputObjectField() + .name("b") + .type(GraphQLInt) + .build()) + .build() + def fieldArgument = newArgument().name("arg").type(inputObjectType).build() + + when: + def argument = new Argument("arg", inputValue) + ValuesResolver.getArgumentValues([fieldArgument], [argument], variables, graphQLContext, locale) + + then: + def e = thrown(OneOfNullValueException) + e.message == "OneOf type field 'oneOfInputObject.a' must be non-null." + + where: + // from https://github.com/graphql/graphql-spec/pull/825/files#diff-30a69c5a5eded8e1aea52e53dad1181e6ec8f549ca2c50570b035153e2de1c43R1692 + testCase | inputValue | variables + + '`{ a: null }` {}' | buildObjectLiteral([ + a: NullValue.of() + ]) | CoercedVariables.emptyVariables() + + '`{ a: $var }` { var : null}' | buildObjectLiteral([ + a: VariableReference.of("var") + ]) | CoercedVariables.of(["var": null]) + + } + def "getVariableValues: enum as variable input"() { given: def enumDef = newEnum() @@ -839,4 +888,4 @@ class ValuesResolverTest extends Specification { executionResult.errors[0].message == "Variable 'input' has an invalid value: Expected a value that can be converted to type 'Float' but it was a 'String'" executionResult.errors[0].locations == [new SourceLocation(2, 35)] } -} \ No newline at end of file +}