diff --git a/middleware.go b/middleware.go new file mode 100644 index 0000000..3aaae11 --- /dev/null +++ b/middleware.go @@ -0,0 +1,19 @@ +package middleware + +import ( + "net/http" +) + +// Wrap converts a list of middlewares to nested calls in reverse order +func Wrap(handler http.Handler, mws ...func(http.Handler) http.Handler) http.Handler { + if len(mws) == 0 { + return handler + } + + res := handler + for i := len(mws) - 1; i >= 0; i-- { + res = mws[i](res) + } + + return res +} diff --git a/middleware_test.go b/middleware_test.go new file mode 100644 index 0000000..7d9e8db --- /dev/null +++ b/middleware_test.go @@ -0,0 +1,36 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWrap(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/something/1/2", r.URL.Path) + }) + + mw1 := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.URL.Path += "/1" + h.ServeHTTP(w, r) + }) + } + mw2 := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.URL.Path += "/2" + h.ServeHTTP(w, r) + }) + } + + server := httptest.NewServer(Wrap(handler, mw1, mw2)) + defer server.Close() + + res, err := http.Get(server.URL + "/something") + require.NoError(t, err) + assert.Equal(t, 200, res.StatusCode) +}