diff --git a/go.mod b/go.mod index 6f9a7bb..8256374 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module go.pkg.cx/middleware go 1.21 require ( - github.com/go-pkgz/expirable-cache v1.0.0 github.com/lestrrat-go/jwx/v2 v2.0.18 github.com/stretchr/testify v1.8.4 ) diff --git a/go.sum b/go.sum index 2c1db99..a639f91 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,6 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= -github.com/go-pkgz/expirable-cache v1.0.0 h1:ns5+1hjY8hntGv8bPaQd9Gr7Jyo+Uw5SLyII40aQdtA= -github.com/go-pkgz/expirable-cache v1.0.0/go.mod h1:GTrEl0X+q0mPNqN6dtcQXksACnzCBQ5k/k1SwXJsZKs= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k= @@ -16,8 +14,6 @@ github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJG github.com/lestrrat-go/httprc v1.0.4/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo= github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI= github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4= -github.com/lestrrat-go/jwx/v2 v2.0.17 h1:+WavkdKVWO90ECnIzUetOnjY+kcqqw4WXEUmil7sMCE= -github.com/lestrrat-go/jwx/v2 v2.0.17/go.mod h1:G8randPHLGAqhcNCqtt6/V/7E6fvJRl3Sf9z777eTQ0= github.com/lestrrat-go/jwx/v2 v2.0.18 h1:HHZkYS5wWDDyAiNBwztEtDoX07WDhGEdixm8G06R50o= github.com/lestrrat-go/jwx/v2 v2.0.18/go.mod h1:fAJ+k5eTgKdDqanzCuK6DAt3W7n3cs2/FX7JhQdk83U= github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I= @@ -38,8 +34,6 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= -golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= @@ -61,14 +55,12 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.14.0/go.mod h1:TySc+nGkYR6qt8km8wUhuFRTVSMIX3XPR58y2lC8vww= golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= diff --git a/ratelimit/rate_limit.go b/ratelimit/rate_limit.go deleted file mode 100644 index 206558d..0000000 --- a/ratelimit/rate_limit.go +++ /dev/null @@ -1,175 +0,0 @@ -package ratelimit - -import ( - "errors" - "net/http" - "strconv" - "strings" - "sync" - "time" - - cache "github.com/go-pkgz/expirable-cache" - - "go.pkg.cx/middleware" -) - -// Errors -var ( - ErrLimitReached = errors.New("limit reached") -) - -// DefaultOptions represents default timeout middleware options -var DefaultOptions = Options( - SetLimit(100), - SetPeriod(time.Minute*1), - SetKeyFn(defaultKeyFn), - SetResponseHandler(RespondWithTooManyRequests), -) - -// Options turns a list of option instances into an option -func Options(opts ...Option) Option { - return func(l *limiter) { - for _, opt := range opts { - opt(l) - } - } -} - -// Option configures timeout middleware -type Option func(l *limiter) - -// SetLimit sets request limit -func SetLimit(limit int) Option { - if limit < 1 { - panic("rate limit middleware expects limit > 0") - } - - return func(l *limiter) { - l.limit = limit - } -} - -// SetPeriod sets limiter period -func SetPeriod(period time.Duration) Option { - return func(l *limiter) { - l.period = period - } -} - -// SetKeyFn sets limiter key extraction function -func SetKeyFn(fn func(r *http.Request) string) Option { - return func(l *limiter) { - l.keyFn = fn - } -} - -// SetResponseHandler sets response handler -func SetResponseHandler(fn middleware.ResponseHandle) Option { - return func(l *limiter) { - l.responseHandler = fn - } -} - -type info struct { - limit int - remaining int - reset int64 - reached bool -} - -type entry struct { - count int - expiration time.Time -} - -type limiter struct { - limit int - period time.Duration - keyFn func(r *http.Request) string - responseHandler middleware.ResponseHandle - - lock sync.Mutex - cache cache.Cache -} - -func (s *limiter) initCache() { - c, err := cache.NewCache(cache.TTL(s.period)) - if err != nil { - panic(err) - } - s.cache = c -} - -func (s *limiter) try(key string) info { - s.lock.Lock() - defer s.lock.Unlock() - - now := time.Now() - - if e, ok := s.cache.Get(key); ok { - e.(*entry).count++ - s.cache.Set(key, e, 0) - - return s.infoFromEntry(e.(*entry)) - } - - e := &entry{count: 1, expiration: now.Add(s.period)} - s.cache.Set(key, e, 0) - - return s.infoFromEntry(e) -} - -func (s *limiter) infoFromEntry(e *entry) info { - reached := true - remaining := 0 - - if e.count <= s.limit { - reached = false - remaining = s.limit - e.count - } - - return info{ - limit: s.limit, - remaining: remaining, - reset: e.expiration.Unix(), - reached: reached, - } -} - -// Middleware is a rate limiter middleware -func Middleware(opts ...Option) func(next http.Handler) http.Handler { - l := &limiter{} - opts = append([]Option{DefaultOptions}, opts...) - - for _, opt := range opts { - opt(l) - } - - l.initCache() - - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - info := l.try(l.keyFn(r)) - - w.Header().Add("X-RateLimit-Limit", strconv.Itoa(info.limit)) - w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(info.remaining)) - w.Header().Add("X-RateLimit-Reset", strconv.FormatInt(info.reset, 10)) - - if info.reached { - l.responseHandler(w, r, ErrLimitReached) - return - } - - next.ServeHTTP(w, r) - }) - } -} - -// RespondWithTooManyRequests is a default response handler -func RespondWithTooManyRequests(w http.ResponseWriter, r *http.Request, err error) { - http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) -} - -func defaultKeyFn(r *http.Request) string { - return strings.Split(r.RemoteAddr, ":")[0] -} diff --git a/ratelimit/rate_limit_test.go b/ratelimit/rate_limit_test.go deleted file mode 100644 index e666768..0000000 --- a/ratelimit/rate_limit_test.go +++ /dev/null @@ -1,163 +0,0 @@ -package ratelimit - -import ( - "io" - "net/http" - "net/http/httptest" - "strconv" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestRateLimitSequential(t *testing.T) { - rateLimit := Middleware( - SetLimit(5), - ) - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte("resp")) - require.NoError(t, err) - }) - - server := httptest.NewServer(rateLimit(handler)) - defer server.Close() - - for i := 0; i < 10; i++ { - res, err := http.Get(server.URL) - assert.NoError(t, err) - defer res.Body.Close() - - if i < 5 { - assert.Equal(t, http.StatusOK, res.StatusCode) - } else { - assert.Equal(t, http.StatusTooManyRequests, res.StatusCode) - } - } -} - -func TestRateLimitConcurrent(t *testing.T) { - rateLimit := Middleware( - SetLimit(5), - ) - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte("resp")) - require.NoError(t, err) - }) - - server := httptest.NewServer(rateLimit(handler)) - defer server.Close() - - counter := int64(0) - var wg sync.WaitGroup - - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - - res, err := http.Get(server.URL) - assert.NoError(t, err) - defer res.Body.Close() - - if res.StatusCode == http.StatusOK { - atomic.AddInt64(&counter, 1) - } - }() - } - - wg.Wait() - assert.Equal(t, int64(5), atomic.LoadInt64(&counter)) -} - -func TestRateLimitHeaders(t *testing.T) { - rateLimit := Middleware( - SetLimit(1), - ) - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte("resp")) - require.NoError(t, err) - }) - - server := httptest.NewServer(rateLimit(handler)) - defer server.Close() - - now := time.Now() - - res, err := http.Get(server.URL) - assert.NoError(t, err) - defer res.Body.Close() - assert.Equal(t, http.StatusOK, res.StatusCode) - assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit")) - assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining")) - - resetTS, err := strconv.Atoi(res.Header.Get("X-RateLimit-Reset")) - assert.NoError(t, err) - assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1) - - b, err := io.ReadAll(res.Body) - assert.NoError(t, err) - assert.Equal(t, "resp", string(b)) - - res, err = http.Get(server.URL) - assert.NoError(t, err) - defer res.Body.Close() - assert.Equal(t, http.StatusTooManyRequests, res.StatusCode) - assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit")) - assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining")) - - resetTS, err = strconv.Atoi(res.Header.Get("X-RateLimit-Reset")) - assert.NoError(t, err) - assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1) - - b, err = io.ReadAll(res.Body) - assert.NoError(t, err) - assert.Equal(t, "Too Many Requests\n", string(b)) -} - -func TestRateLimitExpiration(t *testing.T) { - rateLimit := Middleware( - SetLimit(5), - SetPeriod(time.Millisecond*500), - ) - - handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte("resp")) - require.NoError(t, err) - }) - - server := httptest.NewServer(rateLimit(handler)) - defer server.Close() - - for i := 0; i < 10; i++ { - res, err := http.Get(server.URL) - assert.NoError(t, err) - defer res.Body.Close() - - if i < 5 { - assert.Equal(t, http.StatusOK, res.StatusCode) - } else { - assert.Equal(t, http.StatusTooManyRequests, res.StatusCode) - } - } - - time.Sleep(time.Millisecond * 500) - - for i := 0; i < 10; i++ { - res, err := http.Get(server.URL) - assert.NoError(t, err) - defer res.Body.Close() - - if i < 5 { - assert.Equal(t, http.StatusOK, res.StatusCode) - } else { - assert.Equal(t, http.StatusTooManyRequests, res.StatusCode) - } - } -}