socks5/dialer.go

238 lines
5.0 KiB
Go
Raw Normal View History

2021-07-09 16:50:43 +00:00
package socks5
import (
"context"
"errors"
"fmt"
"io"
"net"
"time"
)
var (
// ErrNetworkNotImplemented is the network not implemented error
ErrNetworkNotImplemented = errors.New("network not implemented")
)
var (
noDeadline = time.Time{}
expiredDeadline = time.Unix(1, 0)
)
// Dialer is the socks5 dialer
type Dialer struct {
Server string
Username string
Password string
ProxyDial func(ctx context.Context, network string, addr string) (net.Conn, error)
}
func (s *Dialer) Dial(network string, address string) (net.Conn, error) {
return s.DialContext(context.Background(), network, address)
}
func (s *Dialer) DialContext(ctx context.Context, network string, address string) (net.Conn, error) {
proxy, err := AddrFromString(s.Server)
if err != nil {
return nil, &net.OpError{Op: "socks", Net: network, Source: proxy, Addr: AddrZero, Err: err}
}
addr, err := AddrFromString(address)
if err != nil {
return nil, &net.OpError{Op: "socks", Net: network, Source: proxy, Addr: AddrZero, Err: err}
}
switch network {
case "tcp", "tcp4", "tcp6":
case "udp", "udp4", "udp6":
default:
return nil, &net.OpError{Op: "socks", Net: network, Source: proxy, Addr: addr, Err: ErrNetworkNotImplemented}
}
var conn net.Conn
if s.ProxyDial != nil {
conn, err = s.ProxyDial(ctx, "tcp", s.Server)
} else {
var dd net.Dialer
conn, err = dd.DialContext(ctx, "tcp", s.Server)
}
if err != nil {
return nil, &net.OpError{Op: "socks", Net: "tcp", Source: proxy, Addr: addr, Err: err}
}
if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() {
conn.SetDeadline(deadline)
defer conn.SetDeadline(noDeadline)
}
var ctxErr error
if ctx != context.Background() {
errCh := make(chan error, 1)
doneCh := make(chan struct{})
defer func() {
close(doneCh)
if ctxErr == nil {
ctxErr = <-errCh
}
}()
go func() {
select {
case <-ctx.Done():
conn.SetDeadline(expiredDeadline)
errCh <- ctx.Err()
case <-doneCh:
errCh <- nil
}
}()
}
if err := s.negotiate(conn); err != nil {
conn.Close()
return nil, &net.OpError{Op: "socks", Net: "tcp", Source: proxy, Addr: addr, Err: err}
}
if network == "udp" || network == "udp4" || network == "udp6" {
uc, err := s.handleUDP(conn, addr)
if err != nil {
conn.Close()
return nil, &net.OpError{Op: "socks", Net: "udp", Source: proxy, Addr: addr, Err: err}
}
return uc, ctxErr
}
if err := s.handleTCP(conn, addr); err != nil {
conn.Close()
return nil, &net.OpError{Op: "socks", Net: "tcp", Source: proxy, Addr: addr, Err: err}
}
return conn, ctxErr
}
func (s *Dialer) DialUDP() (net.PacketConn, error) {
conn, err := s.DialContext(context.Background(), "udp", "")
return conn.(net.PacketConn), err
}
func (s *Dialer) handleTCP(conn net.Conn, addr *Addr) error {
atyp, host, port := addr.Socks()
req := NewRequest(CommandConnect, atyp, host, port)
if _, err := req.WriteTo(conn); err != nil {
return err
}
rep, err := NewReplyFrom(conn)
if err != nil {
return err
}
if rep.Status != ReplyStatusSuccess {
return fmt.Errorf("bad reply status: %s", rep.Status)
}
return nil
}
func (s *Dialer) handleUDP(conn net.Conn, addr *Addr) (*UDPConn, error) {
atyp, host, port := addr.Socks()
req := NewRequest(CommandUDPAssociate, atyp, host, port)
if _, err := req.WriteTo(conn); err != nil {
return nil, err
}
rep, err := NewReplyFrom(conn)
if err != nil {
return nil, err
}
if rep.Status != ReplyStatusSuccess {
return nil, fmt.Errorf("bad reply status: %s", rep.Status)
}
packetConn, err := net.ListenPacket("udp", "")
if err != nil {
return nil, err
}
go func() {
io.Copy(io.Discard, conn)
conn.Close()
packetConn.Close()
}()
remoteAddr, err := AddrFromSocks(rep.ATYP, rep.BindAddr, rep.BindPort).UDP()
if err != nil {
return nil, err
}
if remoteAddr.IP.IsUnspecified() {
serverAddr, err := net.ResolveUDPAddr("udp", s.Server)
if err != nil {
return nil, fmt.Errorf("resolve udp address %s: %w", s.Server, err)
}
remoteAddr.IP = serverAddr.IP
}
udpConn := &UDPConn{
PacketConn: packetConn,
TCPConn: conn,
ServerAddr: remoteAddr,
}
if addr != AddrZero {
connectedAddr, err := addr.UDP()
if err != nil {
return nil, err
}
udpConn.ConnectedAddr = connectedAddr
}
return udpConn, nil
}
func (s *Dialer) negotiate(conn net.Conn) error {
method := AuthMethodNone
if s.Username != "" && s.Password != "" {
method = AuthMethodPassword
}
negReq := NewNegotiationRequest([]AuthMethod{method})
if _, err := negReq.WriteTo(conn); err != nil {
return err
}
negRep, err := NewNegotiationReplyFrom(conn)
if err != nil {
return err
}
if negRep.Method != method {
return ErrAuthMethodNotSupported
}
if negRep.Method == AuthMethodPassword {
passReq := NewPasswordNegotiationRequest(s.Username, s.Password)
if _, err := passReq.WriteTo(conn); err != nil {
return err
}
passRep, err := NewPasswordNegotiationReplyFrom(conn)
if err != nil {
return err
}
if passRep.Status != PasswordStatusSuccess {
return ErrPasswordAuth
}
}
return nil
}