iptables: support postrouting hook and SNAT target

The current SNAT implementation has several limitations:
- SNAT source port has to be specified. It is not optional.
- SNAT source port range is not supported.
- SNAT for UDP is a one-way translation. No response packets
  are handled (because conntrack doesn't support UDP currently).
- SNAT and REDIRECT can't work on the same connection.

Fixes #5489

PiperOrigin-RevId: 367750325
This commit is contained in:
Toshi Kikuchi 2021-04-09 21:09:47 -07:00 committed by gVisor bot
parent ea7faa5057
commit d1edabdca0
18 changed files with 850 additions and 237 deletions

View File

@ -375,6 +375,17 @@ type XTRedirectTarget struct {
// SizeOfXTRedirectTarget is the size of an XTRedirectTarget.
const SizeOfXTRedirectTarget = 56
// XTSNATTarget triggers Source NAT when reached.
// Adding 4 bytes of padding to make the struct 8 byte aligned.
type XTSNATTarget struct {
Target XTEntryTarget
NfRange NfNATIPV4MultiRangeCompat
_ [4]byte
}
// SizeOfXTSNATTarget is the size of an XTSNATTarget.
const SizeOfXTSNATTarget = 56
// IPTGetinfo is the argument for the IPT_SO_GET_INFO sockopt. It corresponds
// to struct ipt_getinfo in include/uapi/linux/netfilter_ipv4/ip_tables.h.
//

View File

@ -274,10 +274,10 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
}
// TODO(gvisor.dev/issue/170): Support other chains.
// Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now,
// make sure all other chains point to ACCEPT rules.
// Since we don't support FORWARD, yet, make sure all other chains point to
// ACCEPT rules.
for hook, ruleIdx := range table.BuiltinChains {
if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting {
if hook := stack.Hook(hook); hook == stack.Forward {
if ruleIdx == stack.HookUnset {
continue
}

View File

@ -35,6 +35,11 @@ const ErrorTargetName = "ERROR"
// change the destination port and/or IP for packets.
const RedirectTargetName = "REDIRECT"
// SNATTargetName is used to mark targets as SNAT targets. SNAT targets should
// be reached for only NAT table. These targets will change the source port
// and/or IP for packets.
const SNATTargetName = "SNAT"
func init() {
// Standard targets include ACCEPT, DROP, RETURN, and JUMP.
registerTargetMaker(&standardTargetMaker{
@ -59,6 +64,13 @@ func init() {
registerTargetMaker(&nfNATTargetMaker{
NetworkProtocol: header.IPv6ProtocolNumber,
})
registerTargetMaker(&snatTargetMakerV4{
NetworkProtocol: header.IPv4ProtocolNumber,
})
registerTargetMaker(&snatTargetMakerV6{
NetworkProtocol: header.IPv6ProtocolNumber,
})
}
// The stack package provides some basic, useful targets for us. The following
@ -131,6 +143,17 @@ func (rt *redirectTarget) id() targetID {
}
}
type snatTarget struct {
stack.SNATTarget
}
func (st *snatTarget) id() targetID {
return targetID{
name: SNATTargetName,
networkProtocol: st.NetworkProtocol,
}
}
type standardTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
@ -341,7 +364,7 @@ type nfNATTarget struct {
Range linux.NFNATRange
}
const nfNATMarhsalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange
const nfNATMarshalledSize = linux.SizeOfXTEntryTarget + linux.SizeOfNFNATRange
type nfNATTargetMaker struct {
NetworkProtocol tcpip.NetworkProtocolNumber
@ -358,7 +381,7 @@ func (*nfNATTargetMaker) marshal(target target) []byte {
rt := target.(*redirectTarget)
nt := nfNATTarget{
Target: linux.XTEntryTarget{
TargetSize: nfNATMarhsalledSize,
TargetSize: nfNATMarshalledSize,
},
Range: linux.NFNATRange{
Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED,
@ -371,12 +394,12 @@ func (*nfNATTargetMaker) marshal(target target) []byte {
nt.Range.MinProto = htons(rt.Port)
nt.Range.MaxProto = nt.Range.MinProto
ret := make([]byte, 0, nfNATMarhsalledSize)
ret := make([]byte, 0, nfNATMarshalledSize)
return binary.Marshal(ret, hostarch.ByteOrder, nt)
}
func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if size := nfNATMarhsalledSize; len(buf) < size {
if size := nfNATMarshalledSize; len(buf) < size {
nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size)
return nil, syserr.ErrInvalidArgument
}
@ -387,7 +410,7 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
}
var natRange linux.NFNATRange
buf = buf[linux.SizeOfXTEntryTarget:nfNATMarhsalledSize]
buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize]
binary.Unmarshal(buf, hostarch.ByteOrder, &natRange)
// We don't support port or address ranges.
@ -418,6 +441,161 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
return &target, nil
}
type snatTargetMakerV4 struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
func (st *snatTargetMakerV4) id() targetID {
return targetID{
name: SNATTargetName,
networkProtocol: st.NetworkProtocol,
}
}
func (*snatTargetMakerV4) marshal(target target) []byte {
st := target.(*snatTarget)
// This is a snat target named snat.
xt := linux.XTSNATTarget{
Target: linux.XTEntryTarget{
TargetSize: linux.SizeOfXTSNATTarget,
},
}
copy(xt.Target.Name[:], SNATTargetName)
xt.NfRange.RangeSize = 1
xt.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_MAP_IPS | linux.NF_NAT_RANGE_PROTO_SPECIFIED
xt.NfRange.RangeIPV4.MinPort = htons(st.Port)
xt.NfRange.RangeIPV4.MaxPort = xt.NfRange.RangeIPV4.MinPort
copy(xt.NfRange.RangeIPV4.MinIP[:], st.Addr)
copy(xt.NfRange.RangeIPV4.MaxIP[:], st.Addr)
ret := make([]byte, 0, linux.SizeOfXTSNATTarget)
return binary.Marshal(ret, hostarch.ByteOrder, xt)
}
func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if len(buf) < linux.SizeOfXTSNATTarget {
nflog("snatTargetMakerV4: buf has insufficient size for snat target %d", len(buf))
return nil, syserr.ErrInvalidArgument
}
if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
nflog("snatTargetMakerV4: bad proto %d", p)
return nil, syserr.ErrInvalidArgument
}
var st linux.XTSNATTarget
buf = buf[:linux.SizeOfXTSNATTarget]
binary.Unmarshal(buf, hostarch.ByteOrder, &st)
// Copy linux.XTSNATTarget to stack.SNATTarget.
target := snatTarget{SNATTarget: stack.SNATTarget{
NetworkProtocol: filter.NetworkProtocol(),
}}
// RangeSize should be 1.
nfRange := st.NfRange
if nfRange.RangeSize != 1 {
nflog("snatTargetMakerV4: bad rangesize %d", nfRange.RangeSize)
return nil, syserr.ErrInvalidArgument
}
// TODO(gvisor.dev/issue/5772): If the rule doesn't specify the source port,
// choose one automatically.
if nfRange.RangeIPV4.MinPort == 0 {
nflog("snatTargetMakerV4: snat target needs to specify a non-zero port")
return nil, syserr.ErrInvalidArgument
}
// TODO(gvisor.dev/issue/170): Port range is not supported yet.
if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
nflog("snatTargetMakerV4: MinPort != MaxPort (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
return nil, syserr.ErrInvalidArgument
}
if nfRange.RangeIPV4.MinIP != nfRange.RangeIPV4.MaxIP {
nflog("snatTargetMakerV4: MinIP != MaxIP (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
return nil, syserr.ErrInvalidArgument
}
target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
target.Port = ntohs(nfRange.RangeIPV4.MinPort)
return &target, nil
}
type snatTargetMakerV6 struct {
NetworkProtocol tcpip.NetworkProtocolNumber
}
func (st *snatTargetMakerV6) id() targetID {
return targetID{
name: SNATTargetName,
networkProtocol: st.NetworkProtocol,
revision: 1,
}
}
func (*snatTargetMakerV6) marshal(target target) []byte {
st := target.(*snatTarget)
nt := nfNATTarget{
Target: linux.XTEntryTarget{
TargetSize: nfNATMarshalledSize,
},
Range: linux.NFNATRange{
Flags: linux.NF_NAT_RANGE_MAP_IPS | linux.NF_NAT_RANGE_PROTO_SPECIFIED,
},
}
copy(nt.Target.Name[:], SNATTargetName)
copy(nt.Range.MinAddr[:], st.Addr)
copy(nt.Range.MaxAddr[:], st.Addr)
nt.Range.MinProto = htons(st.Port)
nt.Range.MaxProto = nt.Range.MinProto
ret := make([]byte, 0, nfNATMarshalledSize)
return binary.Marshal(ret, hostarch.ByteOrder, nt)
}
func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) {
if size := nfNATMarshalledSize; len(buf) < size {
nflog("snatTargetMakerV6: buf has insufficient size (%d) for SNAT V6 target (%d)", len(buf), size)
return nil, syserr.ErrInvalidArgument
}
if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
nflog("snatTargetMakerV6: bad proto %d", p)
return nil, syserr.ErrInvalidArgument
}
var natRange linux.NFNATRange
buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize]
binary.Unmarshal(buf, hostarch.ByteOrder, &natRange)
// TODO(gvisor.dev/issue/5689): Support port or address ranges.
if natRange.MinAddr != natRange.MaxAddr {
nflog("snatTargetMakerV6: MinAddr and MaxAddr are different")
return nil, syserr.ErrInvalidArgument
}
if natRange.MinProto != natRange.MaxProto {
nflog("snatTargetMakerV6: MinProto and MaxProto are different")
return nil, syserr.ErrInvalidArgument
}
// TODO(gvisor.dev/issue/5698): Support other NF_NAT_RANGE flags.
if natRange.Flags != linux.NF_NAT_RANGE_MAP_IPS|linux.NF_NAT_RANGE_PROTO_SPECIFIED {
nflog("snatTargetMakerV6: invalid range flags %d", natRange.Flags)
return nil, syserr.ErrInvalidArgument
}
target := snatTarget{
SNATTarget: stack.SNATTarget{
NetworkProtocol: filter.NetworkProtocol(),
Addr: tcpip.Address(natRange.MinAddr[:]),
Port: ntohs(natRange.MinProto),
},
}
return &target, nil
}
// translateToStandardTarget translates from the value in a
// linux.XTStandardTarget to an stack.Verdict.
func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) {

View File

@ -21,53 +21,56 @@ import "gvisor.dev/gvisor/pkg/tcpip"
// MultiCounterIPStats holds IP statistics, each counter may have several
// versions.
type MultiCounterIPStats struct {
// PacketsReceived is the total number of IP packets received from the link
// layer.
// PacketsReceived is the number of IP packets received from the link layer.
PacketsReceived tcpip.MultiCounterStat
// DisabledPacketsReceived is the total number of IP packets received from the
// link layer when the IP layer is disabled.
// DisabledPacketsReceived is the number of IP packets received from the link
// layer when the IP layer is disabled.
DisabledPacketsReceived tcpip.MultiCounterStat
// InvalidDestinationAddressesReceived is the total number of IP packets
// received with an unknown or invalid destination address.
// InvalidDestinationAddressesReceived is the number of IP packets received
// with an unknown or invalid destination address.
InvalidDestinationAddressesReceived tcpip.MultiCounterStat
// InvalidSourceAddressesReceived is the total number of IP packets received
// with a source address that should never have been received on the wire.
// InvalidSourceAddressesReceived is the number of IP packets received with a
// source address that should never have been received on the wire.
InvalidSourceAddressesReceived tcpip.MultiCounterStat
// PacketsDelivered is the total number of incoming IP packets that are
// successfully delivered to the transport layer.
// PacketsDelivered is the number of incoming IP packets that are successfully
// delivered to the transport layer.
PacketsDelivered tcpip.MultiCounterStat
// PacketsSent is the total number of IP packets sent via WritePacket.
// PacketsSent is the number of IP packets sent via WritePacket.
PacketsSent tcpip.MultiCounterStat
// OutgoingPacketErrors is the total number of IP packets which failed to
// write to a link-layer endpoint.
// OutgoingPacketErrors is the number of IP packets which failed to write to a
// link-layer endpoint.
OutgoingPacketErrors tcpip.MultiCounterStat
// MalformedPacketsReceived is the total number of IP Packets that were
// dropped due to the IP packet header failing validation checks.
// MalformedPacketsReceived is the number of IP Packets that were dropped due
// to the IP packet header failing validation checks.
MalformedPacketsReceived tcpip.MultiCounterStat
// MalformedFragmentsReceived is the total number of IP Fragments that were
// dropped due to the fragment failing validation checks.
// MalformedFragmentsReceived is the number of IP Fragments that were dropped
// due to the fragment failing validation checks.
MalformedFragmentsReceived tcpip.MultiCounterStat
// IPTablesPreroutingDropped is the total number of IP packets dropped in the
// IPTablesPreroutingDropped is the number of IP packets dropped in the
// Prerouting chain.
IPTablesPreroutingDropped tcpip.MultiCounterStat
// IPTablesInputDropped is the total number of IP packets dropped in the Input
// IPTablesInputDropped is the number of IP packets dropped in the Input
// chain.
IPTablesInputDropped tcpip.MultiCounterStat
// IPTablesOutputDropped is the total number of IP packets dropped in the
// Output chain.
// IPTablesOutputDropped is the number of IP packets dropped in the Output
// chain.
IPTablesOutputDropped tcpip.MultiCounterStat
// IPTablesPostroutingDropped is the number of IP packets dropped in the
// Postrouting chain.
IPTablesPostroutingDropped tcpip.MultiCounterStat
// TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out
// of IPStats.
@ -98,6 +101,7 @@ func (m *MultiCounterIPStats) Init(a, b *tcpip.IPStats) {
m.IPTablesPreroutingDropped.Init(a.IPTablesPreroutingDropped, b.IPTablesPreroutingDropped)
m.IPTablesInputDropped.Init(a.IPTablesInputDropped, b.IPTablesInputDropped)
m.IPTablesOutputDropped.Init(a.IPTablesOutputDropped, b.IPTablesOutputDropped)
m.IPTablesPostroutingDropped.Init(a.IPTablesPostroutingDropped, b.IPTablesPostroutingDropped)
m.OptionTimestampReceived.Init(a.OptionTimestampReceived, b.OptionTimestampReceived)
m.OptionRecordRouteReceived.Init(a.OptionRecordRouteReceived, b.OptionRecordRouteReceived)
m.OptionRouterAlertReceived.Init(a.OptionRouterAlertReceived, b.OptionRouterAlertReceived)

View File

@ -415,6 +415,15 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
return nil
}
// Postrouting NAT can only change the source address, and does not alter the
// route or outgoing interface of the packet.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesPostroutingDropped.Increment()
return nil
}
stats := e.stats.ip
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
@ -486,9 +495,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
// iptables filtering. All packets that reach here are locally
// generated.
dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "", outNicName)
stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
for pkt := range dropped {
outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
for pkt := range outputDropped {
pkts.Remove(pkt)
}
@ -510,6 +519,15 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
}
// We ignore the list of NAT-ed packets here because Postrouting NAT can only
// change the source address, and does not alter the route or outgoing
// interface of the packet.
postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, gso, r, "" /* inNicName */, outNicName)
stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
for pkt := range postroutingDropped {
pkts.Remove(pkt)
}
// The rest of the packets can be delivered to the NIC as a batch.
pktsLen := pkts.Len()
written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
@ -517,7 +535,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))
// Dropped packets aren't errors, so include them in the return value.
return locallyDelivered + written + len(dropped), err
return locallyDelivered + written + len(outputDropped) + len(postroutingDropped), err
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.

