Add IPPROTO_RAW, which allows raw sockets to write IP headers.
iptables also relies on IPPROTO_RAW in a way. It opens such a socket to manipulate the kernel's tables, but it doesn't actually use any of the functionality. Blegh. PiperOrigin-RevId: 257903078
This commit is contained in:
parent
17bab652af
commit
9b4d3280e1
|
@ -40,42 +40,49 @@ type provider struct {
|
|||
}
|
||||
|
||||
// getTransportProtocol figures out transport protocol. Currently only TCP,
|
||||
// UDP, and ICMP are supported.
|
||||
func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) {
|
||||
// UDP, and ICMP are supported. The bool return value is true when this socket
|
||||
// is associated with a transport protocol. This is only false for SOCK_RAW,
|
||||
// IPPROTO_IP sockets.
|
||||
func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol int) (tcpip.TransportProtocolNumber, bool, *syserr.Error) {
|
||||
switch stype {
|
||||
case linux.SOCK_STREAM:
|
||||
if protocol != 0 && protocol != syscall.IPPROTO_TCP {
|
||||
return 0, syserr.ErrInvalidArgument
|
||||
return 0, true, syserr.ErrInvalidArgument
|
||||
}
|
||||
return tcp.ProtocolNumber, nil
|
||||
return tcp.ProtocolNumber, true, nil
|
||||
|
||||
case linux.SOCK_DGRAM:
|
||||
switch protocol {
|
||||
case 0, syscall.IPPROTO_UDP:
|
||||
return udp.ProtocolNumber, nil
|
||||
return udp.ProtocolNumber, true, nil
|
||||
case syscall.IPPROTO_ICMP:
|
||||
return header.ICMPv4ProtocolNumber, nil
|
||||
return header.ICMPv4ProtocolNumber, true, nil
|
||||
case syscall.IPPROTO_ICMPV6:
|
||||
return header.ICMPv6ProtocolNumber, nil
|
||||
return header.ICMPv6ProtocolNumber, true, nil
|
||||
}
|
||||
|
||||
case linux.SOCK_RAW:
|
||||
// Raw sockets require CAP_NET_RAW.
|
||||
creds := auth.CredentialsFromContext(ctx)
|
||||
if !creds.HasCapability(linux.CAP_NET_RAW) {
|
||||
return 0, syserr.ErrPermissionDenied
|
||||
return 0, true, syserr.ErrPermissionDenied
|
||||
}
|
||||
|
||||
switch protocol {
|
||||
case syscall.IPPROTO_ICMP:
|
||||
return header.ICMPv4ProtocolNumber, nil
|
||||
return header.ICMPv4ProtocolNumber, true, nil
|
||||
case syscall.IPPROTO_UDP:
|
||||
return header.UDPProtocolNumber, nil
|
||||
return header.UDPProtocolNumber, true, nil
|
||||
case syscall.IPPROTO_TCP:
|
||||
return header.TCPProtocolNumber, nil
|
||||
return header.TCPProtocolNumber, true, nil
|
||||
// IPPROTO_RAW signifies that the raw socket isn't assigned to
|
||||
// a transport protocol. Users will be able to write packets'
|
||||
// IP headers and won't receive anything.
|
||||
case syscall.IPPROTO_RAW:
|
||||
return tcpip.TransportProtocolNumber(0), false, nil
|
||||
}
|
||||
}
|
||||
return 0, syserr.ErrProtocolNotSupported
|
||||
return 0, true, syserr.ErrProtocolNotSupported
|
||||
}
|
||||
|
||||
// Socket creates a new socket object for the AF_INET or AF_INET6 family.
|
||||
|
@ -93,7 +100,7 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
|
|||
}
|
||||
|
||||
// Figure out the transport protocol.
|
||||
transProto, err := getTransportProtocol(t, stype, protocol)
|
||||
transProto, associated, err := getTransportProtocol(t, stype, protocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -103,7 +110,7 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
|
|||
var e *tcpip.Error
|
||||
wq := &waiter.Queue{}
|
||||
if stype == linux.SOCK_RAW {
|
||||
ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq)
|
||||
ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated)
|
||||
} else {
|
||||
ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq)
|
||||
}
|
||||
|
|
|
@ -83,6 +83,10 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buf
|
|||
return tcpip.ErrNotSupported
|
||||
}
|
||||
|
||||
func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
|
||||
return tcpip.ErrNotSupported
|
||||
}
|
||||
|
||||
func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
|
||||
v := vv.First()
|
||||
h := header.ARP(v)
|
||||
|
|
|
@ -232,6 +232,55 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
|
|||
return nil
|
||||
}
|
||||
|
||||
// WriteHeaderIncludedPacket writes a packet already containing a network
|
||||
// header through the given route.
|
||||
func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
|
||||
// The packet already has an IP header, but there are a few required
|
||||
// checks.
|
||||
ip := header.IPv4(payload.First())
|
||||
if !ip.IsValid(payload.Size()) {
|
||||
return tcpip.ErrInvalidOptionValue
|
||||
}
|
||||
|
||||
// Always set the total length.
|
||||
ip.SetTotalLength(uint16(payload.Size()))
|
||||
|
||||
// Set the source address when zero.
|
||||
if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) {
|
||||
ip.SetSourceAddress(r.LocalAddress)
|
||||
}
|
||||
|
||||
// Set the destination. If the packet already included a destination,
|
||||
// it will be part of the route.
|
||||
ip.SetDestinationAddress(r.RemoteAddress)
|
||||
|
||||
// Set the packet ID when zero.
|
||||
if ip.ID() == 0 {
|
||||
id := uint32(0)
|
||||
if payload.Size() > header.IPv4MaximumHeaderSize+8 {
|
||||
// Packets of 68 bytes or less are required by RFC 791 to not be
|
||||
// fragmented, so we only assign ids to larger packets.
|
||||
id = atomic.AddUint32(&ids[hashRoute(r, 0 /* protocol */)%buckets], 1)
|
||||
}
|
||||
ip.SetID(uint16(id))
|
||||
}
|
||||
|
||||
// Always set the checksum.
|
||||
ip.SetChecksum(0)
|
||||
ip.SetChecksum(^ip.CalculateChecksum())
|
||||
|
||||
if loop&stack.PacketLoop != 0 {
|
||||
e.HandlePacket(r, payload)
|
||||
}
|
||||
if loop&stack.PacketOut == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
hdr := buffer.NewPrependableFromView(payload.ToView())
|
||||
r.Stats().IP.PacketsSent.Increment()
|
||||
return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber)
|
||||
}
|
||||
|
||||
// HandlePacket is called by the link layer when new ipv4 packets arrive for
|
||||
// this endpoint.
|
||||
func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
|
||||
|
|
|
@ -120,6 +120,13 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen
|
|||
return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber)
|
||||
}
|
||||
|
||||
// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet
|
||||
// supported by IPv6.
|
||||
func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
|
||||
// TODO(b/119580726): Support IPv6 header-included packets.
|
||||
return tcpip.ErrNotSupported
|
||||
}
|
||||
|
||||
// HandlePacket is called by the link layer when new ipv6 packets arrive for
|
||||
// this endpoint.
|
||||
func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
|
||||
|
|
|
@ -174,6 +174,10 @@ type NetworkEndpoint interface {
|
|||
// protocol.
|
||||
WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error
|
||||
|
||||
// WriteHeaderIncludedPacket writes a packet that includes a network
|
||||
// header to the given destination address.
|
||||
WriteHeaderIncludedPacket(r *Route, payload buffer.VectorisedView, loop PacketLooping) *tcpip.Error
|
||||
|
||||
// ID returns the network protocol endpoint ID.
|
||||
ID() *NetworkEndpointID
|
||||
|
||||
|
@ -357,10 +361,19 @@ type TransportProtocolFactory func() TransportProtocol
|
|||
// instantiate network protocols.
|
||||
type NetworkProtocolFactory func() NetworkProtocol
|
||||
|
||||
// UnassociatedEndpointFactory produces endpoints for writing packets not
|
||||
// associated with a particular transport protocol. Such endpoints can be used
|
||||
// to write arbitrary packets that include the IP header.
|
||||
type UnassociatedEndpointFactory interface {
|
||||
NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
|
||||
}
|
||||
|
||||
var (
|
||||
transportProtocols = make(map[string]TransportProtocolFactory)
|
||||
networkProtocols = make(map[string]NetworkProtocolFactory)
|
||||
|
||||
unassociatedFactory UnassociatedEndpointFactory
|
||||
|
||||
linkEPMu sync.RWMutex
|
||||
nextLinkEndpointID tcpip.LinkEndpointID = 1
|
||||
linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint)
|
||||
|
@ -380,6 +393,13 @@ func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) {
|
|||
networkProtocols[name] = p
|
||||
}
|
||||
|
||||
// RegisterUnassociatedFactory registers a factory to produce endpoints not
|
||||
// associated with any particular transport protocol. This function is intended
|
||||
// to be called by init() functions of the protocols.
|
||||
func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) {
|
||||
unassociatedFactory = f
|
||||
}
|
||||
|
||||
// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an
|
||||
// ID that can be used to refer to it.
|
||||
func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID {
|
||||
|
|
|
@ -163,6 +163,18 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec
|
|||
return err
|
||||
}
|
||||
|
||||
// WriteHeaderIncludedPacket writes a packet already containing a network
|
||||
// header through the given route.
|
||||
func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error {
|
||||
if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil {
|
||||
r.Stats().IP.OutgoingPacketErrors.Increment()
|
||||
return err
|
||||
}
|
||||
r.ref.nic.stats.Tx.Packets.Increment()
|
||||
r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payload.Size()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefaultTTL returns the default TTL of the underlying network endpoint.
|
||||
func (r *Route) DefaultTTL() uint8 {
|
||||
return r.ref.ep.DefaultTTL()
|
||||
|
|
|
@ -340,6 +340,8 @@ type Stack struct {
|
|||
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
|
||||
linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
|
||||
|
||||
unassociatedFactory UnassociatedEndpointFactory
|
||||
|
||||
demux *transportDemuxer
|
||||
|
||||
stats tcpip.Stats
|
||||
|
@ -442,6 +444,8 @@ func New(network []string, transport []string, opts Options) *Stack {
|
|||
}
|
||||
}
|
||||
|
||||
s.unassociatedFactory = unassociatedFactory
|
||||
|
||||
// Create the global transport demuxer.
|
||||
s.demux = newTransportDemuxer(s)
|
||||
|
||||
|
@ -574,11 +578,15 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp
|
|||
// NewRawEndpoint creates a new raw transport layer endpoint of the given
|
||||
// protocol. Raw endpoints receive all traffic for a given protocol regardless
|
||||
// of address.
|
||||
func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
|
||||
func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
|
||||
if !s.raw {
|
||||
return nil, tcpip.ErrNotPermitted
|
||||
}
|
||||
|
||||
if !associated {
|
||||
return s.unassociatedFactory.NewUnassociatedRawEndpoint(s, network, transport, waiterQueue)
|
||||
}
|
||||
|
||||
t, ok := s.transportProtocols[transport]
|
||||
if !ok {
|
||||
return nil, tcpip.ErrUnknownProtocol
|
||||
|
|
|
@ -137,6 +137,10 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu
|
|||
return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber)
|
||||
}
|
||||
|
||||
func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error {
|
||||
return tcpip.ErrNotSupported
|
||||
}
|
||||
|
||||
func (*fakeNetworkEndpoint) Close() {}
|
||||
|
||||
type fakeNetGoodOption bool
|
||||
|
|
|
@ -21,6 +21,7 @@ go_library(
|
|||
"endpoint.go",
|
||||
"endpoint_state.go",
|
||||
"packet_list.go",
|
||||
"protocol.go",
|
||||
],
|
||||
importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw",
|
||||
imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"],
|
||||
|
|
|
@ -67,6 +67,7 @@ type endpoint struct {
|
|||
netProto tcpip.NetworkProtocolNumber
|
||||
transProto tcpip.TransportProtocolNumber
|
||||
waiterQueue *waiter.Queue
|
||||
associated bool
|
||||
|
||||
// The following fields are used to manage the receive queue and are
|
||||
// protected by rcvMu.
|
||||
|
@ -97,8 +98,12 @@ type endpoint struct {
|
|||
}
|
||||
|
||||
// NewEndpoint returns a raw endpoint for the given protocols.
|
||||
// TODO(b/129292371): IP_HDRINCL, IPPROTO_RAW, and AF_PACKET.
|
||||
// TODO(b/129292371): IP_HDRINCL and AF_PACKET.
|
||||
func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
|
||||
return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */)
|
||||
}
|
||||
|
||||
func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
|
||||
if netProto != header.IPv4ProtocolNumber {
|
||||
return nil, tcpip.ErrUnknownProtocol
|
||||
}
|
||||
|
@ -110,6 +115,16 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
|
|||
waiterQueue: waiterQueue,
|
||||
rcvBufSizeMax: 32 * 1024,
|
||||
sndBufSize: 32 * 1024,
|
||||
associated: associated,
|
||||
}
|
||||
|
||||
// Unassociated endpoints are write-only and users call Write() with IP
|
||||
// headers included. Because they're write-only, We don't need to
|
||||
// register with the stack.
|
||||
if !associated {
|
||||
ep.rcvBufSizeMax = 0
|
||||
ep.waiterQueue = nil
|
||||
return ep, nil
|
||||
}
|
||||
|
||||
if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
|
||||
|
@ -124,7 +139,7 @@ func (ep *endpoint) Close() {
|
|||
ep.mu.Lock()
|
||||
defer ep.mu.Unlock()
|
||||
|
||||
if ep.closed {
|
||||
if ep.closed || !ep.associated {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -142,8 +157,11 @@ func (ep *endpoint) Close() {
|
|||
|
||||
if ep.connected {
|
||||
ep.route.Release()
|
||||
ep.connected = false
|
||||
}
|
||||
|
||||
ep.closed = true
|
||||
|
||||
ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
|
||||
}
|
||||
|
||||
|
@ -152,6 +170,10 @@ func (ep *endpoint) ModerateRecvBuf(copied int) {}
|
|||
|
||||
// Read implements tcpip.Endpoint.Read.
|
||||
func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
|
||||
if !ep.associated {
|
||||
return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue
|
||||
}
|
||||
|
||||
ep.rcvMu.Lock()
|
||||
|
||||
// If there's no data to read, return that read would block or that the
|
||||
|
@ -192,6 +214,33 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp
|
|||
return 0, nil, tcpip.ErrInvalidEndpointState
|
||||
}
|
||||
|
||||
payloadBytes, err := payload.Get(payload.Size())
|
||||
if err != nil {
|
||||
ep.mu.RUnlock()
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
// If this is an unassociated socket and callee provided a nonzero
|
||||
// destination address, route using that address.
|
||||
if !ep.associated {
|
||||
ip := header.IPv4(payloadBytes)
|
||||
if !ip.IsValid(payload.Size()) {
|
||||
ep.mu.RUnlock()
|
||||
return 0, nil, tcpip.ErrInvalidOptionValue
|
||||
}
|
||||
dstAddr := ip.DestinationAddress()
|
||||
// Update dstAddr with the address in the IP header, unless
|
||||
// opts.To is set (e.g. if sendto specifies a specific
|
||||
// address).
|
||||
if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil {
|
||||
opts.To = &tcpip.FullAddress{
|
||||
NIC: 0, // NIC is unset.
|
||||
Addr: dstAddr, // The address from the payload.
|
||||
Port: 0, // There are no ports here.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Did the user caller provide a destination? If not, use the connected
|
||||
// destination.
|
||||
if opts.To == nil {
|
||||
|
@ -216,12 +265,12 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp
|
|||
return 0, nil, tcpip.ErrInvalidEndpointState
|
||||
}
|
||||
|
||||
n, ch, err := ep.finishWrite(payload, savedRoute)
|
||||
n, ch, err := ep.finishWrite(payloadBytes, savedRoute)
|
||||
ep.mu.Unlock()
|
||||
return n, ch, err
|
||||
}
|
||||
|
||||
n, ch, err := ep.finishWrite(payload, &ep.route)
|
||||
n, ch, err := ep.finishWrite(payloadBytes, &ep.route)
|
||||
ep.mu.RUnlock()
|
||||
return n, ch, err
|
||||
}
|
||||
|
@ -248,7 +297,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp
|
|||
return 0, nil, err
|
||||
}
|
||||
|
||||
n, ch, err := ep.finishWrite(payload, &route)
|
||||
n, ch, err := ep.finishWrite(payloadBytes, &route)
|
||||
route.Release()
|
||||
ep.mu.RUnlock()
|
||||
return n, ch, err
|
||||
|
@ -256,7 +305,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp
|
|||
|
||||
// finishWrite writes the payload to a route. It resolves the route if
|
||||
// necessary. It's really just a helper to make defer unnecessary in Write.
|
||||
func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) {
|
||||
func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) {
|
||||
// We may need to resolve the route (match a link layer address to the
|
||||
// network address). If that requires blocking (e.g. to use ARP),
|
||||
// return a channel on which the caller can wait.
|
||||
|
@ -269,13 +318,14 @@ func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uint
|
|||
}
|
||||
}
|
||||
|
||||
payloadBytes, err := payload.Get(payload.Size())
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
switch ep.netProto {
|
||||
case header.IPv4ProtocolNumber:
|
||||
if !ep.associated {
|
||||
if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
break
|
||||
}
|
||||
hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
|
||||
if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, route.DefaultTTL()); err != nil {
|
||||
return 0, nil, err
|
||||
|
@ -335,15 +385,17 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
|
|||
}
|
||||
defer route.Release()
|
||||
|
||||
// Re-register the endpoint with the appropriate NIC.
|
||||
if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
|
||||
return err
|
||||
if ep.associated {
|
||||
// Re-register the endpoint with the appropriate NIC.
|
||||
if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
|
||||
return err
|
||||
}
|
||||
ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
|
||||
ep.registeredNIC = nic
|
||||
}
|
||||
ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
|
||||
|
||||
// Save the route and NIC we've connected via.
|
||||
// Save the route we've connected via.
|
||||
ep.route = route.Clone()
|
||||
ep.registeredNIC = nic
|
||||
ep.connected = true
|
||||
|
||||
return nil
|
||||
|
@ -386,14 +438,16 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
|
|||
return tcpip.ErrBadLocalAddress
|
||||
}
|
||||
|
||||
// Re-register the endpoint with the appropriate NIC.
|
||||
if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
|
||||
return err
|
||||
if ep.associated {
|
||||
// Re-register the endpoint with the appropriate NIC.
|
||||
if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
|
||||
return err
|
||||
}
|
||||
ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
|
||||
ep.registeredNIC = addr.NIC
|
||||
ep.boundNIC = addr.NIC
|
||||
}
|
||||
ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
|
||||
|
||||
ep.registeredNIC = addr.NIC
|
||||
ep.boundNIC = addr.NIC
|
||||
ep.boundAddr = addr.Addr
|
||||
ep.bound = true
|
||||
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
// Copyright 2019 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package raw
|
||||
|
||||
import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
type factory struct{}
|
||||
|
||||
// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory.
|
||||
func (factory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
|
||||
return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */)
|
||||
}
|
||||
|
||||
func init() {
|
||||
stack.RegisterUnassociatedFactory(factory{})
|
||||
}
|
|
@ -319,6 +319,10 @@ syscall_test(
|
|||
test = "//test/syscalls/linux:pwrite64_test",
|
||||
)
|
||||
|
||||
syscall_test(test = "//test/syscalls/linux:raw_socket_hdrincl_test")
|
||||
|
||||
syscall_test(test = "//test/syscalls/linux:raw_socket_icmp_test")
|
||||
|
||||
syscall_test(test = "//test/syscalls/linux:raw_socket_ipv4_test")
|
||||
|
||||
syscall_test(
|
||||
|
|
|
@ -1561,6 +1561,24 @@ cc_binary(
|
|||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "raw_socket_hdrincl_test",
|
||||
testonly = 1,
|
||||
srcs = ["raw_socket_hdrincl.cc"],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":socket_test_util",
|
||||
":unix_domain_socket_test_util",
|
||||
"//test/util:capability_util",
|
||||
"//test/util:file_descriptor",
|
||||
"//test/util:test_main",
|
||||
"//test/util:test_util",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/base:endian",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "raw_socket_ipv4_test",
|
||||
testonly = 1,
|
||||
|
|
|
@ -0,0 +1,408 @@
|
|||
// Copyright 2019 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <linux/capability.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/ip.h>
|
||||
#include <netinet/ip_icmp.h>
|
||||
#include <netinet/udp.h>
|
||||
#include <poll.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "absl/base/internal/endian.h"
|
||||
#include "test/syscalls/linux/socket_test_util.h"
|
||||
#include "test/syscalls/linux/unix_domain_socket_test_util.h"
|
||||
#include "test/util/capability_util.h"
|
||||
#include "test/util/file_descriptor.h"
|
||||
#include "test/util/test_util.h"
|
||||
|
||||
namespace gvisor {
|
||||
namespace testing {
|
||||
|
||||
namespace {
|
||||
|
||||
// Tests for IPPROTO_RAW raw sockets, which implies IP_HDRINCL.
|
||||
class RawHDRINCL : public ::testing::Test {
|
||||
protected:
|
||||
// Creates a socket to be used in tests.
|
||||
void SetUp() override;
|
||||
|
||||
// Closes the socket created by SetUp().
|
||||
void TearDown() override;
|
||||
|
||||
// Returns a valid looback IP header with no payload.
|
||||
struct iphdr LoopbackHeader();
|
||||
|
||||
// Fills in buf with an IP header, UDP header, and payload. Returns false if
|
||||
// buf_size isn't large enough to hold everything.
|
||||
bool FillPacket(char* buf, size_t buf_size, int port, const char* payload,
|
||||
uint16_t payload_size);
|
||||
|
||||
// The socket used for both reading and writing.
|
||||
int socket_;
|
||||
|
||||
// The loopback address.
|
||||
struct sockaddr_in addr_;
|
||||
};
|
||||
|
||||
void RawHDRINCL::SetUp() {
|
||||
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
|
||||
|
||||
ASSERT_THAT(socket_ = socket(AF_INET, SOCK_RAW, IPPROTO_RAW),
|
||||
SyscallSucceeds());
|
||||
|
||||
addr_ = {};
|
||||
|
||||
addr_.sin_port = IPPROTO_IP;
|
||||
addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
|
||||
addr_.sin_family = AF_INET;
|
||||
}
|
||||
|
||||
void RawHDRINCL::TearDown() {
|
||||
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
|
||||
|
||||
EXPECT_THAT(close(socket_), SyscallSucceeds());
|
||||
}
|
||||
|
||||
struct iphdr RawHDRINCL::LoopbackHeader() {
|
||||
struct iphdr hdr = {};
|
||||
hdr.ihl = 5;
|
||||
hdr.version = 4;
|
||||
hdr.tos = 0;
|
||||
hdr.tot_len = absl::gbswap_16(sizeof(hdr));
|
||||
hdr.id = 0;
|
||||
hdr.frag_off = 0;
|
||||
hdr.ttl = 7;
|
||||
hdr.protocol = 1;
|
||||
hdr.daddr = htonl(INADDR_LOOPBACK);
|
||||
// hdr.check is set by the network stack.
|
||||
// hdr.tot_len is set by the network stack.
|
||||
// hdr.saddr is set by the network stack.
|
||||
return hdr;
|
||||
}
|
||||
|
||||
bool RawHDRINCL::FillPacket(char* buf, size_t buf_size, int port,
|
||||
const char* payload, uint16_t payload_size) {
|
||||
if (buf_size < sizeof(struct iphdr) + sizeof(struct udphdr) + payload_size) {
|
||||
return false;
|
||||
}
|
||||
|
||||
struct iphdr ip = LoopbackHeader();
|
||||
ip.protocol = IPPROTO_UDP;
|
||||
|
||||
struct udphdr udp = {};
|
||||
udp.source = absl::gbswap_16(port);
|
||||
udp.dest = absl::gbswap_16(port);
|
||||
udp.len = absl::gbswap_16(sizeof(udp) + payload_size);
|
||||
udp.check = 0;
|
||||
|
||||
memcpy(buf, reinterpret_cast<char*>(&ip), sizeof(ip));
|
||||
memcpy(buf + sizeof(ip), reinterpret_cast<char*>(&udp), sizeof(udp));
|
||||
memcpy(buf + sizeof(ip) + sizeof(udp), payload, payload_size);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// We should be able to create multiple IPPROTO_RAW sockets. RawHDRINCL::Setup
|
||||
// creates the first one, so we only have to create one more here.
|
||||
TEST_F(RawHDRINCL, MultipleCreation) {
|
||||
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
|
||||
|
||||
int s2;
|
||||
ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, IPPROTO_RAW), SyscallSucceeds());
|
||||
|
||||
ASSERT_THAT(close(s2), SyscallSucceeds());
|
||||
}
|
||||
|
||||
// Test that shutting down an unconnected socket fails.
|
||||
TEST_F(RawHDRINCL, FailShutdownWithoutConnect) {
|
||||
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
|
||||
|
||||
ASSERT_THAT(shutdown(socket_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN));
|
||||
ASSERT_THAT(shutdown(socket_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN));
|
||||
}
|
||||
|
||||
// Test that listen() fails.
|
||||
TEST_F(RawHDRINCL, FailListen) {
|
||||
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
|
||||
|
||||
ASSERT_THAT(listen(socket_, 1), SyscallFailsWithErrno(ENOTSUP));
|
||||
}
|
||||
|
||||
// Test that accept() fails.
|
||||
TEST_F(RawHDRINCL, FailAccept) {
|
||||
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
|
||||
|
||||
struct sockaddr saddr;
|
||||
socklen_t addrlen;
|
||||
ASSERT_THAT(accept(socket_, &saddr, &addrlen),
|
||||
SyscallFailsWithErrno(ENOTSUP));
|
||||
}
|
||||
|
||||
// Test that the socket is writable immediately.
|
||||
TEST_F(RawHDRINCL, PollWritableImmediately) {
|
||||
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
|
||||
|
||||
struct pollfd pfd = {};
|
||||
pfd.fd = socket_;
|
||||
pfd.events = POLLOUT;
|
||||
ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 0), SyscallSucceedsWithValue(1));
|
||||
}
|
||||
|
||||
// Test that the socket isn't readable.
|
||||
TEST_F(RawHDRINCL, NotReadable) {
|
||||
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
|
||||
|
||||
// Try to receive data with MSG_DONTWAIT, which returns immediately if there's
|
||||
// nothing to be read.
|
||||
char buf[117];
|
||||
ASSERT_THAT(RetryEINTR(recv)(socket_, buf, sizeof(buf), MSG_DONTWAIT),
|
||||
SyscallFailsWithErrno(EINVAL));
|
||||
}
|
||||
|
||||
// Test that we can connect() to a valid IP (loopback).
|
||||
TEST_F(RawHDRINCL, ConnectToLoopback) {
|
||||
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
|
||||
|
||||
ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
|
||||
sizeof(addr_)),
|
||||
SyscallSucceeds());
|
||||
}
|
||||
|
||||
TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) {
|
||||
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
|
||||
|
||||
struct iphdr hdr = LoopbackHeader();
|
||||
ASSERT_THAT(send(socket_, &hdr, sizeof(hdr), 0),
|
||||
SyscallSucceedsWithValue(sizeof(hdr)));
|
||||
}
|
||||
|
||||
// HDRINCL implies write-only. Verify that we can't read a packet sent to
|
||||
// loopback.
|
||||
TEST_F(RawHDRINCL, NotReadableAfterWrite) {
|
||||
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
|
||||
|
||||
ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
|
||||
sizeof(addr_)),
|
||||
SyscallSucceeds());
|
||||
|
||||
// Construct a packet with an IP header, UDP header, and payload.
|
||||
constexpr char kPayload[] = "odst";
|
||||
char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)];
|
||||
ASSERT_TRUE(FillPacket(packet, sizeof(packet), 40000 /* port */, kPayload,
|
||||
sizeof(kPayload)));
|
||||
|
||||
socklen_t addrlen = sizeof(addr_);
|
||||
ASSERT_NO_FATAL_FAILURE(
|
||||
sendto(socket_, reinterpret_cast<void*>(&packet), sizeof(packet), 0,
|
||||
reinterpret_cast<struct sockaddr*>(&addr_), addrlen));
|
||||
|
||||
struct pollfd pfd = {};
|
||||
pfd.fd = socket_;
|
||||
pfd.events = POLLIN;
|
||||
ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0));
|
||||
}
|
||||
|
||||
TEST_F(RawHDRINCL, WriteTooSmall) {
|
||||
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
|
||||
|
||||
ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_),
|
||||
sizeof(addr_)),
|
||||
SyscallSucceeds());
|
||||
|
||||
// This is smaller than the size of an IP header.
|
||||
constexpr char kBuf[] = "JP5";
|
||||
ASSERT_THAT(send(socket_, kBuf, sizeof(kBuf), 0),
|
||||
SyscallFailsWithErrno(EINVAL));
|
||||
}
|
||||
|
||||
// Bind to localhost.
|
||||
TEST_F(RawHDRINCL, BindToLocalhost) {
|
||||
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
|
||||
|
||||
ASSERT_THAT(
|
||||
bind(socket_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)),
|
||||
SyscallSucceeds());
|
||||
}
|
||||
|
||||
// Bind to a different address.
|
||||
TEST_F(RawHDRINCL, BindToInvalid) {
|
||||
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
|
||||
|
||||
struct sockaddr_in bind_addr = {};
|
||||
bind_addr.sin_family = AF_INET;
|
||||
bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to.
|
||||
ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr),
|
||||
sizeof(bind_addr)),
|
||||
SyscallFailsWithErrno(EADDRNOTAVAIL));
|
||||
}
|
||||
|
||||
// Send and receive a packet.
|
||||
TEST_F(RawHDRINCL, SendAndReceive) {
|
||||
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
|
||||
|
||||
int port = 40000;
|
||||
if (!IsRunningOnGvisor()) {
|
||||
port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
|
||||
PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false)));
|
||||
}
|
||||
|
||||
// IPPROTO_RAW sockets are write-only. We'll have to open another socket to
|
||||
// read what we write.
|
||||
FileDescriptor udp_sock =
|
||||
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP));
|
||||
|
||||
// Construct a packet with an IP header, UDP header, and payload.
|
||||
constexpr char kPayload[] = "toto";
|
||||
char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)];
|
||||
ASSERT_TRUE(
|
||||
FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload)));
|
||||
|
||||
socklen_t addrlen = sizeof(addr_);
|
||||
ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0,
|
||||
reinterpret_cast<struct sockaddr*>(&addr_),
|
||||
addrlen));
|
||||
|
||||
// Receive the payload.
|
||||
char recv_buf[sizeof(packet)];
|
||||
struct sockaddr_in src;
|
||||
socklen_t src_size = sizeof(src);
|
||||
ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), 0,
|
||||
reinterpret_cast<struct sockaddr*>(&src), &src_size),
|
||||
SyscallSucceedsWithValue(sizeof(packet)));
|
||||
EXPECT_EQ(
|
||||
memcmp(kPayload, recv_buf + sizeof(struct iphdr) + sizeof(struct udphdr),
|
||||
sizeof(kPayload)),
|
||||
0);
|
||||
// The network stack should have set the source address.
|
||||
EXPECT_EQ(src.sin_family, AF_INET);
|
||||
EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK);
|
||||
// The packet ID should be 0, as the packet is less than 68 bytes.
|
||||
struct iphdr iphdr = {};
|
||||
memcpy(&iphdr, recv_buf, sizeof(iphdr));
|
||||
EXPECT_EQ(iphdr.id, 0);
|
||||
}
|
||||
|
||||
// Send and receive a packet with nonzero IP ID.
|
||||
TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) {
|
||||
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
|
||||
|
||||
int port = 40000;
|
||||
if (!IsRunningOnGvisor()) {
|
||||
port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
|
||||
PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false)));
|
||||
}
|
||||
|
||||
// IPPROTO_RAW sockets are write-only. We'll have to open another socket to
|
||||
// read what we write.
|
||||
FileDescriptor udp_sock =
|
||||
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP));
|
||||
|
||||
// Construct a packet with an IP header, UDP header, and payload. Make the
|
||||
// payload large enough to force an IP ID to be assigned.
|
||||
constexpr char kPayload[128] = {};
|
||||
char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)];
|
||||
ASSERT_TRUE(
|
||||
FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload)));
|
||||
|
||||
socklen_t addrlen = sizeof(addr_);
|
||||
ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0,
|
||||
reinterpret_cast<struct sockaddr*>(&addr_),
|
||||
addrlen));
|
||||
|
||||
// Receive the payload.
|
||||
char recv_buf[sizeof(packet)];
|
||||
struct sockaddr_in src;
|
||||
socklen_t src_size = sizeof(src);
|
||||
ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), 0,
|
||||
reinterpret_cast<struct sockaddr*>(&src), &src_size),
|
||||
SyscallSucceedsWithValue(sizeof(packet)));
|
||||
EXPECT_EQ(
|
||||
memcmp(kPayload, recv_buf + sizeof(struct iphdr) + sizeof(struct udphdr),
|
||||
sizeof(kPayload)),
|
||||
0);
|
||||
// The network stack should have set the source address.
|
||||
EXPECT_EQ(src.sin_family, AF_INET);
|
||||
EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK);
|
||||
// The packet ID should not be 0, as the packet was more than 68 bytes.
|
||||
struct iphdr* iphdr = reinterpret_cast<struct iphdr*>(recv_buf);
|
||||
EXPECT_NE(iphdr->id, 0);
|
||||
}
|
||||
|
||||
// Send and receive a packet where the sendto address is not the same as the
|
||||
// provided destination.
|
||||
TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) {
|
||||
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
|
||||
|
||||
int port = 40000;
|
||||
if (!IsRunningOnGvisor()) {
|
||||
port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE(
|
||||
PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false)));
|
||||
}
|
||||
|
||||
// IPPROTO_RAW sockets are write-only. We'll have to open another socket to
|
||||
// read what we write.
|
||||
FileDescriptor udp_sock =
|
||||
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP));
|
||||
|
||||
// Construct a packet with an IP header, UDP header, and payload.
|
||||
constexpr char kPayload[] = "toto";
|
||||
char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)];
|
||||
ASSERT_TRUE(
|
||||
FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload)));
|
||||
// Overwrite the IP destination address with an IP we can't get to.
|
||||
struct iphdr iphdr = {};
|
||||
memcpy(&iphdr, packet, sizeof(iphdr));
|
||||
iphdr.daddr = 42;
|
||||
memcpy(packet, &iphdr, sizeof(iphdr));
|
||||
|
||||
socklen_t addrlen = sizeof(addr_);
|
||||
ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0,
|
||||
reinterpret_cast<struct sockaddr*>(&addr_),
|
||||
addrlen));
|
||||
|
||||
// Receive the payload, since sendto should replace the bad destination with
|
||||
// localhost.
|
||||
char recv_buf[sizeof(packet)];
|
||||
struct sockaddr_in src;
|
||||
socklen_t src_size = sizeof(src);
|
||||
ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), 0,
|
||||
reinterpret_cast<struct sockaddr*>(&src), &src_size),
|
||||
SyscallSucceedsWithValue(sizeof(packet)));
|
||||
EXPECT_EQ(
|
||||
memcmp(kPayload, recv_buf + sizeof(struct iphdr) + sizeof(struct udphdr),
|
||||
sizeof(kPayload)),
|
||||
0);
|
||||
// The network stack should have set the source address.
|
||||
EXPECT_EQ(src.sin_family, AF_INET);
|
||||
EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK);
|
||||
// The packet ID should be 0, as the packet is less than 68 bytes.
|
||||
struct iphdr recv_iphdr = {};
|
||||
memcpy(&recv_iphdr, recv_buf, sizeof(recv_iphdr));
|
||||
EXPECT_EQ(recv_iphdr.id, 0);
|
||||
// The destination address should be localhost, not the bad IP we set
|
||||
// initially.
|
||||
EXPECT_EQ(absl::gbswap_32(recv_iphdr.daddr), INADDR_LOOPBACK);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace testing
|
||||
} // namespace gvisor
|
Loading…
Reference in New Issue