diff --git a/src/main/java/graphql/GraphqlErrorBuilder.java b/src/main/java/graphql/GraphqlErrorBuilder.java
index 4cef5beab..db71eb20e 100644
--- a/src/main/java/graphql/GraphqlErrorBuilder.java
+++ b/src/main/java/graphql/GraphqlErrorBuilder.java
@@ -9,6 +9,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import static graphql.Assert.assertNotNull;
@@ -132,6 +133,18 @@ public GraphQLError build() {
return new GraphqlErrorImpl(message, locations, errorType, path, extensions);
}
+ /**
+ * A simple implementation of a {@link GraphQLError}.
+ *
+ * This provides {@link #hashCode()} and {@link #equals(Object)} methods that afford comparison with other
+ * {@link GraphQLError} implementations. However, the values provided in the following fields must
+ * in turn implement {@link #hashCode()} and {@link #equals(Object)} for this to function correctly:
+ *
+ * the values in the {@link #getPath()} {@link List}.
+ * the {@link #getErrorType()} {@link ErrorClassification}.
+ * the values in the {@link #getExtensions()} {@link Map}.
+ *
+ */
private static class GraphqlErrorImpl implements GraphQLError {
private final String message;
private final List locations;
@@ -176,6 +189,28 @@ public Map getExtensions() {
public String toString() {
return message;
}
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (!(o instanceof GraphQLError)) return false;
+ GraphQLError that = (GraphQLError) o;
+ return Objects.equals(getMessage(), that.getMessage())
+ && Objects.equals(getLocations(), that.getLocations())
+ && Objects.equals(getErrorType(), that.getErrorType())
+ && Objects.equals(getPath(), that.getPath())
+ && Objects.equals(getExtensions(), that.getExtensions());
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(
+ getMessage(),
+ getLocations(),
+ getErrorType(),
+ getPath(),
+ getExtensions());
+ }
}
/**
diff --git a/src/test/groovy/graphql/GraphqlErrorBuilderTest.groovy b/src/test/groovy/graphql/GraphqlErrorBuilderTest.groovy
index 944e1fef3..8713100d0 100644
--- a/src/test/groovy/graphql/GraphqlErrorBuilderTest.groovy
+++ b/src/test/groovy/graphql/GraphqlErrorBuilderTest.groovy
@@ -152,4 +152,66 @@ class GraphqlErrorBuilderTest extends Specification {
error.path == null
}
-}
\ No newline at end of file
+
+ def "implements equals/hashCode correctly for matching errors"() {
+ when:
+ def firstError = toGraphQLError(first)
+ def secondError = toGraphQLError(second)
+
+ then:
+ firstError == secondError
+ firstError.hashCode() == secondError.hashCode()
+
+ where:
+ first | second
+ [message: "msg"] | [message: "msg"]
+ [message: "msg", locations: [new SourceLocation(1, 2)]] | [message: "msg", locations: [new SourceLocation(1, 2)]]
+ [message: "msg", errorType: ErrorType.InvalidSyntax] | [message: "msg", errorType: ErrorType.InvalidSyntax]
+ [message: "msg", path: ["items", 1, "item"]] | [message: "msg", path: ["items", 1, "item"]]
+ [message: "msg", extensions: [aBoolean: true, aString: "foo"]] | [message: "msg", extensions: [aBoolean: true, aString: "foo"]]
+ }
+
+ def "implements equals/hashCode correctly for different errors"() {
+ when:
+ def firstError = toGraphQLError(first)
+ def secondError = toGraphQLError(second)
+
+ then:
+ firstError != secondError
+ firstError.hashCode() != secondError.hashCode()
+
+ where:
+ first | second
+ [message: "msg"] | [message: "different msg"]
+ [message: "msg", locations: [new SourceLocation(1, 2)]] | [message: "msg", locations: [new SourceLocation(3, 4)]]
+ [message: "msg", errorType: ErrorType.InvalidSyntax] | [message: "msg", errorType: ErrorType.DataFetchingException]
+ [message: "msg", path: ["items", "1", "item"]] | [message: "msg", path: ["items"]]
+ [message: "msg", extensions: [aBoolean: false]] | [message: "msg", extensions: [aString: "foo"]]
+ }
+
+ private static GraphQLError toGraphQLError(Map errorFields) {
+ def errorBuilder = GraphQLError.newError();
+ errorFields.forEach { key, value ->
+ if (value != null) {
+ switch (key) {
+ case "message":
+ errorBuilder.message(value as String);
+ break;
+ case "locations":
+ errorBuilder.locations(value as List);
+ break;
+ case "errorType":
+ errorBuilder.errorType(value as ErrorClassification);
+ break;
+ case "path":
+ errorBuilder.path(value as List);
+ break;
+ case "extensions":
+ errorBuilder.extensions(value as Map);
+ break;
+ }
+ }
+ }
+ return errorBuilder.build();
+ }
+}