[netstack] Move SO_PASSCRED option to SocketOptions.
This change also makes the following fixes: - Make SocketOptions use atomic operations instead of having to acquire/drop locks upon each get/set option. - Make documentation more consistent. - Remove tcpip.SocketOptions from socketOpsCommon because it already exists in transport.Endpoint. - Refactors get/set socket options tests to be easily extendable. PiperOrigin-RevId: 343103780
This commit is contained in:
parent
87ed61ea05
commit
fc342fb439
|
@ -120,9 +120,6 @@ type socketOpsCommon struct {
|
|||
// fixed buffer but only consume this many bytes.
|
||||
sendBufferSize uint32
|
||||
|
||||
// passcred indicates if this socket wants SCM credentials.
|
||||
passcred bool
|
||||
|
||||
// filter indicates that this socket has a BPF filter "installed".
|
||||
//
|
||||
// TODO(gvisor.dev/issue/1119): We don't actually support filtering,
|
||||
|
@ -201,10 +198,7 @@ func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
|
|||
|
||||
// Passcred implements transport.Credentialer.Passcred.
|
||||
func (s *socketOpsCommon) Passcred() bool {
|
||||
s.mu.Lock()
|
||||
passcred := s.passcred
|
||||
s.mu.Unlock()
|
||||
return passcred
|
||||
return s.ep.SocketOptions().GetPassCred()
|
||||
}
|
||||
|
||||
// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
|
||||
|
@ -419,9 +413,7 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
|
|||
}
|
||||
passcred := usermem.ByteOrder.Uint32(opt)
|
||||
|
||||
s.mu.Lock()
|
||||
s.passcred = passcred != 0
|
||||
s.mu.Unlock()
|
||||
s.ep.SocketOptions().SetPassCred(passcred != 0)
|
||||
return nil
|
||||
|
||||
case linux.SO_ATTACH_FILTER:
|
||||
|
|
|
@ -260,10 +260,12 @@ type commonEndpoint interface {
|
|||
// transport.Endpoint.GetSockOpt.
|
||||
GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
|
||||
|
||||
// LastError implements tcpip.Endpoint.LastError.
|
||||
// LastError implements tcpip.Endpoint.LastError and
|
||||
// transport.Endpoint.LastError.
|
||||
LastError() *tcpip.Error
|
||||
|
||||
// SocketOptions implements tcpip.Endpoint.SocketOptions.
|
||||
// SocketOptions implements tcpip.Endpoint.SocketOptions and
|
||||
// transport.Endpoint.SocketOptions.
|
||||
SocketOptions() *tcpip.SocketOptions
|
||||
}
|
||||
|
||||
|
@ -1068,13 +1070,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
|
|||
return nil, syserr.ErrInvalidArgument
|
||||
}
|
||||
|
||||
v, err := ep.GetSockOptBool(tcpip.PasscredOption)
|
||||
if err != nil {
|
||||
return nil, syserr.TranslateNetstackError(err)
|
||||
}
|
||||
|
||||
vP := primitive.Int32(boolToInt32(v))
|
||||
return &vP, nil
|
||||
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetPassCred()))
|
||||
return &v, nil
|
||||
|
||||
case linux.SO_SNDBUF:
|
||||
if outLen < sizeOfInt32 {
|
||||
|
@ -1923,7 +1920,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
|
|||
}
|
||||
|
||||
v := usermem.ByteOrder.Uint32(optVal)
|
||||
return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0))
|
||||
ep.SocketOptions().SetPassCred(v != 0)
|
||||
return nil
|
||||
|
||||
case linux.SO_KEEPALIVE:
|
||||
if len(optVal) < sizeOfInt32 {
|
||||
|
|
|
@ -16,8 +16,6 @@
|
|||
package transport
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/abi/linux"
|
||||
"gvisor.dev/gvisor/pkg/context"
|
||||
"gvisor.dev/gvisor/pkg/log"
|
||||
|
@ -203,10 +201,11 @@ type Endpoint interface {
|
|||
// procfs.
|
||||
State() uint32
|
||||
|
||||
// LastError implements tcpip.Endpoint.LastError.
|
||||
// LastError clears and returns the last error reported by the endpoint.
|
||||
LastError() *tcpip.Error
|
||||
|
||||
// SocketOptions implements tcpip.Endpoint.SocketOptions.
|
||||
// SocketOptions returns the structure which contains all the socket
|
||||
// level options.
|
||||
SocketOptions() *tcpip.SocketOptions
|
||||
}
|
||||
|
||||
|
@ -740,10 +739,6 @@ func (e *connectedEndpoint) CloseUnread() {
|
|||
type baseEndpoint struct {
|
||||
*waiter.Queue
|
||||
|
||||
// passcred specifies whether SCM_CREDENTIALS socket control messages are
|
||||
// enabled on this endpoint. Must be accessed atomically.
|
||||
passcred int32
|
||||
|
||||
// Mutex protects the below fields.
|
||||
sync.Mutex `state:"nosave"`
|
||||
|
||||
|
@ -786,7 +781,7 @@ func (e *baseEndpoint) EventUnregister(we *waiter.Entry) {
|
|||
|
||||
// Passcred implements Credentialer.Passcred.
|
||||
func (e *baseEndpoint) Passcred() bool {
|
||||
return atomic.LoadInt32(&e.passcred) != 0
|
||||
return e.SocketOptions().GetPassCred()
|
||||
}
|
||||
|
||||
// ConnectedPasscred implements Credentialer.ConnectedPasscred.
|
||||
|
@ -796,14 +791,6 @@ func (e *baseEndpoint) ConnectedPasscred() bool {
|
|||
return e.connected != nil && e.connected.Passcred()
|
||||
}
|
||||
|
||||
func (e *baseEndpoint) setPasscred(pc bool) {
|
||||
if pc {
|
||||
atomic.StoreInt32(&e.passcred, 1)
|
||||
} else {
|
||||
atomic.StoreInt32(&e.passcred, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// Connected implements ConnectingEndpoint.Connected.
|
||||
func (e *baseEndpoint) Connected() bool {
|
||||
return e.receiver != nil && e.connected != nil
|
||||
|
@ -870,8 +857,6 @@ func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
|
|||
|
||||
func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
|
||||
switch opt {
|
||||
case tcpip.PasscredOption:
|
||||
e.setPasscred(v)
|
||||
case tcpip.ReuseAddressOption:
|
||||
default:
|
||||
log.Warningf("Unsupported socket option: %d", opt)
|
||||
|
@ -894,9 +879,6 @@ func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error
|
|||
case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption:
|
||||
return false, nil
|
||||
|
||||
case tcpip.PasscredOption:
|
||||
return e.Passcred(), nil
|
||||
|
||||
default:
|
||||
log.Warningf("Unsupported socket option: %d", opt)
|
||||
return false, tcpip.ErrUnknownProtocolOption
|
||||
|
|
|
@ -115,9 +115,6 @@ type socketOpsCommon struct {
|
|||
// bound, they cannot be modified.
|
||||
abstractName string
|
||||
abstractNamespace *kernel.AbstractSocketNamespace
|
||||
|
||||
// ops is used to get socket level options.
|
||||
ops tcpip.SocketOptions
|
||||
}
|
||||
|
||||
func (s *socketOpsCommon) isPacket() bool {
|
||||
|
|
|
@ -15,31 +15,49 @@
|
|||
package tcpip
|
||||
|
||||
import (
|
||||
"gvisor.dev/gvisor/pkg/sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// SocketOptions contains all the variables which store values for socket
|
||||
// SocketOptions contains all the variables which store values for SOL_SOCKET
|
||||
// level options.
|
||||
//
|
||||
// +stateify savable
|
||||
type SocketOptions struct {
|
||||
// mu protects fields below.
|
||||
mu sync.Mutex `state:"nosave"`
|
||||
broadcastEnabled bool
|
||||
// These fields are accessed and modified using atomic operations.
|
||||
|
||||
// broadcastEnabled determines whether datagram sockets are allowed to send
|
||||
// packets to a broadcast address.
|
||||
broadcastEnabled uint32
|
||||
|
||||
// passCredEnabled determines whether SCM_CREDENTIALS socket control messages
|
||||
// are enabled.
|
||||
passCredEnabled uint32
|
||||
}
|
||||
|
||||
func storeAtomicBool(addr *uint32, v bool) {
|
||||
var val uint32
|
||||
if v {
|
||||
val = 1
|
||||
}
|
||||
atomic.StoreUint32(addr, val)
|
||||
}
|
||||
|
||||
// GetBroadcast gets value for SO_BROADCAST option.
|
||||
func (so *SocketOptions) GetBroadcast() bool {
|
||||
so.mu.Lock()
|
||||
defer so.mu.Unlock()
|
||||
|
||||
return so.broadcastEnabled
|
||||
return atomic.LoadUint32(&so.broadcastEnabled) != 0
|
||||
}
|
||||
|
||||
// SetBroadcast sets value for SO_BROADCAST option.
|
||||
func (so *SocketOptions) SetBroadcast(v bool) {
|
||||
so.mu.Lock()
|
||||
defer so.mu.Unlock()
|
||||
|
||||
so.broadcastEnabled = v
|
||||
storeAtomicBool(&so.broadcastEnabled, v)
|
||||
}
|
||||
|
||||
// GetPassCred gets value for SO_PASSCRED option.
|
||||
func (so *SocketOptions) GetPassCred() bool {
|
||||
return atomic.LoadUint32(&so.passCredEnabled) != 0
|
||||
}
|
||||
|
||||
// SetPassCred sets value for SO_PASSCRED option.
|
||||
func (so *SocketOptions) SetPassCred(v bool) {
|
||||
storeAtomicBool(&so.passCredEnabled, v)
|
||||
}
|
||||
|
|
|
@ -721,12 +721,6 @@ const (
|
|||
// whether UDP checksum is disabled for this socket.
|
||||
NoChecksumOption
|
||||
|
||||
// PasscredOption is used by SetSockOptBool/GetSockOptBool to specify
|
||||
// whether SCM_CREDENTIALS socket control messages are enabled.
|
||||
//
|
||||
// Only supported on Unix sockets.
|
||||
PasscredOption
|
||||
|
||||
// QuickAckOption is stubbed out in SetSockOptBool/GetSockOptBool.
|
||||
QuickAckOption
|
||||
|
||||
|
|
|
@ -857,6 +857,7 @@ func (*endpoint) LastError() *tcpip.Error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// SocketOptions implements tcpip.Endpoint.SocketOptions.
|
||||
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
|
||||
return &e.ops
|
||||
}
|
||||
|
|
|
@ -756,10 +756,12 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
|
|||
// Wait implements stack.TransportEndpoint.Wait.
|
||||
func (*endpoint) Wait() {}
|
||||
|
||||
// LastError implements tcpip.Endpoint.LastError.
|
||||
func (*endpoint) LastError() *tcpip.Error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SocketOptions implements tcpip.Endpoint.SocketOptions.
|
||||
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
|
||||
return &e.ops
|
||||
}
|
||||
|
|
|
@ -1279,6 +1279,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {
|
|||
e.rcvListMu.Unlock()
|
||||
}
|
||||
|
||||
// SetOwner implements tcpip.Endpoint.SetOwner.
|
||||
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
|
||||
e.owner = owner
|
||||
}
|
||||
|
@ -1299,6 +1300,7 @@ func (e *endpoint) lastErrorLocked() *tcpip.Error {
|
|||
return err
|
||||
}
|
||||
|
||||
// LastError implements tcpip.Endpoint.LastError.
|
||||
func (e *endpoint) LastError() *tcpip.Error {
|
||||
e.LockUser()
|
||||
defer e.UnlockUser()
|
||||
|
@ -3213,6 +3215,7 @@ func (e *endpoint) Wait() {
|
|||
}
|
||||
}
|
||||
|
||||
// SocketOptions implements tcpip.Endpoint.SocketOptions.
|
||||
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
|
||||
return &e.ops
|
||||
}
|
||||
|
|
|
@ -1535,10 +1535,12 @@ func isBroadcastOrMulticast(a tcpip.Address) bool {
|
|||
return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a)
|
||||
}
|
||||
|
||||
// SetOwner implements tcpip.Endpoint.SetOwner.
|
||||
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
|
||||
e.owner = owner
|
||||
}
|
||||
|
||||
// SocketOptions implements tcpip.Endpoint.SocketOptions.
|
||||
func (e *endpoint) SocketOptions() *tcpip.SocketOptions {
|
||||
return &e.ops
|
||||
}
|
||||
|
|
|
@ -818,32 +818,38 @@ TEST_P(AllSocketPairTest, GetSockoptProtocol) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_P(AllSocketPairTest, GetSockoptBroadcast) {
|
||||
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
|
||||
int opt = -1;
|
||||
socklen_t optlen = sizeof(opt);
|
||||
EXPECT_THAT(
|
||||
getsockopt(sockets->first_fd(), SOL_SOCKET, SO_BROADCAST, &opt, &optlen),
|
||||
SyscallSucceeds());
|
||||
ASSERT_EQ(optlen, sizeof(opt));
|
||||
EXPECT_EQ(opt, 0);
|
||||
}
|
||||
TEST_P(AllSocketPairTest, SetAndGetBooleanSocketOptions) {
|
||||
int sock_opts[] = {SO_BROADCAST, SO_PASSCRED};
|
||||
for (int sock_opt : sock_opts) {
|
||||
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
|
||||
int enable = -1;
|
||||
socklen_t enableLen = sizeof(enable);
|
||||
|
||||
TEST_P(AllSocketPairTest, SetAndGetSocketBroadcastOption) {
|
||||
auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
|
||||
int kSockOptOn = 1;
|
||||
ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, SO_BROADCAST,
|
||||
&kSockOptOn, sizeof(kSockOptOn)),
|
||||
SyscallSucceedsWithValue(0));
|
||||
// Test that the option is initially set to false.
|
||||
ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, sock_opt, &enable,
|
||||
&enableLen),
|
||||
SyscallSucceeds());
|
||||
ASSERT_EQ(enableLen, sizeof(enable));
|
||||
EXPECT_EQ(enable, 0) << absl::StrFormat(
|
||||
"getsockopt(fd, SOL_SOCKET, %d, &enable, &enableLen) => enable=%d",
|
||||
sock_opt, enable);
|
||||
|
||||
int got = -1;
|
||||
socklen_t length = sizeof(got);
|
||||
ASSERT_THAT(
|
||||
getsockopt(sockets->first_fd(), SOL_SOCKET, SO_BROADCAST, &got, &length),
|
||||
SyscallSucceedsWithValue(0));
|
||||
|
||||
ASSERT_EQ(length, sizeof(got));
|
||||
EXPECT_EQ(got, kSockOptOn);
|
||||
// Test that setting the option to true is reflected in the subsequent
|
||||
// call to getsockopt(2).
|
||||
enable = 1;
|
||||
ASSERT_THAT(setsockopt(sockets->first_fd(), SOL_SOCKET, sock_opt, &enable,
|
||||
sizeof(enable)),
|
||||
SyscallSucceeds());
|
||||
enable = -1;
|
||||
enableLen = sizeof(enable);
|
||||
ASSERT_THAT(getsockopt(sockets->first_fd(), SOL_SOCKET, sock_opt, &enable,
|
||||
&enableLen),
|
||||
SyscallSucceeds());
|
||||
ASSERT_EQ(enableLen, sizeof(enable));
|
||||
EXPECT_EQ(enable, 1) << absl::StrFormat(
|
||||
"getsockopt(fd, SOL_SOCKET, %d, &enable, &enableLen) => enable=%d",
|
||||
sock_opt, enable);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace testing
|
||||
|
|
Loading…
Reference in New Issue