Merge some middlewares

This commit is contained in:
Anton Zadvorny 2020-11-07 14:59:33 +03:00
parent 7c06c3c6d8
commit f46dc4094f
30 changed files with 2613 additions and 10 deletions

View File

@ -6,14 +6,12 @@ import (
// AppInfo adds application name and version headers to the response // AppInfo adds application name and version headers to the response
func AppInfo(name string, version string) func(http.Handler) http.Handler { func AppInfo(name string, version string) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler { return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("App-Name", name) w.Header().Set("App-Name", name)
w.Header().Set("App-Version", version) w.Header().Set("App-Version", version)
h.ServeHTTP(w, r) next.ServeHTTP(w, r)
} })
return http.HandlerFunc(fn)
} }
} }

32
app_info_test.go Normal file
View File

@ -0,0 +1,32 @@
package middleware
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAppInfo(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(AppInfo("tapp", "tversion")(handler))
defer server.Close()
res, err := http.Get(server.URL)
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "tapp", res.Header.Get("App-Name"))
assert.Equal(t, "tversion", res.Header.Get("App-Version"))
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}

171
auth/jwt/jwt.go Normal file
View File

@ -0,0 +1,171 @@
package jwt
import (
"context"
"errors"
"net/http"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
"go.pkg.cx/middleware"
)
// Errors
var (
ErrTokenInvalid = errors.New("token invalid")
ErrTokenNotFound = errors.New("token not found")
)
// Context keys
var (
JWTCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/auth/jwt", Name: "JWT"}
DataCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/auth/token", Name: "Data"}
)
// DefaultOptions represents default jwt auth middleware options
var DefaultOptions = Options(
WithFindTokenFn(middleware.TokenFromAuthorizationHeader),
WithFindTokenFn(middleware.TokenFromQuery("jwt")),
WithFindTokenFn(middleware.TokenFromCookie("jwt")),
SetResponseHandler(RespondWithUnauthorized),
SetValidateTokenFn(allowAll),
)
// Options turns a list of option instances into an option.
func Options(opts ...Option) Option {
return func(a *auth) {
for _, opt := range opts {
opt(a)
}
}
}
// Option configures jwt auth middleware
type Option func(a *auth)
// SetKey sets key to verify jwt token with
func SetKey(key interface{}) Option {
return func(a *auth) {
a.key = key
}
}
// SetAlgorithm sets algorithm to verify jwt token with
func SetAlgorithm(alg jwa.SignatureAlgorithm) Option {
return func(a *auth) {
a.algorithm = alg
}
}
// WithFindTokenFn adds jwt token find function to the list
func WithFindTokenFn(fn func(r *http.Request) string) Option {
return func(a *auth) {
a.findTokenFns = append(a.findTokenFns, fn)
}
}
// SetFindTokenFns sets jwt token find functions list
func SetFindTokenFns(fns ...func(r *http.Request) string) Option {
return func(a *auth) {
a.findTokenFns = fns
}
}
// WithVerifyOption adds jwt verify option to the list
func WithVerifyOption(opt jwt.Option) Option {
return func(a *auth) {
a.verifyOptions = append(a.verifyOptions, opt)
}
}
// SetVerifyOptions sets jwt verify options list
func SetVerifyOptions(opts ...jwt.Option) Option {
return func(a *auth) {
a.verifyOptions = opts
}
}
// SetValidateTokenFn sets token validation function
func SetValidateTokenFn(fn func(token jwt.Token) (bool, interface{})) Option {
return func(a *auth) {
a.validateTokenFn = fn
}
}
// SetResponseHandler sets response handler
func SetResponseHandler(fn middleware.ResponseHandle) Option {
return func(a *auth) {
a.responseHandler = fn
}
}
type auth struct {
key interface{}
algorithm jwa.SignatureAlgorithm
verifyOptions []jwt.Option
findTokenFns []func(r *http.Request) string
validateTokenFn func(token jwt.Token) (bool, interface{})
responseHandler middleware.ResponseHandle
}
// Middleware returns jwt auth middleware
func Middleware(key interface{}, alg jwa.SignatureAlgorithm, opts ...Option) func(next http.Handler) http.Handler {
a := &auth{}
opts = append(opts, SetKey(key))
opts = append(opts, SetAlgorithm(alg))
opts = append([]Option{DefaultOptions}, opts...)
for _, opt := range opts {
opt(a)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var token string
for _, findTokenFn := range a.findTokenFns {
token = findTokenFn(r)
if token != "" {
break
}
}
if token == "" {
a.responseHandler(w, r, ErrTokenNotFound)
return
}
jwtToken, err := jwt.ParseString(token, jwt.WithVerify(a.algorithm, a.key))
if err != nil {
a.responseHandler(w, r, err)
return
}
if err := jwt.Verify(jwtToken, a.verifyOptions...); err != nil {
a.responseHandler(w, r, err)
return
}
valid, data := a.validateTokenFn(jwtToken)
if !valid {
a.responseHandler(w, r, ErrTokenInvalid)
return
}
ctx := r.Context()
ctx = context.WithValue(ctx, JWTCtxKey, jwtToken)
ctx = context.WithValue(ctx, DataCtxKey, data)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// RespondWithUnauthorized is a default response handler
func RespondWithUnauthorized(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
func allowAll(token jwt.Token) (bool, interface{}) {
return true, nil
}

192
auth/jwt/jwt_test.go Normal file
View File

@ -0,0 +1,192 @@
package jwt
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.pkg.cx/middleware"
)
var testHandler = func(t *testing.T) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
}
func TestJWTAuthDefaults(t *testing.T) {
auth := Middleware(
[]byte("changethis"),
jwa.HS256,
)
server := httptest.NewServer(auth(testHandler(t)))
defer server.Close()
res, err := http.Get(server.URL)
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "Unauthorized\n", string(b))
token := jwt.New()
payload, err := jwt.Sign(token, jwa.HS256, []byte("changethis"))
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?jwt=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err = ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}
func TestJWTAuthParseVerify(t *testing.T) {
auth := Middleware(
[]byte("tkey"),
jwa.HS512,
SetVerifyOptions(
jwt.WithSubject("tsubj"),
jwt.WithAudience("taud"),
),
SetFindTokenFns(
middleware.TokenFromQuery("token"),
),
)
server := httptest.NewServer(auth(testHandler(t)))
defer server.Close()
token := jwt.New()
payload, err := jwt.Sign(token, jwa.HS512, []byte("tkey"))
assert.NoError(t, err)
res, err := http.Get(server.URL + "/?token=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
payload, err = jwt.Sign(token, jwa.HS256, []byte("tkey"))
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?token=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
payload, err = jwt.Sign(token, jwa.HS512, []byte("wrongkey"))
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?token=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
now := time.Now()
token.Set(jwt.IssuedAtKey, now.Add(-1*time.Hour)) // nolint:errcheck
token.Set(jwt.ExpirationKey, now.Add(-58*time.Minute)) // nolint:errcheck
payload, err = jwt.Sign(token, jwa.HS512, []byte("tkey"))
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?token=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
}
func TestJWTAuthVerifyOptions(t *testing.T) {
auth := Middleware(
[]byte("changethis"),
jwa.HS256,
WithVerifyOption(jwt.WithIssuer("tissuer")),
)
server := httptest.NewServer(auth(testHandler(t)))
defer server.Close()
token := jwt.New()
token.Set(jwt.IssuerKey, "wrongissuer") // nolint:errcheck
payload, err := jwt.Sign(token, jwa.HS256, []byte("changethis"))
assert.NoError(t, err)
res, err := http.Get(server.URL + "/?jwt=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
token.Set(jwt.IssuerKey, "tissuer") // nolint:errcheck
payload, err = jwt.Sign(token, jwa.HS256, []byte("changethis"))
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?jwt=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}
func TestJWTAuthContext(t *testing.T) {
type data struct {
inner string
}
validateWithDataFn := func(token jwt.Token) (bool, interface{}) {
return token.JwtID() == "tid", &data{"test data"}
}
auth := Middleware(
[]byte("changethis"),
jwa.HS256,
SetValidateTokenFn(validateWithDataFn),
)
token := jwt.New()
token.Set(jwt.JwtIDKey, "tid") // nolint:errcheck
testCtxHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
data, ok := r.Context().Value(DataCtxKey).(*data)
require.True(t, ok)
assert.Equal(t, "test data", data.inner)
token, ok := r.Context().Value(JWTCtxKey).(jwt.Token)
require.True(t, ok)
assert.Equal(t, "tid", token.JwtID())
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(auth(testCtxHandler))
defer server.Close()
payload, err := jwt.Sign(token, jwa.HS256, []byte("changethis"))
assert.NoError(t, err)
res, err := http.Get(server.URL + "/?jwt=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
token.Set(jwt.JwtIDKey, "invalid") // nolint:errcheck
payload, err = jwt.Sign(token, jwa.HS256, []byte("changethis"))
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?jwt=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
}

123
auth/token/token.go Normal file
View File

@ -0,0 +1,123 @@
package token
import (
"context"
"errors"
"net/http"
"go.pkg.cx/middleware"
)
// Errors
var (
ErrTokenInvalid = errors.New("token invalid")
ErrTokenNotFound = errors.New("token not found")
)
// Context keys
var (
DataCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/auth/token", Name: "Data"}
TokenCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/auth/token", Name: "Token"}
)
// DefaultOptions represents default token auth middleware options
var DefaultOptions = Options(
WithFindTokenFn(middleware.TokenFromHeader("X-Token")),
WithFindTokenFn(middleware.TokenFromQuery("token")),
SetValidateTokenFn(rejectAll),
SetResponseHandler(RespondWithUnauthorized),
)
// Options turns a list of option instances into an option.
func Options(opts ...Option) Option {
return func(a *auth) {
for _, opt := range opts {
opt(a)
}
}
}
// Option configures token auth middleware
type Option func(a *auth)
// WithFindTokenFn adds token find function to the list
func WithFindTokenFn(fn func(r *http.Request) string) Option {
return func(a *auth) {
a.findTokenFns = append(a.findTokenFns, fn)
}
}
// SetFindTokenFns sets token find functions list
func SetFindTokenFns(fns ...func(r *http.Request) string) Option {
return func(a *auth) {
a.findTokenFns = fns
}
}
// SetValidateTokenFn sets token validation function
func SetValidateTokenFn(fn func(token string) (bool, interface{})) Option {
return func(a *auth) {
a.validateTokenFn = fn
}
}
// SetResponseHandler sets response handler
func SetResponseHandler(fn middleware.ResponseHandle) Option {
return func(a *auth) {
a.responseHandler = fn
}
}
type auth struct {
findTokenFns []func(r *http.Request) string
validateTokenFn func(token string) (bool, interface{})
responseHandler middleware.ResponseHandle
}
// Middleware returns token auth middleware
func Middleware(opts ...Option) func(next http.Handler) http.Handler {
a := &auth{}
opts = append([]Option{DefaultOptions}, opts...)
for _, opt := range opts {
opt(a)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var token string
for _, findTokenFn := range a.findTokenFns {
token = findTokenFn(r)
if token != "" {
break
}
}
if token == "" {
a.responseHandler(w, r, ErrTokenNotFound)
return
}
valid, data := a.validateTokenFn(token)
if !valid {
a.responseHandler(w, r, ErrTokenInvalid)
return
}
ctx := r.Context()
ctx = context.WithValue(ctx, DataCtxKey, data)
ctx = context.WithValue(ctx, TokenCtxKey, token)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// RespondWithUnauthorized is a default response handler
func RespondWithUnauthorized(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
func rejectAll(token string) (bool, interface{}) {
return false, nil
}

123
auth/token/token_test.go Normal file
View File

@ -0,0 +1,123 @@
package token
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.pkg.cx/middleware"
)
var testHandler = func(t *testing.T) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
}
var validateFn = func(token string) (bool, interface{}) {
return token == "valid", nil
}
func TestTokenAuthDefaults(t *testing.T) {
auth := Middleware()
server := httptest.NewServer(auth(testHandler(t)))
defer server.Close()
res, err := http.Get(server.URL)
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "Unauthorized\n", string(b))
res, err = http.Get(server.URL + "?token=invalid")
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
}
func TestTokenAuthValidateFn(t *testing.T) {
auth := Middleware(
SetValidateTokenFn(validateFn),
)
server := httptest.NewServer(auth(testHandler(t)))
defer server.Close()
res, err := http.Get(server.URL + "?token=invalid")
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
res, err = http.Get(server.URL + "?token=valid")
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}
func TestTokenAuthFindTokenFns(t *testing.T) {
auth := Middleware(
SetValidateTokenFn(validateFn),
SetFindTokenFns(middleware.TokenFromQuery("api_key")),
)
server := httptest.NewServer(auth(testHandler(t)))
defer server.Close()
res, err := http.Get(server.URL + "?token=valid")
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
res, err = http.Get(server.URL + "?api_key=valid")
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}
func TestTokenAuthContext(t *testing.T) {
type data struct {
inner string
}
validateWithDataFn := func(token string) (bool, interface{}) {
return token == "valid", &data{"test data"}
}
auth := Middleware(
SetValidateTokenFn(validateWithDataFn),
)
testCtxHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
data, ok := r.Context().Value(DataCtxKey).(*data)
require.True(t, ok)
assert.Equal(t, "test data", data.inner)
token, ok := r.Context().Value(TokenCtxKey).(string)
require.True(t, ok)
assert.Equal(t, "valid", token)
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(auth(testCtxHandler))
defer server.Close()
res, err := http.Get(server.URL + "?token=valid")
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}

44
find_token.go Normal file
View File

@ -0,0 +1,44 @@
package middleware
import (
"net/http"
"strings"
)
// TokenFromAuthorizationHeader tries to retrieve the token string from the
// "Authorization" request header and strips BEARER prefix if necessary
func TokenFromAuthorizationHeader(r *http.Request) string {
header := r.Header.Get("Authorization")
if len(header) > 7 && strings.ToUpper(header[0:6]) == "BEARER" {
return header[7:]
}
return header
}
// TokenFromHeader tries to retrieve the token string from the given header
func TokenFromHeader(headerKey string) func(r *http.Request) string {
return func(r *http.Request) string {
return r.Header.Get(headerKey)
}
}
// TokenFromQuery tries to retrieve the token string from the given query param
func TokenFromQuery(param string) func(r *http.Request) string {
return func(r *http.Request) string {
return r.URL.Query().Get(param)
}
}
// TokenFromCookie tries to retrieve the token string from a given cookie
func TokenFromCookie(cookieName string) func(r *http.Request) string {
return func(r *http.Request) string {
cookie, err := r.Cookie(cookieName)
if err != nil {
return ""
}
return cookie.Value
}
}

50
find_token_test.go Normal file
View File

@ -0,0 +1,50 @@
package middleware
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestTokenFromAuthorizationHeader(t *testing.T) {
req, err := http.NewRequest("GET", "/", nil)
require.Nil(t, err)
req.Header.Set("Authorization", "abc")
assert.Equal(t, "abc", TokenFromAuthorizationHeader(req))
req, err = http.NewRequest("GET", "/", nil)
require.Nil(t, err)
req.Header.Set("Authorization", "abcdefghe")
assert.Equal(t, "abcdefghe", TokenFromAuthorizationHeader(req))
req, err = http.NewRequest("GET", "/", nil)
require.Nil(t, err)
req.Header.Set("Authorization", "Bearer abc")
assert.Equal(t, "abc", TokenFromAuthorizationHeader(req))
}
func TestTokenFromHeader(t *testing.T) {
req, err := http.NewRequest("GET", "/", nil)
require.Nil(t, err)
req.Header.Set("X-Token", "abc")
assert.Equal(t, "abc", TokenFromHeader("X-Token")(req))
}
func TestTokenFromQuery(t *testing.T) {
req, err := http.NewRequest("GET", "/?token=abc", nil)
require.Nil(t, err)
assert.Equal(t, "abc", TokenFromQuery("token")(req))
}
func TestTokenFromCookie(t *testing.T) {
req, err := http.NewRequest("GET", "/", nil)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "token", Value: "abc"})
assert.Equal(t, "abc", TokenFromCookie("token")(req))
req, err = http.NewRequest("GET", "/", nil)
require.Nil(t, err)
assert.Equal(t, "", TokenFromCookie("token")(req))
}

6
go.mod
View File

@ -1,3 +1,9 @@
module go.pkg.cx/middleware module go.pkg.cx/middleware
go 1.15 go 1.15
require (
github.com/go-pkgz/expirable-cache v0.0.3
github.com/lestrrat-go/jwx v1.0.5
github.com/stretchr/testify v1.6.1
)

40
go.sum Normal file
View File

@ -0,0 +1,40 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-pkgz/expirable-cache v0.0.3 h1:rTh6qNPp78z0bQE6HDhXBHUwqnV9i09Vm6dksJLXQDc=
github.com/go-pkgz/expirable-cache v0.0.3/go.mod h1:+IauqN00R2FqNRLCLA+X5YljQJrwB179PfiAoMPlTlQ=
github.com/lestrrat-go/iter v0.0.0-20200422075355-fc1769541911 h1:FvnrqecqX4zT0wOIbYK1gNgTm0677INEWiFY8UEYggY=
github.com/lestrrat-go/iter v0.0.0-20200422075355-fc1769541911/go.mod h1:zIdgO1mRKhn8l9vrZJZz9TUMMFbQbLeTsbqPDrJ/OJc=
github.com/lestrrat-go/jwx v1.0.5 h1:8bVUGXXkR3+YQNwuFof3lLxSJMLtrscHJfGI6ZIBRD0=
github.com/lestrrat-go/jwx v1.0.5/go.mod h1:TPF17WiSFegZo+c20fdpw49QD+/7n4/IsGvEmCSWwT0=
github.com/lestrrat-go/pdebug v0.0.0-20200204225717-4d6bd78da58d/go.mod h1:B06CSso/AWxiPejj+fheUINGeBKeeEZNt8w+EoU7+L8=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200417140056-c07e33ef3290/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

31
logger/body_util.go Normal file
View File

@ -0,0 +1,31 @@
package logger
import (
"bytes"
"io"
"regexp"
)
var regexpMultiWhitespace = regexp.MustCompile(`[\s\p{Zs}]{2,}`)
func peek(r io.Reader, n int64) (io.Reader, string, bool, error) {
if n < 0 {
n = 0
}
var buf bytes.Buffer
_, err := io.CopyN(&buf, r, n+1)
if err == io.EOF {
return &buf, buf.String(), false, nil
}
if err != nil {
return nil, "", false, err
}
s := buf.String()
s = s[:len(s)-1]
return io.MultiReader(&buf, r), s, true, nil
}

54
logger/body_util_test.go Normal file
View File

@ -0,0 +1,54 @@
package logger
import (
"errors"
"io/ioutil"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
type errReader struct{}
func (errReader) Read(_ []byte) (n int, err error) {
return 0, errors.New("test error")
}
func TestLoggerPeek(t *testing.T) {
cases := []struct {
body string
n int64
excerpt string
hasMore bool
}{
{"", -1, "", false},
{"", 0, "", false},
{"", 1024, "", false},
{"123456", -1, "", true},
{"123456", 0, "", true},
{"123456", 4, "1234", true},
{"123456", 5, "12345", true},
{"123456", 6, "123456", false},
{"123456", 7, "123456", false},
}
for _, c := range cases {
r, excerpt, hasMore, err := peek(strings.NewReader(c.body), c.n)
if !assert.NoError(t, err) {
continue
}
body, err := ioutil.ReadAll(r)
if !assert.NoError(t, err) {
continue
}
assert.Equal(t, c.body, string(body))
assert.Equal(t, c.excerpt, excerpt)
assert.Equal(t, c.hasMore, hasMore)
}
_, _, _, err := peek(errReader{}, 1024)
assert.Error(t, err)
}

203
logger/logger.go Normal file
View File

@ -0,0 +1,203 @@
package logger
import (
"io/ioutil"
"log"
"net/http"
"net/url"
"strings"
"time"
)
// DefaultOptions represents default logger middleware options
var DefaultOptions = Options(
WithBody(),
SetMaxBodySize(1*1024*1024),
SetIPFn(defaultIPFn),
SetLogHandler(DefaultLogHandler),
)
// Options turns a list of option instances into an option.
func Options(opts ...Option) Option {
return func(l *logger) {
for _, opt := range opts {
opt(l)
}
}
}
// Option configures logger middleware
type Option func(l *logger)
// WithBody enables request body logging
func WithBody() Option {
return func(l *logger) {
l.logBody = true
}
}
// WithoutBody disables request body logging
func WithoutBody() Option {
return func(l *logger) {
l.logBody = false
}
}
// SetMaxBodySize sets maximum request body size
func SetMaxBodySize(size int) Option {
return func(l *logger) {
l.maxBodySize = size
}
}
// SetIPFn sets function that extracts ip from request
func SetIPFn(fn func(r *http.Request) string) Option {
return func(l *logger) {
l.ipFn = fn
}
}
// SetUserFn sets function that extracts user from request
func SetUserFn(fn func(r *http.Request) string) Option {
return func(l *logger) {
l.userFn = fn
}
}
// SetSanitizeFn sets function that sanitizes request query or body
func SetSanitizeFn(fn func(input string) string) Option {
return func(l *logger) {
l.sanitizeFn = fn
}
}
// SetLogHandler sets log handler
func SetLogHandler(fn func(entry LogEntry)) Option {
return func(l *logger) {
l.logHandler = fn
}
}
// LogEntry is a http log entry
type LogEntry struct {
Method string
RawURL string
Body string
RemoteIP string
StatusCode int
Written int
Duration time.Duration
User string
TraceID string
}
type logger struct {
logBody bool
maxBodySize int
ipFn func(r *http.Request) string
userFn func(r *http.Request) string
sanitizeFn func(input string) string
logHandler func(entry LogEntry)
}
func (s *logger) body(r *http.Request) string {
if !s.logBody {
return ""
}
rdr, body, hasMore, err := peek(r.Body, int64(s.maxBodySize))
if err != nil {
return ""
}
r.Body = ioutil.NopCloser(rdr)
if len(body) > 0 {
body = strings.Replace(body, "\n", " ", -1)
body = regexpMultiWhitespace.ReplaceAllString(body, " ")
}
if hasMore {
body += "..."
}
return body
}
// Middleware is a http logger middleware
func Middleware(opts ...Option) func(http.Handler) http.Handler {
l := &logger{}
opts = append([]Option{DefaultOptions}, opts...)
for _, opt := range opts {
opt(l)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ww := newTrackingResponseWriter(w)
body := l.body(r)
if l.sanitizeFn != nil {
body = l.sanitizeFn(body)
}
startTS := time.Now()
defer func() {
stopTS := time.Now()
query := r.URL.String()
if qun, err := url.QueryUnescape(query); err == nil {
query = qun
}
if l.sanitizeFn != nil {
query = l.sanitizeFn(query)
}
var ip string
if l.ipFn != nil {
ip = l.ipFn(r)
}
var user string
if l.userFn != nil {
user = l.userFn(r)
}
entry := LogEntry{
Method: r.Method,
RawURL: query,
RemoteIP: ip,
Body: body,
StatusCode: ww.status,
Written: ww.size,
Duration: stopTS.Sub(startTS),
User: user,
TraceID: r.Header.Get("X-Request-ID"),
}
l.logHandler(entry)
}()
next.ServeHTTP(ww, r)
})
}
}
// DefaultLogHandler is a default log handler
func DefaultLogHandler(entry LogEntry) {
log.Printf(
"%s - %s - %s - %d (%d) - %v",
entry.Method,
entry.RawURL,
entry.RemoteIP,
entry.StatusCode,
entry.Written,
entry.Duration,
)
}
func defaultIPFn(r *http.Request) string {
return strings.Split(r.RemoteAddr, ":")[0]
}

