diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go index 3a20839da..99e5d2df7 100644 --- a/pkg/tcpip/stack/icmp_rate_limit.go +++ b/pkg/tcpip/stack/icmp_rate_limit.go @@ -16,6 +16,7 @@ package stack import ( "golang.org/x/time/rate" + "gvisor.dev/gvisor/pkg/tcpip" ) const ( @@ -31,11 +32,41 @@ const ( // ICMPRateLimiter is a global rate limiter that controls the generation of // ICMP messages generated by the stack. type ICMPRateLimiter struct { - *rate.Limiter + limiter *rate.Limiter + clock tcpip.Clock } // NewICMPRateLimiter returns a global rate limiter for controlling the rate -// at which ICMP messages are generated by the stack. -func NewICMPRateLimiter() *ICMPRateLimiter { - return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)} +// at which ICMP messages are generated by the stack. The returned limiter +// does not apply limits to any ICMP types by default. +func NewICMPRateLimiter(clock tcpip.Clock) *ICMPRateLimiter { + return &ICMPRateLimiter{ + clock: clock, + limiter: rate.NewLimiter(icmpLimit, icmpBurst), + } +} + +// SetLimit sets a new Limit for the limiter. +func (l *ICMPRateLimiter) SetLimit(limit rate.Limit) { + l.limiter.SetLimitAt(l.clock.Now(), limit) +} + +// Limit returns the maximum overall event rate. +func (l *ICMPRateLimiter) Limit() rate.Limit { + return l.limiter.Limit() +} + +// SetBurst sets a new burst size for the limiter. +func (l *ICMPRateLimiter) SetBurst(burst int) { + l.limiter.SetBurstAt(l.clock.Now(), burst) +} + +// Burst returns the maximum burst size. +func (l *ICMPRateLimiter) Burst() int { + return l.limiter.Burst() +} + +// Allow reports whether one ICMP message may be sent now. +func (l *ICMPRateLimiter) Allow() bool { + return l.limiter.AllowN(l.clock.Now(), 1) } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 98867a828..428350f31 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -375,7 +375,7 @@ func New(opts Options) *Stack { stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, tables: opts.IPTables, - icmpRateLimiter: NewICMPRateLimiter(), + icmpRateLimiter: NewICMPRateLimiter(clock), seed: seed, nudConfigs: opts.NUDConfigs, uniqueIDGenerator: opts.UniqueID, diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 5cc7a2886..d2c0963b0 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -63,5 +63,6 @@ go_test( "//pkg/tcpip/transport/icmp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_x_time//rate:go_default_library", ], ) diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 2d15830a7..3719b0dc7 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -313,6 +314,9 @@ func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal boo Clock: &faketime.NullClock{}, } s := stack.New(options) + // Disable ICMP rate limiter because we're using Null clock, which never advances time and thus + // never allows ICMP messages. + s.SetICMPLimit(rate.Inf) ep := channel.New(256, mtu, "") wep := stack.LinkEndpoint(ep)