iptables: support postrouting hook and SNAT target
The current SNAT implementation has several limitations: - SNAT source port has to be specified. It is not optional. - SNAT source port range is not supported. - SNAT for UDP is a one-way translation. No response packets are handled (because conntrack doesn't support UDP currently). - SNAT and REDIRECT can't work on the same connection. Fixes #5489 PiperOrigin-RevId: 367750325
This commit is contained in:
parent
ea7faa5057
commit
d1edabdca0
|
@ -375,6 +375,17 @@ type XTRedirectTarget struct {
|
|||
// SizeOfXTRedirectTarget is the size of an XTRedirectTarget.
|
||||
const SizeOfXTRedirectTarget = 56
|
||||
|
||||
// XTSNATTarget triggers Source NAT when reached.
|
||||
// Adding 4 bytes of padding to make the struct 8 byte aligned.
|
||||
type XTSNATTarget struct {
|
||||
Target XTEntryTarget
|
||||
NfRange NfNATIPV4MultiRangeCompat
|
||||
_ [4]byte
|
||||
}
|
||||
|
||||
// SizeOfXTSNATTarget is the size of an XTSNATTarget.
|
||||
const SizeOfXTSNATTarget = 56
|
||||
|
||||
// IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds
|
||||
// to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h.
|
||||
//
|
||||
|
|
|
@ -274,10 +274,10 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
|
|||
}
|
||||
|
||||
// TODO(gvisor.dev/issue/170): Support other chains.
|
||||
// Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now,
|
||||
// make sure all other chains point to ACCEPT rules.
|
||||
// Since we don't support FORWARD, yet, make sure all other chains point to
|
||||
// ACCEPT rules.
|
||||
for hook, ruleIdx := range table.BuiltinChains {
|
||||
if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting {
|
||||
if hook := stack.Hook(hook); hook == stack.Forward {
|
||||
if ruleIdx == stack.HookUnset {
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -35,6 +35,11 @@ const ErrorTargetName = "ERROR"
|
|||
// change the destination port and/or IP for packets.
|
||||
const RedirectTargetName = "REDIRECT"
|
||||
|
||||
// SNATTargetName is used to mark targets as SNAT targets. SNAT targets should
|
||||
// be reached for only NAT table. These targets will change the source port
|
||||
// and/or IP for packets.
|
||||
const SNATTargetName = "SNAT"
|
||||
|
||||
func init() {
|
||||
// Standard targets include ACCEPT, DROP, RETURN, and JUMP.
|
||||
registerTargetMaker(&standardTargetMaker{
|
||||
|
@ -59,6 +64,13 @@ func init() {
|
|||
registerTargetMaker(&nfNATTargetMaker{
|
||||
NetworkProtocol: header.IPv6ProtocolNumber,
|
||||
})
|
||||
|
||||
registerTargetMaker(&snatTargetMakerV4{
|
||||
NetworkProtocol: header.IPv4ProtocolNumber,
|
||||
})
|
||||
registerTargetMaker(&snatTargetMakerV6{
|
||||
NetworkProtocol: header.IPv6ProtocolNumber,
|
||||
})
|
||||
}
|
||||
|
||||
// The stack package provides some basic, useful targets for us. The following
|
||||
|
@ -131,6 +143,17 @@ func (rt *redirectTarget) id() targetID {
|
|||
}
|
||||
}
|
||||
|
||||
type snatTarget struct {
|
||||
stack.SNATTarget
|
||||
}
|
||||
|
||||
func (st *snatTarget) id() targetID {
|
||||
return targetID{
|
||||
name: SNATTargetName,
|
||||
networkProtocol: st.NetworkProtocol,
|
||||
}
|
||||
}
|
||||
|
||||
type standardTargetMaker struct {
|
||||
NetworkProtocol tcpip.NetworkProtocolNumber
|
||||
}
|
||||
|
@ -341,7 +364,7 @@ type nfNATTarget struct {
|
|||
Range linux.NFNATRange
|
||||
}
|
||||
|
||||
const nfNATMarhsalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange
|
||||
const nfNATMarshalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange
|
||||
|
||||
type nfNATTargetMaker struct {
|
||||
NetworkProtocol tcpip.NetworkProtocolNumber
|
||||
|
@ -358,7 +381,7 @@ func (*nfNATTargetMaker) marshal(target target) []byte {
|
|||
rt := target.(*redirectTarget)
|
||||
nt := nfNATTarget{
|
||||
Target: linux.XTEntryTarget{
|
||||
TargetSize: nfNATMarhsalledSize,
|
||||
TargetSize: nfNATMarshalledSize,
|
||||
},
|
||||
Range: linux.NFNATRange{
|
||||
Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED,
|
||||
|
@ -371,12 +394,12 @@ func (*nfNATTargetMaker) marshal(target target) []byte {
|
|||
nt.Range.MinProto = htons(rt.Port)
|
||||
nt.Range.MaxProto = nt.Range.MinProto
|
||||
|
||||
ret := make([]byte, 0, nfNATMarhsalledSize)
|
||||
ret := make([]byte, 0, nfNATMarshalledSize)
|
||||
return binary.Marshal(ret, hostarch.ByteOrder, nt)
|
||||
}
|
||||
|
||||
func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
|
||||
if size := nfNATMarhsalledSize; len(buf) < size {
|
||||
if size := nfNATMarshalledSize; len(buf) < size {
|
||||
nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size)
|
||||
return nil, syserr.ErrInvalidArgument
|
||||
}
|
||||
|
@ -387,7 +410,7 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
|
|||
}
|
||||
|
||||
var natRange linux.NFNATRange
|
||||
buf = buf[linux.SizeOfXTEntryTarget:nfNATMarhsalledSize]
|
||||
buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize]
|
||||
binary.Unmarshal(buf, hostarch.ByteOrder, &natRange)
|
||||
|
||||
// We don't support port or address ranges.
|
||||
|
@ -418,6 +441,161 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
|
|||
return &target, nil
|
||||
}
|
||||
|
||||
type snatTargetMakerV4 struct {
|
||||
NetworkProtocol tcpip.NetworkProtocolNumber
|
||||
}
|
||||
|
||||
func (st *snatTargetMakerV4) id() targetID {
|
||||
return targetID{
|
||||
name: SNATTargetName,
|
||||
networkProtocol: st.NetworkProtocol,
|
||||
}
|
||||
}
|
||||
|
||||
func (*snatTargetMakerV4) marshal(target target) []byte {
|
||||
st := target.(*snatTarget)
|
||||
// This is a snat target named snat.
|
||||
xt := linux.XTSNATTarget{
|
||||
Target: linux.XTEntryTarget{
|
||||
TargetSize: linux.SizeOfXTSNATTarget,
|
||||
},
|
||||
}
|
||||
copy(xt.Target.Name[:], SNATTargetName)
|
||||
|
||||
xt.NfRange.RangeSize = 1
|
||||
xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_MAP_IPS | linux.NF_NAT_RANGE_PROTO_SPECIFIED
|
||||
xt.NfRange.RangeIPV4.MinPort = htons(st.Port)
|
||||
xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort
|
||||
copy(xt.NfRange.RangeIPV4.MinIP[:], st.Addr)
|
||||
copy(xt.NfRange.RangeIPV4.MaxIP[:], st.Addr)
|
||||
ret := make([]byte, 0, linux.SizeOfXTSNATTarget)
|
||||
return binary.Marshal(ret, hostarch.ByteOrder, xt)
|
||||
}
|
||||
|
||||
func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
|
||||
if len(buf) < linux.SizeOfXTSNATTarget {
|
||||
nflog("snatTargetMakerV4: buf has insufficient size for snat target %d", len(buf))
|
||||
return nil, syserr.ErrInvalidArgument
|
||||
}
|
||||
|
||||
if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
|
||||
nflog("snatTargetMakerV4: bad proto %d", p)
|
||||
return nil, syserr.ErrInvalidArgument
|
||||
}
|
||||
|
||||
var st linux.XTSNATTarget
|
||||
buf = buf[:linux.SizeOfXTSNATTarget]
|
||||
binary.Unmarshal(buf, hostarch.ByteOrder, &st)
|
||||
|
||||
// Copy linux.XTSNATTarget to stack.SNATTarget.
|
||||
target := snatTarget{SNATTarget: stack.SNATTarget{
|
||||
NetworkProtocol: filter.NetworkProtocol(),
|
||||
}}
|
||||
|
||||
// RangeSize should be 1.
|
||||
nfRange := st.NfRange
|
||||
if nfRange.RangeSize != 1 {
|
||||
nflog("snatTargetMakerV4: bad rangesize %d", nfRange.RangeSize)
|
||||
return nil, syserr.ErrInvalidArgument
|
||||
}
|
||||
|
||||
// TODO(gvisor.dev/issue/5772): If the rule doesn't specify the source port,
|
||||
// choose one automatically.
|
||||
if nfRange.RangeIPV4.MinPort == 0 {
|
||||
nflog("snatTargetMakerV4: snat target needs to specify a non-zero port")
|
||||
return nil, syserr.ErrInvalidArgument
|
||||
}
|
||||
|
||||
// TODO(gvisor.dev/issue/170): Port range is not supported yet.
|
||||
if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
|
||||
nflog("snatTargetMakerV4: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
|
||||
return nil, syserr.ErrInvalidArgument
|
||||
}
|
||||
if nfRange.RangeIPV4.MinIP != nfRange.RangeIPV4.MaxIP {
|
||||
nflog("snatTargetMakerV4: MinIP != MaxIP (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
|
||||
return nil, syserr.ErrInvalidArgument
|
||||
}
|
||||
|
||||
target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
|
||||
target.Port = ntohs(nfRange.RangeIPV4.MinPort)
|
||||
|
||||
return &target, nil
|
||||
}
|
||||
|
||||
type snatTargetMakerV6 struct {
|
||||
NetworkProtocol tcpip.NetworkProtocolNumber
|
||||
}
|
||||
|
||||
func (st *snatTargetMakerV6) id() targetID {
|
||||
return targetID{
|
||||
name: SNATTargetName,
|
||||
networkProtocol: st.NetworkProtocol,
|
||||
revision: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (*snatTargetMakerV6) marshal(target target) []byte {
|
||||
st := target.(*snatTarget)
|
||||
nt := nfNATTarget{
|
||||
Target: linux.XTEntryTarget{
|
||||
TargetSize: nfNATMarshalledSize,
|
||||
},
|
||||
Range: linux.NFNATRange{
|
||||
Flags: linux.NF_NAT_RANGE_MAP_IPS | linux.NF_NAT_RANGE_PROTO_SPECIFIED,
|
||||
},
|
||||
}
|
||||
copy(nt.Target.Name[:], SNATTargetName)
|
||||
copy(nt.Range.MinAddr[:], st.Addr)
|
||||
copy(nt.Range.MaxAddr[:], st.Addr)
|
||||
nt.Range.MinProto = htons(st.Port)
|
||||
nt.Range.MaxProto = nt.Range.MinProto
|
||||
|
||||
ret := make([]byte, 0, nfNATMarshalledSize)
|
||||
return binary.Marshal(ret, hostarch.ByteOrder, nt)
|
||||
}
|
||||
|
||||
func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
|
||||
if size := nfNATMarshalledSize; len(buf) < size {
|
||||
nflog("snatTargetMakerV6: buf has insufficient size (%d) for SNAT V6 target (%d)", len(buf), size)
|
||||
return nil, syserr.ErrInvalidArgument
|
||||
}
|
||||
|
||||
if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
|
||||
nflog("snatTargetMakerV6: bad proto %d", p)
|
||||
return nil, syserr.ErrInvalidArgument
|
||||
}
|
||||
|
||||
var natRange linux.NFNATRange
|
||||
buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize]
|
||||
binary.Unmarshal(buf, hostarch.ByteOrder, &natRange)
|
||||
|
||||
// TODO(gvisor.dev/issue/5689): Support port or address ranges.
|
||||
if natRange.MinAddr != natRange.MaxAddr {
|
||||
nflog("snatTargetMakerV6: MinAddr and MaxAddr are different")
|
||||
return nil, syserr.ErrInvalidArgument
|
||||
}
|
||||
if natRange.MinProto != natRange.MaxProto {
|
||||
nflog("snatTargetMakerV6: MinProto and MaxProto are different")
|
||||
return nil, syserr.ErrInvalidArgument
|
||||
}
|
||||
|
||||
// TODO(gvisor.dev/issue/5698): Support other NF_NAT_RANGE flags.
|
||||
if natRange.Flags != linux.NF_NAT_RANGE_MAP_IPS|linux.NF_NAT_RANGE_PROTO_SPECIFIED {
|
||||
nflog("snatTargetMakerV6: invalid range flags %d", natRange.Flags)
|
||||
return nil, syserr.ErrInvalidArgument
|
||||
}
|
||||
|
||||
target := snatTarget{
|
||||
SNATTarget: stack.SNATTarget{
|
||||
NetworkProtocol: filter.NetworkProtocol(),
|
||||
Addr: tcpip.Address(natRange.MinAddr[:]),
|
||||
Port: ntohs(natRange.MinProto),
|
||||
},
|
||||
}
|
||||
|
||||
return &target, nil
|
||||
}
|
||||
|
||||
// translateToStandardTarget translates from the value in a
|
||||
// linux.XTStandardTarget to an stack.Verdict.
|
||||
func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) {
|
||||
|
|
|
@ -21,53 +21,56 @@ import "gvisor.dev/gvisor/pkg/tcpip"
|
|||
// MultiCounterIPStats holds IP statistics, each counter may have several
|
||||
// versions.
|
||||
type MultiCounterIPStats struct {
|
||||
// PacketsReceived is the total number of IP packets received from the link
|
||||
// layer.
|
||||
// PacketsReceived is the number of IP packets received from the link layer.
|
||||
PacketsReceived tcpip.MultiCounterStat
|
||||
|
||||
// DisabledPacketsReceived is the total number of IP packets received from the
|
||||
// link layer when the IP layer is disabled.
|
||||
// DisabledPacketsReceived is the number of IP packets received from the link
|
||||
// layer when the IP layer is disabled.
|
||||
DisabledPacketsReceived tcpip.MultiCounterStat
|
||||
|
||||
// InvalidDestinationAddressesReceived is the total number of IP packets
|
||||
// received with an unknown or invalid destination address.
|
||||
// InvalidDestinationAddressesReceived is the number of IP packets received
|
||||
// with an unknown or invalid destination address.
|
||||
InvalidDestinationAddressesReceived tcpip.MultiCounterStat
|
||||
|
||||
// InvalidSourceAddressesReceived is the total number of IP packets received
|
||||
// with a source address that should never have been received on the wire.
|
||||
// InvalidSourceAddressesReceived is the number of IP packets received with a
|
||||
// source address that should never have been received on the wire.
|
||||
InvalidSourceAddressesReceived tcpip.MultiCounterStat
|
||||
|
||||
// PacketsDelivered is the total number of incoming IP packets that are
|
||||
// successfully delivered to the transport layer.
|
||||
// PacketsDelivered is the number of incoming IP packets that are successfully
|
||||
// delivered to the transport layer.
|
||||
PacketsDelivered tcpip.MultiCounterStat
|
||||
|
||||
// PacketsSent is the total number of IP packets sent via WritePacket.
|
||||
// PacketsSent is the number of IP packets sent via WritePacket.
|
||||
PacketsSent tcpip.MultiCounterStat
|
||||
|
||||
// OutgoingPacketErrors is the total number of IP packets which failed to
|
||||
// write to a link-layer endpoint.
|
||||
// OutgoingPacketErrors is the number of IP packets which failed to write to a
|
||||
// link-layer endpoint.
|
||||
OutgoingPacketErrors tcpip.MultiCounterStat
|
||||
|
||||
// MalformedPacketsReceived is the total number of IP Packets that were
|
||||
// dropped due to the IP packet header failing validation checks.
|
||||
// MalformedPacketsReceived is the number of IP Packets that were dropped due
|
||||
// to the IP packet header failing validation checks.
|
||||
MalformedPacketsReceived tcpip.MultiCounterStat
|
||||
|
||||
// MalformedFragmentsReceived is the total number of IP Fragments that were
|
||||
// dropped due to the fragment failing validation checks.
|
||||
// MalformedFragmentsReceived is the number of IP Fragments that were dropped
|
||||
// due to the fragment failing validation checks.
|
||||
MalformedFragmentsReceived tcpip.MultiCounterStat
|
||||
|
||||
// IPTablesPreroutingDropped is the total number of IP packets dropped in the
|
||||
// IPTablesPreroutingDropped is the number of IP packets dropped in the
|
||||
// Prerouting chain.
|
||||
IPTablesPreroutingDropped tcpip.MultiCounterStat
|
||||
|
||||
// IPTablesInputDropped is the total number of IP packets dropped in the Input
|
||||
// IPTablesInputDropped is the number of IP packets dropped in the Input
|
||||
// chain.
|
||||
IPTablesInputDropped tcpip.MultiCounterStat
|
||||
|
||||
// IPTablesOutputDropped is the total number of IP packets dropped in the
|
||||
// Output chain.
|
||||
// IPTablesOutputDropped is the number of IP packets dropped in the Output
|
||||
// chain.
|
||||
IPTablesOutputDropped tcpip.MultiCounterStat
|
||||
|
||||
// IPTablesPostroutingDropped is the number of IP packets dropped in the
|
||||
// Postrouting chain.
|
||||
IPTablesPostroutingDropped tcpip.MultiCounterStat
|
||||
|
||||
// TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out
|
||||
// of IPStats.
|
||||
|
||||
|
@ -98,6 +101,7 @@ func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
|
|||
m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped)
|
||||
m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped)
|
||||
m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped)
|
||||
m.IPTablesPostroutingDropped.Init(a.IPTablesPostroutingDropped, b.IPTablesPostroutingDropped)
|
||||
m.OptionTimestampReceived.Init(a.OptionTimestampReceived, b.OptionTimestampReceived)
|
||||
m.OptionRecordRouteReceived.Init(a.OptionRecordRouteReceived, b.OptionRecordRouteReceived)
|
||||
m.OptionRouterAlertReceived.Init(a.OptionRouterAlertReceived, b.OptionRouterAlertReceived)
|
||||
|
|
|
@ -415,6 +415,15 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
|
|||
return nil
|
||||
}
|
||||
|
||||
// Postrouting NAT can only change the source address, and does not alter the
|
||||
// route or outgoing interface of the packet.
|
||||
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
|
||||
if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
|
||||
// iptables is telling us to drop the packet.
|
||||
e.stats.ip.IPTablesPostroutingDropped.Increment()
|
||||
return nil
|
||||
}
|
||||
|
||||
stats := e.stats.ip
|
||||
|
||||
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
|
||||
|
@ -486,9 +495,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
|
|||
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
|
||||
// iptables filtering. All packets that reach here are locally
|
||||
// generated.
|
||||
dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName)
|
||||
stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
|
||||
for pkt := range dropped {
|
||||
outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
|
||||
stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
|
||||
for pkt := range outputDropped {
|
||||
pkts.Remove(pkt)
|
||||
}
|
||||
|
||||
|
@ -510,6 +519,15 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
|
|||
|
||||
}
|
||||
|
||||
// We ignore the list of NAT-ed packets here because Postrouting NAT can only
|
||||
// change the source address, and does not alter the route or outgoing
|
||||
// interface of the packet.
|
||||
postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, gso, r, "" /* inNicName */, outNicName)
|
||||
stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
|
||||
for pkt := range postroutingDropped {
|
||||
pkts.Remove(pkt)
|
||||
}
|
||||
|
||||
// The rest of the packets can be delivered to the NIC as a batch.
|
||||
pktsLen := pkts.Len()
|
||||
written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
|
||||
|
@ -517,7 +535,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
|
|||
stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))
|
||||
|
||||
// Dropped packets aren't errors, so include them in the return value.
|
||||
return locallyDelivered + written + len(dropped), err
|
||||
return locallyDelivered + written + len(outputDropped) + len(postroutingDropped), err
|
||||
}
|
||||
|
||||
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
|
||||
|
|
|
@ -2612,34 +2612,36 @@ func TestWriteStats(t *testing.T) {
|
|||
const nPackets = 3
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(*testing.T, *stack.Stack)
|
||||
allowPackets int
|
||||
expectSent int
|
||||
expectDropped int
|
||||
expectWritten int
|
||||
name string
|
||||
setup func(*testing.T, *stack.Stack)
|
||||
allowPackets int
|
||||
expectSent int
|
||||
expectOutputDropped int
|
||||
expectPostroutingDropped int
|
||||
expectWritten int
|
||||
}{
|
||||
{
|
||||
name: "Accept all",
|
||||
// No setup needed, tables accept everything by default.
|
||||
setup: func(*testing.T, *stack.Stack) {},
|
||||
allowPackets: math.MaxInt32,
|
||||
expectSent: nPackets,
|
||||
expectDropped: 0,
|
||||
expectWritten: nPackets,
|
||||
setup: func(*testing.T, *stack.Stack) {},
|
||||
allowPackets: math.MaxInt32,
|
||||
expectSent: nPackets,
|
||||
expectOutputDropped: 0,
|
||||
expectPostroutingDropped: 0,
|
||||
expectWritten: nPackets,
|
||||
}, {
|
||||
name: "Accept all with error",
|
||||
// No setup needed, tables accept everything by default.
|
||||
setup: func(*testing.T, *stack.Stack) {},
|
||||
allowPackets: nPackets - 1,
|
||||
expectSent: nPackets - 1,
|
||||
expectDropped: 0,
|
||||
expectWritten: nPackets - 1,
|
||||
setup: func(*testing.T, *stack.Stack) {},
|
||||
allowPackets: nPackets - 1,
|
||||
expectSent: nPackets - 1,
|
||||
expectOutputDropped: 0,
|
||||
expectPostroutingDropped: 0,
|
||||
expectWritten: nPackets - 1,
|
||||
}, {
|
||||
name: "Drop all",
|
||||
name: "Drop all with Output chain",
|
||||
setup: func(t *testing.T, stk *stack.Stack) {
|
||||
// Install Output DROP rule.
|
||||
t.Helper()
|
||||
ipt := stk.IPTables()
|
||||
filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
|
||||
ruleIdx := filter.BuiltinChains[stack.Output]
|
||||
|
@ -2648,16 +2650,32 @@ func TestWriteStats(t *testing.T) {
|
|||
t.Fatalf("failed to replace table: %s", err)
|
||||
}
|
||||
},
|
||||
allowPackets: math.MaxInt32,
|
||||
expectSent: 0,
|
||||
expectDropped: nPackets,
|
||||
expectWritten: nPackets,
|
||||
allowPackets: math.MaxInt32,
|
||||
expectSent: 0,
|
||||
expectOutputDropped: nPackets,
|
||||
expectPostroutingDropped: 0,
|
||||
expectWritten: nPackets,
|
||||
}, {
|
||||
name: "Drop some",
|
||||
name: "Drop all with Postrouting chain",
|
||||
setup: func(t *testing.T, stk *stack.Stack) {
|
||||
ipt := stk.IPTables()
|
||||
filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
|
||||
ruleIdx := filter.BuiltinChains[stack.Postrouting]
|
||||
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
|
||||
if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
|
||||
t.Fatalf("failed to replace table: %s", err)
|
||||
}
|
||||
},
|
||||
allowPackets: math.MaxInt32,
|
||||
expectSent: 0,
|
||||
expectOutputDropped: 0,
|
||||
expectPostroutingDropped: nPackets,
|
||||
expectWritten: nPackets,
|
||||
}, {
|
||||
name: "Drop some with Output chain",
|
||||
setup: func(t *testing.T, stk *stack.Stack) {
|
||||
// Install Output DROP rule that matches only 1
|
||||
// of the 3 packets.
|
||||
t.Helper()
|
||||
ipt := stk.IPTables()
|
||||
filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
|
||||
// We'll match and DROP the last packet.
|
||||
|
@ -2670,10 +2688,33 @@ func TestWriteStats(t *testing.T) {
|
|||
t.Fatalf("failed to replace table: %s", err)
|
||||
}
|
||||
},
|
||||
allowPackets: math.MaxInt32,
|
||||
expectSent: nPackets - 1,
|
||||
expectDropped: 1,
|
||||
expectWritten: nPackets,
|
||||
allowPackets: math.MaxInt32,
|
||||
expectSent: nPackets - 1,
|
||||
expectOutputDropped: 1,
|
||||
expectPostroutingDropped: 0,
|
||||
expectWritten: nPackets,
|
||||
}, {
|
||||
name: "Drop some with Postrouting chain",
|
||||
setup: func(t *testing.T, stk *stack.Stack) {
|
||||
// Install Postrouting DROP rule that matches only 1
|
||||
// of the 3 packets.
|
||||
ipt := stk.IPTables()
|
||||
filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
|
||||
// We'll match and DROP the last packet.
|
||||
ruleIdx := filter.BuiltinChains[stack.Postrouting]
|
||||
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
|
||||
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
|
||||
// Make sure the next rule is ACCEPT.
|
||||
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
|
||||
if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
|
||||
t.Fatalf("failed to replace table: %s", err)
|
||||
}
|
||||
},
|
||||
allowPackets: math.MaxInt32,
|
||||
expectSent: nPackets - 1,
|
||||
expectOutputDropped: 0,
|
||||
expectPostroutingDropped: 1,
|
||||
expectWritten: nPackets,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -2724,13 +2765,16 @@ func TestWriteStats(t *testing.T) {
|
|||
nWritten, _ := writer.writePackets(rt, pkts)
|
||||
|
||||
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
|
||||
t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
|
||||
t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent)
|
||||
}
|
||||
if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
|
||||
t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
|
||||
if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped {
|
||||
t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped)
|
||||
}
|
||||
if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped {
|
||||
t.Errorf("got rt.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped)
|
||||
}
|
||||
if nWritten != test.expectWritten {
|
||||
t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
|
||||
t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -769,6 +769,15 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
|
|||
return nil
|
||||
}
|
||||
|
||||
// Postrouting NAT can only change the source address, and does not alter the
|
||||
// route or outgoing interface of the packet.
|
||||
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
|
||||
if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
|
||||
// iptables is telling us to drop the packet.
|
||||
e.stats.ip.IPTablesPostroutingDropped.Increment()
|
||||
return nil
|
||||
}
|
||||
|
||||
stats := e.stats.ip
|
||||
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
|
||||
if err != nil {
|
||||
|
@ -840,9 +849,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
|
|||
// iptables filtering. All packets that reach here are locally
|
||||
// generated.
|
||||
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
|
||||
dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
|
||||
stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
|
||||
for pkt := range dropped {
|
||||
outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
|
||||
stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
|
||||
for pkt := range outputDropped {
|
||||
pkts.Remove(pkt)
|
||||
}
|
||||
|
||||
|
@ -863,6 +872,15 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
|
|||
locallyDelivered++
|
||||
}
|
||||
|
||||
// We ignore the list of NAT-ed packets here because Postrouting NAT can only
|
||||
// change the source address, and does not alter the route or outgoing
|
||||
// interface of the packet.
|
||||
postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, gso, r, "" /* inNicName */, outNicName)
|
||||
stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
|
||||
for pkt := range postroutingDropped {
|
||||
pkts.Remove(pkt)
|
||||
}
|
||||
|
||||
// The rest of the packets can be delivered to the NIC as a batch.
|
||||
pktsLen := pkts.Len()
|
||||
written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
|
||||
|
@ -870,7 +888,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
|
|||
stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))
|
||||
|
||||
// Dropped packets aren't errors, so include them in the return value.
|
||||
return locallyDelivered + written + len(dropped), err
|
||||
return locallyDelivered + written + len(outputDropped) + len(postroutingDropped), err
|
||||
}
|
||||
|
||||
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.
|
||||
|
|
|
@ -2468,34 +2468,36 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
|
|||
func TestWriteStats(t *testing.T) {
|
||||
const nPackets = 3
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(*testing.T, *stack.Stack)
|
||||
allowPackets int
|
||||
expectSent int
|
||||
expectDropped int
|
||||
expectWritten int
|
||||
name string
|
||||
setup func(*testing.T, *stack.Stack)
|
||||
allowPackets int
|
||||
expectSent int
|
||||
expectOutputDropped int
|
||||
expectPostroutingDropped int
|
||||
expectWritten int
|
||||
}{
|
||||
{
|
||||
name: "Accept all",
|
||||
// No setup needed, tables accept everything by default.
|
||||
setup: func(*testing.T, *stack.Stack) {},
|
||||
allowPackets: math.MaxInt32,
|
||||
expectSent: nPackets,
|
||||
expectDropped: 0,
|
||||
expectWritten: nPackets,
|
||||
setup: func(*testing.T, *stack.Stack) {},
|
||||
allowPackets: math.MaxInt32,
|
||||
expectSent: nPackets,
|
||||
expectOutputDropped: 0,
|
||||
expectPostroutingDropped: 0,
|
||||
expectWritten: nPackets,
|
||||
}, {
|
||||
name: "Accept all with error",
|
||||
// No setup needed, tables accept everything by default.
|
||||
setup: func(*testing.T, *stack.Stack) {},
|
||||
allowPackets: nPackets - 1,
|
||||
expectSent: nPackets - 1,
|
||||
expectDropped: 0,
|
||||
expectWritten: nPackets - 1,
|
||||
setup: func(*testing.T, *stack.Stack) {},
|
||||
allowPackets: nPackets - 1,
|
||||
expectSent: nPackets - 1,
|
||||
expectOutputDropped: 0,
|
||||
expectPostroutingDropped: 0,
|
||||
expectWritten: nPackets - 1,
|
||||
}, {
|
||||
name: "Drop all",
|
||||
name: "Drop all with Output chain",
|
||||
setup: func(t *testing.T, stk *stack.Stack) {
|
||||
// Install Output DROP rule.
|
||||
t.Helper()
|
||||
ipt := stk.IPTables()
|
||||
filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
|
||||
ruleIdx := filter.BuiltinChains[stack.Output]
|
||||
|
@ -2504,16 +2506,33 @@ func TestWriteStats(t *testing.T) {
|
|||
t.Fatalf("failed to replace table: %v", err)
|
||||
}
|
||||
},
|
||||
allowPackets: math.MaxInt32,
|
||||
expectSent: 0,
|
||||
expectDropped: nPackets,
|
||||
expectWritten: nPackets,
|
||||
allowPackets: math.MaxInt32,
|
||||
expectSent: 0,
|
||||
expectOutputDropped: nPackets,
|
||||
expectPostroutingDropped: 0,
|
||||
expectWritten: nPackets,
|
||||
}, {
|
||||
name: "Drop some",
|
||||
name: "Drop all with Postrouting chain",
|
||||
setup: func(t *testing.T, stk *stack.Stack) {
|
||||
// Install Output DROP rule.
|
||||
ipt := stk.IPTables()
|
||||
filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
|
||||
ruleIdx := filter.BuiltinChains[stack.Postrouting]
|
||||
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
|
||||
if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
|
||||
t.Fatalf("failed to replace table: %v", err)
|
||||
}
|
||||
},
|
||||
allowPackets: math.MaxInt32,
|
||||
expectSent: 0,
|
||||
expectOutputDropped: 0,
|
||||
expectPostroutingDropped: nPackets,
|
||||
expectWritten: nPackets,
|
||||
}, {
|
||||
name: "Drop some with Output chain",
|
||||
setup: func(t *testing.T, stk *stack.Stack) {
|
||||
// Install Output DROP rule that matches only 1
|
||||
// of the 3 packets.
|
||||
t.Helper()
|
||||
ipt := stk.IPTables()
|
||||
filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
|
||||
// We'll match and DROP the last packet.
|
||||
|
@ -2526,10 +2545,33 @@ func TestWriteStats(t *testing.T) {
|
|||
t.Fatalf("failed to replace table: %v", err)
|
||||
}
|
||||
},
|
||||
allowPackets: math.MaxInt32,
|
||||
expectSent: nPackets - 1,
|
||||
expectDropped: 1,
|
||||
expectWritten: nPackets,
|
||||
allowPackets: math.MaxInt32,
|
||||
expectSent: nPackets - 1,
|
||||
expectOutputDropped: 1,
|
||||
expectPostroutingDropped: 0,
|
||||
expectWritten: nPackets,
|
||||
}, {
|
||||
name: "Drop some with Postrouting chain",
|
||||
setup: func(t *testing.T, stk *stack.Stack) {
|
||||
// Install Postrouting DROP rule that matches only 1
|
||||
// of the 3 packets.
|
||||
ipt := stk.IPTables()
|
||||
filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
|
||||
// We'll match and DROP the last packet.
|
||||
ruleIdx := filter.BuiltinChains[stack.Postrouting]
|
||||
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
|
||||
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
|
||||
// Make sure the next rule is ACCEPT.
|
||||
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
|
||||
if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
|
||||
t.Fatalf("failed to replace table: %v", err)
|
||||
}
|
||||
},
|
||||
allowPackets: math.MaxInt32,
|
||||
expectSent: nPackets - 1,
|
||||
expectOutputDropped: 0,
|
||||
expectPostroutingDropped: 1,
|
||||
expectWritten: nPackets,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -2578,13 +2620,16 @@ func TestWriteStats(t *testing.T) {
|
|||
nWritten, _ := writer.writePackets(rt, pkts)
|
||||
|
||||
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
|
||||
t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
|
||||
t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent)
|
||||
}
|
||||
if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
|
||||
t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
|
||||
if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped {
|
||||
t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped)
|
||||
}
|
||||
if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped {
|
||||
t.Errorf("got r.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped)
|
||||
}
|
||||
if nWritten != test.expectWritten {
|
||||
t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
|
||||
t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -45,6 +45,7 @@ go_library(
|
|||
"addressable_endpoint_state.go",
|
||||
"conntrack.go",
|
||||
"headertype_string.go",
|
||||
"hook_string.go",
|
||||
"icmp_rate_limit.go",
|
||||
"iptables.go",
|
||||
"iptables_state.go",
|
||||
|
|
|
@ -16,6 +16,7 @@ package stack
|
|||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
@ -29,7 +30,7 @@ import (
|
|||
// The connection is created for a packet if it does not exist. Every
|
||||
// connection contains two tuples (original and reply). The tuples are
|
||||
// manipulated if there is a matching NAT rule. The packet is modified by
|
||||
// looking at the tuples in the Prerouting and Output hooks.
|
||||
// looking at the tuples in each hook.
|
||||
//
|
||||
// Currently, only TCP tracking is supported.
|
||||
|
||||
|
@ -46,12 +47,14 @@ const (
|
|||
)
|
||||
|
||||
// Manipulation type for the connection.
|
||||
// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and
|
||||
// DNAT at the same time.
|
||||
type manipType int
|
||||
|
||||
const (
|
||||
manipNone manipType = iota
|
||||
manipDstPrerouting
|
||||
manipDstOutput
|
||||
manipSource
|
||||
manipDestination
|
||||
)
|
||||
|
||||
// tuple holds a connection's identifying and manipulating data in one
|
||||
|
@ -108,6 +111,7 @@ type conn struct {
|
|||
reply tuple
|
||||
|
||||
// manip indicates if the packet should be manipulated. It is immutable.
|
||||
// TODO(gvisor.dev/issue/5696): Support updating manipulation type.
|
||||
manip manipType
|
||||
|
||||
// tcbHook indicates if the packet is inbound or outbound to
|
||||
|
@ -124,6 +128,18 @@ type conn struct {
|
|||
lastUsed time.Time `state:".(unixTime)"`
|
||||
}
|
||||
|
||||
// newConn creates new connection.
|
||||
func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
|
||||
conn := conn{
|
||||
manip: manip,
|
||||
tcbHook: hook,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
conn.original = tuple{conn: &conn, tupleID: orig}
|
||||
conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
|
||||
return &conn
|
||||
}
|
||||
|
||||
// timedOut returns whether the connection timed out based on its state.
|
||||
func (cn *conn) timedOut(now time.Time) bool {
|
||||
const establishedTimeout = 5 * 24 * time.Hour
|
||||
|
@ -219,18 +235,6 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
// newConn creates new connection.
|
||||
func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
|
||||
conn := conn{
|
||||
manip: manip,
|
||||
tcbHook: hook,
|
||||
lastUsed: time.Now(),
|
||||
}
|
||||
conn.original = tuple{conn: &conn, tupleID: orig}
|
||||
conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
|
||||
return &conn
|
||||
}
|
||||
|
||||
func (ct *ConnTrack) init() {
|
||||
ct.mu.Lock()
|
||||
defer ct.mu.Unlock()
|
||||
|
@ -284,20 +288,41 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint1
|
|||
return nil
|
||||
}
|
||||
|
||||
// Create a new connection and change the port as per the iptables
|
||||
// rule. This tuple will be used to manipulate the packet in
|
||||
// handlePacket.
|
||||
replyTID := tid.reply()
|
||||
replyTID.srcAddr = address
|
||||
replyTID.srcPort = port
|
||||
var manip manipType
|
||||
switch hook {
|
||||
case Prerouting:
|
||||
manip = manipDstPrerouting
|
||||
case Output:
|
||||
manip = manipDstOutput
|
||||
|
||||
conn, _ := ct.connForTID(tid)
|
||||
if conn != nil {
|
||||
// The connection is already tracked.
|
||||
// TODO(gvisor.dev/issue/5696): Support updating an existing connection.
|
||||
return nil
|
||||
}
|
||||
conn := newConn(tid, replyTID, manip, hook)
|
||||
conn = newConn(tid, replyTID, manipDestination, hook)
|
||||
ct.insertConn(conn)
|
||||
return conn
|
||||
}
|
||||
|
||||
func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
|
||||
tid, err := packetToTupleID(pkt)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if hook != Input && hook != Postrouting {
|
||||
return nil
|
||||
}
|
||||
|
||||
replyTID := tid.reply()
|
||||
replyTID.dstAddr = address
|
||||
replyTID.dstPort = port
|
||||
|
||||
conn, _ := ct.connForTID(tid)
|
||||
if conn != nil {
|
||||
// The connection is already tracked.
|
||||
// TODO(gvisor.dev/issue/5696): Support updating an existing connection.
|
||||
return nil
|
||||
}
|
||||
conn = newConn(tid, replyTID, manipSource, hook)
|
||||
ct.insertConn(conn)
|
||||
return conn
|
||||
}
|
||||
|
@ -322,6 +347,7 @@ func (ct *ConnTrack) insertConn(conn *conn) {
|
|||
|
||||
// Now that we hold the locks, ensure the tuple hasn't been inserted by
|
||||
// another thread.
|
||||
// TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too?
|
||||
alreadyInserted := false
|
||||
for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
|
||||
if other.tupleID == conn.original.tupleID {
|
||||
|
@ -343,86 +369,6 @@ func (ct *ConnTrack) insertConn(conn *conn) {
|
|||
}
|
||||
}
|
||||
|
||||
// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
|
||||
// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
|
||||
func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
|
||||
// If this is a noop entry, don't do anything.
|
||||
if conn.manip == manipNone {
|
||||
return
|
||||
}
|
||||
|
||||
netHeader := pkt.Network()
|
||||
tcpHeader := header.TCP(pkt.TransportHeader().View())
|
||||
|
||||
// For prerouting redirection, packets going in the original direction
|
||||
// have their destinations modified and replies have their sources
|
||||
// modified.
|
||||
switch dir {
|
||||
case dirOriginal:
|
||||
port := conn.reply.srcPort
|
||||
tcpHeader.SetDestinationPort(port)
|
||||
netHeader.SetDestinationAddress(conn.reply.srcAddr)
|
||||
case dirReply:
|
||||
port := conn.original.dstPort
|
||||
tcpHeader.SetSourcePort(port)
|
||||
netHeader.SetSourceAddress(conn.original.dstAddr)
|
||||
}
|
||||
|
||||
// TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated
|
||||
// on inbound packets, so we don't recalculate them. However, we should
|
||||
// support cases when they are validated, e.g. when we can't offload
|
||||
// receive checksumming.
|
||||
|
||||
// After modification, IPv4 packets need a valid checksum.
|
||||
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
|
||||
netHeader := header.IPv4(pkt.NetworkHeader().View())
|
||||
netHeader.SetChecksum(0)
|
||||
netHeader.SetChecksum(^netHeader.CalculateChecksum())
|
||||
}
|
||||
}
|
||||
|
||||
// handlePacketOutput manipulates ports for packets in Output hook.
|
||||
func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
|
||||
// If this is a noop entry, don't do anything.
|
||||
if conn.manip == manipNone {
|
||||
return
|
||||
}
|
||||
|
||||
netHeader := pkt.Network()
|
||||
tcpHeader := header.TCP(pkt.TransportHeader().View())
|
||||
|
||||
// For output redirection, packets going in the original direction
|
||||
// have their destinations modified and replies have their sources
|
||||
// modified. For prerouting redirection, we only reach this point
|
||||
// when replying, so packet sources are modified.
|
||||
if conn.manip == manipDstOutput && dir == dirOriginal {
|
||||
port := conn.reply.srcPort
|
||||
tcpHeader.SetDestinationPort(port)
|
||||
netHeader.SetDestinationAddress(conn.reply.srcAddr)
|
||||
} else {
|
||||
port := conn.original.dstPort
|
||||
tcpHeader.SetSourcePort(port)
|
||||
netHeader.SetSourceAddress(conn.original.dstAddr)
|
||||
}
|
||||
|
||||
// Calculate the TCP checksum and set it.
|
||||
tcpHeader.SetChecksum(0)
|
||||
length := uint16(len(tcpHeader) + pkt.Data().Size())
|
||||
xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
|
||||
if gso != nil && gso.NeedsCsum {
|
||||
tcpHeader.SetChecksum(xsum)
|
||||
} else if r.RequiresTXTransportChecksum() {
|
||||
xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
|
||||
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
|
||||
}
|
||||
|
||||
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
|
||||
netHeader := header.IPv4(pkt.NetworkHeader().View())
|
||||
netHeader.SetChecksum(0)
|
||||
netHeader.SetChecksum(^netHeader.CalculateChecksum())
|
||||
}
|
||||
}
|
||||
|
||||
// handlePacket will manipulate the port and address of the packet if the
|
||||
// connection exists. Returns whether, after the packet traverses the tables,
|
||||
// it should create a new entry in the table.
|
||||
|
@ -431,7 +377,9 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
|
|||
return false
|
||||
}
|
||||
|
||||
if hook != Prerouting && hook != Output {
|
||||
switch hook {
|
||||
case Prerouting, Input, Output, Postrouting:
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -441,23 +389,79 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
|
|||
}
|
||||
|
||||
conn, dir := ct.connFor(pkt)
|
||||
// Connection or Rule not found for the packet.
|
||||
// Connection not found for the packet.
|
||||
if conn == nil {
|
||||
return true
|
||||
// If this is the last hook in the data path for this packet (Input if
|
||||
// incoming, Postrouting if outgoing), indicate that a connection should be
|
||||
// inserted by the end of this hook.
|
||||
return hook == Input || hook == Postrouting
|
||||
}
|
||||
|
||||
netHeader := pkt.Network()
|
||||
tcpHeader := header.TCP(pkt.TransportHeader().View())
|
||||
if len(tcpHeader) < header.TCPMinimumSize {
|
||||
return false
|
||||
}
|
||||
|
||||
// TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
|
||||
// validated if checksum offloading is off. It may require IP defrag if the
|
||||
// packets are fragmented.
|
||||
|
||||
switch hook {
|
||||
case Prerouting:
|
||||
handlePacketPrerouting(pkt, conn, dir)
|
||||
case Output:
|
||||
handlePacketOutput(pkt, conn, gso, r, dir)
|
||||
case Prerouting, Output:
|
||||
if conn.manip == manipDestination {
|
||||
switch dir {
|
||||
case dirOriginal:
|
||||
tcpHeader.SetDestinationPort(conn.reply.srcPort)
|
||||
netHeader.SetDestinationAddress(conn.reply.srcAddr)
|
||||
case dirReply:
|
||||
tcpHeader.SetSourcePort(conn.original.dstPort)
|
||||
netHeader.SetSourceAddress(conn.original.dstAddr)
|
||||
}
|
||||
pkt.NatDone = true
|
||||
}
|
||||
case Input, Postrouting:
|
||||
if conn.manip == manipSource {
|
||||
switch dir {
|
||||
case dirOriginal:
|
||||
tcpHeader.SetSourcePort(conn.reply.dstPort)
|
||||
netHeader.SetSourceAddress(conn.reply.dstAddr)
|
||||
case dirReply:
|
||||
tcpHeader.SetDestinationPort(conn.original.srcPort)
|
||||
netHeader.SetDestinationAddress(conn.original.srcAddr)
|
||||
}
|
||||
pkt.NatDone = true
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("unrecognized hook = %s", hook))
|
||||
}
|
||||
if !pkt.NatDone {
|
||||
return false
|
||||
}
|
||||
|
||||
switch hook {
|
||||
case Prerouting, Input:
|
||||
case Output, Postrouting:
|
||||
// Calculate the TCP checksum and set it.
|
||||
tcpHeader.SetChecksum(0)
|
||||
length := uint16(len(tcpHeader) + pkt.Data().Size())
|
||||
xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
|
||||
if gso != nil && gso.NeedsCsum {
|
||||
tcpHeader.SetChecksum(xsum)
|
||||
} else if r.RequiresTXTransportChecksum() {
|
||||
xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
|
||||
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("unrecognized hook = %s", hook))
|
||||
}
|
||||
|
||||
// After modification, IPv4 packets need a valid checksum.
|
||||
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
|
||||
netHeader := header.IPv4(pkt.NetworkHeader().View())
|
||||
netHeader.SetChecksum(0)
|
||||
netHeader.SetChecksum(^netHeader.CalculateChecksum())
|
||||
}
|
||||
pkt.NatDone = true
|
||||
|
||||
// Update the state of tcb.
|
||||
// TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
|
||||
|
@ -638,8 +642,8 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
|
|||
if conn == nil {
|
||||
// Not a tracked connection.
|
||||
return "", 0, &tcpip.ErrNotConnected{}
|
||||
} else if conn.manip == manipNone {
|
||||
// Unmanipulated connection.
|
||||
} else if conn.manip != manipDestination {
|
||||
// Unmanipulated destination.
|
||||
return "", 0, &tcpip.ErrInvalidOptionValue{}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright 2021 The gVisor Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at //
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Code generated by "stringer -type Hook ."; DO NOT EDIT.
|
||||
|
||||
package stack
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[Prerouting-0]
|
||||
_ = x[Input-1]
|
||||
_ = x[Forward-2]
|
||||
_ = x[Output-3]
|
||||
_ = x[Postrouting-4]
|
||||
_ = x[NumHooks-5]
|
||||
}
|
||||
|
||||
const _Hook_name = "PreroutingInputForwardOutputPostroutingNumHooks"
|
||||
|
||||
var _Hook_index = [...]uint8{0, 10, 15, 22, 28, 39, 47}
|
||||
|
||||
func (i Hook) String() string {
|
||||
if i >= Hook(len(_Hook_index)-1) {
|
||||
return "Hook(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _Hook_name[_Hook_index[i]:_Hook_index[i+1]]
|
||||
}
|
|
@ -175,9 +175,10 @@ func DefaultTables() *IPTables {
|
|||
},
|
||||
},
|
||||
priorities: [NumHooks][]TableID{
|
||||
Prerouting: {MangleID, NATID},
|
||||
Input: {NATID, FilterID},
|
||||
Output: {MangleID, NATID, FilterID},
|
||||
Prerouting: {MangleID, NATID},
|
||||
Input: {NATID, FilterID},
|
||||
Output: {MangleID, NATID, FilterID},
|
||||
Postrouting: {MangleID, NATID},
|
||||
},
|
||||
connections: ConnTrack{
|
||||
seed: generateRandUint32(),
|
||||
|
|
|
@ -182,3 +182,81 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
|
|||
|
||||
return RuleAccept, 0
|
||||
}
|
||||
|
||||
// SNATTarget modifies the source port/IP in the outgoing packets.
|
||||
type SNATTarget struct {
|
||||
Addr tcpip.Address
|
||||
Port uint16
|
||||
|
||||
// NetworkProtocol is the network protocol the target is used with. It
|
||||
// is immutable.
|
||||
NetworkProtocol tcpip.NetworkProtocolNumber
|
||||
}
|
||||
|
||||
// Action implements Target.Action.
|
||||
func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
|
||||
// Sanity check.
|
||||
if st.NetworkProtocol != pkt.NetworkProtocolNumber {
|
||||
panic(fmt.Sprintf(
|
||||
"SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
|
||||
st.NetworkProtocol, pkt.NetworkProtocolNumber))
|
||||
}
|
||||
|
||||
// Packet is already manipulated.
|
||||
if pkt.NatDone {
|
||||
return RuleAccept, 0
|
||||
}
|
||||
|
||||
// Drop the packet if network and transport header are not set.
|
||||
if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
|
||||
return RuleDrop, 0
|
||||
}
|
||||
|
||||
switch hook {
|
||||
case Postrouting, Input:
|
||||
case Prerouting, Output, Forward:
|
||||
panic(fmt.Sprintf("%s not supported", hook))
|
||||
default:
|
||||
panic(fmt.Sprintf("%s unrecognized", hook))
|
||||
}
|
||||
|
||||
switch protocol := pkt.TransportProtocolNumber; protocol {
|
||||
case header.UDPProtocolNumber:
|
||||
udpHeader := header.UDP(pkt.TransportHeader().View())
|
||||
udpHeader.SetChecksum(0)
|
||||
udpHeader.SetSourcePort(st.Port)
|
||||
netHeader := pkt.Network()
|
||||
netHeader.SetSourceAddress(st.Addr)
|
||||
|
||||
// Only calculate the checksum if offloading isn't supported.
|
||||
if r.RequiresTXTransportChecksum() {
|
||||
length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
|
||||
xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
|
||||
xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
|
||||
udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
|
||||
}
|
||||
|
||||
// After modification, IPv4 packets need a valid checksum.
|
||||
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
|
||||
netHeader := header.IPv4(pkt.NetworkHeader().View())
|
||||
netHeader.SetChecksum(0)
|
||||
netHeader.SetChecksum(^netHeader.CalculateChecksum())
|
||||
}
|
||||
pkt.NatDone = true
|
||||
case header.TCPProtocolNumber:
|
||||
if ct == nil {
|
||||
return RuleAccept, 0
|
||||
}
|
||||
|
||||
// Set up conection for matching NAT rule. Only the first
|
||||
// packet of the connection comes here. Other packets will be
|
||||
// manipulated in connection tracking.
|
||||
if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil {
|
||||
ct.handlePacket(pkt, hook, gso, r)
|
||||
}
|
||||
default:
|
||||
return RuleDrop, 0
|
||||
}
|
||||
|
||||
return RuleAccept, 0
|
||||
}
|
||||
|
|
|
@ -299,9 +299,18 @@ func (pk *PacketBuffer) Network() header.Network {
|
|||
// See PacketBuffer.Data for details about how a packet buffer holds an inbound
|
||||
// packet.
|
||||
func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
|
||||
return NewPacketBuffer(PacketBufferOptions{
|
||||
newPk := NewPacketBuffer(PacketBufferOptions{
|
||||
Data: buffer.NewVectorisedView(pk.Size(), pk.Views()),
|
||||
})
|
||||
// TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
|
||||
// maintain this flag in the packet. Currently conntrack needs this flag to
|
||||
// tell if a noop connection should be inserted at Input hook. Once conntrack
|
||||
// redefines the manipulation field as mutable, we won't need the special noop
|
||||
// connection.
|
||||
if pk.NatDone {
|
||||
newPk.NatDone = true
|
||||
}
|
||||
return newPk
|
||||
}
|
||||
|
||||
// headerInfo stores metadata about a header in a packet.
|
||||
|
|
|
@ -1556,6 +1556,10 @@ type IPStats struct {
|
|||
// chain.
|
||||
IPTablesOutputDropped *StatCounter
|
||||
|
||||
// IPTablesPostroutingDropped is the number of IP packets dropped in the
|
||||
// Postrouting chain.
|
||||
IPTablesPostroutingDropped *StatCounter
|
||||
|
||||
// TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out
|
||||
// of IPStats.
|
||||
// OptionTimestampReceived is the number of Timestamp options seen.
|
||||
|
|
|
@ -456,3 +456,11 @@ func TestNATPreRECVORIGDSTADDR(t *testing.T) {
|
|||
func TestNATOutRECVORIGDSTADDR(t *testing.T) {
|
||||
singleTest(t, &NATOutRECVORIGDSTADDR{})
|
||||
}
|
||||
|
||||
func TestNATPostSNATUDP(t *testing.T) {
|
||||
singleTest(t, &NATPostSNATUDP{})
|
||||
}
|
||||
|
||||
func TestNATPostSNATTCP(t *testing.T) {
|
||||
singleTest(t, &NATPostSNATTCP{})
|
||||
}
|
||||
|
|
|
@ -69,29 +69,41 @@ func tableRules(ipv6 bool, table string, argsList [][]string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// listenUDP listens on a UDP port and returns the value of net.Conn.Read() for
|
||||
// the first read on that port.
|
||||
// listenUDP listens on a UDP port and returns nil if the first read from that
|
||||
// port is successful.
|
||||
func listenUDP(ctx context.Context, port int, ipv6 bool) error {
|
||||
_, err := listenUDPFrom(ctx, port, ipv6)
|
||||
return err
|
||||
}
|
||||
|
||||
// listenUDPFrom listens on a UDP port and returns the sender's UDP address if
|
||||
// the first read from that port is successful.
|
||||
func listenUDPFrom(ctx context.Context, port int, ipv6 bool) (*net.UDPAddr, error) {
|
||||
localAddr := net.UDPAddr{
|
||||
Port: port,
|
||||
}
|
||||
conn, err := net.ListenUDP(udpNetwork(ipv6), &localAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
ch := make(chan error)
|
||||
type result struct {
|
||||
remoteAddr *net.UDPAddr
|
||||
err error
|
||||
}
|
||||
|
||||
ch := make(chan result)
|
||||
go func() {
|
||||
_, err = conn.Read([]byte{0})
|
||||
ch <- err
|
||||
_, remoteAddr, err := conn.ReadFromUDP([]byte{0})
|
||||
ch <- result{remoteAddr, err}
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-ch:
|
||||
return err
|
||||
case res := <-ch:
|
||||
return res.remoteAddr, res.err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
return nil, fmt.Errorf("timed out reading from %s: %w", &localAddr, ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -125,8 +137,16 @@ func sendUDPLoop(ctx context.Context, ip net.IP, port int, ipv6 bool) error {
|
|||
}
|
||||
}
|
||||
|
||||
// listenTCP listens for connections on a TCP port.
|
||||
// listenTCP listens for connections on a TCP port, and returns nil if a
|
||||
// connection is established.
|
||||
func listenTCP(ctx context.Context, port int, ipv6 bool) error {
|
||||
_, err := listenTCPFrom(ctx, port, ipv6)
|
||||
return err
|
||||
}
|
||||
|
||||
// listenTCP listens for connections on a TCP port, and returns the remote
|
||||
// TCP address if a connection is established.
|
||||
func listenTCPFrom(ctx context.Context, port int, ipv6 bool) (net.Addr, error) {
|
||||
localAddr := net.TCPAddr{
|
||||
Port: port,
|
||||
}
|
||||
|
@ -134,23 +154,32 @@ func listenTCP(ctx context.Context, port int, ipv6 bool) error {
|
|||
// Starts listening on port.
|
||||
lConn, err := net.ListenTCP(tcpNetwork(ipv6), &localAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
defer lConn.Close()
|
||||
|
||||
type result struct {
|
||||
remoteAddr net.Addr
|
||||
err error
|
||||
}
|
||||
|
||||
// Accept connections on port.
|
||||
ch := make(chan error)
|
||||
ch := make(chan result)
|
||||
go func() {
|
||||
conn, err := lConn.AcceptTCP()
|
||||
ch <- err
|
||||
var remoteAddr net.Addr
|
||||
if err == nil {
|
||||
remoteAddr = conn.RemoteAddr()
|
||||
}
|
||||
ch <- result{remoteAddr, err}
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-ch:
|
||||
return err
|
||||
case res := <-ch:
|
||||
return res.remoteAddr, res.err
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("timed out waiting for a connection at %#v: %w", localAddr, ctx.Err())
|
||||
return nil, fmt.Errorf("timed out waiting for a connection at %s: %w", &localAddr, ctx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"gvisor.dev/gvisor/pkg/binary"
|
||||
|
@ -48,6 +49,8 @@ func init() {
|
|||
RegisterTestCase(&NATOutOriginalDst{})
|
||||
RegisterTestCase(&NATPreRECVORIGDSTADDR{})
|
||||
RegisterTestCase(&NATOutRECVORIGDSTADDR{})
|
||||
RegisterTestCase(&NATPostSNATUDP{})
|
||||
RegisterTestCase(&NATPostSNATTCP{})
|
||||
}
|
||||
|
||||
// NATPreRedirectUDPPort tests that packets are redirected to different port.
|
||||
|
@ -486,7 +489,12 @@ func (*NATLoopbackSkipsPrerouting) Name() string {
|
|||
// ContainerAction implements TestCase.ContainerAction.
|
||||
func (*NATLoopbackSkipsPrerouting) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
|
||||
// Redirect anything sent to localhost to an unused port.
|
||||
dest := []byte{127, 0, 0, 1}
|
||||
var dest net.IP
|
||||
if ipv6 {
|
||||
dest = net.IPv6loopback
|
||||
} else {
|
||||
dest = net.IPv4(127, 0, 0, 1)
|
||||
}
|
||||
if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -915,3 +923,115 @@ func addrMatches6(got unix.RawSockaddrInet6, wantAddrs []net.IP, port uint16) er
|
|||
}
|
||||
return fmt.Errorf("got %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, wantAddrs)
|
||||
}
|
||||
|
||||
const (
|
||||
snatAddrV4 = "194.236.50.155"
|
||||
snatAddrV6 = "2a0a::1"
|
||||
snatPort = 43
|
||||
)
|
||||
|
||||
// NATPostSNATUDP tests that the source port/IP in the packets are modified as expected.
|
||||
type NATPostSNATUDP struct{ localCase }
|
||||
|
||||
var _ TestCase = (*NATPostSNATUDP)(nil)
|
||||
|
||||
// Name implements TestCase.Name.
|
||||
func (*NATPostSNATUDP) Name() string {
|
||||
return "NATPostSNATUDP"
|
||||
}
|
||||
|
||||
// ContainerAction implements TestCase.ContainerAction.
|
||||
func (*NATPostSNATUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
|
||||
var source string
|
||||
if ipv6 {
|
||||
source = fmt.Sprintf("[%s]:%d", snatAddrV6, snatPort)
|
||||
} else {
|
||||
source = fmt.Sprintf("%s:%d", snatAddrV4, snatPort)
|
||||
}
|
||||
|
||||
if err := natTable(ipv6, "-A", "POSTROUTING", "-p", "udp", "-j", "SNAT", "--to-source", source); err != nil {
|
||||
return err
|
||||
}
|
||||
return sendUDPLoop(ctx, ip, acceptPort, ipv6)
|
||||
}
|
||||
|
||||
// LocalAction implements TestCase.LocalAction.
|
||||
func (*NATPostSNATUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
|
||||
remote, err := listenUDPFrom(ctx, acceptPort, ipv6)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var snatAddr string
|
||||
if ipv6 {
|
||||
snatAddr = snatAddrV6
|
||||
} else {
|
||||
snatAddr = snatAddrV4
|
||||
}
|
||||
if got, want := remote.IP, net.ParseIP(snatAddr); !got.Equal(want) {
|
||||
return fmt.Errorf("got remote address = %s, want = %s", got, want)
|
||||
}
|
||||
if got, want := remote.Port, snatPort; got != want {
|
||||
return fmt.Errorf("got remote port = %d, want = %d", got, want)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NATPostSNATTCP tests that the source port/IP in the packets are modified as
|
||||
// expected.
|
||||
type NATPostSNATTCP struct{ localCase }
|
||||
|
||||
var _ TestCase = (*NATPostSNATTCP)(nil)
|
||||
|
||||
// Name implements TestCase.Name.
|
||||
func (*NATPostSNATTCP) Name() string {
|
||||
return "NATPostSNATTCP"
|
||||
}
|
||||
|
||||
// ContainerAction implements TestCase.ContainerAction.
|
||||
func (*NATPostSNATTCP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
|
||||
addrs, err := getInterfaceAddrs(ipv6)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var source string
|
||||
for _, addr := range addrs {
|
||||
if addr.To4() != nil {
|
||||
if !ipv6 {
|
||||
source = fmt.Sprintf("%s:%d", addr, snatPort)
|
||||
}
|
||||
} else if ipv6 && addr.IsGlobalUnicast() {
|
||||
source = fmt.Sprintf("[%s]:%d", addr, snatPort)
|
||||
}
|
||||
}
|
||||
if source == "" {
|
||||
return fmt.Errorf("can't find any interface address to use")
|
||||
}
|
||||
|
||||
if err := natTable(ipv6, "-A", "POSTROUTING", "-p", "tcp", "-j", "SNAT", "--to-source", source); err != nil {
|
||||
return err
|
||||
}
|
||||
return connectTCP(ctx, ip, acceptPort, ipv6)
|
||||
}
|
||||
|
||||
// LocalAction implements TestCase.LocalAction.
|
||||
func (*NATPostSNATTCP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
|
||||
remote, err := listenTCPFrom(ctx, acceptPort, ipv6)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
HostStr, portStr, err := net.SplitHostPort(remote.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if got, want := HostStr, ip.String(); got != want {
|
||||
return fmt.Errorf("got remote address = %s, want = %s", got, want)
|
||||
}
|
||||
port, err := strconv.ParseInt(portStr, 10, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if got, want := int(port), snatPort; got != want {
|
||||
return fmt.Errorf("got remote port = %d, want = %d", got, want)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue