445 lines
13 KiB
Go
445 lines
13 KiB
Go
// Copyright 2018 The gVisor Authors.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// Binary tcp_proxy is a simple TCP proxy.
|
|
package main
|
|
|
|
import (
|
|
"encoding/gob"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"math/rand"
|
|
"net"
|
|
"os"
|
|
"os/signal"
|
|
"regexp"
|
|
"runtime"
|
|
"runtime/pprof"
|
|
"strconv"
|
|
"syscall"
|
|
"time"
|
|
|
|
"golang.org/x/sys/unix"
|
|
"gvisor.dev/gvisor/pkg/tcpip"
|
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
|
"gvisor.dev/gvisor/pkg/tcpip/link/fdbased"
|
|
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
|
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
|
)
|
|
|
|
var (
|
|
port = flag.Int("port", 0, "bind port (all addresses)")
|
|
forward = flag.String("forward", "", "forwarding target")
|
|
client = flag.Bool("client", false, "use netstack for listen")
|
|
server = flag.Bool("server", false, "use netstack for dial")
|
|
|
|
// Netstack-specific options.
|
|
mtu = flag.Int("mtu", 1280, "mtu for network stack")
|
|
addr = flag.String("addr", "", "address for tap-based netstack")
|
|
mask = flag.Int("mask", 8, "mask size for address")
|
|
iface = flag.String("iface", "", "network interface name to bind for netstack")
|
|
sack = flag.Bool("sack", false, "enable SACK support for netstack")
|
|
cubic = flag.Bool("cubic", false, "enable use of CUBIC congestion control for netstack")
|
|
gso = flag.Int("gso", 0, "GSO maximum size")
|
|
swgso = flag.Bool("swgso", false, "software-level GSO")
|
|
clientTCPProbeFile = flag.String("client_tcp_probe_file", "", "if specified, installs a tcp probe to dump endpoint state to the specified file.")
|
|
serverTCPProbeFile = flag.String("server_tcp_probe_file", "", "if specified, installs a tcp probe to dump endpoint state to the specified file.")
|
|
cpuprofile = flag.String("cpuprofile", "", "write cpu profile to the specified file.")
|
|
memprofile = flag.String("memprofile", "", "write memory profile to the specified file.")
|
|
)
|
|
|
|
type impl interface {
|
|
dial(address string) (net.Conn, error)
|
|
listen(port int) (net.Listener, error)
|
|
printStats()
|
|
}
|
|
|
|
type netImpl struct{}
|
|
|
|
func (netImpl) dial(address string) (net.Conn, error) {
|
|
return net.Dial("tcp", address)
|
|
}
|
|
|
|
func (netImpl) listen(port int) (net.Listener, error) {
|
|
return net.Listen("tcp", fmt.Sprintf(":%d", port))
|
|
}
|
|
|
|
func (netImpl) printStats() {
|
|
}
|
|
|
|
const (
|
|
nicID = 1 // Fixed.
|
|
bufSize = 4 << 20 // 4MB.
|
|
)
|
|
|
|
type netstackImpl struct {
|
|
s *stack.Stack
|
|
addr tcpip.Address
|
|
mode string
|
|
}
|
|
|
|
func setupNetwork(ifaceName string, numChannels int) (fds []int, err error) {
|
|
// Get all interfaces in the namespace.
|
|
ifaces, err := net.Interfaces()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying interfaces: %v", err)
|
|
}
|
|
|
|
for _, iface := range ifaces {
|
|
if iface.Name != ifaceName {
|
|
continue
|
|
}
|
|
// Create the socket.
|
|
const protocol = 0x0300 // htons(ETH_P_ALL)
|
|
fds := make([]int, numChannels)
|
|
for i := range fds {
|
|
fd, err := syscall.Socket(syscall.AF_PACKET, syscall.SOCK_RAW, protocol)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to create raw socket: %v", err)
|
|
}
|
|
|
|
// Bind to the appropriate device.
|
|
ll := syscall.SockaddrLinklayer{
|
|
Protocol: protocol,
|
|
Ifindex: iface.Index,
|
|
Pkttype: syscall.PACKET_HOST,
|
|
}
|
|
if err := syscall.Bind(fd, &ll); err != nil {
|
|
return nil, fmt.Errorf("unable to bind to %q: %v", iface.Name, err)
|
|
}
|
|
|
|
// RAW Sockets by default have a very small SO_RCVBUF of 256KB,
|
|
// up it to at least 4MB to reduce packet drops.
|
|
if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, bufSize); err != nil {
|
|
return nil, fmt.Errorf("setsockopt(..., SO_RCVBUF, %v,..) = %v", bufSize, err)
|
|
}
|
|
|
|
if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, bufSize); err != nil {
|
|
return nil, fmt.Errorf("setsockopt(..., SO_SNDBUF, %v,..) = %v", bufSize, err)
|
|
}
|
|
|
|
if !*swgso && *gso != 0 {
|
|
if err := syscall.SetsockoptInt(fd, syscall.SOL_PACKET, unix.PACKET_VNET_HDR, 1); err != nil {
|
|
return nil, fmt.Errorf("unable to enable the PACKET_VNET_HDR option: %v", err)
|
|
}
|
|
}
|
|
fds[i] = fd
|
|
}
|
|
return fds, nil
|
|
}
|
|
return nil, fmt.Errorf("failed to find interface: %v", ifaceName)
|
|
}
|
|
|
|
func newNetstackImpl(mode string) (impl, error) {
|
|
fds, err := setupNetwork(*iface, runtime.GOMAXPROCS(-1))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Parse details.
|
|
parsedAddr := tcpip.Address(net.ParseIP(*addr).To4())
|
|
parsedDest := tcpip.Address("") // Filled in below.
|
|
parsedMask := tcpip.AddressMask("") // Filled in below.
|
|
switch *mask {
|
|
case 8:
|
|
parsedDest = tcpip.Address([]byte{parsedAddr[0], 0, 0, 0})
|
|
parsedMask = tcpip.AddressMask([]byte{0xff, 0, 0, 0})
|
|
case 16:
|
|
parsedDest = tcpip.Address([]byte{parsedAddr[0], parsedAddr[1], 0, 0})
|
|
parsedMask = tcpip.AddressMask([]byte{0xff, 0xff, 0, 0})
|
|
case 24:
|
|
parsedDest = tcpip.Address([]byte{parsedAddr[0], parsedAddr[1], parsedAddr[2], 0})
|
|
parsedMask = tcpip.AddressMask([]byte{0xff, 0xff, 0xff, 0})
|
|
default:
|
|
// This is just laziness; we don't expect a different mask.
|
|
return nil, fmt.Errorf("mask %d not supported", mask)
|
|
}
|
|
|
|
// Create a new network stack.
|
|
netProtos := []stack.NetworkProtocol{ipv4.NewProtocol(), arp.NewProtocol()}
|
|
transProtos := []stack.TransportProtocol{tcp.NewProtocol(), udp.NewProtocol()}
|
|
s := stack.New(stack.Options{
|
|
NetworkProtocols: netProtos,
|
|
TransportProtocols: transProtos,
|
|
})
|
|
|
|
// Generate a new mac for the eth device.
|
|
mac := make(net.HardwareAddr, 6)
|
|
rand.Read(mac) // Fill with random data.
|
|
mac[0] &^= 0x1 // Clear multicast bit.
|
|
mac[0] |= 0x2 // Set local assignment bit (IEEE802).
|
|
ep, err := fdbased.New(&fdbased.Options{
|
|
FDs: fds,
|
|
MTU: uint32(*mtu),
|
|
EthernetHeader: true,
|
|
Address: tcpip.LinkAddress(mac),
|
|
// Enable checksum generation as we need to generate valid
|
|
// checksums for the veth device to deliver our packets to the
|
|
// peer. But we do want to disable checksum verification as veth
|
|
// devices do perform GRO and the linux host kernel may not
|
|
// regenerate valid checksums after GRO.
|
|
TXChecksumOffload: false,
|
|
RXChecksumOffload: true,
|
|
PacketDispatchMode: fdbased.RecvMMsg,
|
|
GSOMaxSize: uint32(*gso),
|
|
SoftwareGSOEnabled: *swgso,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create FD endpoint: %v", err)
|
|
}
|
|
if err := s.CreateNIC(nicID, ep); err != nil {
|
|
return nil, fmt.Errorf("error creating NIC %q: %v", *iface, err)
|
|
}
|
|
if err := s.AddAddress(nicID, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
|
|
return nil, fmt.Errorf("error adding ARP address to %q: %v", *iface, err)
|
|
}
|
|
if err := s.AddAddress(nicID, ipv4.ProtocolNumber, parsedAddr); err != nil {
|
|
return nil, fmt.Errorf("error adding IP address to %q: %v", *iface, err)
|
|
}
|
|
|
|
subnet, err := tcpip.NewSubnet(parsedDest, parsedMask)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("tcpip.Subnet(%s, %s): %s", parsedDest, parsedMask, err)
|
|
}
|
|
// Add default route; we only support
|
|
s.SetRouteTable([]tcpip.Route{
|
|
{
|
|
Destination: subnet,
|
|
NIC: nicID,
|
|
},
|
|
})
|
|
|
|
// Set protocol options.
|
|
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(*sack)); err != nil {
|
|
return nil, fmt.Errorf("SetTransportProtocolOption for SACKEnabled failed: %v", err)
|
|
}
|
|
|
|
// Set Congestion Control to cubic if requested.
|
|
if *cubic {
|
|
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.CongestionControlOption("cubic")); err != nil {
|
|
return nil, fmt.Errorf("SetTransportProtocolOption for CongestionControlOption(cubic) failed: %v", err)
|
|
}
|
|
}
|
|
|
|
return netstackImpl{
|
|
s: s,
|
|
addr: parsedAddr,
|
|
mode: mode,
|
|
}, nil
|
|
}
|
|
|
|
func (n netstackImpl) dial(address string) (net.Conn, error) {
|
|
host, port, err := net.SplitHostPort(address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if host == "" {
|
|
// A host must be provided for the dial.
|
|
return nil, fmt.Errorf("no host provided")
|
|
}
|
|
portNumber, err := strconv.Atoi(port)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
addr := tcpip.FullAddress{
|
|
NIC: nicID,
|
|
Addr: tcpip.Address(net.ParseIP(host).To4()),
|
|
Port: uint16(portNumber),
|
|
}
|
|
conn, err := gonet.DialTCP(n.s, addr, ipv4.ProtocolNumber)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return conn, nil
|
|
}
|
|
|
|
func (n netstackImpl) listen(port int) (net.Listener, error) {
|
|
addr := tcpip.FullAddress{
|
|
NIC: nicID,
|
|
Port: uint16(port),
|
|
}
|
|
listener, err := gonet.ListenTCP(n.s, addr, ipv4.ProtocolNumber)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return listener, nil
|
|
}
|
|
|
|
var zeroFieldsRegexp = regexp.MustCompile(`\s*[a-zA-Z0-9]*:0`)
|
|
|
|
func (n netstackImpl) printStats() {
|
|
// Don't show zero fields.
|
|
stats := zeroFieldsRegexp.ReplaceAllString(fmt.Sprintf("%+v", n.s.Stats()), "")
|
|
log.Printf("netstack %s Stats: %+v\n", n.mode, stats)
|
|
}
|
|
|
|
// installProbe installs a TCP Probe function that will dump endpoint
|
|
// state to the specified file. It also returns a close func() that
|
|
// can be used to close the probeFile.
|
|
func (n netstackImpl) installProbe(probeFileName string) (close func()) {
|
|
// Install Probe to dump out end point state.
|
|
probeFile, err := os.Create(probeFileName)
|
|
if err != nil {
|
|
log.Fatalf("failed to create tcp_probe file %s: %v", probeFileName, err)
|
|
}
|
|
probeEncoder := gob.NewEncoder(probeFile)
|
|
// Install a TCP Probe.
|
|
n.s.AddTCPProbe(func(state stack.TCPEndpointState) {
|
|
probeEncoder.Encode(state)
|
|
})
|
|
return func() { probeFile.Close() }
|
|
}
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
if *port == 0 {
|
|
log.Fatalf("no port provided")
|
|
}
|
|
if *forward == "" {
|
|
log.Fatalf("no forward provided")
|
|
}
|
|
// Seed the random number generator to ensure that we are given MAC addresses that don't
|
|
// for the case of the client and server stack.
|
|
rand.Seed(time.Now().UTC().UnixNano())
|
|
|
|
if *cpuprofile != "" {
|
|
f, err := os.Create(*cpuprofile)
|
|
if err != nil {
|
|
log.Fatal("could not create CPU profile: ", err)
|
|
}
|
|
defer func() {
|
|
if err := f.Close(); err != nil {
|
|
log.Print("error closing CPU profile: ", err)
|
|
}
|
|
}()
|
|
if err := pprof.StartCPUProfile(f); err != nil {
|
|
log.Fatal("could not start CPU profile: ", err)
|
|
}
|
|
defer pprof.StopCPUProfile()
|
|
}
|
|
|
|
var (
|
|
in impl
|
|
out impl
|
|
err error
|
|
)
|
|
if *server {
|
|
in, err = newNetstackImpl("server")
|
|
if *serverTCPProbeFile != "" {
|
|
defer in.(netstackImpl).installProbe(*serverTCPProbeFile)()
|
|
}
|
|
|
|
} else {
|
|
in = netImpl{}
|
|
}
|
|
if err != nil {
|
|
log.Fatalf("netstack error: %v", err)
|
|
}
|
|
if *client {
|
|
out, err = newNetstackImpl("client")
|
|
if *clientTCPProbeFile != "" {
|
|
defer out.(netstackImpl).installProbe(*clientTCPProbeFile)()
|
|
}
|
|
} else {
|
|
out = netImpl{}
|
|
}
|
|
if err != nil {
|
|
log.Fatalf("netstack error: %v", err)
|
|
}
|
|
|
|
// Dial forward before binding.
|
|
var next net.Conn
|
|
for {
|
|
next, err = out.dial(*forward)
|
|
if err == nil {
|
|
break
|
|
}
|
|
time.Sleep(50 * time.Millisecond)
|
|
log.Printf("connect failed retrying: %v", err)
|
|
}
|
|
|
|
// Bind once to the server socket.
|
|
listener, err := in.listen(*port)
|
|
if err != nil {
|
|
// Should not happen, everything must be bound by this time
|
|
// this proxy is started.
|
|
log.Fatalf("unable to listen: %v", err)
|
|
}
|
|
log.Printf("client=%v, server=%v, ready.", *client, *server)
|
|
|
|
sigs := make(chan os.Signal, 1)
|
|
signal.Notify(sigs, syscall.SIGTERM)
|
|
go func() {
|
|
<-sigs
|
|
if *cpuprofile != "" {
|
|
pprof.StopCPUProfile()
|
|
}
|
|
if *memprofile != "" {
|
|
f, err := os.Create(*memprofile)
|
|
if err != nil {
|
|
log.Fatal("could not create memory profile: ", err)
|
|
}
|
|
defer func() {
|
|
if err := f.Close(); err != nil {
|
|
log.Print("error closing memory profile: ", err)
|
|
}
|
|
}()
|
|
runtime.GC() // get up-to-date statistics
|
|
if err := pprof.WriteHeapProfile(f); err != nil {
|
|
log.Fatalf("Unable to write heap profile: %v", err)
|
|
}
|
|
}
|
|
os.Exit(0)
|
|
}()
|
|
|
|
for {
|
|
// Forward all connections.
|
|
inConn, err := listener.Accept()
|
|
if err != nil {
|
|
// This should not happen; we are listening
|
|
// successfully. Exhausted all available FDs?
|
|
log.Fatalf("accept error: %v", err)
|
|
}
|
|
log.Printf("incoming connection established.")
|
|
|
|
// Copy both ways.
|
|
go io.Copy(inConn, next)
|
|
go io.Copy(next, inConn)
|
|
|
|
// Print stats every second.
|
|
go func() {
|
|
t := time.NewTicker(time.Second)
|
|
defer t.Stop()
|
|
for {
|
|
<-t.C
|
|
in.printStats()
|
|
out.printStats()
|
|
}
|
|
}()
|
|
|
|
for {
|
|
// Dial again.
|
|
next, err = out.dial(*forward)
|
|
if err == nil {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|