View File

@ -2612,34 +2612,36 @@ func TestWriteStats(t *testing.T) {
const nPackets = 3
tests := []struct {
name string
setup func(*testing.T, *stack.Stack)
allowPackets int
expectSent int
expectDropped int
expectWritten int
name string
setup func(*testing.T, *stack.Stack)
allowPackets int
expectSent int
expectOutputDropped int
expectPostroutingDropped int
expectWritten int
}{
{
name: "Accept all",
// No setup needed, tables accept everything by default.
setup: func(*testing.T, *stack.Stack) {},
allowPackets: math.MaxInt32,
expectSent: nPackets,
expectDropped: 0,
expectWritten: nPackets,
setup: func(*testing.T, *stack.Stack) {},
allowPackets: math.MaxInt32,
expectSent: nPackets,
expectOutputDropped: 0,
expectPostroutingDropped: 0,
expectWritten: nPackets,
}, {
name: "Accept all with error",
// No setup needed, tables accept everything by default.
setup: func(*testing.T, *stack.Stack) {},
allowPackets: nPackets - 1,
expectSent: nPackets - 1,
expectDropped: 0,
expectWritten: nPackets - 1,
setup: func(*testing.T, *stack.Stack) {},
allowPackets: nPackets - 1,
expectSent: nPackets - 1,
expectOutputDropped: 0,
expectPostroutingDropped: 0,
expectWritten: nPackets - 1,
}, {
name: "Drop all",
name: "Drop all with Output chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Output DROP rule.
t.Helper()
ipt := stk.IPTables()
filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
ruleIdx := filter.BuiltinChains[stack.Output]
@ -2648,16 +2650,32 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %s", err)
}
},
allowPackets: math.MaxInt32,
expectSent: 0,
expectDropped: nPackets,
expectWritten: nPackets,
allowPackets: math.MaxInt32,
expectSent: 0,
expectOutputDropped: nPackets,
expectPostroutingDropped: 0,
expectWritten: nPackets,
}, {
name: "Drop some",
name: "Drop all with Postrouting chain",
setup: func(t *testing.T, stk *stack.Stack) {
ipt := stk.IPTables()
filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
ruleIdx := filter.BuiltinChains[stack.Postrouting]
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %s", err)
}
},
allowPackets: math.MaxInt32,
expectSent: 0,
expectOutputDropped: 0,
expectPostroutingDropped: nPackets,
expectWritten: nPackets,
}, {
name: "Drop some with Output chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Output DROP rule that matches only 1
// of the 3 packets.
t.Helper()
ipt := stk.IPTables()
filter := ipt.GetTable(stack.FilterID, false /* ipv6 */)
// We'll match and DROP the last packet.
@ -2670,10 +2688,33 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %s", err)
}
},
allowPackets: math.MaxInt32,
expectSent: nPackets - 1,
expectDropped: 1,
expectWritten: nPackets,
allowPackets: math.MaxInt32,
expectSent: nPackets - 1,
expectOutputDropped: 1,
expectPostroutingDropped: 0,
expectWritten: nPackets,
}, {
name: "Drop some with Postrouting chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Postrouting DROP rule that matches only 1
// of the 3 packets.
ipt := stk.IPTables()
filter := ipt.GetTable(stack.NATID, false /* ipv6 */)
// We'll match and DROP the last packet.
ruleIdx := filter.BuiltinChains[stack.Postrouting]
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
// Make sure the next rule is ACCEPT.
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %s", err)
}
},
allowPackets: math.MaxInt32,
expectSent: nPackets - 1,
expectOutputDropped: 0,
expectPostroutingDropped: 1,
expectWritten: nPackets,
},
}
@ -2724,13 +2765,16 @@ func TestWriteStats(t *testing.T) {
nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent)
}
if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped {
t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped)
}
if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped {
t.Errorf("got rt.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped)
}
if nWritten != test.expectWritten {
t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten)
}
})
}

View File

@ -769,6 +769,15 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet
return nil
}
// Postrouting NAT can only change the source address, and does not alter the
// route or outgoing interface of the packet.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, gso, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok {
// iptables is telling us to drop the packet.
e.stats.ip.IPTablesPostroutingDropped.Increment()
return nil
}
stats := e.stats.ip
networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size()))
if err != nil {
@ -840,9 +849,9 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// iptables filtering. All packets that reach here are locally
// generated.
outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID())
dropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
stats.IPTablesOutputDropped.IncrementBy(uint64(len(dropped)))
for pkt := range dropped {
outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, gso, r, "" /* inNicName */, outNicName)
stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped)))
for pkt := range outputDropped {
pkts.Remove(pkt)
}
@ -863,6 +872,15 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
locallyDelivered++
}
// We ignore the list of NAT-ed packets here because Postrouting NAT can only
// change the source address, and does not alter the route or outgoing
// interface of the packet.
postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, gso, r, "" /* inNicName */, outNicName)
stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped)))
for pkt := range postroutingDropped {
pkts.Remove(pkt)
}
// The rest of the packets can be delivered to the NIC as a batch.
pktsLen := pkts.Len()
written, err := e.nic.WritePackets(r, gso, pkts, ProtocolNumber)
@ -870,7 +888,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
stats.OutgoingPacketErrors.IncrementBy(uint64(pktsLen - written))
// Dropped packets aren't errors, so include them in the return value.
return locallyDelivered + written + len(dropped), err
return locallyDelivered + written + len(outputDropped) + len(postroutingDropped), err
}
// WriteHeaderIncludedPacket implements stack.NetworkEndpoint.