134
logger/logger_test.go Normal file
View File

@ -0,0 +1,134 @@
package logger
import (
"bytes"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLogger(t *testing.T) {
logHandler := func(entry LogEntry) {
assert.Equal(t, entry.Method, "GET")
assert.Equal(t, entry.RawURL, "/")
assert.Equal(t, entry.Body, "")
assert.Equal(t, entry.RemoteIP, "127.0.0.1")
assert.Equal(t, entry.StatusCode, http.StatusOK)
assert.Equal(t, entry.Written, len([]byte("resp")))
}
logger := Middleware(
WithoutBody(),
SetLogHandler(logHandler),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(logger(handler))
defer server.Close()
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}
func TestLoggerBody(t *testing.T) {
logHandler := func(entry LogEntry) {
assert.Equal(t, entry.Method, "POST")
assert.Equal(t, entry.Body, "1234567890...")
}
logger := Middleware(
SetMaxBodySize(10),
SetLogHandler(logHandler),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
assert.NoError(t, err)
assert.Equal(t, "12345678901234567890", string(body))
_, err = w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(logger(handler))
defer server.Close()
res, err := http.Post(server.URL, "", bytes.NewBufferString("12345678901234567890"))
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}
func TestLoggerUser(t *testing.T) {
logHandler := func(entry LogEntry) {
assert.Equal(t, entry.User, "user")
}
logger := Middleware(
SetUserFn(func(req *http.Request) string { return "user" }),
SetLogHandler(logHandler),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(logger(handler))
defer server.Close()
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}
func TestLoggerSanitize(t *testing.T) {
logHandler := func(entry LogEntry) {
assert.Equal(t, entry.RawURL, "/?param=*****")
assert.Equal(t, entry.Body, "body|*****")
}
sanitizeFn := func(input string) string {
return strings.ReplaceAll(input, "password", "*****")
}
logger := Middleware(
SetSanitizeFn(sanitizeFn),
SetLogHandler(logHandler),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.URL.Query().Get("param"), "password")
body, err := ioutil.ReadAll(r.Body)
assert.NoError(t, err)
assert.Equal(t, "body|password", string(body))
_, err = w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(logger(handler))
defer server.Close()
res, err := http.Post(server.URL+"?param=password", "", bytes.NewBufferString("body|password"))
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}

View File

@ -0,0 +1,53 @@
package logger
import (
"bufio"
"errors"
"net"
"net/http"
)
var (
errWriterNotImplentsHijacker = errors.New("ResponseWriter does not implement the Hijacker interface") // nolint:golint
)
type trackingResponseWriter struct {
http.ResponseWriter
size int
status int
}
func newTrackingResponseWriter(w http.ResponseWriter) *trackingResponseWriter {
return &trackingResponseWriter{ResponseWriter: w, status: 200}
}
// WriteHeader implements http.ResponseWriter and saves status
func (s *trackingResponseWriter) WriteHeader(status int) {
s.status = status
s.ResponseWriter.WriteHeader(status)
}
// Write implements http.ResponseWriter and tracks number of bytes written
func (s *trackingResponseWriter) Write(b []byte) (int, error) {
size, err := s.ResponseWriter.Write(b)
s.size += size
return size, err
}
// Flush implements http.Flusher
func (s *trackingResponseWriter) Flush() {
if f, ok := s.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
// Hijack implements http.Hijacker
func (s *trackingResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hj, ok := s.ResponseWriter.(http.Hijacker); ok {
return hj.Hijack()
}
return nil, nil, errWriterNotImplentsHijacker
}

41
no_cache.go Normal file
View File

@ -0,0 +1,41 @@
package middleware
import (
"net/http"
"time"
)
var epoch = time.Unix(0, 0).Format(time.RFC1123)
var etagHeaders = []string{
"ETag",
"If-Modified-Since",
"If-Match",
"If-None-Match",
"If-Range",
"If-Unmodified-Since",
}
var noCacheHeaders = map[string]string{
"Expires": epoch,
"Cache-Control": "no-cache, no-store, no-transform, must-revalidate, private, max-age=0",
"Pragma": "no-cache",
"X-Accel-Expires": "0",
}
// NoCache sets a number of HTTP headers to prevent a router from being cached
func NoCache(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, v := range etagHeaders {
if r.Header.Get(v) != "" {
r.Header.Del(v)
}
}
for k, v := range noCacheHeaders {
w.Header().Set(k, v)
}
next.ServeHTTP(w, r)
})
}

43
no_cache_test.go Normal file
View File

@ -0,0 +1,43 @@
package middleware
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNoCache(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "", r.Header.Get("ETag"))
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(NoCache(handler))
defer server.Close()
client := http.Client{}
req, err := http.NewRequest("GET", server.URL, nil)
require.Nil(t, err)
req.Header.Set("ETag", "ETagValue")
res, err := client.Do(req)
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, time.Unix(0, 0).Format(time.RFC1123), res.Header.Get("Expires"))
assert.Equal(t, "no-cache, no-store, no-transform, must-revalidate, private, max-age=0", res.Header.Get("Cache-Control"))
assert.Equal(t, "no-cache", res.Header.Get("Pragma"))
assert.Equal(t, "0", res.Header.Get("X-Accel-Expires"))
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}

View File

@ -7,7 +7,7 @@ import (
// Ping responses with pong to /ping request and stops chain // Ping responses with pong to /ping request and stops chain
func Ping(next http.Handler) http.Handler { func Ping(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" && strings.HasSuffix(strings.ToLower(r.URL.Path), "/ping") { if r.Method == "GET" && strings.HasSuffix(strings.ToLower(r.URL.Path), "/ping") {
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@ -17,7 +17,5 @@ func Ping(next http.Handler) http.Handler {
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
} })
return http.HandlerFunc(fn)
} }

