package jwt import ( "context" "errors" "net/http" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" "go.pkg.cx/middleware" ) // Errors var ( ErrTokenInvalid = errors.New("token invalid") ErrTokenNotFound = errors.New("token not found") ) // Context keys var ( JWTCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/auth/jwt", Name: "JWT"} DataCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/auth/jwt", Name: "Data"} ) // DefaultOptions represents default jwt auth middleware options var DefaultOptions = Options( WithFindTokenFn(middleware.TokenFromAuthorizationHeader), WithFindTokenFn(middleware.TokenFromQuery("jwt")), WithFindTokenFn(middleware.TokenFromCookie("jwt")), SetResponseHandler(middleware.RespondWithUnauthorized), SetValidateTokenFn(allowAll), ) // Options turns a list of option instances into an option func Options(opts ...Option) Option { return func(a *auth) { for _, opt := range opts { opt(a) } } } // Option configures jwt auth middleware type Option func(a *auth) // SetKey sets key to verify jwt token with func SetKey(key interface{}) Option { return func(a *auth) { a.key = key } } // SetAlgorithm sets algorithm to verify jwt token with func SetAlgorithm(alg jwa.SignatureAlgorithm) Option { return func(a *auth) { a.algorithm = alg } } // WithFindTokenFn adds jwt token find function to the list func WithFindTokenFn(fn func(r *http.Request) string) Option { return func(a *auth) { a.findTokenFns = append(a.findTokenFns, fn) } } // SetFindTokenFns sets jwt token find functions list func SetFindTokenFns(fns ...func(r *http.Request) string) Option { return func(a *auth) { a.findTokenFns = fns } } // WithVerifyOption adds jwt verify option to the list func WithVerifyOption(opt jwt.ValidateOption) Option { return func(a *auth) { a.verifyOptions = append(a.verifyOptions, opt) } } // SetVerifyOptions sets jwt verify options list func SetVerifyOptions(opts ...jwt.ValidateOption) Option { return func(a *auth) { a.verifyOptions = opts } } // SetValidateTokenFn sets token validation function func SetValidateTokenFn(fn func(token jwt.Token) (bool, interface{})) Option { return func(a *auth) { a.validateTokenFn = fn } } // SetResponseHandler sets response handler func SetResponseHandler(fn middleware.ResponseHandle) Option { return func(a *auth) { a.responseHandler = fn } } type auth struct { key interface{} algorithm jwa.SignatureAlgorithm verifyOptions []jwt.ValidateOption findTokenFns []func(r *http.Request) string validateTokenFn func(token jwt.Token) (bool, interface{}) responseHandler middleware.ResponseHandle } // Middleware returns jwt auth middleware func Middleware(key interface{}, alg jwa.SignatureAlgorithm, opts ...Option) func(next http.Handler) http.Handler { a := &auth{} opts = append(opts, SetKey(key), SetAlgorithm(alg)) opts = append([]Option{DefaultOptions}, opts...) for _, opt := range opts { opt(a) } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var token string for _, findTokenFn := range a.findTokenFns { token = findTokenFn(r) if token != "" { break } } if token == "" { a.responseHandler(w, r, ErrTokenNotFound) return } jwtToken, err := jwt.ParseString(token, jwt.WithKey(a.algorithm, a.key)) if err != nil { a.responseHandler(w, r, err) return } if err := jwt.Validate(jwtToken, a.verifyOptions...); err != nil { a.responseHandler(w, r, err) return } valid, data := a.validateTokenFn(jwtToken) if !valid { a.responseHandler(w, r, ErrTokenInvalid) return } ctx := r.Context() ctx = context.WithValue(ctx, JWTCtxKey, jwtToken) ctx = context.WithValue(ctx, DataCtxKey, data) next.ServeHTTP(w, r.WithContext(ctx)) }) } } func allowAll(_ jwt.Token) (bool, interface{}) { return true, nil }