Automated rollback of changelist 346565589

PiperOrigin-RevId: 347911316
This commit is contained in:
gVisor bot 2020-12-16 15:38:39 -08:00
parent 2ec6e44c9e
commit 0ac6636aaf
4 changed files with 29 additions and 110 deletions

View File

@ -36,10 +36,10 @@ const (
// UDPFields contains the fields of a UDP packet. It is used to describe the // UDPFields contains the fields of a UDP packet. It is used to describe the
// fields of a packet that needs to be encoded. // fields of a packet that needs to be encoded.
type UDPFields struct { type UDPFields struct {
// SrcPort is the "Source Port" field of a UDP packet. // SrcPort is the "source port" field of a UDP packet.
SrcPort uint16 SrcPort uint16
// DstPort is the "Destination Port" field of a UDP packet. // DstPort is the "destination port" field of a UDP packet.
DstPort uint16 DstPort uint16
// Length is the "length" field of a UDP packet. // Length is the "length" field of a UDP packet.
@ -64,57 +64,52 @@ const (
UDPProtocolNumber tcpip.TransportProtocolNumber = 17 UDPProtocolNumber tcpip.TransportProtocolNumber = 17
) )
// SourcePort returns the "Source Port" field of the UDP header. // SourcePort returns the "source port" field of the udp header.
func (b UDP) SourcePort() uint16 { func (b UDP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[udpSrcPort:]) return binary.BigEndian.Uint16(b[udpSrcPort:])
} }
// DestinationPort returns the "Destination Port" field of the UDP header. // DestinationPort returns the "destination port" field of the udp header.
func (b UDP) DestinationPort() uint16 { func (b UDP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[udpDstPort:]) return binary.BigEndian.Uint16(b[udpDstPort:])
} }
// Length returns the "Length" field of the UDP header. // Length returns the "length" field of the udp header.
func (b UDP) Length() uint16 { func (b UDP) Length() uint16 {
return binary.BigEndian.Uint16(b[udpLength:]) return binary.BigEndian.Uint16(b[udpLength:])
} }
// Payload returns the data contained in the UDP datagram. // Payload returns the data contained in the UDP datagram.
func (b UDP) Payload() []byte { func (b UDP) Payload() []byte {
return b[:b.Length()][UDPMinimumSize:] return b[UDPMinimumSize:]
} }
// Checksum returns the "checksum" field of the UDP header. // Checksum returns the "checksum" field of the udp header.
func (b UDP) Checksum() uint16 { func (b UDP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[udpChecksum:]) return binary.BigEndian.Uint16(b[udpChecksum:])
} }
// SetSourcePort sets the "source port" field of the UDP header. // SetSourcePort sets the "source port" field of the udp header.
func (b UDP) SetSourcePort(port uint16) { func (b UDP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[udpSrcPort:], port) binary.BigEndian.PutUint16(b[udpSrcPort:], port)
} }
// SetDestinationPort sets the "destination port" field of the UDP header. // SetDestinationPort sets the "destination port" field of the udp header.
func (b UDP) SetDestinationPort(port uint16) { func (b UDP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[udpDstPort:], port) binary.BigEndian.PutUint16(b[udpDstPort:], port)
} }
// SetChecksum sets the "checksum" field of the UDP header. // SetChecksum sets the "checksum" field of the udp header.
func (b UDP) SetChecksum(checksum uint16) { func (b UDP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[udpChecksum:], checksum) binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
} }
// SetLength sets the "length" field of the UDP header. // SetLength sets the "length" field of the udp header.
func (b UDP) SetLength(length uint16) { func (b UDP) SetLength(length uint16) {
binary.BigEndian.PutUint16(b[udpLength:], length) binary.BigEndian.PutUint16(b[udpLength:], length)
} }
// PayloadLength returns the length of the payload following the UDP header. // CalculateChecksum calculates the checksum of the udp packet, given the
func (b UDP) PayloadLength() uint16 {
return b.Length() - UDPMinimumSize
}
// CalculateChecksum calculates the checksum of the UDP packet, given the
// checksum of the network-layer pseudo-header and the checksum of the payload. // checksum of the network-layer pseudo-header and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 { func (b UDP) CalculateChecksum(partialChecksum uint16) uint16 {
// Calculate the rest of the checksum. // Calculate the rest of the checksum.

