Compare commits

..

No commits in common. "a5f83606cd640977bbb689cd54a57737ab4d5b93" and "fbff18885ff5026ec7a17025ebca6390154590ca" have entirely different histories.

30 changed files with 551 additions and 536 deletions

View File

@ -1,52 +0,0 @@
run:
timeout: 5m
output:
format: tab
linters-settings:
govet:
check-shadowing: true
golint:
min-confidence: 0.1
maligned:
suggest-new: true
goconst:
min-len: 2
min-occurrences: 2
misspell:
locale: US
lll:
line-length: 140
gocritic:
enabled-tags:
- performance
- style
- experimental
disabled-checks:
- unnamedResult
- paramTypeCombine
linters:
enable:
- megacheck
- revive
- govet
- unconvert
- megacheck
- unused
- gas
- gocyclo
- dupl
- misspell
- unparam
- typecheck
- ineffassign
- stylecheck
- gochecknoinits
- exportloopref
- gocritic
- nakedret
- gosimple
- prealloc
fast: false
disable-all: true

View File

@ -1,7 +1,7 @@
package middleware
import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
@ -26,7 +26,7 @@ func TestAppInfo(t *testing.T) {
assert.Equal(t, "tapp", res.Header.Get("App-Name"))
assert.Equal(t, "tversion", res.Header.Get("App-Version"))
b, err := io.ReadAll(res.Body)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}

View File

