diff --git a/middleware/csrf/custom_errorfunc/main.go b/middleware/csrf/custom_errorfunc/main.go index 1be6481c..f4557e75 100644 --- a/middleware/csrf/custom_errorfunc/main.go +++ b/middleware/csrf/custom_errorfunc/main.go @@ -41,15 +41,25 @@ var ( // myErrFunc is executed when an error occurs in csrf middleware. func myErrFunc(_ context.Context, ctx *app.RequestContext) { err := ctx.Errors.Last() - switch err { - case errMissingForm, errMissingParam, errMissingHeader, errMissingQuery: - ctx.String(http.StatusBadRequest, err.Error()) // extract csrf-token failed - case errMissingSalt: - fmt.Println(err.Error()) - ctx.String(http.StatusInternalServerError, err.Error()) // get salt failed,which is unexpected - case errInvalidToken: - ctx.String(http.StatusBadRequest, err.Error()) // csrf-token is invalid + if err == nil { + return } + + switch err.Err.(type) { + case error: + switch { + case errors.Is(err, errMissingForm), errors.Is(err, errMissingParam), errors.Is(err, errMissingHeader), errors.Is(err, errMissingQuery): + ctx.String(http.StatusBadRequest, err.Error()) // extract csrf-token failed + case errors.Is(err, errMissingSalt): + fmt.Println(err.Error()) + ctx.String(http.StatusInternalServerError, err.Error()) // get salt failed, which is unexpected + case errors.Is(err, errInvalidToken): + ctx.String(http.StatusBadRequest, err.Error()) // csrf-token is invalid + default: + ctx.String(http.StatusInternalServerError, "Unknown error") // handle unknown errors + } + } + ctx.Abort() }