Merge some middlewares
This commit is contained in:
parent
7c06c3c6d8
commit
f46dc4094f
10
app_info.go
10
app_info.go
@ -6,14 +6,12 @@ import (
|
||||
|
||||
// AppInfo adds application name and version headers to the response
|
||||
func AppInfo(name string, version string) func(http.Handler) http.Handler {
|
||||
return func(h http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("App-Name", name)
|
||||
w.Header().Set("App-Version", version)
|
||||
|
||||
h.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return http.HandlerFunc(fn)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
32
app_info_test.go
Normal file
32
app_info_test.go
Normal file
@ -0,0 +1,32 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAppInfo(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(AppInfo("tapp", "tversion")(handler))
|
||||
defer server.Close()
|
||||
|
||||
res, err := http.Get(server.URL)
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, "tapp", res.Header.Get("App-Name"))
|
||||
assert.Equal(t, "tversion", res.Header.Get("App-Version"))
|
||||
|
||||
b, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "resp", string(b))
|
||||
}
|
171
auth/jwt/jwt.go
Normal file
171
auth/jwt/jwt.go
Normal file
@ -0,0 +1,171 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
|
||||
"go.pkg.cx/middleware"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrTokenInvalid = errors.New("token invalid")
|
||||
ErrTokenNotFound = errors.New("token not found")
|
||||
)
|
||||
|
||||
// Context keys
|
||||
var (
|
||||
JWTCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/auth/jwt", Name: "JWT"}
|
||||
DataCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/auth/token", Name: "Data"}
|
||||
)
|
||||
|
||||
// DefaultOptions represents default jwt auth middleware options
|
||||
var DefaultOptions = Options(
|
||||
WithFindTokenFn(middleware.TokenFromAuthorizationHeader),
|
||||
WithFindTokenFn(middleware.TokenFromQuery("jwt")),
|
||||
WithFindTokenFn(middleware.TokenFromCookie("jwt")),
|
||||
SetResponseHandler(RespondWithUnauthorized),
|
||||
SetValidateTokenFn(allowAll),
|
||||
)
|
||||
|
||||
// Options turns a list of option instances into an option.
|
||||
func Options(opts ...Option) Option {
|
||||
return func(a *auth) {
|
||||
for _, opt := range opts {
|
||||
opt(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Option configures jwt auth middleware
|
||||
type Option func(a *auth)
|
||||
|
||||
// SetKey sets key to verify jwt token with
|
||||
func SetKey(key interface{}) Option {
|
||||
return func(a *auth) {
|
||||
a.key = key
|
||||
}
|
||||
}
|
||||
|
||||
// SetAlgorithm sets algorithm to verify jwt token with
|
||||
func SetAlgorithm(alg jwa.SignatureAlgorithm) Option {
|
||||
return func(a *auth) {
|
||||
a.algorithm = alg
|
||||
}
|
||||
}
|
||||
|
||||
// WithFindTokenFn adds jwt token find function to the list
|
||||
func WithFindTokenFn(fn func(r *http.Request) string) Option {
|
||||
return func(a *auth) {
|
||||
a.findTokenFns = append(a.findTokenFns, fn)
|
||||
}
|
||||
}
|
||||
|
||||
// SetFindTokenFns sets jwt token find functions list
|
||||
func SetFindTokenFns(fns ...func(r *http.Request) string) Option {
|
||||
return func(a *auth) {
|
||||
a.findTokenFns = fns
|
||||
}
|
||||
}
|
||||
|
||||
// WithVerifyOption adds jwt verify option to the list
|
||||
func WithVerifyOption(opt jwt.Option) Option {
|
||||
return func(a *auth) {
|
||||
a.verifyOptions = append(a.verifyOptions, opt)
|
||||
}
|
||||
}
|
||||
|
||||
// SetVerifyOptions sets jwt verify options list
|
||||
func SetVerifyOptions(opts ...jwt.Option) Option {
|
||||
return func(a *auth) {
|
||||
a.verifyOptions = opts
|
||||
}
|
||||
}
|
||||
|
||||
// SetValidateTokenFn sets token validation function
|
||||
func SetValidateTokenFn(fn func(token jwt.Token) (bool, interface{})) Option {
|
||||
return func(a *auth) {
|
||||
a.validateTokenFn = fn
|
||||
}
|
||||
}
|
||||
|
||||
// SetResponseHandler sets response handler
|
||||
func SetResponseHandler(fn middleware.ResponseHandle) Option {
|
||||
return func(a *auth) {
|
||||
a.responseHandler = fn
|
||||
}
|
||||
}
|
||||
|
||||
type auth struct {
|
||||
key interface{}
|
||||
algorithm jwa.SignatureAlgorithm
|
||||
verifyOptions []jwt.Option
|
||||
findTokenFns []func(r *http.Request) string
|
||||
validateTokenFn func(token jwt.Token) (bool, interface{})
|
||||
responseHandler middleware.ResponseHandle
|
||||
}
|
||||
|
||||
// Middleware returns jwt auth middleware
|
||||
func Middleware(key interface{}, alg jwa.SignatureAlgorithm, opts ...Option) func(next http.Handler) http.Handler {
|
||||
a := &auth{}
|
||||
opts = append(opts, SetKey(key))
|
||||
opts = append(opts, SetAlgorithm(alg))
|
||||
opts = append([]Option{DefaultOptions}, opts...)
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(a)
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var token string
|
||||
for _, findTokenFn := range a.findTokenFns {
|
||||
token = findTokenFn(r)
|
||||
if token != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
a.responseHandler(w, r, ErrTokenNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
jwtToken, err := jwt.ParseString(token, jwt.WithVerify(a.algorithm, a.key))
|
||||
if err != nil {
|
||||
a.responseHandler(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := jwt.Verify(jwtToken, a.verifyOptions...); err != nil {
|
||||
a.responseHandler(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
valid, data := a.validateTokenFn(jwtToken)
|
||||
if !valid {
|
||||
a.responseHandler(w, r, ErrTokenInvalid)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, JWTCtxKey, jwtToken)
|
||||
ctx = context.WithValue(ctx, DataCtxKey, data)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RespondWithUnauthorized is a default response handler
|
||||
func RespondWithUnauthorized(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func allowAll(token jwt.Token) (bool, interface{}) {
|
||||
return true, nil
|
||||
}
|
192
auth/jwt/jwt_test.go
Normal file
192
auth/jwt/jwt_test.go
Normal file
@ -0,0 +1,192 @@
|
||||
package jwt
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwa"
|
||||
"github.com/lestrrat-go/jwx/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.pkg.cx/middleware"
|
||||
)
|
||||
|
||||
var testHandler = func(t *testing.T) http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestJWTAuthDefaults(t *testing.T) {
|
||||
auth := Middleware(
|
||||
[]byte("changethis"),
|
||||
jwa.HS256,
|
||||
)
|
||||
|
||||
server := httptest.NewServer(auth(testHandler(t)))
|
||||
defer server.Close()
|
||||
|
||||
res, err := http.Get(server.URL)
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
|
||||
b, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Unauthorized\n", string(b))
|
||||
|
||||
token := jwt.New()
|
||||
payload, err := jwt.Sign(token, jwa.HS256, []byte("changethis"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
res, err = http.Get(server.URL + "/?jwt=" + string(payload))
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
b, err = ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "resp", string(b))
|
||||
}
|
||||
|
||||
func TestJWTAuthParseVerify(t *testing.T) {
|
||||
auth := Middleware(
|
||||
[]byte("tkey"),
|
||||
jwa.HS512,
|
||||
SetVerifyOptions(
|
||||
jwt.WithSubject("tsubj"),
|
||||
jwt.WithAudience("taud"),
|
||||
),
|
||||
SetFindTokenFns(
|
||||
middleware.TokenFromQuery("token"),
|
||||
),
|
||||
)
|
||||
|
||||
server := httptest.NewServer(auth(testHandler(t)))
|
||||
defer server.Close()
|
||||
|
||||
token := jwt.New()
|
||||
|
||||
payload, err := jwt.Sign(token, jwa.HS512, []byte("tkey"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
res, err := http.Get(server.URL + "/?token=" + string(payload))
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
|
||||
payload, err = jwt.Sign(token, jwa.HS256, []byte("tkey"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
res, err = http.Get(server.URL + "/?token=" + string(payload))
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
|
||||
payload, err = jwt.Sign(token, jwa.HS512, []byte("wrongkey"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
res, err = http.Get(server.URL + "/?token=" + string(payload))
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
|
||||
now := time.Now()
|
||||
token.Set(jwt.IssuedAtKey, now.Add(-1*time.Hour)) // nolint:errcheck
|
||||
token.Set(jwt.ExpirationKey, now.Add(-58*time.Minute)) // nolint:errcheck
|
||||
payload, err = jwt.Sign(token, jwa.HS512, []byte("tkey"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
res, err = http.Get(server.URL + "/?token=" + string(payload))
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
}
|
||||
|
||||
func TestJWTAuthVerifyOptions(t *testing.T) {
|
||||
auth := Middleware(
|
||||
[]byte("changethis"),
|
||||
jwa.HS256,
|
||||
WithVerifyOption(jwt.WithIssuer("tissuer")),
|
||||
)
|
||||
|
||||
server := httptest.NewServer(auth(testHandler(t)))
|
||||
defer server.Close()
|
||||
|
||||
token := jwt.New()
|
||||
|
||||
token.Set(jwt.IssuerKey, "wrongissuer") // nolint:errcheck
|
||||
payload, err := jwt.Sign(token, jwa.HS256, []byte("changethis"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
res, err := http.Get(server.URL + "/?jwt=" + string(payload))
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
|
||||
token.Set(jwt.IssuerKey, "tissuer") // nolint:errcheck
|
||||
payload, err = jwt.Sign(token, jwa.HS256, []byte("changethis"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
res, err = http.Get(server.URL + "/?jwt=" + string(payload))
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
func TestJWTAuthContext(t *testing.T) {
|
||||
type data struct {
|
||||
inner string
|
||||
}
|
||||
|
||||
validateWithDataFn := func(token jwt.Token) (bool, interface{}) {
|
||||
return token.JwtID() == "tid", &data{"test data"}
|
||||
}
|
||||
|
||||
auth := Middleware(
|
||||
[]byte("changethis"),
|
||||
jwa.HS256,
|
||||
SetValidateTokenFn(validateWithDataFn),
|
||||
)
|
||||
|
||||
token := jwt.New()
|
||||
token.Set(jwt.JwtIDKey, "tid") // nolint:errcheck
|
||||
|
||||
testCtxHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
data, ok := r.Context().Value(DataCtxKey).(*data)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "test data", data.inner)
|
||||
|
||||
token, ok := r.Context().Value(JWTCtxKey).(jwt.Token)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "tid", token.JwtID())
|
||||
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(auth(testCtxHandler))
|
||||
defer server.Close()
|
||||
|
||||
payload, err := jwt.Sign(token, jwa.HS256, []byte("changethis"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
res, err := http.Get(server.URL + "/?jwt=" + string(payload))
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
token.Set(jwt.JwtIDKey, "invalid") // nolint:errcheck
|
||||
payload, err = jwt.Sign(token, jwa.HS256, []byte("changethis"))
|
||||
assert.NoError(t, err)
|
||||
|
||||
res, err = http.Get(server.URL + "/?jwt=" + string(payload))
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
}
|
123
auth/token/token.go
Normal file
123
auth/token/token.go
Normal file
@ -0,0 +1,123 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"go.pkg.cx/middleware"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrTokenInvalid = errors.New("token invalid")
|
||||
ErrTokenNotFound = errors.New("token not found")
|
||||
)
|
||||
|
||||
// Context keys
|
||||
var (
|
||||
DataCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/auth/token", Name: "Data"}
|
||||
TokenCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/auth/token", Name: "Token"}
|
||||
)
|
||||
|
||||
// DefaultOptions represents default token auth middleware options
|
||||
var DefaultOptions = Options(
|
||||
WithFindTokenFn(middleware.TokenFromHeader("X-Token")),
|
||||
WithFindTokenFn(middleware.TokenFromQuery("token")),
|
||||
SetValidateTokenFn(rejectAll),
|
||||
SetResponseHandler(RespondWithUnauthorized),
|
||||
)
|
||||
|
||||
// Options turns a list of option instances into an option.
|
||||
func Options(opts ...Option) Option {
|
||||
return func(a *auth) {
|
||||
for _, opt := range opts {
|
||||
opt(a)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Option configures token auth middleware
|
||||
type Option func(a *auth)
|
||||
|
||||
// WithFindTokenFn adds token find function to the list
|
||||
func WithFindTokenFn(fn func(r *http.Request) string) Option {
|
||||
return func(a *auth) {
|
||||
a.findTokenFns = append(a.findTokenFns, fn)
|
||||
}
|
||||
}
|
||||
|
||||
// SetFindTokenFns sets token find functions list
|
||||
func SetFindTokenFns(fns ...func(r *http.Request) string) Option {
|
||||
return func(a *auth) {
|
||||
a.findTokenFns = fns
|
||||
}
|
||||
}
|
||||
|
||||
// SetValidateTokenFn sets token validation function
|
||||
func SetValidateTokenFn(fn func(token string) (bool, interface{})) Option {
|
||||
return func(a *auth) {
|
||||
a.validateTokenFn = fn
|
||||
}
|
||||
}
|
||||
|
||||
// SetResponseHandler sets response handler
|
||||
func SetResponseHandler(fn middleware.ResponseHandle) Option {
|
||||
return func(a *auth) {
|
||||
a.responseHandler = fn
|
||||
}
|
||||
}
|
||||
|
||||
type auth struct {
|
||||
findTokenFns []func(r *http.Request) string
|
||||
validateTokenFn func(token string) (bool, interface{})
|
||||
responseHandler middleware.ResponseHandle
|
||||
}
|
||||
|
||||
// Middleware returns token auth middleware
|
||||
func Middleware(opts ...Option) func(next http.Handler) http.Handler {
|
||||
a := &auth{}
|
||||
opts = append([]Option{DefaultOptions}, opts...)
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(a)
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
var token string
|
||||
for _, findTokenFn := range a.findTokenFns {
|
||||
token = findTokenFn(r)
|
||||
if token != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
a.responseHandler(w, r, ErrTokenNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
valid, data := a.validateTokenFn(token)
|
||||
if !valid {
|
||||
a.responseHandler(w, r, ErrTokenInvalid)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
ctx = context.WithValue(ctx, DataCtxKey, data)
|
||||
ctx = context.WithValue(ctx, TokenCtxKey, token)
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RespondWithUnauthorized is a default response handler
|
||||
func RespondWithUnauthorized(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
func rejectAll(token string) (bool, interface{}) {
|
||||
return false, nil
|
||||
}
|
123
auth/token/token_test.go
Normal file
123
auth/token/token_test.go
Normal file
@ -0,0 +1,123 @@
|
||||
package token
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.pkg.cx/middleware"
|
||||
)
|
||||
|
||||
var testHandler = func(t *testing.T) http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
var validateFn = func(token string) (bool, interface{}) {
|
||||
return token == "valid", nil
|
||||
}
|
||||
|
||||
func TestTokenAuthDefaults(t *testing.T) {
|
||||
auth := Middleware()
|
||||
|
||||
server := httptest.NewServer(auth(testHandler(t)))
|
||||
defer server.Close()
|
||||
|
||||
res, err := http.Get(server.URL)
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
|
||||
b, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Unauthorized\n", string(b))
|
||||
|
||||
res, err = http.Get(server.URL + "?token=invalid")
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
}
|
||||
|
||||
func TestTokenAuthValidateFn(t *testing.T) {
|
||||
auth := Middleware(
|
||||
SetValidateTokenFn(validateFn),
|
||||
)
|
||||
|
||||
server := httptest.NewServer(auth(testHandler(t)))
|
||||
defer server.Close()
|
||||
|
||||
res, err := http.Get(server.URL + "?token=invalid")
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
|
||||
res, err = http.Get(server.URL + "?token=valid")
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
b, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "resp", string(b))
|
||||
}
|
||||
|
||||
func TestTokenAuthFindTokenFns(t *testing.T) {
|
||||
auth := Middleware(
|
||||
SetValidateTokenFn(validateFn),
|
||||
SetFindTokenFns(middleware.TokenFromQuery("api_key")),
|
||||
)
|
||||
|
||||
server := httptest.NewServer(auth(testHandler(t)))
|
||||
defer server.Close()
|
||||
|
||||
res, err := http.Get(server.URL + "?token=valid")
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
||||
|
||||
res, err = http.Get(server.URL + "?api_key=valid")
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
func TestTokenAuthContext(t *testing.T) {
|
||||
type data struct {
|
||||
inner string
|
||||
}
|
||||
|
||||
validateWithDataFn := func(token string) (bool, interface{}) {
|
||||
return token == "valid", &data{"test data"}
|
||||
}
|
||||
|
||||
auth := Middleware(
|
||||
SetValidateTokenFn(validateWithDataFn),
|
||||
)
|
||||
|
||||
testCtxHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
data, ok := r.Context().Value(DataCtxKey).(*data)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "test data", data.inner)
|
||||
|
||||
token, ok := r.Context().Value(TokenCtxKey).(string)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "valid", token)
|
||||
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(auth(testCtxHandler))
|
||||
defer server.Close()
|
||||
|
||||
res, err := http.Get(server.URL + "?token=valid")
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
44
find_token.go
Normal file
44
find_token.go
Normal file
@ -0,0 +1,44 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// TokenFromAuthorizationHeader tries to retrieve the token string from the
|
||||
// "Authorization" request header and strips BEARER prefix if necessary
|
||||
func TokenFromAuthorizationHeader(r *http.Request) string {
|
||||
header := r.Header.Get("Authorization")
|
||||
|
||||
if len(header) > 7 && strings.ToUpper(header[0:6]) == "BEARER" {
|
||||
return header[7:]
|
||||
}
|
||||
|
||||
return header
|
||||
}
|
||||
|
||||
// TokenFromHeader tries to retrieve the token string from the given header
|
||||
func TokenFromHeader(headerKey string) func(r *http.Request) string {
|
||||
return func(r *http.Request) string {
|
||||
return r.Header.Get(headerKey)
|
||||
}
|
||||
}
|
||||
|
||||
// TokenFromQuery tries to retrieve the token string from the given query param
|
||||
func TokenFromQuery(param string) func(r *http.Request) string {
|
||||
return func(r *http.Request) string {
|
||||
return r.URL.Query().Get(param)
|
||||
}
|
||||
}
|
||||
|
||||
// TokenFromCookie tries to retrieve the token string from a given cookie
|
||||
func TokenFromCookie(cookieName string) func(r *http.Request) string {
|
||||
return func(r *http.Request) string {
|
||||
cookie, err := r.Cookie(cookieName)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return cookie.Value
|
||||
}
|
||||
}
|
50
find_token_test.go
Normal file
50
find_token_test.go
Normal file
@ -0,0 +1,50 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTokenFromAuthorizationHeader(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", "/", nil)
|
||||
require.Nil(t, err)
|
||||
req.Header.Set("Authorization", "abc")
|
||||
assert.Equal(t, "abc", TokenFromAuthorizationHeader(req))
|
||||
|
||||
req, err = http.NewRequest("GET", "/", nil)
|
||||
require.Nil(t, err)
|
||||
req.Header.Set("Authorization", "abcdefghe")
|
||||
assert.Equal(t, "abcdefghe", TokenFromAuthorizationHeader(req))
|
||||
|
||||
req, err = http.NewRequest("GET", "/", nil)
|
||||
require.Nil(t, err)
|
||||
req.Header.Set("Authorization", "Bearer abc")
|
||||
assert.Equal(t, "abc", TokenFromAuthorizationHeader(req))
|
||||
}
|
||||
|
||||
func TestTokenFromHeader(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", "/", nil)
|
||||
require.Nil(t, err)
|
||||
req.Header.Set("X-Token", "abc")
|
||||
assert.Equal(t, "abc", TokenFromHeader("X-Token")(req))
|
||||
}
|
||||
|
||||
func TestTokenFromQuery(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", "/?token=abc", nil)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, "abc", TokenFromQuery("token")(req))
|
||||
}
|
||||
|
||||
func TestTokenFromCookie(t *testing.T) {
|
||||
req, err := http.NewRequest("GET", "/", nil)
|
||||
require.Nil(t, err)
|
||||
req.AddCookie(&http.Cookie{Name: "token", Value: "abc"})
|
||||
assert.Equal(t, "abc", TokenFromCookie("token")(req))
|
||||
|
||||
req, err = http.NewRequest("GET", "/", nil)
|
||||
require.Nil(t, err)
|
||||
assert.Equal(t, "", TokenFromCookie("token")(req))
|
||||
}
|
6
go.mod
6
go.mod
@ -1,3 +1,9 @@
|
||||
module go.pkg.cx/middleware
|
||||
|
||||
go 1.15
|
||||
|
||||
require (
|
||||
github.com/go-pkgz/expirable-cache v0.0.3
|
||||
github.com/lestrrat-go/jwx v1.0.5
|
||||
github.com/stretchr/testify v1.6.1
|
||||
)
|
||||
|
40
go.sum
Normal file
40
go.sum
Normal file
@ -0,0 +1,40 @@
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-pkgz/expirable-cache v0.0.3 h1:rTh6qNPp78z0bQE6HDhXBHUwqnV9i09Vm6dksJLXQDc=
|
||||
github.com/go-pkgz/expirable-cache v0.0.3/go.mod h1:+IauqN00R2FqNRLCLA+X5YljQJrwB179PfiAoMPlTlQ=
|
||||
github.com/lestrrat-go/iter v0.0.0-20200422075355-fc1769541911 h1:FvnrqecqX4zT0wOIbYK1gNgTm0677INEWiFY8UEYggY=
|
||||
github.com/lestrrat-go/iter v0.0.0-20200422075355-fc1769541911/go.mod h1:zIdgO1mRKhn8l9vrZJZz9TUMMFbQbLeTsbqPDrJ/OJc=
|
||||
github.com/lestrrat-go/jwx v1.0.5 h1:8bVUGXXkR3+YQNwuFof3lLxSJMLtrscHJfGI6ZIBRD0=
|
||||
github.com/lestrrat-go/jwx v1.0.5/go.mod h1:TPF17WiSFegZo+c20fdpw49QD+/7n4/IsGvEmCSWwT0=
|
||||
github.com/lestrrat-go/pdebug v0.0.0-20200204225717-4d6bd78da58d/go.mod h1:B06CSso/AWxiPejj+fheUINGeBKeeEZNt8w+EoU7+L8=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20200417140056-c07e33ef3290/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
31
logger/body_util.go
Normal file
31
logger/body_util.go
Normal file
@ -0,0 +1,31 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var regexpMultiWhitespace = regexp.MustCompile(`[\s\p{Zs}]{2,}`)
|
||||
|
||||
func peek(r io.Reader, n int64) (io.Reader, string, bool, error) {
|
||||
if n < 0 {
|
||||
n = 0
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
_, err := io.CopyN(&buf, r, n+1)
|
||||
|
||||
if err == io.EOF {
|
||||
return &buf, buf.String(), false, nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, "", false, err
|
||||
}
|
||||
|
||||
s := buf.String()
|
||||
s = s[:len(s)-1]
|
||||
|
||||
return io.MultiReader(&buf, r), s, true, nil
|
||||
}
|
54
logger/body_util_test.go
Normal file
54
logger/body_util_test.go
Normal file
@ -0,0 +1,54 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type errReader struct{}
|
||||
|
||||
func (errReader) Read(_ []byte) (n int, err error) {
|
||||
return 0, errors.New("test error")
|
||||
}
|
||||
|
||||
func TestLoggerPeek(t *testing.T) {
|
||||
cases := []struct {
|
||||
body string
|
||||
n int64
|
||||
excerpt string
|
||||
hasMore bool
|
||||
}{
|
||||
{"", -1, "", false},
|
||||
{"", 0, "", false},
|
||||
{"", 1024, "", false},
|
||||
{"123456", -1, "", true},
|
||||
{"123456", 0, "", true},
|
||||
{"123456", 4, "1234", true},
|
||||
{"123456", 5, "12345", true},
|
||||
{"123456", 6, "123456", false},
|
||||
{"123456", 7, "123456", false},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
r, excerpt, hasMore, err := peek(strings.NewReader(c.body), c.n)
|
||||
if !assert.NoError(t, err) {
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(r)
|
||||
if !assert.NoError(t, err) {
|
||||
continue
|
||||
}
|
||||
|
||||
assert.Equal(t, c.body, string(body))
|
||||
assert.Equal(t, c.excerpt, excerpt)
|
||||
assert.Equal(t, c.hasMore, hasMore)
|
||||
}
|
||||
|
||||
_, _, _, err := peek(errReader{}, 1024)
|
||||
assert.Error(t, err)
|
||||
}
|
203
logger/logger.go
Normal file
203
logger/logger.go
Normal file
@ -0,0 +1,203 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DefaultOptions represents default logger middleware options
|
||||
var DefaultOptions = Options(
|
||||
WithBody(),
|
||||
SetMaxBodySize(1*1024*1024),
|
||||
SetIPFn(defaultIPFn),
|
||||
SetLogHandler(DefaultLogHandler),
|
||||
)
|
||||
|
||||
// Options turns a list of option instances into an option.
|
||||
func Options(opts ...Option) Option {
|
||||
return func(l *logger) {
|
||||
for _, opt := range opts {
|
||||
opt(l)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Option configures logger middleware
|
||||
type Option func(l *logger)
|
||||
|
||||
// WithBody enables request body logging
|
||||
func WithBody() Option {
|
||||
return func(l *logger) {
|
||||
l.logBody = true
|
||||
}
|
||||
}
|
||||
|
||||
// WithoutBody disables request body logging
|
||||
func WithoutBody() Option {
|
||||
return func(l *logger) {
|
||||
l.logBody = false
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxBodySize sets maximum request body size
|
||||
func SetMaxBodySize(size int) Option {
|
||||
return func(l *logger) {
|
||||
l.maxBodySize = size
|
||||
}
|
||||
}
|
||||
|
||||
// SetIPFn sets function that extracts ip from request
|
||||
func SetIPFn(fn func(r *http.Request) string) Option {
|
||||
return func(l *logger) {
|
||||
l.ipFn = fn
|
||||
}
|
||||
}
|
||||
|
||||
// SetUserFn sets function that extracts user from request
|
||||
func SetUserFn(fn func(r *http.Request) string) Option {
|
||||
return func(l *logger) {
|
||||
l.userFn = fn
|
||||
}
|
||||
}
|
||||
|
||||
// SetSanitizeFn sets function that sanitizes request query or body
|
||||
func SetSanitizeFn(fn func(input string) string) Option {
|
||||
return func(l *logger) {
|
||||
l.sanitizeFn = fn
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogHandler sets log handler
|
||||
func SetLogHandler(fn func(entry LogEntry)) Option {
|
||||
return func(l *logger) {
|
||||
l.logHandler = fn
|
||||
}
|
||||
}
|
||||
|
||||
// LogEntry is a http log entry
|
||||
type LogEntry struct {
|
||||
Method string
|
||||
RawURL string
|
||||
Body string
|
||||
RemoteIP string
|
||||
StatusCode int
|
||||
Written int
|
||||
Duration time.Duration
|
||||
User string
|
||||
TraceID string
|
||||
}
|
||||
|
||||
type logger struct {
|
||||
logBody bool
|
||||
maxBodySize int
|
||||
ipFn func(r *http.Request) string
|
||||
userFn func(r *http.Request) string
|
||||
sanitizeFn func(input string) string
|
||||
logHandler func(entry LogEntry)
|
||||
}
|
||||
|
||||
func (s *logger) body(r *http.Request) string {
|
||||
if !s.logBody {
|
||||
return ""
|
||||
}
|
||||
|
||||
rdr, body, hasMore, err := peek(r.Body, int64(s.maxBodySize))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
r.Body = ioutil.NopCloser(rdr)
|
||||
|
||||
if len(body) > 0 {
|
||||
body = strings.Replace(body, "\n", " ", -1)
|
||||
body = regexpMultiWhitespace.ReplaceAllString(body, " ")
|
||||
}
|
||||
|
||||
if hasMore {
|
||||
body += "..."
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// Middleware is a http logger middleware
|
||||
func Middleware(opts ...Option) func(http.Handler) http.Handler {
|
||||
l := &logger{}
|
||||
opts = append([]Option{DefaultOptions}, opts...)
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(l)
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ww := newTrackingResponseWriter(w)
|
||||
|
||||
body := l.body(r)
|
||||
if l.sanitizeFn != nil {
|
||||
body = l.sanitizeFn(body)
|
||||
}
|
||||
|
||||
startTS := time.Now()
|
||||
defer func() {
|
||||
stopTS := time.Now()
|
||||
|
||||
query := r.URL.String()
|
||||
if qun, err := url.QueryUnescape(query); err == nil {
|
||||
query = qun
|
||||
}
|
||||
|
||||
if l.sanitizeFn != nil {
|
||||
query = l.sanitizeFn(query)
|
||||
}
|
||||
|
||||
var ip string
|
||||
if l.ipFn != nil {
|
||||
ip = l.ipFn(r)
|
||||
}
|
||||
|
||||
var user string
|
||||
if l.userFn != nil {
|
||||
user = l.userFn(r)
|
||||
}
|
||||
|
||||
entry := LogEntry{
|
||||
Method: r.Method,
|
||||
RawURL: query,
|
||||
RemoteIP: ip,
|
||||
Body: body,
|
||||
StatusCode: ww.status,
|
||||
Written: ww.size,
|
||||
Duration: stopTS.Sub(startTS),
|
||||
User: user,
|
||||
TraceID: r.Header.Get("X-Request-ID"),
|
||||
}
|
||||
|
||||
l.logHandler(entry)
|
||||
}()
|
||||
|
||||
next.ServeHTTP(ww, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultLogHandler is a default log handler
|
||||
func DefaultLogHandler(entry LogEntry) {
|
||||
log.Printf(
|
||||
"%s - %s - %s - %d (%d) - %v",
|
||||
entry.Method,
|
||||
entry.RawURL,
|
||||
entry.RemoteIP,
|
||||
entry.StatusCode,
|
||||
entry.Written,
|
||||
entry.Duration,
|
||||
)
|
||||
}
|
||||
|
||||
func defaultIPFn(r *http.Request) string {
|
||||
return strings.Split(r.RemoteAddr, ":")[0]
|
||||
}
|
134
logger/logger_test.go
Normal file
134
logger/logger_test.go
Normal file
@ -0,0 +1,134 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLogger(t *testing.T) {
|
||||
logHandler := func(entry LogEntry) {
|
||||
assert.Equal(t, entry.Method, "GET")
|
||||
assert.Equal(t, entry.RawURL, "/")
|
||||
assert.Equal(t, entry.Body, "")
|
||||
assert.Equal(t, entry.RemoteIP, "127.0.0.1")
|
||||
assert.Equal(t, entry.StatusCode, http.StatusOK)
|
||||
assert.Equal(t, entry.Written, len([]byte("resp")))
|
||||
}
|
||||
|
||||
logger := Middleware(
|
||||
WithoutBody(),
|
||||
SetLogHandler(logHandler),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(logger(handler))
|
||||
defer server.Close()
|
||||
|
||||
res, err := http.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
b, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "resp", string(b))
|
||||
}
|
||||
|
||||
func TestLoggerBody(t *testing.T) {
|
||||
logHandler := func(entry LogEntry) {
|
||||
assert.Equal(t, entry.Method, "POST")
|
||||
assert.Equal(t, entry.Body, "1234567890...")
|
||||
}
|
||||
|
||||
logger := Middleware(
|
||||
SetMaxBodySize(10),
|
||||
SetLogHandler(logHandler),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "12345678901234567890", string(body))
|
||||
|
||||
_, err = w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(logger(handler))
|
||||
defer server.Close()
|
||||
|
||||
res, err := http.Post(server.URL, "", bytes.NewBufferString("12345678901234567890"))
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
func TestLoggerUser(t *testing.T) {
|
||||
logHandler := func(entry LogEntry) {
|
||||
assert.Equal(t, entry.User, "user")
|
||||
}
|
||||
|
||||
logger := Middleware(
|
||||
SetUserFn(func(req *http.Request) string { return "user" }),
|
||||
SetLogHandler(logHandler),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(logger(handler))
|
||||
defer server.Close()
|
||||
|
||||
res, err := http.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
func TestLoggerSanitize(t *testing.T) {
|
||||
logHandler := func(entry LogEntry) {
|
||||
assert.Equal(t, entry.RawURL, "/?param=*****")
|
||||
assert.Equal(t, entry.Body, "body|*****")
|
||||
}
|
||||
|
||||
sanitizeFn := func(input string) string {
|
||||
return strings.ReplaceAll(input, "password", "*****")
|
||||
}
|
||||
|
||||
logger := Middleware(
|
||||
SetSanitizeFn(sanitizeFn),
|
||||
SetLogHandler(logHandler),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, r.URL.Query().Get("param"), "password")
|
||||
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "body|password", string(body))
|
||||
|
||||
_, err = w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(logger(handler))
|
||||
defer server.Close()
|
||||
|
||||
res, err := http.Post(server.URL+"?param=password", "", bytes.NewBufferString("body|password"))
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
53
logger/tracking_response_writer.go
Normal file
53
logger/tracking_response_writer.go
Normal file
@ -0,0 +1,53 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
var (
|
||||
errWriterNotImplentsHijacker = errors.New("ResponseWriter does not implement the Hijacker interface") // nolint:golint
|
||||
)
|
||||
|
||||
type trackingResponseWriter struct {
|
||||
http.ResponseWriter
|
||||
|
||||
size int
|
||||
status int
|
||||
}
|
||||
|
||||
func newTrackingResponseWriter(w http.ResponseWriter) *trackingResponseWriter {
|
||||
return &trackingResponseWriter{ResponseWriter: w, status: 200}
|
||||
}
|
||||
|
||||
// WriteHeader implements http.ResponseWriter and saves status
|
||||
func (s *trackingResponseWriter) WriteHeader(status int) {
|
||||
s.status = status
|
||||
s.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
// Write implements http.ResponseWriter and tracks number of bytes written
|
||||
func (s *trackingResponseWriter) Write(b []byte) (int, error) {
|
||||
size, err := s.ResponseWriter.Write(b)
|
||||
s.size += size
|
||||
|
||||
return size, err
|
||||
}
|
||||
|
||||
// Flush implements http.Flusher
|
||||
func (s *trackingResponseWriter) Flush() {
|
||||
if f, ok := s.ResponseWriter.(http.Flusher); ok {
|
||||
f.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack implements http.Hijacker
|
||||
func (s *trackingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hj, ok := s.ResponseWriter.(http.Hijacker); ok {
|
||||
return hj.Hijack()
|
||||
}
|
||||
|
||||
return nil, nil, errWriterNotImplentsHijacker
|
||||
}
|
41
no_cache.go
Normal file
41
no_cache.go
Normal file
@ -0,0 +1,41 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
var epoch = time.Unix(0, 0).Format(time.RFC1123)
|
||||
|
||||
var etagHeaders = []string{
|
||||
"ETag",
|
||||
"If-Modified-Since",
|
||||
"If-Match",
|
||||
"If-None-Match",
|
||||
"If-Range",
|
||||
"If-Unmodified-Since",
|
||||
}
|
||||
|
||||
var noCacheHeaders = map[string]string{
|
||||
"Expires": epoch,
|
||||
"Cache-Control": "no-cache, no-store, no-transform, must-revalidate, private, max-age=0",
|
||||
"Pragma": "no-cache",
|
||||
"X-Accel-Expires": "0",
|
||||
}
|
||||
|
||||
// NoCache sets a number of HTTP headers to prevent a router from being cached
|
||||
func NoCache(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
for _, v := range etagHeaders {
|
||||
if r.Header.Get(v) != "" {
|
||||
r.Header.Del(v)
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range noCacheHeaders {
|
||||
w.Header().Set(k, v)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
43
no_cache_test.go
Normal file
43
no_cache_test.go
Normal file
@ -0,0 +1,43 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNoCache(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "", r.Header.Get("ETag"))
|
||||
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(NoCache(handler))
|
||||
defer server.Close()
|
||||
|
||||
client := http.Client{}
|
||||
|
||||
req, err := http.NewRequest("GET", server.URL, nil)
|
||||
require.Nil(t, err)
|
||||
req.Header.Set("ETag", "ETagValue")
|
||||
|
||||
res, err := client.Do(req)
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, time.Unix(0, 0).Format(time.RFC1123), res.Header.Get("Expires"))
|
||||
assert.Equal(t, "no-cache, no-store, no-transform, must-revalidate, private, max-age=0", res.Header.Get("Cache-Control"))
|
||||
assert.Equal(t, "no-cache", res.Header.Get("Pragma"))
|
||||
assert.Equal(t, "0", res.Header.Get("X-Accel-Expires"))
|
||||
|
||||
b, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "resp", string(b))
|
||||
}
|
6
ping.go
6
ping.go
@ -7,7 +7,7 @@ import (
|
||||
|
||||
// Ping responses with pong to /ping request and stops chain
|
||||
func Ping(next http.Handler) http.Handler {
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == "GET" && strings.HasSuffix(strings.ToLower(r.URL.Path), "/ping") {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
@ -17,7 +17,5 @@ func Ping(next http.Handler) http.Handler {
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return http.HandlerFunc(fn)
|
||||
})
|
||||
}
|
||||
|
48
ping_test.go
Normal file
48
ping_test.go
Normal file
@ -0,0 +1,48 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPing(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(Ping(handler))
|
||||
defer server.Close()
|
||||
|
||||
res, err := http.Get(server.URL)
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
b, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "resp", string(b))
|
||||
|
||||
res, err = http.Get(server.URL + "/ping")
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
b, err = ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "pong", string(b))
|
||||
|
||||
res, err = http.Get(server.URL + "/a/b/c/ping")
|
||||
require.Nil(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
b, err = ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "pong", string(b))
|
||||
}
|
175
ratelimit/rate_limit.go
Normal file
175
ratelimit/rate_limit.go
Normal file
@ -0,0 +1,175 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
cache "github.com/go-pkgz/expirable-cache"
|
||||
|
||||
"go.pkg.cx/middleware"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrLimitReached = errors.New("limit reached")
|
||||
)
|
||||
|
||||
// DefaultOptions represents default timeout middleware options
|
||||
var DefaultOptions = Options(
|
||||
SetLimit(100),
|
||||
SetPeriod(time.Minute*1),
|
||||
SetKeyFn(defaultKeyFn),
|
||||
SetResponseHandler(RespondWithTooManyRequests),
|
||||
)
|
||||
|
||||
// Options turns a list of option instances into an option.
|
||||
func Options(opts ...Option) Option {
|
||||
return func(l *limiter) {
|
||||
for _, opt := range opts {
|
||||
opt(l)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Option configures timeout middleware
|
||||
type Option func(l *limiter)
|
||||
|
||||
// SetLimit sets request limit
|
||||
func SetLimit(limit int) Option {
|
||||
if limit < 1 {
|
||||
panic("rate limit middleware expects limit > 0")
|
||||
}
|
||||
|
||||
return func(l *limiter) {
|
||||
l.limit = limit
|
||||
}
|
||||
}
|
||||
|
||||
// SetPeriod sets limiter period
|
||||
func SetPeriod(period time.Duration) Option {
|
||||
return func(l *limiter) {
|
||||
l.period = period
|
||||
}
|
||||
}
|
||||
|
||||
// SetKeyFn sets limiter key extraction function
|
||||
func SetKeyFn(fn func(r *http.Request) string) Option {
|
||||
return func(l *limiter) {
|
||||
l.keyFn = fn
|
||||
}
|
||||
}
|
||||
|
||||
// SetResponseHandler sets response handler
|
||||
func SetResponseHandler(fn middleware.ResponseHandle) Option {
|
||||
return func(l *limiter) {
|
||||
l.responseHandler = fn
|
||||
}
|
||||
}
|
||||
|
||||
type info struct {
|
||||
limit int
|
||||
remaining int
|
||||
reset int64
|
||||
reached bool
|
||||
}
|
||||
|
||||
type entry struct {
|
||||
count int
|
||||
expiration time.Time
|
||||
}
|
||||
|
||||
type limiter struct {
|
||||
limit int
|
||||
period time.Duration
|
||||
keyFn func(r *http.Request) string
|
||||
responseHandler middleware.ResponseHandle
|
||||
|
||||
lock sync.Mutex
|
||||
cache cache.Cache
|
||||
}
|
||||
|
||||
func (s *limiter) initCache() {
|
||||
c, err := cache.NewCache(cache.TTL(s.period))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
s.cache = c
|
||||
}
|
||||
|
||||
func (s *limiter) try(key string) info {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
if e, ok := s.cache.Get(key); ok {
|
||||
e.(*entry).count++
|
||||
s.cache.Set(key, e, 0)
|
||||
|
||||
return s.infoFromEntry(e.(*entry))
|
||||
}
|
||||
|
||||
e := &entry{count: 1, expiration: now.Add(s.period)}
|
||||
s.cache.Set(key, e, 0)
|
||||
|
||||
return s.infoFromEntry(e)
|
||||
}
|
||||
|
||||
func (s *limiter) infoFromEntry(e *entry) info {
|
||||
reached := true
|
||||
remaining := 0
|
||||
|
||||
if e.count <= s.limit {
|
||||
reached = false
|
||||
remaining = s.limit - e.count
|
||||
}
|
||||
|
||||
return info{
|
||||
limit: s.limit,
|
||||
remaining: remaining,
|
||||
reset: e.expiration.Unix(),
|
||||
reached: reached,
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware is a rate limiter middleware
|
||||
func Middleware(opts ...Option) func(next http.Handler) http.Handler {
|
||||
l := &limiter{}
|
||||
opts = append([]Option{DefaultOptions}, opts...)
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(l)
|
||||
}
|
||||
|
||||
l.initCache()
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
info := l.try(l.keyFn(r))
|
||||
|
||||
w.Header().Add("X-RateLimit-Limit", strconv.Itoa(info.limit))
|
||||
w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(info.remaining))
|
||||
w.Header().Add("X-RateLimit-Reset", strconv.FormatInt(info.reset, 10))
|
||||
|
||||
if info.reached {
|
||||
l.responseHandler(w, r, ErrLimitReached)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RespondWithTooManyRequests is a default response handler
|
||||
func RespondWithTooManyRequests(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
func defaultKeyFn(r *http.Request) string {
|
||||
return strings.Split(r.RemoteAddr, ":")[0]
|
||||
}
|
163
ratelimit/rate_limit_test.go
Normal file
163
ratelimit/rate_limit_test.go
Normal file
@ -0,0 +1,163 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRateLimitSequential(t *testing.T) {
|
||||
rateLimit := Middleware(
|
||||
SetLimit(5),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(rateLimit(handler))
|
||||
defer server.Close()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
res, err := http.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
if i < 5 {
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
} else {
|
||||
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitConcurrent(t *testing.T) {
|
||||
rateLimit := Middleware(
|
||||
SetLimit(5),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(rateLimit(handler))
|
||||
defer server.Close()
|
||||
|
||||
counter := int64(0)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
res, err := http.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode == http.StatusOK {
|
||||
atomic.AddInt64(&counter, 1)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
assert.Equal(t, int64(5), atomic.LoadInt64(&counter))
|
||||
}
|
||||
|
||||
func TestRateLimitHeaders(t *testing.T) {
|
||||
rateLimit := Middleware(
|
||||
SetLimit(1),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(rateLimit(handler))
|
||||
defer server.Close()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
res, err := http.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit"))
|
||||
assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining"))
|
||||
|
||||
resetTS, err := strconv.Atoi(res.Header.Get("X-RateLimit-Reset"))
|
||||
assert.NoError(t, err)
|
||||
assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1)
|
||||
|
||||
b, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "resp", string(b))
|
||||
|
||||
res, err = http.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
|
||||
assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit"))
|
||||
assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining"))
|
||||
|
||||
resetTS, err = strconv.Atoi(res.Header.Get("X-RateLimit-Reset"))
|
||||
assert.NoError(t, err)
|
||||
assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1)
|
||||
|
||||
b, err = ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Too Many Requests\n", string(b))
|
||||
}
|
||||
|
||||
func TestRateLimitExpiration(t *testing.T) {
|
||||
rateLimit := Middleware(
|
||||
SetLimit(5),
|
||||
SetPeriod(time.Millisecond*500),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(rateLimit(handler))
|
||||
defer server.Close()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
res, err := http.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
if i < 5 {
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
} else {
|
||||
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
res, err := http.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
if i < 5 {
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
} else {
|
||||
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
|
||||
}
|
||||
}
|
||||
}
|
48
real_ip.go
Normal file
48
real_ip.go
Normal file
@ -0,0 +1,48 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// RealIP is a middleware that sets a http.Request's RemoteAddr to the results
|
||||
// of parsing either the X-Forwarded-For header or the X-Real-IP header (in that
|
||||
// order).
|
||||
//
|
||||
// This middleware should be inserted fairly early in the middleware stack to
|
||||
// ensure that subsequent layers (e.g., request loggers) which examine the
|
||||
// RemoteAddr will see the intended value.
|
||||
//
|
||||
// You should only use this middleware if you can trust the headers passed to
|
||||
// you (in particular, the two headers this middleware uses), for example
|
||||
// because you have placed a reverse proxy like HAProxy or nginx in front of
|
||||
// chi. If your reverse proxies are configured to pass along arbitrary header
|
||||
// values from the client, or if you use this middleware without a reverse
|
||||
// proxy, malicious clients will be able to make you very sad (or, depending on
|
||||
// how you're using RemoteAddr, vulnerable to an attack of some sort).
|
||||
func RealIP(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if rip := realIP(r); rip != "" {
|
||||
r.RemoteAddr = rip
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func realIP(r *http.Request) string {
|
||||
if xrip := r.Header.Get("X-Real-IP"); xrip != "" {
|
||||
return xrip
|
||||
}
|
||||
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
i := strings.Index(xff, ", ")
|
||||
if i == -1 {
|
||||
i = len(xff)
|
||||
}
|
||||
|
||||
return xff[:i]
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
100
real_ip_test.go
Normal file
100
real_ip_test.go
Normal file
@ -0,0 +1,100 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var realIPTestHandler = func(t *testing.T) http.HandlerFunc {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, r.RemoteAddr, "3.3.3.3")
|
||||
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRealIPXRealIP(t *testing.T) {
|
||||
server := httptest.NewServer(RealIP(realIPTestHandler(t)))
|
||||
defer server.Close()
|
||||
|
||||
client := http.Client{}
|
||||
|
||||
req, err := http.NewRequest("GET", server.URL, nil)
|
||||
require.Nil(t, err)
|
||||
req.Header.Set("X-Real-IP", "3.3.3.3")
|
||||
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
b, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "resp", string(b))
|
||||
}
|
||||
|
||||
func TestRealIPXForwardedFor(t *testing.T) {
|
||||
server := httptest.NewServer(RealIP(realIPTestHandler(t)))
|
||||
defer server.Close()
|
||||
|
||||
client := http.Client{}
|
||||
|
||||
req, err := http.NewRequest("GET", server.URL, nil)
|
||||
require.Nil(t, err)
|
||||
req.Header.Set("X-Forwarded-For", "3.3.3.3")
|
||||
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
req, err = http.NewRequest("GET", server.URL, nil)
|
||||
require.Nil(t, err)
|
||||
req.Header.Set("X-Forwarded-For", "3.3.3.3, 4.4.4.4, 5.5.5.5")
|
||||
|
||||
res, err = client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
func TestRealIPBothHeaders(t *testing.T) {
|
||||
server := httptest.NewServer(RealIP(realIPTestHandler(t)))
|
||||
defer server.Close()
|
||||
|
||||
client := http.Client{}
|
||||
|
||||
req, err := http.NewRequest("GET", server.URL, nil)
|
||||
require.Nil(t, err)
|
||||
req.Header.Set("X-Real-IP", "3.3.3.3")
|
||||
req.Header.Set("X-Forwarded-For", "4.4.4.4")
|
||||
|
||||
res, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
func TestRealIPNoHeaders(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, strings.Split(r.RemoteAddr, ":")[0], "127.0.0.1")
|
||||
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(RealIP(handler))
|
||||
defer server.Close()
|
||||
|
||||
res, err := http.Get(server.URL)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
92
recoverer/recoverer.go
Normal file
92
recoverer/recoverer.go
Normal file
@ -0,0 +1,92 @@
|
||||
package recoverer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
|
||||
"go.pkg.cx/middleware"
|
||||
)
|
||||
|
||||
// DefaultOptions represents default recoverer middleware options
|
||||
var DefaultOptions = Options(
|
||||
SetLogStackFn(defaultLogStackFn),
|
||||
SetLogRecoverFn(defaultLogRecoverFn),
|
||||
SetResponseHandler(RespondWithInternalServerError),
|
||||
)
|
||||
|
||||
// Options turns a list of option instances into an option.
|
||||
func Options(opts ...Option) Option {
|
||||
return func(r *recoverer) {
|
||||
for _, opt := range opts {
|
||||
opt(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogStackFn sets log stack function
|
||||
func SetLogStackFn(fn func(stack []byte)) Option {
|
||||
return func(r *recoverer) {
|
||||
r.logStackFn = fn
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogRecoverFn sets log recover information function
|
||||
func SetLogRecoverFn(fn func(rec interface{})) Option {
|
||||
return func(r *recoverer) {
|
||||
r.logRecoverFn = fn
|
||||
}
|
||||
}
|
||||
|
||||
// SetResponseHandler sets response handler
|
||||
func SetResponseHandler(fn middleware.ResponseHandle) Option {
|
||||
return func(r *recoverer) {
|
||||
r.responseHandler = fn
|
||||
}
|
||||
}
|
||||
|
||||
// Option configures recoverer middleware
|
||||
type Option func(r *recoverer)
|
||||
|
||||
type recoverer struct {
|
||||
logStackFn func(stack []byte)
|
||||
logRecoverFn func(rec interface{})
|
||||
responseHandler middleware.ResponseHandle
|
||||
}
|
||||
|
||||
// Middleware is a recoverer middleware that recovers from panic
|
||||
func Middleware(opts ...Option) func(next http.Handler) http.Handler {
|
||||
rvr := &recoverer{}
|
||||
opts = append([]Option{DefaultOptions}, opts...)
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(rvr)
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
rvr.logRecoverFn(rec)
|
||||
rvr.logStackFn(debug.Stack())
|
||||
rvr.responseHandler(w, r, nil)
|
||||
}
|
||||
}()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RespondWithInternalServerError is a default response handler
|
||||
func RespondWithInternalServerError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
func defaultLogStackFn(stack []byte) {
|
||||
fmt.Println(string(stack))
|
||||
}
|
||||
|
||||
func defaultLogRecoverFn(rec interface{}) {
|
||||
fmt.Printf("request panic, %v", rec)
|
||||
}
|
58
recoverer/recoverer_test.go
Normal file
58
recoverer/recoverer_test.go
Normal file
@ -0,0 +1,58 @@
|
||||
package recoverer
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRecoverer(t *testing.T) {
|
||||
logStackFn := func(stack []byte) {
|
||||
s := string(stack)
|
||||
assert.Contains(t, s, "goroutine")
|
||||
assert.Contains(t, s, "TestRecoverer")
|
||||
}
|
||||
|
||||
logRecoverFn := func(rec interface{}) {
|
||||
s, ok := rec.(string)
|
||||
assert.True(t, ok)
|
||||
assert.Contains(t, s, "panic message")
|
||||
}
|
||||
|
||||
recoverer := Middleware(
|
||||
SetLogStackFn(logStackFn),
|
||||
SetLogRecoverFn(logRecoverFn),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/panic" {
|
||||
panic("panic message")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(recoverer(handler))
|
||||
defer server.Close()
|
||||
|
||||
res, err := http.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
b, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "resp", string(b))
|
||||
|
||||
res, err = http.Get(server.URL + "/panic")
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
|
||||
}
|
176
throttle/throttle.go
Normal file
176
throttle/throttle.go
Normal file
@ -0,0 +1,176 @@
|
||||
package throttle
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"go.pkg.cx/middleware"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrCapacityExceeded = errors.New("server capacity exceeded")
|
||||
ErrTimedOut = errors.New("timed out while waiting for a pending request to complete")
|
||||
ErrContextCanceled = errors.New("context was canceled")
|
||||
)
|
||||
|
||||
// DefaultOptions represents default throttle middleware options
|
||||
var DefaultOptions = Options(
|
||||
SetLimit(100),
|
||||
SetResponseHandler(RespondWithTooManyRequests),
|
||||
)
|
||||
|
||||
// Options turns a list of option instances into an option.
|
||||
func Options(opts ...Option) Option {
|
||||
return func(t *throttle) {
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Option configures throttle middleware
|
||||
type Option func(t *throttle)
|
||||
|
||||
// SetLimit sets requests limit
|
||||
func SetLimit(limit int) Option {
|
||||
if limit < 1 {
|
||||
panic("throttle middleware expects limit > 0")
|
||||
}
|
||||
|
||||
return func(t *throttle) {
|
||||
t.limit = limit
|
||||
}
|
||||
}
|
||||
|
||||
// SetBacklogLimit sets backlog requests limit
|
||||
func SetBacklogLimit(backlogLimit int) Option {
|
||||
if backlogLimit < 0 {
|
||||
panic("throttle middleware expects backlogLimit >= 0")
|
||||
}
|
||||
|
||||
return func(t *throttle) {
|
||||
t.backlogLimit = backlogLimit
|
||||
}
|
||||
}
|
||||
|
||||
// SetBacklogTimeout sets backlog timeout
|
||||
func SetBacklogTimeout(backlogTimeout time.Duration) Option {
|
||||
return func(t *throttle) {
|
||||
t.backlogTimeout = backlogTimeout
|
||||
}
|
||||
}
|
||||
|
||||
// SetRetryAfterFn sets retry after function
|
||||
func SetRetryAfterFn(fn func(ctxDone bool) time.Duration) Option {
|
||||
return func(t *throttle) {
|
||||
t.retryAfterFn = fn
|
||||
}
|
||||
}
|
||||
|
||||
// SetResponseHandler sets response handler
|
||||
func SetResponseHandler(fn middleware.ResponseHandle) Option {
|
||||
return func(t *throttle) {
|
||||
t.responseHandler = fn
|
||||
}
|
||||
}
|
||||
|
||||
type throttle struct {
|
||||
limit int
|
||||
backlogLimit int
|
||||
backlogTimeout time.Duration
|
||||
retryAfterFn func(ctxDone bool) time.Duration
|
||||
responseHandler middleware.ResponseHandle
|
||||
|
||||
tokens chan struct{}
|
||||
backlogTokens chan struct{}
|
||||
}
|
||||
|
||||
func (s *throttle) initTokens() {
|
||||
s.tokens = make(chan struct{}, s.limit)
|
||||
s.backlogTokens = make(chan struct{}, s.limit+s.backlogLimit)
|
||||
|
||||
for i := 0; i < s.limit+s.backlogLimit; i++ {
|
||||
if i < s.limit {
|
||||
s.tokens <- struct{}{}
|
||||
}
|
||||
|
||||
s.backlogTokens <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *throttle) setRetryAfterHeader(w http.ResponseWriter, ctxDone bool) {
|
||||
if s.retryAfterFn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
retryAfterSeconds := strconv.Itoa(int(s.retryAfterFn(ctxDone).Seconds()))
|
||||
w.Header().Set("Retry-After", retryAfterSeconds)
|
||||
}
|
||||
|
||||
// Middleware is a throttle middleware that limits number of currently processed requests
|
||||
// at a time across all users. Note: Throttle is not a rate-limiter per user,
|
||||
// instead it just puts a ceiling on the number of currentl in-flight requests
|
||||
// being processed from the point from where the Throttle middleware is mounted
|
||||
func Middleware(opts ...Option) func(http.Handler) http.Handler {
|
||||
t := &throttle{}
|
||||
opts = append([]Option{DefaultOptions}, opts...)
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
|
||||
t.initTokens()
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
t.setRetryAfterHeader(w, true)
|
||||
t.responseHandler(w, r, ErrContextCanceled)
|
||||
return
|
||||
|
||||
case btok := <-t.backlogTokens:
|
||||
timer := time.NewTimer(t.backlogTimeout)
|
||||
|
||||
defer func() {
|
||||
t.backlogTokens <- btok
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-timer.C:
|
||||
t.setRetryAfterHeader(w, false)
|
||||
t.responseHandler(w, r, ErrTimedOut)
|
||||
return
|
||||
|
||||
case <-r.Context().Done():
|
||||
timer.Stop()
|
||||
t.setRetryAfterHeader(w, true)
|
||||
t.responseHandler(w, r, ErrContextCanceled)
|
||||
return
|
||||
|
||||
case tok := <-t.tokens:
|
||||
defer func() {
|
||||
timer.Stop()
|
||||
t.tokens <- tok
|
||||
}()
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
default:
|
||||
t.setRetryAfterHeader(w, false)
|
||||
t.responseHandler(w, r, ErrCapacityExceeded)
|
||||
return
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RespondWithTooManyRequests is a default response handler
|
||||
func RespondWithTooManyRequests(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
|
||||
}
|
245
throttle/throttle_test.go
Normal file
245
throttle/throttle_test.go
Normal file
@ -0,0 +1,245 @@
|
||||
package throttle
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestThrottleBacklog(t *testing.T) {
|
||||
throttle := Middleware(
|
||||
SetLimit(10),
|
||||
SetBacklogLimit(50),
|
||||
SetBacklogTimeout(time.Second*10),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
time.Sleep(time.Second * 1)
|
||||
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(throttle(handler))
|
||||
defer server.Close()
|
||||
|
||||
client := http.Client{Timeout: time.Second * 5}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 1; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
res, err := client.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
b, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "resp", string(b))
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestThrottleClientTimeout(t *testing.T) {
|
||||
throttle := Middleware(
|
||||
SetLimit(10),
|
||||
SetBacklogLimit(50),
|
||||
SetBacklogTimeout(time.Second*10),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
time.Sleep(time.Second * 5)
|
||||
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(throttle(handler))
|
||||
defer server.Close()
|
||||
|
||||
client := http.Client{Timeout: time.Second * 3}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
_, err := client.Get(server.URL)
|
||||
assert.Error(t, err)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestThrottleTriggerGatewayTimeout(t *testing.T) {
|
||||
throttle := Middleware(
|
||||
SetLimit(50),
|
||||
SetBacklogLimit(100),
|
||||
SetBacklogTimeout(time.Second*5),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
time.Sleep(time.Second * 10)
|
||||
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(throttle(handler))
|
||||
defer server.Close()
|
||||
|
||||
client := http.Client{Timeout: time.Second * 60}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
res, err := client.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
}(i)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 1)
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
res, err := client.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestThrottleMaximum(t *testing.T) {
|
||||
throttle := Middleware(
|
||||
SetLimit(10),
|
||||
SetBacklogLimit(10),
|
||||
SetBacklogTimeout(time.Second*5),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
time.Sleep(time.Second * 2)
|
||||
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(throttle(handler))
|
||||
defer server.Close()
|
||||
|
||||
client := http.Client{Timeout: time.Second * 60}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
res, err := client.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
b, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "resp", string(b))
|
||||
}(i)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 1)
|
||||
|
||||
for i := 0; i < 20; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
res, err := client.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestThrottleRetryAfter(t *testing.T) {
|
||||
throttle := Middleware(
|
||||
SetLimit(10),
|
||||
SetRetryAfterFn(func(_ bool) time.Duration { return time.Hour * 1 }),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
time.Sleep(time.Second * 3)
|
||||
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(throttle(handler))
|
||||
defer server.Close()
|
||||
|
||||
client := http.Client{Timeout: time.Second * 60}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
res, err := client.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}(i)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second * 1)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
res, err := client.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
|
||||
assert.Equal(t, res.Header.Get("Retry-After"), "3600")
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
81
timeout/timeout.go
Normal file
81
timeout/timeout.go
Normal file
@ -0,0 +1,81 @@
|
||||
package timeout
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"go.pkg.cx/middleware"
|
||||
)
|
||||
|
||||
// DefaultOptions represents default timeout middleware options
|
||||
var DefaultOptions = Options(
|
||||
SetTimeout(time.Second*30),
|
||||
SetResponseHandler(RespondWithTimeout),
|
||||
)
|
||||
|
||||
// Options turns a list of option instances into an option.
|
||||
func Options(opts ...Option) Option {
|
||||
return func(t *timeout) {
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Option configures timeout middleware
|
||||
type Option func(t *timeout)
|
||||
|
||||
// SetTimeout sets request timeout
|
||||
func SetTimeout(limit time.Duration) Option {
|
||||
return func(t *timeout) {
|
||||
t.timeout = limit
|
||||
}
|
||||
}
|
||||
|
||||
// SetResponseHandler sets response handler
|
||||
func SetResponseHandler(fn middleware.ResponseHandle) Option {
|
||||
return func(t *timeout) {
|
||||
t.responseHandler = fn
|
||||
}
|
||||
}
|
||||
|
||||
type timeout struct {
|
||||
timeout time.Duration
|
||||
responseHandler middleware.ResponseHandle
|
||||
}
|
||||
|
||||
// Middleware is a timeout middleware that cancels ctx after a given timeout
|
||||
// and return a 504 Gateway Timeout error to the client.
|
||||
//
|
||||
// It's required that you select the ctx.Done() channel to check for the signal
|
||||
// if the context has reached its deadline and return, otherwise the timeout
|
||||
// signal will be just ignored
|
||||
func Middleware(opts ...Option) func(next http.Handler) http.Handler {
|
||||
t := &timeout{}
|
||||
opts = append([]Option{DefaultOptions}, opts...)
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), t.timeout)
|
||||
defer func() {
|
||||
cancel()
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
t.responseHandler(w, r, ctx.Err())
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RespondWithTimeout is a default response handler
|
||||
func RespondWithTimeout(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, http.StatusText(http.StatusGatewayTimeout), http.StatusGatewayTimeout)
|
||||
}
|
62
timeout/timeout_test.go
Normal file
62
timeout/timeout_test.go
Normal file
@ -0,0 +1,62 @@
|
||||
package timeout
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTimeoutPass(t *testing.T) {
|
||||
timeout := Middleware(
|
||||
SetTimeout(time.Second * 1),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(timeout(handler))
|
||||
defer server.Close()
|
||||
|
||||
res, err := http.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
|
||||
b, err := ioutil.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "resp", string(b))
|
||||
}
|
||||
|
||||
func TestTimeoutTimedOut(t *testing.T) {
|
||||
timeout := Middleware(
|
||||
SetTimeout(time.Millisecond * 300),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
|
||||
case <-time.After(time.Second * 1):
|
||||
w.Write([]byte("resp")) // nolint:errcheck
|
||||
}
|
||||
})
|
||||
|
||||
server := httptest.NewServer(timeout(handler))
|
||||
defer server.Close()
|
||||
|
||||
res, err := http.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusGatewayTimeout, res.StatusCode)
|
||||
}
|
19
types.go
Normal file
19
types.go
Normal file
@ -0,0 +1,19 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// ResponseHandle is a function that middleware call in case of stop chain
|
||||
type ResponseHandle func(w http.ResponseWriter, r *http.Request, err error)
|
||||
|
||||
// CtxKey is a key to use with context.WithValue
|
||||
type CtxKey struct {
|
||||
Pkg string
|
||||
Name string
|
||||
}
|
||||
|
||||
// String returns string representation
|
||||
func (s *CtxKey) String() string {
|
||||
return s.Pkg + " context value " + s.Name
|
||||
}
|
Loading…
Reference in New Issue
Block a user