Remove useless cached state
Simplify some logic while I'm here. PiperOrigin-RevId: 351491593
This commit is contained in:
parent
8b0f0b4d11
commit
626a8ca225
|
@ -309,11 +309,6 @@ type socketOpsCommon struct {
|
|||
|
||||
// readMu protects access to the below fields.
|
||||
readMu sync.Mutex `state:"nosave"`
|
||||
// readCM holds control message information for the last packet read
|
||||
// from Endpoint.
|
||||
readCM socket.IPControlMessages
|
||||
sender tcpip.FullAddress
|
||||
linkPacketInfo tcpip.LinkPacketInfo
|
||||
|
||||
// sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps
|
||||
// of returned messages can be returned via control messages. When
|
||||
|
@ -368,25 +363,6 @@ func (s *socketOpsCommon) isPacketBased() bool {
|
|||
return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
|
||||
}
|
||||
|
||||
// Precondition: s.readMu must be held.
|
||||
func (s *socketOpsCommon) readLocked(dst io.Writer, count int, peek bool) (numRead, numTotal int, serr *syserr.Error) {
|
||||
res, err := s.Endpoint.Read(dst, count, tcpip.ReadOptions{
|
||||
Peek: peek,
|
||||
NeedRemoteAddr: true,
|
||||
NeedLinkPacketInfo: true,
|
||||
})
|
||||
|
||||
// Assign these anyways.
|
||||
s.readCM = socket.NewIPControlMessages(s.family, res.ControlMessages)
|
||||
s.sender = res.RemoteAddr
|
||||
s.linkPacketInfo = res.LinkPacketInfo
|
||||
|
||||
if err != nil {
|
||||
return 0, 0, syserr.TranslateNetstackError(err)
|
||||
}
|
||||
return res.Count, res.Total, nil
|
||||
}
|
||||
|
||||
// Release implements fs.FileOperations.Release.
|
||||
func (s *socketOpsCommon) Release(ctx context.Context) {
|
||||
e, ch := waiter.NewChannelEntry(nil)
|
||||
|
@ -436,11 +412,13 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write
|
|||
defer s.readMu.Unlock()
|
||||
|
||||
// This may return a blocking error.
|
||||
n, _, err := s.readLocked(dst, int(count), dup /* peek */)
|
||||
res, err := s.Endpoint.Read(dst, int(count), tcpip.ReadOptions{
|
||||
Peek: dup,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err.ToError()
|
||||
return 0, syserr.TranslateNetstackError(err).ToError()
|
||||
}
|
||||
return int64(n), nil
|
||||
return int64(res.Count), nil
|
||||
}
|
||||
|
||||
// ioSequencePayload implements tcpip.Payload.
|
||||
|
@ -2557,22 +2535,6 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
|
|||
return a, l, nil
|
||||
}
|
||||
|
||||
// streamRead is the fast path for non-blocking, non-peek, stream-based socket.
|
||||
//
|
||||
// Precondition: s.readMu must be locked.
|
||||
func (s *socketOpsCommon) streamRead(ctx context.Context, dst io.Writer, count int) (int, *syserr.Error) {
|
||||
// Always do at least one read, even if the number of bytes to read is 0.
|
||||
var n int
|
||||
n, _, err := s.readLocked(dst, count, false /* peek */)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if n > 0 {
|
||||
s.Endpoint.ModerateRecvBuf(n)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
|
||||
if !s.sockOptInq {
|
||||
return
|
||||
|
@ -2608,133 +2570,102 @@ func toLinuxPacketType(pktType tcpip.PacketType) uint8 {
|
|||
func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
|
||||
isPacket := s.isPacketBased()
|
||||
|
||||
// Fast path for regular reads from stream (e.g., TCP) endpoints. Note
|
||||
// that senderRequested is ignored for stream sockets.
|
||||
if !peek && !isPacket {
|
||||
// TCP sockets discard the data if MSG_TRUNC is set.
|
||||
//
|
||||
// This behavior is documented in man 7 tcp:
|
||||
// Since version 2.4, Linux supports the use of MSG_TRUNC in the flags
|
||||
// argument of recv(2) (and recvmsg(2)). This flag causes the received
|
||||
// bytes of data to be discarded, rather than passed back in a
|
||||
// caller-supplied buffer.
|
||||
s.readMu.Lock()
|
||||
|
||||
var w io.Writer
|
||||
if trunc {
|
||||
w = ioutil.Discard
|
||||
} else {
|
||||
w = dst.Writer(ctx)
|
||||
}
|
||||
|
||||
n, err := s.streamRead(ctx, w, int(dst.NumBytes()))
|
||||
|
||||
if err == nil && !trunc {
|
||||
// Set the control message, even if 0 bytes were read.
|
||||
s.updateTimestamp()
|
||||
}
|
||||
|
||||
cmsg := s.controlMessages()
|
||||
s.fillCmsgInq(&cmsg)
|
||||
s.readMu.Unlock()
|
||||
return n, 0, nil, 0, cmsg, err
|
||||
readOptions := tcpip.ReadOptions{
|
||||
Peek: peek,
|
||||
NeedRemoteAddr: senderRequested,
|
||||
NeedLinkPacketInfo: isPacket,
|
||||
}
|
||||
|
||||
s.readMu.Lock()
|
||||
defer s.readMu.Unlock()
|
||||
|
||||
// MSG_TRUNC with MSG_PEEK on a TCP socket returns the
|
||||
// amount that could be read, and does not write to buffer.
|
||||
isTCPPeekTrunc := !isPacket && peek && trunc
|
||||
|
||||
// TCP sockets discard the data if MSG_TRUNC is set.
|
||||
//
|
||||
// This behavior is documented in man 7 tcp:
|
||||
// Since version 2.4, Linux supports the use of MSG_TRUNC in the flags
|
||||
// argument of recv(2) (and recvmsg(2)). This flag causes the received
|
||||
// bytes of data to be discarded, rather than passed back in a
|
||||
// caller-supplied buffer.
|
||||
var w io.Writer
|
||||
if isTCPPeekTrunc {
|
||||
if !isPacket && trunc {
|
||||
w = ioutil.Discard
|
||||
} else {
|
||||
w = dst.Writer(ctx)
|
||||
}
|
||||
|
||||
var numRead, numTotal int
|
||||
var err *syserr.Error
|
||||
numRead, numTotal, err = s.readLocked(w, int(dst.NumBytes()), peek)
|
||||
s.readMu.Lock()
|
||||
defer s.readMu.Unlock()
|
||||
|
||||
res, err := s.Endpoint.Read(w, int(dst.NumBytes()), readOptions)
|
||||
if err != nil {
|
||||
return 0, 0, nil, 0, socket.ControlMessages{}, err
|
||||
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
|
||||
}
|
||||
|
||||
if isTCPPeekTrunc {
|
||||
// TCP endpoint does not return the total bytes in buffer as numTotal.
|
||||
// We need to query it from socket option.
|
||||
rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
|
||||
if err != nil {
|
||||
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
|
||||
}
|
||||
available := int(rql)
|
||||
bufLen := int(dst.NumBytes())
|
||||
if available < bufLen {
|
||||
return available, 0, nil, 0, socket.ControlMessages{}, nil
|
||||
}
|
||||
return bufLen, 0, nil, 0, socket.ControlMessages{}, nil
|
||||
}
|
||||
|
||||
// Set the control message, even if 0 bytes were read.
|
||||
s.updateTimestamp()
|
||||
s.updateTimestamp(res.ControlMessages)
|
||||
|
||||
var addr linux.SockAddr
|
||||
var addrLen uint32
|
||||
if isPacket && senderRequested {
|
||||
addr, addrLen = socket.ConvertAddress(s.family, s.sender)
|
||||
switch v := addr.(type) {
|
||||
case *linux.SockAddrLink:
|
||||
v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol))
|
||||
v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType)
|
||||
if isPacket {
|
||||
var addr linux.SockAddr
|
||||
var addrLen uint32
|
||||
if senderRequested {
|
||||
addr, addrLen = socket.ConvertAddress(s.family, res.RemoteAddr)
|
||||
switch v := addr.(type) {
|
||||
case *linux.SockAddrLink:
|
||||
v.Protocol = socket.Htons(uint16(res.LinkPacketInfo.Protocol))
|
||||
v.PacketType = toLinuxPacketType(res.LinkPacketInfo.PktType)
|
||||
}
|
||||
}
|
||||
|
||||
msgLen := res.Count
|
||||
if trunc {
|
||||
msgLen = res.Total
|
||||
}
|
||||
|
||||
var flags int
|
||||
if res.Total > res.Count {
|
||||
flags |= linux.MSG_TRUNC
|
||||
}
|
||||
|
||||
return msgLen, flags, addr, addrLen, s.controlMessages(res.ControlMessages), nil
|
||||
}
|
||||
|
||||
if peek {
|
||||
if trunc && numTotal > numRead {
|
||||
// isPacket must be true.
|
||||
return numTotal, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), nil
|
||||
// MSG_TRUNC with MSG_PEEK on a TCP socket returns the
|
||||
// amount that could be read, and does not write to buffer.
|
||||
if trunc {
|
||||
// TCP endpoint does not return the total bytes in buffer as numTotal.
|
||||
// We need to query it from socket option.
|
||||
rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
|
||||
if err != nil {
|
||||
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
|
||||
}
|
||||
msgLen := int(dst.NumBytes())
|
||||
if msgLen > rql {
|
||||
msgLen = rql
|
||||
}
|
||||
return msgLen, 0, nil, 0, socket.ControlMessages{}, nil
|
||||
}
|
||||
return numRead, 0, nil, 0, s.controlMessages(), nil
|
||||
} else if n := res.Count; n != 0 {
|
||||
s.Endpoint.ModerateRecvBuf(n)
|
||||
}
|
||||
|
||||
var msgLen int
|
||||
if isPacket {
|
||||
msgLen = numTotal
|
||||
} else {
|
||||
msgLen = numRead
|
||||
}
|
||||
|
||||
var flags int
|
||||
if msgLen > numRead {
|
||||
flags |= linux.MSG_TRUNC
|
||||
}
|
||||
|
||||
n := numRead
|
||||
if trunc {
|
||||
n = msgLen
|
||||
}
|
||||
|
||||
cmsg := s.controlMessages()
|
||||
cmsg := s.controlMessages(res.ControlMessages)
|
||||
s.fillCmsgInq(&cmsg)
|
||||
return n, flags, addr, addrLen, cmsg, nil
|
||||
return res.Count, 0, nil, 0, cmsg, syserr.TranslateNetstackError(err)
|
||||
}
|
||||
|
||||
func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
|
||||
func (s *socketOpsCommon) controlMessages(cm tcpip.ControlMessages) socket.ControlMessages {
|
||||
readCM := socket.NewIPControlMessages(s.family, cm)
|
||||
return socket.ControlMessages{
|
||||
IP: socket.IPControlMessages{
|
||||
HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
|
||||
Timestamp: s.readCM.Timestamp,
|
||||
HasInq: s.readCM.HasInq,
|
||||
Inq: s.readCM.Inq,
|
||||
HasTOS: s.readCM.HasTOS,
|
||||
TOS: s.readCM.TOS,
|
||||
HasTClass: s.readCM.HasTClass,
|
||||
TClass: s.readCM.TClass,
|
||||
HasIPPacketInfo: s.readCM.HasIPPacketInfo,
|
||||
PacketInfo: s.readCM.PacketInfo,
|
||||
OriginalDstAddress: s.readCM.OriginalDstAddress,
|
||||
SockErr: s.readCM.SockErr,
|
||||
HasTimestamp: readCM.HasTimestamp && s.sockOptTimestamp,
|
||||
Timestamp: readCM.Timestamp,
|
||||
HasInq: readCM.HasInq,
|
||||
Inq: readCM.Inq,
|
||||
HasTOS: readCM.HasTOS,
|
||||
TOS: readCM.TOS,
|
||||
HasTClass: readCM.HasTClass,
|
||||
TClass: readCM.TClass,
|
||||
HasIPPacketInfo: readCM.HasIPPacketInfo,
|
||||
PacketInfo: readCM.PacketInfo,
|
||||
OriginalDstAddress: readCM.OriginalDstAddress,
|
||||
SockErr: readCM.SockErr,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -2743,11 +2674,11 @@ func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
|
|||
// successfully writing packet data out to userspace.
|
||||
//
|
||||
// Precondition: s.readMu must be locked.
|
||||
func (s *socketOpsCommon) updateTimestamp() {
|
||||
func (s *socketOpsCommon) updateTimestamp(cm tcpip.ControlMessages) {
|
||||
// Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled.
|
||||
if !s.sockOptTimestamp {
|
||||
s.timestampValid = true
|
||||
s.timestampNS = s.readCM.Timestamp
|
||||
s.timestampNS = cm.Timestamp
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue