Remove ratelimit middleware

This commit is contained in:
Anton Zadvorny 2023-12-25 03:54:03 +03:00
parent 5a4fe6a63f
commit b2fbf1f4d7
4 changed files with 0 additions and 347 deletions

1
go.mod
View File

@ -3,7 +3,6 @@ module go.pkg.cx/middleware
go 1.21
require (
github.com/go-pkgz/expirable-cache v1.0.0
github.com/lestrrat-go/jwx/v2 v2.0.18
github.com/stretchr/testify v1.8.4
)

8
go.sum
View File

@ -4,8 +4,6 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs=
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0=
github.com/go-pkgz/expirable-cache v1.0.0 h1:ns5+1hjY8hntGv8bPaQd9Gr7Jyo+Uw5SLyII40aQdtA=
github.com/go-pkgz/expirable-cache v1.0.0/go.mod h1:GTrEl0X+q0mPNqN6dtcQXksACnzCBQ5k/k1SwXJsZKs=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
@ -16,8 +14,6 @@ github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJG
github.com/lestrrat-go/httprc v1.0.4/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
github.com/lestrrat-go/jwx/v2 v2.0.17 h1:+WavkdKVWO90ECnIzUetOnjY+kcqqw4WXEUmil7sMCE=
github.com/lestrrat-go/jwx/v2 v2.0.17/go.mod h1:G8randPHLGAqhcNCqtt6/V/7E6fvJRl3Sf9z777eTQ0=
github.com/lestrrat-go/jwx/v2 v2.0.18 h1:HHZkYS5wWDDyAiNBwztEtDoX07WDhGEdixm8G06R50o=
github.com/lestrrat-go/jwx/v2 v2.0.18/go.mod h1:fAJ+k5eTgKdDqanzCuK6DAt3W7n3cs2/FX7JhQdk83U=
github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
@ -38,8 +34,6 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g=
golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY=
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
@ -61,14 +55,12 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
golang.org/x/term v0.14.0/go.mod h1:TySc+nGkYR6qt8km8wUhuFRTVSMIX3XPR58y2lC8vww=
golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=

View File

@ -1,175 +0,0 @@
package ratelimit
import (
"errors"
"net/http"
"strconv"
"strings"
"sync"
"time"
cache "github.com/go-pkgz/expirable-cache"
"go.pkg.cx/middleware"
)
// Errors
var (
ErrLimitReached = errors.New("limit reached")
)
// DefaultOptions represents default timeout middleware options
var DefaultOptions = Options(
SetLimit(100),
SetPeriod(time.Minute*1),
SetKeyFn(defaultKeyFn),
SetResponseHandler(RespondWithTooManyRequests),
)
// Options turns a list of option instances into an option
func Options(opts ...Option) Option {
return func(l *limiter) {
for _, opt := range opts {
opt(l)
}
}
}
// Option configures timeout middleware
type Option func(l *limiter)
// SetLimit sets request limit
func SetLimit(limit int) Option {
if limit < 1 {
panic("rate limit middleware expects limit > 0")
}
return func(l *limiter) {
l.limit = limit
}
}
// SetPeriod sets limiter period
func SetPeriod(period time.Duration) Option {
return func(l *limiter) {
l.period = period
}
}
// SetKeyFn sets limiter key extraction function
func SetKeyFn(fn func(r *http.Request) string) Option {
return func(l *limiter) {
l.keyFn = fn
}
}
// SetResponseHandler sets response handler
func SetResponseHandler(fn middleware.ResponseHandle) Option {
return func(l *limiter) {
l.responseHandler = fn
}
}
type info struct {
limit int
remaining int
reset int64
reached bool
}
type entry struct {
count int
expiration time.Time
}
type limiter struct {
limit int
period time.Duration
keyFn func(r *http.Request) string
responseHandler middleware.ResponseHandle
lock sync.Mutex
cache cache.Cache
}
func (s *limiter) initCache() {
c, err := cache.NewCache(cache.TTL(s.period))
if err != nil {
panic(err)
}
s.cache = c
}
func (s *limiter) try(key string) info {
s.lock.Lock()
defer s.lock.Unlock()
now := time.Now()
if e, ok := s.cache.Get(key); ok {
e.(*entry).count++
s.cache.Set(key, e, 0)
return s.infoFromEntry(e.(*entry))
}
e := &entry{count: 1, expiration: now.Add(s.period)}
s.cache.Set(key, e, 0)
return s.infoFromEntry(e)
}
func (s *limiter) infoFromEntry(e *entry) info {
reached := true
remaining := 0
if e.count <= s.limit {
reached = false
remaining = s.limit - e.count
}
return info{
limit: s.limit,
remaining: remaining,
reset: e.expiration.Unix(),
reached: reached,
}
}
// Middleware is a rate limiter middleware
func Middleware(opts ...Option) func(next http.Handler) http.Handler {
l := &limiter{}
opts = append([]Option{DefaultOptions}, opts...)
for _, opt := range opts {
opt(l)
}
l.initCache()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
info := l.try(l.keyFn(r))
w.Header().Add("X-RateLimit-Limit", strconv.Itoa(info.limit))
w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(info.remaining))
w.Header().Add("X-RateLimit-Reset", strconv.FormatInt(info.reset, 10))
if info.reached {
l.responseHandler(w, r, ErrLimitReached)
return
}
next.ServeHTTP(w, r)
})
}
}
// RespondWithTooManyRequests is a default response handler
func RespondWithTooManyRequests(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
}
func defaultKeyFn(r *http.Request) string {
return strings.Split(r.RemoteAddr, ":")[0]
}

View File

@ -1,163 +0,0 @@
package ratelimit
import (
"io"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRateLimitSequential(t *testing.T) {
rateLimit := Middleware(
SetLimit(5),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(rateLimit(handler))
defer server.Close()
for i := 0; i < 10; i++ {
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
if i < 5 {
assert.Equal(t, http.StatusOK, res.StatusCode)
} else {
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
}
}
}
func TestRateLimitConcurrent(t *testing.T) {
rateLimit := Middleware(
SetLimit(5),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(rateLimit(handler))
defer server.Close()
counter := int64(0)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
if res.StatusCode == http.StatusOK {
atomic.AddInt64(&counter, 1)
}
}()
}
wg.Wait()
assert.Equal(t, int64(5), atomic.LoadInt64(&counter))
}
func TestRateLimitHeaders(t *testing.T) {
rateLimit := Middleware(
SetLimit(1),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(rateLimit(handler))
defer server.Close()
now := time.Now()
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit"))
assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining"))
resetTS, err := strconv.Atoi(res.Header.Get("X-RateLimit-Reset"))
assert.NoError(t, err)
assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1)
b, err := io.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
res, err = http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit"))
assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining"))
resetTS, err = strconv.Atoi(res.Header.Get("X-RateLimit-Reset"))
assert.NoError(t, err)
assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1)
b, err = io.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "Too Many Requests\n", string(b))
}
func TestRateLimitExpiration(t *testing.T) {
rateLimit := Middleware(
SetLimit(5),
SetPeriod(time.Millisecond*500),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(rateLimit(handler))
defer server.Close()
for i := 0; i < 10; i++ {
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
if i < 5 {
assert.Equal(t, http.StatusOK, res.StatusCode)
} else {
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
}
}
time.Sleep(time.Millisecond * 500)
for i := 0; i < 10; i++ {
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
if i < 5 {
assert.Equal(t, http.StatusOK, res.StatusCode)
} else {
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
}
}
}