middleware/throttle/throttle.go
2024-01-07 06:09:11 +03:00

172 lines
3.8 KiB
Go

package throttle
import (
"errors"
"net/http"
"strconv"
"time"
"go.pkg.cx/middleware"
)
// Errors
var (
ErrCapacityExceeded = errors.New("server capacity exceeded")
ErrTimedOut = errors.New("timed out while waiting for a pending request to complete")
ErrContextCanceled = errors.New("context was canceled")
)
// DefaultOptions represents default throttle middleware options
var DefaultOptions = Options(
SetLimit(100),
SetResponseHandler(middleware.RespondWithTooManyRequests),
)
// Options turns a list of option instances into an option
func Options(opts ...Option) Option {
return func(t *throttle) {
for _, opt := range opts {
opt(t)
}
}
}
// Option configures throttle middleware
type Option func(t *throttle)
// SetLimit sets requests limit
func SetLimit(limit int) Option {
if limit < 1 {
panic("throttle middleware expects limit > 0")
}
return func(t *throttle) {
t.limit = limit
}
}
// SetBacklogLimit sets backlog requests limit
func SetBacklogLimit(backlogLimit int) Option {
if backlogLimit < 0 {
panic("throttle middleware expects backlogLimit >= 0")
}
return func(t *throttle) {
t.backlogLimit = backlogLimit
}
}
// SetBacklogTimeout sets backlog timeout
func SetBacklogTimeout(backlogTimeout time.Duration) Option {
return func(t *throttle) {
t.backlogTimeout = backlogTimeout
}
}
// SetRetryAfterFn sets retry after function
func SetRetryAfterFn(fn func(ctxDone bool) time.Duration) Option {
return func(t *throttle) {
t.retryAfterFn = fn
}
}
// SetResponseHandler sets response handler
func SetResponseHandler(fn middleware.ResponseHandle) Option {
return func(t *throttle) {
t.responseHandler = fn
}
}
type throttle struct {
limit int
backlogLimit int
backlogTimeout time.Duration
retryAfterFn func(ctxDone bool) time.Duration
responseHandler middleware.ResponseHandle
tokens chan struct{}
backlogTokens chan struct{}
}
func (s *throttle) initTokens() {
s.tokens = make(chan struct{}, s.limit)
s.backlogTokens = make(chan struct{}, s.limit+s.backlogLimit)
for i := 0; i < s.limit+s.backlogLimit; i++ {
if i < s.limit {
s.tokens <- struct{}{}
}
s.backlogTokens <- struct{}{}
}
}
func (s *throttle) setRetryAfterHeader(w http.ResponseWriter, ctxDone bool) {
if s.retryAfterFn == nil {
return
}
retryAfterSeconds := strconv.Itoa(int(s.retryAfterFn(ctxDone).Seconds()))
w.Header().Set("Retry-After", retryAfterSeconds)
}
// Middleware is a throttle middleware that limits number of currently processed requests
// at a time across all users. Note: Throttle is not a rate-limiter per user,
// instead it just puts a ceiling on the number of currently in-flight requests
// being processed from the point from where the Throttle middleware is mounted
func Middleware(opts ...Option) func(http.Handler) http.Handler {
t := &throttle{}
opts = append([]Option{DefaultOptions}, opts...)
for _, opt := range opts {
opt(t)
}
t.initTokens()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case <-r.Context().Done():
t.setRetryAfterHeader(w, true)
t.responseHandler(w, r, ErrContextCanceled)
return
case btok := <-t.backlogTokens:
timer := time.NewTimer(t.backlogTimeout)
defer func() {
t.backlogTokens <- btok
}()
select {
case <-timer.C:
t.setRetryAfterHeader(w, false)
t.responseHandler(w, r, ErrTimedOut)
return
case <-r.Context().Done():
timer.Stop()
t.setRetryAfterHeader(w, true)
t.responseHandler(w, r, ErrContextCanceled)
return
case tok := <-t.tokens:
defer func() {
timer.Stop()
t.tokens <- tok
}()
next.ServeHTTP(w, r)
return
}
default:
t.setRetryAfterHeader(w, false)
t.responseHandler(w, r, ErrCapacityExceeded)
return
}
})
}
}