diff --git a/real_ip.go b/real_ip.go index 3caf724..2ce948d 100644 --- a/real_ip.go +++ b/real_ip.go @@ -1,10 +1,15 @@ package middleware import ( + "context" "net/http" "strings" ) +var ( + RealIPCtxKey = &CtxKey{Pkg: "go.pkg.cx/middleware", Name: "RealIP"} +) + // RealIP is a middleware that sets a http.Request's RemoteAddr to the results // of parsing either the X-Forwarded-For header or the X-Real-IP header (in that // order). @@ -26,7 +31,8 @@ func RealIP(next http.Handler) http.Handler { r.RemoteAddr = rip } - next.ServeHTTP(w, r) + ctx := context.WithValue(r.Context(), RealIPCtxKey, strings.Split(r.RemoteAddr, ":")[0]) + next.ServeHTTP(w, r.WithContext(ctx)) }) } diff --git a/real_ip_test.go b/real_ip_test.go index eedab6b..ffd31b1 100644 --- a/real_ip_test.go +++ b/real_ip_test.go @@ -98,3 +98,21 @@ func TestRealIPNoHeaders(t *testing.T) { defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) } + +func TestRealIPContext(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, strings.Split(r.RemoteAddr, ":")[0], "127.0.0.1") + assert.Equal(t, r.Context().Value(RealIPCtxKey), "127.0.0.1") + + _, err := w.Write([]byte("resp")) + require.NoError(t, err) + }) + + server := httptest.NewServer(RealIP(handler)) + defer server.Close() + + res, err := http.Get(server.URL) + require.NoError(t, err) + defer res.Body.Close() + assert.Equal(t, http.StatusOK, res.StatusCode) +}