Fix SO_ERROR behavior for TCP in gVisor.

Fixes the behaviour of SO_ERROR for tcp sockets where in linux it returns
sk->sk_err and if sk->sk_err is 0 then it returns sk->sk_soft_err. In gVisor TCP
we endpoint.HardError is the equivalent of sk->sk_err and endpoint.LastError
holds soft errors. This change brings this into alignment with Linux such that
both hard/soft errors are cleared when retrieved using getsockopt(.. SO_ERROR)
is called on a socket.

Fixes #3812

PiperOrigin-RevId: 342868552
This commit is contained in:
Bhasker Hariharan 2020-11-17 08:30:31 -08:00 committed by gVisor bot
parent 938aabeecb
commit fb9a649f39
9 changed files with 241 additions and 59 deletions

View File

@ -2686,7 +2686,7 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ
// Always do at least one fetchReadView, even if the number of bytes to
// read is 0.
err = s.fetchReadView()
if err != nil {
if err != nil || len(s.readView) == 0 {
break
}
if dst.NumBytes() == 0 {
@ -2709,15 +2709,20 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ
}
copied += n
s.readView.TrimFront(n)
if len(s.readView) == 0 {
atomic.StoreUint32(&s.readViewHasData, 0)
}
dst = dst.DropFirst(n)
if e != nil {
err = syserr.FromError(e)
break
}
// If we are done reading requested data then stop.
if dst.NumBytes() == 0 {
break
}
}
if len(s.readView) == 0 {
atomic.StoreUint32(&s.readViewHasData, 0)
}
// If we managed to copy something, we must deliver it.

View File

