Skip to content

Commit

Permalink
Rename function and store path pattern inside context
Browse files Browse the repository at this point in the history
  • Loading branch information
RohanPadmanabhan committed Dec 1, 2023
1 parent 84a107d commit e9f9213
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 23 deletions.
12 changes: 3 additions & 9 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,9 @@ func (r Request) ResponseWithCode(body interface{}, statusCode int) Response {
return rsp
}

// RouterEndpointPattern finds the router pattern that matches the request. This is only callable while the request
// is being served.
func (r Request) RouterEndpointPattern() string {
if router := RouterForRequest(r); router != nil {
if pathPattern := router.Pattern(r); pathPattern != "" {
return pathPattern
}
}
return ""
// RequestPathPattern finds the router entry pattern that matches the request
func (r Request) RequestPathPattern() string {
return routerEntryPathPatternForRequest(r)
}

// RequestMethod returns the HTTP method of the request
Expand Down
4 changes: 2 additions & 2 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,13 @@ func TestRequestSetMetadata(t *testing.T) {

func TestRouterEndpointPattern(t *testing.T) {
req := NewRequest(context.Background(), http.MethodGet, "/foo/some-url-identifier", nil)
assert.Equal(t, "", req.RouterEndpointPattern()) // should be empty if request has not been served by a router
assert.Equal(t, "", req.RequestPathPattern()) // should be empty if request has not been served by a router

router := Router{}
routerEndpointPattern := "/foo/:id"
router.GET(routerEndpointPattern, func(req Request) Response {
// as we are currently serving the request, we should be able to get the router endpoint pattern
assert.Equal(t, routerEndpointPattern, req.RouterEndpointPattern())
assert.Equal(t, routerEndpointPattern, req.RequestPathPattern())
return req.Response(nil)
})

Expand Down
43 changes: 31 additions & 12 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ import (
// directly means we'd get a collision with any other package that does the same.
// https://play.golang.org/p/MxhRiL37R-9
type routerContextKeyType struct{}
type routerPathContextKeyType struct{}

var (
routerContextKey = routerContextKeyType{}
routerComponentsRe = regexp.MustCompile(`(?:^|/)(\*\w*|:\w+)`)
routerContextKey = routerContextKeyType{}
routerPathContextKey = routerPathContextKeyType{}
routerComponentsRe = regexp.MustCompile(`(?:^|/)(\*\w*|:\w+)`)
)

type routerEntry struct {
Expand Down Expand Up @@ -44,6 +46,13 @@ func RouterForRequest(r Request) *Router {
return nil
}

func routerEntryPathPatternForRequest(r Request) string {
if v := r.Context.Value(routerPathContextKey); v != nil {
return v.(string)
}
return ""
}

func (r *Router) compile(pattern string) *regexp.Regexp {
re, pos := ``, 0
for _, m := range routerComponentsRe.FindAllStringSubmatchIndex(pattern, -1) {
Expand Down Expand Up @@ -116,14 +125,15 @@ func (r Router) Lookup(method, path string) (Service, string, map[string]string,
// Serve returns a Service which will route inbound requests to the enclosed routes.
func (r Router) Serve() Service {
return func(req Request) Response {
svc, _, ok := r.lookup(req.Method, req.URL.Path, nil)
svc, pathPattern, ok := r.lookup(req.Method, req.URL.Path, nil)
if !ok {
txt := fmt.Sprintf("No handler for %s %s", req.Method, req.URL.Path)
rsp := NewResponse(req)
rsp.Error = terrors.NotFound("no_handler", txt, nil)
return rsp
}
req.Context = context.WithValue(req.Context, routerContextKey, &r)
req.Context = context.WithValue(req.Context, routerPathContextKey, pathPattern)
rsp := svc(req)
if rsp.Request == nil {
rsp.Request = &req
Expand All @@ -147,37 +157,46 @@ func (r Router) Params(req Request) map[string]string {
// Sugar

// GET is shorthand for:
// r.Register("GET", pattern, svc)
//
// r.Register("GET", pattern, svc)
func (r *Router) GET(pattern string, svc Service) { r.Register("GET", pattern, svc) }

// CONNECT is shorthand for:
// r.Register("CONNECT", pattern, svc)
//
// r.Register("CONNECT", pattern, svc)
func (r *Router) CONNECT(pattern string, svc Service) { r.Register("CONNECT", pattern, svc) }

// DELETE is shorthand for:
// r.Register("DELETE", pattern, svc)
//
// r.Register("DELETE", pattern, svc)
func (r *Router) DELETE(pattern string, svc Service) { r.Register("DELETE", pattern, svc) }

// HEAD is shorthand for:
// r.Register("HEAD", pattern, svc)
//
// r.Register("HEAD", pattern, svc)
func (r *Router) HEAD(pattern string, svc Service) { r.Register("HEAD", pattern, svc) }

// OPTIONS is shorthand for:
// r.Register("OPTIONS", pattern, svc)
//
// r.Register("OPTIONS", pattern, svc)
func (r *Router) OPTIONS(pattern string, svc Service) { r.Register("OPTIONS", pattern, svc) }

// PATCH is shorthand for:
// r.Register("PATCH", pattern, svc)
//
// r.Register("PATCH", pattern, svc)
func (r *Router) PATCH(pattern string, svc Service) { r.Register("PATCH", pattern, svc) }

// POST is shorthand for:
// r.Register("POST", pattern, svc)
//
// r.Register("POST", pattern, svc)
func (r *Router) POST(pattern string, svc Service) { r.Register("POST", pattern, svc) }

// PUT is shorthand for:
// r.Register("PUT", pattern, svc)
//
// r.Register("PUT", pattern, svc)
func (r *Router) PUT(pattern string, svc Service) { r.Register("PUT", pattern, svc) }

// TRACE is shorthand for:
// r.Register("TRACE", pattern, svc)
//
// r.Register("TRACE", pattern, svc)
func (r *Router) TRACE(pattern string, svc Service) { r.Register("TRACE", pattern, svc) }

0 comments on commit e9f9213

Please sign in to comment.