[client] Add NB_SKIP_SOCKET_MARK & fix crash instead of returing an error (#2899)

* dialer: fix crash instead of returning error

* add NB_SKIP_SOCKET_MARK
This commit is contained in:
Krzysztof Nazarewski (kdn) 2024-11-19 14:14:58 +01:00 committed by GitHub
parent 52ea2e84e9
commit eb5d0569ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 21 additions and 4 deletions

View File

@ -55,7 +55,7 @@ type ruleParams struct {
// isLegacy determines whether to use the legacy routing setup // isLegacy determines whether to use the legacy routing setup
func isLegacy() bool { func isLegacy() bool {
return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true"
} }
// setIsLegacy sets the legacy routing setup // setIsLegacy sets the legacy routing setup

View File

@ -3,6 +3,9 @@ package grpc
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net" "net"
"os/user" "os/user"
"runtime" "runtime"
@ -23,20 +26,22 @@ func WithCustomDialer() grpc.DialOption {
if runtime.GOOS == "linux" { if runtime.GOOS == "linux" {
currentUser, err := user.Current() currentUser, err := user.Current()
if err != nil { if err != nil {
log.Fatalf("failed to get current user: %v", err) return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
} }
// the custom dialer requires root permissions which are not required for use cases run as non-root // the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" { if currentUser.Uid != "0" {
log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{} dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr) return dialer.DialContext(ctx, "tcp", addr)
} }
} }
log.Debug("Using nbnet.NewDialer()")
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil { if err != nil {
log.Errorf("Failed to dial: %s", err) log.Errorf("Failed to dial: %s", err)
return nil, err return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
} }
return conn, nil return conn, nil
}) })

View File

@ -69,7 +69,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
conn, err := d.Dialer.DialContext(ctx, network, address) conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil { if err != nil {
return nil, fmt.Errorf("dial: %w", err) return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
} }
// Wrap the connection in Conn to handle Close with hooks // Wrap the connection in Conn to handle Close with hooks

View File

@ -4,9 +4,14 @@ package net
import ( import (
"fmt" "fmt"
"os"
"syscall" "syscall"
log "github.com/sirupsen/logrus"
) )
const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK"
// SetSocketMark sets the SO_MARK option on the given socket connection // SetSocketMark sets the SO_MARK option on the given socket connection
func SetSocketMark(conn syscall.Conn) error { func SetSocketMark(conn syscall.Conn) error {
sysconn, err := conn.SyscallConn() sysconn, err := conn.SyscallConn()
@ -36,6 +41,13 @@ func SetRawSocketMark(conn syscall.RawConn) error {
func SetSocketOpt(fd int) error { func SetSocketOpt(fd int) error {
if CustomRoutingDisabled() { if CustomRoutingDisabled() {
log.Infof("Custom routing is disabled, skipping SO_MARK")
return nil
}
// Check for the new environment variable
if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" {
log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK")
return nil return nil
} }