176 lines
3.4 KiB
Go
176 lines
3.4 KiB
Go
package ratelimit
|
|
|
|
import (
|
|
"errors"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
cache "github.com/go-pkgz/expirable-cache"
|
|
|
|
"go.pkg.cx/middleware"
|
|
)
|
|
|
|
// Errors
|
|
var (
|
|
ErrLimitReached = errors.New("limit reached")
|
|
)
|
|
|
|
// DefaultOptions represents default timeout middleware options
|
|
var DefaultOptions = Options(
|
|
SetLimit(100),
|
|
SetPeriod(time.Minute*1),
|
|
SetKeyFn(defaultKeyFn),
|
|
SetResponseHandler(RespondWithTooManyRequests),
|
|
)
|
|
|
|
// Options turns a list of option instances into an option
|
|
func Options(opts ...Option) Option {
|
|
return func(l *limiter) {
|
|
for _, opt := range opts {
|
|
opt(l)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Option configures timeout middleware
|
|
type Option func(l *limiter)
|
|
|
|
// SetLimit sets request limit
|
|
func SetLimit(limit int) Option {
|
|
if limit < 1 {
|
|
panic("rate limit middleware expects limit > 0")
|
|
}
|
|
|
|
return func(l *limiter) {
|
|
l.limit = limit
|
|
}
|
|
}
|
|
|
|
// SetPeriod sets limiter period
|
|
func SetPeriod(period time.Duration) Option {
|
|
return func(l *limiter) {
|
|
l.period = period
|
|
}
|
|
}
|
|
|
|
// SetKeyFn sets limiter key extraction function
|
|
func SetKeyFn(fn func(r *http.Request) string) Option {
|
|
return func(l *limiter) {
|
|
l.keyFn = fn
|
|
}
|
|
}
|
|
|
|
// SetResponseHandler sets response handler
|
|
func SetResponseHandler(fn middleware.ResponseHandle) Option {
|
|
return func(l *limiter) {
|
|
l.responseHandler = fn
|
|
}
|
|
}
|
|
|
|
type info struct {
|
|
limit int
|
|
remaining int
|
|
reset int64
|
|
reached bool
|
|
}
|
|
|
|
type entry struct {
|
|
count int
|
|
expiration time.Time
|
|
}
|
|
|
|
type limiter struct {
|
|
limit int
|
|
period time.Duration
|
|
keyFn func(r *http.Request) string
|
|
responseHandler middleware.ResponseHandle
|
|
|
|
lock sync.Mutex
|
|
cache cache.Cache
|
|
}
|
|
|
|
func (s *limiter) initCache() {
|
|
c, err := cache.NewCache(cache.TTL(s.period))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
s.cache = c
|
|
}
|
|
|
|
func (s *limiter) try(key string) info {
|
|
s.lock.Lock()
|
|
defer s.lock.Unlock()
|
|
|
|
now := time.Now()
|
|
|
|
if e, ok := s.cache.Get(key); ok {
|
|
e.(*entry).count++
|
|
s.cache.Set(key, e, 0)
|
|
|
|
return s.infoFromEntry(e.(*entry))
|
|
}
|
|
|
|
e := &entry{count: 1, expiration: now.Add(s.period)}
|
|
s.cache.Set(key, e, 0)
|
|
|
|
return s.infoFromEntry(e)
|
|
}
|
|
|
|
func (s *limiter) infoFromEntry(e *entry) info {
|
|
reached := true
|
|
remaining := 0
|
|
|
|
if e.count <= s.limit {
|
|
reached = false
|
|
remaining = s.limit - e.count
|
|
}
|
|
|
|
return info{
|
|
limit: s.limit,
|
|
remaining: remaining,
|
|
reset: e.expiration.Unix(),
|
|
reached: reached,
|
|
}
|
|
}
|
|
|
|
// Middleware is a rate limiter middleware
|
|
func Middleware(opts ...Option) func(next http.Handler) http.Handler {
|
|
l := &limiter{}
|
|
opts = append([]Option{DefaultOptions}, opts...)
|
|
|
|
for _, opt := range opts {
|
|
opt(l)
|
|
}
|
|
|
|
l.initCache()
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
info := l.try(l.keyFn(r))
|
|
|
|
w.Header().Add("X-RateLimit-Limit", strconv.Itoa(info.limit))
|
|
w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(info.remaining))
|
|
w.Header().Add("X-RateLimit-Reset", strconv.FormatInt(info.reset, 10))
|
|
|
|
if info.reached {
|
|
l.responseHandler(w, r, ErrLimitReached)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
// RespondWithTooManyRequests is a default response handler
|
|
func RespondWithTooManyRequests(w http.ResponseWriter, r *http.Request, err error) {
|
|
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
|
|
}
|
|
|
|
func defaultKeyFn(r *http.Request) string {
|
|
return strings.Split(r.RemoteAddr, ":")[0]
|
|
}
|