gvisor/test/syscalls/linux/socket_netlink_util.cc

188 lines
5.9 KiB
C++
Raw Normal View History

// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "test/syscalls/linux/socket_netlink_util.h"
#include <linux/if_arp.h>
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <sys/socket.h>
#include <vector>
#include "absl/strings/str_cat.h"
#include "test/syscalls/linux/socket_test_util.h"
namespace gvisor {
namespace testing {
PosixErrorOr<FileDescriptor> NetlinkBoundSocket(int protocol) {
FileDescriptor fd;
ASSIGN_OR_RETURN_ERRNO(fd, Socket(AF_NETLINK, SOCK_RAW, protocol));
struct sockaddr_nl addr = {};
addr.nl_family = AF_NETLINK;
RETURN_ERROR_IF_SYSCALL_FAIL(
bind(fd.get(), reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr)));
MaybeSave();
return std::move(fd);
}
PosixErrorOr<uint32_t> NetlinkPortID(int fd) {
struct sockaddr_nl addr;
socklen_t addrlen = sizeof(addr);
RETURN_ERROR_IF_SYSCALL_FAIL(
getsockname(fd, reinterpret_cast<struct sockaddr*>(&addr), &addrlen));
MaybeSave();
return static_cast<uint32_t>(addr.nl_pid);
}
PosixError NetlinkRequestResponse(
const FileDescriptor& fd, void* request, size_t len,
const std::function<void(const struct nlmsghdr* hdr)>& fn,
bool expect_nlmsgerr) {
struct iovec iov = {};
iov.iov_base = request;
iov.iov_len = len;
struct msghdr msg = {};
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
// No destination required; it defaults to pid 0, the kernel.
RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(sendmsg)(fd.get(), &msg, 0));
constexpr size_t kBufferSize = 4096;
std::vector<char> buf(kBufferSize);
iov.iov_base = buf.data();
iov.iov_len = buf.size();
// If NLM_F_MULTI is set, response is a series of messages that ends with a
// NLMSG_DONE message.
int type = -1;
int flags = 0;
do {
int len;
RETURN_ERROR_IF_SYSCALL_FAIL(len = RetryEINTR(recvmsg)(fd.get(), &msg, 0));
// We don't bother with the complexity of dealing with truncated messages.
// We must allocate a large enough buffer up front.
if ((msg.msg_flags & MSG_TRUNC) == MSG_TRUNC) {
return PosixError(EIO,
absl::StrCat("Received truncated message with flags: ",
msg.msg_flags));
}
for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data());
NLMSG_OK(hdr, len); hdr = NLMSG_NEXT(hdr, len)) {
fn(hdr);
flags = hdr->nlmsg_flags;
type = hdr->nlmsg_type;
// Done should include an integer payload for dump_done_errno.
// See net/netlink/af_netlink.c:netlink_dump
// Some tools like the 'ip' tool check the minimum length of the
// NLMSG_DONE message.
if (type == NLMSG_DONE) {
EXPECT_GE(hdr->nlmsg_len, NLMSG_LENGTH(sizeof(int)));
}
}
} while ((flags & NLM_F_MULTI) && type != NLMSG_DONE && type != NLMSG_ERROR);
if (expect_nlmsgerr) {
EXPECT_EQ(type, NLMSG_ERROR);
} else if (flags & NLM_F_MULTI) {
EXPECT_EQ(type, NLMSG_DONE);
}
return NoError();
}
PosixError NetlinkRequestResponseSingle(
const FileDescriptor& fd, void* request, size_t len,
const std::function<void(const struct nlmsghdr* hdr)>& fn) {
struct iovec iov = {};
iov.iov_base = request;
iov.iov_len = len;
struct msghdr msg = {};
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
// No destination required; it defaults to pid 0, the kernel.
RETURN_ERROR_IF_SYSCALL_FAIL(RetryEINTR(sendmsg)(fd.get(), &msg, 0));
constexpr size_t kBufferSize = 4096;
std::vector<char> buf(kBufferSize);
iov.iov_base = buf.data();
iov.iov_len = buf.size();
int ret;
RETURN_ERROR_IF_SYSCALL_FAIL(ret = RetryEINTR(recvmsg)(fd.get(), &msg, 0));
// We don't bother with the complexity of dealing with truncated messages.
// We must allocate a large enough buffer up front.
if ((msg.msg_flags & MSG_TRUNC) == MSG_TRUNC) {
return PosixError(
EIO,
absl::StrCat("Received truncated message with flags: ", msg.msg_flags));
}
for (struct nlmsghdr* hdr = reinterpret_cast<struct nlmsghdr*>(buf.data());
NLMSG_OK(hdr, ret); hdr = NLMSG_NEXT(hdr, ret)) {
fn(hdr);
}
return NoError();
}
PosixError NetlinkRequestAckOrError(const FileDescriptor& fd, uint32_t seq,
void* request, size_t len) {
// Dummy negative number for no error message received.
// We won't get a negative error number so there will be no confusion.
int err = -42;
RETURN_IF_ERRNO(NetlinkRequestResponse(
fd, request, len,
[&](const struct nlmsghdr* hdr) {
EXPECT_EQ(NLMSG_ERROR, hdr->nlmsg_type);
EXPECT_EQ(hdr->nlmsg_seq, seq);
EXPECT_GE(hdr->nlmsg_len, sizeof(*hdr) + sizeof(struct nlmsgerr));
const struct nlmsgerr* msg =
reinterpret_cast<const struct nlmsgerr*>(NLMSG_DATA(hdr));
err = -msg->error;
},
true));
return PosixError(err);
}
const struct rtattr* FindRtAttr(const struct nlmsghdr* hdr,
const struct ifinfomsg* msg, int16_t attr) {
const int ifi_space = NLMSG_SPACE(sizeof(*msg));
int attrlen = hdr->nlmsg_len - ifi_space;
const struct rtattr* rta = reinterpret_cast<const struct rtattr*>(
reinterpret_cast<const uint8_t*>(hdr) + NLMSG_ALIGN(ifi_space));
for (; RTA_OK(rta, attrlen); rta = RTA_NEXT(rta, attrlen)) {
if (rta->rta_type == attr) {
return rta;
}
}
return nullptr;
}
} // namespace testing
} // namespace gvisor