View File

@ -2468,34 +2468,36 @@ func TestFragmentReassemblyTimeout(t *testing.T) {
func TestWriteStats(t *testing.T) {
const nPackets = 3
tests := []struct {
name string
setup func(*testing.T, *stack.Stack)
allowPackets int
expectSent int
expectDropped int
expectWritten int
name string
setup func(*testing.T, *stack.Stack)
allowPackets int
expectSent int
expectOutputDropped int
expectPostroutingDropped int
expectWritten int
}{
{
name: "Accept all",
// No setup needed, tables accept everything by default.
setup: func(*testing.T, *stack.Stack) {},
allowPackets: math.MaxInt32,
expectSent: nPackets,
expectDropped: 0,
expectWritten: nPackets,
setup: func(*testing.T, *stack.Stack) {},
allowPackets: math.MaxInt32,
expectSent: nPackets,
expectOutputDropped: 0,
expectPostroutingDropped: 0,
expectWritten: nPackets,
}, {
name: "Accept all with error",
// No setup needed, tables accept everything by default.
setup: func(*testing.T, *stack.Stack) {},
allowPackets: nPackets - 1,
expectSent: nPackets - 1,
expectDropped: 0,
expectWritten: nPackets - 1,
setup: func(*testing.T, *stack.Stack) {},
allowPackets: nPackets - 1,
expectSent: nPackets - 1,
expectOutputDropped: 0,
expectPostroutingDropped: 0,
expectWritten: nPackets - 1,
}, {
name: "Drop all",
name: "Drop all with Output chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Output DROP rule.
t.Helper()
ipt := stk.IPTables()
filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
ruleIdx := filter.BuiltinChains[stack.Output]
@ -2504,16 +2506,33 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %v", err)
}
},
allowPackets: math.MaxInt32,
expectSent: 0,
expectDropped: nPackets,
expectWritten: nPackets,
allowPackets: math.MaxInt32,
expectSent: 0,
expectOutputDropped: nPackets,
expectPostroutingDropped: 0,
expectWritten: nPackets,
}, {
name: "Drop some",
name: "Drop all with Postrouting chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Output DROP rule.
ipt := stk.IPTables()
filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
ruleIdx := filter.BuiltinChains[stack.Postrouting]
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %v", err)
}
},
allowPackets: math.MaxInt32,
expectSent: 0,
expectOutputDropped: 0,
expectPostroutingDropped: nPackets,
expectWritten: nPackets,
}, {
name: "Drop some with Output chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Output DROP rule that matches only 1
// of the 3 packets.
t.Helper()
ipt := stk.IPTables()
filter := ipt.GetTable(stack.FilterID, true /* ipv6 */)
// We'll match and DROP the last packet.
@ -2526,10 +2545,33 @@ func TestWriteStats(t *testing.T) {
t.Fatalf("failed to replace table: %v", err)
}
},
allowPackets: math.MaxInt32,
expectSent: nPackets - 1,
expectDropped: 1,
expectWritten: nPackets,
allowPackets: math.MaxInt32,
expectSent: nPackets - 1,
expectOutputDropped: 1,
expectPostroutingDropped: 0,
expectWritten: nPackets,
}, {
name: "Drop some with Postrouting chain",
setup: func(t *testing.T, stk *stack.Stack) {
// Install Postrouting DROP rule that matches only 1
// of the 3 packets.
ipt := stk.IPTables()
filter := ipt.GetTable(stack.NATID, true /* ipv6 */)
// We'll match and DROP the last packet.
ruleIdx := filter.BuiltinChains[stack.Postrouting]
filter.Rules[ruleIdx].Target = &stack.DropTarget{}
filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}}
// Make sure the next rule is ACCEPT.
filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{}
if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil {
t.Fatalf("failed to replace table: %v", err)
}
},
allowPackets: math.MaxInt32,
expectSent: nPackets - 1,
expectOutputDropped: 0,
expectPostroutingDropped: 1,
expectWritten: nPackets,
},
}
@ -2578,13 +2620,16 @@ func TestWriteStats(t *testing.T) {
nWritten, _ := writer.writePackets(rt, pkts)
if got := int(rt.Stats().IP.PacketsSent.Value()); got != test.expectSent {
t.Errorf("sent %d packets, but expected to send %d", got, test.expectSent)
t.Errorf("got rt.Stats().IP.PacketsSent.Value() = %d, want = %d", got, test.expectSent)
}
if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectDropped {
t.Errorf("dropped %d packets, but expected to drop %d", got, test.expectDropped)
if got := int(rt.Stats().IP.IPTablesOutputDropped.Value()); got != test.expectOutputDropped {
t.Errorf("got rt.Stats().IP.IPTablesOutputDropped.Value() = %d, want = %d", got, test.expectOutputDropped)
}
if got := int(rt.Stats().IP.IPTablesPostroutingDropped.Value()); got != test.expectPostroutingDropped {
t.Errorf("got r.Stats().IP.IPTablesPostroutingDropped.Value() = %d, want = %d", got, test.expectPostroutingDropped)
}
if nWritten != test.expectWritten {
t.Errorf("wrote %d packets, but expected WritePackets to return %d", nWritten, test.expectWritten)
t.Errorf("got nWritten = %d, want = %d", nWritten, test.expectWritten)
}
})
}