48
ping_test.go Normal file
View File

@ -0,0 +1,48 @@
package middleware
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPing(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(Ping(handler))
defer server.Close()
res, err := http.Get(server.URL)
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
res, err = http.Get(server.URL + "/ping")
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err = ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "pong", string(b))
res, err = http.Get(server.URL + "/a/b/c/ping")
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err = ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "pong", string(b))
}

175
ratelimit/rate_limit.go Normal file
View File

@ -0,0 +1,175 @@
package ratelimit
import (
"errors"
"net/http"
"strconv"
"strings"
"sync"
"time"
cache "github.com/go-pkgz/expirable-cache"
"go.pkg.cx/middleware"
)
// Errors
var (
ErrLimitReached = errors.New("limit reached")
)
// DefaultOptions represents default timeout middleware options
var DefaultOptions = Options(
SetLimit(100),
SetPeriod(time.Minute*1),
SetKeyFn(defaultKeyFn),
SetResponseHandler(RespondWithTooManyRequests),
)
// Options turns a list of option instances into an option.
func Options(opts ...Option) Option {
return func(l *limiter) {
for _, opt := range opts {
opt(l)
}
}
}
// Option configures timeout middleware
type Option func(l *limiter)
// SetLimit sets request limit
func SetLimit(limit int) Option {
if limit < 1 {
panic("rate limit middleware expects limit > 0")
}
return func(l *limiter) {
l.limit = limit
}
}
// SetPeriod sets limiter period
func SetPeriod(period time.Duration) Option {
return func(l *limiter) {
l.period = period
}
}
// SetKeyFn sets limiter key extraction function
func SetKeyFn(fn func(r *http.Request) string) Option {
return func(l *limiter) {
l.keyFn = fn
}
}
// SetResponseHandler sets response handler
func SetResponseHandler(fn middleware.ResponseHandle) Option {
return func(l *limiter) {
l.responseHandler = fn
}
}
type info struct {
limit int
remaining int
reset int64
reached bool
}
type entry struct {
count int
expiration time.Time
}
type limiter struct {
limit int
period time.Duration
keyFn func(r *http.Request) string
responseHandler middleware.ResponseHandle
lock sync.Mutex
cache cache.Cache
}
func (s *limiter) initCache() {
c, err := cache.NewCache(cache.TTL(s.period))
if err != nil {
panic(err)
}
s.cache = c
}
func (s *limiter) try(key string) info {
s.lock.Lock()
defer s.lock.Unlock()
now := time.Now()
if e, ok := s.cache.Get(key); ok {
e.(*entry).count++
s.cache.Set(key, e, 0)
return s.infoFromEntry(e.(*entry))
}
e := &entry{count: 1, expiration: now.Add(s.period)}
s.cache.Set(key, e, 0)
return s.infoFromEntry(e)
}
func (s *limiter) infoFromEntry(e *entry) info {
reached := true
remaining := 0
if e.count <= s.limit {
reached = false
remaining = s.limit - e.count
}
return info{
limit: s.limit,
remaining: remaining,
reset: e.expiration.Unix(),
reached: reached,
}
}
// Middleware is a rate limiter middleware
func Middleware(opts ...Option) func(next http.Handler) http.Handler {
l := &limiter{}
opts = append([]Option{DefaultOptions}, opts...)
for _, opt := range opts {
opt(l)
}
l.initCache()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
info := l.try(l.keyFn(r))
w.Header().Add("X-RateLimit-Limit", strconv.Itoa(info.limit))
w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(info.remaining))
w.Header().Add("X-RateLimit-Reset", strconv.FormatInt(info.reset, 10))
if info.reached {
l.responseHandler(w, r, ErrLimitReached)
return
}
next.ServeHTTP(w, r)
})
}
}
// RespondWithTooManyRequests is a default response handler
func RespondWithTooManyRequests(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
}
func defaultKeyFn(r *http.Request) string {
return strings.Split(r.RemoteAddr, ":")[0]
}

