99 lines
2.5 KiB
Go
99 lines
2.5 KiB
Go
package query
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestFilterFromQuery(t *testing.T) {
|
|
req, err := http.NewRequest("GET", "/?filter=foo", http.NoBody)
|
|
assert.NoError(t, err)
|
|
|
|
filter := FilterFromQuery("filter")(req)
|
|
assert.Equal(t, "foo", filter)
|
|
}
|
|
|
|
func TestFilterFromContext(t *testing.T) {
|
|
ctx := context.Background()
|
|
ctx = context.WithValue(ctx, FilterCtxKey, "foo")
|
|
|
|
assert.Equal(t, 0, FilterFromContext[int](ctx))
|
|
assert.Equal(t, "foo", FilterFromContext[string](ctx))
|
|
}
|
|
|
|
func TestFilterNoOptions(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
filter := FilterFromContext[int](r.Context())
|
|
assert.NotNil(t, filter)
|
|
assert.Equal(t, 0, filter)
|
|
})
|
|
|
|
middleware := Filter[int]()(handler)
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest("GET", "/", http.NoBody)
|
|
middleware.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
|
|
func TestFilterDefault(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
filter := FilterFromContext[string](r.Context())
|
|
assert.NotNil(t, filter)
|
|
assert.Equal(t, "foo", filter)
|
|
})
|
|
|
|
middleware := Filter(SetFilterDefault("foo"))(handler)
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest("GET", "/", http.NoBody)
|
|
middleware.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
|
|
func TestFilterValid(t *testing.T) {
|
|
parseFilter := func(r *http.Request) (string, error) {
|
|
return r.URL.Query().Get("filter"), nil
|
|
}
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
filter := FilterFromContext[string](r.Context())
|
|
assert.NotNil(t, filter)
|
|
assert.Equal(t, "foo", filter)
|
|
})
|
|
|
|
middleware := Filter(WithParseFilterFn(parseFilter))(handler)
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest("GET", "/?filter=foo", http.NoBody)
|
|
middleware.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
}
|
|
|
|
func TestFilterInvalid(t *testing.T) {
|
|
parseFilter := func(r *http.Request) (string, error) {
|
|
return "", errors.New("invalid filter")
|
|
}
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
filter := FilterFromContext[string](r.Context())
|
|
assert.Nil(t, filter)
|
|
})
|
|
|
|
middleware := Filter(WithParseFilterFn(parseFilter))(handler)
|
|
|
|
w := httptest.NewRecorder()
|
|
req := httptest.NewRequest("GET", "/?filter=foo", http.NoBody)
|
|
middleware.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
|
}
|