middleware/auth/jwt/jwt_test.go

193 lines
5.2 KiB
Go
Raw Normal View History

2020-11-07 11:59:33 +00:00
package jwt
import (
2023-12-25 00:47:12 +00:00
"io"
2020-11-07 11:59:33 +00:00
"net/http"
"net/http/httptest"
"testing"
"time"
2023-12-02 04:44:50 +00:00
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
2020-11-07 11:59:33 +00:00
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.pkg.cx/middleware"
)
var testHandler = func(t *testing.T) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
}
func TestJWTAuthDefaults(t *testing.T) {
auth := Middleware(
[]byte("changethis"),
jwa.HS256,
)
server := httptest.NewServer(auth(testHandler(t)))
defer server.Close()
res, err := http.Get(server.URL)
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
2023-12-25 00:47:12 +00:00
b, err := io.ReadAll(res.Body)
2020-11-07 11:59:33 +00:00
assert.NoError(t, err)
assert.Equal(t, "Unauthorized\n", string(b))
token := jwt.New()
2023-12-02 04:44:50 +00:00
payload, err := jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("changethis")))
2020-11-07 11:59:33 +00:00
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?jwt=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
2023-12-25 00:47:12 +00:00
b, err = io.ReadAll(res.Body)
2020-11-07 11:59:33 +00:00
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
}
func TestJWTAuthParseVerify(t *testing.T) {
auth := Middleware(
[]byte("tkey"),
jwa.HS512,
SetVerifyOptions(
jwt.WithSubject("tsubj"),
jwt.WithAudience("taud"),
),
SetFindTokenFns(
middleware.TokenFromQuery("token"),
),
)
server := httptest.NewServer(auth(testHandler(t)))
defer server.Close()
token := jwt.New()
2023-12-02 04:44:50 +00:00
payload, err := jwt.Sign(token, jwt.WithKey(jwa.HS512, []byte("tkey")))
2020-11-07 11:59:33 +00:00
assert.NoError(t, err)
res, err := http.Get(server.URL + "/?token=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
2023-12-02 04:44:50 +00:00
payload, err = jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("tkey")))
2020-11-07 11:59:33 +00:00
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?token=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
2023-12-02 04:44:50 +00:00
payload, err = jwt.Sign(token, jwt.WithKey(jwa.HS512, []byte("wrongkey")))
2020-11-07 11:59:33 +00:00
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?token=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
now := time.Now()
2023-06-25 20:10:30 +00:00
token.Set(jwt.IssuedAtKey, now.Add(-1*time.Hour)) // nolint:errcheck // No need to check error here
token.Set(jwt.ExpirationKey, now.Add(-58*time.Minute)) // nolint:errcheck // No need to check error here
2023-12-02 04:44:50 +00:00
payload, err = jwt.Sign(token, jwt.WithKey(jwa.HS512, []byte("tkey")))
2020-11-07 11:59:33 +00:00
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?token=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
}
func TestJWTAuthVerifyOptions(t *testing.T) {
auth := Middleware(
[]byte("changethis"),
jwa.HS256,
WithVerifyOption(jwt.WithIssuer("tissuer")),
)
server := httptest.NewServer(auth(testHandler(t)))
defer server.Close()
token := jwt.New()
2023-06-25 20:10:30 +00:00
token.Set(jwt.IssuerKey, "wrongissuer") // nolint:errcheck // No need to check error here
2023-12-02 04:44:50 +00:00
payload, err := jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("changethis")))
2020-11-07 11:59:33 +00:00
assert.NoError(t, err)
res, err := http.Get(server.URL + "/?jwt=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
2023-06-25 20:10:30 +00:00
token.Set(jwt.IssuerKey, "tissuer") // nolint:errcheck // No need to check error here
2023-12-02 04:44:50 +00:00
payload, err = jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("changethis")))
2020-11-07 11:59:33 +00:00
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?jwt=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}
func TestJWTAuthContext(t *testing.T) {
type data struct {
inner string
}
validateWithDataFn := func(token jwt.Token) (bool, interface{}) {
return token.JwtID() == "tid", &data{"test data"}
}
auth := Middleware(
[]byte("changethis"),
jwa.HS256,
SetValidateTokenFn(validateWithDataFn),
)
token := jwt.New()
2023-06-25 20:10:30 +00:00
token.Set(jwt.JwtIDKey, "tid") // nolint:errcheck // No need to check error here
2020-11-07 11:59:33 +00:00
testCtxHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
data, ok := r.Context().Value(DataCtxKey).(*data)
require.True(t, ok)
assert.Equal(t, "test data", data.inner)
2023-06-25 20:10:30 +00:00
ctxToken, ok := r.Context().Value(JWTCtxKey).(jwt.Token)
2020-11-07 11:59:33 +00:00
require.True(t, ok)
2023-06-25 20:10:30 +00:00
assert.Equal(t, "tid", ctxToken.JwtID())
2020-11-07 11:59:33 +00:00
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(auth(testCtxHandler))
defer server.Close()
2023-12-02 04:44:50 +00:00
payload, err := jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("changethis")))
2020-11-07 11:59:33 +00:00
assert.NoError(t, err)
res, err := http.Get(server.URL + "/?jwt=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
2023-06-25 20:10:30 +00:00
token.Set(jwt.JwtIDKey, "invalid") // nolint:errcheck // No need to check error here
2023-12-02 04:44:50 +00:00
payload, err = jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("changethis")))
2020-11-07 11:59:33 +00:00
assert.NoError(t, err)
res, err = http.Get(server.URL + "/?jwt=" + string(payload))
require.Nil(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
}