View File

@ -0,0 +1,163 @@
package ratelimit
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRateLimitSequential(t *testing.T) {
rateLimit := Middleware(
SetLimit(5),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(rateLimit(handler))
defer server.Close()
for i := 0; i < 10; i++ {
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
if i < 5 {
assert.Equal(t, http.StatusOK, res.StatusCode)
} else {
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
}
}
}
func TestRateLimitConcurrent(t *testing.T) {
rateLimit := Middleware(
SetLimit(5),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(rateLimit(handler))
defer server.Close()
counter := int64(0)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
if res.StatusCode == http.StatusOK {
atomic.AddInt64(&counter, 1)
}
}()
}
wg.Wait()
assert.Equal(t, int64(5), atomic.LoadInt64(&counter))
}
func TestRateLimitHeaders(t *testing.T) {
rateLimit := Middleware(
SetLimit(1),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(rateLimit(handler))
defer server.Close()
now := time.Now()
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit"))
assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining"))
resetTS, err := strconv.Atoi(res.Header.Get("X-RateLimit-Reset"))
assert.NoError(t, err)
assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
res, err = http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit"))
assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining"))
resetTS, err = strconv.Atoi(res.Header.Get("X-RateLimit-Reset"))
assert.NoError(t, err)
assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1)
b, err = ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "Too Many Requests\n", string(b))
}
func TestRateLimitExpiration(t *testing.T) {
rateLimit := Middleware(
SetLimit(5),
SetPeriod(time.Millisecond*500),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(rateLimit(handler))
defer server.Close()
for i := 0; i < 10; i++ {
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
if i < 5 {
assert.Equal(t, http.StatusOK, res.StatusCode)
} else {
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
}
}
time.Sleep(time.Millisecond * 500)
for i := 0; i < 10; i++ {
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
if i < 5 {
assert.Equal(t, http.StatusOK, res.StatusCode)
} else {
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
}
}
}

48
real_ip.go Normal file
View File

@ -0,0 +1,48 @@
package middleware
import (
"net/http"
"strings"
)
// RealIP is a middleware that sets a http.Request's RemoteAddr to the results
// of parsing either the X-Forwarded-For header or the X-Real-IP header (in that
// order).
//
// This middleware should be inserted fairly early in the middleware stack to
// ensure that subsequent layers (e.g., request loggers) which examine the
// RemoteAddr will see the intended value.
//
// You should only use this middleware if you can trust the headers passed to
// you (in particular, the two headers this middleware uses), for example
// because you have placed a reverse proxy like HAProxy or nginx in front of
// chi. If your reverse proxies are configured to pass along arbitrary header
// values from the client, or if you use this middleware without a reverse
// proxy, malicious clients will be able to make you very sad (or, depending on
// how you're using RemoteAddr, vulnerable to an attack of some sort).
func RealIP(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if rip := realIP(r); rip != "" {
r.RemoteAddr = rip
}
next.ServeHTTP(w, r)
})
}
func realIP(r *http.Request) string {
if xrip := r.Header.Get("X-Real-IP"); xrip != "" {
return xrip
}
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
i := strings.Index(xff, ", ")
if i == -1 {
i = len(xff)
}
return xff[:i]
}
return ""
}

