124 lines
2.9 KiB
Go
124 lines
2.9 KiB
Go
package token
|
|
|
|
import (
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"go.pkg.cx/middleware"
|
|
)
|
|
|
|
var testHandler = func(t *testing.T) http.HandlerFunc {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, err := w.Write([]byte("resp"))
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
var validateFn = func(token string) (bool, interface{}) {
|
|
return token == "valid", nil
|
|
}
|
|
|
|
func TestTokenAuthDefaults(t *testing.T) {
|
|
auth := Middleware()
|
|
|
|
server := httptest.NewServer(auth(testHandler(t)))
|
|
defer server.Close()
|
|
|
|
res, err := http.Get(server.URL)
|
|
require.Nil(t, err)
|
|
defer res.Body.Close()
|
|
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
|
|
|
b, err := io.ReadAll(res.Body)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "Unauthorized\n", string(b))
|
|
|
|
res, err = http.Get(server.URL + "?token=invalid")
|
|
require.Nil(t, err)
|
|
defer res.Body.Close()
|
|
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
|
}
|
|
|
|
func TestTokenAuthValidateFn(t *testing.T) {
|
|
auth := Middleware(
|
|
SetValidateTokenFn(validateFn),
|
|
)
|
|
|
|
server := httptest.NewServer(auth(testHandler(t)))
|
|
defer server.Close()
|
|
|
|
res, err := http.Get(server.URL + "?token=invalid")
|
|
require.Nil(t, err)
|
|
defer res.Body.Close()
|
|
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
|
|
|
res, err = http.Get(server.URL + "?token=valid")
|
|
require.Nil(t, err)
|
|
defer res.Body.Close()
|
|
assert.Equal(t, http.StatusOK, res.StatusCode)
|
|
|
|
b, err := io.ReadAll(res.Body)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "resp", string(b))
|
|
}
|
|
|
|
func TestTokenAuthFindTokenFns(t *testing.T) {
|
|
auth := Middleware(
|
|
SetValidateTokenFn(validateFn),
|
|
SetFindTokenFns(middleware.TokenFromQuery("api_key")),
|
|
)
|
|
|
|
server := httptest.NewServer(auth(testHandler(t)))
|
|
defer server.Close()
|
|
|
|
res, err := http.Get(server.URL + "?token=valid")
|
|
require.Nil(t, err)
|
|
defer res.Body.Close()
|
|
assert.Equal(t, http.StatusUnauthorized, res.StatusCode)
|
|
|
|
res, err = http.Get(server.URL + "?api_key=valid")
|
|
require.Nil(t, err)
|
|
defer res.Body.Close()
|
|
assert.Equal(t, http.StatusOK, res.StatusCode)
|
|
}
|
|
|
|
func TestTokenAuthContext(t *testing.T) {
|
|
type data struct {
|
|
inner string
|
|
}
|
|
|
|
validateWithDataFn := func(token string) (bool, interface{}) {
|
|
return token == "valid", &data{"test data"}
|
|
}
|
|
|
|
auth := Middleware(
|
|
SetValidateTokenFn(validateWithDataFn),
|
|
)
|
|
|
|
testCtxHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
data, ok := r.Context().Value(DataCtxKey).(*data)
|
|
require.True(t, ok)
|
|
assert.Equal(t, "test data", data.inner)
|
|
|
|
token, ok := r.Context().Value(TokenCtxKey).(string)
|
|
require.True(t, ok)
|
|
assert.Equal(t, "valid", token)
|
|
|
|
_, err := w.Write([]byte("resp"))
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
server := httptest.NewServer(auth(testCtxHandler))
|
|
defer server.Close()
|
|
|
|
res, err := http.Get(server.URL + "?token=valid")
|
|
require.Nil(t, err)
|
|
defer res.Body.Close()
|
|
assert.Equal(t, http.StatusOK, res.StatusCode)
|
|
}
|