Implement MSG_WAITALL
MSG_WAITALL requests that recv family calls do not perform short reads. It only has an effect for SOCK_STREAM sockets, other types ignore it. PiperOrigin-RevId: 224918540 Change-Id: Id97fbf972f1f7cbd4e08eec0138f8cbdf1c94fe7
This commit is contained in:
parent
d3bc79bc84
commit
5d87d8865f
|
@ -169,7 +169,7 @@ func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.F
|
||||||
|
|
||||||
ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
|
ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
|
||||||
|
|
||||||
return unixsocket.NewWithDirent(ctx, d, ep, flags), nil
|
return unixsocket.NewWithDirent(ctx, d, ep, e.stype != transport.SockStream, flags), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// newSocket allocates a new unix socket with host endpoint.
|
// newSocket allocates a new unix socket with host endpoint.
|
||||||
|
@ -201,7 +201,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error)
|
||||||
|
|
||||||
ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
|
ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
|
||||||
|
|
||||||
return unixsocket.New(ctx, ep), nil
|
return unixsocket.New(ctx, ep, e.stype != transport.SockStream), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send implements transport.ConnectedEndpoint.Send.
|
// Send implements transport.ConnectedEndpoint.Send.
|
||||||
|
|
|
@ -1300,6 +1300,8 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
|
||||||
func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
|
func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
|
||||||
trunc := flags&linux.MSG_TRUNC != 0
|
trunc := flags&linux.MSG_TRUNC != 0
|
||||||
peek := flags&linux.MSG_PEEK != 0
|
peek := flags&linux.MSG_PEEK != 0
|
||||||
|
dontWait := flags&linux.MSG_DONTWAIT != 0
|
||||||
|
waitAll := flags&linux.MSG_WAITALL != 0
|
||||||
if senderRequested && !s.isPacketBased() {
|
if senderRequested && !s.isPacketBased() {
|
||||||
// Stream sockets ignore the sender address.
|
// Stream sockets ignore the sender address.
|
||||||
senderRequested = false
|
senderRequested = false
|
||||||
|
@ -1311,10 +1313,19 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
|
||||||
return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
|
return 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != syserr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
|
if err != nil && (err != syserr.ErrWouldBlock || dontWait) {
|
||||||
|
// Read failed and we should not retry.
|
||||||
|
return 0, nil, 0, socket.ControlMessages{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err == nil && (dontWait || !waitAll || s.isPacketBased() || int64(n) >= dst.NumBytes()) {
|
||||||
|
// We got all the data we need.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Don't overwrite any data we received.
|
||||||
|
dst = dst.DropFirst(n)
|
||||||
|
|
||||||
// We'll have to block. Register for notifications and keep trying to
|
// We'll have to block. Register for notifications and keep trying to
|
||||||
// send all the data.
|
// send all the data.
|
||||||
e, ch := waiter.NewChannelEntry(nil)
|
e, ch := waiter.NewChannelEntry(nil)
|
||||||
|
@ -1322,10 +1333,23 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
|
||||||
defer s.EventUnregister(&e)
|
defer s.EventUnregister(&e)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
n, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
|
var rn int
|
||||||
if err != syserr.ErrWouldBlock {
|
rn, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
|
||||||
|
n += rn
|
||||||
|
if err != nil && err != syserr.ErrWouldBlock {
|
||||||
|
// Always stop on errors other than would block as we generally
|
||||||
|
// won't be able to get any more data. Eat the error if we got
|
||||||
|
// any data.
|
||||||
|
if n > 0 {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if err == nil && (s.isPacketBased() || !waitAll || int64(rn) >= dst.NumBytes()) {
|
||||||
|
// We got all the data we need.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dst = dst.DropFirst(rn)
|
||||||
|
|
||||||
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
|
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
|
||||||
if err == syserror.ETIMEDOUT {
|
if err == syserror.ETIMEDOUT {
|
||||||
|
|
|
@ -53,19 +53,21 @@ type SocketOperations struct {
|
||||||
fsutil.NoopFlush `state:"nosave"`
|
fsutil.NoopFlush `state:"nosave"`
|
||||||
fsutil.NoMMap `state:"nosave"`
|
fsutil.NoMMap `state:"nosave"`
|
||||||
ep transport.Endpoint
|
ep transport.Endpoint
|
||||||
|
isPacket bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new unix socket.
|
// New creates a new unix socket.
|
||||||
func New(ctx context.Context, endpoint transport.Endpoint) *fs.File {
|
func New(ctx context.Context, endpoint transport.Endpoint, isPacket bool) *fs.File {
|
||||||
dirent := socket.NewDirent(ctx, unixSocketDevice)
|
dirent := socket.NewDirent(ctx, unixSocketDevice)
|
||||||
defer dirent.DecRef()
|
defer dirent.DecRef()
|
||||||
return NewWithDirent(ctx, dirent, endpoint, fs.FileFlags{Read: true, Write: true})
|
return NewWithDirent(ctx, dirent, endpoint, isPacket, fs.FileFlags{Read: true, Write: true})
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWithDirent creates a new unix socket using an existing dirent.
|
// NewWithDirent creates a new unix socket using an existing dirent.
|
||||||
func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, flags fs.FileFlags) *fs.File {
|
func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, isPacket bool, flags fs.FileFlags) *fs.File {
|
||||||
return fs.NewFile(ctx, d, flags, &SocketOperations{
|
return fs.NewFile(ctx, d, flags, &SocketOperations{
|
||||||
ep: ep,
|
ep: ep,
|
||||||
|
isPacket: isPacket,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -188,7 +190,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ns := New(t, ep)
|
ns := New(t, ep, s.isPacket)
|
||||||
defer ns.DecRef()
|
defer ns.DecRef()
|
||||||
|
|
||||||
if flags&linux.SOCK_NONBLOCK != 0 {
|
if flags&linux.SOCK_NONBLOCK != 0 {
|
||||||
|
@ -471,6 +473,8 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
|
||||||
func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
|
func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
|
||||||
trunc := flags&linux.MSG_TRUNC != 0
|
trunc := flags&linux.MSG_TRUNC != 0
|
||||||
peek := flags&linux.MSG_PEEK != 0
|
peek := flags&linux.MSG_PEEK != 0
|
||||||
|
dontWait := flags&linux.MSG_DONTWAIT != 0
|
||||||
|
waitAll := flags&linux.MSG_WAITALL != 0
|
||||||
|
|
||||||
// Calculate the number of FDs for which we have space and if we are
|
// Calculate the number of FDs for which we have space and if we are
|
||||||
// requesting credentials.
|
// requesting credentials.
|
||||||
|
@ -497,7 +501,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
|
||||||
if senderRequested {
|
if senderRequested {
|
||||||
r.From = &tcpip.FullAddress{}
|
r.From = &tcpip.FullAddress{}
|
||||||
}
|
}
|
||||||
if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
|
var total int64
|
||||||
|
if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait {
|
||||||
var from interface{}
|
var from interface{}
|
||||||
var fromLen uint32
|
var fromLen uint32
|
||||||
if r.From != nil {
|
if r.From != nil {
|
||||||
|
@ -506,7 +511,13 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
|
||||||
if trunc {
|
if trunc {
|
||||||
n = int64(r.MsgSize)
|
n = int64(r.MsgSize)
|
||||||
}
|
}
|
||||||
return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
|
if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() {
|
||||||
|
return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't overwrite any data we received.
|
||||||
|
dst = dst.DropFirst64(n)
|
||||||
|
total += n
|
||||||
}
|
}
|
||||||
|
|
||||||
// We'll have to block. Register for notification and keep trying to
|
// We'll have to block. Register for notification and keep trying to
|
||||||
|
@ -525,7 +536,13 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
|
||||||
if trunc {
|
if trunc {
|
||||||
n = int64(r.MsgSize)
|
n = int64(r.MsgSize)
|
||||||
}
|
}
|
||||||
return int(n), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
|
total += n
|
||||||
|
if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() {
|
||||||
|
return int(total), from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't overwrite any data we received.
|
||||||
|
dst = dst.DropFirst64(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
|
if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
|
||||||
|
@ -549,16 +566,21 @@ func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int)
|
||||||
|
|
||||||
// Create the endpoint and socket.
|
// Create the endpoint and socket.
|
||||||
var ep transport.Endpoint
|
var ep transport.Endpoint
|
||||||
|
var isPacket bool
|
||||||
switch stype {
|
switch stype {
|
||||||
case linux.SOCK_DGRAM:
|
case linux.SOCK_DGRAM:
|
||||||
|
isPacket = true
|
||||||
ep = transport.NewConnectionless()
|
ep = transport.NewConnectionless()
|
||||||
case linux.SOCK_STREAM, linux.SOCK_SEQPACKET:
|
case linux.SOCK_SEQPACKET:
|
||||||
|
isPacket = true
|
||||||
|
fallthrough
|
||||||
|
case linux.SOCK_STREAM:
|
||||||
ep = transport.NewConnectioned(stype, t.Kernel())
|
ep = transport.NewConnectioned(stype, t.Kernel())
|
||||||
default:
|
default:
|
||||||
return nil, syserr.ErrInvalidArgument
|
return nil, syserr.ErrInvalidArgument
|
||||||
}
|
}
|
||||||
|
|
||||||
return New(t, ep), nil
|
return New(t, ep, isPacket), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pair creates a new pair of AF_UNIX connected sockets.
|
// Pair creates a new pair of AF_UNIX connected sockets.
|
||||||
|
@ -568,16 +590,19 @@ func (*provider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*
|
||||||
return nil, nil, syserr.ErrInvalidArgument
|
return nil, nil, syserr.ErrInvalidArgument
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var isPacket bool
|
||||||
switch stype {
|
switch stype {
|
||||||
case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
|
case linux.SOCK_STREAM:
|
||||||
|
case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
|
||||||
|
isPacket = true
|
||||||
default:
|
default:
|
||||||
return nil, nil, syserr.ErrInvalidArgument
|
return nil, nil, syserr.ErrInvalidArgument
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the endpoints and sockets.
|
// Create the endpoints and sockets.
|
||||||
ep1, ep2 := transport.NewPair(stype, t.Kernel())
|
ep1, ep2 := transport.NewPair(stype, t.Kernel())
|
||||||
s1 := New(t, ep1)
|
s1 := New(t, ep1, isPacket)
|
||||||
s2 := New(t, ep2)
|
s2 := New(t, ep2, isPacket)
|
||||||
|
|
||||||
return s1, s2, nil
|
return s1, s2, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -602,7 +602,7 @@ func RecvMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reject flags that we don't handle yet.
|
// Reject flags that we don't handle yet.
|
||||||
if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_PEEK|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
|
if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_PEEK|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE|linux.MSG_WAITALL) != 0 {
|
||||||
return 0, nil, syscall.EINVAL
|
return 0, nil, syscall.EINVAL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -635,7 +635,7 @@ func RecvMMsg(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysc
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reject flags that we don't handle yet.
|
// Reject flags that we don't handle yet.
|
||||||
if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE) != 0 {
|
if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CMSG_CLOEXEC|linux.MSG_ERRQUEUE|linux.MSG_WAITALL) != 0 {
|
||||||
return 0, nil, syscall.EINVAL
|
return 0, nil, syscall.EINVAL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -791,7 +791,7 @@ func recvFrom(t *kernel.Task, fd kdefs.FD, bufPtr usermem.Addr, bufLen uint64, f
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reject flags that we don't handle yet.
|
// Reject flags that we don't handle yet.
|
||||||
if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_PEEK|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CONFIRM) != 0 {
|
if flags & ^(linux.MSG_DONTWAIT|linux.MSG_NOSIGNAL|linux.MSG_PEEK|linux.MSG_TRUNC|linux.MSG_CTRUNC|linux.MSG_CONFIRM|linux.MSG_WAITALL) != 0 {
|
||||||
return 0, syscall.EINVAL
|
return 0, syscall.EINVAL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -383,8 +383,6 @@ TEST_P(AllSocketPairTest, RecvmsgTimeoutOneSecondSucceeds) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(AllSocketPairTest, RecvWaitAll) {
|
TEST_P(AllSocketPairTest, RecvWaitAll) {
|
||||||
SKIP_IF(IsRunningOnGvisor()); // FIXME: Support MSG_WAITALL.
|
|
||||||
|
|
||||||
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
|
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
|
||||||
|
|
||||||
char sent_data[100];
|
char sent_data[100];
|
||||||
|
@ -399,5 +397,14 @@ TEST_P(AllSocketPairTest, RecvWaitAll) {
|
||||||
SyscallSucceedsWithValue(sizeof(sent_data)));
|
SyscallSucceedsWithValue(sizeof(sent_data)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_P(AllSocketPairTest, RecvWaitAllDontWait) {
|
||||||
|
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
|
||||||
|
|
||||||
|
char data[100] = {};
|
||||||
|
ASSERT_THAT(RetryEINTR(recv)(sockets->second_fd(), data, sizeof(data),
|
||||||
|
MSG_WAITALL | MSG_DONTWAIT),
|
||||||
|
SyscallFailsWithErrno(EAGAIN));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace testing
|
} // namespace testing
|
||||||
} // namespace gvisor
|
} // namespace gvisor
|
||||||
|
|
|
@ -31,8 +31,6 @@ namespace gvisor {
|
||||||
namespace testing {
|
namespace testing {
|
||||||
|
|
||||||
TEST_P(BlockingNonStreamSocketPairTest, RecvLessThanBufferWaitAll) {
|
TEST_P(BlockingNonStreamSocketPairTest, RecvLessThanBufferWaitAll) {
|
||||||
SKIP_IF(IsRunningOnGvisor()); // FIXME: Support MSG_WAITALL.
|
|
||||||
|
|
||||||
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
|
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
|
||||||
|
|
||||||
char sent_data[100];
|
char sent_data[100];
|
||||||
|
|
|
@ -99,8 +99,6 @@ TEST_P(BlockingStreamSocketPairTest, RecvLessThanBuffer) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(BlockingStreamSocketPairTest, RecvLessThanBufferWaitAll) {
|
TEST_P(BlockingStreamSocketPairTest, RecvLessThanBufferWaitAll) {
|
||||||
SKIP_IF(IsRunningOnGvisor()); // FIXME: Support MSG_WAITALL.
|
|
||||||
|
|
||||||
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
|
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
|
||||||
|
|
||||||
char sent_data[100];
|
char sent_data[100];
|
||||||
|
|
Loading…
Reference in New Issue