package query import ( "context" "errors" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" ) func TestFilterFromQuery(t *testing.T) { req, err := http.NewRequest("GET", "/?filter=foo", http.NoBody) assert.NoError(t, err) filter := FilterFromQuery("filter")(req) assert.Equal(t, "foo", filter) } func TestFilterFromContext(t *testing.T) { ctx := context.Background() ctx = context.WithValue(ctx, FilterCtxKey, "foo") assert.Equal(t, 0, FilterFromContext[int](ctx)) assert.Equal(t, "foo", FilterFromContext[string](ctx)) } func TestFilterNoOptions(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { filter := FilterFromContext[int](r.Context()) assert.NotNil(t, filter) assert.Equal(t, 0, filter) }) middleware := Filter[int]()(handler) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/", http.NoBody) middleware.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) } func TestFilterDefault(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { filter := FilterFromContext[string](r.Context()) assert.NotNil(t, filter) assert.Equal(t, "foo", filter) }) middleware := Filter(SetFilterDefault("foo"))(handler) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/", http.NoBody) middleware.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) } func TestFilterValid(t *testing.T) { parseFilter := func(r *http.Request) (string, error) { return r.URL.Query().Get("filter"), nil } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { filter := FilterFromContext[string](r.Context()) assert.NotNil(t, filter) assert.Equal(t, "foo", filter) }) middleware := Filter(WithParseFilterFn(parseFilter))(handler) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/?filter=foo", http.NoBody) middleware.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) } func TestFilterInvalid(t *testing.T) { parseFilter := func(r *http.Request) (string, error) { return "", errors.New("invalid filter") } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { filter := FilterFromContext[string](r.Context()) assert.Nil(t, filter) }) middleware := Filter(WithParseFilterFn(parseFilter))(handler) w := httptest.NewRecorder() req := httptest.NewRequest("GET", "/?filter=foo", http.NoBody) middleware.ServeHTTP(w, req) assert.Equal(t, http.StatusBadRequest, w.Code) }