Do not handle TCP packets that include a non-unicast IP address

This change drops TCP packets with a non-unicast IP address as the source or
destination address as TCP is meant for communication between two endpoints.

Test: Make sure that if the source or destination address contains a non-unicast
address, no TCP packet is sent in response and the packet is dropped.
PiperOrigin-RevId: 280073731
This commit is contained in:
Ghanan Gowripalan 2019-11-12 15:48:34 -08:00 committed by gVisor bot
parent 5398530e45
commit 3f51bef8cd
3 changed files with 269 additions and 14 deletions

View File

@ -389,8 +389,8 @@ var loopbackSubnet = func() tcpip.Subnet {
}()
// deliverPacket attempts to find one or more matching transport endpoints, and
// then, if matches are found, delivers the packet to them. Returns true if it
// found one or more endpoints, false otherwise.
// then, if matches are found, delivers the packet to them. Returns true if
// the packet no longer needs to be handled.
func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
@ -400,15 +400,40 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
eps.mu.RLock()
// Determine which transport endpoint or endpoints to deliver this packet to.
// If the packet is a broadcast or multicast, then find all matching
// transport endpoints.
// If the packet is a UDP broadcast or multicast, then find all matching
// transport endpoints. If the packet is a TCP packet with a non-unicast
// source or destination address, then do nothing further and instruct
// the caller to do the same.
var destEps []*endpointsByNic
if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
switch protocol {
case header.UDPProtocolNumber:
if isMulticastOrBroadcast(id.LocalAddress) {
destEps = d.findAllEndpointsLocked(eps, id)
} else if ep := d.findEndpointLocked(eps, id); ep != nil {
break
}
if ep := d.findEndpointLocked(eps, id); ep != nil {
destEps = append(destEps, ep)
}
case header.TCPProtocolNumber:
if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) {
// TCP can only be used to communicate between a single
// source and a single destination; the addresses must
// be unicast.
eps.mu.RUnlock()
r.Stats().TCP.InvalidSegmentsReceived.Increment()
return true
}
fallthrough
default:
if ep := d.findEndpointLocked(eps, id); ep != nil {
destEps = append(destEps, ep)
}
}
eps.mu.RUnlock()
// Fail if we didn't find at least one matching transport endpoint.
@ -587,3 +612,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
func isMulticastOrBroadcast(addr tcpip.Address) bool {
return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
}
func isUnicast(addr tcpip.Address) bool {
return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr)
}

View File