View File

@ -58,6 +58,5 @@ go_test(
"//pkg/tcpip/stack", "//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp", "//pkg/tcpip/transport/icmp",
"//pkg/waiter", "//pkg/waiter",
"@com_github_google_go_cmp//cmp:go_default_library",
], ],
) )

View File

@ -1259,6 +1259,7 @@ func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
// HandlePacket is called by the stack when new packets arrive to this transport // HandlePacket is called by the stack when new packets arrive to this transport
// endpoint. // endpoint.
func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) {
// Get the header then trim it from the view.
hdr := header.UDP(pkt.TransportHeader().View()) hdr := header.UDP(pkt.TransportHeader().View())
if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize {
// Malformed packet. // Malformed packet.
@ -1267,10 +1268,6 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
return return
} }
// TODO(gvisor.dev/issues/5033): We should mirror the Network layer and cap
// packets at "Parse" instead of when handling a packet.
pkt.Data.CapLength(int(hdr.PayloadLength()))
if !verifyChecksum(hdr, pkt) { if !verifyChecksum(hdr, pkt) {
// Checksum Error. // Checksum Error.
e.stack.Stats().UDP.ChecksumErrors.Increment() e.stack.Stats().UDP.ChecksumErrors.Increment()
@ -1304,7 +1301,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
senderAddress: tcpip.FullAddress{ senderAddress: tcpip.FullAddress{
NIC: pkt.NICID, NIC: pkt.NICID,
Addr: id.RemoteAddress, Addr: id.RemoteAddress,
Port: hdr.SourcePort(), Port: header.UDP(hdr).SourcePort(),
}, },
destinationAddress: tcpip.FullAddress{ destinationAddress: tcpip.FullAddress{
NIC: pkt.NICID, NIC: pkt.NICID,

View File

@ -22,7 +22,6 @@ import (
"testing" "testing"
"time" "time"
"github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/checker"
@ -1915,31 +1914,27 @@ func TestV4UnknownDestination(t *testing.T) {
icmpPkt := header.ICMPv4(hdr.Payload()) icmpPkt := header.ICMPv4(hdr.Payload())
payloadIPHeader := header.IPv4(icmpPkt.Payload()) payloadIPHeader := header.IPv4(icmpPkt.Payload())
incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize incomingHeaderLength := header.IPv4MinimumSize + header.UDPMinimumSize
wantPayloadLen := len(payload) wantLen := len(payload)
if tc.largePayload { if tc.largePayload {
// To work out the data size we need to simulate what the sender would // To work out the data size we need to simulate what the sender would
// have done. The wanted size is the total available minus the sum of // have done. The wanted size is the total available minus the sum of
// the headers in the UDP AND ICMP packets, given that we know the test // the headers in the UDP AND ICMP packets, given that we know the test
// had only a minimal IP header but the ICMP sender will have allowed // had only a minimal IP header but the ICMP sender will have allowed
// for a maximally sized packet header. // for a maximally sized packet header.
wantPayloadLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength
} }
// In the case of large payloads the IP packet may be truncated. Update // In the case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload. // the length field before retrieving the udp datagram payload.
// Add back the two headers within the payload. // Add back the two headers within the payload.
payloadIPHeader.SetTotalLength(uint16(wantPayloadLen + incomingHeaderLength)) payloadIPHeader.SetTotalLength(uint16(wantLen + incomingHeaderLength))
origDgram := header.UDP(payloadIPHeader.Payload()) origDgram := header.UDP(payloadIPHeader.Payload())
wantDgramLen := wantPayloadLen + header.UDPMinimumSize if got, want := len(origDgram.Payload()), wantLen; got != want {
t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
if got, want := len(origDgram), wantDgramLen; got != want {
t.Fatalf("got len(origDgram) = %d, want = %d", got, want)
} }
// Correct UDP length to access payload. if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
origDgram.SetLength(uint16(wantDgramLen)) t.Fatalf("unexpected payload got: %d, want: %d", got, want)
if got, want := origDgram.Payload(), payload[:wantPayloadLen]; !bytes.Equal(got, want) {
t.Fatalf("got origDgram.Payload() = %x, want = %x", got, want)
} }
}) })
} }
@ -2014,23 +2009,20 @@ func TestV6UnknownDestination(t *testing.T) {
icmpPkt := header.ICMPv6(hdr.Payload()) icmpPkt := header.ICMPv6(hdr.Payload())
payloadIPHeader := header.IPv6(icmpPkt.Payload()) payloadIPHeader := header.IPv6(icmpPkt.Payload())
wantPayloadLen := len(payload) wantLen := len(payload)
if tc.largePayload { if tc.largePayload {
wantPayloadLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize
} }
wantDgramLen := wantPayloadLen + header.UDPMinimumSize
// In case of large payloads the IP packet may be truncated. Update // In case of large payloads the IP packet may be truncated. Update
// the length field before retrieving the udp datagram payload. // the length field before retrieving the udp datagram payload.
payloadIPHeader.SetPayloadLength(uint16(wantDgramLen)) payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize))
origDgram := header.UDP(payloadIPHeader.Payload()) origDgram := header.UDP(payloadIPHeader.Payload())
if got, want := len(origDgram), wantPayloadLen+header.UDPMinimumSize; got != want { if got, want := len(origDgram.Payload()), wantLen; got != want {
t.Fatalf("got len(origDgram) = %d, want = %d", got, want) t.Fatalf("unexpected payload length got: %d, want: %d", got, want)
} }
// Correct UDP length to access payload. if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) {
origDgram.SetLength(uint16(wantPayloadLen + header.UDPMinimumSize)) t.Fatalf("unexpected payload got: %v, want: %v", got, want)
if diff := cmp.Diff(payload[:wantPayloadLen], origDgram.Payload()); diff != "" {
t.Fatalf("origDgram.Payload() mismatch (-want +got):\n%s", diff)
} }
}) })
} }
@ -2543,67 +2535,3 @@ func TestOutgoingSubnetBroadcast(t *testing.T) {
}) })
} }
} }
func TestReceiveShortLength(t *testing.T) {
flows := []testFlow{unicastV4, unicastV6}
for _, flow := range flows {
t.Run(flow.String(), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createEndpointForFlow(flow)
// Bind to wildcard.
bindAddr := tcpip.FullAddress{Port: stackPort}
if err := c.ep.Bind(bindAddr); err != nil {
c.t.Fatalf("c.ep.Bind(%#v): %s", bindAddr, err)
}
payload := newPayload()
extraBytes := []byte{1, 2, 3, 4}
h := flow.header4Tuple(incoming)
var buf buffer.View
var proto tcpip.NetworkProtocolNumber
// Build packets with extra bytes not accounted for in the UDP length
// field.
var udp header.UDP
if flow.isV4() {
buf = c.buildV4Packet(payload, &h)
buf = append(buf, extraBytes...)
ip := header.IPv4(buf)
ip.SetTotalLength(ip.TotalLength() + uint16(len(extraBytes)))
ip.SetChecksum(0)
ip.SetChecksum(^ip.CalculateChecksum())
proto = ipv4.ProtocolNumber
udp = ip.Payload()
} else {
buf = c.buildV6Packet(payload, &h)
buf = append(buf, extraBytes...)
ip := header.IPv6(buf)
ip.SetPayloadLength(ip.PayloadLength() + uint16(len(extraBytes)))
proto = ipv6.ProtocolNumber
udp = ip.Payload()
}
if diff := cmp.Diff(payload, udp.Payload()); diff != "" {
t.Errorf("udp.Payload() mismatch (-want +got):\n%s", diff)
}
c.linkEP.InjectInbound(proto, stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buf.ToVectorisedView(),
}))
// Try to receive the data.
v, _, err := c.ep.Read(nil)
if err != nil {
t.Fatalf("c.ep.Read(nil): %s", err)
}
// Check the payload is read back without extra bytes.
if diff := cmp.Diff(buffer.View(payload), v); diff != "" {
t.Errorf("c.ep.Read(nil) mismatch (-want +got):\n%s", diff)
}
})
}
}