commit 786f71c7f72d7311a083f13197582b28bb29d2a6 Author: Anton Zadvorny Date: Fri Jul 9 19:50:43 2021 +0300 Initial commit diff --git a/addr.go b/addr.go new file mode 100644 index 0000000..5231447 --- /dev/null +++ b/addr.go @@ -0,0 +1,116 @@ +package socks5 + +import ( + "encoding/binary" + "errors" + "net" + "strconv" +) + +var ( + // AddrZero is zero address + AddrZero = &Addr{IP: net.IPv4zero, Port: 0} +) + +// Addr implements net.Addr interface +type Addr struct { + IP net.IP + Host string + Port int +} + +// AddrFromString returns new address from string representation +func AddrFromString(address string) (*Addr, error) { + if address == "" { + return AddrZero, nil + } + + host, port, err := splitHostPort(address) + if err != nil { + return nil, err + } + + addr := &Addr{Port: port} + addr.IP = net.ParseIP(host) + if addr.IP == nil { + addr.Host = host + } + + return addr, nil +} + +func AddrFromSocks(atyp ATYP, host []byte, port []byte) *Addr { + addr := &Addr{Port: int(binary.BigEndian.Uint16(port))} + + switch atyp { + case ATYPFQDN: + addr.Host = string(host[1:]) + default: + addr.IP = net.IP(host) + } + + return addr +} + +// Socks returns socks protocol formatted address type, address and port +func (s *Addr) Socks() (ATYP, []byte, []byte) { + portBytes := make([]byte, 2) + binary.BigEndian.PutUint16(portBytes, uint16(s.Port)) + + if s.IP != nil { + if ip4 := s.IP.To4(); ip4 != nil { + return ATYPIPv4, []byte(ip4), portBytes + } else if ip6 := s.IP.To16(); ip6 != nil { + return ATYPIPv6, []byte(ip6), portBytes + } + } + + addr := []byte{byte(len(s.Host))} + addr = append(addr, []byte(s.Host)...) + return ATYPFQDN, addr, portBytes +} + +// UDP returns udp address +func (s *Addr) UDP() (*net.UDPAddr, error) { + if s.IP != nil { + return &net.UDPAddr{IP: s.IP, Port: s.Port}, nil + } + + return net.ResolveUDPAddr("udp", s.Host) +} + +// Network returns address Network +func (s *Addr) Network() string { + return "socks" +} + +func (s *Addr) String() string { + if s == nil { + return "" + } + + port := strconv.Itoa(s.Port) + if s.IP == nil { + return net.JoinHostPort(s.Host, port) + } + + return net.JoinHostPort(s.IP.String(), port) +} + +func splitHostPort(address string) (string, int, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return "", 0, err + } + + portInt, err := strconv.Atoi(port) + if err != nil { + return "", 0, err + } + + if 1 > portInt || portInt > 0xffff { + return "", 0, errors.New("port number out of range " + port) + } + + return host, portInt, nil +} diff --git a/datagram.go b/datagram.go new file mode 100644 index 0000000..7277065 --- /dev/null +++ b/datagram.go @@ -0,0 +1,83 @@ +package socks5 + +import ( + "bytes" + "io" +) + +// Datagram is the datagram packet +type Datagram struct { + ATYP ATYP + DstAddr []byte + DstPort []byte + Data []byte +} + +// NewDatagram creates new datagram packet +func NewDatagram(atyp ATYP, host []byte, port []byte, data []byte) *Datagram { + return &Datagram{ + ATYP: atyp, + DstAddr: host, + DstPort: port, + Data: data, + } +} + +// NewDatagramFrom reads datagram packet from reader +func NewDatagramFrom(r io.Reader) (*Datagram, error) { + buf := make([]byte, 4) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + + if !bytes.Equal([]byte{Reserved, Reserved}, buf[:1]) { + return nil, ErrBadDatagram + } + + if buf[2] != DatagramStandalone { + return nil, ErrFragmentedDatagram + } + + addr, err := ReadAddress(r, ATYP(buf[3])) + if err != nil { + return nil, err + } + + port := make([]byte, 2) + if _, err := io.ReadFull(r, port); err != nil { + return nil, err + } + + data, err := io.ReadAll(r) + if err != nil { + return nil, err + } + + datagram := &Datagram{ + ATYP: ATYP(buf[3]), + DstAddr: addr, + DstPort: port, + Data: data, + } + + return datagram, nil +} + +// WriteTo writes datagram packet +func (s *Datagram) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write(s.Bytes()) + if err != nil { + return int64(n), err + } + + return int64(n), nil +} + +func (s *Datagram) Bytes() []byte { + buf := bytes.NewBuffer([]byte{Reserved, Reserved, DatagramStandalone, byte(s.ATYP)}) + buf.Write(s.DstAddr) + buf.Write(s.DstPort) + buf.Write(s.Data) + + return buf.Bytes() +} diff --git a/dialer.go b/dialer.go new file mode 100644 index 0000000..e28cf85 --- /dev/null +++ b/dialer.go @@ -0,0 +1,237 @@ +package socks5 + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "time" +) + +var ( + // ErrNetworkNotImplemented is the network not implemented error + ErrNetworkNotImplemented = errors.New("network not implemented") +) + +var ( + noDeadline = time.Time{} + expiredDeadline = time.Unix(1, 0) +) + +// Dialer is the socks5 dialer +type Dialer struct { + Server string + Username string + Password string + ProxyDial func(ctx context.Context, network string, addr string) (net.Conn, error) +} + +func (s *Dialer) Dial(network string, address string) (net.Conn, error) { + return s.DialContext(context.Background(), network, address) +} + +func (s *Dialer) DialContext(ctx context.Context, network string, address string) (net.Conn, error) { + proxy, err := AddrFromString(s.Server) + if err != nil { + return nil, &net.OpError{Op: "socks", Net: network, Source: proxy, Addr: AddrZero, Err: err} + } + + addr, err := AddrFromString(address) + if err != nil { + return nil, &net.OpError{Op: "socks", Net: network, Source: proxy, Addr: AddrZero, Err: err} + } + + switch network { + case "tcp", "tcp4", "tcp6": + case "udp", "udp4", "udp6": + default: + return nil, &net.OpError{Op: "socks", Net: network, Source: proxy, Addr: addr, Err: ErrNetworkNotImplemented} + } + + var conn net.Conn + if s.ProxyDial != nil { + conn, err = s.ProxyDial(ctx, "tcp", s.Server) + } else { + var dd net.Dialer + conn, err = dd.DialContext(ctx, "tcp", s.Server) + } + + if err != nil { + return nil, &net.OpError{Op: "socks", Net: "tcp", Source: proxy, Addr: addr, Err: err} + } + + if deadline, ok := ctx.Deadline(); ok && !deadline.IsZero() { + conn.SetDeadline(deadline) + defer conn.SetDeadline(noDeadline) + } + + var ctxErr error + if ctx != context.Background() { + errCh := make(chan error, 1) + doneCh := make(chan struct{}) + + defer func() { + close(doneCh) + if ctxErr == nil { + ctxErr = <-errCh + } + }() + + go func() { + select { + case <-ctx.Done(): + conn.SetDeadline(expiredDeadline) + errCh <- ctx.Err() + case <-doneCh: + errCh <- nil + } + }() + } + + if err := s.negotiate(conn); err != nil { + conn.Close() + return nil, &net.OpError{Op: "socks", Net: "tcp", Source: proxy, Addr: addr, Err: err} + } + + if network == "udp" || network == "udp4" || network == "udp6" { + uc, err := s.handleUDP(conn, addr) + if err != nil { + conn.Close() + return nil, &net.OpError{Op: "socks", Net: "udp", Source: proxy, Addr: addr, Err: err} + } + + return uc, ctxErr + } + + if err := s.handleTCP(conn, addr); err != nil { + conn.Close() + return nil, &net.OpError{Op: "socks", Net: "tcp", Source: proxy, Addr: addr, Err: err} + } + + return conn, ctxErr +} + +func (s *Dialer) DialUDP() (net.PacketConn, error) { + conn, err := s.DialContext(context.Background(), "udp", "") + return conn.(net.PacketConn), err +} + +func (s *Dialer) handleTCP(conn net.Conn, addr *Addr) error { + atyp, host, port := addr.Socks() + + req := NewRequest(CommandConnect, atyp, host, port) + if _, err := req.WriteTo(conn); err != nil { + return err + } + + rep, err := NewReplyFrom(conn) + if err != nil { + return err + } + + if rep.Status != ReplyStatusSuccess { + return fmt.Errorf("bad reply status: %s", rep.Status) + } + + return nil +} + +func (s *Dialer) handleUDP(conn net.Conn, addr *Addr) (*UDPConn, error) { + atyp, host, port := addr.Socks() + + req := NewRequest(CommandUDPAssociate, atyp, host, port) + if _, err := req.WriteTo(conn); err != nil { + return nil, err + } + + rep, err := NewReplyFrom(conn) + if err != nil { + return nil, err + } + + if rep.Status != ReplyStatusSuccess { + return nil, fmt.Errorf("bad reply status: %s", rep.Status) + } + + packetConn, err := net.ListenPacket("udp", "") + if err != nil { + return nil, err + } + + go func() { + io.Copy(io.Discard, conn) + conn.Close() + packetConn.Close() + }() + + remoteAddr, err := AddrFromSocks(rep.ATYP, rep.BindAddr, rep.BindPort).UDP() + if err != nil { + return nil, err + } + + if remoteAddr.IP.IsUnspecified() { + serverAddr, err := net.ResolveUDPAddr("udp", s.Server) + if err != nil { + return nil, fmt.Errorf("resolve udp address %s: %w", s.Server, err) + } + + remoteAddr.IP = serverAddr.IP + } + + udpConn := &UDPConn{ + PacketConn: packetConn, + TCPConn: conn, + ServerAddr: remoteAddr, + } + + if addr != AddrZero { + connectedAddr, err := addr.UDP() + if err != nil { + return nil, err + } + + udpConn.ConnectedAddr = connectedAddr + } + + return udpConn, nil +} + +func (s *Dialer) negotiate(conn net.Conn) error { + method := AuthMethodNone + if s.Username != "" && s.Password != "" { + method = AuthMethodPassword + } + + negReq := NewNegotiationRequest([]AuthMethod{method}) + if _, err := negReq.WriteTo(conn); err != nil { + return err + } + + negRep, err := NewNegotiationReplyFrom(conn) + if err != nil { + return err + } + + if negRep.Method != method { + return ErrAuthMethodNotSupported + } + + if negRep.Method == AuthMethodPassword { + passReq := NewPasswordNegotiationRequest(s.Username, s.Password) + if _, err := passReq.WriteTo(conn); err != nil { + return err + } + + passRep, err := NewPasswordNegotiationReplyFrom(conn) + if err != nil { + return err + } + + if passRep.Status != PasswordStatusSuccess { + return ErrPasswordAuth + } + } + + return nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..885137a --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module go.pkg.cx/socks5 + +go 1.16 diff --git a/negotiation.go b/negotiation.go new file mode 100644 index 0000000..3271227 --- /dev/null +++ b/negotiation.go @@ -0,0 +1,108 @@ +package socks5 + +import ( + "io" +) + +// NegotiationRequest is the negotiation request packet +type NegotiationRequest struct { + NMethods int + Methods []AuthMethod +} + +// NewNegotiationRequest returns new negotiation request packet +func NewNegotiationRequest(methods []AuthMethod) *NegotiationRequest { + return &NegotiationRequest{ + NMethods: len(methods), + Methods: methods, + } +} + +// NewNegotiationRequestFrom reads negotiation request packet from reader +func NewNegotiationRequestFrom(r io.Reader) (*NegotiationRequest, error) { + buf := make([]byte, 2) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + + if buf[0] != Version { + return nil, ErrVersion + } + + methodsLen := int(buf[1]) + if methodsLen == 0 { + return nil, ErrBadRequest + } + + methodsBytes := make([]byte, methodsLen) + if _, err := io.ReadFull(r, methodsBytes); err != nil { + return nil, err + } + + methods := make([]AuthMethod, methodsLen) + for i := range methodsBytes { + methods[i] = AuthMethod(methodsBytes[i]) + } + + req := &NegotiationRequest{ + NMethods: methodsLen, + Methods: methods, + } + + return req, nil +} + +// WriteTo writes negotiation request packet +func (s *NegotiationRequest) WriteTo(w io.Writer) (int64, error) { + buf := []byte{Version, byte(s.NMethods)} + for i := range s.Methods { + buf = append(buf, byte(s.Methods[i])) + } + + n, err := w.Write(buf) + if err != nil { + return int64(n), err + } + + return int64(n), nil +} + +// NegotiationReply is the negotiation reply packet +type NegotiationReply struct { + Method AuthMethod +} + +// NewNegotiationReply returns new negotiation reply packet +func NewNegotiationReply(method AuthMethod) *NegotiationReply { + return &NegotiationReply{ + Method: method, + } +} + +// NewNegotiationReplyFrom reads negotiation reply packet from reader +func NewNegotiationReplyFrom(r io.Reader) (*NegotiationReply, error) { + buf := make([]byte, 2) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + + if buf[0] != Version { + return nil, ErrVersion + } + + rep := &NegotiationReply{ + Method: AuthMethod(buf[1]), + } + + return rep, nil +} + +// WriteTo writes negotiation reply packet +func (s *NegotiationReply) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write([]byte{Version, byte(s.Method)}) + if err != nil { + return int64(n), err + } + + return int64(n), nil +} diff --git a/password_negotiation.go b/password_negotiation.go new file mode 100644 index 0000000..fcce119 --- /dev/null +++ b/password_negotiation.go @@ -0,0 +1,114 @@ +package socks5 + +import ( + "io" +) + +// PasswordNegotiationRequest is the password negotiation reqeust packet +type PasswordNegotiationRequest struct { + UsernameLen byte + Username string + PasswordLen byte + Password string +} + +// NewPasswordNegotiationRequest returns new password negotiation request packet +func NewPasswordNegotiationRequest(username string, password string) *PasswordNegotiationRequest { + return &PasswordNegotiationRequest{ + UsernameLen: byte(len(username)), + Username: username, + PasswordLen: byte(len(password)), + Password: password, + } +} + +// NewPasswordNegotiationRequestFrom reads password negotiation request packet from reader +func NewPasswordNegotiationRequestFrom(r io.Reader) (*PasswordNegotiationRequest, error) { + buf := make([]byte, 2) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + + if buf[0] != PasswordVersion { + return nil, ErrPasswordAuthVersion + } + + if buf[1] == 0 { + return nil, ErrBadRequest + } + + usernameBuf := make([]byte, int(buf[1])+1) + if _, err := io.ReadFull(r, usernameBuf); err != nil { + return nil, err + } + + var passwordBuf []byte + if usernameBuf[int(buf[1])] != 0 { + passwordBuf = make([]byte, int(usernameBuf[int(buf[1])])) + if _, err := io.ReadFull(r, passwordBuf); err != nil { + return nil, err + } + } + + return &PasswordNegotiationRequest{ + UsernameLen: buf[1], + Username: string(usernameBuf[:int(buf[1])]), + PasswordLen: usernameBuf[int(buf[1])], + Password: string(passwordBuf), + }, nil +} + +// WriteTo writes password negotiation request packet +func (s *PasswordNegotiationRequest) WriteTo(w io.Writer) (int64, error) { + buf := []byte{PasswordVersion, s.UsernameLen} + buf = append(buf, []byte(s.Username)...) + buf = append(buf, s.PasswordLen) + buf = append(buf, []byte(s.Password)...) + + n, err := w.Write(buf) + if err != nil { + return int64(n), err + } + + return int64(n), nil +} + +// PasswordNegotiationReply is the password negotiation reply packet +type PasswordNegotiationReply struct { + Status PasswordStatus +} + +// NewPasswordNegotiationReply returns password negotiation reply packet +func NewPasswordNegotiationReply(status PasswordStatus) *PasswordNegotiationReply { + return &PasswordNegotiationReply{ + Status: status, + } +} + +// NewPasswordNegotiationReplyFrom reads password negotiation reply packet from reader +func NewPasswordNegotiationReplyFrom(r io.Reader) (*PasswordNegotiationReply, error) { + buf := make([]byte, 2) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + + if buf[0] != PasswordVersion { + return nil, ErrPasswordAuthVersion + } + + rep := &PasswordNegotiationReply{ + Status: PasswordStatus(buf[1]), + } + + return rep, nil +} + +// WriteTo writes password negotiation reply packet +func (s *PasswordNegotiationReply) WriteTo(w io.Writer) (int64, error) { + n, err := w.Write([]byte{PasswordVersion, byte(s.Status)}) + if err != nil { + return int64(n), err + } + + return int64(n), nil +} diff --git a/reply.go b/reply.go new file mode 100644 index 0000000..088f6c1 --- /dev/null +++ b/reply.go @@ -0,0 +1,68 @@ +package socks5 + +import ( + "io" +) + +// Reply is the reply packet +type Reply struct { + Status ReplyStatus + ATYP ATYP + BindAddr []byte + BindPort []byte +} + +// NewReply returns new reply packet +func NewReply(status ReplyStatus, atyp ATYP, host []byte, port []byte) *Reply { + return &Reply{ + Status: status, + ATYP: atyp, + BindAddr: host, + BindPort: port, + } +} + +// NewReplyFrom reads reply packet from reader +func NewReplyFrom(r io.Reader) (*Reply, error) { + buf := make([]byte, 4) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + + if buf[0] != Version { + return nil, ErrVersion + } + + addr, err := ReadAddress(r, ATYP(buf[3])) + if err != nil { + return nil, err + } + + port := make([]byte, 2) + if _, err := io.ReadFull(r, port); err != nil { + return nil, err + } + + rep := &Reply{ + Status: ReplyStatus(buf[1]), + ATYP: ATYP(buf[3]), + BindAddr: addr, + BindPort: port, + } + + return rep, nil +} + +// WriteTo writes reply packet +func (s *Reply) WriteTo(w io.Writer) (int64, error) { + buf := []byte{Version, byte(s.Status), Reserved, byte(s.ATYP)} + buf = append(buf, s.BindAddr...) + buf = append(buf, s.BindPort...) + + n, err := w.Write(buf) + if err != nil { + return int64(n), err + } + + return int64(n), nil +} diff --git a/request.go b/request.go new file mode 100644 index 0000000..3bcdd2d --- /dev/null +++ b/request.go @@ -0,0 +1,68 @@ +package socks5 + +import ( + "io" +) + +// Request is the request packet +type Request struct { + Command Command + ATYP ATYP + DstAddr []byte + DstPort []byte +} + +// NewRequest returns new request packet +func NewRequest(command Command, atyp ATYP, host []byte, port []byte) *Request { + return &Request{ + Command: command, + ATYP: atyp, + DstAddr: host, + DstPort: port, + } +} + +// NewRequestFrom reads request packet from client +func NewRequestFrom(r io.Reader) (*Request, error) { + buf := make([]byte, 4) + if _, err := io.ReadFull(r, buf); err != nil { + return nil, err + } + + if buf[0] != Version { + return nil, ErrVersion + } + + addr, err := ReadAddress(r, ATYP(buf[3])) + if err != nil { + return nil, err + } + + port := make([]byte, 2) + if _, err := io.ReadFull(r, port); err != nil { + return nil, err + } + + req := &Request{ + Command: Command(buf[1]), + ATYP: ATYP(buf[3]), + DstAddr: addr, + DstPort: port, + } + + return req, nil +} + +// WriteTo writes request packet +func (s *Request) WriteTo(w io.Writer) (int64, error) { + buf := []byte{Version, byte(s.Command), Reserved, byte(s.ATYP)} + buf = append(buf, s.DstAddr...) + buf = append(buf, s.DstPort...) + + n, err := w.Write(buf) + if err != nil { + return int64(n), err + } + + return int64(n), nil +} diff --git a/socks5.go b/socks5.go new file mode 100644 index 0000000..7c03c1f --- /dev/null +++ b/socks5.go @@ -0,0 +1,151 @@ +package socks5 + +import ( + "errors" +) + +const ( + // Version is the protocol version + Version byte = 0x05 + // Reserved is the reserved field value + Reserved byte = 0x00 + // PasswordVersion is the username/password auth protocol version + PasswordVersion byte = 0x01 + // DatagramStandalone is the standalone datagram fragment field value + DatagramStandalone byte = 0x00 +) + +// AuthMethod is the auth method +type AuthMethod byte + +// Auth methods +const ( + AuthMethodNone AuthMethod = 0x00 + AuthMethodGSSAPI AuthMethod = 0x01 + AuthMethodPassword AuthMethod = 0x02 + AuthMethodNotSupported AuthMethod = 0xFF +) + +func (am AuthMethod) String() string { + return authMethodNames[am] +} + +var authMethodNames = map[AuthMethod]string{ + AuthMethodNone: "none", + AuthMethodGSSAPI: "gssapi", + AuthMethodPassword: "password", + AuthMethodNotSupported: "not supported", +} + +// PasswordStatus is the password auth status +type PasswordStatus byte + +// Password statuses +const ( + PasswordStatusSuccess PasswordStatus = 0x00 + PasswordStatusFailure PasswordStatus = 0x01 +) + +func (ps PasswordStatus) String() string { + return passwordStatusNames[ps] +} + +var passwordStatusNames = map[PasswordStatus]string{ + PasswordStatusSuccess: "success", + PasswordStatusFailure: "failure", +} + +// Command is the command +type Command byte + +// Commands +const ( + CommandConnect Command = 0x01 + CommandBind Command = 0x02 + CommandUDPAssociate Command = 0x03 +) + +func (c Command) String() string { + return commandNames[c] +} + +var commandNames = map[Command]string{ + CommandConnect: "connect", + CommandBind: "bind", + CommandUDPAssociate: "udp associate", +} + +// ATYP is the address type +type ATYP byte + +// Address types +const ( + ATYPIPv4 ATYP = 0x01 + ATYPFQDN ATYP = 0x03 + ATYPIPv6 ATYP = 0x04 +) + +func (a ATYP) String() string { + return atypNames[a] +} + +var atypNames = map[ATYP]string{ + ATYPIPv4: "ipv4", + ATYPFQDN: "fqdn", + ATYPIPv6: "ipv6", +} + +// ReplyStatus is the reply status +type ReplyStatus byte + +// Reply statuses +const ( + ReplyStatusSuccess ReplyStatus = 0x00 + ReplyStatusServerFailure ReplyStatus = 0x01 + ReplyStatusNotAllowed ReplyStatus = 0x02 + ReplyStatusNetworkUnreachable ReplyStatus = 0x03 + ReplyStatusHostUnreachable ReplyStatus = 0x04 + ReplyStatusConnectionRefused ReplyStatus = 0x05 + ReplyStatusTTLExpired ReplyStatus = 0x06 + ReplyStatusCommandNotSupported ReplyStatus = 0x07 + ReplyStatusAddressNotSupported ReplyStatus = 0x08 +) + +func (rs ReplyStatus) String() string { + return replyStatusNames[rs] +} + +var replyStatusNames = map[ReplyStatus]string{ + ReplyStatusSuccess: "success", + ReplyStatusServerFailure: "server failure", + ReplyStatusNotAllowed: "not allowed", + ReplyStatusNetworkUnreachable: "network unreachable", + ReplyStatusHostUnreachable: "host unreachable", + ReplyStatusConnectionRefused: "connection refused", + ReplyStatusTTLExpired: "TTL expired", + ReplyStatusCommandNotSupported: "command not supported", + ReplyStatusAddressNotSupported: "address not supported", +} + +var ( + // ErrVersion is the protocol version error + ErrVersion = errors.New("invalid protocol version") + // ErrBadReply is the bad reply error + ErrBadReply = errors.New("bad reply") + // ErrBadRequest is the bad request error + ErrBadRequest = errors.New("bad request") + // ErrBadDatagram is the bad datagram error + ErrBadDatagram = errors.New("bad datagram") + // ErrFragmentedDatagram us the fragmented datagram error + ErrFragmentedDatagram = errors.New("fragmented datagram") + // ErrIPAuth is the invalid ip error + ErrIPAuth = errors.New("invalid ip auth") + // ErrPasswordAuth is the invalid username or password error + ErrPasswordAuth = errors.New("invalid username or password") + // ErrAuthMethodNotSupported is the error when auth method not supported + ErrAuthMethodNotSupported = errors.New("auth method not supported") + // ErrCommandNotSupported is the error when command not supported + ErrCommandNotSupported = errors.New("command not supported") + // ErrPasswordAuthVersion is the password auth version error + ErrPasswordAuthVersion = errors.New("invalid password auth version") +) diff --git a/udp.go b/udp.go new file mode 100644 index 0000000..f5fe64b --- /dev/null +++ b/udp.go @@ -0,0 +1,68 @@ +package socks5 + +import ( + "bytes" + "errors" + "net" +) + +var ( + // ErrNoConnectedAddress is the no connected address error + ErrNoConnectedAddress = errors.New("no connected address") +) + +type UDPConn struct { + net.PacketConn + + TCPConn net.Conn + ServerAddr net.Addr + ConnectedAddr net.Addr +} + +func (s *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) { + _, _, err := s.PacketConn.ReadFrom(b) + if err != nil { + return 0, nil, err + } + + datagram, err := NewDatagramFrom(bytes.NewReader(b)) + if err != nil { + return 0, nil, err + } + + n := copy(b, datagram.Data) + return n, AddrFromSocks(datagram.ATYP, datagram.DstAddr, datagram.DstPort), nil +} + +func (s *UDPConn) Read(b []byte) (int, error) { + n, _, err := s.ReadFrom(b) + return n, err +} + +func (s *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { + socksAddr, err := AddrFromString(addr.String()) + if err != nil { + return 0, err + } + + atyp, host, port := socksAddr.Socks() + datagram := NewDatagram(atyp, host, port, b) + return s.PacketConn.WriteTo(datagram.Bytes(), s.ServerAddr) +} + +func (s *UDPConn) Write(b []byte) (int, error) { + if s.ConnectedAddr == nil { + return 0, ErrNoConnectedAddress + } + + return s.WriteTo(b, s.ConnectedAddr) +} + +func (s *UDPConn) Close() error { + s.TCPConn.Close() + return s.PacketConn.Close() +} + +func (s *UDPConn) RemoteAddr() net.Addr { + return s.ConnectedAddr +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..6b4565b --- /dev/null +++ b/util.go @@ -0,0 +1,92 @@ +package socks5 + +import ( + "encoding/binary" + "io" + "net" + "strconv" +) + +// NewEmptyAddrReply returns reply with empty address and port +func NewEmptyAddrReply(status ReplyStatus, atyp ATYP) *Reply { + if atyp == ATYPIPv6 { + return NewReply(status, ATYPIPv6, []byte(net.IPv6zero), []byte{0x00, 0x00}) + } + + return NewReply(status, ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00}) +} + +// ReadAddress reads address from reader +func ReadAddress(r io.Reader, atyp ATYP) ([]byte, error) { + var addr []byte + + //nolint:exhaustive + switch atyp { + case ATYPIPv4: + addr = make([]byte, 4) + if _, err := io.ReadFull(r, addr); err != nil { + return nil, err + } + case ATYPIPv6: + addr = make([]byte, 16) + if _, err := io.ReadFull(r, addr); err != nil { + return nil, err + } + case ATYPFQDN: + domainLen := make([]byte, 1) + if _, err := io.ReadFull(r, domainLen); err != nil { + return nil, err + } + + if domainLen[0] == 0 { + return nil, ErrBadRequest + } + + addr = make([]byte, int(domainLen[0])) + if _, err := io.ReadFull(r, addr); err != nil { + return nil, err + } + + addr = append(domainLen, addr...) //nolint:makezero + default: + return nil, ErrBadRequest + } + + return addr, nil +} + +// ParseAddress parses address from string +func ParseAddress(address string) (ATYP, []byte, []byte, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return 0x00, []byte{}, []byte{}, err + } + + var ( + atyp ATYP + addr []byte + ) + + ip := net.ParseIP(host) + if ip4 := ip.To4(); ip4 != nil { + atyp = ATYPIPv4 + addr = []byte(ip4) + } else if ip6 := ip.To16(); ip6 != nil { + atyp = ATYPIPv6 + addr = []byte(ip6) + } else { + atyp = ATYPFQDN + addr = []byte{byte(len(host))} + addr = append(addr, []byte(host)...) + } + + portInt, err := strconv.Atoi(port) + if err != nil { + return 0x00, []byte{}, []byte{}, err + } + + portBytes := make([]byte, 2) + binary.BigEndian.PutUint16(portBytes, uint16(portInt)) + + return atyp, addr, portBytes, nil +}