From 39e5d7811e6268ad569af0f6a101a442839dd32a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Thu, 2 Jan 2020 17:12:32 +1100 Subject: [PATCH 01/41] Update middleware test cases to handle more complex use cases --- fasthttp_test.go | 13 +++++++++++-- nethttp_test.go | 17 +++++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/fasthttp_test.go b/fasthttp_test.go index b5c1ce2..f1d677e 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -383,13 +383,22 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { fmt.Fprintf(ctx, "%s", params.Value("param")) }) - router.USE(http.MethodGet, "/x/{param}", mockFastHTTPMiddleware("m")) + router.USE(http.MethodGet, "/x/{param}", mockFastHTTPMiddleware("m1")) + router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m2")) ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/y") router.HandleFastHTTP(ctx) - if string(ctx.Response.Body()) != "my" { + if string(ctx.Response.Body()) != "m1y" { + t.Errorf("Use global middleware error %s", string(ctx.Response.Body())) + } + + ctx = buildFastHTTPRequestContext(http.MethodGet, "/x/x") + + router.HandleFastHTTP(ctx) + + if string(ctx.Response.Body()) != "m1m2x" { t.Errorf("Use global middleware error %s", string(ctx.Response.Body())) } } diff --git a/nethttp_test.go b/nethttp_test.go index e9f61f9..8ef945e 100644 --- a/nethttp_test.go +++ b/nethttp_test.go @@ -409,7 +409,8 @@ func TestNodeApplyMiddleware(t *testing.T) { w.Write([]byte(params.Value("param"))) })) - router.USE(http.MethodGet, "/x/{param}", mockMiddleware("m")) + router.USE(http.MethodGet, "/x/{param}", mockMiddleware("m1")) + router.USE(http.MethodGet, "/x/x", mockMiddleware("m2")) w := httptest.NewRecorder() req, err := http.NewRequest(http.MethodGet, "/x/y", nil) @@ -419,7 +420,19 @@ func TestNodeApplyMiddleware(t *testing.T) { router.ServeHTTP(w, req) - if w.Body.String() != "my" { + if w.Body.String() != "m1y" { + t.Errorf("Use global middleware error %s", w.Body.String()) + } + + w = httptest.NewRecorder() + req, err = http.NewRequest(http.MethodGet, "/x/x", nil) + if err != nil { + t.Fatal(err) + } + + router.ServeHTTP(w, req) + + if w.Body.String() != "m1m2x" { t.Errorf("Use global middleware error %s", w.Body.String()) } } From a2896ff481c84fb0490bfb0e71e2604b93bb539a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Sat, 11 Jan 2020 13:28:49 +1100 Subject: [PATCH 02/41] Move middleware to mux tree --- .gitignore | 1 + fasthttp.go | 17 +++++++----- mux/benchmark_test.go | 2 +- mux/node.go | 63 ++++++++++++++++++++++++++++++++----------- mux/route.go | 6 ----- mux/tree.go | 38 ++++++++++++++++++++++---- mux/tree_test.go | 2 +- nethttp.go | 15 ++++++----- route.go | 24 +++-------------- route_test.go | 5 ++-- tree.go | 29 +------------------- 11 files changed, 109 insertions(+), 93 deletions(-) diff --git a/.gitignore b/.gitignore index 3b0d410..f7e9dd9 100644 --- a/.gitignore +++ b/.gitignore @@ -14,5 +14,6 @@ .glide/ .vscode +.idea vendor/ \ No newline at end of file diff --git a/fasthttp.go b/fasthttp.go index 80946ec..8cbe4a9 100644 --- a/fasthttp.go +++ b/fasthttp.go @@ -65,15 +65,14 @@ func (r *fastHTTPRouter) TRACE(p string, f fasthttp.RequestHandler) { r.Handle(http.MethodTrace, p, f) } -func (r *fastHTTPRouter) USE(method, p string, fs ...FastHTTPMiddlewareFunc) { +func (r *fastHTTPRouter) USE(method, path string, fs ...FastHTTPMiddlewareFunc) { m := transformFastHTTPMiddlewareFunc(fs...) - addMiddleware(r.routes, method, p, m) + r.routes = r.routes.WithMiddleware(method+path, m, 0) } func (r *fastHTTPRouter) Handle(method, path string, h fasthttp.RequestHandler) { route := newRoute(h) - route.PrependMiddleware(r.middleware) r.routes = r.routes.WithRoute(method+path, route, 0) } @@ -91,7 +90,6 @@ func (r *fastHTTPRouter) Mount(path string, h fasthttp.RequestHandler) { http.MethodTrace, } { route := newRoute(h) - route.PrependMiddleware(r.middleware) r.routes = r.routes.WithSubrouter(method+path, route, 0) } @@ -125,7 +123,14 @@ func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) { path := pathutils.TrimSlash(pathAsString) if root := r.routes.Find(method); root != nil { - if node, params, subPath := root.Tree().Match(path); node != nil && node.Route() != nil { + if node, treeMiddleware, params, subPath := root.Tree().Match(path); node != nil && node.Route() != nil { + route := node.Route() + handler := route.Handler() + middleware := r.middleware.Merge(treeMiddleware) + computedHandler := middleware.Compose(handler) + + h := computedHandler.(fasthttp.RequestHandler) + if len(params) > 0 { ctx.SetUserValue("params", params) } @@ -134,7 +139,7 @@ func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) { ctx.URI().SetPathBytes(fasthttp.NewPathPrefixStripper(len("/" + subPath))(ctx)) } - node.Route().Handler().(fasthttp.RequestHandler)(ctx) + h(ctx) return } diff --git a/mux/benchmark_test.go b/mux/benchmark_test.go index 82a7863..a862a31 100644 --- a/mux/benchmark_test.go +++ b/mux/benchmark_test.go @@ -37,7 +37,7 @@ func BenchmarkMux(b *testing.B) { b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - n, _, _ := root.Tree().Match("pl/blog/comments/123/new") + n, _, _, _ := root.Tree().Match("pl/blog/comments/123/new") if n == nil { b.Fatalf("%v", n) diff --git a/mux/node.go b/mux/node.go index e4eecf6..aa7c89a 100644 --- a/mux/node.go +++ b/mux/node.go @@ -4,6 +4,7 @@ import ( "regexp" "github.com/vardius/gorouter/v4/context" + "github.com/vardius/gorouter/v4/middleware" pathutils "github.com/vardius/gorouter/v4/path" ) @@ -17,6 +18,7 @@ func NewNode(pathPart string, maxParamsSize uint8) Node { static := &staticNode{ name: name, children: NewTree(), + middleware: middleware.New(), maxParamsSize: maxParamsSize, } @@ -39,7 +41,7 @@ func NewNode(pathPart string, maxParamsSize uint8) Node { // Can match path and provide routes type Node interface { // Match matches given path to Node within Node and its Tree - Match(path string) (Node, context.Params, string) + Match(path string) (Node, middleware.Middleware, context.Params, string) // Name provides Node name Name() string @@ -47,6 +49,8 @@ type Node interface { Tree() Tree // Route provides Node's Route if assigned Route() Route + // Middleware provides Node's middleware + Middleware() middleware.Middleware // Name provides maximum number of parameters Route can have for given Node MaxParamsSize() uint8 @@ -55,6 +59,10 @@ type Node interface { WithRoute(r Route) // WithChildren sets Node's Tree WithChildren(t Tree) + // AppendMiddleware appends middleware to Node + AppendMiddleware(m middleware.Middleware) + // PrependMiddleware prepends middleware to Node + PrependMiddleware(m middleware.Middleware) // SkipSubPath sets skipSubPath node property to true // will skip children match search and return current node directly @@ -66,29 +74,32 @@ type staticNode struct { name string children Tree - route Route + route Route + middleware middleware.Middleware maxParamsSize uint8 skipSubPath bool } -func (n *staticNode) Match(path string) (Node, context.Params, string) { +func (n *staticNode) Match(path string) (Node, middleware.Middleware, context.Params, string) { nameLength := len(n.name) pathLength := len(path) if pathLength >= nameLength && n.name == path[:nameLength] { if nameLength+1 >= pathLength { - return n, make(context.Params, n.maxParamsSize), "" + return n, n.middleware, make(context.Params, n.maxParamsSize), "" } if n.skipSubPath { - return n, make(context.Params, n.maxParamsSize), path[nameLength+1:] + return n, n.middleware, make(context.Params, n.maxParamsSize), path[nameLength+1:] } - return n.children.Match(path[nameLength+1:]) // +1 because we wan to skip slash as well + node, treeMiddleware, params, p := n.children.Match(path[nameLength+1:]) // +1 because we wan to skip slash as well + + return node, n.middleware.Merge(treeMiddleware), params, p } - return nil, nil, "" + return nil, nil, nil, "" } func (n *staticNode) Name() string { @@ -103,6 +114,10 @@ func (n *staticNode) Route() Route { return n.route } +func (n *staticNode) Middleware() middleware.Middleware { + return n.middleware +} + func (n *staticNode) MaxParamsSize() uint8 { return n.maxParamsSize } @@ -115,6 +130,14 @@ func (n *staticNode) WithRoute(r Route) { n.route = r } +func (n *staticNode) AppendMiddleware(m middleware.Middleware) { + n.middleware = n.middleware.Merge(m) +} + +func (n *staticNode) PrependMiddleware(m middleware.Middleware) { + n.middleware = m.Merge(n.middleware) +} + func (n *staticNode) SkipSubPath() { n.skipSubPath = true } @@ -127,27 +150,31 @@ type wildcardNode struct { *staticNode } -func (n *wildcardNode) Match(path string) (Node, context.Params, string) { +func (n *wildcardNode) Match(path string) (Node, middleware.Middleware, context.Params, string) { pathPart, subPath := pathutils.GetPart(path) maxParamsSize := n.MaxParamsSize() var node Node + var treeMiddleware middleware.Middleware var params context.Params if subPath == "" || n.staticNode.skipSubPath { node = n + treeMiddleware = n.Middleware() params = make(context.Params, maxParamsSize) } else { - node, params, subPath = n.children.Match(subPath) + node, treeMiddleware, params, subPath = n.children.Match(subPath) if node == nil { - return nil, nil, "" + return nil, nil, nil, "" } + + treeMiddleware = n.middleware.Merge(treeMiddleware) } params.Set(maxParamsSize-1, n.name, pathPart) - return node, params, subPath + return node, treeMiddleware, params, subPath } func withRegexp(parent *staticNode, regexp *regexp.Regexp) *regexpNode { @@ -163,31 +190,35 @@ type regexpNode struct { regexp *regexp.Regexp } -func (n *regexpNode) Match(path string) (Node, context.Params, string) { +func (n *regexpNode) Match(path string) (Node, middleware.Middleware, context.Params, string) { pathPart, subPath := pathutils.GetPart(path) if !n.regexp.MatchString(pathPart) { - return nil, nil, "" + return nil, nil, nil, "" } maxParamsSize := n.MaxParamsSize() var node Node + var treeMiddleware middleware.Middleware var params context.Params if subPath == "" || n.staticNode.skipSubPath { node = n + treeMiddleware = n.Middleware() params = make(context.Params, maxParamsSize) } else { - node, params, subPath = n.children.Match(subPath) + node, treeMiddleware, params, subPath = n.children.Match(subPath) if node == nil { - return nil, nil, "" + return nil, nil, nil, "" } + + treeMiddleware = n.middleware.Merge(treeMiddleware) } params.Set(maxParamsSize-1, n.name, pathPart) - return node, params, subPath + return node, treeMiddleware, params, subPath } func withSubrouter(parent Node) *subrouterNode { diff --git a/mux/route.go b/mux/route.go index 2279e36..60541d7 100644 --- a/mux/route.go +++ b/mux/route.go @@ -1,12 +1,6 @@ package mux -import ( - "github.com/vardius/gorouter/v4/middleware" -) - // Route is an middleware aware route interface type Route interface { Handler() interface{} - AppendMiddleware(m middleware.Middleware) - PrependMiddleware(m middleware.Middleware) } diff --git a/mux/tree.go b/mux/tree.go index 7357d90..bddaa0f 100644 --- a/mux/tree.go +++ b/mux/tree.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/vardius/gorouter/v4/context" + "github.com/vardius/gorouter/v4/middleware" pathutils "github.com/vardius/gorouter/v4/path" ) @@ -68,14 +69,14 @@ func (t Tree) Compile() Tree { } // Match path to Node -func (t Tree) Match(path string) (Node, context.Params, string) { +func (t Tree) Match(path string) (Node, middleware.Middleware, context.Params, string) { for _, child := range t { - if node, params, subPath := child.Match(path); node != nil { - return node, params, subPath + if node, m, params, subPath := child.Match(path); node != nil { + return node, m, params, subPath } } - return nil, nil, "" + return nil, nil, nil, "" } // Find Node inside a tree by name @@ -94,7 +95,7 @@ func (t Tree) Find(name string) Node { } // WithRoute returns new Tree with Route set to Node -// Route is set to Node under the give path, ff Node does not exist it is created +// Route is set to Node under the give path, if Node does not exist it is created func (t Tree) WithRoute(path string, route Route, maxParamsSize uint8) Tree { path = pathutils.TrimSlash(path) if path == "" { @@ -120,6 +121,33 @@ func (t Tree) WithRoute(path string, route Route, maxParamsSize uint8) Tree { return newTree } +// WithMiddleware returns new Tree with Middleware appended to given Node +// Middleware is appended to Node under the give path, if Node does not exist it is created +func (t Tree) WithMiddleware(path string, m middleware.Middleware, maxParamsSize uint8) Tree { + path = pathutils.TrimSlash(path) + if path == "" { + return t + } + + parts := strings.Split(path, "/") + name, _ := pathutils.GetNameFromPart(parts[0]) + node := t.Find(name) + newTree := t + + if node == nil { + node = NewNode(parts[0], maxParamsSize) + newTree = t.withNode(node) + } + + if len(parts) == 1 { + node.AppendMiddleware(m) + } else { + node.WithChildren(node.Tree().WithMiddleware(strings.Join(parts[1:], "/"), m, node.MaxParamsSize())) + } + + return newTree +} + // WithSubrouter returns new Tree with new Route set to Subrouter Node // Route is set to Node under the give path, ff Node does not exist it is created func (t Tree) WithSubrouter(path string, route Route, maxParamsSize uint8) Tree { diff --git a/mux/tree_test.go b/mux/tree_test.go index 12646a7..c5dc3b7 100644 --- a/mux/tree_test.go +++ b/mux/tree_test.go @@ -37,7 +37,7 @@ func TestTreeMatch(t *testing.T) { root.WithChildren(root.Tree().Compile()) - n, _, _ := root.Tree().Match("pl/blog/comments/123/new") + n, _, _, _ := root.Tree().Match("pl/blog/comments/123/new") if n == nil { t.Fatalf("%v", n) diff --git a/nethttp.go b/nethttp.go index bb7170e..c012321 100644 --- a/nethttp.go +++ b/nethttp.go @@ -66,15 +66,14 @@ func (r *router) TRACE(p string, f http.Handler) { r.Handle(http.MethodTrace, p, f) } -func (r *router) USE(method, p string, fs ...MiddlewareFunc) { +func (r *router) USE(method, path string, fs ...MiddlewareFunc) { m := transformMiddlewareFunc(fs...) - addMiddleware(r.routes, method, p, m) + r.routes = r.routes.WithMiddleware(method+path, m, 0) } func (r *router) Handle(method, path string, h http.Handler) { route := newRoute(h) - route.PrependMiddleware(r.middleware) r.routes = r.routes.WithRoute(method+path, route, 0) } @@ -92,7 +91,6 @@ func (r *router) Mount(path string, h http.Handler) { http.MethodTrace, } { route := newRoute(h) - route.PrependMiddleware(r.middleware) r.routes = r.routes.WithSubrouter(method+path, route, 0) } @@ -127,8 +125,13 @@ func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { path := pathutils.TrimSlash(req.URL.Path) if root := r.routes.Find(req.Method); root != nil { - if node, params, subPath := root.Tree().Match(path); node != nil && node.Route() != nil { - h := node.Route().Handler().(http.Handler) + if node, treeMiddleware, params, subPath := root.Tree().Match(path); node != nil && node.Route() != nil { + route := node.Route() + handler := route.Handler() + middleware := r.middleware.Merge(treeMiddleware) + computedHandler := middleware.Compose(handler) + + h := computedHandler.(http.Handler) if len(params) > 0 { req = req.WithContext(context.WithParams(req.Context(), params)) diff --git a/route.go b/route.go index 73cf71a..fc464b0 100644 --- a/route.go +++ b/route.go @@ -1,14 +1,7 @@ package gorouter -import ( - "github.com/vardius/gorouter/v4/middleware" -) - type route struct { - middleware middleware.Middleware - handler interface{} - // computedHandler is an optimization to improve performance - computedHandler interface{} + handler interface{} } func newRoute(h interface{}) *route { @@ -17,22 +10,11 @@ func newRoute(h interface{}) *route { } return &route{ - handler: h, - middleware: middleware.New(), + handler: h, } } func (r *route) Handler() interface{} { // returns already cached computed handler - return r.computedHandler -} - -func (r *route) AppendMiddleware(m middleware.Middleware) { - r.middleware = r.middleware.Merge(m) - r.computedHandler = r.middleware.Compose(r.handler) -} - -func (r *route) PrependMiddleware(m middleware.Middleware) { - r.middleware = m.Merge(r.middleware) - r.computedHandler = r.middleware.Compose(r.handler) + return r.handler } diff --git a/route_test.go b/route_test.go index c71bf8a..c188bb3 100644 --- a/route_test.go +++ b/route_test.go @@ -30,9 +30,8 @@ func TestRouter(t *testing.T) { m3 := buildMiddlewareFunc("3") r := newRoute(handler) - r.AppendMiddleware(middleware.New(m1, m2, m3)) - - h := r.Handler() + m := middleware.New(m1, m2, m3) + h := m.Compose(r.Handler()) w := httptest.NewRecorder() req, err := http.NewRequest("GET", "/", nil) diff --git a/tree.go b/tree.go index 416a662..eb46996 100644 --- a/tree.go +++ b/tree.go @@ -2,38 +2,11 @@ package gorouter import ( "net/http" - "strings" - "github.com/vardius/gorouter/v4/middleware" "github.com/vardius/gorouter/v4/mux" pathutils "github.com/vardius/gorouter/v4/path" ) -func addMiddleware(t mux.Tree, method, path string, mid middleware.Middleware) { - type recFunc func(recFunc, mux.Node, middleware.Middleware) - - c := func(c recFunc, n mux.Node, m middleware.Middleware) { - if n.Route() != nil { - n.Route().AppendMiddleware(m) - } - for _, child := range n.Tree() { - c(c, child, m) - } - } - - // routes tree roots should be http method nodes only - if root := t.Find(method); root != nil { - if path != "" { - node := findNode(root, strings.Split(pathutils.TrimSlash(path), "/")) - if node != nil { - c(c, node, mid) - } - } else { - c(c, root, mid) - } - } -} - func findNode(n mux.Node, parts []string) mux.Node { if len(parts) == 0 { return n @@ -68,7 +41,7 @@ func allowed(t mux.Tree, method, path string) (allow string) { continue } - if n, _, _ := root.Tree().Match(path); n != nil && n.Route() != nil { + if n, _, _, _ := root.Tree().Match(path); n != nil && n.Route() != nil { if len(allow) == 0 { allow = root.Name() } else { From 0cee9c5fb43671c7b1764bb9990d8d681b5d1cc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Wed, 15 Jan 2020 12:09:53 +1100 Subject: [PATCH 03/41] Apply root node (method) middleware --- fasthttp.go | 3 ++- nethttp.go | 5 ++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/fasthttp.go b/fasthttp.go index 8cbe4a9..884b367 100644 --- a/fasthttp.go +++ b/fasthttp.go @@ -126,7 +126,8 @@ func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) { if node, treeMiddleware, params, subPath := root.Tree().Match(path); node != nil && node.Route() != nil { route := node.Route() handler := route.Handler() - middleware := r.middleware.Merge(treeMiddleware) + middleware := root.Middleware().Merge(treeMiddleware) + middleware = r.middleware.Merge(middleware) computedHandler := middleware.Compose(handler) h := computedHandler.(fasthttp.RequestHandler) diff --git a/nethttp.go b/nethttp.go index c012321..88e3c88 100644 --- a/nethttp.go +++ b/nethttp.go @@ -126,9 +126,12 @@ func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { if root := r.routes.Find(req.Method); root != nil { if node, treeMiddleware, params, subPath := root.Tree().Match(path); node != nil && node.Route() != nil { + // @FIXME @TODO: issue when adding middleware to static path node while route was added to wildcard node + // Found node will (matching middleware path) will not have assigned route route := node.Route() handler := route.Handler() - middleware := r.middleware.Merge(treeMiddleware) + middleware := root.Middleware().Merge(treeMiddleware) + middleware = r.middleware.Merge(middleware) computedHandler := middleware.Compose(handler) h := computedHandler.(http.Handler) From 54afa4cfdeb0d8700407f2f514d46fa7a086c1b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Wed, 15 Jan 2020 12:10:07 +1100 Subject: [PATCH 04/41] Append middleware when compiling --- mux/tree.go | 1 + 1 file changed, 1 insertion(+) diff --git a/mux/tree.go b/mux/tree.go index bddaa0f..730ccf4 100644 --- a/mux/tree.go +++ b/mux/tree.go @@ -53,6 +53,7 @@ func (t Tree) Compile() Tree { case *staticNode: if staticNode, ok := node.Tree()[0].(*staticNode); ok { node.WithChildren(staticNode.Tree()) + node.AppendMiddleware(staticNode.Middleware()) node.name = fmt.Sprintf("%s/%s", node.name, staticNode.name) t[i] = node From 209e3724c93e93101082ba6dbf8c8b647234fa4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Wed, 15 Jan 2020 12:35:02 +1100 Subject: [PATCH 05/41] Throw exception while middleware path is invalid --- fasthttp_test.go | 29 ++++++++++++++++++++++++----- mux/tree.go | 5 ++--- nethttp.go | 2 -- nethttp_test.go | 35 +++++++++++++++++++++++++++++------ 4 files changed, 55 insertions(+), 16 deletions(-) diff --git a/fasthttp_test.go b/fasthttp_test.go index f1d677e..32e46ee 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -384,22 +384,41 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { }) router.USE(http.MethodGet, "/x/{param}", mockFastHTTPMiddleware("m1")) - router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m2")) ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/y") router.HandleFastHTTP(ctx) if string(ctx.Response.Body()) != "m1y" { - t.Errorf("Use global middleware error %s", string(ctx.Response.Body())) + t.Errorf("Use middleware error %s", string(ctx.Response.Body())) } +} + +func TestFastHTTPNodeApplyMiddlewareInvalidPath(t *testing.T) { + t.Parallel() + + panicked := false + defer func() { + if rcv := recover(); rcv != nil { + panicked = true + } + }() - ctx = buildFastHTTPRequestContext(http.MethodGet, "/x/x") + router := NewFastHTTPRouter().(*fastHTTPRouter) + + router.GET("/x/{param}", func(ctx *fasthttp.RequestCtx) { + params := ctx.UserValue("params").(context.Params) + fmt.Fprintf(ctx, "%s", params.Value("param")) + }) + + router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m2")) + + ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/x") router.HandleFastHTTP(ctx) - if string(ctx.Response.Body()) != "m1m2x" { - t.Errorf("Use global middleware error %s", string(ctx.Response.Body())) + if panicked != true { + t.Error("Router should panic for invalid middleware path") } } diff --git a/mux/tree.go b/mux/tree.go index 730ccf4..dc9a7b1 100644 --- a/mux/tree.go +++ b/mux/tree.go @@ -123,7 +123,7 @@ func (t Tree) WithRoute(path string, route Route, maxParamsSize uint8) Tree { } // WithMiddleware returns new Tree with Middleware appended to given Node -// Middleware is appended to Node under the give path, if Node does not exist it is created +// Middleware is appended to Node under the give path, if Node does not exist it will panic func (t Tree) WithMiddleware(path string, m middleware.Middleware, maxParamsSize uint8) Tree { path = pathutils.TrimSlash(path) if path == "" { @@ -136,8 +136,7 @@ func (t Tree) WithMiddleware(path string, m middleware.Middleware, maxParamsSize newTree := t if node == nil { - node = NewNode(parts[0], maxParamsSize) - newTree = t.withNode(node) + panic("Could not find node for given path") } if len(parts) == 1 { diff --git a/nethttp.go b/nethttp.go index 88e3c88..6ce8495 100644 --- a/nethttp.go +++ b/nethttp.go @@ -126,8 +126,6 @@ func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { if root := r.routes.Find(req.Method); root != nil { if node, treeMiddleware, params, subPath := root.Tree().Match(path); node != nil && node.Route() != nil { - // @FIXME @TODO: issue when adding middleware to static path node while route was added to wildcard node - // Found node will (matching middleware path) will not have assigned route route := node.Route() handler := route.Handler() middleware := root.Middleware().Merge(treeMiddleware) diff --git a/nethttp_test.go b/nethttp_test.go index 8ef945e..1390690 100644 --- a/nethttp_test.go +++ b/nethttp_test.go @@ -410,7 +410,6 @@ func TestNodeApplyMiddleware(t *testing.T) { })) router.USE(http.MethodGet, "/x/{param}", mockMiddleware("m1")) - router.USE(http.MethodGet, "/x/x", mockMiddleware("m2")) w := httptest.NewRecorder() req, err := http.NewRequest(http.MethodGet, "/x/y", nil) @@ -421,19 +420,43 @@ func TestNodeApplyMiddleware(t *testing.T) { router.ServeHTTP(w, req) if w.Body.String() != "m1y" { - t.Errorf("Use global middleware error %s", w.Body.String()) + t.Errorf("Use middleware error %s", w.Body.String()) } +} - w = httptest.NewRecorder() - req, err = http.NewRequest(http.MethodGet, "/x/x", nil) +func TestNodeApplyMiddlewareInvalidPath(t *testing.T) { + t.Parallel() + + panicked := false + defer func() { + if rcv := recover(); rcv != nil { + panicked = true + } + }() + + router := New().(*router) + + router.GET("/x/{param}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + params, ok := context.Parameters(r.Context()) + if !ok { + t.Fatal("Error while reading param") + } + + w.Write([]byte(params.Value("param"))) + })) + + router.USE(http.MethodGet, "/x/x", mockMiddleware("m2")) + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/x/x", nil) if err != nil { t.Fatal(err) } router.ServeHTTP(w, req) - if w.Body.String() != "m1m2x" { - t.Errorf("Use global middleware error %s", w.Body.String()) + if panicked != true { + t.Error("Router should panic for invalid middleware path") } } From c22bcae8d6186a3c74318804814947cc5af76a37 Mon Sep 17 00:00:00 2001 From: mar1n3r0 Date: Wed, 15 Jan 2020 18:26:53 +0200 Subject: [PATCH 06/41] Map new nodes with middleware missing routes to parameterized node --- mux/node.go | 10 ++++------ mux/tree.go | 12 ++++++++---- nethttp.go | 1 - path/path.go | 20 +++++++------------- 4 files changed, 19 insertions(+), 24 deletions(-) diff --git a/mux/node.go b/mux/node.go index aa7c89a..872f8c6 100644 --- a/mux/node.go +++ b/mux/node.go @@ -84,18 +84,16 @@ type staticNode struct { func (n *staticNode) Match(path string) (Node, middleware.Middleware, context.Params, string) { nameLength := len(n.name) pathLength := len(path) - - if pathLength >= nameLength && n.name == path[:nameLength] { + if pathLength >= nameLength && n.name == path[:nameLength] || regexp.MustCompile(`{|}`).MatchString(n.name) { if nameLength+1 >= pathLength { - return n, n.middleware, make(context.Params, n.maxParamsSize), "" + // is there a better solution here ? it stopped working once the braces were included in node.name + return n, n.middleware, context.Params{{Key: "param", Value: path}}, "" } if n.skipSubPath { - return n, n.middleware, make(context.Params, n.maxParamsSize), path[nameLength+1:] + return n, n.middleware, context.Params{{Key: "param", Value: path}}, path[nameLength+1:] } - node, treeMiddleware, params, p := n.children.Match(path[nameLength+1:]) // +1 because we wan to skip slash as well - return node, n.middleware.Merge(treeMiddleware), params, p } diff --git a/mux/tree.go b/mux/tree.go index bddaa0f..267214a 100644 --- a/mux/tree.go +++ b/mux/tree.go @@ -3,6 +3,7 @@ package mux import ( "bytes" "fmt" + "regexp" "sort" "strings" @@ -71,7 +72,7 @@ func (t Tree) Compile() Tree { // Match path to Node func (t Tree) Match(path string) (Node, middleware.Middleware, context.Params, string) { for _, child := range t { - if node, m, params, subPath := child.Match(path); node != nil { + if node, m, params, subPath := child.Match(path); node != nil && node.Route() != nil { return node, m, params, subPath } } @@ -79,14 +80,14 @@ func (t Tree) Match(path string) (Node, middleware.Middleware, context.Params, s return nil, nil, nil, "" } -// Find Node inside a tree by name +// Find Node inside a tree by name, if name == {} we are specifically looking for a parameterized route and it's regex checked func (t Tree) Find(name string) Node { if name == "" { return nil } for _, child := range t { - if child.Name() == name { + if child.Name() == name || name == "{}" && regexp.MustCompile(`{|}`).MatchString(child.Name()) { return child } } @@ -101,7 +102,6 @@ func (t Tree) WithRoute(path string, route Route, maxParamsSize uint8) Tree { if path == "" { return t } - parts := strings.Split(path, "/") name, _ := pathutils.GetNameFromPart(parts[0]) node := t.Find(name) @@ -134,8 +134,12 @@ func (t Tree) WithMiddleware(path string, m middleware.Middleware, maxParamsSize node := t.Find(name) newTree := t + // If there is no node get the route of the previous pameterized one + // and append middleware on top of its own middleware if node == nil { node = NewNode(parts[0], maxParamsSize) + node.WithRoute(t.Find("{}").Route()) + node.AppendMiddleware(t.Find("{}").Middleware()) newTree = t.withNode(node) } diff --git a/nethttp.go b/nethttp.go index c012321..11f6ef5 100644 --- a/nethttp.go +++ b/nethttp.go @@ -68,7 +68,6 @@ func (r *router) TRACE(p string, f http.Handler) { func (r *router) USE(method, path string, fs ...MiddlewareFunc) { m := transformMiddlewareFunc(fs...) - r.routes = r.routes.WithMiddleware(method+path, m, 0) } diff --git a/path/path.go b/path/path.go index 45c6cc5..917baf6 100644 --- a/path/path.go +++ b/path/path.go @@ -30,23 +30,17 @@ func GetPart(path string) (part string, nextPath string) { return } -// GetNameFromPart gets node name from path part +// GetNameFromPart gets node name from path part, braces are not stripped so that we can check which node is associated with a parameter func GetNameFromPart(pathPart string) (name string, exp string) { name = pathPart - if pathPart[0] == '{' { - name = pathPart[1 : len(pathPart)-1] - - if parts := strings.Split(name, ":"); len(parts) == 2 { - name = parts[0] - exp = parts[1] - } - - if name == "" { - panic("Empty wildcard name") - } + if parts := strings.Split(name, ":"); len(parts) == 2 { + name = parts[0] + exp = parts[1] + } - return + if name == "" { + panic("Empty wildcard name") } return From af63271a5cdc163ed4656eda136ce82270d38f6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Thu, 16 Jan 2020 17:23:40 +1100 Subject: [PATCH 07/41] Revert "Map new nodes with middleware missing routes to parameterized node" This reverts commit c22bcae8d6186a3c74318804814947cc5af76a37. # Conflicts: # mux/tree.go --- mux/node.go | 10 ++++++---- nethttp.go | 1 + path/path.go | 20 +++++++++++++------- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/mux/node.go b/mux/node.go index 872f8c6..aa7c89a 100644 --- a/mux/node.go +++ b/mux/node.go @@ -84,16 +84,18 @@ type staticNode struct { func (n *staticNode) Match(path string) (Node, middleware.Middleware, context.Params, string) { nameLength := len(n.name) pathLength := len(path) - if pathLength >= nameLength && n.name == path[:nameLength] || regexp.MustCompile(`{|}`).MatchString(n.name) { + + if pathLength >= nameLength && n.name == path[:nameLength] { if nameLength+1 >= pathLength { - // is there a better solution here ? it stopped working once the braces were included in node.name - return n, n.middleware, context.Params{{Key: "param", Value: path}}, "" + return n, n.middleware, make(context.Params, n.maxParamsSize), "" } if n.skipSubPath { - return n, n.middleware, context.Params{{Key: "param", Value: path}}, path[nameLength+1:] + return n, n.middleware, make(context.Params, n.maxParamsSize), path[nameLength+1:] } + node, treeMiddleware, params, p := n.children.Match(path[nameLength+1:]) // +1 because we wan to skip slash as well + return node, n.middleware.Merge(treeMiddleware), params, p } diff --git a/nethttp.go b/nethttp.go index f7c490a..6ce8495 100644 --- a/nethttp.go +++ b/nethttp.go @@ -68,6 +68,7 @@ func (r *router) TRACE(p string, f http.Handler) { func (r *router) USE(method, path string, fs ...MiddlewareFunc) { m := transformMiddlewareFunc(fs...) + r.routes = r.routes.WithMiddleware(method+path, m, 0) } diff --git a/path/path.go b/path/path.go index 917baf6..45c6cc5 100644 --- a/path/path.go +++ b/path/path.go @@ -30,17 +30,23 @@ func GetPart(path string) (part string, nextPath string) { return } -// GetNameFromPart gets node name from path part, braces are not stripped so that we can check which node is associated with a parameter +// GetNameFromPart gets node name from path part func GetNameFromPart(pathPart string) (name string, exp string) { name = pathPart - if parts := strings.Split(name, ":"); len(parts) == 2 { - name = parts[0] - exp = parts[1] - } + if pathPart[0] == '{' { + name = pathPart[1 : len(pathPart)-1] + + if parts := strings.Split(name, ":"); len(parts) == 2 { + name = parts[0] + exp = parts[1] + } + + if name == "" { + panic("Empty wildcard name") + } - if name == "" { - panic("Empty wildcard name") + return } return From 68d8a3640ac3962feca8a05d3dbdea514c9f78a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Sat, 18 Jan 2020 12:01:08 +1100 Subject: [PATCH 08/41] Try to find node by matching it name against tree, do not keep orphan nodes --- fasthttp.go | 7 +++---- fasthttp_test.go | 36 +++++++++++++++++++++++++++++------- mux/tree.go | 19 ++++++++++++------- nethttp.go | 7 +++---- nethttp_test.go | 39 ++++++++++++++++++++++++++++++++------- 5 files changed, 79 insertions(+), 29 deletions(-) diff --git a/fasthttp.go b/fasthttp.go index 884b367..ed82e90 100644 --- a/fasthttp.go +++ b/fasthttp.go @@ -68,7 +68,7 @@ func (r *fastHTTPRouter) TRACE(p string, f fasthttp.RequestHandler) { func (r *fastHTTPRouter) USE(method, path string, fs ...FastHTTPMiddlewareFunc) { m := transformFastHTTPMiddlewareFunc(fs...) - r.routes = r.routes.WithMiddleware(method+path, m, 0) + r.routes = r.routes.WithMiddleware(method+path, m) } func (r *fastHTTPRouter) Handle(method, path string, h fasthttp.RequestHandler) { @@ -126,9 +126,8 @@ func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) { if node, treeMiddleware, params, subPath := root.Tree().Match(path); node != nil && node.Route() != nil { route := node.Route() handler := route.Handler() - middleware := root.Middleware().Merge(treeMiddleware) - middleware = r.middleware.Merge(middleware) - computedHandler := middleware.Compose(handler) + allMiddleware := r.middleware.Merge(root.Middleware().Merge(treeMiddleware)) + computedHandler := allMiddleware.Compose(handler) h := computedHandler.(fasthttp.RequestHandler) diff --git a/fasthttp_test.go b/fasthttp_test.go index 32e46ee..a3eae7d 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -159,7 +159,9 @@ func TestFastHTTPNotFound(t *testing.T) { } router.NotFound(func(ctx *fasthttp.RequestCtx) { - fmt.Fprintf(ctx, "test") + if _, err := fmt.Fprintf(ctx, "test"); err != nil { + t.Fatal(err) + } }) if router.notFound == nil { @@ -193,7 +195,9 @@ func TestFastHTTPNotAllowed(t *testing.T) { } router.NotAllowed(func(ctx *fasthttp.RequestCtx) { - fmt.Fprintf(ctx, "test") + if _, err := fmt.Fprintf(ctx, "test"); err != nil { + t.Fatal(err) + } }) if router.notAllowed == nil { @@ -327,7 +331,9 @@ func TestFastHTTPNilMiddleware(t *testing.T) { router := NewFastHTTPRouter().(*fastHTTPRouter) router.GET("/x/{param}", func(ctx *fasthttp.RequestCtx) { - fmt.Fprintf(ctx, "test") + if _, err := fmt.Fprintf(ctx, "test"); err != nil { + t.Fatal(err) + } }) ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/y") @@ -380,7 +386,9 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { router.GET("/x/{param}", func(ctx *fasthttp.RequestCtx) { params := ctx.UserValue("params").(context.Params) - fmt.Fprintf(ctx, "%s", params.Value("param")) + if _, err := fmt.Fprintf(ctx, "%s", params.Value("param")); err != nil { + t.Fatal(err) + } }) router.USE(http.MethodGet, "/x/{param}", mockFastHTTPMiddleware("m1")) @@ -392,6 +400,16 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { if string(ctx.Response.Body()) != "m1y" { t.Errorf("Use middleware error %s", string(ctx.Response.Body())) } + + router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m2")) + + ctx = buildFastHTTPRequestContext(http.MethodGet, "/x/x") + + router.HandleFastHTTP(ctx) + + if string(ctx.Response.Body()) != "m1m2x" { + t.Errorf("Use middleware error %s", string(ctx.Response.Body())) + } } func TestFastHTTPNodeApplyMiddlewareInvalidPath(t *testing.T) { @@ -406,9 +424,11 @@ func TestFastHTTPNodeApplyMiddlewareInvalidPath(t *testing.T) { router := NewFastHTTPRouter().(*fastHTTPRouter) - router.GET("/x/{param}", func(ctx *fasthttp.RequestCtx) { + router.GET("/x/{param:[0-9]+}", func(ctx *fasthttp.RequestCtx) { params := ctx.UserValue("params").(context.Params) - fmt.Fprintf(ctx, "%s", params.Value("param")) + if _, err := fmt.Fprintf(ctx, "%s", params.Value("param")); err != nil { + t.Fatal(err) + } }) router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m2")) @@ -546,7 +566,9 @@ func TestFastHTTPMountSubRouter(t *testing.T) { ).(*fastHTTPRouter) subRouter.GET("/y", func(ctx *fasthttp.RequestCtx) { - fmt.Fprintf(ctx, "[s]") + if _, err := fmt.Fprintf(ctx, "[s]"); err != nil { + t.Fatal(err) + } }) mainRouter.Mount("/{param}", subRouter.HandleFastHTTP) diff --git a/mux/tree.go b/mux/tree.go index dc9a7b1..303e369 100644 --- a/mux/tree.go +++ b/mux/tree.go @@ -26,17 +26,17 @@ func (t Tree) PrettyPrint() string { for _, child := range t { switch node := child.(type) { case *staticNode: - fmt.Fprintf(buff, "\t%s\n", node.Name()) + _, _ = fmt.Fprintf(buff, "\t%s\n", node.Name()) case *wildcardNode: - fmt.Fprintf(buff, "\t{%s}\n", node.Name()) + _, _ = fmt.Fprintf(buff, "\t{%s}\n", node.Name()) case *regexpNode: - fmt.Fprintf(buff, "\t{%s:%s}\n", node.Name(), node.regexp.String()) + _, _ = fmt.Fprintf(buff, "\t{%s:%s}\n", node.Name(), node.regexp.String()) case *subrouterNode: - fmt.Fprintf(buff, "\t_%s\n", node.Name()) + _, _ = fmt.Fprintf(buff, "\t_%s\n", node.Name()) } if len(child.Tree()) > 0 { - fmt.Fprintf(buff, "\t%s", child.Tree().PrettyPrint()) + _, _ = fmt.Fprintf(buff, "\t%s", child.Tree().PrettyPrint()) } } @@ -124,7 +124,7 @@ func (t Tree) WithRoute(path string, route Route, maxParamsSize uint8) Tree { // WithMiddleware returns new Tree with Middleware appended to given Node // Middleware is appended to Node under the give path, if Node does not exist it will panic -func (t Tree) WithMiddleware(path string, m middleware.Middleware, maxParamsSize uint8) Tree { +func (t Tree) WithMiddleware(path string, m middleware.Middleware) Tree { path = pathutils.TrimSlash(path) if path == "" { return t @@ -135,6 +135,11 @@ func (t Tree) WithMiddleware(path string, m middleware.Middleware, maxParamsSize node := t.Find(name) newTree := t + // try to find node by matching name against nodes + if node == nil { + node, _, _, _ = t.Match(name) + } + if node == nil { panic("Could not find node for given path") } @@ -142,7 +147,7 @@ func (t Tree) WithMiddleware(path string, m middleware.Middleware, maxParamsSize if len(parts) == 1 { node.AppendMiddleware(m) } else { - node.WithChildren(node.Tree().WithMiddleware(strings.Join(parts[1:], "/"), m, node.MaxParamsSize())) + node.WithChildren(node.Tree().WithMiddleware(strings.Join(parts[1:], "/"), m)) } return newTree diff --git a/nethttp.go b/nethttp.go index 6ce8495..db8ac15 100644 --- a/nethttp.go +++ b/nethttp.go @@ -69,7 +69,7 @@ func (r *router) TRACE(p string, f http.Handler) { func (r *router) USE(method, path string, fs ...MiddlewareFunc) { m := transformMiddlewareFunc(fs...) - r.routes = r.routes.WithMiddleware(method+path, m, 0) + r.routes = r.routes.WithMiddleware(method+path, m) } func (r *router) Handle(method, path string, h http.Handler) { @@ -128,9 +128,8 @@ func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { if node, treeMiddleware, params, subPath := root.Tree().Match(path); node != nil && node.Route() != nil { route := node.Route() handler := route.Handler() - middleware := root.Middleware().Merge(treeMiddleware) - middleware = r.middleware.Merge(middleware) - computedHandler := middleware.Compose(handler) + allMiddleware := r.middleware.Merge(root.Middleware().Merge(treeMiddleware)) + computedHandler := allMiddleware.Compose(handler) h := computedHandler.(http.Handler) diff --git a/nethttp_test.go b/nethttp_test.go index 1390690..141ad0d 100644 --- a/nethttp_test.go +++ b/nethttp_test.go @@ -164,7 +164,9 @@ func TestNotFound(t *testing.T) { } router.NotFound(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Write([]byte("test")) + if _, err := w.Write([]byte("test")); err != nil { + t.Fatal(err) + } })) if router.notFound == nil { @@ -200,7 +202,9 @@ func TestNotAllowed(t *testing.T) { } router.NotAllowed(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Write([]byte("test")) + if _, err := w.Write([]byte("test")); err != nil { + t.Fatal(err) + } })) if router.notAllowed == nil { @@ -345,7 +349,9 @@ func TestNilMiddleware(t *testing.T) { router := New().(*router) router.GET("/x/{param}", http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Write([]byte("test")) + if _, err := w.Write([]byte("test")); err != nil { + t.Fatal(err) + } })) w := httptest.NewRecorder() @@ -406,7 +412,9 @@ func TestNodeApplyMiddleware(t *testing.T) { t.Fatal("Error while reading param") } - w.Write([]byte(params.Value("param"))) + if _, err := w.Write([]byte(params.Value("param"))); err != nil { + t.Fatal(err) + } })) router.USE(http.MethodGet, "/x/{param}", mockMiddleware("m1")) @@ -422,6 +430,19 @@ func TestNodeApplyMiddleware(t *testing.T) { if w.Body.String() != "m1y" { t.Errorf("Use middleware error %s", w.Body.String()) } + + router.USE(http.MethodGet, "/x/x", mockMiddleware("m2")) + + req, err = http.NewRequest(http.MethodGet, "/x/x", nil) + if err != nil { + t.Fatal(err) + } + + router.ServeHTTP(w, req) + + if w.Body.String() != "m1m2x" { + t.Errorf("Use middleware error %s", w.Body.String()) + } } func TestNodeApplyMiddlewareInvalidPath(t *testing.T) { @@ -436,13 +457,15 @@ func TestNodeApplyMiddlewareInvalidPath(t *testing.T) { router := New().(*router) - router.GET("/x/{param}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + router.GET("/x/{param:[0-9]+}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { params, ok := context.Parameters(r.Context()) if !ok { t.Fatal("Error while reading param") } - w.Write([]byte(params.Value("param"))) + if _, err := w.Write([]byte(params.Value("param"))); err != nil { + t.Fatal(err) + } })) router.USE(http.MethodGet, "/x/x", mockMiddleware("m2")) @@ -604,7 +627,9 @@ func TestMountSubRouter(t *testing.T) { ).(*router) subRouter.GET("/y", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("[s]")) + if _, err := w.Write([]byte("[s]")); err != nil { + t.Fatal(err) + } })) mainRouter.Mount("/{param}", subRouter) From bbde4644b453244c713e05a2b7d740e77976941f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Sat, 18 Jan 2020 12:10:47 +1100 Subject: [PATCH 09/41] Fix test, separate outputs --- nethttp_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/nethttp_test.go b/nethttp_test.go index 141ad0d..489471b 100644 --- a/nethttp_test.go +++ b/nethttp_test.go @@ -433,6 +433,7 @@ func TestNodeApplyMiddleware(t *testing.T) { router.USE(http.MethodGet, "/x/x", mockMiddleware("m2")) + w = httptest.NewRecorder() req, err = http.NewRequest(http.MethodGet, "/x/x", nil) if err != nil { t.Fatal(err) From 931376e831b11852ad64618bd1acab00d78e1cc9 Mon Sep 17 00:00:00 2001 From: mar1n3r0 Date: Sun, 19 Jan 2020 13:23:02 +0200 Subject: [PATCH 10/41] Add invalid node reference case to middleware tests --- fasthttp_test.go | 23 +++++++++++++++++++++++ nethttp_test.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/fasthttp_test.go b/fasthttp_test.go index a3eae7d..67c903b 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -412,6 +412,29 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { } } +func TestFastHTTPNodeApplyMiddlewareInvalidNodeReference(t *testing.T) { + t.Parallel() + + router := NewFastHTTPRouter().(*fastHTTPRouter) + + router.GET("/x/{param}", func(ctx *fasthttp.RequestCtx) { + params := ctx.UserValue("params").(context.Params) + if _, err := fmt.Fprintf(ctx, "%s", params.Value("param")); err != nil { + t.Fatal(err) + } + }) + + router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m1")) + + ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/y") + + router.HandleFastHTTP(ctx) + + if string(ctx.Response.Body()) != "y" { + t.Errorf("Use middleware error %s", string(ctx.Response.Body())) + } +} + func TestFastHTTPNodeApplyMiddlewareInvalidPath(t *testing.T) { t.Parallel() diff --git a/nethttp_test.go b/nethttp_test.go index 489471b..4e91403 100644 --- a/nethttp_test.go +++ b/nethttp_test.go @@ -446,6 +446,37 @@ func TestNodeApplyMiddleware(t *testing.T) { } } +func TestNodeApplyMiddlewareInvalidNodeReference(t *testing.T) { + t.Parallel() + + router := New().(*router) + + router.GET("/x/{param}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + params, ok := context.Parameters(r.Context()) + if !ok { + t.Fatal("Error while reading param") + } + + if _, err := w.Write([]byte(params.Value("param"))); err != nil { + t.Fatal(err) + } + })) + + router.USE(http.MethodGet, "/x/x", mockMiddleware("m1")) + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/x/y", nil) + if err != nil { + t.Fatal(err) + } + + router.ServeHTTP(w, req) + + if w.Body.String() != "y" { + t.Errorf("Use middleware error %s", w.Body.String()) + } +} + func TestNodeApplyMiddlewareInvalidPath(t *testing.T) { t.Parallel() From cfe703a18303ea97b61e9d9d5d0b83cdf01539b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Mon, 20 Jan 2020 22:02:45 +1100 Subject: [PATCH 11/41] Consider orphan nodes middleware while finding route node match --- fasthttp.go | 2 +- mux/tree.go | 49 +++++++++++++++++++++++++++++++++++++++---------- nethttp.go | 2 +- 3 files changed, 41 insertions(+), 12 deletions(-) diff --git a/fasthttp.go b/fasthttp.go index ed82e90..81958c3 100644 --- a/fasthttp.go +++ b/fasthttp.go @@ -68,7 +68,7 @@ func (r *fastHTTPRouter) TRACE(p string, f fasthttp.RequestHandler) { func (r *fastHTTPRouter) USE(method, path string, fs ...FastHTTPMiddlewareFunc) { m := transformFastHTTPMiddlewareFunc(fs...) - r.routes = r.routes.WithMiddleware(method+path, m) + r.routes = r.routes.WithMiddleware(method+path, m, 0) } func (r *fastHTTPRouter) Handle(method, path string, h fasthttp.RequestHandler) { diff --git a/mux/tree.go b/mux/tree.go index 303e369..a3df685 100644 --- a/mux/tree.go +++ b/mux/tree.go @@ -19,6 +19,14 @@ func NewTree() Tree { // Tree slice of mux Nodes type Tree []Node +// Match represents path match data struct +type match struct { + node Node + middleware middleware.Middleware + params context.Params + subPath string +} + // PrettyPrint prints the tree text representation to console func (t Tree) PrettyPrint() string { buff := &bytes.Buffer{} @@ -71,16 +79,41 @@ func (t Tree) Compile() Tree { // Match path to Node func (t Tree) Match(path string) (Node, middleware.Middleware, context.Params, string) { + var orphanMatches []match + for _, child := range t { if node, m, params, subPath := child.Match(path); node != nil { - return node, m, params, subPath + if node.Route() != nil { + if len(orphanMatches) > 0 { + for i := len(orphanMatches) - 1; i >= 0; i-- { + m = orphanMatches[i].node.Middleware().Merge(m) + } + } + + return node, m, params, subPath + } + + orphanMatch := match{ + node: node, + middleware: m, + params: params, + subPath: subPath, + } + orphanMatches = append(orphanMatches, orphanMatch) } } + // no route found, return first orphan match + if len(orphanMatches) > 0 { + firstOrphanMatch := orphanMatches[0] + + return firstOrphanMatch.node, firstOrphanMatch.middleware, firstOrphanMatch.params, firstOrphanMatch.subPath + } + return nil, nil, nil, "" } -// Find Node inside a tree by name +// Find finds Node inside a tree by name func (t Tree) Find(name string) Node { if name == "" { return nil @@ -124,7 +157,7 @@ func (t Tree) WithRoute(path string, route Route, maxParamsSize uint8) Tree { // WithMiddleware returns new Tree with Middleware appended to given Node // Middleware is appended to Node under the give path, if Node does not exist it will panic -func (t Tree) WithMiddleware(path string, m middleware.Middleware) Tree { +func (t Tree) WithMiddleware(path string, m middleware.Middleware, maxParamsSize uint8) Tree { path = pathutils.TrimSlash(path) if path == "" { return t @@ -135,19 +168,15 @@ func (t Tree) WithMiddleware(path string, m middleware.Middleware) Tree { node := t.Find(name) newTree := t - // try to find node by matching name against nodes - if node == nil { - node, _, _, _ = t.Match(name) - } - if node == nil { - panic("Could not find node for given path") + node = NewNode(parts[0], maxParamsSize) + newTree = t.withNode(node) } if len(parts) == 1 { node.AppendMiddleware(m) } else { - node.WithChildren(node.Tree().WithMiddleware(strings.Join(parts[1:], "/"), m)) + node.WithChildren(node.Tree().WithMiddleware(strings.Join(parts[1:], "/"), m, node.MaxParamsSize())) } return newTree diff --git a/nethttp.go b/nethttp.go index db8ac15..cc7ee20 100644 --- a/nethttp.go +++ b/nethttp.go @@ -69,7 +69,7 @@ func (r *router) TRACE(p string, f http.Handler) { func (r *router) USE(method, path string, fs ...MiddlewareFunc) { m := transformMiddlewareFunc(fs...) - r.routes = r.routes.WithMiddleware(method+path, m) + r.routes = r.routes.WithMiddleware(method+path, m, 0) } func (r *router) Handle(method, path string, h http.Handler) { From 1853262ea0bcff88033c604a040dba4fe44dd5d4 Mon Sep 17 00:00:00 2001 From: mar1n3r0 Date: Mon, 20 Jan 2020 13:37:49 +0200 Subject: [PATCH 12/41] Add static case to middleware tests --- fasthttp_test.go | 19 +++++++++++++++++++ nethttp_test.go | 23 +++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/fasthttp_test.go b/fasthttp_test.go index 67c903b..4d6d11c 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -412,6 +412,25 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { } } +func TestFastHTTPNodeApplyMiddlewareStatic(t *testing.T) { + t.Parallel() + + router := NewFastHTTPRouter().(*fastHTTPRouter) + + router.GET("/x/x", func(ctx *fasthttp.RequestCtx) { + }) + + router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m1")) + + ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/x") + + router.HandleFastHTTP(ctx) + + if string(ctx.Response.Body()) != "m1x" { + t.Errorf("Use middleware error %s", string(ctx.Response.Body())) + } +} + func TestFastHTTPNodeApplyMiddlewareInvalidNodeReference(t *testing.T) { t.Parallel() diff --git a/nethttp_test.go b/nethttp_test.go index 4e91403..d50194e 100644 --- a/nethttp_test.go +++ b/nethttp_test.go @@ -446,6 +446,29 @@ func TestNodeApplyMiddleware(t *testing.T) { } } +func TestNodeApplyMiddlewareStatic(t *testing.T) { + t.Parallel() + + router := New().(*router) + + router.GET("/x/{param}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + + router.USE(http.MethodGet, "/x/x", mockMiddleware("m1")) + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/x/x", nil) + if err != nil { + t.Fatal(err) + } + + router.ServeHTTP(w, req) + + if w.Body.String() != "m1x" { + t.Errorf("Use middleware error %s", w.Body.String()) + } +} + func TestNodeApplyMiddlewareInvalidNodeReference(t *testing.T) { t.Parallel() From 9c15d4839af7069e4b41c2ff1eae5409a1a94a6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Tue, 21 Jan 2020 08:11:15 +1100 Subject: [PATCH 13/41] Revers order of orphan middleware --- mux/tree.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mux/tree.go b/mux/tree.go index a3df685..fb925d1 100644 --- a/mux/tree.go +++ b/mux/tree.go @@ -85,7 +85,7 @@ func (t Tree) Match(path string) (Node, middleware.Middleware, context.Params, s if node, m, params, subPath := child.Match(path); node != nil { if node.Route() != nil { if len(orphanMatches) > 0 { - for i := len(orphanMatches) - 1; i >= 0; i-- { + for i := 0; i < len(orphanMatches); i++ { m = orphanMatches[i].node.Middleware().Merge(m) } } From 2035e11d53e10abf1537e70787e7543f5ea4dbde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Tue, 21 Jan 2020 21:21:44 +1100 Subject: [PATCH 14/41] Fix handler body --- fasthttp_test.go | 3 +++ nethttp_test.go | 3 +++ 2 files changed, 6 insertions(+) diff --git a/fasthttp_test.go b/fasthttp_test.go index 4d6d11c..de85759 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -418,6 +418,9 @@ func TestFastHTTPNodeApplyMiddlewareStatic(t *testing.T) { router := NewFastHTTPRouter().(*fastHTTPRouter) router.GET("/x/x", func(ctx *fasthttp.RequestCtx) { + if _, err := fmt.Fprintf(ctx, "x"); err != nil { + t.Fatal(err) + } }) router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m1")) diff --git a/nethttp_test.go b/nethttp_test.go index d50194e..4ca8079 100644 --- a/nethttp_test.go +++ b/nethttp_test.go @@ -452,6 +452,9 @@ func TestNodeApplyMiddlewareStatic(t *testing.T) { router := New().(*router) router.GET("/x/{param}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, err := w.Write([]byte("x")); err != nil { + t.Fatal(err) + } })) router.USE(http.MethodGet, "/x/x", mockMiddleware("m1")) From c571661533cf426f9f9c52d98e715b4b7a9a1284 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Tue, 21 Jan 2020 21:29:11 +1100 Subject: [PATCH 15/41] Static middleware is applied before wildcard one (due to nodes sorting in the tree) --- fasthttp_test.go | 35 ++--------------------------------- nethttp_test.go | 43 ++----------------------------------------- 2 files changed, 4 insertions(+), 74 deletions(-) diff --git a/fasthttp_test.go b/fasthttp_test.go index de85759..cf4004b 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -392,6 +392,7 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { }) router.USE(http.MethodGet, "/x/{param}", mockFastHTTPMiddleware("m1")) + router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m2")) ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/y") @@ -401,13 +402,11 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { t.Errorf("Use middleware error %s", string(ctx.Response.Body())) } - router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m2")) - ctx = buildFastHTTPRequestContext(http.MethodGet, "/x/x") router.HandleFastHTTP(ctx) - if string(ctx.Response.Body()) != "m1m2x" { + if string(ctx.Response.Body()) != "m2m1x" { t.Errorf("Use middleware error %s", string(ctx.Response.Body())) } } @@ -457,36 +456,6 @@ func TestFastHTTPNodeApplyMiddlewareInvalidNodeReference(t *testing.T) { } } -func TestFastHTTPNodeApplyMiddlewareInvalidPath(t *testing.T) { - t.Parallel() - - panicked := false - defer func() { - if rcv := recover(); rcv != nil { - panicked = true - } - }() - - router := NewFastHTTPRouter().(*fastHTTPRouter) - - router.GET("/x/{param:[0-9]+}", func(ctx *fasthttp.RequestCtx) { - params := ctx.UserValue("params").(context.Params) - if _, err := fmt.Fprintf(ctx, "%s", params.Value("param")); err != nil { - t.Fatal(err) - } - }) - - router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m2")) - - ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/x") - - router.HandleFastHTTP(ctx) - - if panicked != true { - t.Error("Router should panic for invalid middleware path") - } -} - func TestFastHTTPChainCalls(t *testing.T) { t.Parallel() diff --git a/nethttp_test.go b/nethttp_test.go index 4ca8079..5984938 100644 --- a/nethttp_test.go +++ b/nethttp_test.go @@ -418,6 +418,7 @@ func TestNodeApplyMiddleware(t *testing.T) { })) router.USE(http.MethodGet, "/x/{param}", mockMiddleware("m1")) + router.USE(http.MethodGet, "/x/x", mockMiddleware("m2")) w := httptest.NewRecorder() req, err := http.NewRequest(http.MethodGet, "/x/y", nil) @@ -431,8 +432,6 @@ func TestNodeApplyMiddleware(t *testing.T) { t.Errorf("Use middleware error %s", w.Body.String()) } - router.USE(http.MethodGet, "/x/x", mockMiddleware("m2")) - w = httptest.NewRecorder() req, err = http.NewRequest(http.MethodGet, "/x/x", nil) if err != nil { @@ -441,7 +440,7 @@ func TestNodeApplyMiddleware(t *testing.T) { router.ServeHTTP(w, req) - if w.Body.String() != "m1m2x" { + if w.Body.String() != "m2m1x" { t.Errorf("Use middleware error %s", w.Body.String()) } } @@ -503,44 +502,6 @@ func TestNodeApplyMiddlewareInvalidNodeReference(t *testing.T) { } } -func TestNodeApplyMiddlewareInvalidPath(t *testing.T) { - t.Parallel() - - panicked := false - defer func() { - if rcv := recover(); rcv != nil { - panicked = true - } - }() - - router := New().(*router) - - router.GET("/x/{param:[0-9]+}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - params, ok := context.Parameters(r.Context()) - if !ok { - t.Fatal("Error while reading param") - } - - if _, err := w.Write([]byte(params.Value("param"))); err != nil { - t.Fatal(err) - } - })) - - router.USE(http.MethodGet, "/x/x", mockMiddleware("m2")) - - w := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/x/x", nil) - if err != nil { - t.Fatal(err) - } - - router.ServeHTTP(w, req) - - if panicked != true { - t.Error("Router should panic for invalid middleware path") - } -} - func TestChainCalls(t *testing.T) { t.Parallel() From 2ce51cd25a392a2f5779ce2111db247f805018cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Thu, 23 Jan 2020 21:43:49 +1100 Subject: [PATCH 16/41] Append orphan nodes middleware instead of prepending --- fasthttp_test.go | 33 +++++++++++++++++++++++++++++++++ mux/tree.go | 2 +- nethttp_test.go | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/fasthttp_test.go b/fasthttp_test.go index cf4004b..37c301a 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -411,6 +411,39 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { } } +func TestFastHTTPTreeOrphanMiddlewareOrder(t *testing.T) { + t.Parallel() + + router := NewFastHTTPRouter().(*fastHTTPRouter) + + router.GET("/x/{param}", func(ctx *fasthttp.RequestCtx) { + if _, err := fmt.Fprintf(ctx, "handler"); err != nil { + t.Fatal(err) + } + }) + + // Method global middleware + router.USE(http.MethodGet, "/", mockFastHTTPMiddleware("m1->")) + router.USE(http.MethodGet, "/", mockFastHTTPMiddleware("m2->")) + // Path middleware + router.USE(http.MethodGet, "/x", mockFastHTTPMiddleware("mx1->")) + router.USE(http.MethodGet, "/x", mockFastHTTPMiddleware("mx2->")) + router.USE(http.MethodGet, "/x/y", mockFastHTTPMiddleware("mxy1->")) + router.USE(http.MethodGet, "/x/y", mockFastHTTPMiddleware("mxy2->")) + router.USE(http.MethodGet, "/x/{param}", mockFastHTTPMiddleware("mparam1->")) + router.USE(http.MethodGet, "/x/{param}", mockFastHTTPMiddleware("mparam2->")) + router.USE(http.MethodGet, "/x/y", mockFastHTTPMiddleware("mxy3->")) + router.USE(http.MethodGet, "/x/y", mockFastHTTPMiddleware("mxy4->")) + + ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/y") + + router.HandleFastHTTP(ctx) + + if string(ctx.Response.Body()) != "m1->m2->mx1->mx2->mparam1->mparam2->mxy1->mxy2->mxy3->mxy4->handler" { + t.Errorf("Use middleware error %s", string(ctx.Response.Body())) + } +} + func TestFastHTTPNodeApplyMiddlewareStatic(t *testing.T) { t.Parallel() diff --git a/mux/tree.go b/mux/tree.go index fb925d1..f273d26 100644 --- a/mux/tree.go +++ b/mux/tree.go @@ -86,7 +86,7 @@ func (t Tree) Match(path string) (Node, middleware.Middleware, context.Params, s if node.Route() != nil { if len(orphanMatches) > 0 { for i := 0; i < len(orphanMatches); i++ { - m = orphanMatches[i].node.Middleware().Merge(m) + m = m.Merge(orphanMatches[i].node.Middleware()) } } diff --git a/nethttp_test.go b/nethttp_test.go index 5984938..645e8e2 100644 --- a/nethttp_test.go +++ b/nethttp_test.go @@ -445,6 +445,43 @@ func TestNodeApplyMiddleware(t *testing.T) { } } +func TestTreeOrphanMiddlewareOrder(t *testing.T) { + t.Parallel() + + router := New().(*router) + + router.GET("/x/{param}", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, err := w.Write([]byte("handler")); err != nil { + t.Fatal(err) + } + })) + + // Method global middleware + router.USE(http.MethodGet, "/", mockMiddleware("m1->")) + router.USE(http.MethodGet, "/", mockMiddleware("m2->")) + // Path middleware + router.USE(http.MethodGet, "/x", mockMiddleware("mx1->")) + router.USE(http.MethodGet, "/x", mockMiddleware("mx2->")) + router.USE(http.MethodGet, "/x/y", mockMiddleware("mxy1->")) + router.USE(http.MethodGet, "/x/y", mockMiddleware("mxy2->")) + router.USE(http.MethodGet, "/x/{param}", mockMiddleware("mparam1->")) + router.USE(http.MethodGet, "/x/{param}", mockMiddleware("mparam2->")) + router.USE(http.MethodGet, "/x/y", mockMiddleware("mxy3->")) + router.USE(http.MethodGet, "/x/y", mockMiddleware("mxy4->")) + + w := httptest.NewRecorder() + req, err := http.NewRequest(http.MethodGet, "/x/y", nil) + if err != nil { + t.Fatal(err) + } + + router.ServeHTTP(w, req) + + if w.Body.String() != "m1->m2->mx1->mx2->mparam1->mparam2->mxy1->mxy2->mxy3->mxy4->handler" { + t.Errorf("Use middleware error %s", w.Body.String()) + } +} + func TestNodeApplyMiddlewareStatic(t *testing.T) { t.Parallel() From 5468041f673d2287a962657ccd5c63820a68cecc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Fri, 24 Jan 2020 22:49:39 +1100 Subject: [PATCH 17/41] Update test to comply with new order --- fasthttp_test.go | 2 +- nethttp_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fasthttp_test.go b/fasthttp_test.go index 37c301a..dd4879c 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -406,7 +406,7 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { router.HandleFastHTTP(ctx) - if string(ctx.Response.Body()) != "m2m1x" { + if string(ctx.Response.Body()) != "m1m2x" { t.Errorf("Use middleware error %s", string(ctx.Response.Body())) } } diff --git a/nethttp_test.go b/nethttp_test.go index 645e8e2..8366c22 100644 --- a/nethttp_test.go +++ b/nethttp_test.go @@ -440,7 +440,7 @@ func TestNodeApplyMiddleware(t *testing.T) { router.ServeHTTP(w, req) - if w.Body.String() != "m2m1x" { + if w.Body.String() != "m1m2x" { t.Errorf("Use middleware error %s", w.Body.String()) } } From 3cd9b3237d92375f4bcfdf60093f953dc84d2fe2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Fri, 24 Jan 2020 23:07:38 +1100 Subject: [PATCH 18/41] Extract sort method --- mux/benchmark_test.go | 20 ++++++++++---------- mux/example_test.go | 22 +++++++++++----------- mux/tree.go | 20 ++++++++++++-------- mux/tree_test.go | 22 +++++++++++----------- 4 files changed, 44 insertions(+), 40 deletions(-) diff --git a/mux/benchmark_test.go b/mux/benchmark_test.go index a862a31..c1b142a 100644 --- a/mux/benchmark_test.go +++ b/mux/benchmark_test.go @@ -21,16 +21,16 @@ func BenchmarkMux(b *testing.B) { commentID := NewNode(`{commentId:\d+}`, comments.MaxParamsSize()) commentNew := NewNode("new", commentID.MaxParamsSize()) - root.WithChildren(root.Tree().withNode(lang)) - lang.WithChildren(lang.Tree().withNode(blog)) - blog.WithChildren(blog.Tree().withNode(search)) - blog.WithChildren(blog.Tree().withNode(page)) - blog.WithChildren(blog.Tree().withNode(posts)) - blog.WithChildren(blog.Tree().withNode(comments)) - page.WithChildren(page.Tree().withNode(pageID)) - posts.WithChildren(posts.Tree().withNode(postsID)) - comments.WithChildren(comments.Tree().withNode(commentID)) - commentID.WithChildren(commentID.Tree().withNode(commentNew)) + root.WithChildren(root.Tree().withNode(lang).sort()) + lang.WithChildren(lang.Tree().withNode(blog).sort()) + blog.WithChildren(blog.Tree().withNode(search).sort()) + blog.WithChildren(blog.Tree().withNode(page).sort()) + blog.WithChildren(blog.Tree().withNode(posts).sort()) + blog.WithChildren(blog.Tree().withNode(comments).sort()) + page.WithChildren(page.Tree().withNode(pageID).sort()) + posts.WithChildren(posts.Tree().withNode(postsID).sort()) + comments.WithChildren(comments.Tree().withNode(commentID).sort()) + commentID.WithChildren(commentID.Tree().withNode(commentNew).sort()) root.WithChildren(root.Tree().Compile()) diff --git a/mux/example_test.go b/mux/example_test.go index ef2ed7b..b5b14ee 100644 --- a/mux/example_test.go +++ b/mux/example_test.go @@ -21,17 +21,17 @@ func Example() { commentID := NewNode(`{commentId:\d+}`, comments.MaxParamsSize()) commentNew := NewNode("new", commentID.MaxParamsSize()) - root.WithChildren(root.Tree().withNode(lang)) - lang.WithChildren(lang.Tree().withNode(blog)) - blog.WithChildren(blog.Tree().withNode(search)) - blog.WithChildren(blog.Tree().withNode(page)) - blog.WithChildren(blog.Tree().withNode(posts)) - blog.WithChildren(blog.Tree().withNode(comments)) - search.WithChildren(search.Tree().withNode(searchAuthor)) - page.WithChildren(page.Tree().withNode(pageID)) - posts.WithChildren(posts.Tree().withNode(postsID)) - comments.WithChildren(comments.Tree().withNode(commentID)) - commentID.WithChildren(commentID.Tree().withNode(commentNew)) + root.WithChildren(root.Tree().withNode(lang).sort()) + lang.WithChildren(lang.Tree().withNode(blog).sort()) + blog.WithChildren(blog.Tree().withNode(search).sort()) + blog.WithChildren(blog.Tree().withNode(page).sort()) + blog.WithChildren(blog.Tree().withNode(posts).sort()) + blog.WithChildren(blog.Tree().withNode(comments).sort()) + search.WithChildren(search.Tree().withNode(searchAuthor).sort()) + page.WithChildren(page.Tree().withNode(pageID).sort()) + posts.WithChildren(posts.Tree().withNode(postsID).sort()) + comments.WithChildren(comments.Tree().withNode(commentID).sort()) + commentID.WithChildren(commentID.Tree().withNode(commentNew).sort()) fmt.Printf("Raw tree:\n") fmt.Print(root.Tree().PrettyPrint()) diff --git a/mux/tree.go b/mux/tree.go index f273d26..83bc8c6 100644 --- a/mux/tree.go +++ b/mux/tree.go @@ -77,7 +77,7 @@ func (t Tree) Compile() Tree { return t } -// Match path to Node +// Match path to first Node func (t Tree) Match(path string) (Node, middleware.Middleware, context.Params, string) { var orphanMatches []match @@ -143,7 +143,7 @@ func (t Tree) WithRoute(path string, route Route, maxParamsSize uint8) Tree { if node == nil { node = NewNode(parts[0], maxParamsSize) - newTree = t.withNode(node) + newTree = t.withNode(node).sort() } if len(parts) == 1 { @@ -170,7 +170,7 @@ func (t Tree) WithMiddleware(path string, m middleware.Middleware, maxParamsSize if node == nil { node = NewNode(parts[0], maxParamsSize) - newTree = t.withNode(node) + newTree = t.withNode(node).sort() } if len(parts) == 1 { @@ -200,7 +200,7 @@ func (t Tree) WithSubrouter(path string, route Route, maxParamsSize uint8) Tree if len(parts) == 1 { node = withSubrouter(node) } - newTree = t.withNode(node) + newTree = t.withNode(node).sort() } if len(parts) == 1 { @@ -213,7 +213,6 @@ func (t Tree) WithSubrouter(path string, route Route, maxParamsSize uint8) Tree } // withNode inserts node to Tree -// Nodes are sorted static, regexp, wildcard func (t Tree) withNode(node Node) Tree { if node == nil { return t @@ -221,12 +220,17 @@ func (t Tree) withNode(node Node) Tree { newTree := append(t, node) + return newTree +} + +// Sort sorts nodes in order: static, regexp, wildcard +func (t Tree) sort() Tree { // Sort Nodes in order [statics, regexps, wildcards] - sort.Slice(newTree, func(i, j int) bool { - return isMoreImportant(newTree[i], newTree[j]) + sort.Slice(t, func(i, j int) bool { + return isMoreImportant(t[i], t[j]) }) - return newTree + return t } func isMoreImportant(left Node, right Node) bool { diff --git a/mux/tree_test.go b/mux/tree_test.go index c5dc3b7..6b2294a 100644 --- a/mux/tree_test.go +++ b/mux/tree_test.go @@ -23,17 +23,17 @@ func TestTreeMatch(t *testing.T) { commentID := NewNode(`{commentId:\d+}`, comments.MaxParamsSize()) commentNew := NewNode("new", commentID.MaxParamsSize()) - root.WithChildren(root.Tree().withNode(lang)) - lang.WithChildren(lang.Tree().withNode(blog)) - blog.WithChildren(blog.Tree().withNode(search)) - blog.WithChildren(blog.Tree().withNode(page)) - blog.WithChildren(blog.Tree().withNode(posts)) - blog.WithChildren(blog.Tree().withNode(comments)) - search.WithChildren(search.Tree().withNode(searchAuthor)) - page.WithChildren(page.Tree().withNode(pageID)) - posts.WithChildren(posts.Tree().withNode(postsID)) - comments.WithChildren(comments.Tree().withNode(commentID)) - commentID.WithChildren(commentID.Tree().withNode(commentNew)) + root.WithChildren(root.Tree().withNode(lang).sort()) + lang.WithChildren(lang.Tree().withNode(blog).sort()) + blog.WithChildren(blog.Tree().withNode(search).sort()) + blog.WithChildren(blog.Tree().withNode(page).sort()) + blog.WithChildren(blog.Tree().withNode(posts).sort()) + blog.WithChildren(blog.Tree().withNode(comments).sort()) + search.WithChildren(search.Tree().withNode(searchAuthor).sort()) + page.WithChildren(page.Tree().withNode(pageID).sort()) + posts.WithChildren(posts.Tree().withNode(postsID).sort()) + comments.WithChildren(comments.Tree().withNode(commentID).sort()) + commentID.WithChildren(commentID.Tree().withNode(commentNew).sort()) root.WithChildren(root.Tree().Compile()) From e6c790ab3380969da62d21b29a53e4c3bfbdadea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Mon, 27 Jan 2020 08:34:17 +1100 Subject: [PATCH 19/41] Handle errors --- middleware/middleware_test.go | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 15a7dad..742e887 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -9,7 +9,9 @@ import ( func mockMiddleware(body string) MiddlewareFunc { fn := func(h interface{}) interface{} { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(body)) + if _, err := w.Write([]byte(body)); err != nil { + panic(err) + } h.(http.Handler).ServeHTTP(w, r) }) } @@ -22,7 +24,9 @@ func TestOrders(t *testing.T) { m2 := mockMiddleware("2") m3 := mockMiddleware("3") fn := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Write([]byte("4")) + if _, err := w.Write([]byte("4")); err != nil { + t.Fatal(err) + } }) m := New(m1, m2, m3) @@ -46,7 +50,9 @@ func TestAppend(t *testing.T) { m2 := mockMiddleware("2") m3 := mockMiddleware("3") fn := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Write([]byte("4")) + if _, err := w.Write([]byte("4")); err != nil { + t.Fatal(err) + } }) m := New(m1) @@ -71,7 +77,9 @@ func TestMerge(t *testing.T) { m2 := mockMiddleware("2") m3 := mockMiddleware("3") fn := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Write([]byte("4")) + if _, err := w.Write([]byte("4")); err != nil { + t.Fatal(err) + } }) m := New(m1) From 532090c79797ef780711bcac78edf0c33247de14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Mon, 27 Jan 2020 08:35:11 +1100 Subject: [PATCH 20/41] Split interface, match route and middleware separately [WIP] --- README.md | 24 ++++--- doc.go | 2 +- example_test.go | 22 +++---- fasthttp.go | 37 +++++++---- fasthttp_test.go | 18 +++--- mocks_test.go | 8 ++- mux/benchmark_test.go | 10 +-- mux/node.go | 144 +++++++++++++++++++++++++++--------------- mux/tree.go | 58 ++++++----------- mux/tree_test.go | 10 +-- nethttp.go | 37 +++++++---- nethttp_test.go | 18 +++--- route_test.go | 8 ++- router.go | 8 +-- tree.go | 17 +---- 15 files changed, 230 insertions(+), 191 deletions(-) diff --git a/README.md b/README.md index c1e8094..8b20a7c 100644 --- a/README.md +++ b/README.md @@ -42,19 +42,23 @@ import ( "github.com/vardius/gorouter/v4/context" ) -func Index(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "Welcome!\n") +func index(w http.ResponseWriter, _ *http.Request) { + if _, err := fmt.Fprint(w, "Welcome!\n"); err != nil { + panic(err) + } } -func Hello(w http.ResponseWriter, r *http.Request) { +func hello(w http.ResponseWriter, r *http.Request) { params, _ := context.Parameters(r.Context()) - fmt.Fprintf(w, "hello, %s!\n", params.Value("name")) + if _, err := fmt.Fprintf(w, "hello, %s!\n", params.Value("name")); err != nil { + panic(err) + } } func main() { router := gorouter.New() - router.GET("/", http.HandlerFunc(Index)) - router.GET("/hello/{name}", http.HandlerFunc(Hello)) + router.GET("/", http.HandlerFunc(index)) + router.GET("/hello/{name}", http.HandlerFunc(hello)) log.Fatal(http.ListenAndServe(":8080", router)) } @@ -71,19 +75,19 @@ import ( "github.com/vardius/gorouter/v4" ) -func Index(ctx *fasthttp.RequestCtx) { +func index(_ *fasthttp.RequestCtx) { fmt.Print("Welcome!\n") } -func Hello(ctx *fasthttp.RequestCtx) { +func hello(ctx *fasthttp.RequestCtx) { params := ctx.UserValue("params").(context.Params) fmt.Printf("Hello, %s!\n", params.Value("name")) } func main() { router := gorouter.NewFastHTTPRouter() - router.GET("/", Index) - router.GET("/hello/{name}", Hello) + router.GET("/", index) + router.GET("/hello/{name}", hello) log.Fatal(fasthttp.ListenAndServe(":8080", router.HandleFastHTTP)) } diff --git a/doc.go b/doc.go index f1f0705..95a2154 100644 --- a/doc.go +++ b/doc.go @@ -1,5 +1,5 @@ /* -Package gorouter provide request router with middleware +Package gorouter provide request router with globalMiddleware Router diff --git a/example_test.go b/example_test.go index 6e261ca..2e5d731 100644 --- a/example_test.go +++ b/example_test.go @@ -65,7 +65,7 @@ func Example_second() { } func ExampleMiddlewareFunc() { - // Global middleware example + // Global globalMiddleware example // applies to all routes hello := func(w http.ResponseWriter, r *http.Request) { params, _ := context.Parameters(r.Context()) @@ -81,7 +81,7 @@ func ExampleMiddlewareFunc() { return http.HandlerFunc(fn) } - // apply middleware to all routes + // apply globalMiddleware to all routes // can pass as many as you want router := gorouter.New(logger) router.GET("/hello/{name}", http.HandlerFunc(hello)) @@ -95,7 +95,7 @@ func ExampleMiddlewareFunc() { } func ExampleMiddlewareFunc_second() { - // Route level middleware example + // Route level globalMiddleware example // applies to route and its lower tree hello := func(w http.ResponseWriter, r *http.Request) { params, _ := context.Parameters(r.Context()) @@ -114,7 +114,7 @@ func ExampleMiddlewareFunc_second() { router := gorouter.New() router.GET("/hello/{name}", http.HandlerFunc(hello)) - // apply middleware to route and all it children + // apply globalMiddleware to route and all it children // can pass as many as you want router.USE("GET", "/hello/{name}", logger) @@ -127,7 +127,7 @@ func ExampleMiddlewareFunc_second() { } func ExampleMiddlewareFunc_third() { - // Http method middleware example + // Http method globalMiddleware example // applies to all routes under this method hello := func(w http.ResponseWriter, r *http.Request) { params, _ := context.Parameters(r.Context()) @@ -146,7 +146,7 @@ func ExampleMiddlewareFunc_third() { router := gorouter.New() router.GET("/hello/{name}", http.HandlerFunc(hello)) - // apply middleware to all routes with GET method + // apply globalMiddleware to all routes with GET method // can pass as many as you want router.USE("GET", "", logger) @@ -159,7 +159,7 @@ func ExampleMiddlewareFunc_third() { } func ExampleFastHTTPMiddlewareFunc() { - // Global middleware example + // Global globalMiddleware example // applies to all routes hello := func(ctx *fasthttp.RequestCtx) { params := ctx.UserValue("params").(context.Params) @@ -187,7 +187,7 @@ func ExampleFastHTTPMiddlewareFunc() { } func ExampleFastHTTPMiddlewareFunc_second() { - // Route level middleware example + // Route level globalMiddleware example // applies to route and its lower tree hello := func(ctx *fasthttp.RequestCtx) { params := ctx.UserValue("params").(context.Params) @@ -206,7 +206,7 @@ func ExampleFastHTTPMiddlewareFunc_second() { router := gorouter.NewFastHTTPRouter() router.GET("/hello/{name}", hello) - // apply middleware to route and all it children + // apply globalMiddleware to route and all it children // can pass as many as you want router.USE("GET", "/hello/{name}", logger) @@ -219,7 +219,7 @@ func ExampleFastHTTPMiddlewareFunc_second() { } func ExampleFastHTTPMiddlewareFunc_third() { - // Http method middleware example + // Http method globalMiddleware example // applies to all routes under this method hello := func(ctx *fasthttp.RequestCtx) { params := ctx.UserValue("params").(context.Params) @@ -238,7 +238,7 @@ func ExampleFastHTTPMiddlewareFunc_third() { router := gorouter.NewFastHTTPRouter() router.GET("/hello/{name}", hello) - // apply middleware to all routes with GET method + // apply globalMiddleware to all routes with GET method // can pass as many as you want router.USE("GET", "", logger) diff --git a/fasthttp.go b/fasthttp.go index 81958c3..6c442c9 100644 --- a/fasthttp.go +++ b/fasthttp.go @@ -12,17 +12,19 @@ import ( // NewFastHTTPRouter creates new Router instance, returns pointer func NewFastHTTPRouter(fs ...FastHTTPMiddlewareFunc) FastHTTPRouter { return &fastHTTPRouter{ - routes: mux.NewTree(), - middleware: transformFastHTTPMiddlewareFunc(fs...), + routes: mux.NewTree(), + middleware: mux.NewTree(), + globalMiddleware: transformFastHTTPMiddlewareFunc(fs...), } } type fastHTTPRouter struct { - routes mux.Tree - middleware middleware.Middleware - fileServer fasthttp.RequestHandler - notFound fasthttp.RequestHandler - notAllowed fasthttp.RequestHandler + routes mux.Tree // mux.RouteAware tree + middleware mux.Tree // mux.MiddlewareAware tree + globalMiddleware middleware.Middleware + fileServer fasthttp.RequestHandler + notFound fasthttp.RequestHandler + notAllowed fasthttp.RequestHandler } func (r *fastHTTPRouter) PrettyPrint() string { @@ -68,7 +70,7 @@ func (r *fastHTTPRouter) TRACE(p string, f fasthttp.RequestHandler) { func (r *fastHTTPRouter) USE(method, path string, fs ...FastHTTPMiddlewareFunc) { m := transformFastHTTPMiddlewareFunc(fs...) - r.routes = r.routes.WithMiddleware(method+path, m, 0) + r.middleware = r.middleware.WithMiddleware(method+path, m) } func (r *fastHTTPRouter) Handle(method, path string, h fasthttp.RequestHandler) { @@ -123,11 +125,13 @@ func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) { path := pathutils.TrimSlash(pathAsString) if root := r.routes.Find(method); root != nil { - if node, treeMiddleware, params, subPath := root.Tree().Match(path); node != nil && node.Route() != nil { - route := node.Route() - handler := route.Handler() - allMiddleware := r.middleware.Merge(root.Middleware().Merge(treeMiddleware)) - computedHandler := allMiddleware.Compose(handler) + if route, params, subPath := root.Tree().MatchRoute(path); route != nil { + allMiddleware := r.globalMiddleware + if treeMiddleware := r.middleware.MatchMiddleware(method + path); treeMiddleware != nil { + allMiddleware = allMiddleware.Merge(treeMiddleware) + } + + computedHandler := allMiddleware.Compose(route.Handler()) h := computedHandler.(fasthttp.RequestHandler) @@ -144,7 +148,12 @@ func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) { } if pathAsString == "/" && root.Route() != nil { - root.Route().Handler().(fasthttp.RequestHandler)(ctx) + rootMiddleware := r.globalMiddleware + if root.Middleware() != nil { + rootMiddleware = rootMiddleware.Merge(root.Middleware()) + } + rootHandler := rootMiddleware.Compose(root.Route().Handler()) + rootHandler.(fasthttp.RequestHandler)(ctx) return } } diff --git a/fasthttp_test.go b/fasthttp_test.go index dd4879c..a1bd0fa 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -341,7 +341,7 @@ func TestFastHTTPNilMiddleware(t *testing.T) { router.HandleFastHTTP(ctx) if string(ctx.Response.Body()) != "test" { - t.Error("Nil middleware works") + t.Error("Nil globalMiddleware works") } } @@ -399,7 +399,7 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { router.HandleFastHTTP(ctx) if string(ctx.Response.Body()) != "m1y" { - t.Errorf("Use middleware error %s", string(ctx.Response.Body())) + t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body())) } ctx = buildFastHTTPRequestContext(http.MethodGet, "/x/x") @@ -407,7 +407,7 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { router.HandleFastHTTP(ctx) if string(ctx.Response.Body()) != "m1m2x" { - t.Errorf("Use middleware error %s", string(ctx.Response.Body())) + t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body())) } } @@ -422,10 +422,10 @@ func TestFastHTTPTreeOrphanMiddlewareOrder(t *testing.T) { } }) - // Method global middleware + // Method global globalMiddleware router.USE(http.MethodGet, "/", mockFastHTTPMiddleware("m1->")) router.USE(http.MethodGet, "/", mockFastHTTPMiddleware("m2->")) - // Path middleware + // Path globalMiddleware router.USE(http.MethodGet, "/x", mockFastHTTPMiddleware("mx1->")) router.USE(http.MethodGet, "/x", mockFastHTTPMiddleware("mx2->")) router.USE(http.MethodGet, "/x/y", mockFastHTTPMiddleware("mxy1->")) @@ -440,7 +440,7 @@ func TestFastHTTPTreeOrphanMiddlewareOrder(t *testing.T) { router.HandleFastHTTP(ctx) if string(ctx.Response.Body()) != "m1->m2->mx1->mx2->mparam1->mparam2->mxy1->mxy2->mxy3->mxy4->handler" { - t.Errorf("Use middleware error %s", string(ctx.Response.Body())) + t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body())) } } @@ -462,7 +462,7 @@ func TestFastHTTPNodeApplyMiddlewareStatic(t *testing.T) { router.HandleFastHTTP(ctx) if string(ctx.Response.Body()) != "m1x" { - t.Errorf("Use middleware error %s", string(ctx.Response.Body())) + t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body())) } } @@ -485,7 +485,7 @@ func TestFastHTTPNodeApplyMiddlewareInvalidNodeReference(t *testing.T) { router.HandleFastHTTP(ctx) if string(ctx.Response.Body()) != "y" { - t.Errorf("Use middleware error %s", string(ctx.Response.Body())) + t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body())) } } @@ -631,6 +631,6 @@ func TestFastHTTPMountSubRouter(t *testing.T) { mainRouter.HandleFastHTTP(ctx) if string(ctx.Response.Body()) != "[rg1][rg2][r1][r2][sg1][sg2][s1][s2][s]" { - t.Errorf("Router mount sub router middleware error: %s", string(ctx.Response.Body())) + t.Errorf("Router mount sub router globalMiddleware error: %s", string(ctx.Response.Body())) } } diff --git a/mocks_test.go b/mocks_test.go index 0a120ef..b7c27b3 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -42,7 +42,9 @@ func (mfs *mockFileSystem) Open(_ string) (http.File, error) { func mockMiddleware(body string) MiddlewareFunc { fn := func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(body)) + if _, err := w.Write([]byte(body)); err != nil { + panic(err) + } h.ServeHTTP(w, r) }) } @@ -65,7 +67,9 @@ func mockServeHTTP(h http.Handler, method, path string) error { func mockFastHTTPMiddleware(body string) FastHTTPMiddlewareFunc { fn := func(h fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { - fmt.Fprintf(ctx, body) + if _, err := fmt.Fprintf(ctx, body); err != nil { + panic(err) + } h(ctx) } diff --git a/mux/benchmark_test.go b/mux/benchmark_test.go index c1b142a..fa32acb 100644 --- a/mux/benchmark_test.go +++ b/mux/benchmark_test.go @@ -37,14 +37,14 @@ func BenchmarkMux(b *testing.B) { b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - n, _, _, _ := root.Tree().Match("pl/blog/comments/123/new") + route, _, _ := root.Tree().MatchRoute("pl/blog/comments/123/new") - if n == nil { - b.Fatalf("%v", n) + if route == nil { + b.Fatalf("%v", route) } - if n.Name() != commentNew.Name() { - b.Fatalf("%s != %s", n.Name(), commentNew.Name()) + if route != commentNew.Route() { + b.Fatalf("%s != %s (%s)", route, commentNew.Route(), commentNew.Name()) } } }) diff --git a/mux/node.go b/mux/node.go index aa7c89a..19e3cf3 100644 --- a/mux/node.go +++ b/mux/node.go @@ -37,37 +37,48 @@ func NewNode(pathPart string, maxParamsSize uint8) Node { return node } -// Node represents mux Node -// Can match path and provide routes -type Node interface { - // Match matches given path to Node within Node and its Tree - Match(path string) (Node, middleware.Middleware, context.Params, string) +type RouteAware interface { + // MatchRoute matches given path to Route within Node and its Tree + MatchRoute(path string) (Route, context.Params, string) - // Name provides Node name - Name() string - // Tree provides next level Node Tree - Tree() Tree // Route provides Node's Route if assigned Route() Route - // Middleware provides Node's middleware - Middleware() middleware.Middleware + // WithRoute assigns Route to given Node + WithRoute(r Route) // Name provides maximum number of parameters Route can have for given Node MaxParamsSize() uint8 + // SkipSubPath sets skipSubPath node property to true + // will skip children match search and return current node directly + // this value is used when matching subrouter + SkipSubPath() +} - // WithRoute assigns Route to given Node - WithRoute(r Route) - // WithChildren sets Node's Tree - WithChildren(t Tree) +type MiddlewareAware interface { + // MatchMiddleware collects middleware from all nodes within tree matching given path + // middleware is merged in order nodes where created, collecting from top to bottom + MatchMiddleware(path string) middleware.Middleware + + // Middleware provides Node's middleware + Middleware() middleware.Middleware // AppendMiddleware appends middleware to Node AppendMiddleware(m middleware.Middleware) // PrependMiddleware prepends middleware to Node PrependMiddleware(m middleware.Middleware) +} - // SkipSubPath sets skipSubPath node property to true - // will skip children match search and return current node directly - // this value is used when matching subrouter - SkipSubPath() +// Node represents mux Node +// Can match path and provide routes +type Node interface { + RouteAware + MiddlewareAware + + // Name provides Node name + Name() string + // Tree provides next level Node Tree + Tree() Tree + // WithChildren sets Node's Tree + WithChildren(t Tree) } type staticNode struct { @@ -81,25 +92,43 @@ type staticNode struct { skipSubPath bool } -func (n *staticNode) Match(path string) (Node, middleware.Middleware, context.Params, string) { +func (n *staticNode) MatchRoute(path string) (Route, context.Params, string) { nameLength := len(n.name) pathLength := len(path) if pathLength >= nameLength && n.name == path[:nameLength] { if nameLength+1 >= pathLength { - return n, n.middleware, make(context.Params, n.maxParamsSize), "" + return n.route, make(context.Params, n.maxParamsSize), "" } if n.skipSubPath { - return n, n.middleware, make(context.Params, n.maxParamsSize), path[nameLength+1:] + return n.route, make(context.Params, n.maxParamsSize), path[nameLength+1:] } - node, treeMiddleware, params, p := n.children.Match(path[nameLength+1:]) // +1 because we wan to skip slash as well + return n.children.MatchRoute(path[nameLength+1:]) // +1 because we wan to skip slash as well + } + + return nil, nil, "" +} + +func (n *staticNode) MatchMiddleware(path string) middleware.Middleware { + nameLength := len(n.name) + pathLength := len(path) + + if pathLength >= nameLength && n.name == path[:nameLength] { + if nameLength+1 >= pathLength { + return n.middleware + } - return node, n.middleware.Merge(treeMiddleware), params, p + if treeMiddleware := n.children.MatchMiddleware(path[nameLength+1:]); treeMiddleware != nil { // +1 because we wan to skip slash as well + + return n.middleware.Merge(treeMiddleware) + } + + return n.middleware } - return nil, nil, nil, "" + return nil } func (n *staticNode) Name() string { @@ -150,31 +179,37 @@ type wildcardNode struct { *staticNode } -func (n *wildcardNode) Match(path string) (Node, middleware.Middleware, context.Params, string) { +func (n *wildcardNode) MatchRoute(path string) (Route, context.Params, string) { pathPart, subPath := pathutils.GetPart(path) maxParamsSize := n.MaxParamsSize() - var node Node - var treeMiddleware middleware.Middleware + var route Route var params context.Params if subPath == "" || n.staticNode.skipSubPath { - node = n - treeMiddleware = n.Middleware() + route = n.route params = make(context.Params, maxParamsSize) } else { - node, treeMiddleware, params, subPath = n.children.Match(subPath) - - if node == nil { - return nil, nil, nil, "" + route, params, subPath = n.children.MatchRoute(subPath) + if route == nil { + return nil, nil, "" } - - treeMiddleware = n.middleware.Merge(treeMiddleware) } params.Set(maxParamsSize-1, n.name, pathPart) - return node, treeMiddleware, params, subPath + return route, params, subPath +} + +func (n *wildcardNode) MatchMiddleware(path string) middleware.Middleware { + _, subPath := pathutils.GetPart(path) + + if treeMiddleware := n.children.MatchMiddleware(subPath); treeMiddleware != nil { + + return n.middleware.Merge(treeMiddleware) + } + + return n.middleware } func withRegexp(parent *staticNode, regexp *regexp.Regexp) *regexpNode { @@ -190,35 +225,44 @@ type regexpNode struct { regexp *regexp.Regexp } -func (n *regexpNode) Match(path string) (Node, middleware.Middleware, context.Params, string) { +func (n *regexpNode) MatchRoute(path string) (Route, context.Params, string) { pathPart, subPath := pathutils.GetPart(path) if !n.regexp.MatchString(pathPart) { - return nil, nil, nil, "" + return nil, nil, "" } maxParamsSize := n.MaxParamsSize() - var node Node - var treeMiddleware middleware.Middleware + var route Route var params context.Params if subPath == "" || n.staticNode.skipSubPath { - node = n - treeMiddleware = n.Middleware() + route = n.route params = make(context.Params, maxParamsSize) } else { - node, treeMiddleware, params, subPath = n.children.Match(subPath) - - if node == nil { - return nil, nil, nil, "" + route, params, subPath = n.children.MatchRoute(subPath) + if route == nil { + return nil, nil, "" } - - treeMiddleware = n.middleware.Merge(treeMiddleware) } params.Set(maxParamsSize-1, n.name, pathPart) - return node, treeMiddleware, params, subPath + return route, params, subPath +} + +func (n *regexpNode) MatchMiddleware(path string) middleware.Middleware { + pathPart, subPath := pathutils.GetPart(path) + if !n.regexp.MatchString(pathPart) { + return nil + } + + if treeMiddleware := n.children.MatchMiddleware(subPath); treeMiddleware != nil { + + return n.middleware.Merge(treeMiddleware) + } + + return n.middleware } func withSubrouter(parent Node) *subrouterNode { diff --git a/mux/tree.go b/mux/tree.go index 83bc8c6..9273052 100644 --- a/mux/tree.go +++ b/mux/tree.go @@ -19,14 +19,6 @@ func NewTree() Tree { // Tree slice of mux Nodes type Tree []Node -// Match represents path match data struct -type match struct { - node Node - middleware middleware.Middleware - params context.Params - subPath string -} - // PrettyPrint prints the tree text representation to console func (t Tree) PrettyPrint() string { buff := &bytes.Buffer{} @@ -77,40 +69,28 @@ func (t Tree) Compile() Tree { return t } -// Match path to first Node -func (t Tree) Match(path string) (Node, middleware.Middleware, context.Params, string) { - var orphanMatches []match - +// MatchRoute path to first Node +func (t Tree) MatchRoute(path string) (Route, context.Params, string) { for _, child := range t { - if node, m, params, subPath := child.Match(path); node != nil { - if node.Route() != nil { - if len(orphanMatches) > 0 { - for i := 0; i < len(orphanMatches); i++ { - m = m.Merge(orphanMatches[i].node.Middleware()) - } - } - - return node, m, params, subPath - } - - orphanMatch := match{ - node: node, - middleware: m, - params: params, - subPath: subPath, - } - orphanMatches = append(orphanMatches, orphanMatch) + if route, params, subPath := child.MatchRoute(path); route != nil { + return route, params, subPath } } - // no route found, return first orphan match - if len(orphanMatches) > 0 { - firstOrphanMatch := orphanMatches[0] + return nil, nil, "" +} + +// MatchMiddleware collects middleware from all nodes that match path +func (t Tree) MatchMiddleware(path string) middleware.Middleware { + var treeMiddleware = make(middleware.Middleware, 0) - return firstOrphanMatch.node, firstOrphanMatch.middleware, firstOrphanMatch.params, firstOrphanMatch.subPath + for _, child := range t { + if m := child.MatchMiddleware(path); m != nil { + treeMiddleware = treeMiddleware.Merge(m) + } } - return nil, nil, nil, "" + return treeMiddleware } // Find finds Node inside a tree by name @@ -157,7 +137,7 @@ func (t Tree) WithRoute(path string, route Route, maxParamsSize uint8) Tree { // WithMiddleware returns new Tree with Middleware appended to given Node // Middleware is appended to Node under the give path, if Node does not exist it will panic -func (t Tree) WithMiddleware(path string, m middleware.Middleware, maxParamsSize uint8) Tree { +func (t Tree) WithMiddleware(path string, m middleware.Middleware) Tree { path = pathutils.TrimSlash(path) if path == "" { return t @@ -169,14 +149,14 @@ func (t Tree) WithMiddleware(path string, m middleware.Middleware, maxParamsSize newTree := t if node == nil { - node = NewNode(parts[0], maxParamsSize) - newTree = t.withNode(node).sort() + node = NewNode(parts[0], 0) + newTree = t.withNode(node) } if len(parts) == 1 { node.AppendMiddleware(m) } else { - node.WithChildren(node.Tree().WithMiddleware(strings.Join(parts[1:], "/"), m, node.MaxParamsSize())) + node.WithChildren(node.Tree().WithMiddleware(strings.Join(parts[1:], "/"), m)) } return newTree diff --git a/mux/tree_test.go b/mux/tree_test.go index 6b2294a..4becac3 100644 --- a/mux/tree_test.go +++ b/mux/tree_test.go @@ -37,13 +37,13 @@ func TestTreeMatch(t *testing.T) { root.WithChildren(root.Tree().Compile()) - n, _, _, _ := root.Tree().Match("pl/blog/comments/123/new") + route, _, _ := root.Tree().MatchRoute("pl/blog/comments/123/new") - if n == nil { - t.Fatalf("%v", n) + if route == nil { + t.Fatalf("%v", route) } - if n.Name() != commentNew.Name() { - t.Fatalf("%s != %s", n.Name(), commentNew.Name()) + if route != commentNew.Route() { + t.Fatalf("%s != %s (%s)", route, commentNew.Route(), commentNew.Name()) } } diff --git a/nethttp.go b/nethttp.go index cc7ee20..d855836 100644 --- a/nethttp.go +++ b/nethttp.go @@ -13,17 +13,19 @@ import ( // New creates new net/http Router instance, returns pointer func New(fs ...MiddlewareFunc) Router { return &router{ - routes: mux.NewTree(), - middleware: transformMiddlewareFunc(fs...), + routes: mux.NewTree(), + middleware: mux.NewTree(), + globalMiddleware: transformMiddlewareFunc(fs...), } } type router struct { - routes mux.Tree - middleware middleware.Middleware - fileServer http.Handler - notFound http.Handler - notAllowed http.Handler + routes mux.Tree // mux.RouteAware tree + middleware mux.Tree // mux.MiddlewareAware tree + globalMiddleware middleware.Middleware + fileServer http.Handler + notFound http.Handler + notAllowed http.Handler } func (r *router) PrettyPrint() string { @@ -69,7 +71,7 @@ func (r *router) TRACE(p string, f http.Handler) { func (r *router) USE(method, path string, fs ...MiddlewareFunc) { m := transformMiddlewareFunc(fs...) - r.routes = r.routes.WithMiddleware(method+path, m, 0) + r.middleware = r.middleware.WithMiddleware(method+path, m) } func (r *router) Handle(method, path string, h http.Handler) { @@ -125,11 +127,13 @@ func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { path := pathutils.TrimSlash(req.URL.Path) if root := r.routes.Find(req.Method); root != nil { - if node, treeMiddleware, params, subPath := root.Tree().Match(path); node != nil && node.Route() != nil { - route := node.Route() - handler := route.Handler() - allMiddleware := r.middleware.Merge(root.Middleware().Merge(treeMiddleware)) - computedHandler := allMiddleware.Compose(handler) + if route, params, subPath := root.Tree().MatchRoute(req.Method + path); route != nil { + allMiddleware := r.globalMiddleware + if treeMiddleware := r.middleware.MatchMiddleware(req.Method + path); treeMiddleware != nil { + allMiddleware = allMiddleware.Merge(treeMiddleware) + } + + computedHandler := allMiddleware.Compose(route.Handler()) h := computedHandler.(http.Handler) @@ -146,7 +150,12 @@ func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { } if req.URL.Path == "/" && root.Route() != nil { - root.Route().Handler().(http.Handler).ServeHTTP(w, req) + rootMiddleware := r.globalMiddleware + if root.Middleware() != nil { + rootMiddleware = rootMiddleware.Merge(root.Middleware()) + } + rootHandler := rootMiddleware.Compose(root.Route().Handler()) + rootHandler.(http.Handler).ServeHTTP(w, req) return } } diff --git a/nethttp_test.go b/nethttp_test.go index 8366c22..7a570b7 100644 --- a/nethttp_test.go +++ b/nethttp_test.go @@ -363,7 +363,7 @@ func TestNilMiddleware(t *testing.T) { router.ServeHTTP(w, req) if w.Body.String() != "test" { - t.Error("Nil middleware works") + t.Error("Nil globalMiddleware works") } } @@ -429,7 +429,7 @@ func TestNodeApplyMiddleware(t *testing.T) { router.ServeHTTP(w, req) if w.Body.String() != "m1y" { - t.Errorf("Use middleware error %s", w.Body.String()) + t.Errorf("Use globalMiddleware error %s", w.Body.String()) } w = httptest.NewRecorder() @@ -441,7 +441,7 @@ func TestNodeApplyMiddleware(t *testing.T) { router.ServeHTTP(w, req) if w.Body.String() != "m1m2x" { - t.Errorf("Use middleware error %s", w.Body.String()) + t.Errorf("Use globalMiddleware error %s", w.Body.String()) } } @@ -456,10 +456,10 @@ func TestTreeOrphanMiddlewareOrder(t *testing.T) { } })) - // Method global middleware + // Method global globalMiddleware router.USE(http.MethodGet, "/", mockMiddleware("m1->")) router.USE(http.MethodGet, "/", mockMiddleware("m2->")) - // Path middleware + // Path globalMiddleware router.USE(http.MethodGet, "/x", mockMiddleware("mx1->")) router.USE(http.MethodGet, "/x", mockMiddleware("mx2->")) router.USE(http.MethodGet, "/x/y", mockMiddleware("mxy1->")) @@ -478,7 +478,7 @@ func TestTreeOrphanMiddlewareOrder(t *testing.T) { router.ServeHTTP(w, req) if w.Body.String() != "m1->m2->mx1->mx2->mparam1->mparam2->mxy1->mxy2->mxy3->mxy4->handler" { - t.Errorf("Use middleware error %s", w.Body.String()) + t.Errorf("Use globalMiddleware error %s", w.Body.String()) } } @@ -504,7 +504,7 @@ func TestNodeApplyMiddlewareStatic(t *testing.T) { router.ServeHTTP(w, req) if w.Body.String() != "m1x" { - t.Errorf("Use middleware error %s", w.Body.String()) + t.Errorf("Use globalMiddleware error %s", w.Body.String()) } } @@ -535,7 +535,7 @@ func TestNodeApplyMiddlewareInvalidNodeReference(t *testing.T) { router.ServeHTTP(w, req) if w.Body.String() != "y" { - t.Errorf("Use middleware error %s", w.Body.String()) + t.Errorf("Use globalMiddleware error %s", w.Body.String()) } } @@ -705,6 +705,6 @@ func TestMountSubRouter(t *testing.T) { mainRouter.ServeHTTP(w, req) if w.Body.String() != "[rg1][rg2][r1][r2][sg1][sg2][s1][s2][s]" { - t.Errorf("Router mount sub router middleware error: %s", w.Body.String()) + t.Errorf("Router mount sub router globalMiddleware error: %s", w.Body.String()) } } diff --git a/route_test.go b/route_test.go index c188bb3..47b2b7a 100644 --- a/route_test.go +++ b/route_test.go @@ -11,13 +11,17 @@ import ( func TestRouter(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.Write([]byte("4")) + if _, err := w.Write([]byte("4")); err != nil { + t.Fatal(err) + } }) buildMiddlewareFunc := func(body string) middleware.MiddlewareFunc { fn := func(h interface{}) interface{} { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(body)) + if _, err := w.Write([]byte(body)); err != nil { + t.Fatal(err) + } h.(http.Handler).ServeHTTP(w, r) }) } diff --git a/router.go b/router.go index 788c4c4..fa4766e 100644 --- a/router.go +++ b/router.go @@ -6,10 +6,10 @@ import ( "github.com/valyala/fasthttp" ) -// MiddlewareFunc is a http middleware function type +// MiddlewareFunc is a http globalMiddleware function type type MiddlewareFunc func(http.Handler) http.Handler -// FastHTTPMiddlewareFunc is a fasthttp middleware function type +// FastHTTPMiddlewareFunc is a fasthttp globalMiddleware function type type FastHTTPMiddlewareFunc func(fasthttp.RequestHandler) fasthttp.RequestHandler // Router is a micro framework, HTTP request router, multiplexer, mux @@ -53,7 +53,7 @@ type Router interface { // under TRACE method and given patter TRACE(pattern string, handler http.Handler) - // USE adds middleware functions ([]MiddlewareFunc) + // USE adds globalMiddleware functions ([]MiddlewareFunc) // to whole router branch under given method and patter USE(method, pattern string, fs ...MiddlewareFunc) @@ -125,7 +125,7 @@ type FastHTTPRouter interface { // under TRACE method and given patter TRACE(pattern string, handler fasthttp.RequestHandler) - // USE adds middleware functions ([]MiddlewareFunc) + // USE adds globalMiddleware functions ([]MiddlewareFunc) // to whole router branch under given method and patter USE(method, pattern string, fs ...FastHTTPMiddlewareFunc) diff --git a/tree.go b/tree.go index eb46996..3ce6b19 100644 --- a/tree.go +++ b/tree.go @@ -4,23 +4,8 @@ import ( "net/http" "github.com/vardius/gorouter/v4/mux" - pathutils "github.com/vardius/gorouter/v4/path" ) -func findNode(n mux.Node, parts []string) mux.Node { - if len(parts) == 0 { - return n - } - - name, _ := pathutils.GetNameFromPart(parts[0]) - - if node := n.Tree().Find(name); node != nil { - return findNode(node, parts[1:]) - } - - return n -} - func allowed(t mux.Tree, method, path string) (allow string) { if path == "*" { // routes tree roots should be http method nodes only @@ -41,7 +26,7 @@ func allowed(t mux.Tree, method, path string) (allow string) { continue } - if n, _, _, _ := root.Tree().Match(path); n != nil && n.Route() != nil { + if route, _, _ := root.Tree().MatchRoute(path); route != nil { if len(allow) == 0 { allow = root.Name() } else { From ee91bf830742f7f9c8009f1c3aabac700e170bc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Mon, 27 Jan 2020 08:36:50 +1100 Subject: [PATCH 21/41] Add comments to interfaces --- mux/node.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mux/node.go b/mux/node.go index 19e3cf3..6313364 100644 --- a/mux/node.go +++ b/mux/node.go @@ -37,6 +37,7 @@ func NewNode(pathPart string, maxParamsSize uint8) Node { return node } +// RouteAware represents route aware Node type RouteAware interface { // MatchRoute matches given path to Route within Node and its Tree MatchRoute(path string) (Route, context.Params, string) @@ -54,6 +55,7 @@ type RouteAware interface { SkipSubPath() } +// MiddlewareAware represents middleware aware node type MiddlewareAware interface { // MatchMiddleware collects middleware from all nodes within tree matching given path // middleware is merged in order nodes where created, collecting from top to bottom From 7b6b026b04cc7a72b9a6bc009a3b1ed8430fd51c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Mon, 27 Jan 2020 08:45:23 +1100 Subject: [PATCH 22/41] Use print-style function --- mocks_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mocks_test.go b/mocks_test.go index b7c27b3..99feba4 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -67,7 +67,7 @@ func mockServeHTTP(h http.Handler, method, path string) error { func mockFastHTTPMiddleware(body string) FastHTTPMiddlewareFunc { fn := func(h fasthttp.RequestHandler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { - if _, err := fmt.Fprintf(ctx, body); err != nil { + if _, err := fmt.Fprint(ctx, body); err != nil { panic(err) } From 6d709b563ee348d4a2d32830b2b7dd1a1c1b11dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Mon, 27 Jan 2020 19:57:05 +1100 Subject: [PATCH 23/41] Fix router serve method, simplify logic --- fasthttp.go | 44 +++++++++++++++----------------------------- fasthttp_test.go | 2 +- nethttp.go | 47 ++++++++++++++++------------------------------- nethttp_test.go | 2 +- tree.go | 3 +++ 5 files changed, 36 insertions(+), 62 deletions(-) diff --git a/fasthttp.go b/fasthttp.go index 6c442c9..ff68e19 100644 --- a/fasthttp.go +++ b/fasthttp.go @@ -6,7 +6,6 @@ import ( "github.com/valyala/fasthttp" "github.com/vardius/gorouter/v4/middleware" "github.com/vardius/gorouter/v4/mux" - pathutils "github.com/vardius/gorouter/v4/path" ) // NewFastHTTPRouter creates new Router instance, returns pointer @@ -121,41 +120,28 @@ func (r *fastHTTPRouter) ServeFiles(root string, stripSlashes int) { func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) { method := string(ctx.Method()) - pathAsString := string(ctx.Path()) - path := pathutils.TrimSlash(pathAsString) - - if root := r.routes.Find(method); root != nil { - if route, params, subPath := root.Tree().MatchRoute(path); route != nil { - allMiddleware := r.globalMiddleware - if treeMiddleware := r.middleware.MatchMiddleware(method + path); treeMiddleware != nil { - allMiddleware = allMiddleware.Merge(treeMiddleware) - } + path := string(ctx.Path()) - computedHandler := allMiddleware.Compose(route.Handler()) + if route, params, subPath := r.routes.MatchRoute(method + path); route != nil { + allMiddleware := r.globalMiddleware + if treeMiddleware := r.middleware.MatchMiddleware(method + path); treeMiddleware != nil { + allMiddleware = allMiddleware.Merge(treeMiddleware) + } - h := computedHandler.(fasthttp.RequestHandler) + computedHandler := allMiddleware.Compose(route.Handler()) - if len(params) > 0 { - ctx.SetUserValue("params", params) - } - - if subPath != "" { - ctx.URI().SetPathBytes(fasthttp.NewPathPrefixStripper(len("/" + subPath))(ctx)) - } + h := computedHandler.(fasthttp.RequestHandler) - h(ctx) - return + if len(params) > 0 { + ctx.SetUserValue("params", params) } - if pathAsString == "/" && root.Route() != nil { - rootMiddleware := r.globalMiddleware - if root.Middleware() != nil { - rootMiddleware = rootMiddleware.Merge(root.Middleware()) - } - rootHandler := rootMiddleware.Compose(root.Route().Handler()) - rootHandler.(fasthttp.RequestHandler)(ctx) - return + if subPath != "" { + ctx.URI().SetPathBytes(fasthttp.NewPathPrefixStripper(len("/" + subPath))(ctx)) } + + h(ctx) + return } // Handle OPTIONS diff --git a/fasthttp_test.go b/fasthttp_test.go index a1bd0fa..0bf2ecc 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -439,7 +439,7 @@ func TestFastHTTPTreeOrphanMiddlewareOrder(t *testing.T) { router.HandleFastHTTP(ctx) - if string(ctx.Response.Body()) != "m1->m2->mx1->mx2->mparam1->mparam2->mxy1->mxy2->mxy3->mxy4->handler" { + if string(ctx.Response.Body()) != "m1->m2->mx1->mx2->mxy1->mxy2->mxy3->mxy4->mparam1->mparam2->handler" { t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body())) } } diff --git a/nethttp.go b/nethttp.go index d855836..e161054 100644 --- a/nethttp.go +++ b/nethttp.go @@ -7,7 +7,6 @@ import ( "github.com/vardius/gorouter/v4/context" "github.com/vardius/gorouter/v4/middleware" "github.com/vardius/gorouter/v4/mux" - pathutils "github.com/vardius/gorouter/v4/path" ) // New creates new net/http Router instance, returns pointer @@ -124,45 +123,31 @@ func (r *router) ServeFiles(fs http.FileSystem, root string, strip bool) { } func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { - path := pathutils.TrimSlash(req.URL.Path) - - if root := r.routes.Find(req.Method); root != nil { - if route, params, subPath := root.Tree().MatchRoute(req.Method + path); route != nil { - allMiddleware := r.globalMiddleware - if treeMiddleware := r.middleware.MatchMiddleware(req.Method + path); treeMiddleware != nil { - allMiddleware = allMiddleware.Merge(treeMiddleware) - } - - computedHandler := allMiddleware.Compose(route.Handler()) - - h := computedHandler.(http.Handler) + if route, params, subPath := r.routes.MatchRoute(req.Method + req.URL.Path); route != nil { + allMiddleware := r.globalMiddleware + if treeMiddleware := r.middleware.MatchMiddleware(req.Method + req.URL.Path); treeMiddleware != nil { + allMiddleware = allMiddleware.Merge(treeMiddleware) + } - if len(params) > 0 { - req = req.WithContext(context.WithParams(req.Context(), params)) - } + computedHandler := allMiddleware.Compose(route.Handler()) - if subPath != "" { - h = http.StripPrefix(strings.TrimSuffix(req.URL.Path, "/"+subPath), h) - } + h := computedHandler.(http.Handler) - h.ServeHTTP(w, req) - return + if len(params) > 0 { + req = req.WithContext(context.WithParams(req.Context(), params)) } - if req.URL.Path == "/" && root.Route() != nil { - rootMiddleware := r.globalMiddleware - if root.Middleware() != nil { - rootMiddleware = rootMiddleware.Merge(root.Middleware()) - } - rootHandler := rootMiddleware.Compose(root.Route().Handler()) - rootHandler.(http.Handler).ServeHTTP(w, req) - return + if subPath != "" { + h = http.StripPrefix(strings.TrimSuffix(req.URL.Path, "/"+subPath), h) } + + h.ServeHTTP(w, req) + return } // Handle OPTIONS if req.Method == http.MethodOptions { - if allow := allowed(r.routes, req.Method, path); len(allow) > 0 { + if allow := allowed(r.routes, req.Method, req.URL.Path); len(allow) > 0 { w.Header().Set("Allow", allow) return } @@ -172,7 +157,7 @@ func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } else { // Handle 405 - if allow := allowed(r.routes, req.Method, path); len(allow) > 0 { + if allow := allowed(r.routes, req.Method, req.URL.Path); len(allow) > 0 { w.Header().Set("Allow", allow) r.serveNotAllowed(w, req) return diff --git a/nethttp_test.go b/nethttp_test.go index 7a570b7..5551857 100644 --- a/nethttp_test.go +++ b/nethttp_test.go @@ -477,7 +477,7 @@ func TestTreeOrphanMiddlewareOrder(t *testing.T) { router.ServeHTTP(w, req) - if w.Body.String() != "m1->m2->mx1->mx2->mparam1->mparam2->mxy1->mxy2->mxy3->mxy4->handler" { + if w.Body.String() != "m1->m2->mx1->mx2->mxy1->mxy2->mxy3->mxy4->mparam1->mparam2->handler" { t.Errorf("Use globalMiddleware error %s", w.Body.String()) } } diff --git a/tree.go b/tree.go index 3ce6b19..34488c7 100644 --- a/tree.go +++ b/tree.go @@ -1,12 +1,15 @@ package gorouter import ( + pathutils "github.com/vardius/gorouter/v4/path" "net/http" "github.com/vardius/gorouter/v4/mux" ) func allowed(t mux.Tree, method, path string) (allow string) { + path = pathutils.TrimSlash(path) + if path == "*" { // routes tree roots should be http method nodes only for _, root := range t { From 1cc5ab9330c32eec5ee3106469597df190242f43 Mon Sep 17 00:00:00 2001 From: mar1n3r0 Date: Mon, 27 Jan 2020 12:13:23 +0200 Subject: [PATCH 24/41] demonstrate static priority over wildcard --- fasthttp_test.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/fasthttp_test.go b/fasthttp_test.go index 0bf2ecc..99ea8b5 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -391,14 +391,16 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { } }) - router.USE(http.MethodGet, "/x/{param}", mockFastHTTPMiddleware("m1")) - router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m2")) + router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m1")) + router.USE(http.MethodGet, "/x/{param}", mockFastHTTPMiddleware("m2")) + router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m3")) + router.USE(http.MethodGet, "/x/{param}", mockFastHTTPMiddleware("m4")) ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/y") router.HandleFastHTTP(ctx) - if string(ctx.Response.Body()) != "m1y" { + if string(ctx.Response.Body()) != "m2m4y" { t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body())) } @@ -406,7 +408,7 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { router.HandleFastHTTP(ctx) - if string(ctx.Response.Body()) != "m1m2x" { + if string(ctx.Response.Body()) != "m1m2m3m4x" { t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body())) } } @@ -439,7 +441,7 @@ func TestFastHTTPTreeOrphanMiddlewareOrder(t *testing.T) { router.HandleFastHTTP(ctx) - if string(ctx.Response.Body()) != "m1->m2->mx1->mx2->mxy1->mxy2->mxy3->mxy4->mparam1->mparam2->handler" { + if string(ctx.Response.Body()) != "m1->m2->mx1->mx2->mxy1->mxy2->mparam1->mparam2->mxy3->mxy4->handler" { t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body())) } } From ff933670ac51bd7e61592565e321483550585dad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Tue, 28 Jan 2020 22:41:48 +1100 Subject: [PATCH 25/41] Add middleware priority, sort middleware before composing with handler --- doc.go | 6 +- example_test.go | 14 ++-- fasthttp.go | 50 +++++++------ fasthttp_test.go | 2 +- middleware/collection.go | 55 ++++++++++++++ middleware/collection_test.go | 126 +++++++++++++++++++++++++++++++ middleware/middleware.go | 57 ++++++++------ middleware/middleware_test.go | 137 ++++++++++++---------------------- mocks_test.go | 2 +- mux/node.go | 32 ++++---- mux/route.go | 2 +- mux/tree.go | 18 ++--- mux/tree_test.go | 10 +-- nethttp.go | 52 ++++++------- nethttp_test.go | 4 +- route_test.go | 8 +- tree.go | 4 +- 17 files changed, 364 insertions(+), 215 deletions(-) create mode 100644 middleware/collection.go create mode 100644 middleware/collection_test.go diff --git a/doc.go b/doc.go index 95a2154..c6e6dbc 100644 --- a/doc.go +++ b/doc.go @@ -4,8 +4,8 @@ Package gorouter provide request router with globalMiddleware Router The router determines how to handle that request. -GoRouter uses a routing tree. Once one branch of the tree matches, only routes inside that branch are considered, -not any routes after that branch. When instantiating router, the root node of tree is created. +GoRouter uses a routing tree. Once one branch of the tree matches, only tree inside that branch are considered, +not any tree after that branch. When instantiating router, the root node of tree is created. Route types @@ -31,7 +31,7 @@ A full route definition contain up to three parts: 2. The URL path route. This is matched against the URL passed to the router, and can contain named wildcard placeholders *(e.g. {placeholder})* to match dynamic parts in the URL. -3. `http.HandleFunc`, which tells the router to handle matched requests to the router with handler. +3. `http.HandlerFunc`, which tells the router to handle matched requests to the router with handler. Take the following example: diff --git a/example_test.go b/example_test.go index 2e5d731..66ebf33 100644 --- a/example_test.go +++ b/example_test.go @@ -66,7 +66,7 @@ func Example_second() { func ExampleMiddlewareFunc() { // Global globalMiddleware example - // applies to all routes + // applies to all tree hello := func(w http.ResponseWriter, r *http.Request) { params, _ := context.Parameters(r.Context()) fmt.Printf("Hello, %s!\n", params.Value("name")) @@ -81,7 +81,7 @@ func ExampleMiddlewareFunc() { return http.HandlerFunc(fn) } - // apply globalMiddleware to all routes + // apply globalMiddleware to all tree // can pass as many as you want router := gorouter.New(logger) router.GET("/hello/{name}", http.HandlerFunc(hello)) @@ -128,7 +128,7 @@ func ExampleMiddlewareFunc_second() { func ExampleMiddlewareFunc_third() { // Http method globalMiddleware example - // applies to all routes under this method + // applies to all tree under this method hello := func(w http.ResponseWriter, r *http.Request) { params, _ := context.Parameters(r.Context()) fmt.Printf("Hello, %s!\n", params.Value("name")) @@ -146,7 +146,7 @@ func ExampleMiddlewareFunc_third() { router := gorouter.New() router.GET("/hello/{name}", http.HandlerFunc(hello)) - // apply globalMiddleware to all routes with GET method + // apply globalMiddleware to all tree with GET method // can pass as many as you want router.USE("GET", "", logger) @@ -160,7 +160,7 @@ func ExampleMiddlewareFunc_third() { func ExampleFastHTTPMiddlewareFunc() { // Global globalMiddleware example - // applies to all routes + // applies to all tree hello := func(ctx *fasthttp.RequestCtx) { params := ctx.UserValue("params").(context.Params) fmt.Printf("Hello, %s!\n", params.Value("name")) @@ -220,7 +220,7 @@ func ExampleFastHTTPMiddlewareFunc_second() { func ExampleFastHTTPMiddlewareFunc_third() { // Http method globalMiddleware example - // applies to all routes under this method + // applies to all tree under this method hello := func(ctx *fasthttp.RequestCtx) { params := ctx.UserValue("params").(context.Params) fmt.Printf("Hello, %s!\n", params.Value("name")) @@ -238,7 +238,7 @@ func ExampleFastHTTPMiddlewareFunc_third() { router := gorouter.NewFastHTTPRouter() router.GET("/hello/{name}", hello) - // apply globalMiddleware to all routes with GET method + // apply globalMiddleware to all tree with GET method // can pass as many as you want router.USE("GET", "", logger) diff --git a/fasthttp.go b/fasthttp.go index ff68e19..72e6e3f 100644 --- a/fasthttp.go +++ b/fasthttp.go @@ -10,24 +10,25 @@ import ( // NewFastHTTPRouter creates new Router instance, returns pointer func NewFastHTTPRouter(fs ...FastHTTPMiddlewareFunc) FastHTTPRouter { + globalMiddleware := middleware.NewCollectionFromWrappers(0, transformFastHTTPMiddlewareFunc(fs...)...) return &fastHTTPRouter{ - routes: mux.NewTree(), - middleware: mux.NewTree(), - globalMiddleware: transformFastHTTPMiddlewareFunc(fs...), + tree: mux.NewTree(), + globalMiddleware: globalMiddleware, + middlewareCounter: uint(len(globalMiddleware)), } } type fastHTTPRouter struct { - routes mux.Tree // mux.RouteAware tree - middleware mux.Tree // mux.MiddlewareAware tree - globalMiddleware middleware.Middleware - fileServer fasthttp.RequestHandler - notFound fasthttp.RequestHandler - notAllowed fasthttp.RequestHandler + tree mux.Tree + globalMiddleware middleware.Collection + fileServer fasthttp.RequestHandler + notFound fasthttp.RequestHandler + notAllowed fasthttp.RequestHandler + middlewareCounter uint } func (r *fastHTTPRouter) PrettyPrint() string { - return r.routes.PrettyPrint() + return r.tree.PrettyPrint() } func (r *fastHTTPRouter) POST(p string, f fasthttp.RequestHandler) { @@ -69,13 +70,14 @@ func (r *fastHTTPRouter) TRACE(p string, f fasthttp.RequestHandler) { func (r *fastHTTPRouter) USE(method, path string, fs ...FastHTTPMiddlewareFunc) { m := transformFastHTTPMiddlewareFunc(fs...) - r.middleware = r.middleware.WithMiddleware(method+path, m) + r.tree = r.tree.WithMiddleware(method+path, m, r.middlewareCounter, 0) + r.middlewareCounter += uint(len(m)) } func (r *fastHTTPRouter) Handle(method, path string, h fasthttp.RequestHandler) { route := newRoute(h) - r.routes = r.routes.WithRoute(method+path, route, 0) + r.tree = r.tree.WithRoute(method+path, route, 0) } func (r *fastHTTPRouter) Mount(path string, h fasthttp.RequestHandler) { @@ -92,13 +94,13 @@ func (r *fastHTTPRouter) Mount(path string, h fasthttp.RequestHandler) { } { route := newRoute(h) - r.routes = r.routes.WithSubrouter(method+path, route, 0) + r.tree = r.tree.WithSubrouter(method+path, route, 0) } } func (r *fastHTTPRouter) Compile() { - for i, methodNode := range r.routes { - r.routes[i].WithChildren(methodNode.Tree().Compile()) + for i, methodNode := range r.tree { + r.tree[i].WithChildren(methodNode.Tree().Compile()) } } @@ -122,13 +124,13 @@ func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) { method := string(ctx.Method()) path := string(ctx.Path()) - if route, params, subPath := r.routes.MatchRoute(method + path); route != nil { + if route, params, subPath := r.tree.MatchRoute(method + path); route != nil { allMiddleware := r.globalMiddleware - if treeMiddleware := r.middleware.MatchMiddleware(method + path); treeMiddleware != nil { + if treeMiddleware := r.tree.MatchMiddleware(method + path); treeMiddleware != nil { allMiddleware = allMiddleware.Merge(treeMiddleware) } - computedHandler := allMiddleware.Compose(route.Handler()) + computedHandler := allMiddleware.Sort().Compose(route.Handler()) h := computedHandler.(fasthttp.RequestHandler) @@ -146,7 +148,7 @@ func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) { // Handle OPTIONS if method == http.MethodOptions { - if allow := allowed(r.routes, method, path); len(allow) > 0 { + if allow := allowed(r.tree, method, path); len(allow) > 0 { ctx.Response.Header.Set("Allow", allow) return } @@ -156,7 +158,7 @@ func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) { return } else { // Handle 405 - if allow := allowed(r.routes, method, path); len(allow) > 0 { + if allow := allowed(r.tree, method, path); len(allow) > 0 { ctx.Response.Header.Set("Allow", allow) r.serveNotAllowed(ctx) return @@ -183,12 +185,12 @@ func (r *fastHTTPRouter) serveNotAllowed(ctx *fasthttp.RequestCtx) { } } -func transformFastHTTPMiddlewareFunc(fs ...FastHTTPMiddlewareFunc) middleware.Middleware { - m := make(middleware.Middleware, len(fs)) +func transformFastHTTPMiddlewareFunc(fs ...FastHTTPMiddlewareFunc) []middleware.Wrapper { + m := make([]middleware.Wrapper, len(fs)) for i, f := range fs { - m[i] = func(mf FastHTTPMiddlewareFunc) middleware.MiddlewareFunc { - return func(h interface{}) interface{} { + m[i] = func(mf FastHTTPMiddlewareFunc) middleware.WrapperFunc { + return func(h middleware.Handler) middleware.Handler { return mf(h.(fasthttp.RequestHandler)) } }(f) // f is a reference to function so we have to wrap if with that callback diff --git a/fasthttp_test.go b/fasthttp_test.go index 0bf2ecc..c546897 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -439,7 +439,7 @@ func TestFastHTTPTreeOrphanMiddlewareOrder(t *testing.T) { router.HandleFastHTTP(ctx) - if string(ctx.Response.Body()) != "m1->m2->mx1->mx2->mxy1->mxy2->mxy3->mxy4->mparam1->mparam2->handler" { + if string(ctx.Response.Body()) != "m1->m2->mx1->mx2->mxy1->mxy2->mparam1->mparam2->mxy3->mxy4->handler" { t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body())) } } diff --git a/middleware/collection.go b/middleware/collection.go new file mode 100644 index 0000000..c64eeae --- /dev/null +++ b/middleware/collection.go @@ -0,0 +1,55 @@ +package middleware + +import ( + "sort" +) + +// Collection is a slice of handler wrappers functions +type Collection []Middleware + +// NewCollection provides new middleware +func NewCollection(ms ...Middleware) Collection { + return ms +} + +// NewCollectionFromWrappers provides new middleware +// with order priority preset to provided value +func NewCollectionFromWrappers(priority uint, ws ...Wrapper) Collection { + c := make(Collection, len(ws)) + + for i, w := range ws { + c[i] = Middleware{ + wrapper: w, + priority: priority, + } + } + + return c +} + +// Merge merges another middleware +func (c Collection) Merge(m Collection) Collection { + return append(c, m...) +} + +// Compose returns middleware composed to single WrapperFunc +func (c Collection) Compose(h Handler) Handler { + if h == nil { + return nil + } + + for i := range c { + h = c[len(c)-1-i].Wrap(h) + } + + return h +} + +// Merge merges another middleware +func (c Collection) Sort() Collection { + sort.SliceStable(c, func(i, j int) bool { + return c[i].Priority() < c[j].Priority() + }) + + return c +} diff --git a/middleware/collection_test.go b/middleware/collection_test.go new file mode 100644 index 0000000..6d2193e --- /dev/null +++ b/middleware/collection_test.go @@ -0,0 +1,126 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func mockMiddleware(body string, priority uint) Middleware { + fn := func(h Handler) Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if _, err := w.Write([]byte(body)); err != nil { + panic(err) + } + h.(http.Handler).ServeHTTP(w, r) + }) + } + + return New(WrapperFunc(fn), priority) +} + +func TestNewCollection(t *testing.T) { + middlewareFactory := func(body string, priority uint) Middleware { + fn := func(h Handler) Handler { + return func() string { return body + h.(func() string)() } + } + + return New(WrapperFunc(fn), priority) + } + type test struct { + name string + m []Middleware + output string + sortedOutput string + } + tests := []test{ + test{"Empty", []Middleware{}, "h", "h"}, + test{"Single middleware", []Middleware{middlewareFactory("0", 0)}, "0h", "0h"}, + test{"Multiple unsorted middleware", []Middleware{middlewareFactory("3", 3), middlewareFactory("1", 1), middlewareFactory("2", 2)}, "312h", "123h"}, + test{"Multiple unsorted middleware 2", []Middleware{middlewareFactory("2", 2), middlewareFactory("1", 1), middlewareFactory("3", 3)}, "213h", "123h"}, + test{"Multiple unsorted middleware 3", []Middleware{middlewareFactory("1", 1), middlewareFactory("3", 3), middlewareFactory("2", 2)}, "132h", "123h"}, + test{"Multiple sorted middleware", []Middleware{middlewareFactory("1", 1), middlewareFactory("2", 2), middlewareFactory("3", 3)}, "123h", "123h"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := NewCollection(tt.m...) + h := m.Compose(func() string { return "h" }) + + result := h.(func() string)() + + if h.(func() string)() != tt.output { + t.Errorf("NewCollection: h() = %v, want %v", result, tt.output) + } + + h = m.Sort().Compose(func() string { return "h" }) + + result = h.(func() string)() + + if h.(func() string)() != tt.sortedOutput { + t.Errorf("NewCollection: h() = %v, want %v", result, tt.sortedOutput) + } + }) + } +} + +func TestOrders(t *testing.T) { + m1 := mockMiddleware("1", 3) + m2 := mockMiddleware("2", 2) + m3 := mockMiddleware("3", 1) + fn := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + if _, err := w.Write([]byte("4")); err != nil { + t.Fatal(err) + } + }) + + m := NewCollection(m1, m2, m3) + h := m.Sort().Compose(fn).(http.Handler) + + w := httptest.NewRecorder() + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + h.ServeHTTP(w, r) + + if w.Body.String() != "3214" { + t.Error("The order is incorrect") + } +} + +func TestMerge(t *testing.T) { + m1 := mockMiddleware("1", 0) + m2 := mockMiddleware("2", 0) + m3 := mockMiddleware("3", 0) + fn := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + if _, err := w.Write([]byte("4")); err != nil { + t.Fatal(err) + } + }) + + m := NewCollection(m1) + m = m.Merge(NewCollection(m2, m3)) + h := m.Compose(fn).(http.Handler) + + w := httptest.NewRecorder() + r, err := http.NewRequest("GET", "/", nil) + if err != nil { + t.Fatal(err) + } + + h.ServeHTTP(w, r) + + if w.Body.String() != "1234" { + t.Errorf("The order is incorrect expected: 1234 actual: %s", w.Body.String()) + } +} + +func TestCompose(t *testing.T) { + m := NewCollection(mockMiddleware("1", 0)) + h := m.Compose(nil) + + if h != nil { + t.Fail() + } +} diff --git a/middleware/middleware.go b/middleware/middleware.go index 4f4fbb3..e42138d 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -1,36 +1,47 @@ package middleware -// MiddlewareFunc is a middleware function type. -// Long story - short: it is a handler wrapper -type MiddlewareFunc func(interface{}) interface{} +// Handler represents wrapped function +type Handler interface{} -// Middleware is a slice of handler functions -type Middleware []MiddlewareFunc +// Wrapper wraps Handler +type Wrapper interface { + // Wrap Handler with middleware + Wrap(Handler) Handler +} -// New provides new middleware -func New(fs ...MiddlewareFunc) Middleware { - return fs +// Sortable allows Collection to be sorted by priority +type Sortable interface { + // Priority provides a value for sorting Collection, lower values come first + Priority() uint } -// Append appends handlers to middleware -func (m Middleware) Append(fs ...MiddlewareFunc) Middleware { - return m.Merge(fs) +// WrapperFunc is an adapter to allow the use of +// handler wrapper functions as middleware functions. +type WrapperFunc func(Handler) Handler + +// Wrap implements Wrapper interface +func (f WrapperFunc) Wrap(h Handler) Handler { + return f(h) } -// Merge merges another middleware -func (m Middleware) Merge(n Middleware) Middleware { - return append(m, n...) +// Middleware is a slice of handler wrappers functions +type Middleware struct { + wrapper Wrapper + priority uint } -// Compose returns middleware composed to single MiddlewareFunc -func (m Middleware) Compose(h interface{}) interface{} { - if h == nil { - return nil - } +func (m Middleware) Wrap(h Handler) Handler { + return m.wrapper.Wrap(h) +} - for i := range m { - h = m[len(m)-1-i](h) - } +func (m Middleware) Priority() uint { + return m.priority +} - return h +// New provides new Middleware +func New(w Wrapper, priority uint) Middleware { + return Middleware{ + wrapper: w, + priority: priority, + } } diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 742e887..58fa08d 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -1,109 +1,66 @@ package middleware import ( - "net/http" - "net/http/httptest" "testing" ) -func mockMiddleware(body string) MiddlewareFunc { - fn := func(h interface{}) interface{} { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if _, err := w.Write([]byte(body)); err != nil { - panic(err) - } - h.(http.Handler).ServeHTTP(w, r) - }) - } +type mockWrapper struct{} - return fn +func (*mockWrapper) Wrap(h Handler) Handler { + return h } -func TestOrders(t *testing.T) { - m1 := mockMiddleware("1") - m2 := mockMiddleware("2") - m3 := mockMiddleware("3") - fn := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - if _, err := w.Write([]byte("4")); err != nil { - t.Fatal(err) - } - }) - - m := New(m1, m2, m3) - h := m.Compose(fn).(http.Handler) - - w := httptest.NewRecorder() - r, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatal(err) +func TestNew(t *testing.T) { + type args struct { + w Wrapper + priority uint } - - h.ServeHTTP(w, r) - - if w.Body.String() != "1234" { - t.Error("The order is incorrect") + type test struct { + name string + args args } -} - -func TestAppend(t *testing.T) { - m1 := mockMiddleware("1") - m2 := mockMiddleware("2") - m3 := mockMiddleware("3") - fn := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - if _, err := w.Write([]byte("4")); err != nil { - t.Fatal(err) - } - }) - - m := New(m1) - m = m.Append(m2, m3) - h := m.Compose(fn).(http.Handler) - - w := httptest.NewRecorder() - r, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatal(err) + tests := []test{ + test{"From Wrapper", args{&mockWrapper{}, 0}}, + test{"From WrapperFunc", args{WrapperFunc(func(h Handler) Handler { return func() {} }), 0}}, } - - h.ServeHTTP(w, r) - - if w.Body.String() != "1234" { - t.Errorf("The order is incorrect expected: 1234 actual: %s", w.Body.String()) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + panicked := false + defer func() { + if rcv := recover(); rcv != nil { + panicked = true + } + }() + + got := New(tt.args.w, tt.args.priority) + + if panicked { + t.Errorf("Panic: New() = %v", got) + } + }) } } -func TestMerge(t *testing.T) { - m1 := mockMiddleware("1") - m2 := mockMiddleware("2") - m3 := mockMiddleware("3") - fn := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - if _, err := w.Write([]byte("4")); err != nil { - t.Fatal(err) - } - }) - - m := New(m1) - m = m.Merge(New(m2, m3)) - h := m.Compose(fn).(http.Handler) - - w := httptest.NewRecorder() - r, err := http.NewRequest("GET", "/", nil) - if err != nil { - t.Fatal(err) +func TestMiddleware_Priority(t *testing.T) { + type test struct { + name string + middleware Middleware + want uint } - - h.ServeHTTP(w, r) - - if w.Body.String() != "1234" { - t.Errorf("The order is incorrect expected: 1234 actual: %s", w.Body.String()) + tests := []test{ + test{"Zero", mockMiddleware("TestMiddleware_Priority 1", 0), 0}, + test{"Positive", mockMiddleware("TestMiddleware_Priority 1", 1), 1}, + test{"Positive Large", mockMiddleware("TestMiddleware_Priority 1", 999), 999}, } -} - -func TestCompose(t *testing.T) { - m := New(mockMiddleware("1")) - h := m.Compose(nil) - - if h != nil { - t.Fail() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := Middleware{ + wrapper: tt.middleware.wrapper, + priority: tt.middleware.priority, + } + if got := m.Priority(); got != tt.want { + t.Errorf("Priority() = %v, want %v", got, tt.want) + } + }) } } diff --git a/mocks_test.go b/mocks_test.go index 99feba4..4d586df 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -92,7 +92,7 @@ func checkIfHasRootRoute(t *testing.T, r interface{}, method string) { switch v := r.(type) { case *router: case *fastHTTPRouter: - if rootRoute := v.routes.Find(method); rootRoute == nil { + if rootRoute := v.tree.Find(method); rootRoute == nil { t.Error("Route not found") } default: diff --git a/mux/node.go b/mux/node.go index 6313364..84de64d 100644 --- a/mux/node.go +++ b/mux/node.go @@ -18,7 +18,7 @@ func NewNode(pathPart string, maxParamsSize uint8) Node { static := &staticNode{ name: name, children: NewTree(), - middleware: middleware.New(), + middleware: middleware.NewCollection(), maxParamsSize: maxParamsSize, } @@ -59,14 +59,14 @@ type RouteAware interface { type MiddlewareAware interface { // MatchMiddleware collects middleware from all nodes within tree matching given path // middleware is merged in order nodes where created, collecting from top to bottom - MatchMiddleware(path string) middleware.Middleware - - // Middleware provides Node's middleware - Middleware() middleware.Middleware - // AppendMiddleware appends middleware to Node - AppendMiddleware(m middleware.Middleware) - // PrependMiddleware prepends middleware to Node - PrependMiddleware(m middleware.Middleware) + MatchMiddleware(path string) middleware.Collection + + // Middleware provides Node's middleware collection + Middleware() middleware.Collection + // AppendMiddleware appends middleware collection to Node + AppendMiddleware(m middleware.Collection) + // PrependMiddleware prepends middleware collection to Node + PrependMiddleware(m middleware.Collection) } // Node represents mux Node @@ -88,7 +88,7 @@ type staticNode struct { children Tree route Route - middleware middleware.Middleware + middleware middleware.Collection maxParamsSize uint8 skipSubPath bool @@ -113,7 +113,7 @@ func (n *staticNode) MatchRoute(path string) (Route, context.Params, string) { return nil, nil, "" } -func (n *staticNode) MatchMiddleware(path string) middleware.Middleware { +func (n *staticNode) MatchMiddleware(path string) middleware.Collection { nameLength := len(n.name) pathLength := len(path) @@ -145,7 +145,7 @@ func (n *staticNode) Route() Route { return n.route } -func (n *staticNode) Middleware() middleware.Middleware { +func (n *staticNode) Middleware() middleware.Collection { return n.middleware } @@ -161,11 +161,11 @@ func (n *staticNode) WithRoute(r Route) { n.route = r } -func (n *staticNode) AppendMiddleware(m middleware.Middleware) { +func (n *staticNode) AppendMiddleware(m middleware.Collection) { n.middleware = n.middleware.Merge(m) } -func (n *staticNode) PrependMiddleware(m middleware.Middleware) { +func (n *staticNode) PrependMiddleware(m middleware.Collection) { n.middleware = m.Merge(n.middleware) } @@ -203,7 +203,7 @@ func (n *wildcardNode) MatchRoute(path string) (Route, context.Params, string) { return route, params, subPath } -func (n *wildcardNode) MatchMiddleware(path string) middleware.Middleware { +func (n *wildcardNode) MatchMiddleware(path string) middleware.Collection { _, subPath := pathutils.GetPart(path) if treeMiddleware := n.children.MatchMiddleware(subPath); treeMiddleware != nil { @@ -253,7 +253,7 @@ func (n *regexpNode) MatchRoute(path string) (Route, context.Params, string) { return route, params, subPath } -func (n *regexpNode) MatchMiddleware(path string) middleware.Middleware { +func (n *regexpNode) MatchMiddleware(path string) middleware.Collection { pathPart, subPath := pathutils.GetPart(path) if !n.regexp.MatchString(pathPart) { return nil diff --git a/mux/route.go b/mux/route.go index 60541d7..2c98993 100644 --- a/mux/route.go +++ b/mux/route.go @@ -1,6 +1,6 @@ package mux -// Route is an middleware aware route interface +// Route is an handler aware route interface type Route interface { Handler() interface{} } diff --git a/mux/tree.go b/mux/tree.go index 9273052..fda8299 100644 --- a/mux/tree.go +++ b/mux/tree.go @@ -81,8 +81,8 @@ func (t Tree) MatchRoute(path string) (Route, context.Params, string) { } // MatchMiddleware collects middleware from all nodes that match path -func (t Tree) MatchMiddleware(path string) middleware.Middleware { - var treeMiddleware = make(middleware.Middleware, 0) +func (t Tree) MatchMiddleware(path string) middleware.Collection { + var treeMiddleware = make(middleware.Collection, 0) for _, child := range t { if m := child.MatchMiddleware(path); m != nil { @@ -135,9 +135,9 @@ func (t Tree) WithRoute(path string, route Route, maxParamsSize uint8) Tree { return newTree } -// WithMiddleware returns new Tree with Middleware appended to given Node -// Middleware is appended to Node under the give path, if Node does not exist it will panic -func (t Tree) WithMiddleware(path string, m middleware.Middleware) Tree { +// WithMiddleware returns new Tree with Collection appended to given Node +// Collection is appended to Node under the give path, if Node does not exist it will panic +func (t Tree) WithMiddleware(path string, ws []middleware.Wrapper, priority uint, maxParamsSize uint8) Tree { path = pathutils.TrimSlash(path) if path == "" { return t @@ -149,14 +149,14 @@ func (t Tree) WithMiddleware(path string, m middleware.Middleware) Tree { newTree := t if node == nil { - node = NewNode(parts[0], 0) + node = NewNode(parts[0], maxParamsSize) newTree = t.withNode(node) } if len(parts) == 1 { - node.AppendMiddleware(m) + node.AppendMiddleware(middleware.NewCollectionFromWrappers(priority, ws...)) } else { - node.WithChildren(node.Tree().WithMiddleware(strings.Join(parts[1:], "/"), m)) + node.WithChildren(node.Tree().WithMiddleware(strings.Join(parts[1:], "/"), ws, priority, maxParamsSize)) } return newTree @@ -206,7 +206,7 @@ func (t Tree) withNode(node Node) Tree { // Sort sorts nodes in order: static, regexp, wildcard func (t Tree) sort() Tree { // Sort Nodes in order [statics, regexps, wildcards] - sort.Slice(t, func(i, j int) bool { + sort.SliceStable(t, func(i, j int) bool { return isMoreImportant(t[i], t[j]) }) diff --git a/mux/tree_test.go b/mux/tree_test.go index 4becac3..7d05b00 100644 --- a/mux/tree_test.go +++ b/mux/tree_test.go @@ -37,13 +37,9 @@ func TestTreeMatch(t *testing.T) { root.WithChildren(root.Tree().Compile()) - route, _, _ := root.Tree().MatchRoute("pl/blog/comments/123/new") + _, _, subPath := root.Tree().MatchRoute("pl/blog/comments/123/new") - if route == nil { - t.Fatalf("%v", route) - } - - if route != commentNew.Route() { - t.Fatalf("%s != %s (%s)", route, commentNew.Route(), commentNew.Name()) + if subPath != "" { + t.Fatalf("%s != %s (%s)", subPath, "pl/blog/comments/123/new", commentNew.Name()) } } diff --git a/nethttp.go b/nethttp.go index e161054..b39fec3 100644 --- a/nethttp.go +++ b/nethttp.go @@ -9,26 +9,27 @@ import ( "github.com/vardius/gorouter/v4/mux" ) -// New creates new net/http Router instance, returns pointer +// NewCollection creates new net/http Router instance, returns pointer func New(fs ...MiddlewareFunc) Router { + globalMiddleware := middleware.NewCollectionFromWrappers(0, transformMiddlewareFunc(fs...)...) return &router{ - routes: mux.NewTree(), - middleware: mux.NewTree(), - globalMiddleware: transformMiddlewareFunc(fs...), + tree: mux.NewTree(), + globalMiddleware: globalMiddleware, + middlewareCounter: uint(len(globalMiddleware)), } } type router struct { - routes mux.Tree // mux.RouteAware tree - middleware mux.Tree // mux.MiddlewareAware tree - globalMiddleware middleware.Middleware - fileServer http.Handler - notFound http.Handler - notAllowed http.Handler + tree mux.Tree + globalMiddleware middleware.Collection + fileServer http.Handler + notFound http.Handler + notAllowed http.Handler + middlewareCounter uint } func (r *router) PrettyPrint() string { - return r.routes.PrettyPrint() + return r.tree.PrettyPrint() } func (r *router) POST(p string, f http.Handler) { @@ -70,13 +71,14 @@ func (r *router) TRACE(p string, f http.Handler) { func (r *router) USE(method, path string, fs ...MiddlewareFunc) { m := transformMiddlewareFunc(fs...) - r.middleware = r.middleware.WithMiddleware(method+path, m) + r.tree = r.tree.WithMiddleware(method+path, m, r.middlewareCounter, 0) + r.middlewareCounter += uint(len(m)) } func (r *router) Handle(method, path string, h http.Handler) { route := newRoute(h) - r.routes = r.routes.WithRoute(method+path, route, 0) + r.tree = r.tree.WithRoute(method+path, route, 0) } func (r *router) Mount(path string, h http.Handler) { @@ -93,13 +95,13 @@ func (r *router) Mount(path string, h http.Handler) { } { route := newRoute(h) - r.routes = r.routes.WithSubrouter(method+path, route, 0) + r.tree = r.tree.WithSubrouter(method+path, route, 0) } } func (r *router) Compile() { - for i, methodNode := range r.routes { - r.routes[i].WithChildren(methodNode.Tree().Compile()) + for i, methodNode := range r.tree { + r.tree[i].WithChildren(methodNode.Tree().Compile()) } } @@ -123,13 +125,13 @@ func (r *router) ServeFiles(fs http.FileSystem, root string, strip bool) { } func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { - if route, params, subPath := r.routes.MatchRoute(req.Method + req.URL.Path); route != nil { + if route, params, subPath := r.tree.MatchRoute(req.Method + req.URL.Path); route != nil { allMiddleware := r.globalMiddleware - if treeMiddleware := r.middleware.MatchMiddleware(req.Method + req.URL.Path); treeMiddleware != nil { + if treeMiddleware := r.tree.MatchMiddleware(req.Method + req.URL.Path); treeMiddleware != nil { allMiddleware = allMiddleware.Merge(treeMiddleware) } - computedHandler := allMiddleware.Compose(route.Handler()) + computedHandler := allMiddleware.Sort().Compose(route.Handler()) h := computedHandler.(http.Handler) @@ -147,7 +149,7 @@ func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { // Handle OPTIONS if req.Method == http.MethodOptions { - if allow := allowed(r.routes, req.Method, req.URL.Path); len(allow) > 0 { + if allow := allowed(r.tree, req.Method, req.URL.Path); len(allow) > 0 { w.Header().Set("Allow", allow) return } @@ -157,7 +159,7 @@ func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { return } else { // Handle 405 - if allow := allowed(r.routes, req.Method, req.URL.Path); len(allow) > 0 { + if allow := allowed(r.tree, req.Method, req.URL.Path); len(allow) > 0 { w.Header().Set("Allow", allow) r.serveNotAllowed(w, req) return @@ -187,12 +189,12 @@ func (r *router) serveNotAllowed(w http.ResponseWriter, req *http.Request) { } } -func transformMiddlewareFunc(fs ...MiddlewareFunc) middleware.Middleware { - m := make(middleware.Middleware, len(fs)) +func transformMiddlewareFunc(fs ...MiddlewareFunc) []middleware.Wrapper { + m := make([]middleware.Wrapper, len(fs)) for i, f := range fs { - m[i] = func(mf MiddlewareFunc) middleware.MiddlewareFunc { - return func(h interface{}) interface{} { + m[i] = func(mf MiddlewareFunc) middleware.WrapperFunc { + return func(h middleware.Handler) middleware.Handler { return mf(h.(http.Handler)) } }(f) // f is a reference to function so we have to wrap if with that callback diff --git a/nethttp_test.go b/nethttp_test.go index 5551857..0b47920 100644 --- a/nethttp_test.go +++ b/nethttp_test.go @@ -118,7 +118,7 @@ func TestOPTIONS(t *testing.T) { w := httptest.NewRecorder() - // test all routes "*" paths + // test all tree "*" paths req, err := http.NewRequest(http.MethodOptions, "*", nil) if err != nil { t.Fatal(err) @@ -477,7 +477,7 @@ func TestTreeOrphanMiddlewareOrder(t *testing.T) { router.ServeHTTP(w, req) - if w.Body.String() != "m1->m2->mx1->mx2->mxy1->mxy2->mxy3->mxy4->mparam1->mparam2->handler" { + if w.Body.String() != "m1->m2->mx1->mx2->mxy1->mxy2->mparam1->mparam2->mxy3->mxy4->handler" { t.Errorf("Use globalMiddleware error %s", w.Body.String()) } } diff --git a/route_test.go b/route_test.go index 47b2b7a..e58e3b0 100644 --- a/route_test.go +++ b/route_test.go @@ -16,8 +16,8 @@ func TestRouter(t *testing.T) { } }) - buildMiddlewareFunc := func(body string) middleware.MiddlewareFunc { - fn := func(h interface{}) interface{} { + buildMiddlewareFunc := func(body string) middleware.Middleware { + fn := func(h middleware.Handler) middleware.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte(body)); err != nil { t.Fatal(err) @@ -26,7 +26,7 @@ func TestRouter(t *testing.T) { }) } - return fn + return middleware.New(middleware.WrapperFunc(fn), 0) } m1 := buildMiddlewareFunc("1") @@ -34,7 +34,7 @@ func TestRouter(t *testing.T) { m3 := buildMiddlewareFunc("3") r := newRoute(handler) - m := middleware.New(m1, m2, m3) + m := middleware.NewCollection(m1, m2, m3) h := m.Compose(r.Handler()) w := httptest.NewRecorder() diff --git a/tree.go b/tree.go index 34488c7..6156c5e 100644 --- a/tree.go +++ b/tree.go @@ -11,7 +11,7 @@ func allowed(t mux.Tree, method, path string) (allow string) { path = pathutils.TrimSlash(path) if path == "*" { - // routes tree roots should be http method nodes only + // tree tree roots should be http method nodes only for _, root := range t { if root.Name() == http.MethodOptions { continue @@ -23,7 +23,7 @@ func allowed(t mux.Tree, method, path string) (allow string) { } } } else { - // routes tree roots should be http method nodes only + // tree tree roots should be http method nodes only for _, root := range t { if root.Name() == method || root.Name() == http.MethodOptions { continue From 3a8674b3da4ae69bb2f200f6c9e91a0040c24bfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Tue, 28 Jan 2020 22:43:52 +1100 Subject: [PATCH 26/41] Remove redundant types --- middleware/collection_test.go | 12 ++++++------ middleware/middleware_test.go | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/middleware/collection_test.go b/middleware/collection_test.go index 6d2193e..db223fb 100644 --- a/middleware/collection_test.go +++ b/middleware/collection_test.go @@ -34,12 +34,12 @@ func TestNewCollection(t *testing.T) { sortedOutput string } tests := []test{ - test{"Empty", []Middleware{}, "h", "h"}, - test{"Single middleware", []Middleware{middlewareFactory("0", 0)}, "0h", "0h"}, - test{"Multiple unsorted middleware", []Middleware{middlewareFactory("3", 3), middlewareFactory("1", 1), middlewareFactory("2", 2)}, "312h", "123h"}, - test{"Multiple unsorted middleware 2", []Middleware{middlewareFactory("2", 2), middlewareFactory("1", 1), middlewareFactory("3", 3)}, "213h", "123h"}, - test{"Multiple unsorted middleware 3", []Middleware{middlewareFactory("1", 1), middlewareFactory("3", 3), middlewareFactory("2", 2)}, "132h", "123h"}, - test{"Multiple sorted middleware", []Middleware{middlewareFactory("1", 1), middlewareFactory("2", 2), middlewareFactory("3", 3)}, "123h", "123h"}, + {"Empty", []Middleware{}, "h", "h"}, + {"Single middleware", []Middleware{middlewareFactory("0", 0)}, "0h", "0h"}, + {"Multiple unsorted middleware", []Middleware{middlewareFactory("3", 3), middlewareFactory("1", 1), middlewareFactory("2", 2)}, "312h", "123h"}, + {"Multiple unsorted middleware 2", []Middleware{middlewareFactory("2", 2), middlewareFactory("1", 1), middlewareFactory("3", 3)}, "213h", "123h"}, + {"Multiple unsorted middleware 3", []Middleware{middlewareFactory("1", 1), middlewareFactory("3", 3), middlewareFactory("2", 2)}, "132h", "123h"}, + {"Multiple sorted middleware", []Middleware{middlewareFactory("1", 1), middlewareFactory("2", 2), middlewareFactory("3", 3)}, "123h", "123h"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 58fa08d..4d9b49f 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -20,8 +20,8 @@ func TestNew(t *testing.T) { args args } tests := []test{ - test{"From Wrapper", args{&mockWrapper{}, 0}}, - test{"From WrapperFunc", args{WrapperFunc(func(h Handler) Handler { return func() {} }), 0}}, + {"From Wrapper", args{&mockWrapper{}, 0}}, + {"From WrapperFunc", args{WrapperFunc(func(h Handler) Handler { return func() {} }), 0}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -48,9 +48,9 @@ func TestMiddleware_Priority(t *testing.T) { want uint } tests := []test{ - test{"Zero", mockMiddleware("TestMiddleware_Priority 1", 0), 0}, - test{"Positive", mockMiddleware("TestMiddleware_Priority 1", 1), 1}, - test{"Positive Large", mockMiddleware("TestMiddleware_Priority 1", 999), 999}, + {"Zero", mockMiddleware("TestMiddleware_Priority 1", 0), 0}, + {"Positive", mockMiddleware("TestMiddleware_Priority 1", 1), 1}, + {"Positive Large", mockMiddleware("TestMiddleware_Priority 1", 999), 999}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { From c38ddf4483aa0be294ffd79b85957ebc771b570d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Tue, 28 Jan 2020 22:46:45 +1100 Subject: [PATCH 27/41] Revert test, case covered in other test case --- fasthttp_test.go | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/fasthttp_test.go b/fasthttp_test.go index 99ea8b5..c546897 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -391,16 +391,14 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { } }) - router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m1")) - router.USE(http.MethodGet, "/x/{param}", mockFastHTTPMiddleware("m2")) - router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m3")) - router.USE(http.MethodGet, "/x/{param}", mockFastHTTPMiddleware("m4")) + router.USE(http.MethodGet, "/x/{param}", mockFastHTTPMiddleware("m1")) + router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m2")) ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/y") router.HandleFastHTTP(ctx) - if string(ctx.Response.Body()) != "m2m4y" { + if string(ctx.Response.Body()) != "m1y" { t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body())) } @@ -408,7 +406,7 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { router.HandleFastHTTP(ctx) - if string(ctx.Response.Body()) != "m1m2m3m4x" { + if string(ctx.Response.Body()) != "m1m2x" { t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body())) } } From 6b55fe312daaccdb6e8b610bcfdb8cf04ab1b351 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Tue, 28 Jan 2020 22:50:53 +1100 Subject: [PATCH 28/41] Update method comments --- middleware/collection.go | 2 +- middleware/middleware.go | 2 ++ nethttp.go | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/middleware/collection.go b/middleware/collection.go index c64eeae..7052c04 100644 --- a/middleware/collection.go +++ b/middleware/collection.go @@ -45,7 +45,7 @@ func (c Collection) Compose(h Handler) Handler { return h } -// Merge merges another middleware +// Sort sorts collection by priority func (c Collection) Sort() Collection { sort.SliceStable(c, func(i, j int) bool { return c[i].Priority() < c[j].Priority() diff --git a/middleware/middleware.go b/middleware/middleware.go index e42138d..9801707 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -30,10 +30,12 @@ type Middleware struct { priority uint } +// Wrap Handler with middleware func (m Middleware) Wrap(h Handler) Handler { return m.wrapper.Wrap(h) } +// Priority provides a value for sorting Collection, lower values come first func (m Middleware) Priority() uint { return m.priority } diff --git a/nethttp.go b/nethttp.go index b39fec3..1bb8154 100644 --- a/nethttp.go +++ b/nethttp.go @@ -9,7 +9,7 @@ import ( "github.com/vardius/gorouter/v4/mux" ) -// NewCollection creates new net/http Router instance, returns pointer +// New creates new net/http Router instance, returns pointer func New(fs ...MiddlewareFunc) Router { globalMiddleware := middleware.NewCollectionFromWrappers(0, transformMiddlewareFunc(fs...)...) return &router{ From ba9a1e9a22e50caf60c720f884b32b2e18108c46 Mon Sep 17 00:00:00 2001 From: Rafal Lorenz Date: Wed, 29 Jan 2020 09:02:45 +1100 Subject: [PATCH 29/41] Refactor middleware package --- fasthttp.go | 11 ++++--- middleware/collection.go | 15 ---------- middleware/collection_test.go | 38 ++++++++++++------------ middleware/middleware.go | 32 +++++++++----------- middleware/middleware_test.go | 56 +++++------------------------------ mux/tree.go | 6 ++-- nethttp.go | 11 ++++--- route_test.go | 2 +- 8 files changed, 59 insertions(+), 112 deletions(-) diff --git a/fasthttp.go b/fasthttp.go index 72e6e3f..0b86ffc 100644 --- a/fasthttp.go +++ b/fasthttp.go @@ -10,7 +10,7 @@ import ( // NewFastHTTPRouter creates new Router instance, returns pointer func NewFastHTTPRouter(fs ...FastHTTPMiddlewareFunc) FastHTTPRouter { - globalMiddleware := middleware.NewCollectionFromWrappers(0, transformFastHTTPMiddlewareFunc(fs...)...) + globalMiddleware := transformFastHTTPMiddlewareFunc(fs...) return &fastHTTPRouter{ tree: mux.NewTree(), globalMiddleware: globalMiddleware, @@ -69,8 +69,11 @@ func (r *fastHTTPRouter) TRACE(p string, f fasthttp.RequestHandler) { func (r *fastHTTPRouter) USE(method, path string, fs ...FastHTTPMiddlewareFunc) { m := transformFastHTTPMiddlewareFunc(fs...) + for i, mf := range m { + m[i] = middleware.WithPriority(mf, r.middlewareCounter) + } - r.tree = r.tree.WithMiddleware(method+path, m, r.middlewareCounter, 0) + r.tree = r.tree.WithMiddleware(method+path, m, 0) r.middlewareCounter += uint(len(m)) } @@ -185,8 +188,8 @@ func (r *fastHTTPRouter) serveNotAllowed(ctx *fasthttp.RequestCtx) { } } -func transformFastHTTPMiddlewareFunc(fs ...FastHTTPMiddlewareFunc) []middleware.Wrapper { - m := make([]middleware.Wrapper, len(fs)) +func transformFastHTTPMiddlewareFunc(fs ...FastHTTPMiddlewareFunc) middleware.Collection { + m := make(middleware.Collection, len(fs)) for i, f := range fs { m[i] = func(mf FastHTTPMiddlewareFunc) middleware.WrapperFunc { diff --git a/middleware/collection.go b/middleware/collection.go index 7052c04..8b1e5ed 100644 --- a/middleware/collection.go +++ b/middleware/collection.go @@ -12,21 +12,6 @@ func NewCollection(ms ...Middleware) Collection { return ms } -// NewCollectionFromWrappers provides new middleware -// with order priority preset to provided value -func NewCollectionFromWrappers(priority uint, ws ...Wrapper) Collection { - c := make(Collection, len(ws)) - - for i, w := range ws { - c[i] = Middleware{ - wrapper: w, - priority: priority, - } - } - - return c -} - // Merge merges another middleware func (c Collection) Merge(m Collection) Collection { return append(c, m...) diff --git a/middleware/collection_test.go b/middleware/collection_test.go index db223fb..0898b59 100644 --- a/middleware/collection_test.go +++ b/middleware/collection_test.go @@ -6,7 +6,7 @@ import ( "testing" ) -func mockMiddleware(body string, priority uint) Middleware { +func mockMiddleware(body string) Middleware { fn := func(h Handler) Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte(body)); err != nil { @@ -16,30 +16,30 @@ func mockMiddleware(body string, priority uint) Middleware { }) } - return New(WrapperFunc(fn), priority) + return WrapperFunc(fn) } -func TestNewCollection(t *testing.T) { +func TestCollection(t *testing.T) { middlewareFactory := func(body string, priority uint) Middleware { fn := func(h Handler) Handler { return func() string { return body + h.(func() string)() } } - return New(WrapperFunc(fn), priority) + return WithPriority(WrapperFunc(fn), priority) } type test struct { name string - m []Middleware + m Collection output string sortedOutput string } tests := []test{ - {"Empty", []Middleware{}, "h", "h"}, - {"Single middleware", []Middleware{middlewareFactory("0", 0)}, "0h", "0h"}, - {"Multiple unsorted middleware", []Middleware{middlewareFactory("3", 3), middlewareFactory("1", 1), middlewareFactory("2", 2)}, "312h", "123h"}, - {"Multiple unsorted middleware 2", []Middleware{middlewareFactory("2", 2), middlewareFactory("1", 1), middlewareFactory("3", 3)}, "213h", "123h"}, - {"Multiple unsorted middleware 3", []Middleware{middlewareFactory("1", 1), middlewareFactory("3", 3), middlewareFactory("2", 2)}, "132h", "123h"}, - {"Multiple sorted middleware", []Middleware{middlewareFactory("1", 1), middlewareFactory("2", 2), middlewareFactory("3", 3)}, "123h", "123h"}, + {"Empty", NewCollection(), "h", "h"}, + {"Single middleware", NewCollection(middlewareFactory("0", 0)), "0h", "0h"}, + {"Multiple unsorted middleware", NewCollection(middlewareFactory("3", 3), middlewareFactory("1", 1), middlewareFactory("2", 2)), "312h", "123h"}, + {"Multiple unsorted middleware 2", NewCollection(middlewareFactory("2", 2), middlewareFactory("1", 1), middlewareFactory("3", 3)), "213h", "123h"}, + {"Multiple unsorted middleware 3", NewCollection(middlewareFactory("1", 1), middlewareFactory("3", 3), middlewareFactory("2", 2)), "132h", "123h"}, + {"Multiple sorted middleware", NewCollection(middlewareFactory("1", 1), middlewareFactory("2", 2), middlewareFactory("3", 3)), "123h", "123h"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -63,10 +63,10 @@ func TestNewCollection(t *testing.T) { } } -func TestOrders(t *testing.T) { - m1 := mockMiddleware("1", 3) - m2 := mockMiddleware("2", 2) - m3 := mockMiddleware("3", 1) +func TestWithPriority(t *testing.T) { + m1 := WithPriority(mockMiddleware("1"), 3) + m2 := WithPriority(mockMiddleware("2"), 2) + m3 := WithPriority(mockMiddleware("3"), 1) fn := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { if _, err := w.Write([]byte("4")); err != nil { t.Fatal(err) @@ -90,9 +90,9 @@ func TestOrders(t *testing.T) { } func TestMerge(t *testing.T) { - m1 := mockMiddleware("1", 0) - m2 := mockMiddleware("2", 0) - m3 := mockMiddleware("3", 0) + m1 := mockMiddleware("1") + m2 := mockMiddleware("2") + m3 := mockMiddleware("3") fn := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { if _, err := w.Write([]byte("4")); err != nil { t.Fatal(err) @@ -117,7 +117,7 @@ func TestMerge(t *testing.T) { } func TestCompose(t *testing.T) { - m := NewCollection(mockMiddleware("1", 0)) + m := NewCollection(mockMiddleware("1")) h := m.Compose(nil) if h != nil { diff --git a/middleware/middleware.go b/middleware/middleware.go index 9801707..8459b28 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -3,14 +3,10 @@ package middleware // Handler represents wrapped function type Handler interface{} -// Wrapper wraps Handler -type Wrapper interface { +// Middleware wraps Handler +type Middleware interface { // Wrap Handler with middleware Wrap(Handler) Handler -} - -// Sortable allows Collection to be sorted by priority -type Sortable interface { // Priority provides a value for sorting Collection, lower values come first Priority() uint } @@ -24,26 +20,26 @@ func (f WrapperFunc) Wrap(h Handler) Handler { return f(h) } -// Middleware is a slice of handler wrappers functions -type Middleware struct { - wrapper Wrapper - priority uint +// Priority provides a value for sorting Collection, lower values come first +func (f WrapperFunc) Priority() (priority uint) { + return } -// Wrap Handler with middleware -func (m Middleware) Wrap(h Handler) Handler { - return m.wrapper.Wrap(h) +// Middleware is a slice of handler wrappers functions +type sortableMiddleware struct { + Middleware + priority uint } // Priority provides a value for sorting Collection, lower values come first -func (m Middleware) Priority() uint { +func (m *sortableMiddleware) Priority() uint { return m.priority } -// New provides new Middleware -func New(w Wrapper, priority uint) Middleware { - return Middleware{ - wrapper: w, +// WithPriority provides new Middleware with priority +func WithPriority(middleware Middleware, priority uint) Middleware { + return &sortableMiddleware{ + Middleware: middleware, priority: priority, } } diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 4d9b49f..6591ec0 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -4,62 +4,22 @@ import ( "testing" ) -type mockWrapper struct{} - -func (*mockWrapper) Wrap(h Handler) Handler { - return h -} - -func TestNew(t *testing.T) { - type args struct { - w Wrapper - priority uint - } - type test struct { - name string - args args - } - tests := []test{ - {"From Wrapper", args{&mockWrapper{}, 0}}, - {"From WrapperFunc", args{WrapperFunc(func(h Handler) Handler { return func() {} }), 0}}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - panicked := false - defer func() { - if rcv := recover(); rcv != nil { - panicked = true - } - }() - - got := New(tt.args.w, tt.args.priority) - - if panicked { - t.Errorf("Panic: New() = %v", got) - } - }) - } -} - -func TestMiddleware_Priority(t *testing.T) { +func TestMiddleware_WithPriority(t *testing.T) { type test struct { name string middleware Middleware - want uint + priority uint } tests := []test{ - {"Zero", mockMiddleware("TestMiddleware_Priority 1", 0), 0}, - {"Positive", mockMiddleware("TestMiddleware_Priority 1", 1), 1}, - {"Positive Large", mockMiddleware("TestMiddleware_Priority 1", 999), 999}, + {"Zero", mockMiddleware("Zero"), 0}, + {"Positive", mockMiddleware("Positive"), 1}, + {"Positive Large", mockMiddleware("Positive Large"), 999}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - m := Middleware{ - wrapper: tt.middleware.wrapper, - priority: tt.middleware.priority, - } - if got := m.Priority(); got != tt.want { - t.Errorf("Priority() = %v, want %v", got, tt.want) + m := WithPriority(tt.middleware, tt.priority) + if got := m.Priority(); got != tt.priority { + t.Errorf("Priority() = %v, want %v", got, tt.priority) } }) } diff --git a/mux/tree.go b/mux/tree.go index fda8299..8d4ecfb 100644 --- a/mux/tree.go +++ b/mux/tree.go @@ -137,7 +137,7 @@ func (t Tree) WithRoute(path string, route Route, maxParamsSize uint8) Tree { // WithMiddleware returns new Tree with Collection appended to given Node // Collection is appended to Node under the give path, if Node does not exist it will panic -func (t Tree) WithMiddleware(path string, ws []middleware.Wrapper, priority uint, maxParamsSize uint8) Tree { +func (t Tree) WithMiddleware(path string, m middleware.Collection, maxParamsSize uint8) Tree { path = pathutils.TrimSlash(path) if path == "" { return t @@ -154,9 +154,9 @@ func (t Tree) WithMiddleware(path string, ws []middleware.Wrapper, priority uint } if len(parts) == 1 { - node.AppendMiddleware(middleware.NewCollectionFromWrappers(priority, ws...)) + node.AppendMiddleware(m) } else { - node.WithChildren(node.Tree().WithMiddleware(strings.Join(parts[1:], "/"), ws, priority, maxParamsSize)) + node.WithChildren(node.Tree().WithMiddleware(strings.Join(parts[1:], "/"), m, maxParamsSize)) } return newTree diff --git a/nethttp.go b/nethttp.go index 1bb8154..d462158 100644 --- a/nethttp.go +++ b/nethttp.go @@ -11,7 +11,7 @@ import ( // New creates new net/http Router instance, returns pointer func New(fs ...MiddlewareFunc) Router { - globalMiddleware := middleware.NewCollectionFromWrappers(0, transformMiddlewareFunc(fs...)...) + globalMiddleware := transformMiddlewareFunc(fs...) return &router{ tree: mux.NewTree(), globalMiddleware: globalMiddleware, @@ -70,8 +70,11 @@ func (r *router) TRACE(p string, f http.Handler) { func (r *router) USE(method, path string, fs ...MiddlewareFunc) { m := transformMiddlewareFunc(fs...) + for i, mf := range m { + m[i] = middleware.WithPriority(mf, r.middlewareCounter) + } - r.tree = r.tree.WithMiddleware(method+path, m, r.middlewareCounter, 0) + r.tree = r.tree.WithMiddleware(method+path, m, 0) r.middlewareCounter += uint(len(m)) } @@ -189,8 +192,8 @@ func (r *router) serveNotAllowed(w http.ResponseWriter, req *http.Request) { } } -func transformMiddlewareFunc(fs ...MiddlewareFunc) []middleware.Wrapper { - m := make([]middleware.Wrapper, len(fs)) +func transformMiddlewareFunc(fs ...MiddlewareFunc) middleware.Collection { + m := make(middleware.Collection, len(fs)) for i, f := range fs { m[i] = func(mf MiddlewareFunc) middleware.WrapperFunc { diff --git a/route_test.go b/route_test.go index e58e3b0..978d5de 100644 --- a/route_test.go +++ b/route_test.go @@ -26,7 +26,7 @@ func TestRouter(t *testing.T) { }) } - return middleware.New(middleware.WrapperFunc(fn), 0) + return middleware.WrapperFunc(fn) } m1 := buildMiddlewareFunc("1") From c388c4197c3896ffb1502d7fec475f83452b2cdc Mon Sep 17 00:00:00 2001 From: Rafal Lorenz Date: Wed, 29 Jan 2020 09:12:08 +1100 Subject: [PATCH 30/41] Sort only tree middleware (global always comes first), comput handler only if we have middleware --- fasthttp.go | 14 ++++++++++---- nethttp.go | 14 ++++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/fasthttp.go b/fasthttp.go index 0b86ffc..387e91c 100644 --- a/fasthttp.go +++ b/fasthttp.go @@ -129,13 +129,19 @@ func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) { if route, params, subPath := r.tree.MatchRoute(method + path); route != nil { allMiddleware := r.globalMiddleware - if treeMiddleware := r.tree.MatchMiddleware(method + path); treeMiddleware != nil { - allMiddleware = allMiddleware.Merge(treeMiddleware) + if treeMiddleware := r.tree.MatchMiddleware(method + path); treeMiddleware != nil && len(treeMiddleware) > 0 { + allMiddleware = allMiddleware.Merge(treeMiddleware.Sort()) } - computedHandler := allMiddleware.Sort().Compose(route.Handler()) + var h fasthttp.RequestHandler + if len(allMiddleware) > 0 { + computedHandler := allMiddleware.Compose(route.Handler()) + + h = computedHandler.(fasthttp.RequestHandler) - h := computedHandler.(fasthttp.RequestHandler) + } + + h = route.Handler().(fasthttp.RequestHandler) if len(params) > 0 { ctx.SetUserValue("params", params) diff --git a/nethttp.go b/nethttp.go index d462158..d8af8c8 100644 --- a/nethttp.go +++ b/nethttp.go @@ -130,13 +130,19 @@ func (r *router) ServeFiles(fs http.FileSystem, root string, strip bool) { func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { if route, params, subPath := r.tree.MatchRoute(req.Method + req.URL.Path); route != nil { allMiddleware := r.globalMiddleware - if treeMiddleware := r.tree.MatchMiddleware(req.Method + req.URL.Path); treeMiddleware != nil { - allMiddleware = allMiddleware.Merge(treeMiddleware) + if treeMiddleware := r.tree.MatchMiddleware(req.Method + req.URL.Path); treeMiddleware != nil && len(treeMiddleware) > 0 { + allMiddleware = allMiddleware.Merge(treeMiddleware.Sort()) } - computedHandler := allMiddleware.Sort().Compose(route.Handler()) + var h http.Handler + if len(allMiddleware) > 0 { + computedHandler := allMiddleware.Compose(route.Handler()) + + h = computedHandler.(http.Handler) - h := computedHandler.(http.Handler) + } + + h = route.Handler().(http.Handler) if len(params) > 0 { req = req.WithContext(context.WithParams(req.Context(), params)) From 46b853f5008d2e6a7588b7f1d334c8fe5773852e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Wed, 29 Jan 2020 09:15:50 +1100 Subject: [PATCH 31/41] Update fasthttp.go --- fasthttp.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fasthttp.go b/fasthttp.go index 387e91c..fc2205e 100644 --- a/fasthttp.go +++ b/fasthttp.go @@ -139,10 +139,10 @@ func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) { h = computedHandler.(fasthttp.RequestHandler) + } else { + h = route.Handler().(fasthttp.RequestHandler) } - h = route.Handler().(fasthttp.RequestHandler) - if len(params) > 0 { ctx.SetUserValue("params", params) } From c855097a01ab7625a0ecbba28dc42691a6794e56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Wed, 29 Jan 2020 09:15:53 +1100 Subject: [PATCH 32/41] Update nethttp.go --- nethttp.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nethttp.go b/nethttp.go index d8af8c8..5999c31 100644 --- a/nethttp.go +++ b/nethttp.go @@ -140,10 +140,10 @@ func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { h = computedHandler.(http.Handler) + } else { + h = route.Handler().(http.Handler) } - h = route.Handler().(http.Handler) - if len(params) > 0 { req = req.WithContext(context.WithParams(req.Context(), params)) } From 6ecb25bd90c38f5ebc01a77c697ee4f332a7fd14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Wed, 29 Jan 2020 09:43:48 +1100 Subject: [PATCH 33/41] Update doc.go --- doc.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/doc.go b/doc.go index c6e6dbc..e217722 100644 --- a/doc.go +++ b/doc.go @@ -1,11 +1,11 @@ /* -Package gorouter provide request router with globalMiddleware +Package gorouter provide request router with middleware Router -The router determines how to handle that request. -GoRouter uses a routing tree. Once one branch of the tree matches, only tree inside that branch are considered, -not any tree after that branch. When instantiating router, the root node of tree is created. +The router determines how to handle http request. +GoRouter uses a routing tree. Once one branch of the tree matches, only routes inside that branch are considered, +not any routes after that branch. When instantiating router, the root node of tree is created. Route types From 11ccf293b74e20d85ec0a00eadadf1b6fa3abefa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Wed, 29 Jan 2020 17:37:13 +1100 Subject: [PATCH 34/41] Fox some comments --- example_test.go | 30 +++++++++++++++--------------- fasthttp.go | 15 +++++++-------- fasthttp_test.go | 18 +++++++++--------- nethttp.go | 15 +++++++-------- nethttp_test.go | 18 +++++++++--------- router.go | 8 ++++---- 6 files changed, 51 insertions(+), 53 deletions(-) diff --git a/example_test.go b/example_test.go index 66ebf33..c6e1db6 100644 --- a/example_test.go +++ b/example_test.go @@ -65,8 +65,8 @@ func Example_second() { } func ExampleMiddlewareFunc() { - // Global globalMiddleware example - // applies to all tree + // Global middleware example + // applies to all routes hello := func(w http.ResponseWriter, r *http.Request) { params, _ := context.Parameters(r.Context()) fmt.Printf("Hello, %s!\n", params.Value("name")) @@ -81,7 +81,7 @@ func ExampleMiddlewareFunc() { return http.HandlerFunc(fn) } - // apply globalMiddleware to all tree + // apply middleware to all routes // can pass as many as you want router := gorouter.New(logger) router.GET("/hello/{name}", http.HandlerFunc(hello)) @@ -95,7 +95,7 @@ func ExampleMiddlewareFunc() { } func ExampleMiddlewareFunc_second() { - // Route level globalMiddleware example + // Route level middleware example // applies to route and its lower tree hello := func(w http.ResponseWriter, r *http.Request) { params, _ := context.Parameters(r.Context()) @@ -114,7 +114,7 @@ func ExampleMiddlewareFunc_second() { router := gorouter.New() router.GET("/hello/{name}", http.HandlerFunc(hello)) - // apply globalMiddleware to route and all it children + // apply middleware to route and all its children // can pass as many as you want router.USE("GET", "/hello/{name}", logger) @@ -127,8 +127,8 @@ func ExampleMiddlewareFunc_second() { } func ExampleMiddlewareFunc_third() { - // Http method globalMiddleware example - // applies to all tree under this method + // Http method middleware example + // applies to all routes under this method hello := func(w http.ResponseWriter, r *http.Request) { params, _ := context.Parameters(r.Context()) fmt.Printf("Hello, %s!\n", params.Value("name")) @@ -146,7 +146,7 @@ func ExampleMiddlewareFunc_third() { router := gorouter.New() router.GET("/hello/{name}", http.HandlerFunc(hello)) - // apply globalMiddleware to all tree with GET method + // apply middleware to all routes with GET method // can pass as many as you want router.USE("GET", "", logger) @@ -159,8 +159,8 @@ func ExampleMiddlewareFunc_third() { } func ExampleFastHTTPMiddlewareFunc() { - // Global globalMiddleware example - // applies to all tree + // Global middleware example + // applies to all routes hello := func(ctx *fasthttp.RequestCtx) { params := ctx.UserValue("params").(context.Params) fmt.Printf("Hello, %s!\n", params.Value("name")) @@ -187,7 +187,7 @@ func ExampleFastHTTPMiddlewareFunc() { } func ExampleFastHTTPMiddlewareFunc_second() { - // Route level globalMiddleware example + // Route level middleware example // applies to route and its lower tree hello := func(ctx *fasthttp.RequestCtx) { params := ctx.UserValue("params").(context.Params) @@ -206,7 +206,7 @@ func ExampleFastHTTPMiddlewareFunc_second() { router := gorouter.NewFastHTTPRouter() router.GET("/hello/{name}", hello) - // apply globalMiddleware to route and all it children + // apply middleware to route and all its children // can pass as many as you want router.USE("GET", "/hello/{name}", logger) @@ -219,8 +219,8 @@ func ExampleFastHTTPMiddlewareFunc_second() { } func ExampleFastHTTPMiddlewareFunc_third() { - // Http method globalMiddleware example - // applies to all tree under this method + // Http method middleware example + // applies to all routes under this method hello := func(ctx *fasthttp.RequestCtx) { params := ctx.UserValue("params").(context.Params) fmt.Printf("Hello, %s!\n", params.Value("name")) @@ -238,7 +238,7 @@ func ExampleFastHTTPMiddlewareFunc_third() { router := gorouter.NewFastHTTPRouter() router.GET("/hello/{name}", hello) - // apply globalMiddleware to all tree with GET method + // apply middleware to all routes with GET method // can pass as many as you want router.USE("GET", "", logger) diff --git a/fasthttp.go b/fasthttp.go index fc2205e..7de68c3 100644 --- a/fasthttp.go +++ b/fasthttp.go @@ -128,17 +128,16 @@ func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) { path := string(ctx.Path()) if route, params, subPath := r.tree.MatchRoute(method + path); route != nil { - allMiddleware := r.globalMiddleware - if treeMiddleware := r.tree.MatchMiddleware(method + path); treeMiddleware != nil && len(treeMiddleware) > 0 { - allMiddleware = allMiddleware.Merge(treeMiddleware.Sort()) - } - var h fasthttp.RequestHandler - if len(allMiddleware) > 0 { + if r.middlewareCounter > 0 { + allMiddleware := r.globalMiddleware + if treeMiddleware := r.tree.MatchMiddleware(method + path); treeMiddleware != nil && len(treeMiddleware) > 0 { + allMiddleware = allMiddleware.Merge(treeMiddleware.Sort()) + } + computedHandler := allMiddleware.Compose(route.Handler()) - - h = computedHandler.(fasthttp.RequestHandler) + h = computedHandler.(fasthttp.RequestHandler) } else { h = route.Handler().(fasthttp.RequestHandler) } diff --git a/fasthttp_test.go b/fasthttp_test.go index c546897..b4b11cb 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -341,7 +341,7 @@ func TestFastHTTPNilMiddleware(t *testing.T) { router.HandleFastHTTP(ctx) if string(ctx.Response.Body()) != "test" { - t.Error("Nil globalMiddleware works") + t.Error("Nil middleware works") } } @@ -399,7 +399,7 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { router.HandleFastHTTP(ctx) if string(ctx.Response.Body()) != "m1y" { - t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body())) + t.Errorf("Use middleware error %s", string(ctx.Response.Body())) } ctx = buildFastHTTPRequestContext(http.MethodGet, "/x/x") @@ -407,7 +407,7 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { router.HandleFastHTTP(ctx) if string(ctx.Response.Body()) != "m1m2x" { - t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body())) + t.Errorf("Use middleware error %s", string(ctx.Response.Body())) } } @@ -422,10 +422,10 @@ func TestFastHTTPTreeOrphanMiddlewareOrder(t *testing.T) { } }) - // Method global globalMiddleware + // Method global middleware router.USE(http.MethodGet, "/", mockFastHTTPMiddleware("m1->")) router.USE(http.MethodGet, "/", mockFastHTTPMiddleware("m2->")) - // Path globalMiddleware + // Path middleware router.USE(http.MethodGet, "/x", mockFastHTTPMiddleware("mx1->")) router.USE(http.MethodGet, "/x", mockFastHTTPMiddleware("mx2->")) router.USE(http.MethodGet, "/x/y", mockFastHTTPMiddleware("mxy1->")) @@ -440,7 +440,7 @@ func TestFastHTTPTreeOrphanMiddlewareOrder(t *testing.T) { router.HandleFastHTTP(ctx) if string(ctx.Response.Body()) != "m1->m2->mx1->mx2->mxy1->mxy2->mparam1->mparam2->mxy3->mxy4->handler" { - t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body())) + t.Errorf("Use middleware error %s", string(ctx.Response.Body())) } } @@ -462,7 +462,7 @@ func TestFastHTTPNodeApplyMiddlewareStatic(t *testing.T) { router.HandleFastHTTP(ctx) if string(ctx.Response.Body()) != "m1x" { - t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body())) + t.Errorf("Use middleware error %s", string(ctx.Response.Body())) } } @@ -485,7 +485,7 @@ func TestFastHTTPNodeApplyMiddlewareInvalidNodeReference(t *testing.T) { router.HandleFastHTTP(ctx) if string(ctx.Response.Body()) != "y" { - t.Errorf("Use globalMiddleware error %s", string(ctx.Response.Body())) + t.Errorf("Use middleware error %s", string(ctx.Response.Body())) } } @@ -631,6 +631,6 @@ func TestFastHTTPMountSubRouter(t *testing.T) { mainRouter.HandleFastHTTP(ctx) if string(ctx.Response.Body()) != "[rg1][rg2][r1][r2][sg1][sg2][s1][s2][s]" { - t.Errorf("Router mount sub router globalMiddleware error: %s", string(ctx.Response.Body())) + t.Errorf("Router mount sub router middleware error: %s", string(ctx.Response.Body())) } } diff --git a/nethttp.go b/nethttp.go index 5999c31..17a872c 100644 --- a/nethttp.go +++ b/nethttp.go @@ -129,17 +129,16 @@ func (r *router) ServeFiles(fs http.FileSystem, root string, strip bool) { func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { if route, params, subPath := r.tree.MatchRoute(req.Method + req.URL.Path); route != nil { - allMiddleware := r.globalMiddleware - if treeMiddleware := r.tree.MatchMiddleware(req.Method + req.URL.Path); treeMiddleware != nil && len(treeMiddleware) > 0 { - allMiddleware = allMiddleware.Merge(treeMiddleware.Sort()) - } - var h http.Handler - if len(allMiddleware) > 0 { + if r.middlewareCounter > 0 { + allMiddleware := r.globalMiddleware + if treeMiddleware := r.tree.MatchMiddleware(req.Method + req.URL.Path); treeMiddleware != nil && len(treeMiddleware) > 0 { + allMiddleware = allMiddleware.Merge(treeMiddleware.Sort()) + } + computedHandler := allMiddleware.Compose(route.Handler()) - - h = computedHandler.(http.Handler) + h = computedHandler.(http.Handler) } else { h = route.Handler().(http.Handler) } diff --git a/nethttp_test.go b/nethttp_test.go index 0b47920..abce73b 100644 --- a/nethttp_test.go +++ b/nethttp_test.go @@ -363,7 +363,7 @@ func TestNilMiddleware(t *testing.T) { router.ServeHTTP(w, req) if w.Body.String() != "test" { - t.Error("Nil globalMiddleware works") + t.Error("Nil middleware works") } } @@ -429,7 +429,7 @@ func TestNodeApplyMiddleware(t *testing.T) { router.ServeHTTP(w, req) if w.Body.String() != "m1y" { - t.Errorf("Use globalMiddleware error %s", w.Body.String()) + t.Errorf("Use middleware error %s", w.Body.String()) } w = httptest.NewRecorder() @@ -441,7 +441,7 @@ func TestNodeApplyMiddleware(t *testing.T) { router.ServeHTTP(w, req) if w.Body.String() != "m1m2x" { - t.Errorf("Use globalMiddleware error %s", w.Body.String()) + t.Errorf("Use middleware error %s", w.Body.String()) } } @@ -456,10 +456,10 @@ func TestTreeOrphanMiddlewareOrder(t *testing.T) { } })) - // Method global globalMiddleware + // Method global middleware router.USE(http.MethodGet, "/", mockMiddleware("m1->")) router.USE(http.MethodGet, "/", mockMiddleware("m2->")) - // Path globalMiddleware + // Path middleware router.USE(http.MethodGet, "/x", mockMiddleware("mx1->")) router.USE(http.MethodGet, "/x", mockMiddleware("mx2->")) router.USE(http.MethodGet, "/x/y", mockMiddleware("mxy1->")) @@ -478,7 +478,7 @@ func TestTreeOrphanMiddlewareOrder(t *testing.T) { router.ServeHTTP(w, req) if w.Body.String() != "m1->m2->mx1->mx2->mxy1->mxy2->mparam1->mparam2->mxy3->mxy4->handler" { - t.Errorf("Use globalMiddleware error %s", w.Body.String()) + t.Errorf("Use middleware error %s", w.Body.String()) } } @@ -504,7 +504,7 @@ func TestNodeApplyMiddlewareStatic(t *testing.T) { router.ServeHTTP(w, req) if w.Body.String() != "m1x" { - t.Errorf("Use globalMiddleware error %s", w.Body.String()) + t.Errorf("Use middleware error %s", w.Body.String()) } } @@ -535,7 +535,7 @@ func TestNodeApplyMiddlewareInvalidNodeReference(t *testing.T) { router.ServeHTTP(w, req) if w.Body.String() != "y" { - t.Errorf("Use globalMiddleware error %s", w.Body.String()) + t.Errorf("Use middleware error %s", w.Body.String()) } } @@ -705,6 +705,6 @@ func TestMountSubRouter(t *testing.T) { mainRouter.ServeHTTP(w, req) if w.Body.String() != "[rg1][rg2][r1][r2][sg1][sg2][s1][s2][s]" { - t.Errorf("Router mount sub router globalMiddleware error: %s", w.Body.String()) + t.Errorf("Router mount subrouter middleware error: %s", w.Body.String()) } } diff --git a/router.go b/router.go index fa4766e..788c4c4 100644 --- a/router.go +++ b/router.go @@ -6,10 +6,10 @@ import ( "github.com/valyala/fasthttp" ) -// MiddlewareFunc is a http globalMiddleware function type +// MiddlewareFunc is a http middleware function type type MiddlewareFunc func(http.Handler) http.Handler -// FastHTTPMiddlewareFunc is a fasthttp globalMiddleware function type +// FastHTTPMiddlewareFunc is a fasthttp middleware function type type FastHTTPMiddlewareFunc func(fasthttp.RequestHandler) fasthttp.RequestHandler // Router is a micro framework, HTTP request router, multiplexer, mux @@ -53,7 +53,7 @@ type Router interface { // under TRACE method and given patter TRACE(pattern string, handler http.Handler) - // USE adds globalMiddleware functions ([]MiddlewareFunc) + // USE adds middleware functions ([]MiddlewareFunc) // to whole router branch under given method and patter USE(method, pattern string, fs ...MiddlewareFunc) @@ -125,7 +125,7 @@ type FastHTTPRouter interface { // under TRACE method and given patter TRACE(pattern string, handler fasthttp.RequestHandler) - // USE adds globalMiddleware functions ([]MiddlewareFunc) + // USE adds middleware functions ([]MiddlewareFunc) // to whole router branch under given method and patter USE(method, pattern string, fs ...FastHTTPMiddlewareFunc) From 1334191adcbed5f5298737f659fd9211cbb81620 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Wed, 29 Jan 2020 17:40:18 +1100 Subject: [PATCH 35/41] Lint code --- middleware/middleware.go | 2 +- middleware/middleware_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/middleware/middleware.go b/middleware/middleware.go index 8459b28..1882752 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -40,6 +40,6 @@ func (m *sortableMiddleware) Priority() uint { func WithPriority(middleware Middleware, priority uint) Middleware { return &sortableMiddleware{ Middleware: middleware, - priority: priority, + priority: priority, } } diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 6591ec0..b417d0c 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -8,7 +8,7 @@ func TestMiddleware_WithPriority(t *testing.T) { type test struct { name string middleware Middleware - priority uint + priority uint } tests := []test{ {"Zero", mockMiddleware("Zero"), 0}, From d551eebbddd4f9092bda52719a4dd3f09e7db7ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Wed, 29 Jan 2020 17:44:25 +1100 Subject: [PATCH 36/41] Remove unnecessary check --- fasthttp.go | 2 +- nethttp.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fasthttp.go b/fasthttp.go index 7de68c3..90fc35c 100644 --- a/fasthttp.go +++ b/fasthttp.go @@ -131,7 +131,7 @@ func (r *fastHTTPRouter) HandleFastHTTP(ctx *fasthttp.RequestCtx) { var h fasthttp.RequestHandler if r.middlewareCounter > 0 { allMiddleware := r.globalMiddleware - if treeMiddleware := r.tree.MatchMiddleware(method + path); treeMiddleware != nil && len(treeMiddleware) > 0 { + if treeMiddleware := r.tree.MatchMiddleware(method + path); len(treeMiddleware) > 0 { allMiddleware = allMiddleware.Merge(treeMiddleware.Sort()) } diff --git a/nethttp.go b/nethttp.go index 17a872c..48fe980 100644 --- a/nethttp.go +++ b/nethttp.go @@ -132,7 +132,7 @@ func (r *router) ServeHTTP(w http.ResponseWriter, req *http.Request) { var h http.Handler if r.middlewareCounter > 0 { allMiddleware := r.globalMiddleware - if treeMiddleware := r.tree.MatchMiddleware(req.Method + req.URL.Path); treeMiddleware != nil && len(treeMiddleware) > 0 { + if treeMiddleware := r.tree.MatchMiddleware(req.Method + req.URL.Path); len(treeMiddleware) > 0 { allMiddleware = allMiddleware.Merge(treeMiddleware.Sort()) } From f5d525d7da13790370735ce08021903fa030088f Mon Sep 17 00:00:00 2001 From: mar1n3r0 Date: Wed, 29 Jan 2020 12:29:44 +0200 Subject: [PATCH 37/41] Add node tests. TestNewNode added. --- mux/node_test.go | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 mux/node_test.go diff --git a/mux/node_test.go b/mux/node_test.go new file mode 100644 index 0000000..5645bd7 --- /dev/null +++ b/mux/node_test.go @@ -0,0 +1,35 @@ +package mux + +import ( + "testing" +) + +func TestNewNode(t *testing.T) { + node := NewNode("lang", 0) + + switch node := node.(type) { + case *regexpNode: + t.Fatalf("Expecting: *mux.staticNode. Wrong node type: %T\n", node) + case *wildcardNode: + t.Fatalf("Expecting: *mux.staticNode. Wrong node type: %T\n", node) + } + + node = NewNode("{lang:en|pl}", 0) + + switch node := node.(type) { + case *staticNode: + t.Fatalf("Expecting: *mux.staticNode. Wrong node type: %T\n", node) + case *wildcardNode: + t.Fatalf("Expecting: *mux.staticNode. Wrong node type: %T\n", node) + } + + node = NewNode("{lang}", 0) + + switch node := node.(type) { + case *staticNode: + t.Fatalf("Expecting: *mux.staticNode. Wrong node type: %T\n", node) + case *regexpNode: + t.Fatalf("Expecting: *mux.staticNode. Wrong node type: %T\n", node) + } + +} From 78338eb135a00e1e3603cf0ec690eacd4c7e0fc8 Mon Sep 17 00:00:00 2001 From: mar1n3r0 Date: Wed, 29 Jan 2020 15:42:16 +0200 Subject: [PATCH 38/41] Align BenchmarkMux with TestTreeMatch since we can not use routes in tests without request handlers --- mux/benchmark_test.go | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/mux/benchmark_test.go b/mux/benchmark_test.go index fa32acb..6d6d9df 100644 --- a/mux/benchmark_test.go +++ b/mux/benchmark_test.go @@ -37,14 +37,10 @@ func BenchmarkMux(b *testing.B) { b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - route, _, _ := root.Tree().MatchRoute("pl/blog/comments/123/new") + _, _, subPath := root.Tree().MatchRoute("pl/blog/comments/123/new") - if route == nil { - b.Fatalf("%v", route) - } - - if route != commentNew.Route() { - b.Fatalf("%s != %s (%s)", route, commentNew.Route(), commentNew.Name()) + if subPath != "" { + b.Fatalf("%s != %s (%s)", subPath, "pl/blog/comments/123/new", commentNew.Name()) } } }) From 1cf98aaaf605c4855ba293dca3976458494c85dc Mon Sep 17 00:00:00 2001 From: mar1n3r0 Date: Wed, 29 Jan 2020 22:09:12 +0200 Subject: [PATCH 39/41] Refactor and optimize method and mocks tests --- .gitignore | 4 +- fasthttp_test.go | 139 ++++++++++++++--------------------------- mocks_test.go | 45 ++++++++++++-- nethttp_test.go | 157 +++++++++++++++++------------------------------ 4 files changed, 145 insertions(+), 200 deletions(-) diff --git a/.gitignore b/.gitignore index f7e9dd9..b276d24 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,6 @@ .vscode .idea -vendor/ \ No newline at end of file +vendor/ + +.history/ \ No newline at end of file diff --git a/fasthttp_test.go b/fasthttp_test.go index b4b11cb..0b7e9c4 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -3,6 +3,7 @@ package gorouter import ( "fmt" "net/http" + "reflect" "strings" "testing" @@ -18,19 +19,52 @@ func buildFastHTTPRequestContext(method, path string) *fasthttp.RequestCtx { return ctx } -func testBasicFastHTTPMethod(t *testing.T, router *fastHTTPRouter, h func(pattern string, handler fasthttp.RequestHandler), method string) { - handler := &mockHandler{} - h("/x/y", handler.HandleFastHTTP) +func TestBasicFastHTTPMethod(t *testing.T) { + for _, m := range CreateHTTPMethodsMap() { + if m == "OPTIONS" { + router := NewFastHTTPRouter().(*fastHTTPRouter) - checkIfHasRootRoute(t, router, method) + handler := &mockHandler{} + router.GET("/x/y", handler.HandleFastHTTP) + router.POST("/x/y", handler.HandleFastHTTP) - err := mockHandleFastHTTP(router.HandleFastHTTP, method, "/x/y") - if err != nil { - t.Fatal(err) - } + TestIfHasRootRoute(t) - if handler.served != true { - t.Error("Handler has not been served") + ctx := buildFastHTTPRequestContext(http.MethodOptions, "*") + + router.HandleFastHTTP(ctx) + + if allow := string(ctx.Response.Header.Peek("Allow")); !strings.Contains(allow, "POST") || !strings.Contains(allow, "GET") || !strings.Contains(allow, "OPTIONS") { + t.Errorf("Allow header incorrect value: %s", allow) + } + + ctx2 := buildFastHTTPRequestContext(http.MethodOptions, "/x/y") + + router.HandleFastHTTP(ctx2) + + if allow := string(ctx.Response.Header.Peek("Allow")); !strings.Contains(allow, "POST") || !strings.Contains(allow, "GET") || !strings.Contains(allow, "OPTIONS") { + t.Errorf("Allow header incorrect value: %s", allow) + } + } + + handler := &mockHandler{} + router := NewFastHTTPRouter().(*fastHTTPRouter) + + in := make([]reflect.Value, 2) + in[0] = reflect.ValueOf("/x/y") + in[1] = reflect.ValueOf(handler.HandleFastHTTP) + reflect.ValueOf(router).MethodByName(m).Call(in) + + TestIfHasRootRoute(t) + + err := mockHandleFastHTTP(router.HandleFastHTTP, m, "/x/y") + if err != nil { + t.Fatal(err) + } + + if handler.served != true { + t.Error("Handler has not been served") + } } } @@ -45,7 +79,7 @@ func TestFastHTTPHandle(t *testing.T) { router := NewFastHTTPRouter().(*fastHTTPRouter) router.Handle(http.MethodPost, "/x/y", handler.HandleFastHTTP) - checkIfHasRootRoute(t, router, http.MethodPost) + TestIfHasRootRoute(t) err := mockHandleFastHTTP(router.HandleFastHTTP, http.MethodPost, "/x/y") if err != nil { @@ -57,89 +91,10 @@ func TestFastHTTPHandle(t *testing.T) { } } -func TestFastHTTPPOST(t *testing.T) { - t.Parallel() - - router := NewFastHTTPRouter().(*fastHTTPRouter) - testBasicFastHTTPMethod(t, router, router.POST, http.MethodPost) -} - -func TestFastHTTPGET(t *testing.T) { - t.Parallel() - - router := NewFastHTTPRouter().(*fastHTTPRouter) - testBasicFastHTTPMethod(t, router, router.GET, http.MethodGet) -} - -func TestFastHTTPPUT(t *testing.T) { - t.Parallel() - - router := NewFastHTTPRouter().(*fastHTTPRouter) - testBasicFastHTTPMethod(t, router, router.PUT, http.MethodPut) -} - -func TestFastHTTPDELETE(t *testing.T) { - t.Parallel() - - router := NewFastHTTPRouter().(*fastHTTPRouter) - testBasicFastHTTPMethod(t, router, router.DELETE, http.MethodDelete) -} - -func TestFastHTTPPATCH(t *testing.T) { - t.Parallel() - - router := NewFastHTTPRouter().(*fastHTTPRouter) - testBasicFastHTTPMethod(t, router, router.PATCH, http.MethodPatch) -} - -func TestFastHTTPHEAD(t *testing.T) { +func TestFastHTTPMethods(t *testing.T) { t.Parallel() - router := NewFastHTTPRouter().(*fastHTTPRouter) - testBasicFastHTTPMethod(t, router, router.HEAD, http.MethodHead) -} - -func TestFastHTTPCONNECT(t *testing.T) { - t.Parallel() - - router := NewFastHTTPRouter().(*fastHTTPRouter) - testBasicFastHTTPMethod(t, router, router.CONNECT, http.MethodConnect) -} - -func TestFastHTTPTRACE(t *testing.T) { - t.Parallel() - - router := NewFastHTTPRouter().(*fastHTTPRouter) - testBasicFastHTTPMethod(t, router, router.TRACE, http.MethodTrace) -} - -func TestFastHTTPOPTIONS(t *testing.T) { - t.Parallel() - - router := NewFastHTTPRouter().(*fastHTTPRouter) - testBasicFastHTTPMethod(t, router, router.OPTIONS, http.MethodOptions) - - handler := &mockHandler{} - router.GET("/x/y", handler.HandleFastHTTP) - router.POST("/x/y", handler.HandleFastHTTP) - - checkIfHasRootRoute(t, router, http.MethodGet) - - ctx := buildFastHTTPRequestContext(http.MethodOptions, "*") - - router.HandleFastHTTP(ctx) - - if allow := string(ctx.Response.Header.Peek("Allow")); !strings.Contains(allow, "POST") || !strings.Contains(allow, "GET") || !strings.Contains(allow, "OPTIONS") { - t.Errorf("Allow header incorrect value: %s", allow) - } - - ctx2 := buildFastHTTPRequestContext(http.MethodOptions, "/x/y") - - router.HandleFastHTTP(ctx2) - - if allow := string(ctx.Response.Header.Peek("Allow")); !strings.Contains(allow, "POST") || !strings.Contains(allow, "GET") || !strings.Contains(allow, "OPTIONS") { - t.Errorf("Allow header incorrect value: %s", allow) - } + TestBasicFastHTTPMethod(t) } func TestFastHTTPNotFound(t *testing.T) { diff --git a/mocks_test.go b/mocks_test.go index 4d586df..3c16450 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -88,14 +88,49 @@ func mockHandleFastHTTP(h fasthttp.RequestHandler, method, path string) error { return nil } -func checkIfHasRootRoute(t *testing.T, r interface{}, method string) { - switch v := r.(type) { - case *router: +func TestIfHasRootRoute(t *testing.T) { + r := routerInterface() + f := fastHTTProuterInterface() + switch v := f.(type) { case *fastHTTPRouter: - if rootRoute := v.tree.Find(method); rootRoute == nil { - t.Error("Route not found") + if rootRoute := v.tree.Find(fasthttp.MethodPost); rootRoute == nil { + switch v := r.(type) { + case *router: + if rootRoute := v.tree.Find(fasthttp.MethodPost); rootRoute == nil { + t.Error("Route not found") + } + } } default: t.Error("Unsupported type") } } + +func routerInterface() interface{} { + handler := &mockHandler{} + router := New().(*router) + router.POST("/x/y", handler) + return router +} + +func fastHTTProuterInterface() interface{} { + handler := &mockHandler{} + router := NewFastHTTPRouter().(*fastHTTPRouter) + router.POST("/x/y", handler.HandleFastHTTP) + return router +} + +func CreateHTTPMethodsMap() []string { + m := []string{ + fasthttp.MethodPost, + fasthttp.MethodGet, + fasthttp.MethodPut, + fasthttp.MethodDelete, + fasthttp.MethodPatch, + fasthttp.MethodHead, + fasthttp.MethodConnect, + fasthttp.MethodTrace, + fasthttp.MethodOptions, + } + return m +} diff --git a/nethttp_test.go b/nethttp_test.go index abce73b..24d835f 100644 --- a/nethttp_test.go +++ b/nethttp_test.go @@ -3,25 +3,67 @@ package gorouter import ( "net/http" "net/http/httptest" + "reflect" "strings" "testing" "github.com/vardius/gorouter/v4/context" ) -func testBasicMethod(t *testing.T, router *router, h func(pattern string, handler http.Handler), method string) { - handler := &mockHandler{} - h("/x/y", handler) +func TestBasicMethod(t *testing.T) { + for _, m := range CreateHTTPMethodsMap() { + if m == "OPTIONS" { + router := New().(*router) + handler := &mockHandler{} + router.GET("/x/y", handler) + router.POST("/x/y", handler) - checkIfHasRootRoute(t, router, method) + TestIfHasRootRoute(t) - err := mockServeHTTP(router, method, "/x/y") - if err != nil { - t.Fatal(err) - } + w := httptest.NewRecorder() - if handler.served != true { - t.Error("Handler has not been served") + // test all tree "*" paths + req, err := http.NewRequest(http.MethodOptions, "*", nil) + if err != nil { + t.Fatal(err) + } + + router.ServeHTTP(w, req) + + if allow := w.Header().Get("Allow"); !strings.Contains(allow, "POST") || !strings.Contains(allow, "GET") || !strings.Contains(allow, "OPTIONS") { + t.Errorf("Allow header incorrect value: %s", allow) + } + + // test specific path + req, err = http.NewRequest(http.MethodOptions, "/x/y", nil) + if err != nil { + t.Fatal(err) + } + + router.ServeHTTP(w, req) + + if allow := w.Header().Get("Allow"); !strings.Contains(allow, "POST") || !strings.Contains(allow, "GET") || !strings.Contains(allow, "OPTIONS") { + t.Errorf("Allow header incorrect value: %s", allow) + } + } + handler := &mockHandler{} + router := New().(*router) + + in := make([]reflect.Value, 2) + in[0] = reflect.ValueOf("/x/y") + in[1] = reflect.ValueOf(handler) + reflect.ValueOf(router).MethodByName(m).Call(in) + + TestIfHasRootRoute(t) + + err := mockServeHTTP(router, m, "/x/y") + if err != nil { + t.Fatal(err) + } + + if handler.served != true { + t.Error("Handler has not been served") + } } } @@ -36,7 +78,7 @@ func TestHandle(t *testing.T) { router := New().(*router) router.Handle(http.MethodPost, "/x/y", handler) - checkIfHasRootRoute(t, router, http.MethodPost) + TestIfHasRootRoute(t) err := mockServeHTTP(router, http.MethodPost, "/x/y") if err != nil { @@ -48,99 +90,10 @@ func TestHandle(t *testing.T) { } } -func TestPOST(t *testing.T) { - t.Parallel() - - router := New().(*router) - testBasicMethod(t, router, router.POST, http.MethodPost) -} - -func TestGET(t *testing.T) { - t.Parallel() - - router := New().(*router) - testBasicMethod(t, router, router.GET, http.MethodGet) -} - -func TestPUT(t *testing.T) { - t.Parallel() - - router := New().(*router) - testBasicMethod(t, router, router.PUT, http.MethodPut) -} - -func TestDELETE(t *testing.T) { - t.Parallel() - - router := New().(*router) - testBasicMethod(t, router, router.DELETE, http.MethodDelete) -} - -func TestPATCH(t *testing.T) { - t.Parallel() - - router := New().(*router) - testBasicMethod(t, router, router.PATCH, http.MethodPatch) -} - -func TestHEAD(t *testing.T) { +func TestMethods(t *testing.T) { t.Parallel() - router := New().(*router) - testBasicMethod(t, router, router.HEAD, http.MethodHead) -} - -func TestCONNECT(t *testing.T) { - t.Parallel() - - router := New().(*router) - testBasicMethod(t, router, router.CONNECT, http.MethodConnect) -} - -func TestTRACE(t *testing.T) { - t.Parallel() - - router := New().(*router) - testBasicMethod(t, router, router.TRACE, http.MethodTrace) -} - -func TestOPTIONS(t *testing.T) { - t.Parallel() - - router := New().(*router) - testBasicMethod(t, router, router.OPTIONS, http.MethodOptions) - - handler := &mockHandler{} - router.GET("/x/y", handler) - router.POST("/x/y", handler) - - checkIfHasRootRoute(t, router, http.MethodGet) - - w := httptest.NewRecorder() - - // test all tree "*" paths - req, err := http.NewRequest(http.MethodOptions, "*", nil) - if err != nil { - t.Fatal(err) - } - - router.ServeHTTP(w, req) - - if allow := w.Header().Get("Allow"); !strings.Contains(allow, "POST") || !strings.Contains(allow, "GET") || !strings.Contains(allow, "OPTIONS") { - t.Errorf("Allow header incorrect value: %s", allow) - } - - // test specific path - req, err = http.NewRequest(http.MethodOptions, "/x/y", nil) - if err != nil { - t.Fatal(err) - } - - router.ServeHTTP(w, req) - - if allow := w.Header().Get("Allow"); !strings.Contains(allow, "POST") || !strings.Contains(allow, "GET") || !strings.Contains(allow, "OPTIONS") { - t.Errorf("Allow header incorrect value: %s", allow) - } + TestBasicMethod(t) } func TestNotFound(t *testing.T) { From 966b1f9b7007672de6845e658f8ec403ffa3a3e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Thu, 30 Jan 2020 17:47:28 +1100 Subject: [PATCH 40/41] Remove cross dependencies between http test implementations --- fasthttp_test.go | 188 ++++++++++++++++++++++++----------------------- mocks_test.go | 45 ++---------- nethttp_test.go | 123 ++++++++++++++++--------------- 3 files changed, 164 insertions(+), 192 deletions(-) diff --git a/fasthttp_test.go b/fasthttp_test.go index 0b7e9c4..d84189b 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -2,7 +2,6 @@ package gorouter import ( "fmt" - "net/http" "reflect" "strings" "testing" @@ -19,55 +18,6 @@ func buildFastHTTPRequestContext(method, path string) *fasthttp.RequestCtx { return ctx } -func TestBasicFastHTTPMethod(t *testing.T) { - for _, m := range CreateHTTPMethodsMap() { - if m == "OPTIONS" { - router := NewFastHTTPRouter().(*fastHTTPRouter) - - handler := &mockHandler{} - router.GET("/x/y", handler.HandleFastHTTP) - router.POST("/x/y", handler.HandleFastHTTP) - - TestIfHasRootRoute(t) - - ctx := buildFastHTTPRequestContext(http.MethodOptions, "*") - - router.HandleFastHTTP(ctx) - - if allow := string(ctx.Response.Header.Peek("Allow")); !strings.Contains(allow, "POST") || !strings.Contains(allow, "GET") || !strings.Contains(allow, "OPTIONS") { - t.Errorf("Allow header incorrect value: %s", allow) - } - - ctx2 := buildFastHTTPRequestContext(http.MethodOptions, "/x/y") - - router.HandleFastHTTP(ctx2) - - if allow := string(ctx.Response.Header.Peek("Allow")); !strings.Contains(allow, "POST") || !strings.Contains(allow, "GET") || !strings.Contains(allow, "OPTIONS") { - t.Errorf("Allow header incorrect value: %s", allow) - } - } - - handler := &mockHandler{} - router := NewFastHTTPRouter().(*fastHTTPRouter) - - in := make([]reflect.Value, 2) - in[0] = reflect.ValueOf("/x/y") - in[1] = reflect.ValueOf(handler.HandleFastHTTP) - reflect.ValueOf(router).MethodByName(m).Call(in) - - TestIfHasRootRoute(t) - - err := mockHandleFastHTTP(router.HandleFastHTTP, m, "/x/y") - if err != nil { - t.Fatal(err) - } - - if handler.served != true { - t.Error("Handler has not been served") - } - } -} - func TestFastHTTPInterface(t *testing.T) { var _ fasthttp.RequestHandler = NewFastHTTPRouter().HandleFastHTTP } @@ -77,11 +27,11 @@ func TestFastHTTPHandle(t *testing.T) { handler := &mockHandler{} router := NewFastHTTPRouter().(*fastHTTPRouter) - router.Handle(http.MethodPost, "/x/y", handler.HandleFastHTTP) + router.Handle(fasthttp.MethodPost, "/x/y", handler.HandleFastHTTP) - TestIfHasRootRoute(t) + checkIfHasRootRoute(t, router, fasthttp.MethodPost) - err := mockHandleFastHTTP(router.HandleFastHTTP, http.MethodPost, "/x/y") + err := mockHandleFastHTTP(router.HandleFastHTTP, fasthttp.MethodPost, "/x/y") if err != nil { t.Fatal(err) } @@ -91,10 +41,62 @@ func TestFastHTTPHandle(t *testing.T) { } } +func TestFastHTTPOPTIONSHeaders(t *testing.T) { + handler := &mockHandler{} + router := NewFastHTTPRouter().(*fastHTTPRouter) + + router.GET("/x/y", handler.HandleFastHTTP) + router.POST("/x/y", handler.HandleFastHTTP) + + checkIfHasRootRoute(t, router, fasthttp.MethodGet) + + ctx := buildFastHTTPRequestContext(fasthttp.MethodOptions, "*") + + router.HandleFastHTTP(ctx) + + if allow := string(ctx.Response.Header.Peek("Allow")); !strings.Contains(allow, "POST") || !strings.Contains(allow, "GET") || !strings.Contains(allow, "OPTIONS") { + t.Errorf("Allow header incorrect value: %s", allow) + } + + ctx2 := buildFastHTTPRequestContext(fasthttp.MethodOptions, "/x/y") + + router.HandleFastHTTP(ctx2) + + if allow := string(ctx.Response.Header.Peek("Allow")); !strings.Contains(allow, "POST") || !strings.Contains(allow, "GET") || !strings.Contains(allow, "OPTIONS") { + t.Errorf("Allow header incorrect value: %s", allow) + } +} + func TestFastHTTPMethods(t *testing.T) { t.Parallel() - TestBasicFastHTTPMethod(t) + for _, method := range []string{ + fasthttp.MethodPost, + fasthttp.MethodGet, + fasthttp.MethodPut, + fasthttp.MethodDelete, + fasthttp.MethodPatch, + fasthttp.MethodHead, + fasthttp.MethodConnect, + fasthttp.MethodTrace, + fasthttp.MethodOptions, + } { + handler := &mockHandler{} + router := NewFastHTTPRouter().(*fastHTTPRouter) + + reflect.ValueOf(router).MethodByName(method).Call([]reflect.Value{reflect.ValueOf("/x/y"), reflect.ValueOf(handler.HandleFastHTTP)}) + + checkIfHasRootRoute(t, router, method) + + err := mockHandleFastHTTP(router.HandleFastHTTP, method, "/x/y") + if err != nil { + t.Fatal(err) + } + + if handler.served != true { + t.Error("Handler has not been served") + } + } } func TestFastHTTPNotFound(t *testing.T) { @@ -105,11 +107,11 @@ func TestFastHTTPNotFound(t *testing.T) { router.GET("/x", handler.HandleFastHTTP) router.GET("/x/y", handler.HandleFastHTTP) - ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/x") + ctx := buildFastHTTPRequestContext(fasthttp.MethodGet, "/x/x") router.HandleFastHTTP(ctx) - if ctx.Response.StatusCode() != http.StatusNotFound { + if ctx.Response.StatusCode() != fasthttp.StatusNotFound { t.Errorf("NotFound error, actual code: %d", ctx.Response.StatusCode()) } @@ -141,11 +143,11 @@ func TestFastHTTPNotAllowed(t *testing.T) { router := NewFastHTTPRouter().(*fastHTTPRouter) router.GET("/x/y", handler.HandleFastHTTP) - ctx := buildFastHTTPRequestContext(http.MethodPost, "/x/y") + ctx := buildFastHTTPRequestContext(fasthttp.MethodPost, "/x/y") router.HandleFastHTTP(ctx) - if ctx.Response.StatusCode() != http.StatusMethodNotAllowed { + if ctx.Response.StatusCode() != fasthttp.StatusMethodNotAllowed { t.Error("NotAllowed doesn't work") } @@ -191,7 +193,7 @@ func TestFastHTTPParam(t *testing.T) { } }) - err := mockHandleFastHTTP(router.HandleFastHTTP, http.MethodGet, "/x/y") + err := mockHandleFastHTTP(router.HandleFastHTTP, fasthttp.MethodGet, "/x/y") if err != nil { t.Fatal(err) } @@ -216,7 +218,7 @@ func TestFastHTTPRegexpParam(t *testing.T) { } }) - err := mockHandleFastHTTP(router.HandleFastHTTP, http.MethodGet, "/x/rxgo") + err := mockHandleFastHTTP(router.HandleFastHTTP, fasthttp.MethodGet, "/x/rxgo") if err != nil { t.Fatal(err) } @@ -259,7 +261,7 @@ func TestFastHTTPServeFiles(t *testing.T) { var ctx fasthttp.RequestCtx var req fasthttp.Request ctx.Init(&req, nil, testLogger{t}) - ctx.Request.Header.SetMethod(http.MethodGet) + ctx.Request.Header.SetMethod(fasthttp.MethodGet) // will serve files from /var/www/static/favicon.ico // because strip 1 value ServeFiles("/var/www/static", 1) // /static/favicon.ico -> /favicon.ico @@ -267,7 +269,7 @@ func TestFastHTTPServeFiles(t *testing.T) { router.HandleFastHTTP(&ctx) - if ctx.Response.StatusCode() != http.StatusNotFound { + if ctx.Response.StatusCode() != fasthttp.StatusNotFound { t.Error("File should not exist") } @@ -291,7 +293,7 @@ func TestFastHTTPNilMiddleware(t *testing.T) { } }) - ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/y") + ctx := buildFastHTTPRequestContext(fasthttp.MethodGet, "/x/y") router.HandleFastHTTP(ctx) @@ -324,7 +326,7 @@ func TestFastHTTPPanicMiddleware(t *testing.T) { panic("test panic recover") }) - err := mockHandleFastHTTP(router.HandleFastHTTP, http.MethodGet, "/x/y") + err := mockHandleFastHTTP(router.HandleFastHTTP, fasthttp.MethodGet, "/x/y") if err != nil { t.Fatal(err) } @@ -346,10 +348,10 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { } }) - router.USE(http.MethodGet, "/x/{param}", mockFastHTTPMiddleware("m1")) - router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m2")) + router.USE(fasthttp.MethodGet, "/x/{param}", mockFastHTTPMiddleware("m1")) + router.USE(fasthttp.MethodGet, "/x/x", mockFastHTTPMiddleware("m2")) - ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/y") + ctx := buildFastHTTPRequestContext(fasthttp.MethodGet, "/x/y") router.HandleFastHTTP(ctx) @@ -357,7 +359,7 @@ func TestFastHTTPNodeApplyMiddleware(t *testing.T) { t.Errorf("Use middleware error %s", string(ctx.Response.Body())) } - ctx = buildFastHTTPRequestContext(http.MethodGet, "/x/x") + ctx = buildFastHTTPRequestContext(fasthttp.MethodGet, "/x/x") router.HandleFastHTTP(ctx) @@ -378,19 +380,19 @@ func TestFastHTTPTreeOrphanMiddlewareOrder(t *testing.T) { }) // Method global middleware - router.USE(http.MethodGet, "/", mockFastHTTPMiddleware("m1->")) - router.USE(http.MethodGet, "/", mockFastHTTPMiddleware("m2->")) + router.USE(fasthttp.MethodGet, "/", mockFastHTTPMiddleware("m1->")) + router.USE(fasthttp.MethodGet, "/", mockFastHTTPMiddleware("m2->")) // Path middleware - router.USE(http.MethodGet, "/x", mockFastHTTPMiddleware("mx1->")) - router.USE(http.MethodGet, "/x", mockFastHTTPMiddleware("mx2->")) - router.USE(http.MethodGet, "/x/y", mockFastHTTPMiddleware("mxy1->")) - router.USE(http.MethodGet, "/x/y", mockFastHTTPMiddleware("mxy2->")) - router.USE(http.MethodGet, "/x/{param}", mockFastHTTPMiddleware("mparam1->")) - router.USE(http.MethodGet, "/x/{param}", mockFastHTTPMiddleware("mparam2->")) - router.USE(http.MethodGet, "/x/y", mockFastHTTPMiddleware("mxy3->")) - router.USE(http.MethodGet, "/x/y", mockFastHTTPMiddleware("mxy4->")) + router.USE(fasthttp.MethodGet, "/x", mockFastHTTPMiddleware("mx1->")) + router.USE(fasthttp.MethodGet, "/x", mockFastHTTPMiddleware("mx2->")) + router.USE(fasthttp.MethodGet, "/x/y", mockFastHTTPMiddleware("mxy1->")) + router.USE(fasthttp.MethodGet, "/x/y", mockFastHTTPMiddleware("mxy2->")) + router.USE(fasthttp.MethodGet, "/x/{param}", mockFastHTTPMiddleware("mparam1->")) + router.USE(fasthttp.MethodGet, "/x/{param}", mockFastHTTPMiddleware("mparam2->")) + router.USE(fasthttp.MethodGet, "/x/y", mockFastHTTPMiddleware("mxy3->")) + router.USE(fasthttp.MethodGet, "/x/y", mockFastHTTPMiddleware("mxy4->")) - ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/y") + ctx := buildFastHTTPRequestContext(fasthttp.MethodGet, "/x/y") router.HandleFastHTTP(ctx) @@ -410,9 +412,9 @@ func TestFastHTTPNodeApplyMiddlewareStatic(t *testing.T) { } }) - router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m1")) + router.USE(fasthttp.MethodGet, "/x/x", mockFastHTTPMiddleware("m1")) - ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/x") + ctx := buildFastHTTPRequestContext(fasthttp.MethodGet, "/x/x") router.HandleFastHTTP(ctx) @@ -433,9 +435,9 @@ func TestFastHTTPNodeApplyMiddlewareInvalidNodeReference(t *testing.T) { } }) - router.USE(http.MethodGet, "/x/x", mockFastHTTPMiddleware("m1")) + router.USE(fasthttp.MethodGet, "/x/x", mockFastHTTPMiddleware("m1")) - ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/y") + ctx := buildFastHTTPRequestContext(fasthttp.MethodGet, "/x/y") router.HandleFastHTTP(ctx) @@ -500,7 +502,7 @@ func TestFastHTTPChainCalls(t *testing.T) { }) // //FIRST CALL - err := mockHandleFastHTTP(router.HandleFastHTTP, http.MethodGet, "/users/x/starred") + err := mockHandleFastHTTP(router.HandleFastHTTP, fasthttp.MethodGet, "/users/x/starred") if err != nil { t.Fatal(err) } @@ -511,7 +513,7 @@ func TestFastHTTPChainCalls(t *testing.T) { //SECOND CALL served = false - err = mockHandleFastHTTP(router.HandleFastHTTP, http.MethodGet, "/applications/client_id/tokens") + err = mockHandleFastHTTP(router.HandleFastHTTP, fasthttp.MethodGet, "/applications/client_id/tokens") if err != nil { t.Fatal(err) } @@ -522,7 +524,7 @@ func TestFastHTTPChainCalls(t *testing.T) { //THIRD CALL served = false - err = mockHandleFastHTTP(router.HandleFastHTTP, http.MethodGet, "/applications/client_id/tokens/access_token") + err = mockHandleFastHTTP(router.HandleFastHTTP, fasthttp.MethodGet, "/applications/client_id/tokens/access_token") if err != nil { t.Fatal(err) } @@ -533,7 +535,7 @@ func TestFastHTTPChainCalls(t *testing.T) { //FOURTH CALL served = false - err = mockHandleFastHTTP(router.HandleFastHTTP, http.MethodGet, "/users/user1/received_events") + err = mockHandleFastHTTP(router.HandleFastHTTP, fasthttp.MethodGet, "/users/user1/received_events") if err != nil { t.Fatal(err) } @@ -544,7 +546,7 @@ func TestFastHTTPChainCalls(t *testing.T) { //FIFTH CALL served = false - err = mockHandleFastHTTP(router.HandleFastHTTP, http.MethodGet, "/users/user2/received_events/public") + err = mockHandleFastHTTP(router.HandleFastHTTP, fasthttp.MethodGet, "/users/user2/received_events/public") if err != nil { t.Fatal(err) } @@ -575,13 +577,13 @@ func TestFastHTTPMountSubRouter(t *testing.T) { mainRouter.Mount("/{param}", subRouter.HandleFastHTTP) - mainRouter.USE(http.MethodGet, "/{param}", mockFastHTTPMiddleware("[r1]")) - mainRouter.USE(http.MethodGet, "/{param}", mockFastHTTPMiddleware("[r2]")) + mainRouter.USE(fasthttp.MethodGet, "/{param}", mockFastHTTPMiddleware("[r1]")) + mainRouter.USE(fasthttp.MethodGet, "/{param}", mockFastHTTPMiddleware("[r2]")) - subRouter.USE(http.MethodGet, "/y", mockFastHTTPMiddleware("[s1]")) - subRouter.USE(http.MethodGet, "/y", mockFastHTTPMiddleware("[s2]")) + subRouter.USE(fasthttp.MethodGet, "/y", mockFastHTTPMiddleware("[s1]")) + subRouter.USE(fasthttp.MethodGet, "/y", mockFastHTTPMiddleware("[s2]")) - ctx := buildFastHTTPRequestContext(http.MethodGet, "/x/y") + ctx := buildFastHTTPRequestContext(fasthttp.MethodGet, "/x/y") mainRouter.HandleFastHTTP(ctx) diff --git a/mocks_test.go b/mocks_test.go index 3c16450..4d586df 100644 --- a/mocks_test.go +++ b/mocks_test.go @@ -88,49 +88,14 @@ func mockHandleFastHTTP(h fasthttp.RequestHandler, method, path string) error { return nil } -func TestIfHasRootRoute(t *testing.T) { - r := routerInterface() - f := fastHTTProuterInterface() - switch v := f.(type) { +func checkIfHasRootRoute(t *testing.T, r interface{}, method string) { + switch v := r.(type) { + case *router: case *fastHTTPRouter: - if rootRoute := v.tree.Find(fasthttp.MethodPost); rootRoute == nil { - switch v := r.(type) { - case *router: - if rootRoute := v.tree.Find(fasthttp.MethodPost); rootRoute == nil { - t.Error("Route not found") - } - } + if rootRoute := v.tree.Find(method); rootRoute == nil { + t.Error("Route not found") } default: t.Error("Unsupported type") } } - -func routerInterface() interface{} { - handler := &mockHandler{} - router := New().(*router) - router.POST("/x/y", handler) - return router -} - -func fastHTTProuterInterface() interface{} { - handler := &mockHandler{} - router := NewFastHTTPRouter().(*fastHTTPRouter) - router.POST("/x/y", handler.HandleFastHTTP) - return router -} - -func CreateHTTPMethodsMap() []string { - m := []string{ - fasthttp.MethodPost, - fasthttp.MethodGet, - fasthttp.MethodPut, - fasthttp.MethodDelete, - fasthttp.MethodPatch, - fasthttp.MethodHead, - fasthttp.MethodConnect, - fasthttp.MethodTrace, - fasthttp.MethodOptions, - } - return m -} diff --git a/nethttp_test.go b/nethttp_test.go index 24d835f..a4ac97c 100644 --- a/nethttp_test.go +++ b/nethttp_test.go @@ -10,63 +10,6 @@ import ( "github.com/vardius/gorouter/v4/context" ) -func TestBasicMethod(t *testing.T) { - for _, m := range CreateHTTPMethodsMap() { - if m == "OPTIONS" { - router := New().(*router) - handler := &mockHandler{} - router.GET("/x/y", handler) - router.POST("/x/y", handler) - - TestIfHasRootRoute(t) - - w := httptest.NewRecorder() - - // test all tree "*" paths - req, err := http.NewRequest(http.MethodOptions, "*", nil) - if err != nil { - t.Fatal(err) - } - - router.ServeHTTP(w, req) - - if allow := w.Header().Get("Allow"); !strings.Contains(allow, "POST") || !strings.Contains(allow, "GET") || !strings.Contains(allow, "OPTIONS") { - t.Errorf("Allow header incorrect value: %s", allow) - } - - // test specific path - req, err = http.NewRequest(http.MethodOptions, "/x/y", nil) - if err != nil { - t.Fatal(err) - } - - router.ServeHTTP(w, req) - - if allow := w.Header().Get("Allow"); !strings.Contains(allow, "POST") || !strings.Contains(allow, "GET") || !strings.Contains(allow, "OPTIONS") { - t.Errorf("Allow header incorrect value: %s", allow) - } - } - handler := &mockHandler{} - router := New().(*router) - - in := make([]reflect.Value, 2) - in[0] = reflect.ValueOf("/x/y") - in[1] = reflect.ValueOf(handler) - reflect.ValueOf(router).MethodByName(m).Call(in) - - TestIfHasRootRoute(t) - - err := mockServeHTTP(router, m, "/x/y") - if err != nil { - t.Fatal(err) - } - - if handler.served != true { - t.Error("Handler has not been served") - } - } -} - func TestInterface(t *testing.T) { var _ http.Handler = New() } @@ -78,7 +21,7 @@ func TestHandle(t *testing.T) { router := New().(*router) router.Handle(http.MethodPost, "/x/y", handler) - TestIfHasRootRoute(t) + checkIfHasRootRoute(t, router, http.MethodPost) err := mockServeHTTP(router, http.MethodPost, "/x/y") if err != nil { @@ -90,10 +33,72 @@ func TestHandle(t *testing.T) { } } +func TestOPTIONSHeaders(t *testing.T) { + handler := &mockHandler{} + router := New().(*router) + + router.GET("/x/y", handler) + router.POST("/x/y", handler) + + checkIfHasRootRoute(t, router, http.MethodGet) + + w := httptest.NewRecorder() + + // test all tree "*" paths + req, err := http.NewRequest(http.MethodOptions, "*", nil) + if err != nil { + t.Fatal(err) + } + + router.ServeHTTP(w, req) + + if allow := w.Header().Get("Allow"); !strings.Contains(allow, "POST") || !strings.Contains(allow, "GET") || !strings.Contains(allow, "OPTIONS") { + t.Errorf("Allow header incorrect value: %s", allow) + } + + // test specific path + req, err = http.NewRequest(http.MethodOptions, "/x/y", nil) + if err != nil { + t.Fatal(err) + } + + router.ServeHTTP(w, req) + + if allow := w.Header().Get("Allow"); !strings.Contains(allow, "POST") || !strings.Contains(allow, "GET") || !strings.Contains(allow, "OPTIONS") { + t.Errorf("Allow header incorrect value: %s", allow) + } +} + func TestMethods(t *testing.T) { t.Parallel() - TestBasicMethod(t) + for _, method := range []string{ + http.MethodPost, + http.MethodGet, + http.MethodPut, + http.MethodDelete, + http.MethodPatch, + http.MethodHead, + http.MethodConnect, + http.MethodTrace, + http.MethodOptions, + } { + handler := &mockHandler{} + router := New().(*router) + + reflect.ValueOf(router).MethodByName(method).Call([]reflect.Value{reflect.ValueOf("/x/y"), reflect.ValueOf(handler)}) + + checkIfHasRootRoute(t, router, method) + + err := mockServeHTTP(router, method, "/x/y") + if err != nil { + t.Fatal(err) + } + + if handler.served != true { + t.Error("Handler has not been served") + } + } } func TestNotFound(t *testing.T) { From 7eb5db7c7b95c40ead10cafc61be9c22be4db776 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Lorenz?= Date: Thu, 30 Jan 2020 17:49:37 +1100 Subject: [PATCH 41/41] Run inner test for methods --- fasthttp_test.go | 26 +++++++++++++++----------- nethttp_test.go | 26 +++++++++++++++----------- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/fasthttp_test.go b/fasthttp_test.go index d84189b..a78f002 100644 --- a/fasthttp_test.go +++ b/fasthttp_test.go @@ -81,21 +81,25 @@ func TestFastHTTPMethods(t *testing.T) { fasthttp.MethodTrace, fasthttp.MethodOptions, } { - handler := &mockHandler{} - router := NewFastHTTPRouter().(*fastHTTPRouter) + t.Run(method, func(t *testing.T) { + t.Parallel() - reflect.ValueOf(router).MethodByName(method).Call([]reflect.Value{reflect.ValueOf("/x/y"), reflect.ValueOf(handler.HandleFastHTTP)}) + handler := &mockHandler{} + router := NewFastHTTPRouter().(*fastHTTPRouter) - checkIfHasRootRoute(t, router, method) + reflect.ValueOf(router).MethodByName(method).Call([]reflect.Value{reflect.ValueOf("/x/y"), reflect.ValueOf(handler.HandleFastHTTP)}) - err := mockHandleFastHTTP(router.HandleFastHTTP, method, "/x/y") - if err != nil { - t.Fatal(err) - } + checkIfHasRootRoute(t, router, method) - if handler.served != true { - t.Error("Handler has not been served") - } + err := mockHandleFastHTTP(router.HandleFastHTTP, method, "/x/y") + if err != nil { + t.Fatal(err) + } + + if handler.served != true { + t.Error("Handler has not been served") + } + }) } } diff --git a/nethttp_test.go b/nethttp_test.go index a4ac97c..6d1b20c 100644 --- a/nethttp_test.go +++ b/nethttp_test.go @@ -83,21 +83,25 @@ func TestMethods(t *testing.T) { http.MethodTrace, http.MethodOptions, } { - handler := &mockHandler{} - router := New().(*router) + t.Run(method, func(t *testing.T) { + t.Parallel() - reflect.ValueOf(router).MethodByName(method).Call([]reflect.Value{reflect.ValueOf("/x/y"), reflect.ValueOf(handler)}) + handler := &mockHandler{} + router := New().(*router) - checkIfHasRootRoute(t, router, method) + reflect.ValueOf(router).MethodByName(method).Call([]reflect.Value{reflect.ValueOf("/x/y"), reflect.ValueOf(handler)}) - err := mockServeHTTP(router, method, "/x/y") - if err != nil { - t.Fatal(err) - } + checkIfHasRootRoute(t, router, method) - if handler.served != true { - t.Error("Handler has not been served") - } + err := mockServeHTTP(router, method, "/x/y") + if err != nil { + t.Fatal(err) + } + + if handler.served != true { + t.Error("Handler has not been served") + } + }) } }