// Copyright 2018 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ipv4_test import ( "bytes" "encoding/hex" "math/rand" "testing" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) func TestExcludeBroadcast(t *testing.T) { s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) const defaultMTU = 65536 id, _ := channel.New(256, defaultMTU, "") if testing.Verbose() { id = sniffer.New(id) } if err := s.CreateNIC(1, id); err != nil { t.Fatalf("CreateNIC failed: %v", err) } if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Broadcast); err != nil { t.Fatalf("AddAddress failed: %v", err) } if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil { t.Fatalf("AddAddress failed: %v", err) } s.SetRouteTable([]tcpip.Route{{ Destination: "\x00\x00\x00\x00", Mask: "\x00\x00\x00\x00", Gateway: "", NIC: 1, }}) randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53} var wq waiter.Queue t.Run("WithoutPrimaryAddress", func(t *testing.T) { ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { t.Fatal(err) } defer ep.Close() // Cannot connect using a broadcast address as the source. if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute { t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) } // However, we can bind to a broadcast address to listen. if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}); err != nil { t.Errorf("Bind failed: %v", err) } }) t.Run("WithPrimaryAddress", func(t *testing.T) { ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { t.Fatal(err) } defer ep.Close() // Add a valid primary endpoint address, now we can connect. if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil { t.Fatalf("AddAddress failed: %v", err) } if err := ep.Connect(randomAddr); err != nil { t.Errorf("Connect failed: %v", err) } }) } // makeHdrAndPayload generates a randomize packet. hdrLength indicates how much // data should already be in the header before WritePacket. extraLength // indicates how much extra space should be in the header. The payload is made // from many Views of the sizes listed in viewSizes. func makeHdrAndPayload(hdrLength int, extraLength int, viewSizes []int) (buffer.Prependable, buffer.VectorisedView) { hdr := buffer.NewPrependable(hdrLength + extraLength) hdr.Prepend(hdrLength) rand.Read(hdr.View()) var views []buffer.View totalLength := 0 for _, s := range viewSizes { newView := buffer.NewView(s) rand.Read(newView) views = append(views, newView) totalLength += s } payload := buffer.NewVectorisedView(totalLength, views) return hdr, payload } // comparePayloads compared the contents of all the packets against the contents // of the source packet. func compareFragments(t *testing.T, packets []packetInfo, sourcePacketInfo packetInfo, mtu uint32) { t.Helper() // Make a complete array of the sourcePacketInfo packet. source := header.IPv4(packets[0].Header.View()[:header.IPv4MinimumSize]) source = append(source, sourcePacketInfo.Header.View()...) source = append(source, sourcePacketInfo.Payload.ToView()...) // Make a copy of the IP header, which will be modified in some fields to make // an expected header. sourceCopy := header.IPv4(append(buffer.View(nil), source[:source.HeaderLength()]...)) sourceCopy.SetChecksum(0) sourceCopy.SetFlagsFragmentOffset(0, 0) sourceCopy.SetTotalLength(0) var offset uint16 // Build up an array of the bytes sent. var reassembledPayload []byte for i, packet := range packets { // Confirm that the packet is valid. allBytes := packet.Header.View().ToVectorisedView() allBytes.Append(packet.Payload) ip := header.IPv4(allBytes.ToView()) if !ip.IsValid(len(ip)) { t.Errorf("IP packet is invalid:\n%s", hex.Dump(ip)) } if got, want := ip.CalculateChecksum(), uint16(0xffff); got != want { t.Errorf("ip.CalculateChecksum() got %#x, want %#x", got, want) } if got, want := len(ip), int(mtu); got > want { t.Errorf("fragment is too large, got %d want %d", got, want) } if got, want := packet.Header.UsedLength(), sourcePacketInfo.Header.UsedLength()+header.IPv4MinimumSize; i == 0 && want < int(mtu) && got != want { t.Errorf("first fragment hdr parts should have unmodified length if possible: got %d, want %d", got, want) } if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want { t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) } if i < len(packets)-1 { sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) } else { sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()&^header.IPv4FlagMoreFragments, offset) } reassembledPayload = append(reassembledPayload, ip.Payload()...) offset += ip.TotalLength() - uint16(ip.HeaderLength()) // Clear out the checksum and length from the ip because we can't compare // it. sourceCopy.SetTotalLength(uint16(len(ip))) sourceCopy.SetChecksum(0) sourceCopy.SetChecksum(^sourceCopy.CalculateChecksum()) if !bytes.Equal(ip[:ip.HeaderLength()], sourceCopy[:sourceCopy.HeaderLength()]) { t.Errorf("ip[:ip.HeaderLength()] got:\n%s\nwant:\n%s", hex.Dump(ip[:ip.HeaderLength()]), hex.Dump(sourceCopy[:sourceCopy.HeaderLength()])) } } expected := source[source.HeaderLength():] if !bytes.Equal(reassembledPayload, expected) { t.Errorf("reassembledPayload got:\n%s\nwant:\n%s", hex.Dump(reassembledPayload), hex.Dump(expected)) } } type errorChannel struct { *channel.Endpoint Ch chan packetInfo packetCollectorErrors []*tcpip.Error } // newErrorChannel creates a new errorChannel endpoint. Each call to WritePacket // will return successive errors from packetCollectorErrors until the list is // empty and then return nil each time. func newErrorChannel(size int, mtu uint32, linkAddr tcpip.LinkAddress, packetCollectorErrors []*tcpip.Error) (tcpip.LinkEndpointID, *errorChannel) { _, e := channel.New(size, mtu, linkAddr) ec := errorChannel{ Endpoint: e, Ch: make(chan packetInfo, size), packetCollectorErrors: packetCollectorErrors, } return stack.RegisterLinkEndpoint(e), &ec } // packetInfo holds all the information about an outbound packet. type packetInfo struct { Header buffer.Prependable Payload buffer.VectorisedView } // Drain removes all outbound packets from the channel and counts them. func (e *errorChannel) Drain() int { c := 0 for { select { case <-e.Ch: c++ default: return c } } } // WritePacket stores outbound packets into the channel. func (e *errorChannel) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { p := packetInfo{ Header: hdr, Payload: payload, } select { case e.Ch <- p: default: } nextError := (*tcpip.Error)(nil) if len(e.packetCollectorErrors) > 0 { nextError = e.packetCollectorErrors[0] e.packetCollectorErrors = e.packetCollectorErrors[1:] } return nextError } type context struct { stack.Route linkEP *errorChannel } func buildContext(t *testing.T, packetCollectorErrors []*tcpip.Error, mtu uint32) context { // Make the packet and write it. s := stack.New([]string{ipv4.ProtocolName}, []string{}, stack.Options{}) _, linkEP := newErrorChannel(100 /* Enough for all tests. */, mtu, "", packetCollectorErrors) linkEPId := stack.RegisterLinkEndpoint(linkEP) s.CreateNIC(1, linkEPId) s.AddAddress(1, ipv4.ProtocolNumber, "\x10\x00\x00\x01") s.SetRouteTable([]tcpip.Route{{ Destination: "\x10\x00\x00\x02", Mask: "\xff\xff\xff\xff", Gateway: "", NIC: 1, }}) r, err := s.FindRoute(0, "\x10\x00\x00\x01", "\x10\x00\x00\x02", ipv4.ProtocolNumber, false /* multicastLoop */) if err != nil { t.Fatalf("s.FindRoute got %v, want %v", err, nil) } return context{ Route: r, linkEP: linkEP, } } func TestFragmentation(t *testing.T) { var manyPayloadViewsSizes [1000]int for i := range manyPayloadViewsSizes { manyPayloadViewsSizes[i] = 7 } fragTests := []struct { description string mtu uint32 gso *stack.GSO hdrLength int extraLength int payloadViewsSizes []int expectedFrags int }{ {"NoFragmentation", 2000, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 1}, {"NoFragmentationWithBigHeader", 2000, &stack.GSO{}, 16, header.IPv4MinimumSize, []int{1000}, 1}, {"Fragmented", 800, &stack.GSO{}, 0, header.IPv4MinimumSize, []int{1000}, 2}, {"FragmentedWithGsoNil", 800, nil, 0, header.IPv4MinimumSize, []int{1000}, 2}, {"FragmentedWithManyViews", 300, &stack.GSO{}, 0, header.IPv4MinimumSize, manyPayloadViewsSizes[:], 25}, {"FragmentedWithManyViewsAndPrependableBytes", 300, &stack.GSO{}, 0, header.IPv4MinimumSize + 55, manyPayloadViewsSizes[:], 25}, {"FragmentedWithBigHeader", 800, &stack.GSO{}, 20, header.IPv4MinimumSize, []int{1000}, 2}, {"FragmentedWithBigHeaderAndPrependableBytes", 800, &stack.GSO{}, 20, header.IPv4MinimumSize + 66, []int{1000}, 2}, {"FragmentedWithMTUSmallerThanHeaderAndPrependableBytes", 300, &stack.GSO{}, 1000, header.IPv4MinimumSize + 77, []int{500}, 6}, } for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, ft.extraLength, ft.payloadViewsSizes) source := packetInfo{ Header: hdr, // Save the source payload because WritePacket will modify it. Payload: payload.Clone([]buffer.View{}), } c := buildContext(t, nil, ft.mtu) err := c.Route.WritePacket(ft.gso, hdr, payload, tcp.ProtocolNumber, 42) if err != nil { t.Errorf("err got %v, want %v", err, nil) } var results []packetInfo L: for { select { case pi := <-c.linkEP.Ch: results = append(results, pi) default: break L } } if got, want := len(results), ft.expectedFrags; got != want { t.Errorf("len(result) got %d, want %d", got, want) } if got, want := len(results), int(c.Route.Stats().IP.PacketsSent.Value()); got != want { t.Errorf("no errors yet len(result) got %d, want %d", got, want) } compareFragments(t, results, source, ft.mtu) }) } } // TestFragmentationErrors checks that errors are returned from write packet // correctly. func TestFragmentationErrors(t *testing.T) { fragTests := []struct { description string mtu uint32 hdrLength int payloadViewsSizes []int packetCollectorErrors []*tcpip.Error }{ {"NoFrag", 2000, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, {"ErrorOnFirstFrag", 500, 0, []int{1000}, []*tcpip.Error{tcpip.ErrAborted}}, {"ErrorOnSecondFrag", 500, 0, []int{1000}, []*tcpip.Error{nil, tcpip.ErrAborted}}, {"ErrorOnFirstFragMTUSmallerThanHdr", 500, 1000, []int{500}, []*tcpip.Error{tcpip.ErrAborted}}, } for _, ft := range fragTests { t.Run(ft.description, func(t *testing.T) { hdr, payload := makeHdrAndPayload(ft.hdrLength, header.IPv4MinimumSize, ft.payloadViewsSizes) c := buildContext(t, ft.packetCollectorErrors, ft.mtu) err := c.Route.WritePacket(&stack.GSO{}, hdr, payload, tcp.ProtocolNumber, 42) for i := 0; i < len(ft.packetCollectorErrors)-1; i++ { if got, want := ft.packetCollectorErrors[i], (*tcpip.Error)(nil); got != want { t.Errorf("ft.packetCollectorErrors[%d] got %v, want %v", i, got, want) } } // We only need to check that last error because all the ones before are // nil. if got, want := err, ft.packetCollectorErrors[len(ft.packetCollectorErrors)-1]; got != want { t.Errorf("err got %v, want %v", got, want) } if got, want := c.linkEP.Drain(), int(c.Route.Stats().IP.PacketsSent.Value())+1; err != nil && got != want { t.Errorf("after linkEP error len(result) got %d, want %d", got, want) } }) } }