diff --git a/request.go b/request.go index 3ebbe003..c01d2ebe 100644 --- a/request.go +++ b/request.go @@ -213,15 +213,9 @@ func (r Request) ResponseWithCode(body interface{}, statusCode int) Response { return rsp } -// RouterEndpointPattern finds the router pattern that matches the request. This is only callable while the request -// is being served. -func (r Request) RouterEndpointPattern() string { - if router := RouterForRequest(r); router != nil { - if pathPattern := router.Pattern(r); pathPattern != "" { - return pathPattern - } - } - return "" +// RequestPathPattern finds the router entry pattern that matches the request +func (r Request) RequestPathPattern() string { + return routerEntryPathPatternForRequest(r) } // RequestMethod returns the HTTP method of the request diff --git a/request_test.go b/request_test.go index 591b22b4..e83ca296 100644 --- a/request_test.go +++ b/request_test.go @@ -226,13 +226,13 @@ func TestRequestSetMetadata(t *testing.T) { func TestRouterEndpointPattern(t *testing.T) { req := NewRequest(context.Background(), http.MethodGet, "/foo/some-url-identifier", nil) - assert.Equal(t, "", req.RouterEndpointPattern()) // should be empty if request has not been served by a router + assert.Equal(t, "", req.RequestPathPattern()) // should be empty if request has not been served by a router router := Router{} routerEndpointPattern := "/foo/:id" router.GET(routerEndpointPattern, func(req Request) Response { // as we are currently serving the request, we should be able to get the router endpoint pattern - assert.Equal(t, routerEndpointPattern, req.RouterEndpointPattern()) + assert.Equal(t, routerEndpointPattern, req.RequestPathPattern()) return req.Response(nil) }) diff --git a/router.go b/router.go index ff3e3f3e..7746076f 100644 --- a/router.go +++ b/router.go @@ -13,10 +13,12 @@ import ( // directly means we'd get a collision with any other package that does the same. // https://play.golang.org/p/MxhRiL37R-9 type routerContextKeyType struct{} +type routerPathContextKeyType struct{} var ( - routerContextKey = routerContextKeyType{} - routerComponentsRe = regexp.MustCompile(`(?:^|/)(\*\w*|:\w+)`) + routerContextKey = routerContextKeyType{} + routerPathContextKey = routerPathContextKeyType{} + routerComponentsRe = regexp.MustCompile(`(?:^|/)(\*\w*|:\w+)`) ) type routerEntry struct { @@ -44,6 +46,13 @@ func RouterForRequest(r Request) *Router { return nil } +func routerEntryPathPatternForRequest(r Request) string { + if v := r.Context.Value(routerPathContextKey); v != nil { + return v.(string) + } + return "" +} + func (r *Router) compile(pattern string) *regexp.Regexp { re, pos := ``, 0 for _, m := range routerComponentsRe.FindAllStringSubmatchIndex(pattern, -1) { @@ -116,7 +125,7 @@ func (r Router) Lookup(method, path string) (Service, string, map[string]string, // Serve returns a Service which will route inbound requests to the enclosed routes. func (r Router) Serve() Service { return func(req Request) Response { - svc, _, ok := r.lookup(req.Method, req.URL.Path, nil) + svc, pathPattern, ok := r.lookup(req.Method, req.URL.Path, nil) if !ok { txt := fmt.Sprintf("No handler for %s %s", req.Method, req.URL.Path) rsp := NewResponse(req) @@ -124,6 +133,7 @@ func (r Router) Serve() Service { return rsp } req.Context = context.WithValue(req.Context, routerContextKey, &r) + req.Context = context.WithValue(req.Context, routerPathContextKey, pathPattern) rsp := svc(req) if rsp.Request == nil { rsp.Request = &req @@ -147,37 +157,46 @@ func (r Router) Params(req Request) map[string]string { // Sugar // GET is shorthand for: -// r.Register("GET", pattern, svc) +// +// r.Register("GET", pattern, svc) func (r *Router) GET(pattern string, svc Service) { r.Register("GET", pattern, svc) } // CONNECT is shorthand for: -// r.Register("CONNECT", pattern, svc) +// +// r.Register("CONNECT", pattern, svc) func (r *Router) CONNECT(pattern string, svc Service) { r.Register("CONNECT", pattern, svc) } // DELETE is shorthand for: -// r.Register("DELETE", pattern, svc) +// +// r.Register("DELETE", pattern, svc) func (r *Router) DELETE(pattern string, svc Service) { r.Register("DELETE", pattern, svc) } // HEAD is shorthand for: -// r.Register("HEAD", pattern, svc) +// +// r.Register("HEAD", pattern, svc) func (r *Router) HEAD(pattern string, svc Service) { r.Register("HEAD", pattern, svc) } // OPTIONS is shorthand for: -// r.Register("OPTIONS", pattern, svc) +// +// r.Register("OPTIONS", pattern, svc) func (r *Router) OPTIONS(pattern string, svc Service) { r.Register("OPTIONS", pattern, svc) } // PATCH is shorthand for: -// r.Register("PATCH", pattern, svc) +// +// r.Register("PATCH", pattern, svc) func (r *Router) PATCH(pattern string, svc Service) { r.Register("PATCH", pattern, svc) } // POST is shorthand for: -// r.Register("POST", pattern, svc) +// +// r.Register("POST", pattern, svc) func (r *Router) POST(pattern string, svc Service) { r.Register("POST", pattern, svc) } // PUT is shorthand for: -// r.Register("PUT", pattern, svc) +// +// r.Register("PUT", pattern, svc) func (r *Router) PUT(pattern string, svc Service) { r.Register("PUT", pattern, svc) } // TRACE is shorthand for: -// r.Register("TRACE", pattern, svc) +// +// r.Register("TRACE", pattern, svc) func (r *Router) TRACE(pattern string, svc Service) { r.Register("TRACE", pattern, svc) }