package socks5 import ( "context" "errors" "fmt" "io" "net" "time" ) var ( // ErrNetworkNotImplemented is the network not implemented error ErrNetworkNotImplemented = errors.New("network not implemented") ) var ( noDeadline = time.Time{} expiredDeadline = time.Unix(1, 0) ) // Dialer is the socks5 dialer type Dialer struct { Server string Username string Password string ProxyDial func(ctx context.Context, network string, addr string) (net.Conn, error) } func (s *Dialer) Dial(network string, address string) (net.Conn, error) { return s.DialContext(context.Background(), network, address) } func (s *Dialer) DialContext(ctx context.Context, network string, address string) (net.Conn, error) { proxy, err := AddrFromString(s.Server) if err != nil { return nil, &net.OpError{Op: "socks", Net: network, Source: proxy, Addr: AddrZero, Err: err} } addr, err := AddrFromString(address) if err != nil { return nil, &net.OpError{Op: "socks", Net: network, Source: proxy, Addr: AddrZero, Err: err} } switch network { case "tcp", "tcp4", "tcp6": case "udp", "udp4", "udp6": default: return nil, &net.OpError{Op: "socks", Net: network, Source: proxy, Addr: addr, Err: ErrNetworkNotImplemented} } var conn net.Conn if s.ProxyDial != nil { conn, err = s.ProxyDial(ctx, "tcp", s.Server) } else { var dd net.Dialer conn, err = dd.DialContext(ctx, "tcp", s.Server) } if err != nil { return nil, &net.OpError{Op: "socks", Net: "tcp", Source: proxy, Addr: addr, Err: err} } if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { conn.SetDeadline(deadline) defer conn.SetDeadline(noDeadline) } var ctxErr error if ctx != context.Background() { errCh := make(chan error, 1) doneCh := make(chan struct{}) defer func() { close(doneCh) if ctxErr == nil { ctxErr = <-errCh } }() go func() { select { case <-ctx.Done(): conn.SetDeadline(expiredDeadline) errCh <- ctx.Err() case <-doneCh: errCh <- nil } }() } if err := s.negotiate(conn); err != nil { conn.Close() return nil, &net.OpError{Op: "socks", Net: "tcp", Source: proxy, Addr: addr, Err: err} } if network == "udp" || network == "udp4" || network == "udp6" { uc, err := s.handleUDP(conn, addr) if err != nil { conn.Close() return nil, &net.OpError{Op: "socks", Net: "udp", Source: proxy, Addr: addr, Err: err} } return uc, ctxErr } if err := s.handleTCP(conn, addr); err != nil { conn.Close() return nil, &net.OpError{Op: "socks", Net: "tcp", Source: proxy, Addr: addr, Err: err} } return conn, ctxErr } func (s *Dialer) DialUDP() (net.PacketConn, error) { conn, err := s.DialContext(context.Background(), "udp", "") if conn == nil { return nil, err } return conn.(net.PacketConn), err } func (s *Dialer) handleTCP(conn net.Conn, addr *Addr) error { atyp, host, port := addr.Socks() req := NewRequest(CommandConnect, atyp, host, port) if _, err := req.WriteTo(conn); err != nil { return err } rep, err := NewReplyFrom(conn) if err != nil { return err } if rep.Status != ReplyStatusSuccess { return fmt.Errorf("bad reply status: %s", rep.Status) } return nil } func (s *Dialer) handleUDP(conn net.Conn, addr *Addr) (*UDPConn, error) { atyp, host, port := addr.Socks() req := NewRequest(CommandUDPAssociate, atyp, host, port) if _, err := req.WriteTo(conn); err != nil { return nil, err } rep, err := NewReplyFrom(conn) if err != nil { return nil, err } if rep.Status != ReplyStatusSuccess { return nil, fmt.Errorf("bad reply status: %s", rep.Status) } packetConn, err := net.ListenPacket("udp", "") if err != nil { return nil, err } go func() { io.Copy(io.Discard, conn) conn.Close() packetConn.Close() }() remoteAddr, err := AddrFromSocks(rep.ATYP, rep.BindAddr, rep.BindPort).UDP() if err != nil { return nil, err } if remoteAddr.IP.IsUnspecified() { serverAddr, err := net.ResolveUDPAddr("udp", s.Server) if err != nil { return nil, fmt.Errorf("resolve udp address %s: %w", s.Server, err) } remoteAddr.IP = serverAddr.IP } udpConn := &UDPConn{ PacketConn: packetConn, TCPConn: conn, ServerAddr: remoteAddr, } if addr != AddrZero { connectedAddr, err := addr.UDP() if err != nil { return nil, err } udpConn.ConnectedAddr = connectedAddr } return udpConn, nil } func (s *Dialer) negotiate(conn net.Conn) error { method := AuthMethodNone if s.Username != "" && s.Password != "" { method = AuthMethodPassword } negReq := NewNegotiationRequest([]AuthMethod{method}) if _, err := negReq.WriteTo(conn); err != nil { return err } negRep, err := NewNegotiationReplyFrom(conn) if err != nil { return err } if negRep.Method != method { return ErrAuthMethodNotSupported } if negRep.Method == AuthMethodPassword { passReq := NewPasswordNegotiationRequest(s.Username, s.Password) if _, err := passReq.WriteTo(conn); err != nil { return err } passRep, err := NewPasswordNegotiationReplyFrom(conn) if err != nil { return err } if passRep.Status != PasswordStatusSuccess { return ErrPasswordAuth } } return nil }