Remove useless cached state
Simplify some logic while I'm here. PiperOrigin-RevId: 351491593
This commit is contained in:
parent
8b0f0b4d11
commit
626a8ca225
|
@ -309,11 +309,6 @@ type socketOpsCommon struct {
|
||||||
|
|
||||||
// readMu protects access to the below fields.
|
// readMu protects access to the below fields.
|
||||||
readMu sync.Mutex `state:"nosave"`
|
readMu sync.Mutex `state:"nosave"`
|
||||||
// readCM holds control message information for the last packet read
|
|
||||||
// from Endpoint.
|
|
||||||
readCM socket.IPControlMessages
|
|
||||||
sender tcpip.FullAddress
|
|
||||||
linkPacketInfo tcpip.LinkPacketInfo
|
|
||||||
|
|
||||||
// sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps
|
// sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps
|
||||||
// of returned messages can be returned via control messages. When
|
// of returned messages can be returned via control messages. When
|
||||||
|
@ -368,25 +363,6 @@ func (s *socketOpsCommon) isPacketBased() bool {
|
||||||
return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
|
return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
|
||||||
}
|
}
|
||||||
|
|
||||||
// Precondition: s.readMu must be held.
|
|
||||||
func (s *socketOpsCommon) readLocked(dst io.Writer, count int, peek bool) (numRead, numTotal int, serr *syserr.Error) {
|
|
||||||
res, err := s.Endpoint.Read(dst, count, tcpip.ReadOptions{
|
|
||||||
Peek: peek,
|
|
||||||
NeedRemoteAddr: true,
|
|
||||||
NeedLinkPacketInfo: true,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Assign these anyways.
|
|
||||||
s.readCM = socket.NewIPControlMessages(s.family, res.ControlMessages)
|
|
||||||
s.sender = res.RemoteAddr
|
|
||||||
s.linkPacketInfo = res.LinkPacketInfo
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, syserr.TranslateNetstackError(err)
|
|
||||||
}
|
|
||||||
return res.Count, res.Total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Release implements fs.FileOperations.Release.
|
// Release implements fs.FileOperations.Release.
|
||||||
func (s *socketOpsCommon) Release(ctx context.Context) {
|
func (s *socketOpsCommon) Release(ctx context.Context) {
|
||||||
e, ch := waiter.NewChannelEntry(nil)
|
e, ch := waiter.NewChannelEntry(nil)
|
||||||
|
@ -436,11 +412,13 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write
|
||||||
defer s.readMu.Unlock()
|
defer s.readMu.Unlock()
|
||||||
|
|
||||||
// This may return a blocking error.
|
// This may return a blocking error.
|
||||||
n, _, err := s.readLocked(dst, int(count), dup /* peek */)
|
res, err := s.Endpoint.Read(dst, int(count), tcpip.ReadOptions{
|
||||||
|
Peek: dup,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err.ToError()
|
return 0, syserr.TranslateNetstackError(err).ToError()
|
||||||
}
|
}
|
||||||
return int64(n), nil
|
return int64(res.Count), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ioSequencePayload implements tcpip.Payload.
|
// ioSequencePayload implements tcpip.Payload.
|
||||||
|
@ -2557,22 +2535,6 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
|
||||||
return a, l, nil
|
return a, l, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// streamRead is the fast path for non-blocking, non-peek, stream-based socket.
|
|
||||||
//
|
|
||||||
// Precondition: s.readMu must be locked.
|
|
||||||
func (s *socketOpsCommon) streamRead(ctx context.Context, dst io.Writer, count int) (int, *syserr.Error) {
|
|
||||||
// Always do at least one read, even if the number of bytes to read is 0.
|
|
||||||
var n int
|
|
||||||
n, _, err := s.readLocked(dst, count, false /* peek */)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if n > 0 {
|
|
||||||
s.Endpoint.ModerateRecvBuf(n)
|
|
||||||
}
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
|
func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
|
||||||
if !s.sockOptInq {
|
if !s.sockOptInq {
|
||||||
return
|
return
|
||||||
|
@ -2608,133 +2570,102 @@ func toLinuxPacketType(pktType tcpip.PacketType) uint8 {
|
||||||
func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
|
func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
|
||||||
isPacket := s.isPacketBased()
|
isPacket := s.isPacketBased()
|
||||||
|
|
||||||
// Fast path for regular reads from stream (e.g., TCP) endpoints. Note
|
readOptions := tcpip.ReadOptions{
|
||||||
// that senderRequested is ignored for stream sockets.
|
Peek: peek,
|
||||||
if !peek && !isPacket {
|
NeedRemoteAddr: senderRequested,
|
||||||
// TCP sockets discard the data if MSG_TRUNC is set.
|
NeedLinkPacketInfo: isPacket,
|
||||||
//
|
|
||||||
// This behavior is documented in man 7 tcp:
|
|
||||||
// Since version 2.4, Linux supports the use of MSG_TRUNC in the flags
|
|
||||||
// argument of recv(2) (and recvmsg(2)). This flag causes the received
|
|
||||||
// bytes of data to be discarded, rather than passed back in a
|
|
||||||
// caller-supplied buffer.
|
|
||||||
s.readMu.Lock()
|
|
||||||
|
|
||||||
var w io.Writer
|
|
||||||
if trunc {
|
|
||||||
w = ioutil.Discard
|
|
||||||
} else {
|
|
||||||
w = dst.Writer(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err := s.streamRead(ctx, w, int(dst.NumBytes()))
|
|
||||||
|
|
||||||
if err == nil && !trunc {
|
|
||||||
// Set the control message, even if 0 bytes were read.
|
|
||||||
s.updateTimestamp()
|
|
||||||
}
|
|
||||||
|
|
||||||
cmsg := s.controlMessages()
|
|
||||||
s.fillCmsgInq(&cmsg)
|
|
||||||
s.readMu.Unlock()
|
|
||||||
return n, 0, nil, 0, cmsg, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
s.readMu.Lock()
|
// TCP sockets discard the data if MSG_TRUNC is set.
|
||||||
defer s.readMu.Unlock()
|
//
|
||||||
|
// This behavior is documented in man 7 tcp:
|
||||||
// MSG_TRUNC with MSG_PEEK on a TCP socket returns the
|
// Since version 2.4, Linux supports the use of MSG_TRUNC in the flags
|
||||||
// amount that could be read, and does not write to buffer.
|
// argument of recv(2) (and recvmsg(2)). This flag causes the received
|
||||||
isTCPPeekTrunc := !isPacket && peek && trunc
|
// bytes of data to be discarded, rather than passed back in a
|
||||||
|
// caller-supplied buffer.
|
||||||
var w io.Writer
|
var w io.Writer
|
||||||
if isTCPPeekTrunc {
|
if !isPacket && trunc {
|
||||||
w = ioutil.Discard
|
w = ioutil.Discard
|
||||||
} else {
|
} else {
|
||||||
w = dst.Writer(ctx)
|
w = dst.Writer(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
var numRead, numTotal int
|
s.readMu.Lock()
|
||||||
var err *syserr.Error
|
defer s.readMu.Unlock()
|
||||||
numRead, numTotal, err = s.readLocked(w, int(dst.NumBytes()), peek)
|
|
||||||
|
res, err := s.Endpoint.Read(w, int(dst.NumBytes()), readOptions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, nil, 0, socket.ControlMessages{}, err
|
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if isTCPPeekTrunc {
|
|
||||||
// TCP endpoint does not return the total bytes in buffer as numTotal.
|
|
||||||
// We need to query it from socket option.
|
|
||||||
rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
|
|
||||||
}
|
|
||||||
available := int(rql)
|
|
||||||
bufLen := int(dst.NumBytes())
|
|
||||||
if available < bufLen {
|
|
||||||
return available, 0, nil, 0, socket.ControlMessages{}, nil
|
|
||||||
}
|
|
||||||
return bufLen, 0, nil, 0, socket.ControlMessages{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the control message, even if 0 bytes were read.
|
// Set the control message, even if 0 bytes were read.
|
||||||
s.updateTimestamp()
|
s.updateTimestamp(res.ControlMessages)
|
||||||
|
|
||||||
var addr linux.SockAddr
|
if isPacket {
|
||||||
var addrLen uint32
|
var addr linux.SockAddr
|
||||||
if isPacket && senderRequested {
|
var addrLen uint32
|
||||||
addr, addrLen = socket.ConvertAddress(s.family, s.sender)
|
if senderRequested {
|
||||||
switch v := addr.(type) {
|
addr, addrLen = socket.ConvertAddress(s.family, res.RemoteAddr)
|
||||||
case *linux.SockAddrLink:
|
switch v := addr.(type) {
|
||||||
v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol))
|
case *linux.SockAddrLink:
|
||||||
v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType)
|
v.Protocol = socket.Htons(uint16(res.LinkPacketInfo.Protocol))
|
||||||
|
v.PacketType = toLinuxPacketType(res.LinkPacketInfo.PktType)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
msgLen := res.Count
|
||||||
|
if trunc {
|
||||||
|
msgLen = res.Total
|
||||||
|
}
|
||||||
|
|
||||||
|
var flags int
|
||||||
|
if res.Total > res.Count {
|
||||||
|
flags |= linux.MSG_TRUNC
|
||||||
|
}
|
||||||
|
|
||||||
|
return msgLen, flags, addr, addrLen, s.controlMessages(res.ControlMessages), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if peek {
|
if peek {
|
||||||
if trunc && numTotal > numRead {
|
// MSG_TRUNC with MSG_PEEK on a TCP socket returns the
|
||||||
// isPacket must be true.
|
// amount that could be read, and does not write to buffer.
|
||||||
return numTotal, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), nil
|
if trunc {
|
||||||
|
// TCP endpoint does not return the total bytes in buffer as numTotal.
|
||||||
|
// We need to query it from socket option.
|
||||||
|
rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
|
||||||
|
}
|
||||||
|
msgLen := int(dst.NumBytes())
|
||||||
|
if msgLen > rql {
|
||||||
|
msgLen = rql
|
||||||
|
}
|
||||||
|
return msgLen, 0, nil, 0, socket.ControlMessages{}, nil
|
||||||
}
|
}
|
||||||
return numRead, 0, nil, 0, s.controlMessages(), nil
|
} else if n := res.Count; n != 0 {
|
||||||
|
s.Endpoint.ModerateRecvBuf(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
var msgLen int
|
cmsg := s.controlMessages(res.ControlMessages)
|
||||||
if isPacket {
|
|
||||||
msgLen = numTotal
|
|
||||||
} else {
|
|
||||||
msgLen = numRead
|
|
||||||
}
|
|
||||||
|
|
||||||
var flags int
|
|
||||||
if msgLen > numRead {
|
|
||||||
flags |= linux.MSG_TRUNC
|
|
||||||
}
|
|
||||||
|
|
||||||
n := numRead
|
|
||||||
if trunc {
|
|
||||||
n = msgLen
|
|
||||||
}
|
|
||||||
|
|
||||||
cmsg := s.controlMessages()
|
|
||||||
s.fillCmsgInq(&cmsg)
|
s.fillCmsgInq(&cmsg)
|
||||||
return n, flags, addr, addrLen, cmsg, nil
|
return res.Count, 0, nil, 0, cmsg, syserr.TranslateNetstackError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
|
func (s *socketOpsCommon) controlMessages(cm tcpip.ControlMessages) socket.ControlMessages {
|
||||||
|
readCM := socket.NewIPControlMessages(s.family, cm)
|
||||||
return socket.ControlMessages{
|
return socket.ControlMessages{
|
||||||
IP: socket.IPControlMessages{
|
IP: socket.IPControlMessages{
|
||||||
HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
|
HasTimestamp: readCM.HasTimestamp && s.sockOptTimestamp,
|
||||||
Timestamp: s.readCM.Timestamp,
|
Timestamp: readCM.Timestamp,
|
||||||
HasInq: s.readCM.HasInq,
|
HasInq: readCM.HasInq,
|
||||||
Inq: s.readCM.Inq,
|
Inq: readCM.Inq,
|
||||||
HasTOS: s.readCM.HasTOS,
|
HasTOS: readCM.HasTOS,
|
||||||
TOS: s.readCM.TOS,
|
TOS: readCM.TOS,
|
||||||
HasTClass: s.readCM.HasTClass,
|
HasTClass: readCM.HasTClass,
|
||||||
TClass: s.readCM.TClass,
|
TClass: readCM.TClass,
|
||||||
HasIPPacketInfo: s.readCM.HasIPPacketInfo,
|
HasIPPacketInfo: readCM.HasIPPacketInfo,
|
||||||
PacketInfo: s.readCM.PacketInfo,
|
PacketInfo: readCM.PacketInfo,
|
||||||
OriginalDstAddress: s.readCM.OriginalDstAddress,
|
OriginalDstAddress: readCM.OriginalDstAddress,
|
||||||
SockErr: s.readCM.SockErr,
|
SockErr: readCM.SockErr,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2743,11 +2674,11 @@ func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
|
||||||
// successfully writing packet data out to userspace.
|
// successfully writing packet data out to userspace.
|
||||||
//
|
//
|
||||||
// Precondition: s.readMu must be locked.
|
// Precondition: s.readMu must be locked.
|
||||||
func (s *socketOpsCommon) updateTimestamp() {
|
func (s *socketOpsCommon) updateTimestamp(cm tcpip.ControlMessages) {
|
||||||
// Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled.
|
// Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled.
|
||||||
if !s.sockOptTimestamp {
|
if !s.sockOptTimestamp {
|
||||||
s.timestampValid = true
|
s.timestampValid = true
|
||||||
s.timestampNS = s.readCM.Timestamp
|
s.timestampNS = cm.Timestamp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue