Skip to content

Commit

Permalink
Merge pull request #169 from monzo/add-router-endpoint-pattern
Browse files Browse the repository at this point in the history
Add router endpoint pattern method to Typhon Request
  • Loading branch information
RohanPadmanabhan authored Dec 4, 2023
2 parents ea17ae8 + fc0f840 commit ab6ccb7
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
12 changes: 11 additions & 1 deletion request.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func (r *Request) BodyBytes(consume bool) ([]byte, error) {
// Send round-trips the request via the default Client. It does not block, instead returning a ResponseFuture
// representing the asynchronous operation to produce the response. It is equivalent to:
//
// r.SendVia(Client)
// r.SendVia(Client)
func (r Request) Send() *ResponseFuture {
return Send(r)
}
Expand Down Expand Up @@ -213,6 +213,16 @@ func (r Request) ResponseWithCode(body interface{}, statusCode int) Response {
return rsp
}

// RequestPathPattern finds the router entry pattern that matches the request
func (r Request) RequestPathPattern() string {
return routerPathPatternForRequest(r)
}

// RequestMethod returns the HTTP method of the request
func (r Request) RequestMethod() string {
return r.Method
}

func (r Request) String() string {
if r.URL == nil {
return "Request(Unknown)"
Expand Down
22 changes: 22 additions & 0 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"io/ioutil"
"math"
"net/http"
"strings"
"testing"

Expand Down Expand Up @@ -223,6 +224,27 @@ func TestRequestSetMetadata(t *testing.T) {
assert.Equal(t, []string{"data"}, req.Request.Header["meta"])
}

func TestRouterEndpointPattern(t *testing.T) {
req := NewRequest(context.Background(), http.MethodGet, "/foo/some-url-identifier", nil)
assert.Equal(t, "", req.RequestPathPattern()) // should be empty if request has not been served by a router

router := Router{}
routerEndpointPattern := "/foo/:id"
router.GET(routerEndpointPattern, func(req Request) Response {
// as we are currently serving the request, we should be able to get the router endpoint pattern
assert.Equal(t, routerEndpointPattern, req.RequestPathPattern())
return req.Response(nil)
})

rsp := req.SendVia(router.Serve()).Response()
require.NoError(t, rsp.Error) // check we didn't get a "route not found" error
}

func TestRequestMethod(t *testing.T) {
req := NewRequest(context.Background(), http.MethodGet, "", nil)
assert.Equal(t, http.MethodGet, req.RequestMethod())
}

func jsonStreamMarshal(v interface{}) ([]byte, error) {
var buffer bytes.Buffer
writer := bufio.NewWriter(&buffer)
Expand Down
16 changes: 13 additions & 3 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ import (
// directly means we'd get a collision with any other package that does the same.
// https://play.golang.org/p/MxhRiL37R-9
type routerContextKeyType struct{}
type routerRequestPatternContextKeyType struct{}

var (
routerContextKey = routerContextKeyType{}
routerComponentsRe = regexp.MustCompile(`(?:^|/)(\*\w*|:\w+)`)
routerContextKey = routerContextKeyType{}
routerRequestPatternContextKey = routerRequestPatternContextKeyType{}
routerComponentsRe = regexp.MustCompile(`(?:^|/)(\*\w*|:\w+)`)
)

type routerEntry struct {
Expand Down Expand Up @@ -44,6 +46,13 @@ func RouterForRequest(r Request) *Router {
return nil
}

func routerPathPatternForRequest(r Request) string {
if v := r.Context.Value(routerRequestPatternContextKey); v != nil {
return v.(string)
}
return ""
}

func (r *Router) compile(pattern string) *regexp.Regexp {
re, pos := ``, 0
for _, m := range routerComponentsRe.FindAllStringSubmatchIndex(pattern, -1) {
Expand Down Expand Up @@ -116,14 +125,15 @@ func (r Router) Lookup(method, path string) (Service, string, map[string]string,
// Serve returns a Service which will route inbound requests to the enclosed routes.
func (r Router) Serve() Service {
return func(req Request) Response {
svc, _, ok := r.lookup(req.Method, req.URL.Path, nil)
svc, pathPattern, ok := r.lookup(req.Method, req.URL.Path, nil)
if !ok {
txt := fmt.Sprintf("No handler for %s %s", req.Method, req.URL.Path)
rsp := NewResponse(req)
rsp.Error = terrors.NotFound("no_handler", txt, nil)
return rsp
}
req.Context = context.WithValue(req.Context, routerContextKey, &r)
req.Context = context.WithValue(req.Context, routerRequestPatternContextKey, pathPattern)
rsp := svc(req)
if rsp.Request == nil {
rsp.Request = &req
Expand Down

0 comments on commit ab6ccb7

Please sign in to comment.