middleware/ratelimit/rate_limit_test.go
2023-12-25 03:47:26 +03:00

164 lines
3.6 KiB
Go

package ratelimit
import (
"io"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRateLimitSequential(t *testing.T) {
rateLimit := Middleware(
SetLimit(5),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(rateLimit(handler))
defer server.Close()
for i := 0; i < 10; i++ {
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
if i < 5 {
assert.Equal(t, http.StatusOK, res.StatusCode)
} else {
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
}
}
}
func TestRateLimitConcurrent(t *testing.T) {
rateLimit := Middleware(
SetLimit(5),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(rateLimit(handler))
defer server.Close()
counter := int64(0)
var wg sync.WaitGroup
for i := 0; i < 10; i++ {
wg.Add(1)
go func() {
defer wg.Done()
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
if res.StatusCode == http.StatusOK {
atomic.AddInt64(&counter, 1)
}
}()
}
wg.Wait()
assert.Equal(t, int64(5), atomic.LoadInt64(&counter))
}
func TestRateLimitHeaders(t *testing.T) {
rateLimit := Middleware(
SetLimit(1),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(rateLimit(handler))
defer server.Close()
now := time.Now()
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit"))
assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining"))
resetTS, err := strconv.Atoi(res.Header.Get("X-RateLimit-Reset"))
assert.NoError(t, err)
assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1)
b, err := io.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "resp", string(b))
res, err = http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit"))
assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining"))
resetTS, err = strconv.Atoi(res.Header.Get("X-RateLimit-Reset"))
assert.NoError(t, err)
assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1)
b, err = io.ReadAll(res.Body)
assert.NoError(t, err)
assert.Equal(t, "Too Many Requests\n", string(b))
}
func TestRateLimitExpiration(t *testing.T) {
rateLimit := Middleware(
SetLimit(5),
SetPeriod(time.Millisecond*500),
)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := w.Write([]byte("resp"))
require.NoError(t, err)
})
server := httptest.NewServer(rateLimit(handler))
defer server.Close()
for i := 0; i < 10; i++ {
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
if i < 5 {
assert.Equal(t, http.StatusOK, res.StatusCode)
} else {
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
}
}
time.Sleep(time.Millisecond * 500)
for i := 0; i < 10; i++ {
res, err := http.Get(server.URL)
assert.NoError(t, err)
defer res.Body.Close()
if i < 5 {
assert.Equal(t, http.StatusOK, res.StatusCode)
} else {
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
}
}
}