100
real_ip_test.go Normal file
View File

@ -0,0 +1,100 @@
package middleware
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var realIPTestHandler = func(t *testing.T) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.RemoteAddr, "3.3.3.3")
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
}
func TestRealIPXRealIP(t *testing.T) {
server := httptest.NewServer(RealIP(realIPTestHandler(t)))
defer server.Close()
client := http.Client{}
req, err := http.NewRequest("GET", server.URL, nil)
require.Nil(t, err)
req.Header.Set("X-Real-IP", "3.3.3.3")
res, err := client.Do(req)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}
func TestRealIPXForwardedFor(t *testing.T) {
server := httptest.NewServer(RealIP(realIPTestHandler(t)))
defer server.Close()
client := http.Client{}
req, err := http.NewRequest("GET", server.URL, nil)
require.Nil(t, err)
req.Header.Set("X-Forwarded-For", "3.3.3.3")
res, err := client.Do(req)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
req, err = http.NewRequest("GET", server.URL, nil)
require.Nil(t, err)
req.Header.Set("X-Forwarded-For", "3.3.3.3, 4.4.4.4, 5.5.5.5")
res, err = client.Do(req)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}
func TestRealIPBothHeaders(t *testing.T) {
server := httptest.NewServer(RealIP(realIPTestHandler(t)))
defer server.Close()
client := http.Client{}
req, err := http.NewRequest("GET", server.URL, nil)
require.Nil(t, err)
req.Header.Set("X-Real-IP", "3.3.3.3")
req.Header.Set("X-Forwarded-For", "4.4.4.4")
res, err := client.Do(req)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}
func TestRealIPNoHeaders(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, strings.Split(r.RemoteAddr, ":")[0], "127.0.0.1")
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(RealIP(handler))
defer server.Close()
res, err := http.Get(server.URL)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}

