From 570fb1db6b4e01be37386a379fea4d63e5a3cdc2 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 14 Aug 2019 14:33:11 -0700 Subject: [PATCH] Improve SendMsg performance. SendMsg before this change would copy all the data over into a new slice even if the underlying socket could only accept a small amount of data. This is really inefficient with non-blocking sockets and under high throughput where large writes could get ErrWouldBlock or if there was say a timeout associated with the sendmsg() syscall. With this change we delay copying bytes in till they are needed and only copy what can be potentially sent/held in the socket buffer. Reducing the need to repeatedly copy data over. Also a minor fix to change state FIN-WAIT-1 when shutdown(..., SHUT_WR) is called instead of when we transmit the actual FIN. Otherwise the socket could remain in CONNECTED state even though the user has called shutdown() on the socket. Updates #627 PiperOrigin-RevId: 263430505 --- pkg/sentry/socket/epsocket/epsocket.go | 27 +++--- pkg/tcpip/transport/tcp/endpoint.go | 89 +++++++++++------ test/syscalls/linux/BUILD | 4 + test/syscalls/linux/partial_bad_buffer.cc | 110 ++++++++++++++++++++++ 4 files changed, 189 insertions(+), 41 deletions(-) diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 8cb5c823f..0f2cd05fc 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -429,6 +429,11 @@ func (i *ioSequencePayload) Size() int { return int(i.src.NumBytes()) } +// DropFirst drops the first n bytes from underlying src. +func (i *ioSequencePayload) DropFirst(n int) { + i.src = i.src.DropFirst(int(n)) +} + // Write implements fs.FileOperations.Write. func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { f := &ioSequencePayload{ctx: ctx, src: src} @@ -2026,28 +2031,22 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] addr = &addrBuf } - v := buffer.NewView(int(src.NumBytes())) - - // Copy all the data into the buffer. - if _, err := src.CopyIn(t, v); err != nil { - return 0, syserr.FromError(err) - } - opts := tcpip.WriteOptions{ To: addr, More: flags&linux.MSG_MORE != 0, EndOfRecord: flags&linux.MSG_EOR != 0, } - n, resCh, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts) + v := &ioSequencePayload{t, src} + n, resCh, err := s.Endpoint.Write(v, opts) if resCh != nil { if err := t.Block(resCh); err != nil { return 0, syserr.FromError(err) } - n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts) + n, _, err = s.Endpoint.Write(v, opts) } dontWait := flags&linux.MSG_DONTWAIT != 0 - if err == nil && (n >= uintptr(len(v)) || dontWait) { + if err == nil && (n >= uintptr(v.Size()) || dontWait) { // Complete write. return int(n), nil } @@ -2061,18 +2060,18 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] s.EventRegister(&e, waiter.EventOut) defer s.EventUnregister(&e) - v.TrimFront(int(n)) + v.DropFirst(int(n)) total := n for { - n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts) - v.TrimFront(int(n)) + n, _, err = s.Endpoint.Write(v, opts) + v.DropFirst(int(n)) total += n if err != nil && err != tcpip.ErrWouldBlock && total == 0 { return 0, syserr.TranslateNetstackError(err) } - if err == nil && len(v) == 0 || err != nil && err != tcpip.ErrWouldBlock { + if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock { return int(total), nil } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index e67169111..7c42a830a 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -878,6 +878,34 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { return v, nil } +// isEndpointWritableLocked checks if a given endpoint is writable +// and also returns the number of bytes that can be written at this +// moment. If the endpoint is not writable then it returns an error +// indicating the reason why it's not writable. +// Caller must hold e.mu and e.sndBufMu +func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { + // The endpoint cannot be written to if it's not connected. + if !e.state.connected() { + switch e.state { + case StateError: + return 0, e.hardError + default: + return 0, tcpip.ErrClosedForSend + } + } + + // Check if the connection has already been closed for sends. + if e.sndClosed { + return 0, tcpip.ErrClosedForSend + } + + avail := e.sndBufSize - e.sndBufUsed + if avail <= 0 { + return 0, tcpip.ErrWouldBlock + } + return avail, nil +} + // Write writes data to the endpoint's peer. func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { // Linux completely ignores any address passed to sendto(2) for TCP sockets @@ -885,53 +913,60 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c // and opts.EndOfRecord are also ignored. e.mu.RLock() - defer e.mu.RUnlock() + e.sndBufMu.Lock() - // The endpoint cannot be written to if it's not connected. - if !e.state.connected() { - switch e.state { - case StateError: - return 0, nil, e.hardError - default: - return 0, nil, tcpip.ErrClosedForSend - } + avail, err := e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + return 0, nil, err } + e.sndBufMu.Unlock() + e.mu.RUnlock() + // Nothing to do if the buffer is empty. if p.Size() == 0 { return 0, nil, nil } - e.sndBufMu.Lock() - - // Check if the connection has already been closed for sends. - if e.sndClosed { - e.sndBufMu.Unlock() - return 0, nil, tcpip.ErrClosedForSend - } - - // Check against the limit. - avail := e.sndBufSize - e.sndBufUsed - if avail <= 0 { - e.sndBufMu.Unlock() - return 0, nil, tcpip.ErrWouldBlock - } - + // Copy in memory without holding sndBufMu so that worker goroutine can + // make progress independent of this operation. v, perr := p.Get(avail) if perr != nil { - e.sndBufMu.Unlock() return 0, nil, perr } - l := len(v) - s := newSegmentFromView(&e.route, e.id, v) + e.mu.RLock() + e.sndBufMu.Lock() + + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a + // write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + return 0, nil, err + } + + // Discard any excess data copied in due to avail being reduced due to a + // simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } // Add data to the send queue. + l := len(v) + s := newSegmentFromView(&e.route, e.id, v) e.sndBufUsed += l e.sndBufInQueue += seqnum.Size(l) e.sndQueue.PushBack(s) e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. + e.mu.RUnlock() if e.workMu.TryLock() { // Do the work inline. diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 16666e772..d28ce4ba1 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1252,10 +1252,14 @@ cc_binary( srcs = ["partial_bad_buffer.cc"], linkstatic = 1, deps = [ + "//test/syscalls/linux:socket_test_util", + "//test/util:file_descriptor", "//test/util:fs_util", + "//test/util:posix_error", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", + "@com_google_absl//absl/time", "@com_google_googletest//:gtest", ], ) diff --git a/test/syscalls/linux/partial_bad_buffer.cc b/test/syscalls/linux/partial_bad_buffer.cc index 83b1ad4e4..33822ee57 100644 --- a/test/syscalls/linux/partial_bad_buffer.cc +++ b/test/syscalls/linux/partial_bad_buffer.cc @@ -14,13 +14,20 @@ #include #include +#include +#include #include +#include #include #include #include #include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/file_descriptor.h" #include "test/util/fs_util.h" +#include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -299,6 +306,109 @@ TEST_F(PartialBadBufferTest, WriteEfaultIsntPartial) { EXPECT_STREQ(buf, kMessage); } +PosixErrorOr InetLoopbackAddr(int family) { + struct sockaddr_storage addr; + memset(&addr, 0, sizeof(addr)); + addr.ss_family = family; + switch (family) { + case AF_INET: + reinterpret_cast(&addr)->sin_addr.s_addr = + htonl(INADDR_LOOPBACK); + break; + case AF_INET6: + reinterpret_cast(&addr)->sin6_addr = + in6addr_loopback; + break; + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } + return addr; +} + +// SendMsgTCP verifies that calling sendmsg with a bad address returns an +// EFAULT. It also verifies that passing a buffer which is made up of 2 +// pages one valid and one guard page succeeds as long as the write is +// for exactly the size of 1 page. +TEST_F(PartialBadBufferTest, SendMsgTCP) { + auto listen_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + + // Initialize address to the loopback one. + sockaddr_storage addr = ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(AF_INET)); + socklen_t addrlen = sizeof(addr); + + // Bind to some port then start listening. + ASSERT_THAT(bind(listen_socket.get(), + reinterpret_cast(&addr), addrlen), + SyscallSucceeds()); + + ASSERT_THAT(listen(listen_socket.get(), SOMAXCONN), SyscallSucceeds()); + + // Get the address we're listening on, then connect to it. We need to do this + // because we're allowing the stack to pick a port for us. + ASSERT_THAT(getsockname(listen_socket.get(), + reinterpret_cast(&addr), &addrlen), + SyscallSucceeds()); + + auto send_socket = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)); + + ASSERT_THAT( + RetryEINTR(connect)(send_socket.get(), + reinterpret_cast(&addr), addrlen), + SyscallSucceeds()); + + // Accept the connection. + auto recv_socket = + ASSERT_NO_ERRNO_AND_VALUE(Accept(listen_socket.get(), nullptr, nullptr)); + + // TODO(gvisor.dev/issue/674): Update this once Netstack matches linux + // behaviour on a setsockopt of SO_RCVBUF/SO_SNDBUF. + // + // Set SO_SNDBUF for socket to exactly kPageSize+1. + // + // gVisor does not double the value passed in SO_SNDBUF like linux does so we + // just increase it by 1 byte here for gVisor so that we can test writing 1 + // byte past the valid page and check that it triggers an EFAULT + // correctly. Otherwise in gVisor the sendmsg call will just return with no + // error with kPageSize bytes written successfully. + const uint32_t buf_size = kPageSize + 1; + ASSERT_THAT(setsockopt(send_socket.get(), SOL_SOCKET, SO_SNDBUF, &buf_size, + sizeof(buf_size)), + SyscallSucceedsWithValue(0)); + + struct msghdr hdr = {}; + struct iovec iov = {}; + iov.iov_base = bad_buffer_; + iov.iov_len = kPageSize; + hdr.msg_iov = &iov; + hdr.msg_iovlen = 1; + + ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0), + SyscallFailsWithErrno(EFAULT)); + + // Now assert that writing kPageSize from addr_ succeeds. + iov.iov_base = addr_; + ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0), + SyscallSucceedsWithValue(kPageSize)); + // Read all the data out so that we drain the socket SND_BUF on the sender. + std::vector buffer(kPageSize); + ASSERT_THAT(RetryEINTR(read)(recv_socket.get(), buffer.data(), kPageSize), + SyscallSucceedsWithValue(kPageSize)); + + // Sleep for a shortwhile to ensure that we have time to process the + // ACKs. This is not strictly required unless running under gotsan which is a + // lot slower and can result in the next write to write only 1 byte instead of + // our intended kPageSize + 1. + absl::SleepFor(absl::Milliseconds(50)); + + // Now assert that writing > kPageSize results in EFAULT. + iov.iov_len = kPageSize + 1; + ASSERT_THAT(RetryEINTR(sendmsg)(send_socket.get(), &hdr, 0), + SyscallFailsWithErrno(EFAULT)); +} + } // namespace } // namespace testing