View File

@ -45,6 +45,7 @@ go_library(
"addressable_endpoint_state.go",
"conntrack.go",
"headertype_string.go",
"hook_string.go",
"icmp_rate_limit.go",
"iptables.go",
"iptables_state.go",

View File

@ -16,6 +16,7 @@ package stack
import (
"encoding/binary"
"fmt"
"sync"
"time"
@ -29,7 +30,7 @@ import (
// The connection is created for a packet if it does not exist. Every
// connection contains two tuples (original and reply). The tuples are
// manipulated if there is a matching NAT rule. The packet is modified by
// looking at the tuples in the Prerouting and Output hooks.
// looking at the tuples in each hook.
//
// Currently, only TCP tracking is supported.
@ -46,12 +47,14 @@ const (
)
// Manipulation type for the connection.
// TODO(gvisor.dev/issue/5696): Define this as a bit set and support SNAT and
// DNAT at the same time.
type manipType int
const (
manipNone manipType = iota
manipDstPrerouting
manipDstOutput
manipSource
manipDestination
)
// tuple holds a connection's identifying and manipulating data in one
@ -108,6 +111,7 @@ type conn struct {
reply tuple
// manip indicates if the packet should be manipulated. It is immutable.
// TODO(gvisor.dev/issue/5696): Support updating manipulation type.
manip manipType
// tcbHook indicates if the packet is inbound or outbound to
@ -124,6 +128,18 @@ type conn struct {
lastUsed time.Time `state:".(unixTime)"`
}
// newConn creates new connection.
func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
conn := conn{
manip: manip,
tcbHook: hook,
lastUsed: time.Now(),
}
conn.original = tuple{conn: &conn, tupleID: orig}
conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
return &conn
}
// timedOut returns whether the connection timed out based on its state.
func (cn *conn) timedOut(now time.Time) bool {
const establishedTimeout = 5 * 24 * time.Hour
@ -219,18 +235,6 @@ func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) {
}, nil
}
// newConn creates new connection.
func newConn(orig, reply tupleID, manip manipType, hook Hook) *conn {
conn := conn{
manip: manip,
tcbHook: hook,
lastUsed: time.Now(),
}
conn.original = tuple{conn: &conn, tupleID: orig}
conn.reply = tuple{conn: &conn, tupleID: reply, direction: dirReply}
return &conn
}
func (ct *ConnTrack) init() {
ct.mu.Lock()
defer ct.mu.Unlock()
@ -284,20 +288,41 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint1
return nil
}
// Create a new connection and change the port as per the iptables
// rule. This tuple will be used to manipulate the packet in
// handlePacket.
replyTID := tid.reply()
replyTID.srcAddr = address
replyTID.srcPort = port
var manip manipType
switch hook {
case Prerouting:
manip = manipDstPrerouting
case Output:
manip = manipDstOutput
conn, _ := ct.connForTID(tid)
if conn != nil {
// The connection is already tracked.
// TODO(gvisor.dev/issue/5696): Support updating an existing connection.
return nil
}
conn := newConn(tid, replyTID, manip, hook)
conn = newConn(tid, replyTID, manipDestination, hook)
ct.insertConn(conn)
return conn
}
func (ct *ConnTrack) insertSNATConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn {
tid, err := packetToTupleID(pkt)
if err != nil {
return nil
}
if hook != Input && hook != Postrouting {
return nil
}
replyTID := tid.reply()
replyTID.dstAddr = address
replyTID.dstPort = port
conn, _ := ct.connForTID(tid)
if conn != nil {
// The connection is already tracked.
// TODO(gvisor.dev/issue/5696): Support updating an existing connection.
return nil
}
conn = newConn(tid, replyTID, manipSource, hook)
ct.insertConn(conn)
return conn
}
@ -322,6 +347,7 @@ func (ct *ConnTrack) insertConn(conn *conn) {
// Now that we hold the locks, ensure the tuple hasn't been inserted by
// another thread.
// TODO(gvisor.dev/issue/5773): Should check conn.reply.tupleID, too?
alreadyInserted := false
for other := ct.buckets[tupleBucket].tuples.Front(); other != nil; other = other.Next() {
if other.tupleID == conn.original.tupleID {
@ -343,86 +369,6 @@ func (ct *ConnTrack) insertConn(conn *conn) {
}
}
// handlePacketPrerouting manipulates ports for packets in Prerouting hook.
// TODO(gvisor.dev/issue/170): Change address for Prerouting hook.
func handlePacketPrerouting(pkt *PacketBuffer, conn *conn, dir direction) {
// If this is a noop entry, don't do anything.
if conn.manip == manipNone {
return
}
netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
// For prerouting redirection, packets going in the original direction
// have their destinations modified and replies have their sources
// modified.
switch dir {
case dirOriginal:
port := conn.reply.srcPort
tcpHeader.SetDestinationPort(port)
netHeader.SetDestinationAddress(conn.reply.srcAddr)
case dirReply:
port := conn.original.dstPort
tcpHeader.SetSourcePort(port)
netHeader.SetSourceAddress(conn.original.dstAddr)
}
// TODO(gvisor.dev/issue/170): TCP checksums aren't usually validated
// on inbound packets, so we don't recalculate them. However, we should
// support cases when they are validated, e.g. when we can't offload
// receive checksumming.
// After modification, IPv4 packets need a valid checksum.
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
netHeader := header.IPv4(pkt.NetworkHeader().View())
netHeader.SetChecksum(0)
netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
}
// handlePacketOutput manipulates ports for packets in Output hook.
func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir direction) {
// If this is a noop entry, don't do anything.
if conn.manip == manipNone {
return
}
netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
// For output redirection, packets going in the original direction
// have their destinations modified and replies have their sources
// modified. For prerouting redirection, we only reach this point
// when replying, so packet sources are modified.
if conn.manip == manipDstOutput && dir == dirOriginal {
port := conn.reply.srcPort
tcpHeader.SetDestinationPort(port)
netHeader.SetDestinationAddress(conn.reply.srcAddr)
} else {
port := conn.original.dstPort
tcpHeader.SetSourcePort(port)
netHeader.SetSourceAddress(conn.original.dstAddr)
}
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
length := uint16(len(tcpHeader) + pkt.Data().Size())
xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
} else if r.RequiresTXTransportChecksum() {
xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
netHeader := header.IPv4(pkt.NetworkHeader().View())
netHeader.SetChecksum(0)
netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
}
// handlePacket will manipulate the port and address of the packet if the
// connection exists. Returns whether, after the packet traverses the tables,
// it should create a new entry in the table.
@ -431,7 +377,9 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
return false
}
if hook != Prerouting && hook != Output {
switch hook {
case Prerouting, Input, Output, Postrouting:
default:
return false
}
@ -441,23 +389,79 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, gso *GSO, r *Rou
}
conn, dir := ct.connFor(pkt)
// Connection or Rule not found for the packet.
// Connection not found for the packet.
if conn == nil {
return true
// If this is the last hook in the data path for this packet (Input if
// incoming, Postrouting if outgoing), indicate that a connection should be
// inserted by the end of this hook.
return hook == Input || hook == Postrouting
}
netHeader := pkt.Network()
tcpHeader := header.TCP(pkt.TransportHeader().View())
if len(tcpHeader) < header.TCPMinimumSize {
return false
}
// TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
// validated if checksum offloading is off. It may require IP defrag if the
// packets are fragmented.
switch hook {
case Prerouting:
handlePacketPrerouting(pkt, conn, dir)
case Output:
handlePacketOutput(pkt, conn, gso, r, dir)
case Prerouting, Output:
if conn.manip == manipDestination {
switch dir {
case dirOriginal:
tcpHeader.SetDestinationPort(conn.reply.srcPort)
netHeader.SetDestinationAddress(conn.reply.srcAddr)
case dirReply:
tcpHeader.SetSourcePort(conn.original.dstPort)
netHeader.SetSourceAddress(conn.original.dstAddr)
}
pkt.NatDone = true
}
case Input, Postrouting:
if conn.manip == manipSource {
switch dir {
case dirOriginal:
tcpHeader.SetSourcePort(conn.reply.dstPort)
netHeader.SetSourceAddress(conn.reply.dstAddr)
case dirReply:
tcpHeader.SetDestinationPort(conn.original.srcPort)
netHeader.SetDestinationAddress(conn.original.srcAddr)
}
pkt.NatDone = true
}
default:
panic(fmt.Sprintf("unrecognized hook = %s", hook))
}
if !pkt.NatDone {
return false
}
switch hook {
case Prerouting, Input:
case Output, Postrouting:
// Calculate the TCP checksum and set it.
tcpHeader.SetChecksum(0)
length := uint16(len(tcpHeader) + pkt.Data().Size())
xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
if gso != nil && gso.NeedsCsum {
tcpHeader.SetChecksum(xsum)
} else if r.RequiresTXTransportChecksum() {
xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum))
}
default:
panic(fmt.Sprintf("unrecognized hook = %s", hook))
}
// After modification, IPv4 packets need a valid checksum.
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
netHeader := header.IPv4(pkt.NetworkHeader().View())
netHeader.SetChecksum(0)
netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
pkt.NatDone = true
// Update the state of tcb.
// TODO(gvisor.dev/issue/170): Add support in tcpcontrack to handle
@ -638,8 +642,8 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ
if conn == nil {
// Not a tracked connection.
return "", 0, &tcpip.ErrNotConnected{}
} else if conn.manip == manipNone {
// Unmanipulated connection.
} else if conn.manip != manipDestination {
// Unmanipulated destination.
return "", 0, &tcpip.ErrInvalidOptionValue{}
}