@ -5,8 +5,8 @@ import (
"errors"
"net/http"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
"go.pkg.cx/middleware"
)
@ -20,7 +20,7 @@ var (
// Context keys
var (
JWTCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/auth/jwt", Name: "JWT"}
DataCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/auth/jwt", Name: "Data"}
DataCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/auth/token", Name: "Data"}
)
// DefaultOptions represents default jwt auth middleware options
@ -32,7 +32,7 @@ var DefaultOptions = Options(
SetValidateTokenFn(allowAll),
)
// Options turns a list of option instances into an option
// Options turns a list of option instances into an option.
func Options(opts ...Option) Option {
return func(a *auth) {
for _, opt := range opts {
@ -112,7 +112,8 @@ type auth struct {
// Middleware returns jwt auth middleware
func Middleware(key interface{}, alg jwa.SignatureAlgorithm, opts ...Option) func(next http.Handler) http.Handler {
a := &auth{}
opts = append(opts, SetKey(key), SetAlgorithm(alg))
opts = append(opts, SetKey(key))
opts = append(opts, SetAlgorithm(alg))
opts = append([]Option{DefaultOptions}, opts...)
for _, opt := range opts {
@ -134,7 +135,7 @@ func Middleware(key interface{}, alg jwa.SignatureAlgorithm, opts ...Option) fun
return
}
jwtToken, err := jwt.ParseString(token, jwt.WithKey(a.algorithm, a.key))
jwtToken, err := jwt.ParseString(token, jwt.WithVerify(a.algorithm, a.key))
if err != nil {
a.responseHandler(w, r, err)
return
@ -161,10 +162,10 @@ func Middleware(key interface{}, alg jwa.SignatureAlgorithm, opts ...Option) fun
}
// RespondWithUnauthorized is a default response handler
func RespondWithUnauthorized(w http.ResponseWriter, _ *http.Request, _ error) {
func RespondWithUnauthorized(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
func allowAll(_ jwt.Token) (bool, interface{}) {
func allowAll(token jwt.Token) (bool, interface{}) {
return true, nil
}

View File

@ -1,14 +1,14 @@
package jwt
import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -36,12 +36,12 @@ func TestJWTAuthDefaults(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
b, err := io.ReadAll(res.Body)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "Unauthorized\n", string(b))
token := jwt.New()
payload, err := jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("changethis")))
payload, err := jwt.Sign(token, jwa.HS256, []byte("changethis"))
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?jwt=" + string(payload))
@ -49,7 +49,7 @@ func TestJWTAuthDefaults(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err = io.ReadAll(res.Body)
b, err = ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}
@ -72,7 +72,7 @@ func TestJWTAuthParseVerify(t *testing.T) {
token := jwt.New()
payload, err := jwt.Sign(token, jwt.WithKey(jwa.HS512, []byte("tkey")))
payload, err := jwt.Sign(token, jwa.HS512, []byte("tkey"))
assert.NoError(t, err)
res, err := http.Get(server.URL + "/?token=" + string(payload))
@ -80,7 +80,7 @@ func TestJWTAuthParseVerify(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
payload, err = jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("tkey")))
payload, err = jwt.Sign(token, jwa.HS256, []byte("tkey"))
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?token=" + string(payload))
@ -88,7 +88,7 @@ func TestJWTAuthParseVerify(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
payload, err = jwt.Sign(token, jwt.WithKey(jwa.HS512, []byte("wrongkey")))
payload, err = jwt.Sign(token, jwa.HS512, []byte("wrongkey"))
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?token=" + string(payload))
@ -97,9 +97,9 @@ func TestJWTAuthParseVerify(t *testing.T) {
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
now := time.Now()
token.Set(jwt.IssuedAtKey, now.Add(-1*time.Hour)) // nolint:errcheck // No need to check error here
token.Set(jwt.ExpirationKey, now.Add(-58*time.Minute)) // nolint:errcheck // No need to check error here
payload, err = jwt.Sign(token, jwt.WithKey(jwa.HS512, []byte("tkey")))
token.Set(jwt.IssuedAtKey, now.Add(-1*time.Hour)) // nolint:errcheck
token.Set(jwt.ExpirationKey, now.Add(-58*time.Minute)) // nolint:errcheck
payload, err = jwt.Sign(token, jwa.HS512, []byte("tkey"))
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?token=" + string(payload))
@ -120,8 +120,8 @@ func TestJWTAuthVerifyOptions(t *testing.T) {
token := jwt.New()
token.Set(jwt.IssuerKey, "wrongissuer") // nolint:errcheck // No need to check error here
payload, err := jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("changethis")))
token.Set(jwt.IssuerKey, "wrongissuer") // nolint:errcheck
payload, err := jwt.Sign(token, jwa.HS256, []byte("changethis"))
assert.NoError(t, err)
res, err := http.Get(server.URL + "/?jwt=" + string(payload))
@ -129,8 +129,8 @@ func TestJWTAuthVerifyOptions(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
token.Set(jwt.IssuerKey, "tissuer") // nolint:errcheck // No need to check error here
payload, err = jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("changethis")))
token.Set(jwt.IssuerKey, "tissuer") // nolint:errcheck
payload, err = jwt.Sign(token, jwa.HS256, []byte("changethis"))
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?jwt=" + string(payload))
@ -155,16 +155,16 @@ func TestJWTAuthContext(t *testing.T) {
)
token := jwt.New()
token.Set(jwt.JwtIDKey, "tid") // nolint:errcheck // No need to check error here
token.Set(jwt.JwtIDKey, "tid") // nolint:errcheck
testCtxHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
data, ok := r.Context().Value(DataCtxKey).(*data)
require.True(t, ok)
assert.Equal(t, "test data", data.inner)
ctxToken, ok := r.Context().Value(JWTCtxKey).(jwt.Token)
token, ok := r.Context().Value(JWTCtxKey).(jwt.Token)
require.True(t, ok)
assert.Equal(t, "tid", ctxToken.JwtID())
assert.Equal(t, "tid", token.JwtID())
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
@ -173,7 +173,7 @@ func TestJWTAuthContext(t *testing.T) {
server := httptest.NewServer(auth(testCtxHandler))
defer server.Close()
payload, err := jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("changethis")))
payload, err := jwt.Sign(token, jwa.HS256, []byte("changethis"))
assert.NoError(t, err)
res, err := http.Get(server.URL + "/?jwt=" + string(payload))
@ -181,8 +181,8 @@ func TestJWTAuthContext(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
token.Set(jwt.JwtIDKey, "invalid") // nolint:errcheck // No need to check error here
payload, err = jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("changethis")))
token.Set(jwt.JwtIDKey, "invalid") // nolint:errcheck
payload, err = jwt.Sign(token, jwa.HS256, []byte("changethis"))
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?jwt=" + string(payload))

View File

@ -28,7 +28,7 @@ var DefaultOptions = Options(
SetResponseHandler(RespondWithUnauthorized),
)
// Options turns a list of option instances into an option
// Options turns a list of option instances into an option.
func Options(opts ...Option) Option {
return func(a *auth) {
for _, opt := range opts {
@ -114,10 +114,10 @@ func Middleware(opts ...Option) func(next http.Handler) http.Handler {
}
// RespondWithUnauthorized is a default response handler
func RespondWithUnauthorized(w http.ResponseWriter, _ *http.Request, _ error) {
func RespondWithUnauthorized(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
}
func rejectAll(_ string) (bool, interface{}) {
func rejectAll(token string) (bool, interface{}) {
return false, nil
}

View File

@ -1,7 +1,7 @@
package token
import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
@ -34,7 +34,7 @@ func TestTokenAuthDefaults(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
b, err := io.ReadAll(res.Body)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "Unauthorized\n", string(b))
@ -62,7 +62,7 @@ func TestTokenAuthValidateFn(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err := io.ReadAll(res.Body)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}

View File

@ -9,42 +9,42 @@ import (
)
func TestTokenFromAuthorizationHeader(t *testing.T) {
req, err := http.NewRequest("GET", "/", http.NoBody)
req, err := http.NewRequest("GET", "/", nil)
require.Nil(t, err)
req.Header.Set("Authorization", "abc")
assert.Equal(t, "abc", TokenFromAuthorizationHeader(req))
req, err = http.NewRequest("GET", "/", http.NoBody)
req, err = http.NewRequest("GET", "/", nil)
require.Nil(t, err)
req.Header.Set("Authorization", "abcdefghe")
assert.Equal(t, "abcdefghe", TokenFromAuthorizationHeader(req))
req, err = http.NewRequest("GET", "/", http.NoBody)
req, err = http.NewRequest("GET", "/", nil)
require.Nil(t, err)
req.Header.Set("Authorization", "Bearer abc")
assert.Equal(t, "abc", TokenFromAuthorizationHeader(req))
}
func TestTokenFromHeader(t *testing.T) {
req, err := http.NewRequest("GET", "/", http.NoBody)
req, err := http.NewRequest("GET", "/", nil)
require.Nil(t, err)
req.Header.Set("X-Token", "abc")
assert.Equal(t, "abc", TokenFromHeader("X-Token")(req))
}
func TestTokenFromQuery(t *testing.T) {
req, err := http.NewRequest("GET", "/?token=abc", http.NoBody)
req, err := http.NewRequest("GET", "/?token=abc", nil)
require.Nil(t, err)
assert.Equal(t, "abc", TokenFromQuery("token")(req))
}
func TestTokenFromCookie(t *testing.T) {
req, err := http.NewRequest("GET", "/", http.NoBody)
req, err := http.NewRequest("GET", "/", nil)
require.Nil(t, err)
req.AddCookie(&http.Cookie{Name: "token", Value: "abc"})
assert.Equal(t, "abc", TokenFromCookie("token")(req))
req, err = http.NewRequest("GET", "/", http.NoBody)
req, err = http.NewRequest("GET", "/", nil)
require.Nil(t, err)
assert.Equal(t, "", TokenFromCookie("token")(req))
}

26
go.mod
View File

@ -1,24 +1,14 @@
module go.pkg.cx/middleware
go 1.21
require (
github.com/lestrrat-go/jwx/v2 v2.0.18
github.com/stretchr/testify v1.8.4
)
go 1.15
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/lestrrat-go/blackmagic v1.0.2 // indirect
github.com/lestrrat-go/httpcc v1.0.1 // indirect
github.com/lestrrat-go/httprc v1.0.4 // indirect
github.com/lestrrat-go/iter v1.0.2 // indirect
github.com/lestrrat-go/option v1.0.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/segmentio/asm v1.2.0 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/sys v0.15.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
github.com/go-pkgz/expirable-cache v0.0.3
github.com/goccy/go-json v0.7.3 // indirect
github.com/lestrrat-go/backoff/v2 v2.0.8 // indirect
github.com/lestrrat-go/jwx v1.2.1
github.com/stretchr/testify v1.7.0
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
)

111
go.sum
View File

@ -1,80 +1,79 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
github.com/lestrrat-go/blackmagic v1.0.2/go.mod h1:UrEqBzIR2U6CnzVyUtfM6oZNMt/7O7Vohk2J0OGSAtU=
github.com/lestrrat-go/httpcc v1.0.1 h1:ydWCStUeJLkpYyjLDHihupbn2tYmZ7m22BGkcvZZrIE=
github.com/lestrrat-go/httpcc v1.0.1/go.mod h1:qiltp3Mt56+55GPVCbTdM9MlqhvzyuL6W/NMDA8vA5E=
github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJGdI8=
github.com/lestrrat-go/httprc v1.0.4/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
github.com/lestrrat-go/jwx/v2 v2.0.18 h1:HHZkYS5wWDDyAiNBwztEtDoX07WDhGEdixm8G06R50o=
github.com/lestrrat-go/jwx/v2 v2.0.18/go.mod h1:fAJ+k5eTgKdDqanzCuK6DAt3W7n3cs2/FX7JhQdk83U=
github.com/decred/dcrd/chaincfg/chainhash v1.0.2/go.mod h1:BpbrGgrPTr3YJYRN3Bm+D9NuaFd+zGyNeIKgrhCXK60=
github.com/decred/dcrd/crypto/blake256 v1.0.0/go.mod h1:sQl2p6Y26YV+ZOcSTP6thNdn47hh8kt6rqSlvmrXFAc=
github.com/decred/dcrd/dcrec/secp256k1/v3 v3.0.0 h1:sgNeV1VRMDzs6rzyPpxyM0jp317hnwiq58Filgag2xw=
github.com/decred/dcrd/dcrec/secp256k1/v3 v3.0.0/go.mod h1:J70FGZSbzsjecRTiTzER+3f1KZLNaXkuv+yeFTKoxM8=
github.com/go-pkgz/expirable-cache v0.0.3 h1:rTh6qNPp78z0bQE6HDhXBHUwqnV9i09Vm6dksJLXQDc=
github.com/go-pkgz/expirable-cache v0.0.3/go.mod h1:+IauqN00R2FqNRLCLA+X5YljQJrwB179PfiAoMPlTlQ=
github.com/goccy/go-json v0.4.8/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/goccy/go-json v0.7.3 h1:Pznres7bC8RRKT9yOn3EZ7fK+8Kle6K9rW2U33QlXZI=
github.com/goccy/go-json v0.7.3/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/lestrrat-go/backoff/v2 v2.0.7/go.mod h1:rHP/q/r9aT27n24JQLa7JhSQZCKBBOiM/uP402WwN8Y=
github.com/lestrrat-go/backoff/v2 v2.0.8 h1:oNb5E5isby2kiro9AgdHLv5N5tint1AnDVVf2E2un5A=
github.com/lestrrat-go/backoff/v2 v2.0.8/go.mod h1:rHP/q/r9aT27n24JQLa7JhSQZCKBBOiM/uP402WwN8Y=
github.com/lestrrat-go/blackmagic v1.0.0 h1:XzdxDbuQTz0RZZEmdU7cnQxUtFUzgCSPq8RCz4BxIi4=
github.com/lestrrat-go/blackmagic v1.0.0/go.mod h1:TNgH//0vYSs8VXDCfkZLgIrVTTXQELZffUV0tz3MtdQ=
github.com/lestrrat-go/codegen v1.0.0/go.mod h1:JhJw6OQAuPEfVKUCLItpaVLumDGWQznd1VaXrBk9TdM=
github.com/lestrrat-go/httpcc v1.0.0 h1:FszVC6cKfDvBKcJv646+lkh4GydQg2Z29scgUfkOpYc=
github.com/lestrrat-go/httpcc v1.0.0/go.mod h1:tGS/u00Vh5N6FHNkExqGGNId8e0Big+++0Gf8MBnAvE=
github.com/lestrrat-go/iter v1.0.1 h1:q8faalr2dY6o8bV45uwrxq12bRa1ezKrB6oM9FUgN4A=
github.com/lestrrat-go/iter v1.0.1/go.mod h1:zIdgO1mRKhn8l9vrZJZz9TUMMFbQbLeTsbqPDrJ/OJc=
github.com/lestrrat-go/jwx v1.2.1 h1:WJ/3tiPUz1wV24KiwMEanbENwHnYub9UqzCbQ82mv9c=
github.com/lestrrat-go/jwx v1.2.1/go.mod h1:Tg2uP7bpxEHUDtuWjap/PxroJ4okxGzkQznXiG+a5Dc=
github.com/lestrrat-go/option v0.0.0-20210103042652-6f1ecfceda35/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lestrrat-go/option v1.0.0 h1:WqAWL8kh8VcSoD6xjSH34/1m8yxluXQbDeKNfvFeEO4=
github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lestrrat-go/option v1.0.1 h1:oAzP2fvZGQKWkvHa1/SAcFolBEca1oN+mQ7eooNBEYU=
github.com/lestrrat-go/option v1.0.1/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
github.com/lestrrat-go/pdebug/v3 v3.0.1 h1:3G5sX/aw/TbMTtVc9U7IHBWRZtMvwvBziF1e4HoQtv8=
github.com/lestrrat-go/pdebug/v3 v3.0.1/go.mod h1:za+m+Ve24yCxTEhR59N7UlnJomWwCiIqbJRmKeiADU4=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201217014255-9d1352758620/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e h1:gsTQYXdTw2Gq7RBsWvlQ91b+aEQ6bXFUngBGuR8sPpI=
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.0.0-20200918232735-d647fc253266/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU=
golang.org/x/tools v0.0.0-20210114065538-d78b04bdf963/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -2,7 +2,7 @@ package logger
import (
"errors"
"io"
"io/ioutil"
"strings"
"testing"
@ -39,7 +39,7 @@ func TestLoggerPeek(t *testing.T) {
continue
}
body, err := io.ReadAll(r)
body, err := ioutil.ReadAll(r)
if !assert.NoError(t, err) {
continue
}

View File

@ -1,7 +1,7 @@
package logger
import (
"io"
"io/ioutil"
"log"
"net/http"
"net/url"
@ -17,7 +17,7 @@ var DefaultOptions = Options(
SetLogHandler(DefaultLogHandler),
)
// Options turns a list of option instances into an option
// Options turns a list of option instances into an option.
func Options(opts ...Option) Option {
return func(l *logger) {
for _, opt := range opts {
@ -110,10 +110,10 @@ func (s *logger) body(r *http.Request) string {
return ""
}
r.Body = io.NopCloser(rdr)
r.Body = ioutil.NopCloser(rdr)
if len(body) > 0 {
body = strings.ReplaceAll(body, "\n", " ")
body = strings.Replace(body, "\n", " ", -1)
body = regexpMultiWhitespace.ReplaceAllString(body, " ")
}
@ -186,7 +186,7 @@ func Middleware(opts ...Option) func(http.Handler) http.Handler {
}
// DefaultLogHandler is a default log handler
func DefaultLogHandler(entry LogEntry) { // nolint:gocritic // For backwards compatibility
func DefaultLogHandler(entry LogEntry) {
log.Printf(
"%s - %s - %s - %d (%d) - %v",
entry.Method,

View File

@ -2,7 +2,7 @@ package logger
import (
"bytes"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
@ -40,7 +40,7 @@ func TestLogger(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err := io.ReadAll(res.Body)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}
@ -57,7 +57,7 @@ func TestLoggerBody(t *testing.T) {
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
body, err := ioutil.ReadAll(r.Body)
assert.NoError(t, err)
assert.Equal(t, "12345678901234567890", string(body))
@ -116,7 +116,7 @@ func TestLoggerSanitize(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, r.URL.Query().Get("param"), "password")
body, err := io.ReadAll(r.Body)
body, err := ioutil.ReadAll(r.Body)
assert.NoError(t, err)
assert.Equal(t, "body|password", string(body))

View File

@ -8,7 +8,7 @@ import (
)
var (
errWriterNotImplentsHijacker = errors.New("ResponseWriter does not implement the Hijacker interface")
errWriterNotImplentsHijacker = errors.New("ResponseWriter does not implement the Hijacker interface") // nolint:golint
)
type trackingResponseWriter struct {

View File

@ -1,19 +0,0 @@
package middleware
import (
"net/http"
)
// Wrap converts a list of middlewares to nested calls in reverse order
func Wrap(handler http.Handler, mws ...func(http.Handler) http.Handler) http.Handler {
if len(mws) == 0 {
return handler
}
res := handler
for i := len(mws) - 1; i >= 0; i-- {
res = mws[i](res)
}
return res
}

View File

@ -1,36 +0,0 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestWrap(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/something/1/2", r.URL.Path)
})
mw1 := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.URL.Path += "/1"
h.ServeHTTP(w, r)
})
}
mw2 := func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.URL.Path += "/2"
h.ServeHTTP(w, r)
})
}
server := httptest.NewServer(Wrap(handler, mw1, mw2))
defer server.Close()
res, err := http.Get(server.URL + "/something")
require.NoError(t, err)
assert.Equal(t, 200, res.StatusCode)
}

View File

@ -1,7 +1,7 @@
package middleware
import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
@ -24,7 +24,7 @@ func TestNoCache(t *testing.T) {
client := http.Client{}
req, err := http.NewRequest("GET", server.URL, http.NoBody)
req, err := http.NewRequest("GET", server.URL, nil)
require.Nil(t, err)
req.Header.Set("ETag", "ETagValue")
@ -37,7 +37,7 @@ func TestNoCache(t *testing.T) {
assert.Equal(t, "no-cache", res.Header.Get("Pragma"))
assert.Equal(t, "0", res.Header.Get("X-Accel-Expires"))
b, err := io.ReadAll(res.Body)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}

View File

@ -1,159 +0,0 @@
package paginate
import (
"context"
"errors"
"net/http"
"strconv"
"go.pkg.cx/middleware"
)
// Errors
var (
ErrPaginationDefaults = errors.New("pagination defaults are nil")
)
// Context keys
var (
PaginationCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/paginate", Name: "Pagination"}
)
// DefaultOptions represents default paginate middleware options
var DefaultOptions = Options(
WithFindPaginationFn(PaginationFromQuery("page", "pageSize")),
SetPaginationDefaults(1, 50),
SetValidatePaginationFn(allowAll),
SetResponseHandler(RespondWithBadRequest),
)
// Pagination represents pagination info
type Pagination struct {
Page int
PageSize int
}
// Options turns a list of option instances into an option
func Options(opts ...Option) Option {
return func(p *paginate) {
for _, opt := range opts {
opt(p)
}
}
}
// Option configures paginate middleware
type Option func(p *paginate)
// WithFindPaginationFn adds pagination find function to the list
func WithFindPaginationFn(fn func(r *http.Request, p *Pagination) *Pagination) Option {
return func(p *paginate) {
p.findPaginationFns = append(p.findPaginationFns, fn)
}
}
// SetFindPaginationFns sets pagination find functions list
func SetFindPaginationFns(fns ...func(r *http.Request, p *Pagination) *Pagination) Option {
return func(p *paginate) {
p.findPaginationFns = fns
}
}
// SetPaginationDefaults sets pagination defaults function
func SetPaginationDefaults(page int, pageSize int) Option {
return func(p *paginate) {
p.paginationDefaultsFn = func() *Pagination {
return &Pagination{Page: page, PageSize: pageSize}
}
}
}
// SetValidatePaginationFn sets pagination validation function
func SetValidatePaginationFn(fn func(p *Pagination) error) Option {
return func(p *paginate) {
p.validatePaginationFn = fn
}
}
// SetResponseHandler sets response handler
func SetResponseHandler(fn middleware.ResponseHandle) Option {
return func(p *paginate) {
p.responseHandler = fn
}
}
type paginate struct {
findPaginationFns []func(r *http.Request, p *Pagination) *Pagination
paginationDefaultsFn func() *Pagination
validatePaginationFn func(p *Pagination) error
responseHandler middleware.ResponseHandle
}
// Middleware returns paginate middleware
func Middleware(opts ...Option) func(next http.Handler) http.Handler {
p := &paginate{}
opts = append([]Option{DefaultOptions}, opts...)
for _, opt := range opts {
opt(p)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pagination := p.paginationDefaultsFn()
if pagination == nil {
p.responseHandler(w, r, ErrPaginationDefaults)
return
}
for _, fn := range p.findPaginationFns {
if nextPagination := fn(r, pagination); nextPagination != nil {
pagination = nextPagination
}
}
if err := p.validatePaginationFn(pagination); err != nil {
p.responseHandler(w, r, err)
return
}
ctx := r.Context()
ctx = context.WithValue(ctx, PaginationCtxKey, pagination)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// PaginationFromContext returns pagination from context
func PaginationFromContext(ctx context.Context) *Pagination {
if pagination, ok := ctx.Value(PaginationCtxKey).(*Pagination); ok {
return pagination
}
return nil
}
// PaginationFromQuery returns pagination from query params
func PaginationFromQuery(pageParam string, pageSizeParam string) func(r *http.Request, p *Pagination) *Pagination {
return func(r *http.Request, p *Pagination) *Pagination {
if page, err := strconv.Atoi(r.URL.Query().Get(pageParam)); err == nil {
p.Page = page
}
if pageSize, err := strconv.Atoi(r.URL.Query().Get(pageSizeParam)); err == nil {
p.PageSize = pageSize
}
return p
}
}
// RespondWithBadRequest is a default response handler
func RespondWithBadRequest(w http.ResponseWriter, _ *http.Request, _ error) {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
}
func allowAll(_ *Pagination) error {
return nil
}

View File

@ -1,68 +0,0 @@
package paginate
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestPaginationFromQuery(t *testing.T) {
req, err := http.NewRequest("GET", "/?page=2&pageSize=10", http.NoBody)
assert.NoError(t, err)
pagination := PaginationFromQuery("page", "pageSize")(req, &Pagination{})
assert.Equal(t, &Pagination{Page: 2, PageSize: 10}, pagination)
}
func TestValidPagination(t *testing.T) {
opts := []Option{
SetPaginationDefaults(1, 10),
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pagination := PaginationFromContext(r.Context())
assert.NotNil(t, pagination)
assert.Equal(t, 2, pagination.Page)
assert.Equal(t, 10, pagination.PageSize)
})
middleware := Middleware(opts...)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/?page=2&pageSize=10", http.NoBody)
middleware.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
}
func TestInvalidPagination(t *testing.T) {
opts := []Option{
SetPaginationDefaults(1, 10),
SetValidatePaginationFn(func(p *Pagination) error {
if p.Page < 1 || p.PageSize < 1 {
return errors.New("invalid pagination")
}
return nil
}),
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
pagination := PaginationFromContext(r.Context())
assert.NotNil(t, pagination)
assert.Equal(t, -1, pagination.Page)
assert.Equal(t, 10, pagination.PageSize)
})
middleware := Middleware(opts...)(handler)
w := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/?page=-1&pageSize=10", http.NoBody)
middleware.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadRequest, w.Code)
}

View File

@ -11,7 +11,7 @@ func Ping(next http.Handler) http.Handler {
if r.Method == "GET" && strings.HasSuffix(strings.ToLower(r.URL.Path), "/ping") {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte("pong")) // nolint:errcheck // No need to check error here
w.Write([]byte("pong")) // nolint:errcheck
return
}

View File

@ -1,7 +1,7 @@
package middleware
import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
@ -24,7 +24,7 @@ func TestPing(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err := io.ReadAll(res.Body)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
@ -33,7 +33,7 @@ func TestPing(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err = io.ReadAll(res.Body)
b, err = ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "pong", string(b))
@ -42,7 +42,7 @@ func TestPing(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err = io.ReadAll(res.Body)
b, err = ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "pong", string(b))
}

View File

@ -1,29 +0,0 @@
package middleware
import (
"expvar"
"net/http"
"net/http/pprof"
)
// Profiler is a convenient subrouter used for mounting net/http/pprof
func Profiler() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/pprof/", pprof.Index)
mux.HandleFunc("/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/pprof/profile", pprof.Profile)
mux.HandleFunc("/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/pprof/trace", pprof.Trace)
mux.Handle("/pprof/goroutine", pprof.Handler("goroutine"))
mux.Handle("/pprof/threadcreate", pprof.Handler("threadcreate"))
mux.Handle("/pprof/mutex", pprof.Handler("mutex"))
mux.Handle("/pprof/heap", pprof.Handler("heap"))
mux.Handle("/pprof/block", pprof.Handler("block"))
mux.Handle("/pprof/allocs", pprof.Handler("allocs"))
mux.Handle("/vars", expvar.Handler())
return Wrap(mux, NoCache)
}

175
ratelimit/rate_limit.go Normal file
View File

@ -0,0 +1,175 @@
package ratelimit
import (
"errors"
"net/http"
"strconv"
"strings"
"sync"
"time"
cache "github.com/go-pkgz/expirable-cache"
"go.pkg.cx/middleware"
)
// Errors
var (
ErrLimitReached = errors.New("limit reached")
)
// DefaultOptions represents default timeout middleware options
var DefaultOptions = Options(
SetLimit(100),
SetPeriod(time.Minute*1),
SetKeyFn(defaultKeyFn),
SetResponseHandler(RespondWithTooManyRequests),
)
// Options turns a list of option instances into an option.
func Options(opts ...Option) Option {
return func(l *limiter) {
for _, opt := range opts {
opt(l)
}
}
}
// Option configures timeout middleware
type Option func(l *limiter)
// SetLimit sets request limit
func SetLimit(limit int) Option {
if limit < 1 {
panic("rate limit middleware expects limit > 0")
}
return func(l *limiter) {
l.limit = limit
}
}
// SetPeriod sets limiter period
func SetPeriod(period time.Duration) Option {
return func(l *limiter) {
l.period = period
}
}
// SetKeyFn sets limiter key extraction function
func SetKeyFn(fn func(r *http.Request) string) Option {
return func(l *limiter) {
l.keyFn = fn
}
}
// SetResponseHandler sets response handler
func SetResponseHandler(fn middleware.ResponseHandle) Option {
return func(l *limiter) {
l.responseHandler = fn
}
}
type info struct {
limit int
remaining int
reset int64
reached bool
}
type entry struct {
count int
expiration time.Time
}
type limiter struct {
limit int
period time.Duration
keyFn func(r *http.Request) string
responseHandler middleware.ResponseHandle
lock sync.Mutex
cache cache.Cache
}
func (s *limiter) initCache() {
c, err := cache.NewCache(cache.TTL(s.period))
if err != nil {
panic(err)
}
s.cache = c
}
func (s *limiter) try(key string) info {
s.lock.Lock()
defer s.lock.Unlock()
now := time.Now()
if e, ok := s.cache.Get(key); ok {
e.(*entry).count++
s.cache.Set(key, e, 0)
return s.infoFromEntry(e.(*entry))
}
e := &entry{count: 1, expiration: now.Add(s.period)}
s.cache.Set(key, e, 0)
return s.infoFromEntry(e)
}
func (s *limiter) infoFromEntry(e *entry) info {
reached := true
remaining := 0
if e.count <= s.limit {
reached = false
remaining = s.limit - e.count
}
return info{
limit: s.limit,
remaining: remaining,
reset: e.expiration.Unix(),
reached: reached,
}
}
// Middleware is a rate limiter middleware
func Middleware(opts ...Option) func(next http.Handler) http.Handler {
l := &limiter{}
opts = append([]Option{DefaultOptions}, opts...)
for _, opt := range opts {
opt(l)
}
l.initCache()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
info := l.try(l.keyFn(r))
w.Header().Add("X-RateLimit-Limit", strconv.Itoa(info.limit))
w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(info.remaining))
w.Header().Add("X-RateLimit-Reset", strconv.FormatInt(info.reset, 10))
if info.reached {
l.responseHandler(w, r, ErrLimitReached)
return
}
next.ServeHTTP(w, r)
})
}
}
// RespondWithTooManyRequests is a default response handler
func RespondWithTooManyRequests(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
}
func defaultKeyFn(r *http.Request) string {
return strings.Split(r.RemoteAddr, ":")[0]
}

View File

@ -0,0 +1,163 @@
package ratelimit
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRateLimitSequential(t *testing.T) {
rateLimit := Middleware(
SetLimit(5),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(rateLimit(handler))
defer server.Close()
for i := 0; i < 10; i++ {
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
if i < 5 {
assert.Equal(t, http.StatusOK, res.StatusCode)
} else {
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
}
}
}
func TestRateLimitConcurrent(t *testing.T) {
rateLimit := Middleware(
SetLimit(5),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(rateLimit(handler))
defer server.Close()
counter := int64(0)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
if res.StatusCode == http.StatusOK {
atomic.AddInt64(&counter, 1)
}
}()
}
wg.Wait()
assert.Equal(t, int64(5), atomic.LoadInt64(&counter))
}
func TestRateLimitHeaders(t *testing.T) {
rateLimit := Middleware(
SetLimit(1),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(rateLimit(handler))
defer server.Close()
now := time.Now()
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit"))
assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining"))
resetTS, err := strconv.Atoi(res.Header.Get("X-RateLimit-Reset"))
assert.NoError(t, err)
assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
res, err = http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit"))
assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining"))
resetTS, err = strconv.Atoi(res.Header.Get("X-RateLimit-Reset"))
assert.NoError(t, err)
assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1)
b, err = ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "Too Many Requests\n", string(b))
}
func TestRateLimitExpiration(t *testing.T) {
rateLimit := Middleware(
SetLimit(5),
SetPeriod(time.Millisecond*500),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(rateLimit(handler))
defer server.Close()
for i := 0; i < 10; i++ {
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
if i < 5 {
assert.Equal(t, http.StatusOK, res.StatusCode)
} else {
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
}
}
time.Sleep(time.Millisecond * 500)
for i := 0; i < 10; i++ {
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
if i < 5 {
assert.Equal(t, http.StatusOK, res.StatusCode)
} else {
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
}
}
}

View File

@ -1,7 +1,7 @@
package middleware
import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
@ -26,7 +26,7 @@ func TestRealIPXRealIP(t *testing.T) {
client := http.Client{}
req, err := http.NewRequest("GET", server.URL, http.NoBody)
req, err := http.NewRequest("GET", server.URL, nil)
require.Nil(t, err)
req.Header.Set("X-Real-IP", "3.3.3.3")
@ -35,7 +35,7 @@ func TestRealIPXRealIP(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err := io.ReadAll(res.Body)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}
@ -46,7 +46,7 @@ func TestRealIPXForwardedFor(t *testing.T) {
client := http.Client{}
req, err := http.NewRequest("GET", server.URL, http.NoBody)
req, err := http.NewRequest("GET", server.URL, nil)
require.Nil(t, err)
req.Header.Set("X-Forwarded-For", "3.3.3.3")
@ -55,7 +55,7 @@ func TestRealIPXForwardedFor(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
req, err = http.NewRequest("GET", server.URL, http.NoBody)
req, err = http.NewRequest("GET", server.URL, nil)
require.Nil(t, err)
req.Header.Set("X-Forwarded-For", "3.3.3.3, 4.4.4.4, 5.5.5.5")
@ -71,7 +71,7 @@ func TestRealIPBothHeaders(t *testing.T) {
client := http.Client{}
req, err := http.NewRequest("GET", server.URL, http.NoBody)
req, err := http.NewRequest("GET", server.URL, nil)
require.Nil(t, err)
req.Header.Set("X-Real-IP", "3.3.3.3")
req.Header.Set("X-Forwarded-For", "4.4.4.4")

View File

@ -15,7 +15,7 @@ var DefaultOptions = Options(
SetResponseHandler(RespondWithInternalServerError),
)
// Options turns a list of option instances into an option
// Options turns a list of option instances into an option.
func Options(opts ...Option) Option {
return func(r *recoverer) {
for _, opt := range opts {
@ -79,7 +79,7 @@ func Middleware(opts ...Option) func(next http.Handler) http.Handler {
}
// RespondWithInternalServerError is a default response handler
func RespondWithInternalServerError(w http.ResponseWriter, _ *http.Request, _ error) {
func RespondWithInternalServerError(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}

View File

@ -1,7 +1,7 @@
package recoverer
import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
@ -47,7 +47,7 @@ func TestRecoverer(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err := io.ReadAll(res.Body)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))

View File

@ -22,7 +22,7 @@ var DefaultOptions = Options(
SetResponseHandler(RespondWithTooManyRequests),
)
// Options turns a list of option instances into an option
// Options turns a list of option instances into an option.
func Options(opts ...Option) Option {
return func(t *throttle) {
for _, opt := range opts {
@ -112,7 +112,7 @@ func (s *throttle) setRetryAfterHeader(w http.ResponseWriter, ctxDone bool) {
// Middleware is a throttle middleware that limits number of currently processed requests
// at a time across all users. Note: Throttle is not a rate-limiter per user,
// instead it just puts a ceiling on the number of currently in-flight requests
// instead it just puts a ceiling on the number of currentl in-flight requests
// being processed from the point from where the Throttle middleware is mounted
func Middleware(opts ...Option) func(http.Handler) http.Handler {
t := &throttle{}
@ -171,6 +171,6 @@ func Middleware(opts ...Option) func(http.Handler) http.Handler {
}
// RespondWithTooManyRequests is a default response handler
func RespondWithTooManyRequests(w http.ResponseWriter, _ *http.Request, _ error) {
func RespondWithTooManyRequests(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
}

View File

@ -1,7 +1,7 @@
package throttle
import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"sync"
@ -35,7 +35,7 @@ func TestThrottleBacklog(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < 1; i++ {
wg.Add(1)
go func() {
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
@ -43,10 +43,10 @@ func TestThrottleBacklog(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err := io.ReadAll(res.Body)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}()
}(i)
}
wg.Wait()
@ -75,12 +75,12 @@ func TestThrottleClientTimeout(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
go func(i int) {
defer wg.Done()
_, err := client.Get(server.URL)
assert.Error(t, err)
}()
}(i)
}
wg.Wait()
@ -110,7 +110,7 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) {
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
@ -118,21 +118,21 @@ func TestThrottleTriggerGatewayTimeout(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}()
}(i)
}
time.Sleep(time.Second * 1)
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
}()
}(i)
}
wg.Wait()
@ -162,7 +162,7 @@ func TestThrottleMaximum(t *testing.T) {
for i := 0; i < 20; i++ {
wg.Add(1)
go func() {
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
@ -170,24 +170,75 @@ func TestThrottleMaximum(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err := io.ReadAll(res.Body)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}()
}(i)
}
time.Sleep(time.Second * 1)
for i := 0; i < 20; i++ {
wg.Add(1)
go func() {
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
}()
}(i)
}
wg.Wait()
}
func TestThrottleRetryAfter(t *testing.T) {
throttle := Middleware(
SetLimit(10),
SetRetryAfterFn(func(_ bool) time.Duration { return time.Hour * 1 }),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
time.Sleep(time.Second * 3)
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(throttle(handler))
defer server.Close()
client := http.Client{Timeout: time.Second * 60}
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}(i)
}
time.Sleep(time.Second * 1)
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
res, err := client.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
assert.Equal(t, res.Header.Get("Retry-After"), "3600")
}(i)
}
wg.Wait()

View File

@ -14,7 +14,7 @@ var DefaultOptions = Options(
SetResponseHandler(RespondWithTimeout),
)
// Options turns a list of option instances into an option
// Options turns a list of option instances into an option.
func Options(opts ...Option) Option {
return func(t *timeout) {
for _, opt := range opts {
@ -76,6 +76,6 @@ func Middleware(opts ...Option) func(next http.Handler) http.Handler {
}
// RespondWithTimeout is a default response handler
func RespondWithTimeout(w http.ResponseWriter, _ *http.Request, _ error) {
func RespondWithTimeout(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, http.StatusText(http.StatusGatewayTimeout), http.StatusGatewayTimeout)
}

View File

@ -1,7 +1,7 @@
package timeout
import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
@ -32,7 +32,7 @@ func TestTimeoutPass(t *testing.T) {
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
b, err := io.ReadAll(res.Body)
b, err := ioutil.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}
@ -48,8 +48,7 @@ func TestTimeoutTimedOut(t *testing.T) {
return
case <-time.After(time.Second * 1):
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
w.Write([]byte("resp")) // nolint:errcheck
}
})