package ratelimit import ( "errors" "net/http" "strconv" "strings" "sync" "time" cache "github.com/go-pkgz/expirable-cache" "go.pkg.cx/middleware" ) // Errors var ( ErrLimitReached = errors.New("limit reached") ) // DefaultOptions represents default timeout middleware options var DefaultOptions = Options( SetLimit(100), SetPeriod(time.Minute*1), SetKeyFn(defaultKeyFn), SetResponseHandler(RespondWithTooManyRequests), ) // Options turns a list of option instances into an option. func Options(opts ...Option) Option { return func(l *limiter) { for _, opt := range opts { opt(l) } } } // Option configures timeout middleware type Option func(l *limiter) // SetLimit sets request limit func SetLimit(limit int) Option { if limit < 1 { panic("rate limit middleware expects limit > 0") } return func(l *limiter) { l.limit = limit } } // SetPeriod sets limiter period func SetPeriod(period time.Duration) Option { return func(l *limiter) { l.period = period } } // SetKeyFn sets limiter key extraction function func SetKeyFn(fn func(r *http.Request) string) Option { return func(l *limiter) { l.keyFn = fn } } // SetResponseHandler sets response handler func SetResponseHandler(fn middleware.ResponseHandle) Option { return func(l *limiter) { l.responseHandler = fn } } type info struct { limit int remaining int reset int64 reached bool } type entry struct { count int expiration time.Time } type limiter struct { limit int period time.Duration keyFn func(r *http.Request) string responseHandler middleware.ResponseHandle lock sync.Mutex cache cache.Cache } func (s *limiter) initCache() { c, err := cache.NewCache(cache.TTL(s.period)) if err != nil { panic(err) } s.cache = c } func (s *limiter) try(key string) info { s.lock.Lock() defer s.lock.Unlock() now := time.Now() if e, ok := s.cache.Get(key); ok { e.(*entry).count++ s.cache.Set(key, e, 0) return s.infoFromEntry(e.(*entry)) } e := &entry{count: 1, expiration: now.Add(s.period)} s.cache.Set(key, e, 0) return s.infoFromEntry(e) } func (s *limiter) infoFromEntry(e *entry) info { reached := true remaining := 0 if e.count <= s.limit { reached = false remaining = s.limit - e.count } return info{ limit: s.limit, remaining: remaining, reset: e.expiration.Unix(), reached: reached, } } // Middleware is a rate limiter middleware func Middleware(opts ...Option) func(next http.Handler) http.Handler { l := &limiter{} opts = append([]Option{DefaultOptions}, opts...) for _, opt := range opts { opt(l) } l.initCache() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { info := l.try(l.keyFn(r)) w.Header().Add("X-RateLimit-Limit", strconv.Itoa(info.limit)) w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(info.remaining)) w.Header().Add("X-RateLimit-Reset", strconv.FormatInt(info.reset, 10)) if info.reached { l.responseHandler(w, r, ErrLimitReached) return } next.ServeHTTP(w, r) }) } } // RespondWithTooManyRequests is a default response handler func RespondWithTooManyRequests(w http.ResponseWriter, r *http.Request, err error) { http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) } func defaultKeyFn(r *http.Request) string { return strings.Split(r.RemoteAddr, ":")[0] }