middleware/ratelimit/rate_limit.go
2023-12-25 03:47:26 +03:00

176 lines
3.4 KiB
Go

package ratelimit
import (
"errors"
"net/http"
"strconv"
"strings"
"sync"
"time"
cache "github.com/go-pkgz/expirable-cache"
"go.pkg.cx/middleware"
)
// Errors
var (
ErrLimitReached = errors.New("limit reached")
)
// DefaultOptions represents default timeout middleware options
var DefaultOptions = Options(
SetLimit(100),
SetPeriod(time.Minute*1),
SetKeyFn(defaultKeyFn),
SetResponseHandler(RespondWithTooManyRequests),
)
// Options turns a list of option instances into an option
func Options(opts ...Option) Option {
return func(l *limiter) {
for _, opt := range opts {
opt(l)
}
}
}
// Option configures timeout middleware
type Option func(l *limiter)
// SetLimit sets request limit
func SetLimit(limit int) Option {
if limit < 1 {
panic("rate limit middleware expects limit > 0")
}
return func(l *limiter) {
l.limit = limit
}
}
// SetPeriod sets limiter period
func SetPeriod(period time.Duration) Option {
return func(l *limiter) {
l.period = period
}
}
// SetKeyFn sets limiter key extraction function
func SetKeyFn(fn func(r *http.Request) string) Option {
return func(l *limiter) {
l.keyFn = fn
}
}
// SetResponseHandler sets response handler
func SetResponseHandler(fn middleware.ResponseHandle) Option {
return func(l *limiter) {
l.responseHandler = fn
}
}
type info struct {
limit int
remaining int
reset int64
reached bool
}
type entry struct {
count int
expiration time.Time
}
type limiter struct {
limit int
period time.Duration
keyFn func(r *http.Request) string
responseHandler middleware.ResponseHandle
lock sync.Mutex
cache cache.Cache
}
func (s *limiter) initCache() {
c, err := cache.NewCache(cache.TTL(s.period))
if err != nil {
panic(err)
}
s.cache = c
}
func (s *limiter) try(key string) info {
s.lock.Lock()
defer s.lock.Unlock()
now := time.Now()
if e, ok := s.cache.Get(key); ok {
e.(*entry).count++
s.cache.Set(key, e, 0)
return s.infoFromEntry(e.(*entry))
}
e := &entry{count: 1, expiration: now.Add(s.period)}
s.cache.Set(key, e, 0)
return s.infoFromEntry(e)
}
func (s *limiter) infoFromEntry(e *entry) info {
reached := true
remaining := 0
if e.count <= s.limit {
reached = false
remaining = s.limit - e.count
}
return info{
limit: s.limit,
remaining: remaining,
reset: e.expiration.Unix(),
reached: reached,
}
}
// Middleware is a rate limiter middleware
func Middleware(opts ...Option) func(next http.Handler) http.Handler {
l := &limiter{}
opts = append([]Option{DefaultOptions}, opts...)
for _, opt := range opts {
opt(l)
}
l.initCache()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
info := l.try(l.keyFn(r))
w.Header().Add("X-RateLimit-Limit", strconv.Itoa(info.limit))
w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(info.remaining))
w.Header().Add("X-RateLimit-Reset", strconv.FormatInt(info.reset, 10))
if info.reached {
l.responseHandler(w, r, ErrLimitReached)
return
}
next.ServeHTTP(w, r)
})
}
}
// RespondWithTooManyRequests is a default response handler
func RespondWithTooManyRequests(w http.ResponseWriter, r *http.Request, err error) {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
}
func defaultKeyFn(r *http.Request) string {
return strings.Split(r.RemoteAddr, ":")[0]
}