package throttle import ( "io" "net/http" "net/http/httptest" "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestThrottleBacklog(t *testing.T) { throttle := Middleware( SetLimit(10), SetBacklogLimit(50), SetBacklogTimeout(time.Second*10), ) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) time.Sleep(time.Second * 1) _, err := w.Write([]byte("resp")) require.NoError(t, err) }) server := httptest.NewServer(throttle(handler)) defer server.Close() client := http.Client{Timeout: time.Second * 5} var wg sync.WaitGroup for i := 0; i < 1; i++ { wg.Add(1) go func() { defer wg.Done() res, err := client.Get(server.URL) assert.NoError(t, err) defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) b, err := io.ReadAll(res.Body) assert.NoError(t, err) assert.Equal(t, "resp", string(b)) }() } wg.Wait() } func TestThrottleClientTimeout(t *testing.T) { throttle := Middleware( SetLimit(10), SetBacklogLimit(50), SetBacklogTimeout(time.Second*10), ) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) time.Sleep(time.Second * 5) _, err := w.Write([]byte("resp")) require.NoError(t, err) }) server := httptest.NewServer(throttle(handler)) defer server.Close() client := http.Client{Timeout: time.Second * 3} var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() _, err := client.Get(server.URL) assert.Error(t, err) }() } wg.Wait() } func TestThrottleTriggerGatewayTimeout(t *testing.T) { throttle := Middleware( SetLimit(50), SetBacklogLimit(100), SetBacklogTimeout(time.Second*5), ) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) time.Sleep(time.Second * 10) _, err := w.Write([]byte("resp")) require.NoError(t, err) }) server := httptest.NewServer(throttle(handler)) defer server.Close() client := http.Client{Timeout: time.Second * 60} var wg sync.WaitGroup for i := 0; i < 50; i++ { wg.Add(1) go func() { defer wg.Done() res, err := client.Get(server.URL) assert.NoError(t, err) defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) }() } time.Sleep(time.Second * 1) for i := 0; i < 50; i++ { wg.Add(1) go func() { defer wg.Done() res, err := client.Get(server.URL) assert.NoError(t, err) defer res.Body.Close() assert.Equal(t, http.StatusTooManyRequests, res.StatusCode) }() } wg.Wait() } func TestThrottleMaximum(t *testing.T) { throttle := Middleware( SetLimit(10), SetBacklogLimit(10), SetBacklogTimeout(time.Second*5), ) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) time.Sleep(time.Second * 2) _, err := w.Write([]byte("resp")) require.NoError(t, err) }) server := httptest.NewServer(throttle(handler)) defer server.Close() client := http.Client{Timeout: time.Second * 60} var wg sync.WaitGroup for i := 0; i < 20; i++ { wg.Add(1) go func() { defer wg.Done() res, err := client.Get(server.URL) assert.NoError(t, err) defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) b, err := io.ReadAll(res.Body) assert.NoError(t, err) assert.Equal(t, "resp", string(b)) }() } time.Sleep(time.Second * 1) for i := 0; i < 20; i++ { wg.Add(1) go func() { defer wg.Done() res, err := client.Get(server.URL) assert.NoError(t, err) defer res.Body.Close() assert.Equal(t, http.StatusTooManyRequests, res.StatusCode) }() } wg.Wait() }