Remove useless cached state

Simplify some logic while I'm here.

PiperOrigin-RevId: 351491593
This commit is contained in:
Tamir Duberstein 2021-01-12 18:39:47 -08:00 committed by gVisor bot
parent 8b0f0b4d11
commit 626a8ca225
1 changed files with 79 additions and 148 deletions

View File

@ -309,11 +309,6 @@ type socketOpsCommon struct {
// readMu protects access to the below fields.
readMu sync.Mutex `state:"nosave"`
// readCM holds control message information for the last packet read
// from Endpoint.
readCM socket.IPControlMessages
sender tcpip.FullAddress
linkPacketInfo tcpip.LinkPacketInfo
// sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps
// of returned messages can be returned via control messages. When
@ -368,25 +363,6 @@ func (s *socketOpsCommon) isPacketBased() bool {
return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
}
// Precondition: s.readMu must be held.
func (s *socketOpsCommon) readLocked(dst io.Writer, count int, peek bool) (numRead, numTotal int, serr *syserr.Error) {
res, err := s.Endpoint.Read(dst, count, tcpip.ReadOptions{
Peek: peek,
NeedRemoteAddr: true,
NeedLinkPacketInfo: true,
})
// Assign these anyways.
s.readCM = socket.NewIPControlMessages(s.family, res.ControlMessages)
s.sender = res.RemoteAddr
s.linkPacketInfo = res.LinkPacketInfo
if err != nil {
return 0, 0, syserr.TranslateNetstackError(err)
}
return res.Count, res.Total, nil
}
// Release implements fs.FileOperations.Release.
func (s *socketOpsCommon) Release(ctx context.Context) {
e, ch := waiter.NewChannelEntry(nil)
@ -436,11 +412,13 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write
defer s.readMu.Unlock()
// This may return a blocking error.
n, _, err := s.readLocked(dst, int(count), dup /* peek */)
res, err := s.Endpoint.Read(dst, int(count), tcpip.ReadOptions{
Peek: dup,
})
if err != nil {
return 0, err.ToError()
return 0, syserr.TranslateNetstackError(err).ToError()
}
return int64(n), nil
return int64(res.Count), nil
}
// ioSequencePayload implements tcpip.Payload.
@ -2557,22 +2535,6 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
return a, l, nil
}
// streamRead is the fast path for non-blocking, non-peek, stream-based socket.
//
// Precondition: s.readMu must be locked.
func (s *socketOpsCommon) streamRead(ctx context.Context, dst io.Writer, count int) (int, *syserr.Error) {
// Always do at least one read, even if the number of bytes to read is 0.
var n int
n, _, err := s.readLocked(dst, count, false /* peek */)
if err != nil {
return 0, err
}
if n > 0 {
s.Endpoint.ModerateRecvBuf(n)
}
return n, nil
}
func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
if !s.sockOptInq {
return
@ -2608,133 +2570,102 @@ func toLinuxPacketType(pktType tcpip.PacketType) uint8 {
func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
isPacket := s.isPacketBased()
// Fast path for regular reads from stream (e.g., TCP) endpoints. Note
// that senderRequested is ignored for stream sockets.
if !peek && !isPacket {
// TCP sockets discard the data if MSG_TRUNC is set.
//
// This behavior is documented in man 7 tcp:
// Since version 2.4, Linux supports the use of MSG_TRUNC in the flags
// argument of recv(2) (and recvmsg(2)). This flag causes the received
// bytes of data to be discarded, rather than passed back in a
// caller-supplied buffer.
s.readMu.Lock()
var w io.Writer
if trunc {
w = ioutil.Discard
} else {
w = dst.Writer(ctx)
}
n, err := s.streamRead(ctx, w, int(dst.NumBytes()))
if err == nil && !trunc {
// Set the control message, even if 0 bytes were read.
s.updateTimestamp()
}
cmsg := s.controlMessages()
s.fillCmsgInq(&cmsg)
s.readMu.Unlock()
return n, 0, nil, 0, cmsg, err
readOptions := tcpip.ReadOptions{
Peek: peek,
NeedRemoteAddr: senderRequested,
NeedLinkPacketInfo: isPacket,
}
s.readMu.Lock()
defer s.readMu.Unlock()
// MSG_TRUNC with MSG_PEEK on a TCP socket returns the
// amount that could be read, and does not write to buffer.
isTCPPeekTrunc := !isPacket && peek && trunc
// TCP sockets discard the data if MSG_TRUNC is set.
//
// This behavior is documented in man 7 tcp:
// Since version 2.4, Linux supports the use of MSG_TRUNC in the flags
// argument of recv(2) (and recvmsg(2)). This flag causes the received
// bytes of data to be discarded, rather than passed back in a
// caller-supplied buffer.
var w io.Writer
if isTCPPeekTrunc {
if !isPacket && trunc {
w = ioutil.Discard
} else {
w = dst.Writer(ctx)
}
var numRead, numTotal int
var err *syserr.Error
numRead, numTotal, err = s.readLocked(w, int(dst.NumBytes()), peek)
s.readMu.Lock()
defer s.readMu.Unlock()
res, err := s.Endpoint.Read(w, int(dst.NumBytes()), readOptions)
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, err
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
if isTCPPeekTrunc {
// TCP endpoint does not return the total bytes in buffer as numTotal.
// We need to query it from socket option.
rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
available := int(rql)
bufLen := int(dst.NumBytes())
if available < bufLen {
return available, 0, nil, 0, socket.ControlMessages{}, nil
}
return bufLen, 0, nil, 0, socket.ControlMessages{}, nil
}
// Set the control message, even if 0 bytes were read.
s.updateTimestamp()
s.updateTimestamp(res.ControlMessages)
var addr linux.SockAddr
var addrLen uint32
if isPacket && senderRequested {
addr, addrLen = socket.ConvertAddress(s.family, s.sender)
switch v := addr.(type) {
case *linux.SockAddrLink:
v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol))
v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType)
if isPacket {
var addr linux.SockAddr
var addrLen uint32
if senderRequested {
addr, addrLen = socket.ConvertAddress(s.family, res.RemoteAddr)
switch v := addr.(type) {
case *linux.SockAddrLink:
v.Protocol = socket.Htons(uint16(res.LinkPacketInfo.Protocol))
v.PacketType = toLinuxPacketType(res.LinkPacketInfo.PktType)
}
}
msgLen := res.Count
if trunc {
msgLen = res.Total
}
var flags int
if res.Total > res.Count {
flags |= linux.MSG_TRUNC
}
return msgLen, flags, addr, addrLen, s.controlMessages(res.ControlMessages), nil
}
if peek {
if trunc && numTotal > numRead {
// isPacket must be true.
return numTotal, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), nil
// MSG_TRUNC with MSG_PEEK on a TCP socket returns the
// amount that could be read, and does not write to buffer.
if trunc {
// TCP endpoint does not return the total bytes in buffer as numTotal.
// We need to query it from socket option.
rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
msgLen := int(dst.NumBytes())
if msgLen > rql {
msgLen = rql
}
return msgLen, 0, nil, 0, socket.ControlMessages{}, nil
}
return numRead, 0, nil, 0, s.controlMessages(), nil
} else if n := res.Count; n != 0 {
s.Endpoint.ModerateRecvBuf(n)
}
var msgLen int
if isPacket {
msgLen = numTotal
} else {
msgLen = numRead
}
var flags int
if msgLen > numRead {
flags |= linux.MSG_TRUNC
}
n := numRead
if trunc {
n = msgLen
}
cmsg := s.controlMessages()
cmsg := s.controlMessages(res.ControlMessages)
s.fillCmsgInq(&cmsg)
return n, flags, addr, addrLen, cmsg, nil
return res.Count, 0, nil, 0, cmsg, syserr.TranslateNetstackError(err)
}
func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
func (s *socketOpsCommon) controlMessages(cm tcpip.ControlMessages) socket.ControlMessages {
readCM := socket.NewIPControlMessages(s.family, cm)
return socket.ControlMessages{
IP: socket.IPControlMessages{
HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
Timestamp: s.readCM.Timestamp,
HasInq: s.readCM.HasInq,
Inq: s.readCM.Inq,
HasTOS: s.readCM.HasTOS,
TOS: s.readCM.TOS,
HasTClass: s.readCM.HasTClass,
TClass: s.readCM.TClass,
HasIPPacketInfo: s.readCM.HasIPPacketInfo,
PacketInfo: s.readCM.PacketInfo,
OriginalDstAddress: s.readCM.OriginalDstAddress,
SockErr: s.readCM.SockErr,
HasTimestamp: readCM.HasTimestamp && s.sockOptTimestamp,
Timestamp: readCM.Timestamp,
HasInq: readCM.HasInq,
Inq: readCM.Inq,
HasTOS: readCM.HasTOS,
TOS: readCM.TOS,
HasTClass: readCM.HasTClass,
TClass: readCM.TClass,
HasIPPacketInfo: readCM.HasIPPacketInfo,
PacketInfo: readCM.PacketInfo,
OriginalDstAddress: readCM.OriginalDstAddress,
SockErr: readCM.SockErr,
},
}
}
@ -2743,11 +2674,11 @@ func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
// successfully writing packet data out to userspace.
//
// Precondition: s.readMu must be locked.
func (s *socketOpsCommon) updateTimestamp() {
func (s *socketOpsCommon) updateTimestamp(cm tcpip.ControlMessages) {
// Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled.
if !s.sockOptTimestamp {
s.timestampValid = true
s.timestampNS = s.readCM.Timestamp
s.timestampNS = cm.Timestamp
}
}