middleware/timeout/timeout.go

82 lines
1.9 KiB
Go
Raw Normal View History

2020-11-07 11:59:33 +00:00
package timeout
import (
"context"
"net/http"
"time"
"go.pkg.cx/middleware"
)
// DefaultOptions represents default timeout middleware options
var DefaultOptions = Options(
SetTimeout(time.Second*30),
SetResponseHandler(RespondWithTimeout),
)
2023-06-25 21:40:26 +00:00
// Options turns a list of option instances into an option
2020-11-07 11:59:33 +00:00
func Options(opts ...Option) Option {
return func(t *timeout) {
for _, opt := range opts {
opt(t)
}
}
}
// Option configures timeout middleware
type Option func(t *timeout)
// SetTimeout sets request timeout
func SetTimeout(limit time.Duration) Option {
return func(t *timeout) {
t.timeout = limit
}
}
// SetResponseHandler sets response handler
func SetResponseHandler(fn middleware.ResponseHandle) Option {
return func(t *timeout) {
t.responseHandler = fn
}
}
type timeout struct {
timeout time.Duration
responseHandler middleware.ResponseHandle
}
// Middleware is a timeout middleware that cancels ctx after a given timeout
// and return a 504 Gateway Timeout error to the client.
//
// It's required that you select the ctx.Done() channel to check for the signal
// if the context has reached its deadline and return, otherwise the timeout
// signal will be just ignored
func Middleware(opts ...Option) func(next http.Handler) http.Handler {
t := &timeout{}
opts = append([]Option{DefaultOptions}, opts...)
for _, opt := range opts {
opt(t)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), t.timeout)
defer func() {
cancel()
if ctx.Err() == context.DeadlineExceeded {
t.responseHandler(w, r, ctx.Err())
return
}
}()
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// RespondWithTimeout is a default response handler
2023-06-25 20:10:30 +00:00
func RespondWithTimeout(w http.ResponseWriter, _ *http.Request, _ error) {
2020-11-07 11:59:33 +00:00
http.Error(w, http.StatusText(http.StatusGatewayTimeout), http.StatusGatewayTimeout)
}