Remove useless cached state

Simplify some logic while I'm here.

PiperOrigin-RevId: 351491593
This commit is contained in:
Tamir Duberstein 2021-01-12 18:39:47 -08:00 committed by gVisor bot
parent 8b0f0b4d11
commit 626a8ca225
1 changed files with 79 additions and 148 deletions

View File

@ -309,11 +309,6 @@ type socketOpsCommon struct {
// readMu protects access to the below fields. // readMu protects access to the below fields.
readMu sync.Mutex `state:"nosave"` readMu sync.Mutex `state:"nosave"`
// readCM holds control message information for the last packet read
// from Endpoint.
readCM socket.IPControlMessages
sender tcpip.FullAddress
linkPacketInfo tcpip.LinkPacketInfo
// sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps // sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps
// of returned messages can be returned via control messages. When // of returned messages can be returned via control messages. When
@ -368,25 +363,6 @@ func (s *socketOpsCommon) isPacketBased() bool {
return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
} }
// Precondition: s.readMu must be held.
func (s *socketOpsCommon) readLocked(dst io.Writer, count int, peek bool) (numRead, numTotal int, serr *syserr.Error) {
res, err := s.Endpoint.Read(dst, count, tcpip.ReadOptions{
Peek: peek,
NeedRemoteAddr: true,
NeedLinkPacketInfo: true,
})
// Assign these anyways.
s.readCM = socket.NewIPControlMessages(s.family, res.ControlMessages)
s.sender = res.RemoteAddr
s.linkPacketInfo = res.LinkPacketInfo
if err != nil {
return 0, 0, syserr.TranslateNetstackError(err)
}
return res.Count, res.Total, nil
}
// Release implements fs.FileOperations.Release. // Release implements fs.FileOperations.Release.
func (s *socketOpsCommon) Release(ctx context.Context) { func (s *socketOpsCommon) Release(ctx context.Context) {
e, ch := waiter.NewChannelEntry(nil) e, ch := waiter.NewChannelEntry(nil)
@ -436,11 +412,13 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write
defer s.readMu.Unlock() defer s.readMu.Unlock()
// This may return a blocking error. // This may return a blocking error.
n, _, err := s.readLocked(dst, int(count), dup /* peek */) res, err := s.Endpoint.Read(dst, int(count), tcpip.ReadOptions{
Peek: dup,
})
if err != nil { if err != nil {
return 0, err.ToError() return 0, syserr.TranslateNetstackError(err).ToError()
} }
return int64(n), nil return int64(res.Count), nil
} }
// ioSequencePayload implements tcpip.Payload. // ioSequencePayload implements tcpip.Payload.
@ -2557,22 +2535,6 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
return a, l, nil return a, l, nil
} }
// streamRead is the fast path for non-blocking, non-peek, stream-based socket.
//
// Precondition: s.readMu must be locked.
func (s *socketOpsCommon) streamRead(ctx context.Context, dst io.Writer, count int) (int, *syserr.Error) {
// Always do at least one read, even if the number of bytes to read is 0.
var n int
n, _, err := s.readLocked(dst, count, false /* peek */)
if err != nil {
return 0, err
}
if n > 0 {
s.Endpoint.ModerateRecvBuf(n)
}
return n, nil
}
func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) { func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
if !s.sockOptInq { if !s.sockOptInq {
return return
@ -2608,133 +2570,102 @@ func toLinuxPacketType(pktType tcpip.PacketType) uint8 {
func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
isPacket := s.isPacketBased() isPacket := s.isPacketBased()
// Fast path for regular reads from stream (e.g., TCP) endpoints. Note readOptions := tcpip.ReadOptions{
// that senderRequested is ignored for stream sockets. Peek: peek,
if !peek && !isPacket { NeedRemoteAddr: senderRequested,
// TCP sockets discard the data if MSG_TRUNC is set. NeedLinkPacketInfo: isPacket,
//
// This behavior is documented in man 7 tcp:
// Since version 2.4, Linux supports the use of MSG_TRUNC in the flags
// argument of recv(2) (and recvmsg(2)). This flag causes the received
// bytes of data to be discarded, rather than passed back in a
// caller-supplied buffer.
s.readMu.Lock()
var w io.Writer
if trunc {
w = ioutil.Discard
} else {
w = dst.Writer(ctx)
}
n, err := s.streamRead(ctx, w, int(dst.NumBytes()))
if err == nil && !trunc {
// Set the control message, even if 0 bytes were read.
s.updateTimestamp()
}
cmsg := s.controlMessages()
s.fillCmsgInq(&cmsg)
s.readMu.Unlock()
return n, 0, nil, 0, cmsg, err
} }
s.readMu.Lock() // TCP sockets discard the data if MSG_TRUNC is set.
defer s.readMu.Unlock() //
// This behavior is documented in man 7 tcp:
// MSG_TRUNC with MSG_PEEK on a TCP socket returns the // Since version 2.4, Linux supports the use of MSG_TRUNC in the flags
// amount that could be read, and does not write to buffer. // argument of recv(2) (and recvmsg(2)). This flag causes the received
isTCPPeekTrunc := !isPacket && peek && trunc // bytes of data to be discarded, rather than passed back in a
// caller-supplied buffer.
var w io.Writer var w io.Writer
if isTCPPeekTrunc { if !isPacket && trunc {
w = ioutil.Discard w = ioutil.Discard
} else { } else {
w = dst.Writer(ctx) w = dst.Writer(ctx)
} }
var numRead, numTotal int s.readMu.Lock()
var err *syserr.Error defer s.readMu.Unlock()
numRead, numTotal, err = s.readLocked(w, int(dst.NumBytes()), peek)
res, err := s.Endpoint.Read(w, int(dst.NumBytes()), readOptions)
if err != nil { if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, err return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
} }
if isTCPPeekTrunc {
// TCP endpoint does not return the total bytes in buffer as numTotal.
// We need to query it from socket option.
rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
available := int(rql)
bufLen := int(dst.NumBytes())
if available < bufLen {
return available, 0, nil, 0, socket.ControlMessages{}, nil
}
return bufLen, 0, nil, 0, socket.ControlMessages{}, nil
}
// Set the control message, even if 0 bytes were read. // Set the control message, even if 0 bytes were read.
s.updateTimestamp() s.updateTimestamp(res.ControlMessages)
var addr linux.SockAddr if isPacket {
var addrLen uint32 var addr linux.SockAddr
if isPacket && senderRequested { var addrLen uint32
addr, addrLen = socket.ConvertAddress(s.family, s.sender) if senderRequested {
switch v := addr.(type) { addr, addrLen = socket.ConvertAddress(s.family, res.RemoteAddr)
case *linux.SockAddrLink: switch v := addr.(type) {
v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol)) case *linux.SockAddrLink:
v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) v.Protocol = socket.Htons(uint16(res.LinkPacketInfo.Protocol))
v.PacketType = toLinuxPacketType(res.LinkPacketInfo.PktType)
}
} }
msgLen := res.Count
if trunc {
msgLen = res.Total
}
var flags int
if res.Total > res.Count {
flags |= linux.MSG_TRUNC
}
return msgLen, flags, addr, addrLen, s.controlMessages(res.ControlMessages), nil
} }
if peek { if peek {
if trunc && numTotal > numRead { // MSG_TRUNC with MSG_PEEK on a TCP socket returns the
// isPacket must be true. // amount that could be read, and does not write to buffer.
return numTotal, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), nil if trunc {
// TCP endpoint does not return the total bytes in buffer as numTotal.
// We need to query it from socket option.
rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
msgLen := int(dst.NumBytes())
if msgLen > rql {
msgLen = rql
}
return msgLen, 0, nil, 0, socket.ControlMessages{}, nil
} }
return numRead, 0, nil, 0, s.controlMessages(), nil } else if n := res.Count; n != 0 {
s.Endpoint.ModerateRecvBuf(n)
} }
var msgLen int cmsg := s.controlMessages(res.ControlMessages)
if isPacket {
msgLen = numTotal
} else {
msgLen = numRead
}
var flags int
if msgLen > numRead {
flags |= linux.MSG_TRUNC
}
n := numRead
if trunc {
n = msgLen
}
cmsg := s.controlMessages()
s.fillCmsgInq(&cmsg) s.fillCmsgInq(&cmsg)
return n, flags, addr, addrLen, cmsg, nil return res.Count, 0, nil, 0, cmsg, syserr.TranslateNetstackError(err)
} }
func (s *socketOpsCommon) controlMessages() socket.ControlMessages { func (s *socketOpsCommon) controlMessages(cm tcpip.ControlMessages) socket.ControlMessages {
readCM := socket.NewIPControlMessages(s.family, cm)
return socket.ControlMessages{ return socket.ControlMessages{
IP: socket.IPControlMessages{ IP: socket.IPControlMessages{
HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, HasTimestamp: readCM.HasTimestamp && s.sockOptTimestamp,
Timestamp: s.readCM.Timestamp, Timestamp: readCM.Timestamp,
HasInq: s.readCM.HasInq, HasInq: readCM.HasInq,
Inq: s.readCM.Inq, Inq: readCM.Inq,
HasTOS: s.readCM.HasTOS, HasTOS: readCM.HasTOS,
TOS: s.readCM.TOS, TOS: readCM.TOS,
HasTClass: s.readCM.HasTClass, HasTClass: readCM.HasTClass,
TClass: s.readCM.TClass, TClass: readCM.TClass,
HasIPPacketInfo: s.readCM.HasIPPacketInfo, HasIPPacketInfo: readCM.HasIPPacketInfo,
PacketInfo: s.readCM.PacketInfo, PacketInfo: readCM.PacketInfo,
OriginalDstAddress: s.readCM.OriginalDstAddress, OriginalDstAddress: readCM.OriginalDstAddress,
SockErr: s.readCM.SockErr, SockErr: readCM.SockErr,
}, },
} }
} }
@ -2743,11 +2674,11 @@ func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
// successfully writing packet data out to userspace. // successfully writing packet data out to userspace.
// //
// Precondition: s.readMu must be locked. // Precondition: s.readMu must be locked.
func (s *socketOpsCommon) updateTimestamp() { func (s *socketOpsCommon) updateTimestamp(cm tcpip.ControlMessages) {
// Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled. // Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled.
if !s.sockOptTimestamp { if !s.sockOptTimestamp {
s.timestampValid = true s.timestampValid = true
s.timestampNS = s.readCM.Timestamp s.timestampNS = cm.Timestamp
} }
} }