package throttle import ( "errors" "net/http" "strconv" "time" "go.pkg.cx/middleware" ) // Errors var ( ErrCapacityExceeded = errors.New("server capacity exceeded") ErrTimedOut = errors.New("timed out while waiting for a pending request to complete") ErrContextCanceled = errors.New("context was canceled") ) // DefaultOptions represents default throttle middleware options var DefaultOptions = Options( SetLimit(100), SetResponseHandler(middleware.RespondWithTooManyRequests), ) // Options turns a list of option instances into an option func Options(opts ...Option) Option { return func(t *throttle) { for _, opt := range opts { opt(t) } } } // Option configures throttle middleware type Option func(t *throttle) // SetLimit sets requests limit func SetLimit(limit int) Option { if limit < 1 { panic("throttle middleware expects limit > 0") } return func(t *throttle) { t.limit = limit } } // SetBacklogLimit sets backlog requests limit func SetBacklogLimit(backlogLimit int) Option { if backlogLimit < 0 { panic("throttle middleware expects backlogLimit >= 0") } return func(t *throttle) { t.backlogLimit = backlogLimit } } // SetBacklogTimeout sets backlog timeout func SetBacklogTimeout(backlogTimeout time.Duration) Option { return func(t *throttle) { t.backlogTimeout = backlogTimeout } } // SetRetryAfterFn sets retry after function func SetRetryAfterFn(fn func(ctxDone bool) time.Duration) Option { return func(t *throttle) { t.retryAfterFn = fn } } // SetResponseHandler sets response handler func SetResponseHandler(fn middleware.ResponseHandle) Option { return func(t *throttle) { t.responseHandler = fn } } type throttle struct { limit int backlogLimit int backlogTimeout time.Duration retryAfterFn func(ctxDone bool) time.Duration responseHandler middleware.ResponseHandle tokens chan struct{} backlogTokens chan struct{} } func (s *throttle) initTokens() { s.tokens = make(chan struct{}, s.limit) s.backlogTokens = make(chan struct{}, s.limit+s.backlogLimit) for i := 0; i < s.limit+s.backlogLimit; i++ { if i < s.limit { s.tokens <- struct{}{} } s.backlogTokens <- struct{}{} } } func (s *throttle) setRetryAfterHeader(w http.ResponseWriter, ctxDone bool) { if s.retryAfterFn == nil { return } retryAfterSeconds := strconv.Itoa(int(s.retryAfterFn(ctxDone).Seconds())) w.Header().Set("Retry-After", retryAfterSeconds) } // Middleware is a throttle middleware that limits number of currently processed requests // at a time across all users. Note: Throttle is not a rate-limiter per user, // instead it just puts a ceiling on the number of currently in-flight requests // being processed from the point from where the Throttle middleware is mounted func Middleware(opts ...Option) func(http.Handler) http.Handler { t := &throttle{} opts = append([]Option{DefaultOptions}, opts...) for _, opt := range opts { opt(t) } t.initTokens() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { select { case <-r.Context().Done(): t.setRetryAfterHeader(w, true) t.responseHandler(w, r, ErrContextCanceled) return case btok := <-t.backlogTokens: timer := time.NewTimer(t.backlogTimeout) defer func() { t.backlogTokens <- btok }() select { case <-timer.C: t.setRetryAfterHeader(w, false) t.responseHandler(w, r, ErrTimedOut) return case <-r.Context().Done(): timer.Stop() t.setRetryAfterHeader(w, true) t.responseHandler(w, r, ErrContextCanceled) return case tok := <-t.tokens: defer func() { timer.Stop() t.tokens <- tok }() next.ServeHTTP(w, r) return } default: t.setRetryAfterHeader(w, false) t.responseHandler(w, r, ErrCapacityExceeded) return } }) } }