package token import ( "io/ioutil" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.pkg.cx/middleware" ) var testHandler = func(t *testing.T) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte("resp")) require.NoError(t, err) }) } var validateFn = func(token string) (bool, interface{}) { return token == "valid", nil } func TestTokenAuthDefaults(t *testing.T) { auth := Middleware() server := httptest.NewServer(auth(testHandler(t))) defer server.Close() res, err := http.Get(server.URL) require.Nil(t, err) defer res.Body.Close() assert.Equal(t, http.StatusUnauthorized, res.StatusCode) b, err := ioutil.ReadAll(res.Body) assert.NoError(t, err) assert.Equal(t, "Unauthorized\n", string(b)) res, err = http.Get(server.URL + "?token=invalid") require.Nil(t, err) defer res.Body.Close() assert.Equal(t, http.StatusUnauthorized, res.StatusCode) } func TestTokenAuthValidateFn(t *testing.T) { auth := Middleware( SetValidateTokenFn(validateFn), ) server := httptest.NewServer(auth(testHandler(t))) defer server.Close() res, err := http.Get(server.URL + "?token=invalid") require.Nil(t, err) defer res.Body.Close() assert.Equal(t, http.StatusUnauthorized, res.StatusCode) res, err = http.Get(server.URL + "?token=valid") require.Nil(t, err) defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) b, err := ioutil.ReadAll(res.Body) assert.NoError(t, err) assert.Equal(t, "resp", string(b)) } func TestTokenAuthFindTokenFns(t *testing.T) { auth := Middleware( SetValidateTokenFn(validateFn), SetFindTokenFns(middleware.TokenFromQuery("api_key")), ) server := httptest.NewServer(auth(testHandler(t))) defer server.Close() res, err := http.Get(server.URL + "?token=valid") require.Nil(t, err) defer res.Body.Close() assert.Equal(t, http.StatusUnauthorized, res.StatusCode) res, err = http.Get(server.URL + "?api_key=valid") require.Nil(t, err) defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) } func TestTokenAuthContext(t *testing.T) { type data struct { inner string } validateWithDataFn := func(token string) (bool, interface{}) { return token == "valid", &data{"test data"} } auth := Middleware( SetValidateTokenFn(validateWithDataFn), ) testCtxHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { data, ok := r.Context().Value(DataCtxKey).(*data) require.True(t, ok) assert.Equal(t, "test data", data.inner) token, ok := r.Context().Value(TokenCtxKey).(string) require.True(t, ok) assert.Equal(t, "valid", token) _, err := w.Write([]byte("resp")) require.NoError(t, err) }) server := httptest.NewServer(auth(testCtxHandler)) defer server.Close() res, err := http.Get(server.URL + "?token=valid") require.Nil(t, err) defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) }