Do not handle TCP packets that include a non-unicast IP address
This change drops TCP packets with a non-unicast IP address as the source or destination address as TCP is meant for communication between two endpoints. Test: Make sure that if the source or destination address contains a non-unicast address, no TCP packet is sent in response and the packet is dropped. PiperOrigin-RevId: 280073731
This commit is contained in:
parent
5398530e45
commit
3f51bef8cd
|
@ -389,8 +389,8 @@ var loopbackSubnet = func() tcpip.Subnet {
|
|||
}()
|
||||
|
||||
// deliverPacket attempts to find one or more matching transport endpoints, and
|
||||
// then, if matches are found, delivers the packet to them. Returns true if it
|
||||
// found one or more endpoints, false otherwise.
|
||||
// then, if matches are found, delivers the packet to them. Returns true if
|
||||
// the packet no longer needs to be handled.
|
||||
func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt tcpip.PacketBuffer, id TransportEndpointID) bool {
|
||||
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
|
||||
if !ok {
|
||||
|
@ -400,15 +400,40 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto
|
|||
eps.mu.RLock()
|
||||
|
||||
// Determine which transport endpoint or endpoints to deliver this packet to.
|
||||
// If the packet is a broadcast or multicast, then find all matching
|
||||
// transport endpoints.
|
||||
// If the packet is a UDP broadcast or multicast, then find all matching
|
||||
// transport endpoints. If the packet is a TCP packet with a non-unicast
|
||||
// source or destination address, then do nothing further and instruct
|
||||
// the caller to do the same.
|
||||
var destEps []*endpointsByNic
|
||||
if protocol == header.UDPProtocolNumber && isMulticastOrBroadcast(id.LocalAddress) {
|
||||
switch protocol {
|
||||
case header.UDPProtocolNumber:
|
||||
if isMulticastOrBroadcast(id.LocalAddress) {
|
||||
destEps = d.findAllEndpointsLocked(eps, id)
|
||||
} else if ep := d.findEndpointLocked(eps, id); ep != nil {
|
||||
break
|
||||
}
|
||||
|
||||
if ep := d.findEndpointLocked(eps, id); ep != nil {
|
||||
destEps = append(destEps, ep)
|
||||
}
|
||||
|
||||
case header.TCPProtocolNumber:
|
||||
if !(isUnicast(r.LocalAddress) && isUnicast(r.RemoteAddress)) {
|
||||
// TCP can only be used to communicate between a single
|
||||
// source and a single destination; the addresses must
|
||||
// be unicast.
|
||||
eps.mu.RUnlock()
|
||||
r.Stats().TCP.InvalidSegmentsReceived.Increment()
|
||||
return true
|
||||
}
|
||||
|
||||
fallthrough
|
||||
|
||||
default:
|
||||
if ep := d.findEndpointLocked(eps, id); ep != nil {
|
||||
destEps = append(destEps, ep)
|
||||
}
|
||||
}
|
||||
|
||||
eps.mu.RUnlock()
|
||||
|
||||
// Fail if we didn't find at least one matching transport endpoint.
|
||||
|
@ -587,3 +612,7 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN
|
|||
func isMulticastOrBroadcast(addr tcpip.Address) bool {
|
||||
return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr)
|
||||
}
|
||||
|
||||
func isUnicast(addr tcpip.Address) bool {
|
||||
return addr != header.IPv4Any && addr != header.IPv6Any && !isMulticastOrBroadcast(addr)
|
||||
}
|
||||
|
|
|
@ -4242,6 +4242,210 @@ func TestListenBacklogFull(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a
|
||||
// non unicast IPv4 address are not accepted.
|
||||
func TestListenNoAcceptNonUnicastV4(t *testing.T) {
|
||||
multicastAddr := tcpip.Address("\xe0\x00\x01\x02")
|
||||
otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
srcAddr tcpip.Address
|
||||
dstAddr tcpip.Address
|
||||
}{
|
||||
{
|
||||
"SourceUnspecified",
|
||||
header.IPv4Any,
|
||||
context.StackAddr,
|
||||
},
|
||||
{
|
||||
"SourceBroadcast",
|
||||
header.IPv4Broadcast,
|
||||
context.StackAddr,
|
||||
},
|
||||
{
|
||||
"SourceOurMulticast",
|
||||
multicastAddr,
|
||||
context.StackAddr,
|
||||
},
|
||||
{
|
||||
"SourceOtherMulticast",
|
||||
otherMulticastAddr,
|
||||
context.StackAddr,
|
||||
},
|
||||
{
|
||||
"DestUnspecified",
|
||||
context.TestAddr,
|
||||
header.IPv4Any,
|
||||
},
|
||||
{
|
||||
"DestBroadcast",
|
||||
context.TestAddr,
|
||||
header.IPv4Broadcast,
|
||||
},
|
||||
{
|
||||
"DestOurMulticast",
|
||||
context.TestAddr,
|
||||
multicastAddr,
|
||||
},
|
||||
{
|
||||
"DestOtherMulticast",
|
||||
context.TestAddr,
|
||||
otherMulticastAddr,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := context.New(t, defaultMTU)
|
||||
defer c.Cleanup()
|
||||
|
||||
c.Create(-1)
|
||||
|
||||
if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil {
|
||||
t.Fatalf("JoinGroup failed: %s", err)
|
||||
}
|
||||
|
||||
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
|
||||
t.Fatalf("Bind failed: %s", err)
|
||||
}
|
||||
|
||||
if err := c.EP.Listen(1); err != nil {
|
||||
t.Fatalf("Listen failed: %s", err)
|
||||
}
|
||||
|
||||
irs := seqnum.Value(789)
|
||||
c.SendPacketWithAddrs(nil, &context.Headers{
|
||||
SrcPort: context.TestPort,
|
||||
DstPort: context.StackPort,
|
||||
Flags: header.TCPFlagSyn,
|
||||
SeqNum: irs,
|
||||
RcvWnd: 30000,
|
||||
}, test.srcAddr, test.dstAddr)
|
||||
c.CheckNoPacket("Should not have received a response")
|
||||
|
||||
// Handle normal packet.
|
||||
c.SendPacketWithAddrs(nil, &context.Headers{
|
||||
SrcPort: context.TestPort,
|
||||
DstPort: context.StackPort,
|
||||
Flags: header.TCPFlagSyn,
|
||||
SeqNum: irs,
|
||||
RcvWnd: 30000,
|
||||
}, context.TestAddr, context.StackAddr)
|
||||
checker.IPv4(t, c.GetPacket(),
|
||||
checker.TCP(
|
||||
checker.SrcPort(context.StackPort),
|
||||
checker.DstPort(context.TestPort),
|
||||
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
|
||||
checker.AckNum(uint32(irs)+1)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a
|
||||
// non unicast IPv6 address are not accepted.
|
||||
func TestListenNoAcceptNonUnicastV6(t *testing.T) {
|
||||
multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01")
|
||||
otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
srcAddr tcpip.Address
|
||||
dstAddr tcpip.Address
|
||||
}{
|
||||
{
|
||||
"SourceUnspecified",
|
||||
header.IPv6Any,
|
||||
context.StackV6Addr,
|
||||
},
|
||||
{
|
||||
"SourceAllNodes",
|
||||
header.IPv6AllNodesMulticastAddress,
|
||||
context.StackV6Addr,
|
||||
},
|
||||
{
|
||||
"SourceOurMulticast",
|
||||
multicastAddr,
|
||||
context.StackV6Addr,
|
||||
},
|
||||
{
|
||||
"SourceOtherMulticast",
|
||||
otherMulticastAddr,
|
||||
context.StackV6Addr,
|
||||
},
|
||||
{
|
||||
"DestUnspecified",
|
||||
context.TestV6Addr,
|
||||
header.IPv6Any,
|
||||
},
|
||||
{
|
||||
"DestAllNodes",
|
||||
context.TestV6Addr,
|
||||
header.IPv6AllNodesMulticastAddress,
|
||||
},
|
||||
{
|
||||
"DestOurMulticast",
|
||||
context.TestV6Addr,
|
||||
multicastAddr,
|
||||
},
|
||||
{
|
||||
"DestOtherMulticast",
|
||||
context.TestV6Addr,
|
||||
otherMulticastAddr,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c := context.New(t, defaultMTU)
|
||||
defer c.Cleanup()
|
||||
|
||||
c.CreateV6Endpoint(true)
|
||||
|
||||
if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil {
|
||||
t.Fatalf("JoinGroup failed: %s", err)
|
||||
}
|
||||
|
||||
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil {
|
||||
t.Fatalf("Bind failed: %s", err)
|
||||
}
|
||||
|
||||
if err := c.EP.Listen(1); err != nil {
|
||||
t.Fatalf("Listen failed: %s", err)
|
||||
}
|
||||
|
||||
irs := seqnum.Value(789)
|
||||
c.SendV6PacketWithAddrs(nil, &context.Headers{
|
||||
SrcPort: context.TestPort,
|
||||
DstPort: context.StackPort,
|
||||
Flags: header.TCPFlagSyn,
|
||||
SeqNum: irs,
|
||||
RcvWnd: 30000,
|
||||
}, test.srcAddr, test.dstAddr)
|
||||
c.CheckNoPacket("Should not have received a response")
|
||||
|
||||
// Handle normal packet.
|
||||
c.SendV6PacketWithAddrs(nil, &context.Headers{
|
||||
SrcPort: context.TestPort,
|
||||
DstPort: context.StackPort,
|
||||
Flags: header.TCPFlagSyn,
|
||||
SeqNum: irs,
|
||||
RcvWnd: 30000,
|
||||
}, context.TestV6Addr, context.StackV6Addr)
|
||||
checker.IPv6(t, c.GetV6Packet(),
|
||||
checker.TCP(
|
||||
checker.SrcPort(context.StackPort),
|
||||
checker.DstPort(context.TestPort),
|
||||
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
|
||||
checker.AckNum(uint32(irs)+1)))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenSynRcvdQueueFull(t *testing.T) {
|
||||
c := context.New(t, defaultMTU)
|
||||
defer c.Cleanup()
|
||||
|
|
|
@ -309,6 +309,12 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt
|
|||
|
||||
// BuildSegment builds a TCP segment based on the given Headers and payload.
|
||||
func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView {
|
||||
return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr)
|
||||
}
|
||||
|
||||
// BuildSegmentWithAddrs builds a TCP segment based on the given Headers,
|
||||
// payload and source and destination IPv4 addresses.
|
||||
func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView {
|
||||
// Allocate a buffer for data and headers.
|
||||
buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload))
|
||||
copy(buf[len(buf)-len(payload):], payload)
|
||||
|
@ -321,8 +327,8 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
|
|||
TotalLength: uint16(len(buf)),
|
||||
TTL: 65,
|
||||
Protocol: uint8(tcp.ProtocolNumber),
|
||||
SrcAddr: TestAddr,
|
||||
DstAddr: StackAddr,
|
||||
SrcAddr: src,
|
||||
DstAddr: dst,
|
||||
})
|
||||
ip.SetChecksum(^ip.CalculateChecksum())
|
||||
|
||||
|
@ -339,7 +345,7 @@ func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView
|
|||
})
|
||||
|
||||
// Calculate the TCP pseudo-header checksum.
|
||||
xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestAddr, StackAddr, uint16(len(t)))
|
||||
xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
|
||||
|
||||
// Calculate the TCP checksum and set it.
|
||||
xsum = header.Checksum(payload, xsum)
|
||||
|
@ -365,6 +371,15 @@ func (c *Context) SendPacket(payload []byte, h *Headers) {
|
|||
})
|
||||
}
|
||||
|
||||
// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload
|
||||
// & TCPheaders) in an IPv4 packet via the link layer endpoint using the
|
||||
// provided source and destination IPv4 addresses.
|
||||
func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
|
||||
c.linkEP.InjectInbound(ipv4.ProtocolNumber, tcpip.PacketBuffer{
|
||||
Data: c.BuildSegmentWithAddrs(payload, h, src, dst),
|
||||
})
|
||||
}
|
||||
|
||||
// SendAck sends an ACK packet.
|
||||
func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) {
|
||||
c.SendAckWithSACK(seq, bytesReceived, nil)
|
||||
|
@ -490,6 +505,13 @@ func (c *Context) GetV6Packet() []byte {
|
|||
// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
|
||||
// the context.
|
||||
func (c *Context) SendV6Packet(payload []byte, h *Headers) {
|
||||
c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr)
|
||||
}
|
||||
|
||||
// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer
|
||||
// endpoint of the context using the provided source and destination IPv6
|
||||
// addresses.
|
||||
func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) {
|
||||
// Allocate a buffer for data and headers.
|
||||
buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload))
|
||||
copy(buf[len(buf)-len(payload):], payload)
|
||||
|
@ -500,8 +522,8 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
|
|||
PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
|
||||
NextHeader: uint8(tcp.ProtocolNumber),
|
||||
HopLimit: 65,
|
||||
SrcAddr: TestV6Addr,
|
||||
DstAddr: StackV6Addr,
|
||||
SrcAddr: src,
|
||||
DstAddr: dst,
|
||||
})
|
||||
|
||||
// Initialize the TCP header.
|
||||
|
@ -517,7 +539,7 @@ func (c *Context) SendV6Packet(payload []byte, h *Headers) {
|
|||
})
|
||||
|
||||
// Calculate the TCP pseudo-header checksum.
|
||||
xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, TestV6Addr, StackV6Addr, uint16(len(t)))
|
||||
xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t)))
|
||||
|
||||
// Calculate the TCP checksum and set it.
|
||||
xsum = header.Checksum(payload, xsum)
|
||||
|
|
Loading…
Reference in New Issue