package token import ( "context" "errors" "net/http" "go.pkg.cx/middleware" ) // Errors var ( ErrTokenInvalid = errors.New("token invalid") ErrTokenNotFound = errors.New("token not found") ) // Context keys var ( DataCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/auth/token", Name: "Data"} TokenCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/auth/token", Name: "Token"} ) // DefaultOptions represents default token auth middleware options var DefaultOptions = Options( WithFindTokenFn(middleware.TokenFromHeader("X-Token")), WithFindTokenFn(middleware.TokenFromQuery("token")), SetValidateTokenFn(rejectAll), SetResponseHandler(middleware.RespondWithUnauthorized), ) // Options turns a list of option instances into an option func Options(opts ...Option) Option { return func(a *auth) { for _, opt := range opts { opt(a) } } } // Option configures token auth middleware type Option func(a *auth) // WithFindTokenFn adds token find function to the list func WithFindTokenFn(fn func(r *http.Request) string) Option { return func(a *auth) { a.findTokenFns = append(a.findTokenFns, fn) } } // SetFindTokenFns sets token find functions list func SetFindTokenFns(fns ...func(r *http.Request) string) Option { return func(a *auth) { a.findTokenFns = fns } } // SetValidateTokenFn sets token validation function func SetValidateTokenFn(fn func(token string) (bool, interface{})) Option { return func(a *auth) { a.validateTokenFn = fn } } // SetResponseHandler sets response handler func SetResponseHandler(fn middleware.ResponseHandle) Option { return func(a *auth) { a.responseHandler = fn } } type auth struct { findTokenFns []func(r *http.Request) string validateTokenFn func(token string) (bool, interface{}) responseHandler middleware.ResponseHandle } // Middleware returns token auth middleware func Middleware(opts ...Option) func(next http.Handler) http.Handler { a := &auth{} opts = append([]Option{DefaultOptions}, opts...) for _, opt := range opts { opt(a) } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var token string for _, findTokenFn := range a.findTokenFns { token = findTokenFn(r) if token != "" { break } } if token == "" { a.responseHandler(w, r, ErrTokenNotFound) return } valid, data := a.validateTokenFn(token) if !valid { a.responseHandler(w, r, ErrTokenInvalid) return } ctx := r.Context() ctx = context.WithValue(ctx, DataCtxKey, data) ctx = context.WithValue(ctx, TokenCtxKey, token) next.ServeHTTP(w, r.WithContext(ctx)) }) } } func rejectAll(_ string) (bool, interface{}) { return false, nil }