diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index dcf898c0a..57f224120 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -309,11 +309,6 @@ type socketOpsCommon struct { // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` - // readCM holds control message information for the last packet read - // from Endpoint. - readCM socket.IPControlMessages - sender tcpip.FullAddress - linkPacketInfo tcpip.LinkPacketInfo // sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps // of returned messages can be returned via control messages. When @@ -368,25 +363,6 @@ func (s *socketOpsCommon) isPacketBased() bool { return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW } -// Precondition: s.readMu must be held. -func (s *socketOpsCommon) readLocked(dst io.Writer, count int, peek bool) (numRead, numTotal int, serr *syserr.Error) { - res, err := s.Endpoint.Read(dst, count, tcpip.ReadOptions{ - Peek: peek, - NeedRemoteAddr: true, - NeedLinkPacketInfo: true, - }) - - // Assign these anyways. - s.readCM = socket.NewIPControlMessages(s.family, res.ControlMessages) - s.sender = res.RemoteAddr - s.linkPacketInfo = res.LinkPacketInfo - - if err != nil { - return 0, 0, syserr.TranslateNetstackError(err) - } - return res.Count, res.Total, nil -} - // Release implements fs.FileOperations.Release. func (s *socketOpsCommon) Release(ctx context.Context) { e, ch := waiter.NewChannelEntry(nil) @@ -436,11 +412,13 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write defer s.readMu.Unlock() // This may return a blocking error. - n, _, err := s.readLocked(dst, int(count), dup /* peek */) + res, err := s.Endpoint.Read(dst, int(count), tcpip.ReadOptions{ + Peek: dup, + }) if err != nil { - return 0, err.ToError() + return 0, syserr.TranslateNetstackError(err).ToError() } - return int64(n), nil + return int64(res.Count), nil } // ioSequencePayload implements tcpip.Payload. @@ -2557,22 +2535,6 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * return a, l, nil } -// streamRead is the fast path for non-blocking, non-peek, stream-based socket. -// -// Precondition: s.readMu must be locked. -func (s *socketOpsCommon) streamRead(ctx context.Context, dst io.Writer, count int) (int, *syserr.Error) { - // Always do at least one read, even if the number of bytes to read is 0. - var n int - n, _, err := s.readLocked(dst, count, false /* peek */) - if err != nil { - return 0, err - } - if n > 0 { - s.Endpoint.ModerateRecvBuf(n) - } - return n, nil -} - func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) { if !s.sockOptInq { return @@ -2608,133 +2570,102 @@ func toLinuxPacketType(pktType tcpip.PacketType) uint8 { func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { isPacket := s.isPacketBased() - // Fast path for regular reads from stream (e.g., TCP) endpoints. Note - // that senderRequested is ignored for stream sockets. - if !peek && !isPacket { - // TCP sockets discard the data if MSG_TRUNC is set. - // - // This behavior is documented in man 7 tcp: - // Since version 2.4, Linux supports the use of MSG_TRUNC in the flags - // argument of recv(2) (and recvmsg(2)). This flag causes the received - // bytes of data to be discarded, rather than passed back in a - // caller-supplied buffer. - s.readMu.Lock() - - var w io.Writer - if trunc { - w = ioutil.Discard - } else { - w = dst.Writer(ctx) - } - - n, err := s.streamRead(ctx, w, int(dst.NumBytes())) - - if err == nil && !trunc { - // Set the control message, even if 0 bytes were read. - s.updateTimestamp() - } - - cmsg := s.controlMessages() - s.fillCmsgInq(&cmsg) - s.readMu.Unlock() - return n, 0, nil, 0, cmsg, err + readOptions := tcpip.ReadOptions{ + Peek: peek, + NeedRemoteAddr: senderRequested, + NeedLinkPacketInfo: isPacket, } - s.readMu.Lock() - defer s.readMu.Unlock() - - // MSG_TRUNC with MSG_PEEK on a TCP socket returns the - // amount that could be read, and does not write to buffer. - isTCPPeekTrunc := !isPacket && peek && trunc - + // TCP sockets discard the data if MSG_TRUNC is set. + // + // This behavior is documented in man 7 tcp: + // Since version 2.4, Linux supports the use of MSG_TRUNC in the flags + // argument of recv(2) (and recvmsg(2)). This flag causes the received + // bytes of data to be discarded, rather than passed back in a + // caller-supplied buffer. var w io.Writer - if isTCPPeekTrunc { + if !isPacket && trunc { w = ioutil.Discard } else { w = dst.Writer(ctx) } - var numRead, numTotal int - var err *syserr.Error - numRead, numTotal, err = s.readLocked(w, int(dst.NumBytes()), peek) + s.readMu.Lock() + defer s.readMu.Unlock() + + res, err := s.Endpoint.Read(w, int(dst.NumBytes()), readOptions) if err != nil { - return 0, 0, nil, 0, socket.ControlMessages{}, err + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err) } - - if isTCPPeekTrunc { - // TCP endpoint does not return the total bytes in buffer as numTotal. - // We need to query it from socket option. - rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption) - if err != nil { - return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err) - } - available := int(rql) - bufLen := int(dst.NumBytes()) - if available < bufLen { - return available, 0, nil, 0, socket.ControlMessages{}, nil - } - return bufLen, 0, nil, 0, socket.ControlMessages{}, nil - } - // Set the control message, even if 0 bytes were read. - s.updateTimestamp() + s.updateTimestamp(res.ControlMessages) - var addr linux.SockAddr - var addrLen uint32 - if isPacket && senderRequested { - addr, addrLen = socket.ConvertAddress(s.family, s.sender) - switch v := addr.(type) { - case *linux.SockAddrLink: - v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol)) - v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) + if isPacket { + var addr linux.SockAddr + var addrLen uint32 + if senderRequested { + addr, addrLen = socket.ConvertAddress(s.family, res.RemoteAddr) + switch v := addr.(type) { + case *linux.SockAddrLink: + v.Protocol = socket.Htons(uint16(res.LinkPacketInfo.Protocol)) + v.PacketType = toLinuxPacketType(res.LinkPacketInfo.PktType) + } } + + msgLen := res.Count + if trunc { + msgLen = res.Total + } + + var flags int + if res.Total > res.Count { + flags |= linux.MSG_TRUNC + } + + return msgLen, flags, addr, addrLen, s.controlMessages(res.ControlMessages), nil } if peek { - if trunc && numTotal > numRead { - // isPacket must be true. - return numTotal, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), nil + // MSG_TRUNC with MSG_PEEK on a TCP socket returns the + // amount that could be read, and does not write to buffer. + if trunc { + // TCP endpoint does not return the total bytes in buffer as numTotal. + // We need to query it from socket option. + rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption) + if err != nil { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err) + } + msgLen := int(dst.NumBytes()) + if msgLen > rql { + msgLen = rql + } + return msgLen, 0, nil, 0, socket.ControlMessages{}, nil } - return numRead, 0, nil, 0, s.controlMessages(), nil + } else if n := res.Count; n != 0 { + s.Endpoint.ModerateRecvBuf(n) } - var msgLen int - if isPacket { - msgLen = numTotal - } else { - msgLen = numRead - } - - var flags int - if msgLen > numRead { - flags |= linux.MSG_TRUNC - } - - n := numRead - if trunc { - n = msgLen - } - - cmsg := s.controlMessages() + cmsg := s.controlMessages(res.ControlMessages) s.fillCmsgInq(&cmsg) - return n, flags, addr, addrLen, cmsg, nil + return res.Count, 0, nil, 0, cmsg, syserr.TranslateNetstackError(err) } -func (s *socketOpsCommon) controlMessages() socket.ControlMessages { +func (s *socketOpsCommon) controlMessages(cm tcpip.ControlMessages) socket.ControlMessages { + readCM := socket.NewIPControlMessages(s.family, cm) return socket.ControlMessages{ IP: socket.IPControlMessages{ - HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, - Timestamp: s.readCM.Timestamp, - HasInq: s.readCM.HasInq, - Inq: s.readCM.Inq, - HasTOS: s.readCM.HasTOS, - TOS: s.readCM.TOS, - HasTClass: s.readCM.HasTClass, - TClass: s.readCM.TClass, - HasIPPacketInfo: s.readCM.HasIPPacketInfo, - PacketInfo: s.readCM.PacketInfo, - OriginalDstAddress: s.readCM.OriginalDstAddress, - SockErr: s.readCM.SockErr, + HasTimestamp: readCM.HasTimestamp && s.sockOptTimestamp, + Timestamp: readCM.Timestamp, + HasInq: readCM.HasInq, + Inq: readCM.Inq, + HasTOS: readCM.HasTOS, + TOS: readCM.TOS, + HasTClass: readCM.HasTClass, + TClass: readCM.TClass, + HasIPPacketInfo: readCM.HasIPPacketInfo, + PacketInfo: readCM.PacketInfo, + OriginalDstAddress: readCM.OriginalDstAddress, + SockErr: readCM.SockErr, }, } } @@ -2743,11 +2674,11 @@ func (s *socketOpsCommon) controlMessages() socket.ControlMessages { // successfully writing packet data out to userspace. // // Precondition: s.readMu must be locked. -func (s *socketOpsCommon) updateTimestamp() { +func (s *socketOpsCommon) updateTimestamp(cm tcpip.ControlMessages) { // Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled. if !s.sockOptTimestamp { s.timestampValid = true - s.timestampNS = s.readCM.Timestamp + s.timestampNS = cm.Timestamp } }