@ -4242,6 +4242,210 @@ func TestListenBacklogFull(t *testing.T) {
}
}
// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a
// non unicast IPv4 address are not accepted.
func TestListenNoAcceptNonUnicastV4(t *testing.T) {
multicastAddr := tcpip.Address("\xe0\x00\x01\x02")
otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03")
tests := []struct {
name string
srcAddr tcpip.Address
dstAddr tcpip.Address
}{
{
"SourceUnspecified",
header.IPv4Any,
context.StackAddr,
},
{
"SourceBroadcast",
header.IPv4Broadcast,
context.StackAddr,
},
{
"SourceOurMulticast",
multicastAddr,
context.StackAddr,
},
{
"SourceOtherMulticast",
otherMulticastAddr,
context.StackAddr,
},
{
"DestUnspecified",
context.TestAddr,
header.IPv4Any,
},
{
"DestBroadcast",
context.TestAddr,
header.IPv4Broadcast,
},
{
"DestOurMulticast",
context.TestAddr,
multicastAddr,
},
{
"DestOtherMulticast",
context.TestAddr,
otherMulticastAddr,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.Create(-1)
if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil {
t.Fatalf("JoinGroup failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Listen(1); err != nil {
t.Fatalf("Listen failed: %s", err)
}
irs := seqnum.Value(789)
c.SendPacketWithAddrs(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: irs,
RcvWnd: 30000,
}, test.srcAddr, test.dstAddr)
c.CheckNoPacket("Should not have received a response")
// Handle normal packet.
c.SendPacketWithAddrs(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: irs,
RcvWnd: 30000,
}, context.TestAddr, context.StackAddr)
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
checker.AckNum(uint32(irs)+1)))
})
}
}
// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a
// non unicast IPv6 address are not accepted.
func TestListenNoAcceptNonUnicastV6(t *testing.T) {
multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
tests := []struct {
name string
srcAddr tcpip.Address
dstAddr tcpip.Address
}{
{
"SourceUnspecified",
header.IPv6Any,
context.StackV6Addr,
},
{
"SourceAllNodes",
header.IPv6AllNodesMulticastAddress,
context.StackV6Addr,
},
{
"SourceOurMulticast",
multicastAddr,
context.StackV6Addr,
},
{
"SourceOtherMulticast",
otherMulticastAddr,
context.StackV6Addr,
},
{
"DestUnspecified",
context.TestV6Addr,
header.IPv6Any,
},
{
"DestAllNodes",
context.TestV6Addr,
header.IPv6AllNodesMulticastAddress,
},
{
"DestOurMulticast",
context.TestV6Addr,
multicastAddr,
},
{
"DestOtherMulticast",
context.TestV6Addr,
otherMulticastAddr,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateV6Endpoint(true)
if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil {
t.Fatalf("JoinGroup failed: %s", err)
}
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
t.Fatalf("Bind failed: %s", err)
}
if err := c.EP.Listen(1); err != nil {
t.Fatalf("Listen failed: %s", err)
}
irs := seqnum.Value(789)
c.SendV6PacketWithAddrs(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: irs,
RcvWnd: 30000,
}, test.srcAddr, test.dstAddr)
c.CheckNoPacket("Should not have received a response")
// Handle normal packet.
c.SendV6PacketWithAddrs(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: irs,
RcvWnd: 30000,
}, context.TestV6Addr, context.StackV6Addr)
checker.IPv6(t, c.GetV6Packet(),
checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
checker.AckNum(uint32(irs)+1)))
})
}
}
func TestListenSynRcvdQueueFull(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()

View File

@ -309,6 +309,12 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
// BuildSegment builds a TCP segment based on the given Headers and payload.
func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView {
return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr)
}
// BuildSegmentWithAddrs builds a TCP segment based on the given Headers,
// payload and source and destination IPv4 addresses.
func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@ -321,8 +327,8 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(tcp.ProtocolNumber),
SrcAddr: TestAddr,
DstAddr: StackAddr,
SrcAddr: src,
DstAddr: dst,
})
ip.SetChecksum(^ip.CalculateChecksum())
@ -339,7 +345,7 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
})
// Calculate the TCP pseudo-header checksum.
xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestAddr, StackAddr, uint16(len(t)))
xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
// Calculate the TCP checksum and set it.
xsum = header.Checksum(payload, xsum)
@ -365,6 +371,15 @@ func (c *Context) SendPacket(payload []byte, h *Headers) {
})
}
// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload
// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
// provided source and destination IPv4 addresses.
func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
})
}
// SendAck sends an ACK packet.
func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) {
c.SendAckWithSACK(seq, bytesReceived, nil)
@ -490,6 +505,13 @@ func (c *Context) GetV6Packet() []byte {
// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
// the context.
func (c *Context) SendV6Packet(payload []byte, h *Headers) {
c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr)
}
// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer
// endpoint of the context using the provided source and destination IPv6
// addresses.
func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
@ -500,8 +522,8 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
NextHeader: uint8(tcp.ProtocolNumber),
HopLimit: 65,
SrcAddr: TestV6Addr,
DstAddr: StackV6Addr,
SrcAddr: src,
DstAddr: dst,
})
// Initialize the TCP header.
@ -517,7 +539,7 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
})
// Calculate the TCP pseudo-header checksum.
xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestV6Addr, StackV6Addr, uint16(len(t)))
xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
// Calculate the TCP checksum and set it.
xsum = header.Checksum(payload, xsum)