From e34a8aa7fde2de7b801346b3a58bd2b2227830db Mon Sep 17 00:00:00 2001 From: Anton Zadvorny Date: Sun, 7 Jan 2024 06:00:59 +0300 Subject: [PATCH] Add Filter middleware --- query/filter.go | 104 +++++++++++++++++++++++++++++++++++++++++++ query/filter_test.go | 98 ++++++++++++++++++++++++++++++++++++++++ query/response.go | 10 +++++ 3 files changed, 212 insertions(+) create mode 100644 query/filter.go create mode 100644 query/filter_test.go create mode 100644 query/response.go diff --git a/query/filter.go b/query/filter.go new file mode 100644 index 0000000..40e4fbc --- /dev/null +++ b/query/filter.go @@ -0,0 +1,104 @@ +package query + +import ( + "context" + "net/http" + + "go.pkg.cx/middleware" +) + +// Context keys +var ( + FilterCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/query", Name: "Filter"} +) + +// FilterOptions turns a list of option instances into an option +func FilterOptions[T any](opts ...FilterOption[T]) FilterOption[T] { + return func(f *filter[T]) { + for _, opt := range opts { + opt(f) + } + } +} + +// FilterOption configures filter middleware +type FilterOption[T any] func(f *filter[T]) + +// WithParseFilterFn adds filter parse function to the list +func WithParseFilterFn[T any](fn func(r *http.Request) (T, error)) FilterOption[T] { + return func(f *filter[T]) { + f.parseFilterFns = append(f.parseFilterFns, fn) + } +} + +// SetFilterDefault sets filter default value +func SetFilterDefault[T any](defaultValue T) FilterOption[T] { + return func(f *filter[T]) { + f.defaultFilter = defaultValue + } +} + +// SetResponseHandler sets filter response handler +func SetResponseHandler[T any](fn middleware.ResponseHandle) FilterOption[T] { + return func(f *filter[T]) { + f.responseHandler = fn + } +} + +type filter[T any] struct { + defaultFilter T + parseFilterFns []func(r *http.Request) (T, error) + responseHandler middleware.ResponseHandle +} + +// Filter returns filter middleware +func Filter[T any](opts ...FilterOption[T]) func(next http.Handler) http.Handler { + f := &filter[T]{ + defaultFilter: *new(T), + parseFilterFns: []func(r *http.Request) (T, error){}, + responseHandler: RespondWithBadRequest, + } + + for _, opt := range opts { + opt(f) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + filter := f.defaultFilter + + var err error + for _, fn := range f.parseFilterFns { + if filter, err = fn(r); err == nil { + break + } + } + + if err != nil { + f.responseHandler(w, r, err) + return + } + + ctx := r.Context() + ctx = context.WithValue(ctx, FilterCtxKey, filter) + + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// FilterFromContext returns filter from context +func FilterFromContext[T any](ctx context.Context) T { + if filter, ok := ctx.Value(FilterCtxKey).(T); ok { + return filter + } + + return *new(T) +} + +// FilterFromQuery returns filter string from query params +func FilterFromQuery(filterParam string) func(r *http.Request) string { + return func(r *http.Request) string { + return r.URL.Query().Get(filterParam) + } +} diff --git a/query/filter_test.go b/query/filter_test.go new file mode 100644 index 0000000..108e505 --- /dev/null +++ b/query/filter_test.go @@ -0,0 +1,98 @@ +package query + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFilterFromQuery(t *testing.T) { + req, err := http.NewRequest("GET", "/?filter=foo", http.NoBody) + assert.NoError(t, err) + + filter := FilterFromQuery("filter")(req) + assert.Equal(t, "foo", filter) +} + +func TestFilterFromContext(t *testing.T) { + ctx := context.Background() + ctx = context.WithValue(ctx, FilterCtxKey, "foo") + + assert.Equal(t, 0, FilterFromContext[int](ctx)) + assert.Equal(t, "foo", FilterFromContext[string](ctx)) +} + +func TestFilterNoOptions(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + filter := FilterFromContext[int](r.Context()) + assert.NotNil(t, filter) + assert.Equal(t, 0, filter) + }) + + middleware := Filter[int]()(handler) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", http.NoBody) + middleware.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestFilterDefault(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + filter := FilterFromContext[string](r.Context()) + assert.NotNil(t, filter) + assert.Equal(t, "foo", filter) + }) + + middleware := Filter(SetFilterDefault("foo"))(handler) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/", http.NoBody) + middleware.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestFilterValid(t *testing.T) { + parseFilter := func(r *http.Request) (string, error) { + return r.URL.Query().Get("filter"), nil + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + filter := FilterFromContext[string](r.Context()) + assert.NotNil(t, filter) + assert.Equal(t, "foo", filter) + }) + + middleware := Filter(WithParseFilterFn(parseFilter))(handler) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/?filter=foo", http.NoBody) + middleware.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) +} + +func TestFilterInvalid(t *testing.T) { + parseFilter := func(r *http.Request) (string, error) { + return "", errors.New("invalid filter") + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + filter := FilterFromContext[string](r.Context()) + assert.Nil(t, filter) + }) + + middleware := Filter(WithParseFilterFn(parseFilter))(handler) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/?filter=foo", http.NoBody) + middleware.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) +} diff --git a/query/response.go b/query/response.go new file mode 100644 index 0000000..52695a0 --- /dev/null +++ b/query/response.go @@ -0,0 +1,10 @@ +package query + +import ( + "net/http" +) + +func RespondWithBadRequest(w http.ResponseWriter, _ *http.Request, err error) { + text := http.StatusText(http.StatusBadRequest) + ": " + err.Error() + http.Error(w, text, http.StatusBadRequest) +}