View File

@ -0,0 +1,41 @@
// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at //
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Code generated by "stringer -type Hook ."; DO NOT EDIT.
package stack
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[Prerouting-0]
_ = x[Input-1]
_ = x[Forward-2]
_ = x[Output-3]
_ = x[Postrouting-4]
_ = x[NumHooks-5]
}
const _Hook_name = "PreroutingInputForwardOutputPostroutingNumHooks"
var _Hook_index = [...]uint8{0, 10, 15, 22, 28, 39, 47}
func (i Hook) String() string {
if i >= Hook(len(_Hook_index)-1) {
return "Hook(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _Hook_name[_Hook_index[i]:_Hook_index[i+1]]
}

View File

@ -175,9 +175,10 @@ func DefaultTables() *IPTables {
},
},
priorities: [NumHooks][]TableID{
Prerouting: {MangleID, NATID},
Input: {NATID, FilterID},
Output: {MangleID, NATID, FilterID},
Prerouting: {MangleID, NATID},
Input: {NATID, FilterID},
Output: {MangleID, NATID, FilterID},
Postrouting: {MangleID, NATID},
},
connections: ConnTrack{
seed: generateRandUint32(),

View File

@ -182,3 +182,81 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs
return RuleAccept, 0
}
// SNATTarget modifies the source port/IP in the outgoing packets.
type SNATTarget struct {
Addr tcpip.Address
Port uint16
// NetworkProtocol is the network protocol the target is used with. It
// is immutable.
NetworkProtocol tcpip.NetworkProtocolNumber
}
// Action implements Target.Action.
func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) {
// Sanity check.
if st.NetworkProtocol != pkt.NetworkProtocolNumber {
panic(fmt.Sprintf(
"SNATTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d",
st.NetworkProtocol, pkt.NetworkProtocolNumber))
}
// Packet is already manipulated.
if pkt.NatDone {
return RuleAccept, 0
}
// Drop the packet if network and transport header are not set.
if pkt.NetworkHeader().View().IsEmpty() || pkt.TransportHeader().View().IsEmpty() {
return RuleDrop, 0
}
switch hook {
case Postrouting, Input:
case Prerouting, Output, Forward:
panic(fmt.Sprintf("%s not supported", hook))
default:
panic(fmt.Sprintf("%s unrecognized", hook))
}
switch protocol := pkt.TransportProtocolNumber; protocol {
case header.UDPProtocolNumber:
udpHeader := header.UDP(pkt.TransportHeader().View())
udpHeader.SetChecksum(0)
udpHeader.SetSourcePort(st.Port)
netHeader := pkt.Network()
netHeader.SetSourceAddress(st.Addr)
// Only calculate the checksum if offloading isn't supported.
if r.RequiresTXTransportChecksum() {
length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View()))
xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length)
xsum = header.ChecksumCombine(xsum, pkt.Data().AsRange().Checksum())
udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum))
}
// After modification, IPv4 packets need a valid checksum.
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber {
netHeader := header.IPv4(pkt.NetworkHeader().View())
netHeader.SetChecksum(0)
netHeader.SetChecksum(^netHeader.CalculateChecksum())
}
pkt.NatDone = true
case header.TCPProtocolNumber:
if ct == nil {
return RuleAccept, 0
}
// Set up conection for matching NAT rule. Only the first
// packet of the connection comes here. Other packets will be
// manipulated in connection tracking.
if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil {
ct.handlePacket(pkt, hook, gso, r)
}
default:
return RuleDrop, 0
}
return RuleAccept, 0
}

View File

@ -299,9 +299,18 @@ func (pk *PacketBuffer) Network() header.Network {
// See PacketBuffer.Data for details about how a packet buffer holds an inbound
// packet.
func (pk *PacketBuffer) CloneToInbound() *PacketBuffer {
return NewPacketBuffer(PacketBufferOptions{
newPk := NewPacketBuffer(PacketBufferOptions{
Data: buffer.NewVectorisedView(pk.Size(), pk.Views()),
})
// TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to
// maintain this flag in the packet. Currently conntrack needs this flag to
// tell if a noop connection should be inserted at Input hook. Once conntrack
// redefines the manipulation field as mutable, we won't need the special noop
// connection.
if pk.NatDone {
newPk.NatDone = true
}
return newPk
}
// headerInfo stores metadata about a header in a packet.

View File

@ -1556,6 +1556,10 @@ type IPStats struct {
// chain.
IPTablesOutputDropped *StatCounter
// IPTablesPostroutingDropped is the number of IP packets dropped in the
// Postrouting chain.
IPTablesPostroutingDropped *StatCounter
// TODO(https://gvisor.dev/issues/5529): Move the IPv4-only option stats out
// of IPStats.
// OptionTimestampReceived is the number of Timestamp options seen.

View File

@ -456,3 +456,11 @@ func TestNATPreRECVORIGDSTADDR(t *testing.T) {
func TestNATOutRECVORIGDSTADDR(t *testing.T) {
singleTest(t, &NATOutRECVORIGDSTADDR{})
}
func TestNATPostSNATUDP(t *testing.T) {
singleTest(t, &NATPostSNATUDP{})
}
func TestNATPostSNATTCP(t *testing.T) {
singleTest(t, &NATPostSNATTCP{})
}

View File

@ -69,29 +69,41 @@ func tableRules(ipv6 bool, table string, argsList [][]string) error {
return nil
}
// listenUDP listens on a UDP port and returns the value of net.Conn.Read() for
// the first read on that port.
// listenUDP listens on a UDP port and returns nil if the first read from that
// port is successful.
func listenUDP(ctx context.Context, port int, ipv6 bool) error {
_, err := listenUDPFrom(ctx, port, ipv6)
return err
}
// listenUDPFrom listens on a UDP port and returns the sender's UDP address if
// the first read from that port is successful.
func listenUDPFrom(ctx context.Context, port int, ipv6 bool) (*net.UDPAddr, error) {
localAddr := net.UDPAddr{
Port: port,
}
conn, err := net.ListenUDP(udpNetwork(ipv6), &localAddr)
if err != nil {
return err
return nil, err
}
defer conn.Close()
ch := make(chan error)
type result struct {
remoteAddr *net.UDPAddr
err error
}
ch := make(chan result)
go func() {
_, err = conn.Read([]byte{0})
ch <- err
_, remoteAddr, err := conn.ReadFromUDP([]byte{0})
ch <- result{remoteAddr, err}
}()
select {
case err := <-ch:
return err
case res := <-ch:
return res.remoteAddr, res.err
case <-ctx.Done():
return ctx.Err()
return nil, fmt.Errorf("timed out reading from %s: %w", &localAddr, ctx.Err())
}
}
@ -125,8 +137,16 @@ func sendUDPLoop(ctx context.Context, ip net.IP, port int, ipv6 bool) error {
}
}
// listenTCP listens for connections on a TCP port.
// listenTCP listens for connections on a TCP port, and returns nil if a
// connection is established.
func listenTCP(ctx context.Context, port int, ipv6 bool) error {
_, err := listenTCPFrom(ctx, port, ipv6)
return err
}
// listenTCP listens for connections on a TCP port, and returns the remote
// TCP address if a connection is established.
func listenTCPFrom(ctx context.Context, port int, ipv6 bool) (net.Addr, error) {
localAddr := net.TCPAddr{
Port: port,
}
@ -134,23 +154,32 @@ func listenTCP(ctx context.Context, port int, ipv6 bool) error {
// Starts listening on port.
lConn, err := net.ListenTCP(tcpNetwork(ipv6), &localAddr)
if err != nil {
return err
return nil, err
}
defer lConn.Close()
type result struct {
remoteAddr net.Addr
err error
}
// Accept connections on port.
ch := make(chan error)
ch := make(chan result)
go func() {
conn, err := lConn.AcceptTCP()
ch <- err
var remoteAddr net.Addr
if err == nil {
remoteAddr = conn.RemoteAddr()
}
ch <- result{remoteAddr, err}
conn.Close()
}()
select {
case err := <-ch:
return err
case res := <-ch:
return res.remoteAddr, res.err
case <-ctx.Done():
return fmt.Errorf("timed out waiting for a connection at %#v: %w", localAddr, ctx.Err())
return nil, fmt.Errorf("timed out waiting for a connection at %s: %w", &localAddr, ctx.Err())
}
}

View File

@ -19,6 +19,7 @@ import (
"errors"
"fmt"
"net"
"strconv"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/binary"
@ -48,6 +49,8 @@ func init() {
RegisterTestCase(&NATOutOriginalDst{})
RegisterTestCase(&NATPreRECVORIGDSTADDR{})
RegisterTestCase(&NATOutRECVORIGDSTADDR{})
RegisterTestCase(&NATPostSNATUDP{})
RegisterTestCase(&NATPostSNATTCP{})
}
// NATPreRedirectUDPPort tests that packets are redirected to different port.
@ -486,7 +489,12 @@ func (*NATLoopbackSkipsPrerouting) Name() string {
// ContainerAction implements TestCase.ContainerAction.
func (*NATLoopbackSkipsPrerouting) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
// Redirect anything sent to localhost to an unused port.
dest := []byte{127, 0, 0, 1}
var dest net.IP
if ipv6 {
dest = net.IPv6loopback
} else {
dest = net.IPv4(127, 0, 0, 1)
}
if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil {
return err
}
@ -915,3 +923,115 @@ func addrMatches6(got unix.RawSockaddrInet6, wantAddrs []net.IP, port uint16) er
}
return fmt.Errorf("got %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, wantAddrs)
}
const (
snatAddrV4 = "194.236.50.155"
snatAddrV6 = "2a0a::1"
snatPort = 43
)
// NATPostSNATUDP tests that the source port/IP in the packets are modified as expected.
type NATPostSNATUDP struct{ localCase }
var _ TestCase = (*NATPostSNATUDP)(nil)
// Name implements TestCase.Name.
func (*NATPostSNATUDP) Name() string {
return "NATPostSNATUDP"
}
// ContainerAction implements TestCase.ContainerAction.
func (*NATPostSNATUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
var source string
if ipv6 {
source = fmt.Sprintf("[%s]:%d", snatAddrV6, snatPort)
} else {
source = fmt.Sprintf("%s:%d", snatAddrV4, snatPort)
}
if err := natTable(ipv6, "-A", "POSTROUTING", "-p", "udp", "-j", "SNAT", "--to-source", source); err != nil {
return err
}
return sendUDPLoop(ctx, ip, acceptPort, ipv6)
}
// LocalAction implements TestCase.LocalAction.
func (*NATPostSNATUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
remote, err := listenUDPFrom(ctx, acceptPort, ipv6)
if err != nil {
return err
}
var snatAddr string
if ipv6 {
snatAddr = snatAddrV6
} else {
snatAddr = snatAddrV4
}
if got, want := remote.IP, net.ParseIP(snatAddr); !got.Equal(want) {
return fmt.Errorf("got remote address = %s, want = %s", got, want)
}
if got, want := remote.Port, snatPort; got != want {
return fmt.Errorf("got remote port = %d, want = %d", got, want)
}
return nil
}
// NATPostSNATTCP tests that the source port/IP in the packets are modified as
// expected.
type NATPostSNATTCP struct{ localCase }
var _ TestCase = (*NATPostSNATTCP)(nil)
// Name implements TestCase.Name.
func (*NATPostSNATTCP) Name() string {
return "NATPostSNATTCP"
}
// ContainerAction implements TestCase.ContainerAction.
func (*NATPostSNATTCP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error {
addrs, err := getInterfaceAddrs(ipv6)
if err != nil {
return err
}
var source string
for _, addr := range addrs {
if addr.To4() != nil {
if !ipv6 {
source = fmt.Sprintf("%s:%d", addr, snatPort)
}
} else if ipv6 && addr.IsGlobalUnicast() {
source = fmt.Sprintf("[%s]:%d", addr, snatPort)
}
}
if source == "" {
return fmt.Errorf("can't find any interface address to use")
}
if err := natTable(ipv6, "-A", "POSTROUTING", "-p", "tcp", "-j", "SNAT", "--to-source", source); err != nil {
return err
}
return connectTCP(ctx, ip, acceptPort, ipv6)
}
// LocalAction implements TestCase.LocalAction.
func (*NATPostSNATTCP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error {
remote, err := listenTCPFrom(ctx, acceptPort, ipv6)
if err != nil {
return err
}
HostStr, portStr, err := net.SplitHostPort(remote.String())
if err != nil {
return err
}
if got, want := HostStr, ip.String(); got != want {
return fmt.Errorf("got remote address = %s, want = %s", got, want)
}
port, err := strconv.ParseInt(portStr, 10, 0)
if err != nil {
return err
}
if got, want := int(port), snatPort; got != want {
return fmt.Errorf("got remote port = %d, want = %d", got, want)
}
return nil
}