Compare commits
2 Commits
a5f83606cd
...
fe3c788fba
Author | SHA1 | Date | |
---|---|---|---|
|
fe3c788fba | ||
|
e34a8aa7fd |
@ -28,7 +28,7 @@ var DefaultOptions = Options(
|
|||||||
WithFindTokenFn(middleware.TokenFromAuthorizationHeader),
|
WithFindTokenFn(middleware.TokenFromAuthorizationHeader),
|
||||||
WithFindTokenFn(middleware.TokenFromQuery("jwt")),
|
WithFindTokenFn(middleware.TokenFromQuery("jwt")),
|
||||||
WithFindTokenFn(middleware.TokenFromCookie("jwt")),
|
WithFindTokenFn(middleware.TokenFromCookie("jwt")),
|
||||||
SetResponseHandler(RespondWithUnauthorized),
|
SetResponseHandler(middleware.RespondWithUnauthorized),
|
||||||
SetValidateTokenFn(allowAll),
|
SetValidateTokenFn(allowAll),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -160,11 +160,6 @@ func Middleware(key interface{}, alg jwa.SignatureAlgorithm, opts ...Option) fun
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RespondWithUnauthorized is a default response handler
|
|
||||||
func RespondWithUnauthorized(w http.ResponseWriter, _ *http.Request, _ error) {
|
|
||||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
|
||||||
}
|
|
||||||
|
|
||||||
func allowAll(_ jwt.Token) (bool, interface{}) {
|
func allowAll(_ jwt.Token) (bool, interface{}) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
@ -25,7 +25,7 @@ var DefaultOptions = Options(
|
|||||||
WithFindTokenFn(middleware.TokenFromHeader("X-Token")),
|
WithFindTokenFn(middleware.TokenFromHeader("X-Token")),
|
||||||
WithFindTokenFn(middleware.TokenFromQuery("token")),
|
WithFindTokenFn(middleware.TokenFromQuery("token")),
|
||||||
SetValidateTokenFn(rejectAll),
|
SetValidateTokenFn(rejectAll),
|
||||||
SetResponseHandler(RespondWithUnauthorized),
|
SetResponseHandler(middleware.RespondWithUnauthorized),
|
||||||
)
|
)
|
||||||
|
|
||||||
// Options turns a list of option instances into an option
|
// Options turns a list of option instances into an option
|
||||||
@ -113,11 +113,6 @@ func Middleware(opts ...Option) func(next http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RespondWithUnauthorized is a default response handler
|
|
||||||
func RespondWithUnauthorized(w http.ResponseWriter, _ *http.Request, _ error) {
|
|
||||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
|
||||||
}
|
|
||||||
|
|
||||||
func rejectAll(_ string) (bool, interface{}) {
|
func rejectAll(_ string) (bool, interface{}) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
@ -1,12 +1,5 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ResponseHandle is a function that middleware call in case of stop chain
|
|
||||||
type ResponseHandle func(w http.ResponseWriter, r *http.Request, err error)
|
|
||||||
|
|
||||||
// CtxKey is a key to use with context.WithValue
|
// CtxKey is a key to use with context.WithValue
|
||||||
type CtxKey struct {
|
type CtxKey struct {
|
||||||
Pkg string
|
Pkg string
|
@ -24,7 +24,7 @@ var DefaultOptions = Options(
|
|||||||
WithFindPaginationFn(PaginationFromQuery("page", "pageSize")),
|
WithFindPaginationFn(PaginationFromQuery("page", "pageSize")),
|
||||||
SetPaginationDefaults(1, 50),
|
SetPaginationDefaults(1, 50),
|
||||||
SetValidatePaginationFn(allowAll),
|
SetValidatePaginationFn(allowAll),
|
||||||
SetResponseHandler(RespondWithBadRequest),
|
SetResponseHandler(middleware.RespondWithBadRequest),
|
||||||
)
|
)
|
||||||
|
|
||||||
// Pagination represents pagination info
|
// Pagination represents pagination info
|
||||||
@ -149,11 +149,6 @@ func PaginationFromQuery(pageParam string, pageSizeParam string) func(r *http.Re
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RespondWithBadRequest is a default response handler
|
|
||||||
func RespondWithBadRequest(w http.ResponseWriter, _ *http.Request, _ error) {
|
|
||||||
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
|
|
||||||
func allowAll(_ *Pagination) error {
|
func allowAll(_ *Pagination) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
104
query/filter.go
Normal file
104
query/filter.go
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
package query
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"go.pkg.cx/middleware"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Context keys
|
||||||
|
var (
|
||||||
|
FilterCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/query", Name: "Filter"}
|
||||||
|
)
|
||||||
|
|
||||||
|
// FilterOptions turns a list of option instances into an option
|
||||||
|
func FilterOptions[T any](opts ...FilterOption[T]) FilterOption[T] {
|
||||||
|
return func(f *filter[T]) {
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterOption configures filter middleware
|
||||||
|
type FilterOption[T any] func(f *filter[T])
|
||||||
|
|
||||||
|
// WithParseFilterFn adds filter parse function to the list
|
||||||
|
func WithParseFilterFn[T any](fn func(r *http.Request) (T, error)) FilterOption[T] {
|
||||||
|
return func(f *filter[T]) {
|
||||||
|
f.parseFilterFns = append(f.parseFilterFns, fn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFilterDefault sets filter default value
|
||||||
|
func SetFilterDefault[T any](defaultValue T) FilterOption[T] {
|
||||||
|
return func(f *filter[T]) {
|
||||||
|
f.defaultFilter = defaultValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetResponseHandler sets filter response handler
|
||||||
|
func SetResponseHandler[T any](fn middleware.ResponseHandle) FilterOption[T] {
|
||||||
|
return func(f *filter[T]) {
|
||||||
|
f.responseHandler = fn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type filter[T any] struct {
|
||||||
|
defaultFilter T
|
||||||
|
parseFilterFns []func(r *http.Request) (T, error)
|
||||||
|
responseHandler middleware.ResponseHandle
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter returns filter middleware
|
||||||
|
func Filter[T any](opts ...FilterOption[T]) func(next http.Handler) http.Handler {
|
||||||
|
f := &filter[T]{
|
||||||
|
defaultFilter: *new(T),
|
||||||
|
parseFilterFns: []func(r *http.Request) (T, error){},
|
||||||
|
responseHandler: middleware.RespondWithBadRequest,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
filter := f.defaultFilter
|
||||||
|
|
||||||
|
var err error
|
||||||
|
for _, fn := range f.parseFilterFns {
|
||||||
|
if filter, err = fn(r); err == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
f.responseHandler(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, FilterCtxKey, filter)
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterFromContext returns filter from context
|
||||||
|
func FilterFromContext[T any](ctx context.Context) T {
|
||||||
|
if filter, ok := ctx.Value(FilterCtxKey).(T); ok {
|
||||||
|
return filter
|
||||||
|
}
|
||||||
|
|
||||||
|
return *new(T)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterFromQuery returns filter string from query params
|
||||||
|
func FilterFromQuery(filterParam string) func(r *http.Request) string {
|
||||||
|
return func(r *http.Request) string {
|
||||||
|
return r.URL.Query().Get(filterParam)
|
||||||
|
}
|
||||||
|
}
|
98
query/filter_test.go
Normal file
98
query/filter_test.go
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
package query
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFilterFromQuery(t *testing.T) {
|
||||||
|
req, err := http.NewRequest("GET", "/?filter=foo", http.NoBody)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
filter := FilterFromQuery("filter")(req)
|
||||||
|
assert.Equal(t, "foo", filter)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterFromContext(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
ctx = context.WithValue(ctx, FilterCtxKey, "foo")
|
||||||
|
|
||||||
|
assert.Equal(t, 0, FilterFromContext[int](ctx))
|
||||||
|
assert.Equal(t, "foo", FilterFromContext[string](ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterNoOptions(t *testing.T) {
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
filter := FilterFromContext[int](r.Context())
|
||||||
|
assert.NotNil(t, filter)
|
||||||
|
assert.Equal(t, 0, filter)
|
||||||
|
})
|
||||||
|
|
||||||
|
middleware := Filter[int]()(handler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("GET", "/", http.NoBody)
|
||||||
|
middleware.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterDefault(t *testing.T) {
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
filter := FilterFromContext[string](r.Context())
|
||||||
|
assert.NotNil(t, filter)
|
||||||
|
assert.Equal(t, "foo", filter)
|
||||||
|
})
|
||||||
|
|
||||||
|
middleware := Filter(SetFilterDefault("foo"))(handler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("GET", "/", http.NoBody)
|
||||||
|
middleware.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterValid(t *testing.T) {
|
||||||
|
parseFilter := func(r *http.Request) (string, error) {
|
||||||
|
return r.URL.Query().Get("filter"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
filter := FilterFromContext[string](r.Context())
|
||||||
|
assert.NotNil(t, filter)
|
||||||
|
assert.Equal(t, "foo", filter)
|
||||||
|
})
|
||||||
|
|
||||||
|
middleware := Filter(WithParseFilterFn(parseFilter))(handler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("GET", "/?filter=foo", http.NoBody)
|
||||||
|
middleware.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFilterInvalid(t *testing.T) {
|
||||||
|
parseFilter := func(r *http.Request) (string, error) {
|
||||||
|
return "", errors.New("invalid filter")
|
||||||
|
}
|
||||||
|
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
filter := FilterFromContext[string](r.Context())
|
||||||
|
assert.Nil(t, filter)
|
||||||
|
})
|
||||||
|
|
||||||
|
middleware := Filter(WithParseFilterFn(parseFilter))(handler)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest("GET", "/?filter=foo", http.NoBody)
|
||||||
|
middleware.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
}
|
@ -12,7 +12,7 @@ import (
|
|||||||
var DefaultOptions = Options(
|
var DefaultOptions = Options(
|
||||||
SetLogStackFn(defaultLogStackFn),
|
SetLogStackFn(defaultLogStackFn),
|
||||||
SetLogRecoverFn(defaultLogRecoverFn),
|
SetLogRecoverFn(defaultLogRecoverFn),
|
||||||
SetResponseHandler(RespondWithInternalServerError),
|
SetResponseHandler(middleware.RespondWithInternalServerError),
|
||||||
)
|
)
|
||||||
|
|
||||||
// Options turns a list of option instances into an option
|
// Options turns a list of option instances into an option
|
||||||
@ -78,11 +78,6 @@ func Middleware(opts ...Option) func(next http.Handler) http.Handler {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RespondWithInternalServerError is a default response handler
|
|
||||||
func RespondWithInternalServerError(w http.ResponseWriter, _ *http.Request, _ error) {
|
|
||||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
|
|
||||||
func defaultLogStackFn(stack []byte) {
|
func defaultLogStackFn(stack []byte) {
|
||||||
fmt.Println(string(stack))
|
fmt.Println(string(stack))
|
||||||
}
|
}
|
||||||
|
33
response.go
Normal file
33
response.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ResponseHandle is a function that middleware call in case of stop chain
|
||||||
|
type ResponseHandle func(w http.ResponseWriter, r *http.Request, err error)
|
||||||
|
|
||||||
|
// RespondWithBadRequest is a default bad request response handler
|
||||||
|
func RespondWithBadRequest(w http.ResponseWriter, _ *http.Request, _ error) {
|
||||||
|
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RespondWithUnauthorized is a default unathorized response handler
|
||||||
|
func RespondWithUnauthorized(w http.ResponseWriter, _ *http.Request, _ error) {
|
||||||
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RespondWithTooManyRequests is a default too many requests response handler
|
||||||
|
func RespondWithTooManyRequests(w http.ResponseWriter, _ *http.Request, _ error) {
|
||||||
|
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RespondWithInternalServerError is a default internal server error response handler
|
||||||
|
func RespondWithInternalServerError(w http.ResponseWriter, _ *http.Request, _ error) {
|
||||||
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RespondWithTimeout is a default gateway timeout response handler
|
||||||
|
func RespondWithGatewayTimeout(w http.ResponseWriter, _ *http.Request, _ error) {
|
||||||
|
http.Error(w, http.StatusText(http.StatusGatewayTimeout), http.StatusGatewayTimeout)
|
||||||
|
}
|
@ -19,7 +19,7 @@ var (
|
|||||||
// DefaultOptions represents default throttle middleware options
|
// DefaultOptions represents default throttle middleware options
|
||||||
var DefaultOptions = Options(
|
var DefaultOptions = Options(
|
||||||
SetLimit(100),
|
SetLimit(100),
|
||||||
SetResponseHandler(RespondWithTooManyRequests),
|
SetResponseHandler(middleware.RespondWithTooManyRequests),
|
||||||
)
|
)
|
||||||
|
|
||||||
// Options turns a list of option instances into an option
|
// Options turns a list of option instances into an option
|
||||||
@ -169,8 +169,3 @@ func Middleware(opts ...Option) func(http.Handler) http.Handler {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RespondWithTooManyRequests is a default response handler
|
|
||||||
func RespondWithTooManyRequests(w http.ResponseWriter, _ *http.Request, _ error) {
|
|
||||||
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
|
|
||||||
}
|
|
||||||
|
@ -11,7 +11,7 @@ import (
|
|||||||
// DefaultOptions represents default timeout middleware options
|
// DefaultOptions represents default timeout middleware options
|
||||||
var DefaultOptions = Options(
|
var DefaultOptions = Options(
|
||||||
SetTimeout(time.Second*30),
|
SetTimeout(time.Second*30),
|
||||||
SetResponseHandler(RespondWithTimeout),
|
SetResponseHandler(middleware.RespondWithGatewayTimeout),
|
||||||
)
|
)
|
||||||
|
|
||||||
// Options turns a list of option instances into an option
|
// Options turns a list of option instances into an option
|
||||||
@ -74,8 +74,3 @@ func Middleware(opts ...Option) func(next http.Handler) http.Handler {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RespondWithTimeout is a default response handler
|
|
||||||
func RespondWithTimeout(w http.ResponseWriter, _ *http.Request, _ error) {
|
|
||||||
http.Error(w, http.StatusText(http.StatusGatewayTimeout), http.StatusGatewayTimeout)
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user