164 lines
3.6 KiB
Go
164 lines
3.6 KiB
Go
package ratelimit
|
|
|
|
import (
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strconv"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestRateLimitSequential(t *testing.T) {
|
|
rateLimit := Middleware(
|
|
SetLimit(5),
|
|
)
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, err := w.Write([]byte("resp"))
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
server := httptest.NewServer(rateLimit(handler))
|
|
defer server.Close()
|
|
|
|
for i := 0; i < 10; i++ {
|
|
res, err := http.Get(server.URL)
|
|
assert.NoError(t, err)
|
|
defer res.Body.Close()
|
|
|
|
if i < 5 {
|
|
assert.Equal(t, http.StatusOK, res.StatusCode)
|
|
} else {
|
|
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRateLimitConcurrent(t *testing.T) {
|
|
rateLimit := Middleware(
|
|
SetLimit(5),
|
|
)
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, err := w.Write([]byte("resp"))
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
server := httptest.NewServer(rateLimit(handler))
|
|
defer server.Close()
|
|
|
|
counter := int64(0)
|
|
var wg sync.WaitGroup
|
|
|
|
for i := 0; i < 10; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
|
|
res, err := http.Get(server.URL)
|
|
assert.NoError(t, err)
|
|
defer res.Body.Close()
|
|
|
|
if res.StatusCode == http.StatusOK {
|
|
atomic.AddInt64(&counter, 1)
|
|
}
|
|
}()
|
|
}
|
|
|
|
wg.Wait()
|
|
assert.Equal(t, int64(5), atomic.LoadInt64(&counter))
|
|
}
|
|
|
|
func TestRateLimitHeaders(t *testing.T) {
|
|
rateLimit := Middleware(
|
|
SetLimit(1),
|
|
)
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, err := w.Write([]byte("resp"))
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
server := httptest.NewServer(rateLimit(handler))
|
|
defer server.Close()
|
|
|
|
now := time.Now()
|
|
|
|
res, err := http.Get(server.URL)
|
|
assert.NoError(t, err)
|
|
defer res.Body.Close()
|
|
assert.Equal(t, http.StatusOK, res.StatusCode)
|
|
assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit"))
|
|
assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining"))
|
|
|
|
resetTS, err := strconv.Atoi(res.Header.Get("X-RateLimit-Reset"))
|
|
assert.NoError(t, err)
|
|
assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1)
|
|
|
|
b, err := io.ReadAll(res.Body)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "resp", string(b))
|
|
|
|
res, err = http.Get(server.URL)
|
|
assert.NoError(t, err)
|
|
defer res.Body.Close()
|
|
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
|
|
assert.Equal(t, "1", res.Header.Get("X-RateLimit-Limit"))
|
|
assert.Equal(t, "0", res.Header.Get("X-RateLimit-Remaining"))
|
|
|
|
resetTS, err = strconv.Atoi(res.Header.Get("X-RateLimit-Reset"))
|
|
assert.NoError(t, err)
|
|
assert.InDelta(t, now.Add(time.Minute*1).Unix(), resetTS, 1)
|
|
|
|
b, err = io.ReadAll(res.Body)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "Too Many Requests\n", string(b))
|
|
}
|
|
|
|
func TestRateLimitExpiration(t *testing.T) {
|
|
rateLimit := Middleware(
|
|
SetLimit(5),
|
|
SetPeriod(time.Millisecond*500),
|
|
)
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
_, err := w.Write([]byte("resp"))
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
server := httptest.NewServer(rateLimit(handler))
|
|
defer server.Close()
|
|
|
|
for i := 0; i < 10; i++ {
|
|
res, err := http.Get(server.URL)
|
|
assert.NoError(t, err)
|
|
defer res.Body.Close()
|
|
|
|
if i < 5 {
|
|
assert.Equal(t, http.StatusOK, res.StatusCode)
|
|
} else {
|
|
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
|
|
}
|
|
}
|
|
|
|
time.Sleep(time.Millisecond * 500)
|
|
|
|
for i := 0; i < 10; i++ {
|
|
res, err := http.Get(server.URL)
|
|
assert.NoError(t, err)
|
|
defer res.Body.Close()
|
|
|
|
if i < 5 {
|
|
assert.Equal(t, http.StatusOK, res.StatusCode)
|
|
} else {
|
|
assert.Equal(t, http.StatusTooManyRequests, res.StatusCode)
|
|
}
|
|
}
|
|
}
|