Improve SendMsg performance.
SendMsg before this change would copy all the data over into a new slice even if the underlying socket could only accept a small amount of data. This is really inefficient with non-blocking sockets and under high throughput where large writes could get ErrWouldBlock or if there was say a timeout associated with the sendmsg() syscall. With this change we delay copying bytes in till they are needed and only copy what can be potentially sent/held in the socket buffer. Reducing the need to repeatedly copy data over. Also a minor fix to change state FIN-WAIT-1 when shutdown(..., SHUT_WR) is called instead of when we transmit the actual FIN. Otherwise the socket could remain in CONNECTED state even though the user has called shutdown() on the socket. Updates #627 PiperOrigin-RevId: 263430505
This commit is contained in:
parent
cee044c2ab
commit
570fb1db6b
|
@ -429,6 +429,11 @@ func (i *ioSequencePayload) Size() int {
|
|||
return int(i.src.NumBytes())
|
||||
}
|
||||
|
||||
// DropFirst drops the first n bytes from underlying src.
|
||||
func (i *ioSequencePayload) DropFirst(n int) {
|
||||
i.src = i.src.DropFirst(int(n))
|
||||
}
|
||||
|
||||
// Write implements fs.FileOperations.Write.
|
||||
func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
|
||||
f := &ioSequencePayload{ctx: ctx, src: src}
|
||||
|
@ -2026,28 +2031,22 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
|
|||
addr = &addrBuf
|
||||
}
|
||||
|
||||
v := buffer.NewView(int(src.NumBytes()))
|
||||
|
||||
// Copy all the data into the buffer.
|
||||
if _, err := src.CopyIn(t, v); err != nil {
|
||||
return 0, syserr.FromError(err)
|
||||
}
|
||||
|
||||
opts := tcpip.WriteOptions{
|
||||
To: addr,
|
||||
More: flags&linux.MSG_MORE != 0,
|
||||
EndOfRecord: flags&linux.MSG_EOR != 0,
|
||||
}
|
||||
|
||||
n, resCh, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts)
|
||||
v := &ioSequencePayload{t, src}
|
||||
n, resCh, err := s.Endpoint.Write(v, opts)
|
||||
if resCh != nil {
|
||||
if err := t.Block(resCh); err != nil {
|
||||
return 0, syserr.FromError(err)
|
||||
}
|
||||
n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
|
||||
n, _, err = s.Endpoint.Write(v, opts)
|
||||
}
|
||||
dontWait := flags&linux.MSG_DONTWAIT != 0
|
||||
if err == nil && (n >= uintptr(len(v)) || dontWait) {
|
||||
if err == nil && (n >= uintptr(v.Size()) || dontWait) {
|
||||
// Complete write.
|
||||
return int(n), nil
|
||||
}
|
||||
|
@ -2061,18 +2060,18 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
|
|||
s.EventRegister(&e, waiter.EventOut)
|
||||
defer s.EventUnregister(&e)
|
||||
|
||||
v.TrimFront(int(n))
|
||||
v.DropFirst(int(n))
|
||||
total := n
|
||||
for {
|
||||
n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
|
||||
v.TrimFront(int(n))
|
||||
n, _, err = s.Endpoint.Write(v, opts)
|
||||
v.DropFirst(int(n))
|
||||
total += n
|
||||
|
||||
if err != nil && err != tcpip.ErrWouldBlock && total == 0 {
|
||||
return 0, syserr.TranslateNetstackError(err)
|
||||
}
|
||||
|
||||
if err == nil && len(v) == 0 || err != nil && err != tcpip.ErrWouldBlock {
|
||||
if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock {
|
||||
return int(total), nil
|
||||
}
|
||||
|
||||
|
|
|
@ -878,6 +878,34 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
|
|||
return v, nil
|
||||
}
|
||||
|
||||
// isEndpointWritableLocked checks if a given endpoint is writable
|
||||
// and also returns the number of bytes that can be written at this
|
||||
// moment. If the endpoint is not writable then it returns an error
|
||||
// indicating the reason why it's not writable.
|
||||
// Caller must hold e.mu and e.sndBufMu
|
||||
func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
|
||||
// The endpoint cannot be written to if it's not connected.
|
||||
if !e.state.connected() {
|
||||
switch e.state {
|
||||
case StateError:
|
||||
return 0, e.hardError
|
||||
default:
|
||||
return 0, tcpip.ErrClosedForSend
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the connection has already been closed for sends.
|
||||
if e.sndClosed {
|
||||
return 0, tcpip.ErrClosedForSend
|
||||
}
|
||||
|
||||
avail := e.sndBufSize - e.sndBufUsed
|
||||
if avail <= 0 {
|
||||
return 0, tcpip.ErrWouldBlock
|
||||
}
|
||||
return avail, nil
|
||||
}
|
||||
|
||||
// Write writes data to the endpoint's peer.
|
||||
func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
|
||||
// Linux completely ignores any address passed to sendto(2) for TCP sockets
|
||||
|
@ -885,53 +913,60 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c
|
|||
// and opts.EndOfRecord are also ignored.
|
||||
|
||||
e.mu.RLock()
|
||||
defer e.mu.RUnlock()
|
||||
e.sndBufMu.Lock()
|
||||
|
||||
// The endpoint cannot be written to if it's not connected.
|
||||
if !e.state.connected() {
|
||||
switch e.state {
|
||||
case StateError:
|
||||
return 0, nil, e.hardError
|
||||
default:
|
||||
return 0, nil, tcpip.ErrClosedForSend
|
||||
}
|
||||
avail, err := e.isEndpointWritableLocked()
|
||||
if err != nil {
|
||||
e.sndBufMu.Unlock()
|
||||
e.mu.RUnlock()
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
e.sndBufMu.Unlock()
|
||||
e.mu.RUnlock()
|
||||
|
||||
// Nothing to do if the buffer is empty.
|
||||
if p.Size() == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
|
||||
e.sndBufMu.Lock()
|
||||
|
||||
// Check if the connection has already been closed for sends.
|
||||
if e.sndClosed {
|
||||
e.sndBufMu.Unlock()
|
||||
return 0, nil, tcpip.ErrClosedForSend
|
||||
}
|
||||
|
||||
// Check against the limit.
|
||||
avail := e.sndBufSize - e.sndBufUsed
|
||||
if avail <= 0 {
|
||||
e.sndBufMu.Unlock()
|
||||
return 0, nil, tcpip.ErrWouldBlock
|
||||
}
|
||||
|
||||
// Copy in memory without holding sndBufMu so that worker goroutine can
|
||||
// make progress independent of this operation.
|
||||
v, perr := p.Get(avail)
|
||||
if perr != nil {
|
||||
e.sndBufMu.Unlock()
|
||||
return 0, nil, perr
|
||||
}
|
||||
|
||||
l := len(v)
|
||||
s := newSegmentFromView(&e.route, e.id, v)
|
||||
e.mu.RLock()
|
||||
e.sndBufMu.Lock()
|
||||
|
||||
// Because we released the lock before copying, check state again
|
||||
// to make sure the endpoint is still in a valid state for a
|
||||
// write.
|
||||
avail, err = e.isEndpointWritableLocked()
|
||||
if err != nil {
|
||||
e.sndBufMu.Unlock()
|
||||
e.mu.RUnlock()
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
// Discard any excess data copied in due to avail being reduced due to a
|
||||
// simultaneous write call to the socket.
|
||||
if avail < len(v) {
|
||||
v = v[:avail]
|
||||
}
|
||||
|
||||
// Add data to the send queue.
|
||||
l := len(v)
|
||||
s := newSegmentFromView(&e.route, e.id, v)
|
||||
e.sndBufUsed += l
|
||||
e.sndBufInQueue += seqnum.Size(l)
|
||||
e.sndQueue.PushBack(s)
|
||||
|
||||
e.sndBufMu.Unlock()
|
||||
// Release the endpoint lock to prevent deadlocks due to lock
|
||||
// order inversion when acquiring workMu.
|
||||
e.mu.RUnlock()
|
||||
|
||||
if e.workMu.TryLock() {
|
||||
// Do the work inline.
|
||||
|
|
|
@ -1252,10 +1252,14 @@ cc_binary(
|
|||
srcs = ["partial_bad_buffer.cc"],
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
"//test/syscalls/linux:socket_test_util",
|
||||
"//test/util:file_descriptor",
|
||||
"//test/util:fs_util",
|
||||
"//test/util:posix_error",
|
||||
"//test/util:temp_path",
|
||||
"//test/util:test_main",
|
||||
"//test/util:test_util",
|
||||
"@com_google_absl//absl/time",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -14,13 +14,20 @@
|
|||
|
||||
#include <errno.h>
|
||||
#include <fcntl.h>
|
||||
#include <netinet/in.h>
|
||||
#include <netinet/tcp.h>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/syscall.h>
|
||||
#include <sys/uio.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "absl/time/clock.h"
|
||||
#include "test/syscalls/linux/socket_test_util.h"
|
||||
#include "test/util/file_descriptor.h"
|
||||
#include "test/util/fs_util.h"
|
||||
#include "test/util/posix_error.h"
|
||||
#include "test/util/temp_path.h"
|
||||
#include "test/util/test_util.h"
|
||||
|
||||
|
@ -299,6 +306,109 @@ TEST_F(PartialBadBufferTest, WriteEfaultIsntPartial) {
|
|||
EXPECT_STREQ(buf, kMessage);
|
||||
}
|
||||
|
||||
PosixErrorOr<sockaddr_storage> InetLoopbackAddr(int family) {
|
||||
struct sockaddr_storage addr;
|
||||
memset(&addr, 0, sizeof(addr));
|
||||
addr.ss_family = family;
|
||||
switch (family) {
|
||||
case AF_INET:
|
||||
reinterpret_cast<struct sockaddr_in*>(&addr)->sin_addr.s_addr =
|
||||
htonl(INADDR_LOOPBACK);
|
||||
break;
|
||||
case AF_INET6:
|
||||
reinterpret_cast<struct sockaddr_in6*>(&addr)->sin6_addr =
|
||||
in6addr_loopback;
|
||||
break;
|
||||
default:
|
||||
return PosixError(EINVAL,
|
||||
absl::StrCat("unknown socket family: ", family));
|
||||
}
|
||||
return addr;
|
||||
}
|
||||
|
||||
// SendMsgTCP verifies that calling sendmsg with a bad address returns an
|
||||
// EFAULT. It also verifies that passing a buffer which is made up of 2
|
||||
// pages one valid and one guard page succeeds as long as the write is
|
||||
// for exactly the size of 1 page.
|
||||
TEST_F(PartialBadBufferTest, SendMsgTCP) {
|
||||
auto listen_socket =
|
||||
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
|
||||
|
||||
// Initialize address to the loopback one.
|
||||
sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET));
|
||||
socklen_t addrlen = sizeof(addr);
|
||||
|
||||
// Bind to some port then start listening.
|
||||
ASSERT_THAT(bind(listen_socket.get(),
|
||||
reinterpret_cast<struct sockaddr*>(&addr), addrlen),
|
||||
SyscallSucceeds());
|
||||
|
||||
ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds());
|
||||
|
||||
// Get the address we're listening on, then connect to it. We need to do this
|
||||
// because we're allowing the stack to pick a port for us.
|
||||
ASSERT_THAT(getsockname(listen_socket.get(),
|
||||
reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
|
||||
SyscallSucceeds());
|
||||
|
||||
auto send_socket =
|
||||
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP));
|
||||
|
||||
ASSERT_THAT(
|
||||
RetryEINTR(connect)(send_socket.get(),
|
||||
reinterpret_cast<struct sockaddr*>(&addr), addrlen),
|
||||
SyscallSucceeds());
|
||||
|
||||
// Accept the connection.
|
||||
auto recv_socket =
|
||||
ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr));
|
||||
|
||||
// TODO(gvisor.dev/issue/674): Update this once Netstack matches linux
|
||||
// behaviour on a setsockopt of SO_RCVBUF/SO_SNDBUF.
|
||||
//
|
||||
// Set SO_SNDBUF for socket to exactly kPageSize+1.
|
||||
//
|
||||
// gVisor does not double the value passed in SO_SNDBUF like linux does so we
|
||||
// just increase it by 1 byte here for gVisor so that we can test writing 1
|
||||
// byte past the valid page and check that it triggers an EFAULT
|
||||
// correctly. Otherwise in gVisor the sendmsg call will just return with no
|
||||
// error with kPageSize bytes written successfully.
|
||||
const uint32_t buf_size = kPageSize + 1;
|
||||
ASSERT_THAT(setsockopt(send_socket.get(), SOL_SOCKET, SO_SNDBUF, &buf_size,
|
||||
sizeof(buf_size)),
|
||||
SyscallSucceedsWithValue(0));
|
||||
|
||||
struct msghdr hdr = {};
|
||||
struct iovec iov = {};
|
||||
iov.iov_base = bad_buffer_;
|
||||
iov.iov_len = kPageSize;
|
||||
hdr.msg_iov = &iov;
|
||||
hdr.msg_iovlen = 1;
|
||||
|
||||
ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0),
|
||||
SyscallFailsWithErrno(EFAULT));
|
||||
|
||||
// Now assert that writing kPageSize from addr_ succeeds.
|
||||
iov.iov_base = addr_;
|
||||
ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0),
|
||||
SyscallSucceedsWithValue(kPageSize));
|
||||
// Read all the data out so that we drain the socket SND_BUF on the sender.
|
||||
std::vector<char> buffer(kPageSize);
|
||||
ASSERT_THAT(RetryEINTR(read)(recv_socket.get(), buffer.data(), kPageSize),
|
||||
SyscallSucceedsWithValue(kPageSize));
|
||||
|
||||
// Sleep for a shortwhile to ensure that we have time to process the
|
||||
// ACKs. This is not strictly required unless running under gotsan which is a
|
||||
// lot slower and can result in the next write to write only 1 byte instead of
|
||||
// our intended kPageSize + 1.
|
||||
absl::SleepFor(absl::Milliseconds(50));
|
||||
|
||||
// Now assert that writing > kPageSize results in EFAULT.
|
||||
iov.iov_len = kPageSize + 1;
|
||||
ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0),
|
||||
SyscallFailsWithErrno(EFAULT));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace testing
|
||||
|
|
Loading…
Reference in New Issue