package jwt import ( "io" "net/http" "net/http/httptest" "testing" "time" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.pkg.cx/middleware" ) var testHandler = func(t *testing.T) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte("resp")) require.NoError(t, err) }) } func TestJWTAuthDefaults(t *testing.T) { auth := Middleware( []byte("changethis"), jwa.HS256, ) server := httptest.NewServer(auth(testHandler(t))) defer server.Close() res, err := http.Get(server.URL) require.Nil(t, err) defer res.Body.Close() assert.Equal(t, http.StatusUnauthorized, res.StatusCode) b, err := io.ReadAll(res.Body) assert.NoError(t, err) assert.Equal(t, "Unauthorized\n", string(b)) token := jwt.New() payload, err := jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("changethis"))) assert.NoError(t, err) res, err = http.Get(server.URL + "/?jwt=" + string(payload)) require.Nil(t, err) defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) b, err = io.ReadAll(res.Body) assert.NoError(t, err) assert.Equal(t, "resp", string(b)) } func TestJWTAuthParseVerify(t *testing.T) { auth := Middleware( []byte("tkey"), jwa.HS512, SetVerifyOptions( jwt.WithSubject("tsubj"), jwt.WithAudience("taud"), ), SetFindTokenFns( middleware.TokenFromQuery("token"), ), ) server := httptest.NewServer(auth(testHandler(t))) defer server.Close() token := jwt.New() payload, err := jwt.Sign(token, jwt.WithKey(jwa.HS512, []byte("tkey"))) assert.NoError(t, err) res, err := http.Get(server.URL + "/?token=" + string(payload)) require.Nil(t, err) defer res.Body.Close() assert.Equal(t, http.StatusUnauthorized, res.StatusCode) payload, err = jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("tkey"))) assert.NoError(t, err) res, err = http.Get(server.URL + "/?token=" + string(payload)) require.Nil(t, err) defer res.Body.Close() assert.Equal(t, http.StatusUnauthorized, res.StatusCode) payload, err = jwt.Sign(token, jwt.WithKey(jwa.HS512, []byte("wrongkey"))) assert.NoError(t, err) res, err = http.Get(server.URL + "/?token=" + string(payload)) require.Nil(t, err) defer res.Body.Close() assert.Equal(t, http.StatusUnauthorized, res.StatusCode) now := time.Now() token.Set(jwt.IssuedAtKey, now.Add(-1*time.Hour)) // nolint:errcheck // No need to check error here token.Set(jwt.ExpirationKey, now.Add(-58*time.Minute)) // nolint:errcheck // No need to check error here payload, err = jwt.Sign(token, jwt.WithKey(jwa.HS512, []byte("tkey"))) assert.NoError(t, err) res, err = http.Get(server.URL + "/?token=" + string(payload)) require.Nil(t, err) defer res.Body.Close() assert.Equal(t, http.StatusUnauthorized, res.StatusCode) } func TestJWTAuthVerifyOptions(t *testing.T) { auth := Middleware( []byte("changethis"), jwa.HS256, WithVerifyOption(jwt.WithIssuer("tissuer")), ) server := httptest.NewServer(auth(testHandler(t))) defer server.Close() token := jwt.New() token.Set(jwt.IssuerKey, "wrongissuer") // nolint:errcheck // No need to check error here payload, err := jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("changethis"))) assert.NoError(t, err) res, err := http.Get(server.URL + "/?jwt=" + string(payload)) require.Nil(t, err) defer res.Body.Close() assert.Equal(t, http.StatusUnauthorized, res.StatusCode) token.Set(jwt.IssuerKey, "tissuer") // nolint:errcheck // No need to check error here payload, err = jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("changethis"))) assert.NoError(t, err) res, err = http.Get(server.URL + "/?jwt=" + string(payload)) require.Nil(t, err) defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) } func TestJWTAuthContext(t *testing.T) { type data struct { inner string } validateWithDataFn := func(token jwt.Token) (bool, interface{}) { return token.JwtID() == "tid", &data{"test data"} } auth := Middleware( []byte("changethis"), jwa.HS256, SetValidateTokenFn(validateWithDataFn), ) token := jwt.New() token.Set(jwt.JwtIDKey, "tid") // nolint:errcheck // No need to check error here testCtxHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { data, ok := r.Context().Value(DataCtxKey).(*data) require.True(t, ok) assert.Equal(t, "test data", data.inner) ctxToken, ok := r.Context().Value(JWTCtxKey).(jwt.Token) require.True(t, ok) assert.Equal(t, "tid", ctxToken.JwtID()) _, err := w.Write([]byte("resp")) require.NoError(t, err) }) server := httptest.NewServer(auth(testCtxHandler)) defer server.Close() payload, err := jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("changethis"))) assert.NoError(t, err) res, err := http.Get(server.URL + "/?jwt=" + string(payload)) require.Nil(t, err) defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) token.Set(jwt.JwtIDKey, "invalid") // nolint:errcheck // No need to check error here payload, err = jwt.Sign(token, jwt.WithKey(jwa.HS256, []byte("changethis"))) assert.NoError(t, err) res, err = http.Get(server.URL + "/?jwt=" + string(payload)) require.Nil(t, err) defer res.Body.Close() assert.Equal(t, http.StatusUnauthorized, res.StatusCode) }