diff --git a/pkg/codegen/schema.go b/pkg/codegen/schema.go index c3983d66e1..5494f85594 100644 --- a/pkg/codegen/schema.go +++ b/pkg/codegen/schema.go @@ -1,7 +1,9 @@ package codegen import ( + "encoding/base64" "fmt" + "strconv" "strings" "github.com/getkin/kin-openapi/openapi3" @@ -21,6 +23,8 @@ type Schema struct { AdditionalTypes []TypeDefinition // We may need to generate auxiliary helper types, stored here SkipOptionalPointer bool // Some types don't need a * in front when they're optional + ValidationTags []string //holds validation tags + ValidationPattern string //holds a regex pattern, if it defined } func (s Schema) IsRef() bool { @@ -169,6 +173,12 @@ func GenerateGoSchema(sref *openapi3.SchemaRef, path []string) (Schema, error) { required := StringInArray(pName, schema.Required) + //append validation tags + pSchema.ValidationTags = GenerateValidationTags(p.Value,required) + if p.Value.Pattern != ""{ + pSchema.ValidationPattern = p.Value.Pattern + } + if pSchema.HasAdditionalProperties && pSchema.RefType == "" { // If we have fields present which have additional properties, // but are not a pre-defined type, we need to define a type @@ -278,6 +288,24 @@ func GenerateGoSchema(sref *openapi3.SchemaRef, path []string) (Schema, error) { return outSchema, nil } +//add validation tags that can be processed by the go validator https://github.com/go-playground/validator +func GenerateValidationTags(schema *openapi3.Schema, required bool) []string { + var validationTags []string + if len(schema.Format) > 1{ + validationTags = append(validationTags,schema.Format) + } + if schema.MinLength > 0{ + validationTags = append(validationTags, "min="+strconv.FormatUint(schema.MinLength,10)) + } + if schema.MaxLength != nil && *schema.MaxLength > 0{ + validationTags = append(validationTags, "max="+strconv.FormatUint(*schema.MaxLength,10)) + } + if required{ + validationTags = append(validationTags,"required") + } + return validationTags +} + // This describes a Schema, a type definition. type SchemaDescriptor struct { Fields []FieldDescriptor @@ -306,11 +334,20 @@ func GenFieldsFromProperties(props []Property) []string { field += fmt.Sprintf("\n%s\n", StringToGoComment(p.Description)) } field += fmt.Sprintf(" %s %s", p.GoFieldName(), p.GoTypeDef()) - if p.Required || p.Nullable { - field += fmt.Sprintf(" `json:\"%s\"`", p.JsonFieldName) - } else { - field += fmt.Sprintf(" `json:\"%s,omitempty\"`", p.JsonFieldName) + validationTags := "" + validationPattern := "" + omitEmpty := ",omitempty" + if p.Required || p.Nullable{ + omitEmpty = "" + } + if p.Schema.ValidationPattern != ""{ + validationPattern = fmt.Sprintf(" pattern:\"%s\"", base64.StdEncoding.EncodeToString([]byte(p.Schema.ValidationPattern))) + p.Schema.ValidationTags = append(p.Schema.ValidationTags,"patternbase64") + } + if p.Schema.ValidationTags != nil{ + validationTags = fmt.Sprintf(" validate:\"%s\"", strings.Join(p.Schema.ValidationTags,",")) } + field += fmt.Sprintf(" `json:\"%s%s\"%s%s`", p.JsonFieldName,omitEmpty,validationTags,validationPattern) fields = append(fields, field) } return fields diff --git a/pkg/middleware/oapi_validate.go b/pkg/middleware/oapi_validate.go index b990333b41..00ea5430e5 100644 --- a/pkg/middleware/oapi_validate.go +++ b/pkg/middleware/oapi_validate.go @@ -17,14 +17,13 @@ package middleware import ( "context" "fmt" - "io/ioutil" - "net/http" - "strings" - "github.com/getkin/kin-openapi/openapi3" "github.com/getkin/kin-openapi/openapi3filter" "github.com/labstack/echo/v4" echomiddleware "github.com/labstack/echo/v4/middleware" + "io/ioutil" + "net/http" + "strings" ) const EchoContextKey = "oapi-codegen/echo-context" @@ -82,6 +81,90 @@ func OapiRequestValidatorWithOptions(swagger *openapi3.Swagger, options *Options } } +// Create a validator from a swagger object, with validation options +//func(next http.Handler) http.Handler +func OapiRequestValidatorWithOptionsHttpHandler(swagger *openapi3.Swagger, options *Options, errorHandler func(err error)http.Handler) func(next http.Handler) http.Handler { + if errorHandler == nil{ + errorHandler = DefaultValidationErrorHandler + } + router := openapi3filter.NewRouter().WithSwagger(swagger) + return func(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + err := ValidateHTTPRequest(r, router, options) + if err != nil { + errorHandler(err).ServeHTTP(w,r) + } + next.ServeHTTP(w,r) + } + return http.HandlerFunc(fn) + } + +} + +func DefaultValidationErrorHandler(err error) http.Handler{ + switch e := err.(type) { + case *openapi3filter.RequestError: + // We've got a bad request + fn := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(e.Error())) + } + return http.HandlerFunc(fn) + case *openapi3filter.SecurityRequirementsError: + fn := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + w.Write([]byte(e.Error())) + } + return http.HandlerFunc(fn) + default: + fn := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(e.Error())) + } + return http.HandlerFunc(fn) + } +} + +// This function is called from the middleware above and actually does the work +// of validating a request. +func ValidateHTTPRequest(r *http.Request, router *openapi3filter.Router, options *Options) error { + + route, pathParams, err := router.FindRoute(r.Method, r.URL) + + // We failed to find a matching route for the request. + if err != nil { + switch e := err.(type) { + case *openapi3filter.RouteError: + // We've got a bad request, the path requested doesn't match + // either server, or path, or something. + return echo.NewHTTPError(http.StatusBadRequest, e.Reason) + default: + // This should never happen today, but if our upstream code changes, + // we don't want to crash the server, so handle the unexpected error. + return echo.NewHTTPError(http.StatusInternalServerError, + fmt.Sprintf("error validating route: %s", err.Error())) + } + } + + validationInput := &openapi3filter.RequestValidationInput{ + Request: r, + PathParams: pathParams, + Route: route, + } + + requestContext := r.Context() + + if options != nil { + validationInput.Options = &options.Options + validationInput.ParamDecoder = options.ParamDecoder + requestContext = context.WithValue(requestContext, UserDataKey, options.UserData) + } + + err = openapi3filter.ValidateRequest(requestContext, validationInput) + //for default http handler errors have to be handled by the error handler function + return err +} + // This function is called from the middleware above and actually does the work // of validating a request. func ValidateRequestFromContext(ctx echo.Context, router *openapi3filter.Router, options *Options) error {