Set RealIP context value

This commit is contained in:
Anton Zadvorny 2022-05-12 14:29:33 +03:00
parent a0c0fbfc2e
commit fbff18885f
2 changed files with 25 additions and 1 deletions

View File

@ -1,10 +1,15 @@
package middleware package middleware
import ( import (
"context"
"net/http" "net/http"
"strings" "strings"
) )
var (
RealIPCtxKey = &CtxKey{Pkg: "go.pkg.cx/middleware", Name: "RealIP"}
)
// RealIP is a middleware that sets a http.Request's RemoteAddr to the results // RealIP is a middleware that sets a http.Request's RemoteAddr to the results
// of parsing either the X-Forwarded-For header or the X-Real-IP header (in that // of parsing either the X-Forwarded-For header or the X-Real-IP header (in that
// order). // order).
@ -26,7 +31,8 @@ func RealIP(next http.Handler) http.Handler {
r.RemoteAddr = rip r.RemoteAddr = rip
} }
next.ServeHTTP(w, r) ctx := context.WithValue(r.Context(), RealIPCtxKey, strings.Split(r.RemoteAddr, ":")[0])
next.ServeHTTP(w, r.WithContext(ctx))
}) })
} }

View File

@ -98,3 +98,21 @@ func TestRealIPNoHeaders(t *testing.T) {
defer res.Body.Close() defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode) assert.Equal(t, http.StatusOK, res.StatusCode)
} }
func TestRealIPContext(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, strings.Split(r.RemoteAddr, ":")[0], "127.0.0.1")
assert.Equal(t, r.Context().Value(RealIPCtxKey), "127.0.0.1")
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(RealIP(handler))
defer server.Close()
res, err := http.Get(server.URL)
require.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
}