92
recoverer/recoverer.go Normal file
View File

@ -0,0 +1,92 @@
package recoverer
import (
"fmt"
"net/http"
"runtime/debug"
"go.pkg.cx/middleware"
)
// DefaultOptions represents default recoverer middleware options
var DefaultOptions = Options(
SetLogStackFn(defaultLogStackFn),
SetLogRecoverFn(defaultLogRecoverFn),
SetResponseHandler(RespondWithInternalServerError),
)
// Options turns a list of option instances into an option.
func Options(opts ...Option) Option {
return func(r *recoverer) {
for _, opt := range opts {
opt(r)
}
}
}
// SetLogStackFn sets log stack function
func SetLogStackFn(fn func(stack []byte)) Option {
return func(r *recoverer) {
r.logStackFn = fn
}
}
// SetLogRecoverFn sets log recover information function
func SetLogRecoverFn(fn func(rec interface{})) Option {
return func(r *recoverer) {
r.logRecoverFn = fn
}
}
// SetResponseHandler sets response handler
func SetResponseHandler(fn middleware.ResponseHandle) Option {
return func(r *recoverer) {
r.responseHandler = fn
}
}
// Option configures recoverer middleware
type Option func(r *recoverer)
type recoverer struct {
logStackFn func(stack []byte)
logRecoverFn func(rec interface{})
responseHandler middleware.ResponseHandle
}
// Middleware is a recoverer middleware that recovers from panic
func Middleware(opts ...Option) func(next http.Handler) http.Handler {
rvr := &recoverer{}
opts = append([]Option{DefaultOptions}, opts...)
for _, opt := range opts {
opt(rvr)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if rec := recover(); rec != nil {
rvr.logRecoverFn(rec)
rvr.logStackFn(debug.Stack())
rvr.responseHandler(w, r, nil)
}
}()
next.ServeHTTP(w, r)
})
}
}
// RespondWithInternalServerError is a default response handler
func RespondWithInternalServerError(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
func defaultLogStackFn(stack []byte) {
fmt.Println(string(stack))
}
func defaultLogRecoverFn(rec interface{}) {
fmt.Printf("request panic, %v", rec)
}

View File

@ -0,0 +1,58 @@
package recoverer
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRecoverer(t *testing.T) {
logStackFn := func(stack []byte) {
s := string(stack)
assert.Contains(t, s, "goroutine")
assert.Contains(t, s, "TestRecoverer")
}
logRecoverFn := func(rec interface{}) {
s, ok := rec.(string)
assert.True(t, ok)
assert.Contains(t, s, "panic message")
}
recoverer := Middleware(
SetLogStackFn(logStackFn),
SetLogRecoverFn(logRecoverFn),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/panic" {
panic("panic message")
}
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(recoverer(handler))
defer server.Close()
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
res, err = http.Get(server.URL + "/panic")
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
}

176
throttle/throttle.go Normal file
View File

@ -0,0 +1,176 @@
package throttle
import (
"errors"
"net/http"
"strconv"
"time"
"go.pkg.cx/middleware"
)
// Errors
var (
ErrCapacityExceeded = errors.New("server capacity exceeded")
ErrTimedOut = errors.New("timed out while waiting for a pending request to complete")
ErrContextCanceled = errors.New("context was canceled")
)
// DefaultOptions represents default throttle middleware options
var DefaultOptions = Options(
SetLimit(100),
SetResponseHandler(RespondWithTooManyRequests),
)
// Options turns a list of option instances into an option.
func Options(opts ...Option) Option {
return func(t *throttle) {
for _, opt := range opts {
opt(t)
}
}
}
// Option configures throttle middleware
type Option func(t *throttle)
// SetLimit sets requests limit
func SetLimit(limit int) Option {
if limit < 1 {
panic("throttle middleware expects limit > 0")
}
return func(t *throttle) {
t.limit = limit
}
}
// SetBacklogLimit sets backlog requests limit
func SetBacklogLimit(backlogLimit int) Option {
if backlogLimit < 0 {
panic("throttle middleware expects backlogLimit >= 0")
}
return func(t *throttle) {
t.backlogLimit = backlogLimit
}
}
// SetBacklogTimeout sets backlog timeout
func SetBacklogTimeout(backlogTimeout time.Duration) Option {
return func(t *throttle) {
t.backlogTimeout = backlogTimeout
}
}
// SetRetryAfterFn sets retry after function
func SetRetryAfterFn(fn func(ctxDone bool) time.Duration) Option {
return func(t *throttle) {
t.retryAfterFn = fn
}
}
// SetResponseHandler sets response handler
func SetResponseHandler(fn middleware.ResponseHandle) Option {
return func(t *throttle) {
t.responseHandler = fn
}
}
type throttle struct {
limit int
backlogLimit int
backlogTimeout time.Duration
retryAfterFn func(ctxDone bool) time.Duration
responseHandler middleware.ResponseHandle
tokens chan struct{}
backlogTokens chan struct{}
}
func (s *throttle) initTokens() {
s.tokens = make(chan struct{}, s.limit)
s.backlogTokens = make(chan struct{}, s.limit+s.backlogLimit)
for i := 0; i < s.limit+s.backlogLimit; i++ {
if i < s.limit {
s.tokens <- struct{}{}
}
s.backlogTokens <- struct{}{}
}
}
func (s *throttle) setRetryAfterHeader(w http.ResponseWriter, ctxDone bool) {
if s.retryAfterFn == nil {
return
}
retryAfterSeconds := strconv.Itoa(int(s.retryAfterFn(ctxDone).Seconds()))
w.Header().Set("Retry-After", retryAfterSeconds)
}
// Middleware is a throttle middleware that limits number of currently processed requests
// at a time across all users. Note: Throttle is not a rate-limiter per user,
// instead it just puts a ceiling on the number of currentl in-flight requests
// being processed from the point from where the Throttle middleware is mounted
func Middleware(opts ...Option) func(http.Handler) http.Handler {
t := &throttle{}
opts = append([]Option{DefaultOptions}, opts...)
for _, opt := range opts {
opt(t)
}
t.initTokens()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case <-r.Context().Done():
t.setRetryAfterHeader(w, true)
t.responseHandler(w, r, ErrContextCanceled)
return
case btok := <-t.backlogTokens:
timer := time.NewTimer(t.backlogTimeout)
defer func() {
t.backlogTokens <- btok
}()
select {
case <-timer.C:
t.setRetryAfterHeader(w, false)
t.responseHandler(w, r, ErrTimedOut)
return
case <-r.Context().Done():
timer.Stop()
t.setRetryAfterHeader(w, true)
t.responseHandler(w, r, ErrContextCanceled)
return
case tok := <-t.tokens:
defer func() {
timer.Stop()
t.tokens <- tok
}()
next.ServeHTTP(w, r)
return
}
default:
t.setRetryAfterHeader(w, false)
t.responseHandler(w, r, ErrCapacityExceeded)
return
}
})
}
}
// RespondWithTooManyRequests is a default response handler
func RespondWithTooManyRequests(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
}