@ -26,6 +26,7 @@ go_test(
"//pkg/tcpip",
"//pkg/tcpip/header",
"//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",

View File

@ -28,6 +28,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@ -65,7 +66,7 @@ func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
})
if err := s.CreateNIC(NICID, loopback.New()); err != nil {
if err := s.CreateNIC(NICID, sniffer.New(loopback.New())); err != nil {
return nil, err
}

View File

@ -496,7 +496,7 @@ func (h *handshake) resolveRoute() *tcpip.Error {
h.ep.mu.Lock()
}
if n&notifyError != 0 {
return h.ep.LastError()
return h.ep.lastErrorLocked()
}
}
@ -575,7 +575,6 @@ func (h *handshake) complete() *tcpip.Error {
return err
}
defer timer.stop()
for h.state != handshakeCompleted {
// Unlock before blocking, and reacquire again afterwards (h.ep.mu is held
// throughout handshake processing).
@ -631,9 +630,8 @@ func (h *handshake) complete() *tcpip.Error {
h.ep.mu.Lock()
}
if n&notifyError != 0 {
return h.ep.LastError()
return h.ep.lastErrorLocked()
}
case wakerForNewSegment:
if err := h.processSegments(); err != nil {
return err
@ -1002,7 +1000,7 @@ func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
// Only send a reset if the connection is being aborted for a reason
// other than receiving a reset.
e.setEndpointState(StateError)
e.HardError = err
e.hardError = err
if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout {
// The exact sequence number to be used for the RST is the same as the
// one used by Linux. We need to handle the case of window being shrunk
@ -1141,7 +1139,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) {
// delete the TCB, and return.
case StateCloseWait:
e.transitionToStateCloseLocked()
e.HardError = tcpip.ErrAborted
e.hardError = tcpip.ErrAborted
e.notifyProtocolGoroutine(notifyTickleWorker)
return false, nil
default:
@ -1353,7 +1351,6 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
epilogue := func() {
// e.mu is expected to be hold upon entering this section.
if e.snd != nil {
e.snd.resendTimer.cleanup()
}
@ -1383,7 +1380,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{
e.lastErrorMu.Unlock()
e.setEndpointState(StateError)
e.HardError = err
e.hardError = err
e.workerCleanup = true
// Lock released below.

View File

@ -315,11 +315,6 @@ func (*Stats) IsEndpointStats() {}
// +stateify savable
type EndpointInfo struct {
stack.TransportEndpointInfo
// HardError is meaningful only when state is stateError. It stores the
// error to be returned when read/write syscalls are called and the
// endpoint is in this state. HardError is protected by endpoint mu.
HardError *tcpip.Error `state:".(string)"`
}
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
@ -386,6 +381,11 @@ type endpoint struct {
waiterQueue *waiter.Queue `state:"wait"`
uniqueID uint64
// hardError is meaningful only when state is stateError. It stores the
// error to be returned when read/write syscalls are called and the
// endpoint is in this state. hardError is protected by endpoint mu.
hardError *tcpip.Error `state:".(string)"`
// lastError represents the last error that the endpoint reported;
// access to it is protected by the following mutex.
lastErrorMu sync.Mutex `state:"nosave"`
@ -1283,7 +1283,15 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
func (e *endpoint) LastError() *tcpip.Error {
// Preconditions: e.mu must be held to call this function.
func (e *endpoint) hardErrorLocked() *tcpip.Error {
err := e.hardError
e.hardError = nil
return err
}
// Preconditions: e.mu must be held to call this function.
func (e *endpoint) lastErrorLocked() *tcpip.Error {
e.lastErrorMu.Lock()
defer e.lastErrorMu.Unlock()
err := e.lastError
@ -1291,6 +1299,15 @@ func (e *endpoint) LastError() *tcpip.Error {
return err
}
func (e *endpoint) LastError() *tcpip.Error {
e.LockUser()
defer e.UnlockUser()
if err := e.hardErrorLocked(); err != nil {
return err
}
return e.lastErrorLocked()
}
// Read reads data from the endpoint.
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.LockUser()
@ -1312,9 +1329,8 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
bufUsed := e.rcvBufUsed
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
e.rcvListMu.Unlock()
he := e.HardError
if s == StateError {
return buffer.View{}, tcpip.ControlMessages{}, he
return buffer.View{}, tcpip.ControlMessages{}, e.hardErrorLocked()
}
e.stats.ReadErrors.NotConnected.Increment()
return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected
@ -1370,9 +1386,13 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
// indicating the reason why it's not writable.
// Caller must hold e.mu and e.sndBufMu
func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) {
// The endpoint cannot be written to if it's not connected.
switch s := e.EndpointState(); {
case s == StateError:
return 0, e.HardError
if err := e.hardErrorLocked(); err != nil {
return 0, err
}
return 0, tcpip.ErrClosedForSend
case !s.connecting() && !s.connected():
return 0, tcpip.ErrClosedForSend
case s.connecting():
@ -1486,7 +1506,7 @@ func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Erro
// but has some pending unread data.
if s := e.EndpointState(); !s.connected() && s != StateClose {
if s == StateError {
return 0, tcpip.ControlMessages{}, e.HardError
return 0, tcpip.ControlMessages{}, e.hardErrorLocked()
}
e.stats.ReadErrors.InvalidEndpointState.Increment()
return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
@ -2243,7 +2263,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc
return tcpip.ErrAlreadyConnecting
case StateError:
return e.HardError
if err := e.hardErrorLocked(); err != nil {
return err
}
return tcpip.ErrConnectionAborted
default:
return tcpip.ErrInvalidEndpointState
@ -2417,7 +2440,7 @@ func (e *endpoint) startMainLoop(handshake bool) *tcpip.Error {
e.lastErrorMu.Unlock()
e.setEndpointState(StateError)
e.HardError = err
e.hardError = err
// Call cleanupLocked to free up any reservations.
e.cleanupLocked()

View File

@ -321,21 +321,21 @@ func (e *endpoint) loadRecentTSTime(unix unixTime) {
}
// saveHardError is invoked by stateify.
func (e *EndpointInfo) saveHardError() string {
if e.HardError == nil {
func (e *endpoint) saveHardError() string {
if e.hardError == nil {
return ""
}
return e.HardError.String()
return e.hardError.String()
}
// loadHardError is invoked by stateify.
func (e *EndpointInfo) loadHardError(s string) {
func (e *endpoint) loadHardError(s string) {
if s == "" {
return
}
e.HardError = tcpip.StringToError(s)
e.hardError = tcpip.StringToError(s)
}
// saveMeasureTime is invoked by stateify.

View File

@ -75,9 +75,6 @@ func TestGiveUpConnect(t *testing.T) {
// Wait for ep to become writable.
<-notifyCh
if err := ep.LastError(); err != tcpip.ErrAborted {
t.Fatalf("got ep.LastError() = %s, want = %s", err, tcpip.ErrAborted)
}
// Call Connect again to retreive the handshake failure status
// and stats updates.
@ -3198,6 +3195,11 @@ loop:
case tcpip.ErrWouldBlock:
select {
case <-ch:
// Expect the state to be StateError and subsequent Reads to fail with HardError.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
break loop
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for reset to arrive")
}
@ -3207,14 +3209,10 @@ loop:
t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
}
// Expect the state to be StateError and subsequent Reads to fail with HardError.
if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
}
if tcp.EndpointState(c.EP.State()) != tcp.StateError {
t.Fatalf("got EP state is not StateError")
}
if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 {
t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got)
}

View File

@ -1185,19 +1185,44 @@ TEST_P(SocketInetLoopbackTest, TCPAcceptAfterReset) {
listen_fd.get(), reinterpret_cast<sockaddr*>(&accept_addr), &addrlen));
ASSERT_EQ(addrlen, listener.addr_len);
// TODO(gvisor.dev/issue/3812): Remove after SO_ERROR is fixed.
if (IsRunningOnGvisor()) {
char buf[10];
ASSERT_THAT(ReadFd(accept_fd.get(), buf, sizeof(buf)),
SyscallFailsWithErrno(ECONNRESET));
} else {
// Wait for accept_fd to process the RST.
const int kTimeout = 10000;
struct pollfd pfd = {
.fd = accept_fd.get(),
.events = POLLIN,
};
ASSERT_THAT(poll(&pfd, 1, kTimeout), SyscallSucceedsWithValue(1));
ASSERT_EQ(pfd.revents, POLLIN | POLLHUP | POLLERR);
{
int err;
socklen_t optlen = sizeof(err);
ASSERT_THAT(
getsockopt(accept_fd.get(), SOL_SOCKET, SO_ERROR, &err, &optlen),
SyscallSucceeds());
ASSERT_EQ(err, ECONNRESET);
// This should return ECONNRESET as the socket just received a RST packet
// from the peer.
ASSERT_EQ(optlen, sizeof(err));
ASSERT_EQ(err, ECONNRESET);
}
{
int err;
socklen_t optlen = sizeof(err);
ASSERT_THAT(
getsockopt(accept_fd.get(), SOL_SOCKET, SO_ERROR, &err, &optlen),
SyscallSucceeds());
// This should return no error as the previous getsockopt call would have
// cleared the socket error.
ASSERT_EQ(optlen, sizeof(err));
ASSERT_EQ(err, 0);
}
{
sockaddr_storage peer_addr;
socklen_t addrlen = sizeof(peer_addr);
// The socket is not connected anymore and should return ENOTCONN.
ASSERT_THAT(getpeername(accept_fd.get(),
reinterpret_cast<sockaddr*>(&peer_addr), &addrlen),
SyscallFailsWithErrno(ENOTCONN));
}
}

View File

@ -964,37 +964,156 @@ TEST_P(TcpSocketTest, PollAfterShutdown) {
SyscallSucceedsWithValue(1));
}
TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListener) {
TEST_P(SimpleTcpSocketTest, NonBlockingConnectRetry) {
const FileDescriptor listener =
ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
// Initialize address to the loopback one.
sockaddr_storage addr =
ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
socklen_t addrlen = sizeof(addr);
const FileDescriptor s =
// Bind to some port but don't listen yet.
ASSERT_THAT(
bind(listener.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
SyscallSucceeds());
// Get the address we're bound to, then connect to it. We need to do this
// because we're allowing the stack to pick a port for us.
ASSERT_THAT(getsockname(listener.get(),
reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
SyscallSucceeds());
FileDescriptor connector =
ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP));
// Set the FD to O_NONBLOCK.
int opts;
ASSERT_THAT(opts = fcntl(s.get(), F_GETFL), SyscallSucceeds());
opts |= O_NONBLOCK;
ASSERT_THAT(fcntl(s.get(), F_SETFL, opts), SyscallSucceeds());
// Verify that connect fails.
ASSERT_THAT(
RetryEINTR(connect)(connector.get(),
reinterpret_cast<struct sockaddr*>(&addr), addrlen),
SyscallFailsWithErrno(ECONNREFUSED));
ASSERT_THAT(RetryEINTR(connect)(
// Now start listening
ASSERT_THAT(listen(listener.get(), SOMAXCONN), SyscallSucceeds());
// TODO(gvisor.dev/issue/3828): Issuing connect() again on a socket that
// failed first connect should succeed.
if (IsRunningOnGvisor()) {
ASSERT_THAT(
RetryEINTR(connect)(connector.get(),
reinterpret_cast<struct sockaddr*>(&addr), addrlen),
SyscallFailsWithErrno(ECONNABORTED));
return;
}
// Verify that connect now succeeds.
ASSERT_THAT(
RetryEINTR(connect)(connector.get(),
reinterpret_cast<struct sockaddr*>(&addr), addrlen),
SyscallSucceeds());
// Accept the connection.
const FileDescriptor accepted =
ASSERT_NO_ERRNO_AND_VALUE(Accept(listener.get(), nullptr, nullptr));
}
// nonBlockingConnectNoListener returns a socket on which a connect that is
// expected to fail has been issued.
PosixErrorOr<FileDescriptor> nonBlockingConnectNoListener(const int family,
sockaddr_storage addr,
socklen_t addrlen) {
// We will first create a socket and bind to ensure we bind a port but will
// not call listen on this socket.
// Then we will create a new socket that will connect to the port bound by
// the first socket and that shoud fail.
constexpr int sock_type = SOCK_STREAM | SOCK_NONBLOCK;
int b_sock;
RETURN_ERROR_IF_SYSCALL_FAIL(b_sock = socket(family, sock_type, IPPROTO_TCP));
FileDescriptor b(b_sock);
EXPECT_THAT(bind(b.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
SyscallSucceeds());
// Get the address bound by the listening socket.
EXPECT_THAT(
getsockname(b.get(), reinterpret_cast<struct sockaddr*>(&addr), &addrlen),
SyscallSucceeds());
// Now create another socket and issue a connect on this one. This connect
// should fail as there is no listener.
int c_sock;
RETURN_ERROR_IF_SYSCALL_FAIL(c_sock = socket(family, sock_type, IPPROTO_TCP));
FileDescriptor s(c_sock);
// Now connect to the bound address and this should fail as nothing
// is listening on the bound address.
EXPECT_THAT(RetryEINTR(connect)(
s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
SyscallFailsWithErrno(EINPROGRESS));
// Now polling on the FD with a timeout should return 0 corresponding to no
// FDs ready.
struct pollfd poll_fd = {s.get(), POLLOUT, 0};
EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 10000),
SyscallSucceedsWithValue(1));
// Wait for the connect to fail.
struct pollfd poll_fd = {s.get(), POLLERR, 0};
EXPECT_THAT(RetryEINTR(poll)(&poll_fd, 1, 1000), SyscallSucceedsWithValue(1));
return std::move(s);
}
TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListener) {
sockaddr_storage addr =
ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
socklen_t addrlen = sizeof(addr);
const FileDescriptor s =
nonBlockingConnectNoListener(GetParam(), addr, addrlen).ValueOrDie();
int err;
socklen_t optlen = sizeof(err);
ASSERT_THAT(getsockopt(s.get(), SOL_SOCKET, SO_ERROR, &err, &optlen),
SyscallSucceeds());
ASSERT_THAT(optlen, sizeof(err));
EXPECT_EQ(err, ECONNREFUSED);
unsigned char c;
ASSERT_THAT(read(s.get(), &c, sizeof(c)), SyscallSucceedsWithValue(0));
int opts;
EXPECT_THAT(opts = fcntl(s.get(), F_GETFL), SyscallSucceeds());
opts &= ~O_NONBLOCK;
EXPECT_THAT(fcntl(s.get(), F_SETFL, opts), SyscallSucceeds());
// Try connecting again.
ASSERT_THAT(RetryEINTR(connect)(
s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
SyscallFailsWithErrno(ECONNABORTED));
}
TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListenerRead) {
sockaddr_storage addr =
ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
socklen_t addrlen = sizeof(addr);
const FileDescriptor s =
nonBlockingConnectNoListener(GetParam(), addr, addrlen).ValueOrDie();
unsigned char c;
ASSERT_THAT(read(s.get(), &c, 1), SyscallFailsWithErrno(ECONNREFUSED));
ASSERT_THAT(read(s.get(), &c, 1), SyscallSucceedsWithValue(0));
ASSERT_THAT(RetryEINTR(connect)(
s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
SyscallFailsWithErrno(ECONNABORTED));
}
TEST_P(SimpleTcpSocketTest, NonBlockingConnectNoListenerPeek) {
sockaddr_storage addr =
ASSERT_NO_ERRNO_AND_VALUE(InetLoopbackAddr(GetParam()));
socklen_t addrlen = sizeof(addr);
const FileDescriptor s =
nonBlockingConnectNoListener(GetParam(), addr, addrlen).ValueOrDie();
unsigned char c;
ASSERT_THAT(recv(s.get(), &c, 1, MSG_PEEK),
SyscallFailsWithErrno(ECONNREFUSED));
ASSERT_THAT(recv(s.get(), &c, 1, MSG_PEEK), SyscallSucceedsWithValue(0));
ASSERT_THAT(RetryEINTR(connect)(
s.get(), reinterpret_cast<struct sockaddr*>(&addr), addrlen),
SyscallFailsWithErrno(ECONNABORTED));
}
TEST_P(SimpleTcpSocketTest, SelfConnectSendRecv_NoRandomSave) {
@ -1235,6 +1354,19 @@ TEST_P(SimpleTcpSocketTest, CleanupOnConnectionRefused) {
// Attempt #2, with the new socket and reused addr our connect should fail in
// the same way as before, not with an EADDRINUSE.
//
// TODO(gvisor.dev/issue/3828): 2nd connect on a socket which failed connect
// first time should succeed.
// gVisor never issues the second connect and returns ECONNABORTED instead.
// Linux actually sends a SYN again and gets a RST and correctly returns
// ECONNREFUSED.
if (IsRunningOnGvisor()) {
ASSERT_THAT(connect(client_s.get(),
reinterpret_cast<const struct sockaddr*>(&bound_addr),
bound_addrlen),
SyscallFailsWithErrno(ECONNABORTED));
return;
}
ASSERT_THAT(connect(client_s.get(),
reinterpret_cast<const struct sockaddr*>(&bound_addr),
bound_addrlen),