Set RealIP context value
This commit is contained in:
parent
a0c0fbfc2e
commit
fbff18885f
@ -1,10 +1,15 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
RealIPCtxKey = &CtxKey{Pkg: "go.pkg.cx/middleware", Name: "RealIP"}
|
||||
)
|
||||
|
||||
// RealIP is a middleware that sets a http.Request's RemoteAddr to the results
|
||||
// of parsing either the X-Forwarded-For header or the X-Real-IP header (in that
|
||||
// order).
|
||||
@ -26,7 +31,8 @@ func RealIP(next http.Handler) http.Handler {
|
||||
r.RemoteAddr = rip
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
ctx := context.WithValue(r.Context(), RealIPCtxKey, strings.Split(r.RemoteAddr, ":")[0])
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -98,3 +98,21 @@ func TestRealIPNoHeaders(t *testing.T) {
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
||||
func TestRealIPContext(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, strings.Split(r.RemoteAddr, ":")[0], "127.0.0.1")
|
||||
assert.Equal(t, r.Context().Value(RealIPCtxKey), "127.0.0.1")
|
||||
|
||||
_, err := w.Write([]byte("resp"))
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
server := httptest.NewServer(RealIP(handler))
|
||||
defer server.Close()
|
||||
|
||||
res, err := http.Get(server.URL)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, res.StatusCode)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user