245
throttle/throttle_test.go Normal file
View File

@ -0,0 +1,245 @@
package throttle
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestThrottleBacklog(t *testing.T) {
throttle := Middleware(
SetLimit(10),
SetBacklogLimit(50),
SetBacklogTimeout(time.Second*10),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
time.Sleep(time.Second * 1)
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(throttle(handler))
defer server.Close()
client := http.Client{Timeout: time.Second * 5}
var wg sync.WaitGroup
for i := 0; i < 1; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}(i)
}
wg.Wait()
}
func TestThrottleClientTimeout(t *testing.T) {
throttle := Middleware(
SetLimit(10),
SetBacklogLimit(50),
SetBacklogTimeout(time.Second*10),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
time.Sleep(time.Second * 5)
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(throttle(handler))
defer server.Close()
client := http.Client{Timeout: time.Second * 3}
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
_, err := client.Get(server.URL)
assert.Error(t, err)
}(i)
}
wg.Wait()
}
func TestThrottleTriggerGatewayTimeout(t *testing.T) {
throttle := Middleware(
SetLimit(50),
SetBacklogLimit(100),
SetBacklogTimeout(time.Second*5),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
time.Sleep(time.Second * 10)
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(throttle(handler))
defer server.Close()
client := http.Client{Timeout: time.Second * 60}
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}(i)
}
time.Sleep(time.Second * 1)
for i := 0; i < 50; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
}(i)
}
wg.Wait()
}
func TestThrottleMaximum(t *testing.T) {
throttle := Middleware(
SetLimit(10),
SetBacklogLimit(10),
SetBacklogTimeout(time.Second*5),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
time.Sleep(time.Second * 2)
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(throttle(handler))
defer server.Close()
client := http.Client{Timeout: time.Second * 60}
var wg sync.WaitGroup
for i := 0; i < 20; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}(i)
}
time.Sleep(time.Second * 1)
for i := 0; i < 20; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
}(i)
}
wg.Wait()
}
func TestThrottleRetryAfter(t *testing.T) {
throttle := Middleware(
SetLimit(10),
SetRetryAfterFn(func(_ bool) time.Duration { return time.Hour * 1 }),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
time.Sleep(time.Second * 3)
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(throttle(handler))
defer server.Close()
client := http.Client{Timeout: time.Second * 60}
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}(i)
}
time.Sleep(time.Second * 1)
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
assert.Equal(t, res.Header.Get("Retry-After"), "3600")
}(i)
}
wg.Wait()
}

