Remove ratelimit middleware
This commit is contained in:
parent
5a4fe6a63f
commit
b2fbf1f4d7
1
go.mod
1
go.mod
@ -3,7 +3,6 @@ module go.pkg.cx/middleware
|
||||
go 1.21
|
||||
|
||||
require (
|
||||
github.com/go-pkgz/expirable-cache v1.0.0
|
||||
github.com/lestrrat-go/jwx/v2 v2.0.18
|
||||
github.com/stretchr/testify v1.8.4
|
||||
)
|
||||
|
8
go.sum
8
go.sum
@ -4,8 +4,6 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs
|
||||
github.com/decred/dcrd/crypto/blake256 v1.0.1/go.mod h1:2OfgNZ5wDpcsFmHmCK5gZTPcCXqlm2ArzUIkw9czNJo=
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0 h1:8UrgZ3GkP4i/CLijOJx79Yu+etlyjdBU4sfcs2WYQMs=
|
||||
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0=
|
||||
github.com/go-pkgz/expirable-cache v1.0.0 h1:ns5+1hjY8hntGv8bPaQd9Gr7Jyo+Uw5SLyII40aQdtA=
|
||||
github.com/go-pkgz/expirable-cache v1.0.0/go.mod h1:GTrEl0X+q0mPNqN6dtcQXksACnzCBQ5k/k1SwXJsZKs=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/lestrrat-go/blackmagic v1.0.2 h1:Cg2gVSc9h7sz9NOByczrbUvLopQmXrfFx//N+AkAr5k=
|
||||
@ -16,8 +14,6 @@ github.com/lestrrat-go/httprc v1.0.4 h1:bAZymwoZQb+Oq8MEbyipag7iSq6YIga8Wj6GOiJG
|
||||
github.com/lestrrat-go/httprc v1.0.4/go.mod h1:mwwz3JMTPBjHUkkDv/IGJ39aALInZLrhBp0X7KGUZlo=
|
||||
github.com/lestrrat-go/iter v1.0.2 h1:gMXo1q4c2pHmC3dn8LzRhJfP1ceCbgSiT9lUydIzltI=
|
||||
github.com/lestrrat-go/iter v1.0.2/go.mod h1:Momfcq3AnRlRjI5b5O8/G5/BvpzrhoFTZcn06fEOPt4=
|
||||
github.com/lestrrat-go/jwx/v2 v2.0.17 h1:+WavkdKVWO90ECnIzUetOnjY+kcqqw4WXEUmil7sMCE=
|
||||
github.com/lestrrat-go/jwx/v2 v2.0.17/go.mod h1:G8randPHLGAqhcNCqtt6/V/7E6fvJRl3Sf9z777eTQ0=
|
||||
github.com/lestrrat-go/jwx/v2 v2.0.18 h1:HHZkYS5wWDDyAiNBwztEtDoX07WDhGEdixm8G06R50o=
|
||||
github.com/lestrrat-go/jwx/v2 v2.0.18/go.mod h1:fAJ+k5eTgKdDqanzCuK6DAt3W7n3cs2/FX7JhQdk83U=
|
||||
github.com/lestrrat-go/option v1.0.0/go.mod h1:5ZHFbivi4xwXxhxY9XHDe2FHo6/Z7WWmtT7T5nBBp3I=
|
||||
@ -38,8 +34,6 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g=
|
||||
golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY=
|
||||
golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
||||
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
|
||||
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
|
||||
@ -61,14 +55,12 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
|
||||
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.14.0/go.mod h1:TySc+nGkYR6qt8km8wUhuFRTVSMIX3XPR58y2lC8vww=
|
||||
golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
|
@ -1,175 +0,0 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
cache "github.com/go-pkgz/expirable-cache"
|
||||
|
||||
"go.pkg.cx/middleware"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrLimitReached = errors.New("limit reached")
|
||||
)
|
||||
|
||||
// DefaultOptions represents default timeout middleware options
|
||||
var DefaultOptions = Options(
|
||||
SetLimit(100),
|
||||
SetPeriod(time.Minute*1),
|
||||
SetKeyFn(defaultKeyFn),
|
||||
SetResponseHandler(RespondWithTooManyRequests),
|
||||
)
|
||||
|
||||
// Options turns a list of option instances into an option
|
||||
func Options(opts ...Option) Option {
|
||||
return func(l *limiter) {
|
||||
for _, opt := range opts {
|
||||
opt(l)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Option configures timeout middleware
|
||||
type Option func(l *limiter)
|
||||
|
||||
// SetLimit sets request limit
|
||||
func SetLimit(limit int) Option {
|
||||
if limit < 1 {
|
||||
panic("rate limit middleware expects limit > 0")
|
||||
}
|
||||
|
||||
return func(l *limiter) {
|
||||
l.limit = limit
|
||||
}
|
||||
}
|
||||
|
||||
// SetPeriod sets limiter period
|
||||
func SetPeriod(period time.Duration) Option {
|
||||
return func(l *limiter) {
|
||||
l.period = period
|
||||
}
|
||||
}
|
||||
|
||||
// SetKeyFn sets limiter key extraction function
|
||||
func SetKeyFn(fn func(r *http.Request) string) Option {
|
||||
return func(l *limiter) {
|
||||
l.keyFn = fn
|
||||
}
|
||||
}
|
||||
|
||||
// SetResponseHandler sets response handler
|
||||
func SetResponseHandler(fn middleware.ResponseHandle) Option {
|
||||
return func(l *limiter) {
|
||||
l.responseHandler = fn
|
||||
}
|
||||
}
|
||||
|
||||
type info struct {
|
||||
limit int
|
||||
remaining int
|
||||
reset int64
|
||||
reached bool
|
||||
}
|
||||
|
||||
type entry struct {
|
||||
count int
|
||||
expiration time.Time
|
||||
}
|
||||
|
||||
type limiter struct {
|
||||
limit int
|
||||
period time.Duration
|
||||
keyFn func(r *http.Request) string
|
||||
responseHandler middleware.ResponseHandle
|
||||
|
||||
lock sync.Mutex
|
||||
cache cache.Cache
|
||||
}
|
||||
|
||||
func (s *limiter) initCache() {
|
||||
c, err := cache.NewCache(cache.TTL(s.period))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
s.cache = c
|
||||
}
|
||||
|
||||
func (s *limiter) try(key string) info {
|
||||
s.lock.Lock()
|
||||
defer s.lock.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
if e, ok := s.cache.Get(key); ok {
|
||||
e.(*entry).count++
|
||||
s.cache.Set(key, e, 0)
|
||||
|
||||
return s.infoFromEntry(e.(*entry))
|
||||
}
|
||||
|
||||
e := &entry{count: 1, expiration: now.Add(s.period)}
|
||||
s.cache.Set(key, e, 0)
|
||||
|
||||
return s.infoFromEntry(e)
|
||||
}
|
||||
|
||||
func (s *limiter) infoFromEntry(e *entry) info {
|
||||
reached := true
|
||||
remaining := 0
|
||||
|
||||
if e.count <= s.limit {
|
||||
reached = false
|
||||
remaining = s.limit - e.count
|
||||
}
|
||||
|
||||
return info{
|
||||
limit: s.limit,
|
||||
remaining: remaining,
|
||||
reset: e.expiration.Unix(),
|
||||
reached: reached,
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware is a rate limiter middleware
|
||||
func Middleware(opts ...Option) func(next http.Handler) http.Handler {
|
||||
l := &limiter{}
|
||||
opts = append([]Option{DefaultOptions}, opts...)
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(l)
|
||||
}
|
||||
|
||||
l.initCache()
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
info := l.try(l.keyFn(r))
|
||||
|
||||
w.Header().Add("X-RateLimit-Limit", strconv.Itoa(info.limit))
|
||||
w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(info.remaining))
|
||||
w.Header().Add("X-RateLimit-Reset", strconv.FormatInt(info.reset, 10))
|
||||
|
||||
if info.reached {
|
||||
l.responseHandler(w, r, ErrLimitReached)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RespondWithTooManyRequests is a default response handler
|
||||
func RespondWithTooManyRequests(w http.ResponseWriter, r *http.Request, err error) {
|
||||
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
func defaultKeyFn(r *http.Request) string {
|
||||
return strings.Split(r.RemoteAddr, ":")[0]
|
||||
}
|
@ -1,163 +0,0 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRateLimitSequential(t *testing.T) {
|
||||
rateLimit := Middleware(
|
||||
SetLimit(5),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(rateLimit(handler))
|
||||
defer server.Close()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
res, err := http.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
if i < 5 {
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
} else {
|
||||
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitConcurrent(t *testing.T) {
|
||||
rateLimit := Middleware(
|
||||
SetLimit(5),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(rateLimit(handler))
|
||||
defer server.Close()
|
||||
|
||||
counter := int64(0)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
res, err := http.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode == http.StatusOK {
|
||||
atomic.AddInt64(&counter, 1)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
assert.Equal(t, int64(5), atomic.LoadInt64(&counter))
|
||||
}
|
||||
|
||||
func TestRateLimitHeaders(t *testing.T) {
|
||||
rateLimit := Middleware(
|
||||
SetLimit(1),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(rateLimit(handler))
|
||||
defer server.Close()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
res, err := http.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit"))
|
||||
assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining"))
|
||||
|
||||
resetTS, err := strconv.Atoi(res.Header.Get("X-RateLimit-Reset"))
|
||||
assert.NoError(t, err)
|
||||
assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1)
|
||||
|
||||
b, err := io.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "resp", string(b))
|
||||
|
||||
res, err = http.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
|
||||
assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit"))
|
||||
assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining"))
|
||||
|
||||
resetTS, err = strconv.Atoi(res.Header.Get("X-RateLimit-Reset"))
|
||||
assert.NoError(t, err)
|
||||
assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1)
|
||||
|
||||
b, err = io.ReadAll(res.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Too Many Requests\n", string(b))
|
||||
}
|
||||
|
||||
func TestRateLimitExpiration(t *testing.T) {
|
||||
rateLimit := Middleware(
|
||||
SetLimit(5),
|
||||
SetPeriod(time.Millisecond*500),
|
||||
)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(rateLimit(handler))
|
||||
defer server.Close()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
res, err := http.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
if i < 5 {
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
} else {
|
||||
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond * 500)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
res, err := http.Get(server.URL)
|
||||
assert.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
|
||||
if i < 5 {
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
} else {
|
||||
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user