package timeout import ( "io/ioutil" "net/http" "net/http/httptest" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestTimeoutPass(t *testing.T) { timeout := Middleware( SetTimeout(time.Second * 1), ) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) time.Sleep(time.Millisecond * 100) _, err := w.Write([]byte("resp")) require.NoError(t, err) }) server := httptest.NewServer(timeout(handler)) defer server.Close() res, err := http.Get(server.URL) assert.NoError(t, err) defer res.Body.Close() assert.Equal(t, http.StatusOK, res.StatusCode) b, err := ioutil.ReadAll(res.Body) assert.NoError(t, err) assert.Equal(t, "resp", string(b)) } func TestTimeoutTimedOut(t *testing.T) { timeout := Middleware( SetTimeout(time.Millisecond * 300), ) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { select { case <-r.Context().Done(): return case <-time.After(time.Second * 1): _, err := w.Write([]byte("resp")) require.NoError(t, err) } }) server := httptest.NewServer(timeout(handler)) defer server.Close() res, err := http.Get(server.URL) assert.NoError(t, err) defer res.Body.Close() assert.Equal(t, http.StatusGatewayTimeout, res.StatusCode) }