package ratelimit import ( "io" "net/http" "net/http/httptest" "strconv" "sync" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestRateLimitSequential(t *testing.T) { rateLimit := Middleware( SetLimit(5), ) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte("resp")) require.NoError(t, err) }) server := httptest.NewServer(rateLimit(handler)) defer server.Close() for i := 0; i < 10; i++ { res, err := http.Get(server.URL) assert.NoError(t, err) defer res.Body.Close() if i < 5 { assert.Equal(t, http.StatusOK, res.StatusCode) } else { assert.Equal(t, http.StatusTooManyRequests, res.StatusCode) } } } func TestRateLimitConcurrent(t *testing.T) { rateLimit := Middleware( SetLimit(5), ) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte("resp")) require.NoError(t, err) }) server := httptest.NewServer(rateLimit(handler)) defer server.Close() counter := int64(0) var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() res, err := http.Get(server.URL) assert.NoError(t, err) defer res.Body.Close() if res.StatusCode == http.StatusOK { atomic.AddInt64(&counter, 1) } }() } wg.Wait() assert.Equal(t, int64(5), atomic.LoadInt64(&counter)) } func TestRateLimitHeaders(t *testing.T) { rateLimit := Middleware( SetLimit(1), ) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte("resp")) require.NoError(t, err) }) server := httptest.NewServer(rateLimit(handler)) defer server.Close() now := time.Now() res, err := http.Get(server.URL) assert.NoError(t, err) defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit")) assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining")) resetTS, err := strconv.Atoi(res.Header.Get("X-RateLimit-Reset")) assert.NoError(t, err) assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1) b, err := io.ReadAll(res.Body) assert.NoError(t, err) assert.Equal(t, "resp", string(b)) res, err = http.Get(server.URL) assert.NoError(t, err) defer res.Body.Close() assert.Equal(t, http.StatusTooManyRequests, res.StatusCode) assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit")) assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining")) resetTS, err = strconv.Atoi(res.Header.Get("X-RateLimit-Reset")) assert.NoError(t, err) assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1) b, err = io.ReadAll(res.Body) assert.NoError(t, err) assert.Equal(t, "Too Many Requests\n", string(b)) } func TestRateLimitExpiration(t *testing.T) { rateLimit := Middleware( SetLimit(5), SetPeriod(time.Millisecond*500), ) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, err := w.Write([]byte("resp")) require.NoError(t, err) }) server := httptest.NewServer(rateLimit(handler)) defer server.Close() for i := 0; i < 10; i++ { res, err := http.Get(server.URL) assert.NoError(t, err) defer res.Body.Close() if i < 5 { assert.Equal(t, http.StatusOK, res.StatusCode) } else { assert.Equal(t, http.StatusTooManyRequests, res.StatusCode) } } time.Sleep(time.Millisecond * 500) for i := 0; i < 10; i++ { res, err := http.Get(server.URL) assert.NoError(t, err) defer res.Body.Close() if i < 5 { assert.Equal(t, http.StatusOK, res.StatusCode) } else { assert.Equal(t, http.StatusTooManyRequests, res.StatusCode) } } }