tcpip/transport/udp: add Forwarder type
Add a UDP forwarder for intercepting and forwarding UDP sessions. Change-Id: I2d83c900c1931adfc59a532dd4f6b33a0db406c9 PiperOrigin-RevId: 244293576
This commit is contained in:
parent
f4d434c180
commit
cec2cdc12f
|
@ -222,6 +222,62 @@ func TestCloseReaderWithForwarder(t *testing.T) {
|
||||||
sender.close()
|
sender.close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUDPForwarder(t *testing.T) {
|
||||||
|
s, terr := newLoopbackStack()
|
||||||
|
if terr != nil {
|
||||||
|
t.Fatalf("newLoopbackStack() = %v", terr)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
|
||||||
|
addr1 := tcpip.FullAddress{NICID, ip1, 11211}
|
||||||
|
s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
|
||||||
|
ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
|
||||||
|
addr2 := tcpip.FullAddress{NICID, ip2, 11311}
|
||||||
|
s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
fwd := udp.NewForwarder(s, func(r *udp.ForwarderRequest) {
|
||||||
|
defer close(done)
|
||||||
|
|
||||||
|
var wq waiter.Queue
|
||||||
|
ep, err := r.CreateEndpoint(&wq)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("r.CreateEndpoint() = %v", err)
|
||||||
|
}
|
||||||
|
defer ep.Close()
|
||||||
|
|
||||||
|
c := NewConn(&wq, ep)
|
||||||
|
|
||||||
|
buf := make([]byte, 256)
|
||||||
|
n, e := c.Read(buf)
|
||||||
|
if e != nil {
|
||||||
|
t.Errorf("c.Read() = %v", e)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, e := c.Write(buf[:n]); e != nil {
|
||||||
|
t.Errorf("c.Write() = %v", e)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
s.SetTransportProtocolHandler(udp.ProtocolNumber, fwd.HandlePacket)
|
||||||
|
|
||||||
|
c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("NewPacketConn(port 5):", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sent := "abc123"
|
||||||
|
sendAddr := fullToUDPAddr(addr1)
|
||||||
|
if n, err := c2.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) {
|
||||||
|
t.Errorf("c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 256)
|
||||||
|
n, recvAddr, err := c2.ReadFrom(buf)
|
||||||
|
if err != nil || recvAddr.String() != sendAddr.String() {
|
||||||
|
t.Errorf("c1.ReadFrom() = %d, %v, %v, want = %d, %v, %v", n, recvAddr, err, len(sent), sendAddr, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestDeadlineChange tests that changing the deadline affects currently blocked reads.
|
// TestDeadlineChange tests that changing the deadline affects currently blocked reads.
|
||||||
func TestDeadlineChange(t *testing.T) {
|
func TestDeadlineChange(t *testing.T) {
|
||||||
s, err := newLoopbackStack()
|
s, err := newLoopbackStack()
|
||||||
|
|
|
@ -20,6 +20,7 @@ go_library(
|
||||||
srcs = [
|
srcs = [
|
||||||
"endpoint.go",
|
"endpoint.go",
|
||||||
"endpoint_state.go",
|
"endpoint_state.go",
|
||||||
|
"forwarder.go",
|
||||||
"protocol.go",
|
"protocol.go",
|
||||||
"udp_packet_list.go",
|
"udp_packet_list.go",
|
||||||
],
|
],
|
||||||
|
|
|
@ -0,0 +1,96 @@
|
||||||
|
// Copyright 2019 Google LLC
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package udp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"gvisor.googlesource.com/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
|
||||||
|
"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.googlesource.com/gvisor/pkg/waiter"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Forwarder is a session request forwarder, which allows clients to decide
|
||||||
|
// what to do with a session request, for example: ignore it, or process it.
|
||||||
|
//
|
||||||
|
// The canonical way of using it is to pass the Forwarder.HandlePacket function
|
||||||
|
// to stack.SetTransportProtocolHandler.
|
||||||
|
type Forwarder struct {
|
||||||
|
handler func(*ForwarderRequest)
|
||||||
|
|
||||||
|
stack *stack.Stack
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewForwarder allocates and initializes a new forwarder.
|
||||||
|
func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
|
||||||
|
return &Forwarder{
|
||||||
|
stack: s,
|
||||||
|
handler: handler,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandlePacket handles all packets.
|
||||||
|
//
|
||||||
|
// This function is expected to be passed as an argument to the
|
||||||
|
// stack.SetTransportProtocolHandler function.
|
||||||
|
func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
|
||||||
|
f.handler(&ForwarderRequest{
|
||||||
|
stack: f.stack,
|
||||||
|
route: r,
|
||||||
|
id: id,
|
||||||
|
vv: vv,
|
||||||
|
})
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForwarderRequest represents a session request received by the forwarder and
|
||||||
|
// passed to the client. Clients may optionally create an endpoint to represent
|
||||||
|
// it via CreateEndpoint.
|
||||||
|
type ForwarderRequest struct {
|
||||||
|
stack *stack.Stack
|
||||||
|
route *stack.Route
|
||||||
|
id stack.TransportEndpointID
|
||||||
|
vv buffer.VectorisedView
|
||||||
|
}
|
||||||
|
|
||||||
|
// ID returns the 4-tuple (src address, src port, dst address, dst port) that
|
||||||
|
// represents the session request.
|
||||||
|
func (r *ForwarderRequest) ID() stack.TransportEndpointID {
|
||||||
|
return r.id
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateEndpoint creates a connected UDP endpoint for the session request.
|
||||||
|
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
|
||||||
|
ep := newEndpoint(r.stack, r.route.NetProto, queue)
|
||||||
|
if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil {
|
||||||
|
ep.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ep.id = r.id
|
||||||
|
ep.route = r.route.Clone()
|
||||||
|
ep.dstPort = r.id.RemotePort
|
||||||
|
ep.regNICID = r.route.NICID()
|
||||||
|
|
||||||
|
ep.state = stateConnected
|
||||||
|
|
||||||
|
ep.rcvMu.Lock()
|
||||||
|
ep.rcvReady = true
|
||||||
|
ep.rcvMu.Unlock()
|
||||||
|
|
||||||
|
ep.HandlePacket(r.route, r.id, r.vv)
|
||||||
|
|
||||||
|
return ep, nil
|
||||||
|
}
|
Loading…
Reference in New Issue