177 lines
4.0 KiB
Go
177 lines
4.0 KiB
Go
|
package throttle
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"net/http"
|
||
|
"strconv"
|
||
|
"time"
|
||
|
|
||
|
"go.pkg.cx/middleware"
|
||
|
)
|
||
|
|
||
|
// Errors
|
||
|
var (
|
||
|
ErrCapacityExceeded = errors.New("server capacity exceeded")
|
||
|
ErrTimedOut = errors.New("timed out while waiting for a pending request to complete")
|
||
|
ErrContextCanceled = errors.New("context was canceled")
|
||
|
)
|
||
|
|
||
|
// DefaultOptions represents default throttle middleware options
|
||
|
var DefaultOptions = Options(
|
||
|
SetLimit(100),
|
||
|
SetResponseHandler(RespondWithTooManyRequests),
|
||
|
)
|
||
|
|
||
|
// Options turns a list of option instances into an option.
|
||
|
func Options(opts ...Option) Option {
|
||
|
return func(t *throttle) {
|
||
|
for _, opt := range opts {
|
||
|
opt(t)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Option configures throttle middleware
|
||
|
type Option func(t *throttle)
|
||
|
|
||
|
// SetLimit sets requests limit
|
||
|
func SetLimit(limit int) Option {
|
||
|
if limit < 1 {
|
||
|
panic("throttle middleware expects limit > 0")
|
||
|
}
|
||
|
|
||
|
return func(t *throttle) {
|
||
|
t.limit = limit
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// SetBacklogLimit sets backlog requests limit
|
||
|
func SetBacklogLimit(backlogLimit int) Option {
|
||
|
if backlogLimit < 0 {
|
||
|
panic("throttle middleware expects backlogLimit >= 0")
|
||
|
}
|
||
|
|
||
|
return func(t *throttle) {
|
||
|
t.backlogLimit = backlogLimit
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// SetBacklogTimeout sets backlog timeout
|
||
|
func SetBacklogTimeout(backlogTimeout time.Duration) Option {
|
||
|
return func(t *throttle) {
|
||
|
t.backlogTimeout = backlogTimeout
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// SetRetryAfterFn sets retry after function
|
||
|
func SetRetryAfterFn(fn func(ctxDone bool) time.Duration) Option {
|
||
|
return func(t *throttle) {
|
||
|
t.retryAfterFn = fn
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// SetResponseHandler sets response handler
|
||
|
func SetResponseHandler(fn middleware.ResponseHandle) Option {
|
||
|
return func(t *throttle) {
|
||
|
t.responseHandler = fn
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type throttle struct {
|
||
|
limit int
|
||
|
backlogLimit int
|
||
|
backlogTimeout time.Duration
|
||
|
retryAfterFn func(ctxDone bool) time.Duration
|
||
|
responseHandler middleware.ResponseHandle
|
||
|
|
||
|
tokens chan struct{}
|
||
|
backlogTokens chan struct{}
|
||
|
}
|
||
|
|
||
|
func (s *throttle) initTokens() {
|
||
|
s.tokens = make(chan struct{}, s.limit)
|
||
|
s.backlogTokens = make(chan struct{}, s.limit+s.backlogLimit)
|
||
|
|
||
|
for i := 0; i < s.limit+s.backlogLimit; i++ {
|
||
|
if i < s.limit {
|
||
|
s.tokens <- struct{}{}
|
||
|
}
|
||
|
|
||
|
s.backlogTokens <- struct{}{}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (s *throttle) setRetryAfterHeader(w http.ResponseWriter, ctxDone bool) {
|
||
|
if s.retryAfterFn == nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
retryAfterSeconds := strconv.Itoa(int(s.retryAfterFn(ctxDone).Seconds()))
|
||
|
w.Header().Set("Retry-After", retryAfterSeconds)
|
||
|
}
|
||
|
|
||
|
// Middleware is a throttle middleware that limits number of currently processed requests
|
||
|
// at a time across all users. Note: Throttle is not a rate-limiter per user,
|
||
|
// instead it just puts a ceiling on the number of currentl in-flight requests
|
||
|
// being processed from the point from where the Throttle middleware is mounted
|
||
|
func Middleware(opts ...Option) func(http.Handler) http.Handler {
|
||
|
t := &throttle{}
|
||
|
opts = append([]Option{DefaultOptions}, opts...)
|
||
|
|
||
|
for _, opt := range opts {
|
||
|
opt(t)
|
||
|
}
|
||
|
|
||
|
t.initTokens()
|
||
|
|
||
|
return func(next http.Handler) http.Handler {
|
||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
select {
|
||
|
case <-r.Context().Done():
|
||
|
t.setRetryAfterHeader(w, true)
|
||
|
t.responseHandler(w, r, ErrContextCanceled)
|
||
|
return
|
||
|
|
||
|
case btok := <-t.backlogTokens:
|
||
|
timer := time.NewTimer(t.backlogTimeout)
|
||
|
|
||
|
defer func() {
|
||
|
t.backlogTokens <- btok
|
||
|
}()
|
||
|
|
||
|
select {
|
||
|
case <-timer.C:
|
||
|
t.setRetryAfterHeader(w, false)
|
||
|
t.responseHandler(w, r, ErrTimedOut)
|
||
|
return
|
||
|
|
||
|
case <-r.Context().Done():
|
||
|
timer.Stop()
|
||
|
t.setRetryAfterHeader(w, true)
|
||
|
t.responseHandler(w, r, ErrContextCanceled)
|
||
|
return
|
||
|
|
||
|
case tok := <-t.tokens:
|
||
|
defer func() {
|
||
|
timer.Stop()
|
||
|
t.tokens <- tok
|
||
|
}()
|
||
|
|
||
|
next.ServeHTTP(w, r)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
default:
|
||
|
t.setRetryAfterHeader(w, false)
|
||
|
t.responseHandler(w, r, ErrCapacityExceeded)
|
||
|
return
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// RespondWithTooManyRequests is a default response handler
|
||
|
func RespondWithTooManyRequests(w http.ResponseWriter, r *http.Request, err error) {
|
||
|
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
|
||
|
}
|