From 31bc4b683309b345a3e6e9c02cd247ff017cdc31 Mon Sep 17 00:00:00 2001 From: Michael Quigley Date: Wed, 7 Feb 2024 13:59:34 -0500 Subject: [PATCH] initial socks backend (#558) --- endpoints/socks/backend.go | 53 +++++ endpoints/socks/socks5.go | 403 +++++++++++++++++++++++++++++++++++++ 2 files changed, 456 insertions(+) create mode 100644 endpoints/socks/backend.go create mode 100755 endpoints/socks/socks5.go diff --git a/endpoints/socks/backend.go b/endpoints/socks/backend.go new file mode 100644 index 00000000..1ad152b4 --- /dev/null +++ b/endpoints/socks/backend.go @@ -0,0 +1,53 @@ +package socks + +import ( + "github.com/openziti/sdk-golang/ziti" + "github.com/openziti/sdk-golang/ziti/edge" + "github.com/openziti/zrok/endpoints" + "github.com/pkg/errors" + "time" +) + +type BackendConfig struct { + IdentityPath string + ShrToken string + Requests chan *endpoints.Request +} + +type Backend struct { + cfg *BackendConfig + listener edge.Listener + server *Server +} + +func NewBackend(cfg *BackendConfig) (*Backend, error) { + options := ziti.ListenOptions{ + ConnectTimeout: 5 * time.Minute, + WaitForNEstablishedListeners: 1, + } + zcfg, err := ziti.NewConfigFromFile(cfg.IdentityPath) + if err != nil { + return nil, errors.Wrap(err, "error loading ziti identity") + } + zctx, err := ziti.NewContext(zcfg) + if err != nil { + return nil, errors.Wrap(err, "error loading ziti context") + } + listener, err := zctx.ListenWithOptions(cfg.ShrToken, &options) + if err != nil { + return nil, err + } + + return &Backend{ + cfg: cfg, + listener: listener, + server: &Server{}, + }, nil +} + +func (b *Backend) Run() error { + if err := b.server.Serve(b.listener); err != nil { + return err + } + return nil +} diff --git a/endpoints/socks/socks5.go b/endpoints/socks/socks5.go new file mode 100755 index 00000000..8b1d1593 --- /dev/null +++ b/endpoints/socks/socks5.go @@ -0,0 +1,403 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package socks5 is a SOCKS5 server implementation. +// +// This is used for userspace networking in Tailscale. Specifically, +// this is used for dialing out of the machine to other nodes, without +// the host kernel's involvement, so it doesn't proper routing tables, +// TUN, IPv6, etc. This package is meant to only handle the SOCKS5 protocol +// details and not any integration with Tailscale internals itself. +// +// The glue between this package and Tailscale is in net/socks5/tssocks. +package socks + +import ( + "context" + "encoding/binary" + "fmt" + "github.com/sirupsen/logrus" + "io" + "net" + "strconv" + "time" +) + +// Authentication METHODs described in RFC 1928, section 3. +const ( + noAuthRequired byte = 0 + passwordAuth byte = 2 + noAcceptableAuth byte = 255 +) + +// passwordAuthVersion is the auth version byte described in RFC 1929. +const passwordAuthVersion = 1 + +// socks5Version is the byte that represents the SOCKS version +// in requests. +const socks5Version byte = 5 + +// commandType are the bytes sent in SOCKS5 packets +// that represent the kind of connection the client needs. +type commandType byte + +// The set of valid SOCKS5 commands as described in RFC 1928. +const ( + connect commandType = 1 + bind commandType = 2 + udpAssociate commandType = 3 +) + +// addrType are the bytes sent in SOCKS5 packets +// that represent particular address types. +type addrType byte + +// The set of valid SOCKS5 address types as defined in RFC 1928. +const ( + ipv4 addrType = 1 + domainName addrType = 3 + ipv6 addrType = 4 +) + +// replyCode are the bytes sent in SOCKS5 packets +// that represent replies from the server to a client +// request. +type replyCode byte + +// The set of valid SOCKS5 reply types as per the RFC 1928. +const ( + success replyCode = 0 + generalFailure replyCode = 1 + connectionNotAllowed replyCode = 2 + networkUnreachable replyCode = 3 + hostUnreachable replyCode = 4 + connectionRefused replyCode = 5 + ttlExpired replyCode = 6 + commandNotSupported replyCode = 7 + addrTypeNotSupported replyCode = 8 +) + +// Server is a SOCKS5 proxy server. +type Server struct { + // Dialer optionally specifies the dialer to use for outgoing connections. + // If nil, the net package's standard dialer is used. + Dialer func(ctx context.Context, network, addr string) (net.Conn, error) + + // Username and Password, if set, are the credential clients must provide. + Username string + Password string +} + +func (s *Server) dial(ctx context.Context, network, addr string) (net.Conn, error) { + dial := s.Dialer + if dial == nil { + dialer := &net.Dialer{} + dial = dialer.DialContext + } + return dial(ctx, network, addr) +} + +// Serve accepts and handles incoming connections on the given listener. +func (s *Server) Serve(l net.Listener) error { + defer l.Close() + for { + c, err := l.Accept() + if err != nil { + return err + } + go func() { + defer c.Close() + conn := &Conn{clientConn: c, srv: s} + err := conn.Run() + if err != nil { + logrus.Infof("client connection failed: %v", err) + } + }() + } +} + +// Conn is a SOCKS5 connection for client to reach +// server. +type Conn struct { + // The struct is filled by each of the internal + // methods in turn as the transaction progresses. + + srv *Server + clientConn net.Conn + request *request +} + +// Run starts the new connection. +func (c *Conn) Run() error { + needAuth := c.srv.Username != "" || c.srv.Password != "" + authMethod := noAuthRequired + if needAuth { + authMethod = passwordAuth + } + + err := parseClientGreeting(c.clientConn, authMethod) + if err != nil { + c.clientConn.Write([]byte{socks5Version, noAcceptableAuth}) + return err + } + c.clientConn.Write([]byte{socks5Version, authMethod}) + if !needAuth { + return c.handleRequest() + } + + user, pwd, err := parseClientAuth(c.clientConn) + if err != nil || user != c.srv.Username || pwd != c.srv.Password { + c.clientConn.Write([]byte{1, 1}) // auth error + return err + } + c.clientConn.Write([]byte{1, 0}) // auth success + + return c.handleRequest() +} + +func (c *Conn) handleRequest() error { + req, err := parseClientRequest(c.clientConn) + if err != nil { + res := &response{reply: generalFailure} + buf, _ := res.marshal() + c.clientConn.Write(buf) + return err + } + if req.command != connect { + res := &response{reply: commandNotSupported} + buf, _ := res.marshal() + c.clientConn.Write(buf) + return fmt.Errorf("unsupported command %v", req.command) + } + c.request = req + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + srv, err := c.srv.dial( + ctx, + "tcp", + net.JoinHostPort(c.request.destination, strconv.Itoa(int(c.request.port))), + ) + if err != nil { + res := &response{reply: generalFailure} + buf, _ := res.marshal() + c.clientConn.Write(buf) + return err + } + defer srv.Close() + serverAddr, serverPortStr, err := net.SplitHostPort(srv.LocalAddr().String()) + if err != nil { + return err + } + serverPort, _ := strconv.Atoi(serverPortStr) + + var bindAddrType addrType + if ip := net.ParseIP(serverAddr); ip != nil { + if ip.To4() != nil { + bindAddrType = ipv4 + } else { + bindAddrType = ipv6 + } + } else { + bindAddrType = domainName + } + res := &response{ + reply: success, + bindAddrType: bindAddrType, + bindAddr: serverAddr, + bindPort: uint16(serverPort), + } + buf, err := res.marshal() + if err != nil { + res = &response{reply: generalFailure} + buf, _ = res.marshal() + } + c.clientConn.Write(buf) + + errc := make(chan error, 2) + go func() { + _, err := io.Copy(c.clientConn, srv) + if err != nil { + err = fmt.Errorf("from backend to client: %w", err) + } + errc <- err + }() + go func() { + _, err := io.Copy(srv, c.clientConn) + if err != nil { + err = fmt.Errorf("from client to backend: %w", err) + } + errc <- err + }() + return <-errc +} + +// parseClientGreeting parses a request initiation packet. +func parseClientGreeting(r io.Reader, authMethod byte) error { + var hdr [2]byte + _, err := io.ReadFull(r, hdr[:]) + if err != nil { + return fmt.Errorf("could not read packet header") + } + if hdr[0] != socks5Version { + return fmt.Errorf("incompatible SOCKS version") + } + count := int(hdr[1]) + methods := make([]byte, count) + _, err = io.ReadFull(r, methods) + if err != nil { + return fmt.Errorf("could not read methods") + } + for _, m := range methods { + if m == authMethod { + return nil + } + } + return fmt.Errorf("no acceptable auth methods") +} + +func parseClientAuth(r io.Reader) (usr, pwd string, err error) { + var hdr [2]byte + if _, err := io.ReadFull(r, hdr[:]); err != nil { + return "", "", fmt.Errorf("could not read auth packet header") + } + if hdr[0] != passwordAuthVersion { + return "", "", fmt.Errorf("bad SOCKS auth version") + } + usrLen := int(hdr[1]) + usrBytes := make([]byte, usrLen) + if _, err := io.ReadFull(r, usrBytes); err != nil { + return "", "", fmt.Errorf("could not read auth packet username") + } + var hdrPwd [1]byte + if _, err := io.ReadFull(r, hdrPwd[:]); err != nil { + return "", "", fmt.Errorf("could not read auth packet password length") + } + pwdLen := int(hdrPwd[0]) + pwdBytes := make([]byte, pwdLen) + if _, err := io.ReadFull(r, pwdBytes); err != nil { + return "", "", fmt.Errorf("could not read auth packet password") + } + return string(usrBytes), string(pwdBytes), nil +} + +// request represents data contained within a SOCKS5 +// connection request packet. +type request struct { + command commandType + destination string + port uint16 + destAddrType addrType +} + +// parseClientRequest converts raw packet bytes into a +// SOCKS5Request struct. +func parseClientRequest(r io.Reader) (*request, error) { + var hdr [4]byte + _, err := io.ReadFull(r, hdr[:]) + if err != nil { + return nil, fmt.Errorf("could not read packet header") + } + cmd := hdr[1] + destAddrType := addrType(hdr[3]) + + var destination string + var port uint16 + + if destAddrType == ipv4 { + var ip [4]byte + _, err = io.ReadFull(r, ip[:]) + if err != nil { + return nil, fmt.Errorf("could not read IPv4 address") + } + destination = net.IP(ip[:]).String() + } else if destAddrType == domainName { + var dstSizeByte [1]byte + _, err = io.ReadFull(r, dstSizeByte[:]) + if err != nil { + return nil, fmt.Errorf("could not read domain name size") + } + dstSize := int(dstSizeByte[0]) + domainName := make([]byte, dstSize) + _, err = io.ReadFull(r, domainName) + if err != nil { + return nil, fmt.Errorf("could not read domain name") + } + destination = string(domainName) + } else if destAddrType == ipv6 { + var ip [16]byte + _, err = io.ReadFull(r, ip[:]) + if err != nil { + return nil, fmt.Errorf("could not read IPv6 address") + } + destination = net.IP(ip[:]).String() + } else { + return nil, fmt.Errorf("unsupported address type") + } + var portBytes [2]byte + _, err = io.ReadFull(r, portBytes[:]) + if err != nil { + return nil, fmt.Errorf("could not read port") + } + port = binary.BigEndian.Uint16(portBytes[:]) + + return &request{ + command: commandType(cmd), + destination: destination, + port: port, + destAddrType: destAddrType, + }, nil +} + +// response contains the contents of +// a response packet sent from the proxy +// to the client. +type response struct { + reply replyCode + bindAddrType addrType + bindAddr string + bindPort uint16 +} + +// marshal converts a SOCKS5Response struct into +// a packet. If res.reply == Success, it may throw an error on +// receiving an invalid bind address. Otherwise, it will not throw. +func (res *response) marshal() ([]byte, error) { + pkt := make([]byte, 4) + pkt[0] = socks5Version + pkt[1] = byte(res.reply) + pkt[2] = 0 // null reserved byte + pkt[3] = byte(res.bindAddrType) + + if res.reply != success { + return pkt, nil + } + + var addr []byte + switch res.bindAddrType { + case ipv4: + addr = net.ParseIP(res.bindAddr).To4() + if addr == nil { + return nil, fmt.Errorf("invalid IPv4 address for binding") + } + case domainName: + if len(res.bindAddr) > 255 { + return nil, fmt.Errorf("invalid domain name for binding") + } + addr = make([]byte, 0, len(res.bindAddr)+1) + addr = append(addr, byte(len(res.bindAddr))) + addr = append(addr, []byte(res.bindAddr)...) + case ipv6: + addr = net.ParseIP(res.bindAddr).To16() + if addr == nil { + return nil, fmt.Errorf("invalid IPv6 address for binding") + } + default: + return nil, fmt.Errorf("unsupported address type") + } + + pkt = append(pkt, addr...) + pkt = binary.BigEndian.AppendUint16(pkt, uint16(res.bindPort)) + + return pkt, nil +}