81
timeout/timeout.go Normal file
View File

@ -0,0 +1,81 @@
package timeout
import (
"context"
"net/http"
"time"
"go.pkg.cx/middleware"
)
// DefaultOptions represents default timeout middleware options
var DefaultOptions = Options(
SetTimeout(time.Second*30),
SetResponseHandler(RespondWithTimeout),
)
// Options turns a list of option instances into an option.
func Options(opts ...Option) Option {
return func(t *timeout) {
for _, opt := range opts {
opt(t)
}
}
}
// Option configures timeout middleware
type Option func(t *timeout)
// SetTimeout sets request timeout
func SetTimeout(limit time.Duration) Option {
return func(t *timeout) {
t.timeout = limit
}
}
// SetResponseHandler sets response handler
func SetResponseHandler(fn middleware.ResponseHandle) Option {
return func(t *timeout) {
t.responseHandler = fn
}
}
type timeout struct {
timeout time.Duration
responseHandler middleware.ResponseHandle
}
// Middleware is a timeout middleware that cancels ctx after a given timeout
// and return a 504 Gateway Timeout error to the client.
//
// It's required that you select the ctx.Done() channel to check for the signal
// if the context has reached its deadline and return, otherwise the timeout
// signal will be just ignored
func Middleware(opts ...Option) func(next http.Handler) http.Handler {
t := &timeout{}
opts = append([]Option{DefaultOptions}, opts...)
for _, opt := range opts {
opt(t)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), t.timeout)
defer func() {
cancel()
if ctx.Err() == context.DeadlineExceeded {
t.responseHandler(w, r, ctx.Err())
return
}
}()
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// RespondWithTimeout is a default response handler
func RespondWithTimeout(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, http.StatusText(http.StatusGatewayTimeout), http.StatusGatewayTimeout)
}

62
timeout/timeout_test.go Normal file
View File

@ -0,0 +1,62 @@
package timeout
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestTimeoutPass(t *testing.T) {
timeout := Middleware(
SetTimeout(time.Second * 1),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
time.Sleep(time.Millisecond * 100)
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(timeout(handler))
defer server.Close()
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}
func TestTimeoutTimedOut(t *testing.T) {
timeout := Middleware(
SetTimeout(time.Millisecond * 300),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case <-r.Context().Done():
return
case <-time.After(time.Second * 1):
w.Write([]byte("resp")) // nolint:errcheck
}
})
server := httptest.NewServer(timeout(handler))
defer server.Close()
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusGatewayTimeout, res.StatusCode)
}

19
types.go Normal file
View File

@ -0,0 +1,19 @@
package middleware
import (
"net/http"
)
// ResponseHandle is a function that middleware call in case of stop chain
type ResponseHandle func(w http.ResponseWriter, r *http.Request, err error)
// CtxKey is a key to use with context.WithValue
type CtxKey struct {
Pkg string
Name string
}
// String returns string representation
func (s *CtxKey) String() string {
return s.Pkg + " context value " + s.Name
}