Skip to content

Commit

Permalink
Merge pull request #15 from vardius/hotfix/middleware-by-path
Browse files Browse the repository at this point in the history
Fix middleware for wildcard routes
  • Loading branch information
vardius authored Jan 30, 2020
2 parents 0d328e6 + 7eb5db7 commit 071a686
Show file tree
Hide file tree
Showing 23 changed files with 943 additions and 590 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,8 @@
.glide/

.vscode
.idea

vendor/
vendor/

.history/
24 changes: 14 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,23 @@ import (
"github.com/vardius/gorouter/v4/context"
)

func Index(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "Welcome!\n")
func index(w http.ResponseWriter, _ *http.Request) {
if _, err := fmt.Fprint(w, "Welcome!\n"); err != nil {
panic(err)
}
}

func Hello(w http.ResponseWriter, r *http.Request) {
func hello(w http.ResponseWriter, r *http.Request) {
params, _ := context.Parameters(r.Context())
fmt.Fprintf(w, "hello, %s!\n", params.Value("name"))
if _, err := fmt.Fprintf(w, "hello, %s!\n", params.Value("name")); err != nil {
panic(err)
}
}

func main() {
router := gorouter.New()
router.GET("/", http.HandlerFunc(Index))
router.GET("/hello/{name}", http.HandlerFunc(Hello))
router.GET("/", http.HandlerFunc(index))
router.GET("/hello/{name}", http.HandlerFunc(hello))

log.Fatal(http.ListenAndServe(":8080", router))
}
Expand All @@ -71,19 +75,19 @@ import (
"github.com/vardius/gorouter/v4"
)

func Index(ctx *fasthttp.RequestCtx) {
func index(_ *fasthttp.RequestCtx) {
fmt.Print("Welcome!\n")
}

func Hello(ctx *fasthttp.RequestCtx) {
func hello(ctx *fasthttp.RequestCtx) {
params := ctx.UserValue("params").(context.Params)
fmt.Printf("Hello, %s!\n", params.Value("name"))
}

func main() {
router := gorouter.NewFastHTTPRouter()
router.GET("/", Index)
router.GET("/hello/{name}", Hello)
router.GET("/", index)
router.GET("/hello/{name}", hello)

log.Fatal(fasthttp.ListenAndServe(":8080", router.HandleFastHTTP))
}
Expand Down
4 changes: 2 additions & 2 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Package gorouter provide request router with middleware
Router
The router determines how to handle that request.
The router determines how to handle http request.
GoRouter uses a routing tree. Once one branch of the tree matches, only routes inside that branch are considered,
not any routes after that branch. When instantiating router, the root node of tree is created.
Expand Down Expand Up @@ -31,7 +31,7 @@ A full route definition contain up to three parts:
2. The URL path route. This is matched against the URL passed to the router,
and can contain named wildcard placeholders *(e.g. {placeholder})* to match dynamic parts in the URL.
3. `http.HandleFunc`, which tells the router to handle matched requests to the router with handler.
3. `http.HandlerFunc`, which tells the router to handle matched requests to the router with handler.
Take the following example:
Expand Down
4 changes: 2 additions & 2 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func ExampleMiddlewareFunc_second() {
router := gorouter.New()
router.GET("/hello/{name}", http.HandlerFunc(hello))

// apply middleware to route and all it children
// apply middleware to route and all its children
// can pass as many as you want
router.USE("GET", "/hello/{name}", logger)

Expand Down Expand Up @@ -206,7 +206,7 @@ func ExampleFastHTTPMiddlewareFunc_second() {
router := gorouter.NewFastHTTPRouter()
router.GET("/hello/{name}", hello)

// apply middleware to route and all it children
// apply middleware to route and all its children
// can pass as many as you want
router.USE("GET", "/hello/{name}", logger)

Expand Down
86 changes: 48 additions & 38 deletions fasthttp.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,29 @@ import (
"github.com/valyala/fasthttp"
"github.com/vardius/gorouter/v4/middleware"
"github.com/vardius/gorouter/v4/mux"
pathutils "github.com/vardius/gorouter/v4/path"
)

// NewFastHTTPRouter creates new Router instance, returns pointer
func NewFastHTTPRouter(fs ...FastHTTPMiddlewareFunc) FastHTTPRouter {
globalMiddleware := transformFastHTTPMiddlewareFunc(fs...)
return &fastHTTPRouter{
routes: mux.NewTree(),
middleware: transformFastHTTPMiddlewareFunc(fs...),
tree: mux.NewTree(),
globalMiddleware: globalMiddleware,
middlewareCounter: uint(len(globalMiddleware)),
}
}

type fastHTTPRouter struct {
routes mux.Tree
middleware middleware.Middleware
fileServer fasthttp.RequestHandler
notFound fasthttp.RequestHandler
notAllowed fasthttp.RequestHandler
tree mux.Tree
globalMiddleware middleware.Collection
fileServer fasthttp.RequestHandler
notFound fasthttp.RequestHandler
notAllowed fasthttp.RequestHandler
middlewareCounter uint
}

func (r *fastHTTPRouter) PrettyPrint() string {
return r.routes.PrettyPrint()
return r.tree.PrettyPrint()
}

func (r *fastHTTPRouter) POST(p string, f fasthttp.RequestHandler) {
Expand Down Expand Up @@ -65,17 +67,20 @@ func (r *fastHTTPRouter) TRACE(p string, f fasthttp.RequestHandler) {
r.Handle(http.MethodTrace, p, f)
}

func (r *fastHTTPRouter) USE(method, p string, fs ...FastHTTPMiddlewareFunc) {
func (r *fastHTTPRouter) USE(method, path string, fs ...FastHTTPMiddlewareFunc) {
m := transformFastHTTPMiddlewareFunc(fs...)
for i, mf := range m {
m[i] = middleware.WithPriority(mf, r.middlewareCounter)
}

addMiddleware(r.routes, method, p, m)
r.tree = r.tree.WithMiddleware(method+path, m, 0)
r.middlewareCounter += uint(len(m))
}

func (r *fastHTTPRouter) Handle(method, path string, h fasthttp.RequestHandler) {
route := newRoute(h)
route.PrependMiddleware(r.middleware)

r.routes = r.routes.WithRoute(method+path, route, 0)
r.tree = r.tree.WithRoute(method+path, route, 0)
}

func (r *fastHTTPRouter) Mount(path string, h fasthttp.RequestHandler) {
Expand All @@ -91,15 +96,14 @@ func (r *fastHTTPRouter) Mount(path string, h fasthttp.RequestHandler) {
http.MethodTrace,
} {
route := newRoute(h)
route.PrependMiddleware(r.middleware)

r.routes = r.routes.WithSubrouter(method+path, route, 0)
r.tree = r.tree.WithSubrouter(method+path, route, 0)
}
}

func (r *fastHTTPRouter) Compile() {
for i, methodNode := range r.routes {
r.routes[i].WithChildren(methodNode.Tree().Compile())
for i, methodNode := range r.tree {
r.tree[i].WithChildren(methodNode.Tree().Compile())
}
}

Expand All @@ -121,32 +125,38 @@ func (r *fastHTTPRouter) ServeFiles(root string, stripSlashes int) {

func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) {
method := string(ctx.Method())
pathAsString := string(ctx.Path())
path := pathutils.TrimSlash(pathAsString)

if root := r.routes.Find(method); root != nil {
if node, params, subPath := root.Tree().Match(path); node != nil && node.Route() != nil {
if len(params) > 0 {
ctx.SetUserValue("params", params)
path := string(ctx.Path())

if route, params, subPath := r.tree.MatchRoute(method + path); route != nil {
var h fasthttp.RequestHandler
if r.middlewareCounter > 0 {
allMiddleware := r.globalMiddleware
if treeMiddleware := r.tree.MatchMiddleware(method + path); len(treeMiddleware) > 0 {
allMiddleware = allMiddleware.Merge(treeMiddleware.Sort())
}

if subPath != "" {
ctx.URI().SetPathBytes(fasthttp.NewPathPrefixStripper(len("/" + subPath))(ctx))
}
computedHandler := allMiddleware.Compose(route.Handler())

node.Route().Handler().(fasthttp.RequestHandler)(ctx)
return
h = computedHandler.(fasthttp.RequestHandler)
} else {
h = route.Handler().(fasthttp.RequestHandler)
}

if pathAsString == "/" && root.Route() != nil {
root.Route().Handler().(fasthttp.RequestHandler)(ctx)
return
if len(params) > 0 {
ctx.SetUserValue("params", params)
}

if subPath != "" {
ctx.URI().SetPathBytes(fasthttp.NewPathPrefixStripper(len("/" + subPath))(ctx))
}

h(ctx)
return
}

// Handle OPTIONS
if method == http.MethodOptions {
if allow := allowed(r.routes, method, path); len(allow) > 0 {
if allow := allowed(r.tree, method, path); len(allow) > 0 {
ctx.Response.Header.Set("Allow", allow)
return
}
Expand All @@ -156,7 +166,7 @@ func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) {
return
} else {
// Handle 405
if allow := allowed(r.routes, method, path); len(allow) > 0 {
if allow := allowed(r.tree, method, path); len(allow) > 0 {
ctx.Response.Header.Set("Allow", allow)
r.serveNotAllowed(ctx)
return
Expand All @@ -183,12 +193,12 @@ func (r *fastHTTPRouter) serveNotAllowed(ctx *fasthttp.RequestCtx) {
}
}

func transformFastHTTPMiddlewareFunc(fs ...FastHTTPMiddlewareFunc) middleware.Middleware {
m := make(middleware.Middleware, len(fs))
func transformFastHTTPMiddlewareFunc(fs ...FastHTTPMiddlewareFunc) middleware.Collection {
m := make(middleware.Collection, len(fs))

for i, f := range fs {
m[i] = func(mf FastHTTPMiddlewareFunc) middleware.MiddlewareFunc {
return func(h interface{}) interface{} {
m[i] = func(mf FastHTTPMiddlewareFunc) middleware.WrapperFunc {
return func(h middleware.Handler) middleware.Handler {
return mf(h.(fasthttp.RequestHandler))
}
}(f) // f is a reference to function so we have to wrap if with that callback
Expand Down
Loading

0 comments on commit 071a686

Please sign in to comment.