Add Filter middleware
This commit is contained in:
parent
a5f83606cd
commit
e34a8aa7fd
104
query/filter.go
Normal file
104
query/filter.go
Normal file
@ -0,0 +1,104 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"go.pkg.cx/middleware"
|
||||
)
|
||||
|
||||
// Context keys
|
||||
var (
|
||||
FilterCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/query", Name: "Filter"}
|
||||
)
|
||||
|
||||
// FilterOptions turns a list of option instances into an option
|
||||
func FilterOptions[T any](opts ...FilterOption[T]) FilterOption[T] {
|
||||
return func(f *filter[T]) {
|
||||
for _, opt := range opts {
|
||||
opt(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FilterOption configures filter middleware
|
||||
type FilterOption[T any] func(f *filter[T])
|
||||
|
||||
// WithParseFilterFn adds filter parse function to the list
|
||||
func WithParseFilterFn[T any](fn func(r *http.Request) (T, error)) FilterOption[T] {
|
||||
return func(f *filter[T]) {
|
||||
f.parseFilterFns = append(f.parseFilterFns, fn)
|
||||
}
|
||||
}
|
||||
|
||||
// SetFilterDefault sets filter default value
|
||||
func SetFilterDefault[T any](defaultValue T) FilterOption[T] {
|
||||
return func(f *filter[T]) {
|
||||
f.defaultFilter = defaultValue
|
||||
}
|
||||
}
|
||||
|
||||
// SetResponseHandler sets filter response handler
|
||||
func SetResponseHandler[T any](fn middleware.ResponseHandle) FilterOption[T] {
|
||||
return func(f *filter[T]) {
|
||||
f.responseHandler = fn
|
||||
}
|
||||
}
|
||||
|
||||
type filter[T any] struct {
|
||||
defaultFilter T
|
||||
parseFilterFns []func(r *http.Request) (T, error)
|
||||
responseHandler middleware.ResponseHandle
|
||||
}
|
||||
|
||||
// Filter returns filter middleware
|
||||
func Filter[T any](opts ...FilterOption[T]) func(next http.Handler) http.Handler {
|
||||
f := &filter[T]{
|
||||
defaultFilter: *new(T),
|
||||
parseFilterFns: []func(r *http.Request) (T, error){},
|
||||
responseHandler: RespondWithBadRequest,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(f)
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
filter := f.defaultFilter
|
||||
|
||||
var err error
|
||||
for _, fn := range f.parseFilterFns {
|
||||
if filter, err = fn(r); err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
f.responseHandler(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, FilterCtxKey, filter)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// FilterFromContext returns filter from context
|
||||
func FilterFromContext[T any](ctx context.Context) T {
|
||||
if filter, ok := ctx.Value(FilterCtxKey).(T); ok {
|
||||
return filter
|
||||
}
|
||||
|
||||
return *new(T)
|
||||
}
|
||||
|
||||
// FilterFromQuery returns filter string from query params
|
||||
func FilterFromQuery(filterParam string) func(r *http.Request) string {
|
||||
return func(r *http.Request) string {
|
||||
return r.URL.Query().Get(filterParam)
|
||||
}
|
||||
}
|
98
query/filter_test.go
Normal file
98
query/filter_test.go
Normal file
@ -0,0 +1,98 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFilterFromQuery(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", "/?filter=foo", http.NoBody)
|
||||
assert.NoError(t, err)
|
||||
|
||||
filter := FilterFromQuery("filter")(req)
|
||||
assert.Equal(t, "foo", filter)
|
||||
}
|
||||
|
||||
func TestFilterFromContext(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, FilterCtxKey, "foo")
|
||||
|
||||
assert.Equal(t, 0, FilterFromContext[int](ctx))
|
||||
assert.Equal(t, "foo", FilterFromContext[string](ctx))
|
||||
}
|
||||
|
||||
func TestFilterNoOptions(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
filter := FilterFromContext[int](r.Context())
|
||||
assert.NotNil(t, filter)
|
||||
assert.Equal(t, 0, filter)
|
||||
})
|
||||
|
||||
middleware := Filter[int]()(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", http.NoBody)
|
||||
middleware.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestFilterDefault(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
filter := FilterFromContext[string](r.Context())
|
||||
assert.NotNil(t, filter)
|
||||
assert.Equal(t, "foo", filter)
|
||||
})
|
||||
|
||||
middleware := Filter(SetFilterDefault("foo"))(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/", http.NoBody)
|
||||
middleware.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestFilterValid(t *testing.T) {
|
||||
parseFilter := func(r *http.Request) (string, error) {
|
||||
return r.URL.Query().Get("filter"), nil
|
||||
}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
filter := FilterFromContext[string](r.Context())
|
||||
assert.NotNil(t, filter)
|
||||
assert.Equal(t, "foo", filter)
|
||||
})
|
||||
|
||||
middleware := Filter(WithParseFilterFn(parseFilter))(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/?filter=foo", http.NoBody)
|
||||
middleware.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestFilterInvalid(t *testing.T) {
|
||||
parseFilter := func(r *http.Request) (string, error) {
|
||||
return "", errors.New("invalid filter")
|
||||
}
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
filter := FilterFromContext[string](r.Context())
|
||||
assert.Nil(t, filter)
|
||||
})
|
||||
|
||||
middleware := Filter(WithParseFilterFn(parseFilter))(handler)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest("GET", "/?filter=foo", http.NoBody)
|
||||
middleware.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
10
query/response.go
Normal file
10
query/response.go
Normal file
@ -0,0 +1,10 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func RespondWithBadRequest(w http.ResponseWriter, _ *http.Request, err error) {
|
||||
text := http.StatusText(http.StatusBadRequest) + ": " + err.Error()
|
||||
http.Error(w, text, http.StatusBadRequest)
|
||||
}
|
Loading…
Reference in New Issue
Block a user