middleware/auth/token/token.go
2024-01-07 06:09:11 +03:00

119 lines
2.7 KiB
Go

package token
import (
"context"
"errors"
"net/http"
"go.pkg.cx/middleware"
)
// Errors
var (
ErrTokenInvalid = errors.New("token invalid")
ErrTokenNotFound = errors.New("token not found")
)
// Context keys
var (
DataCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/auth/token", Name: "Data"}
TokenCtxKey = &middleware.CtxKey{Pkg: "go.pkg.cx/middleware/auth/token", Name: "Token"}
)
// DefaultOptions represents default token auth middleware options
var DefaultOptions = Options(
WithFindTokenFn(middleware.TokenFromHeader("X-Token")),
WithFindTokenFn(middleware.TokenFromQuery("token")),
SetValidateTokenFn(rejectAll),
SetResponseHandler(middleware.RespondWithUnauthorized),
)
// Options turns a list of option instances into an option
func Options(opts ...Option) Option {
return func(a *auth) {
for _, opt := range opts {
opt(a)
}
}
}
// Option configures token auth middleware
type Option func(a *auth)
// WithFindTokenFn adds token find function to the list
func WithFindTokenFn(fn func(r *http.Request) string) Option {
return func(a *auth) {
a.findTokenFns = append(a.findTokenFns, fn)
}
}
// SetFindTokenFns sets token find functions list
func SetFindTokenFns(fns ...func(r *http.Request) string) Option {
return func(a *auth) {
a.findTokenFns = fns
}
}
// SetValidateTokenFn sets token validation function
func SetValidateTokenFn(fn func(token string) (bool, interface{})) Option {
return func(a *auth) {
a.validateTokenFn = fn
}
}
// SetResponseHandler sets response handler
func SetResponseHandler(fn middleware.ResponseHandle) Option {
return func(a *auth) {
a.responseHandler = fn
}
}
type auth struct {
findTokenFns []func(r *http.Request) string
validateTokenFn func(token string) (bool, interface{})
responseHandler middleware.ResponseHandle
}
// Middleware returns token auth middleware
func Middleware(opts ...Option) func(next http.Handler) http.Handler {
a := &auth{}
opts = append([]Option{DefaultOptions}, opts...)
for _, opt := range opts {
opt(a)
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var token string
for _, findTokenFn := range a.findTokenFns {
token = findTokenFn(r)
if token != "" {
break
}
}
if token == "" {
a.responseHandler(w, r, ErrTokenNotFound)
return
}
valid, data := a.validateTokenFn(token)
if !valid {
a.responseHandler(w, r, ErrTokenInvalid)
return
}
ctx := r.Context()
ctx = context.WithValue(ctx, DataCtxKey, data)
ctx = context.WithValue(ctx, TokenCtxKey, token)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
func rejectAll(_ string) (bool, interface{}) {
return false, nil
}