diff --git a/pkg/abi/linux/netlink.go b/pkg/abi/linux/netlink.go index e8b6544b4..0ba086c76 100644 --- a/pkg/abi/linux/netlink.go +++ b/pkg/abi/linux/netlink.go @@ -122,3 +122,9 @@ const ( NETLINK_EXT_ACK = 11 NETLINK_DUMP_STRICT_CHK = 12 ) + +// NetlinkErrorMessage is struct nlmsgerr, from uapi/linux/netlink.h. +type NetlinkErrorMessage struct { + Error int32 + Header NetlinkMessageHeader +} diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index eccbd527a..d0aab293d 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -511,6 +511,19 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error return nil } +func (s *Socket) dumpErrorMesage(ctx context.Context, hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) *syserr.Error { + m := ms.AddMessage(linux.NetlinkMessageHeader{ + Type: linux.NLMSG_ERROR, + }) + + m.Put(linux.NetlinkErrorMessage{ + Error: int32(-err.ToLinux().Number()), + Header: hdr, + }) + return nil + +} + // processMessages handles each message in buf, passing it to the protocol // handler for final handling. func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error { @@ -545,14 +558,20 @@ func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error continue } + ms := NewMessageSet(s.portID, hdr.Seq) + var err *syserr.Error // TODO(b/68877377): ACKs not supported yet. if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK { - return syserr.ErrNotSupported - } + err = syserr.ErrNotSupported + } else { - ms := NewMessageSet(s.portID, hdr.Seq) - if err := s.protocol.ProcessMessage(ctx, hdr, data, ms); err != nil { - return err + err = s.protocol.ProcessMessage(ctx, hdr, data, ms) + } + if err != nil { + ms = NewMessageSet(s.portID, hdr.Seq) + if err := s.dumpErrorMesage(ctx, hdr, ms, err); err != nil { + return err + } } if err := s.sendResponse(ctx, ms); err != nil { diff --git a/test/syscalls/linux/socket_netdevice.cc b/test/syscalls/linux/socket_netdevice.cc index 6a5fa8965..765f8e0e4 100644 --- a/test/syscalls/linux/socket_netdevice.cc +++ b/test/syscalls/linux/socket_netdevice.cc @@ -89,7 +89,8 @@ TEST(NetdeviceTest, Netmask) { // (i.e. netmask) for the loopback device. int prefixlen = -1; ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), [&](const struct nlmsghdr *hdr) { + fd, &req, sizeof(req), + [&](const struct nlmsghdr *hdr) { EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWADDR), Eq(NLMSG_DONE))); EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI) @@ -111,7 +112,8 @@ TEST(NetdeviceTest, Netmask) { ifaddrmsg->ifa_family == AF_INET) { prefixlen = ifaddrmsg->ifa_prefixlen; } - })); + }, + false)); ASSERT_GE(prefixlen, 0); diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index ac7e0bd3e..32fe0d6d1 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -238,7 +238,8 @@ TEST(NetlinkRouteTest, GetLinkDump) { // Loopback is common among all tests, check that it's found. bool loopbackFound = false; ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { CheckGetLinkResponse(hdr, kSeq, port); if (hdr->nlmsg_type != RTM_NEWLINK) { return; @@ -252,10 +253,44 @@ TEST(NetlinkRouteTest, GetLinkDump) { loopbackFound = true; EXPECT_NE(msg->ifi_flags & IFF_LOOPBACK, 0); } - })); + }, + false)); EXPECT_TRUE(loopbackFound); } +TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + constexpr uint32_t kSeq = 12345; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + // If type & 0x3 is equal to 0x2, this means a get request + // which doesn't require CAP_SYS_ADMIN. + req.hdr.nlmsg_type = ((__RTM_MAX + 1024) & (~0x3)) | 0x2; + req.hdr.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + + ASSERT_NO_ERRNO(NetlinkRequestResponse( + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { + EXPECT_THAT(hdr->nlmsg_type, Eq(NLMSG_ERROR)); + EXPECT_EQ(hdr->nlmsg_seq, kSeq); + EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr)); + + const struct nlmsgerr* msg = + reinterpret_cast(NLMSG_DATA(hdr)); + EXPECT_EQ(msg->error, -EOPNOTSUPP); + }, + true)); +} + TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket()); @@ -364,9 +399,11 @@ TEST(NetlinkRouteTest, ControlMessageIgnored) { req.ifm.ifi_family = AF_UNSPEC; ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { CheckGetLinkResponse(hdr, kSeq, port); - })); + }, + false)); } TEST(NetlinkRouteTest, GetAddrDump) { @@ -388,7 +425,8 @@ TEST(NetlinkRouteTest, GetAddrDump) { req.rgm.rtgen_family = AF_UNSPEC; ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWADDR), Eq(NLMSG_DONE))); EXPECT_TRUE((hdr->nlmsg_flags & NLM_F_MULTI) == NLM_F_MULTI) @@ -405,7 +443,8 @@ TEST(NetlinkRouteTest, GetAddrDump) { EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct ifaddrmsg)); // TODO(mpratt): Check ifaddrmsg contents and following attrs. - })); + }, + false)); } TEST(NetlinkRouteTest, LookupAll) { @@ -448,7 +487,8 @@ TEST(NetlinkRouteTest, GetRouteDump) { bool routeFound = false; bool dstFound = true; ASSERT_NO_ERRNO(NetlinkRequestResponse( - fd, &req, sizeof(req), [&](const struct nlmsghdr* hdr) { + fd, &req, sizeof(req), + [&](const struct nlmsghdr* hdr) { // Validate the reponse to RTM_GETROUTE + NLM_F_DUMP. EXPECT_THAT(hdr->nlmsg_type, AnyOf(Eq(RTM_NEWROUTE), Eq(NLMSG_DONE))); @@ -491,7 +531,8 @@ TEST(NetlinkRouteTest, GetRouteDump) { routeFound = true; dstFound = rtDstFound && dstFound; } - })); + }, + false)); // At least one route found in main route table. EXPECT_TRUE(routeFound); // Found RTA_DST for each route in main table. diff --git a/test/syscalls/linux/socket_netlink_util.cc b/test/syscalls/linux/socket_netlink_util.cc index 728d25434..36b6560c2 100644 --- a/test/syscalls/linux/socket_netlink_util.cc +++ b/test/syscalls/linux/socket_netlink_util.cc @@ -54,7 +54,8 @@ PosixErrorOr NetlinkPortID(int fd) { PosixError NetlinkRequestResponse( const FileDescriptor& fd, void* request, size_t len, - const std::function& fn) { + const std::function& fn, + bool expect_nlmsgerr) { struct iovec iov = {}; iov.iov_base = request; iov.iov_len = len; @@ -93,7 +94,11 @@ PosixError NetlinkRequestResponse( } } while (type != NLMSG_DONE && type != NLMSG_ERROR); - EXPECT_EQ(type, NLMSG_DONE); + if (expect_nlmsgerr) { + EXPECT_EQ(type, NLMSG_ERROR); + } else { + EXPECT_EQ(type, NLMSG_DONE); + } return NoError(); } diff --git a/test/syscalls/linux/socket_netlink_util.h b/test/syscalls/linux/socket_netlink_util.h index bea449107..db8639a2f 100644 --- a/test/syscalls/linux/socket_netlink_util.h +++ b/test/syscalls/linux/socket_netlink_util.h @@ -34,7 +34,8 @@ PosixErrorOr NetlinkPortID(int fd); // Send the passed request and call fn will all response netlink messages. PosixError NetlinkRequestResponse( const FileDescriptor& fd, void* request, size_t len, - const std::function& fn); + const std::function& fn, + bool expect_nlmsgerr); } // namespace testing } // namespace gvisor