diff --git a/client/internal/engine.go b/client/internal/engine.go index 74a07927c..c377c12e1 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1231,36 +1231,19 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer PreSharedKey: e.config.PreSharedKey, } - if e.config.RosenpassEnabled && !e.config.RosenpassPermissive { - lk := []byte(e.config.WgPrivateKey.PublicKey().String()) - rk := []byte(wgConfig.RemoteKey) - var keyInput []byte - if string(lk) > string(rk) { - //nolint:gocritic - keyInput = append(lk[:16], rk[:16]...) - } else { - //nolint:gocritic - keyInput = append(rk[:16], lk[:16]...) - } - - key, err := wgtypes.NewKey(keyInput) - if err != nil { - return nil, err - } - - wgConfig.PreSharedKey = &key - } - // randomize connection timeout timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond config := peer.ConnConfig{ - Key: pubKey, - LocalKey: e.config.WgPrivateKey.PublicKey().String(), - Timeout: timeout, - WgConfig: wgConfig, - LocalWgPort: e.config.WgPort, - RosenpassPubKey: e.getRosenpassPubKey(), - RosenpassAddr: e.getRosenpassAddr(), + Key: pubKey, + LocalKey: e.config.WgPrivateKey.PublicKey().String(), + Timeout: timeout, + WgConfig: wgConfig, + LocalWgPort: e.config.WgPort, + RosenpassConfig: peer.RosenpassConfig{ + PubKey: e.getRosenpassPubKey(), + Addr: e.getRosenpassAddr(), + PermissiveMode: e.config.RosenpassPermissive, + }, ICEConfig: icemaker.Config{ StunTurn: &e.stunTurn, InterfaceBlackList: e.config.IFaceBlackList, diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 85f94b53f..44e8997bc 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -60,6 +60,15 @@ type WgConfig struct { PreSharedKey *wgtypes.Key } +type RosenpassConfig struct { + // RosenpassPubKey is this peer's Rosenpass public key + PubKey []byte + // RosenpassPubKey is this peer's RosenpassAddr server address (IP:port) + Addr string + + PermissiveMode bool +} + // ConnConfig is a peer Connection configuration type ConnConfig struct { // Key is a public key of a remote peer @@ -73,10 +82,7 @@ type ConnConfig struct { LocalWgPort int - // RosenpassPubKey is this peer's Rosenpass public key - RosenpassPubKey []byte - // RosenpassPubKey is this peer's RosenpassAddr server address (IP:port) - RosenpassAddr string + RosenpassConfig RosenpassConfig // ICEConfig ICE protocol configuration ICEConfig icemaker.Config @@ -109,6 +115,8 @@ type Conn struct { connIDICE nbnet.ConnectionID beforeAddPeerHooks []nbnet.AddHookFunc afterRemovePeerHooks []nbnet.RemoveHookFunc + // used to store the remote Rosenpass key for Relayed connection in case of connection update from ice + rosenpassRemoteKey []byte wgProxyICE wgproxy.Proxy wgProxyRelay wgproxy.Proxy @@ -375,7 +383,7 @@ func (conn *Conn) onICEConnectionIsReady(priority ConnPriority, iceConnInfo ICEC wgProxy.Work() } - if err = conn.configureWGEndpoint(ep); err != nil { + if err = conn.configureWGEndpoint(ep, iceConnInfo.RosenpassPubKey); err != nil { conn.handleConfigurationFailure(err, wgProxy) return } @@ -408,7 +416,7 @@ func (conn *Conn) onICEStateDisconnected() { conn.dumpState.SwitchToRelay() conn.wgProxyRelay.Work() - if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil { + if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), conn.rosenpassRemoteKey); err != nil { conn.log.Errorf("failed to switch to relay conn: %v", err) } @@ -478,7 +486,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { } wgProxy.Work() - if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil { + if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.log.Warnf("Failed to close relay connection: %v", err) } @@ -493,6 +501,7 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { }() wgConfigWorkaround() + conn.rosenpassRemoteKey = rci.rosenpassPubKey conn.currentConnPriority = connPriorityRelay conn.statusRelay.Set(StatusConnected) conn.setRelayedProxy(wgProxy) @@ -556,13 +565,14 @@ func (conn *Conn) listenGuardEvent(ctx context.Context) { } } -func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error { +func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr, remoteRPKey []byte) error { + presharedKey := conn.presharedKey(remoteRPKey) return conn.config.WgConfig.WgInterface.UpdatePeer( conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, addr, - conn.config.WgConfig.PreSharedKey, + presharedKey, ) } @@ -783,6 +793,44 @@ func (conn *Conn) AllowedIP() netip.Addr { return conn.config.WgConfig.AllowedIps[0].Addr() } +func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key { + if conn.config.RosenpassConfig.PubKey == nil { + return conn.config.WgConfig.PreSharedKey + } + + if remoteRosenpassKey == nil && conn.config.RosenpassConfig.PermissiveMode { + return conn.config.WgConfig.PreSharedKey + } + + determKey, err := conn.rosenpassDetermKey() + if err != nil { + conn.log.Errorf("failed to generate Rosenpass initial key: %v", err) + return conn.config.WgConfig.PreSharedKey + } + + return determKey +} + +// todo: move this logic into Rosenpass package +func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) { + lk := []byte(conn.config.LocalKey) + rk := []byte(conn.config.Key) // remote key + var keyInput []byte + if string(lk) > string(rk) { + //nolint:gocritic + keyInput = append(lk[:16], rk[:16]...) + } else { + //nolint:gocritic + keyInput = append(rk[:16], lk[:16]...) + } + + key, err := wgtypes.NewKey(keyInput) + if err != nil { + return nil, err + } + return &key, nil +} + func isController(config ConnConfig) bool { return config.LocalKey > config.Key } diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index 505bedb7f..6d55cfff4 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -2,6 +2,7 @@ package peer import ( "context" + "fmt" "os" "sync" "testing" @@ -161,3 +162,145 @@ func TestConn_Status(t *testing.T) { }) } } + +func TestConn_presharedKey(t *testing.T) { + conn1 := Conn{ + config: ConnConfig{ + Key: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + LocalKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + RosenpassConfig: RosenpassConfig{}, + }, + } + conn2 := Conn{ + config: ConnConfig{ + Key: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + LocalKey: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + RosenpassConfig: RosenpassConfig{}, + }, + } + + tests := []struct { + conn1Permissive bool + conn1RosenpassEnabled bool + conn2Permissive bool + conn2RosenpassEnabled bool + conn1ExpectedInitialKey bool + conn2ExpectedInitialKey bool + }{ + { + conn1Permissive: false, + conn1RosenpassEnabled: false, + conn2Permissive: false, + conn2RosenpassEnabled: false, + conn1ExpectedInitialKey: false, + conn2ExpectedInitialKey: false, + }, + { + conn1Permissive: false, + conn1RosenpassEnabled: true, + conn2Permissive: false, + conn2RosenpassEnabled: true, + conn1ExpectedInitialKey: true, + conn2ExpectedInitialKey: true, + }, + { + conn1Permissive: false, + conn1RosenpassEnabled: true, + conn2Permissive: false, + conn2RosenpassEnabled: false, + conn1ExpectedInitialKey: true, + conn2ExpectedInitialKey: false, + }, + { + conn1Permissive: false, + conn1RosenpassEnabled: false, + conn2Permissive: false, + conn2RosenpassEnabled: true, + conn1ExpectedInitialKey: false, + conn2ExpectedInitialKey: true, + }, + { + conn1Permissive: true, + conn1RosenpassEnabled: true, + conn2Permissive: false, + conn2RosenpassEnabled: false, + conn1ExpectedInitialKey: false, + conn2ExpectedInitialKey: false, + }, + { + conn1Permissive: false, + conn1RosenpassEnabled: false, + conn2Permissive: true, + conn2RosenpassEnabled: true, + conn1ExpectedInitialKey: false, + conn2ExpectedInitialKey: false, + }, + { + conn1Permissive: true, + conn1RosenpassEnabled: true, + conn2Permissive: true, + conn2RosenpassEnabled: true, + conn1ExpectedInitialKey: true, + conn2ExpectedInitialKey: true, + }, + { + conn1Permissive: false, + conn1RosenpassEnabled: false, + conn2Permissive: false, + conn2RosenpassEnabled: true, + conn1ExpectedInitialKey: false, + conn2ExpectedInitialKey: true, + }, + { + conn1Permissive: false, + conn1RosenpassEnabled: true, + conn2Permissive: true, + conn2RosenpassEnabled: true, + conn1ExpectedInitialKey: true, + conn2ExpectedInitialKey: true, + }, + } + + conn1.config.RosenpassConfig.PermissiveMode = true + for i, test := range tests { + tcase := i + 1 + t.Run(fmt.Sprintf("Rosenpass test case %d", tcase), func(t *testing.T) { + conn1.config.RosenpassConfig = RosenpassConfig{} + conn2.config.RosenpassConfig = RosenpassConfig{} + + if test.conn1RosenpassEnabled { + conn1.config.RosenpassConfig.PubKey = []byte("dummykey") + } + conn1.config.RosenpassConfig.PermissiveMode = test.conn1Permissive + + if test.conn2RosenpassEnabled { + conn2.config.RosenpassConfig.PubKey = []byte("dummykey") + } + conn2.config.RosenpassConfig.PermissiveMode = test.conn2Permissive + + conn1PresharedKey := conn1.presharedKey(conn2.config.RosenpassConfig.PubKey) + conn2PresharedKey := conn2.presharedKey(conn1.config.RosenpassConfig.PubKey) + + if test.conn1ExpectedInitialKey { + if conn1PresharedKey == nil { + t.Errorf("Case %d: Expected conn1 to have a non-nil key, but got nil", tcase) + } + } else { + if conn1PresharedKey != nil { + t.Errorf("Case %d: Expected conn1 to have a nil key, but got %v", tcase, conn1PresharedKey) + } + } + + // Assert conn2's key expectation + if test.conn2ExpectedInitialKey { + if conn2PresharedKey == nil { + t.Errorf("Case %d: Expected conn2 to have a non-nil key, but got nil", tcase) + } + } else { + if conn2PresharedKey != nil { + t.Errorf("Case %d: Expected conn2 to have a nil key, but got %v", tcase, conn2PresharedKey) + } + } + }) + } +} diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index d23727e96..224ea0262 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -154,8 +154,8 @@ func (h *Handshaker) sendOffer() error { IceCredentials: IceCredentials{iceUFrag, icePwd}, WgListenPort: h.config.LocalWgPort, Version: version.NetbirdVersion(), - RosenpassPubKey: h.config.RosenpassPubKey, - RosenpassAddr: h.config.RosenpassAddr, + RosenpassPubKey: h.config.RosenpassConfig.PubKey, + RosenpassAddr: h.config.RosenpassConfig.Addr, } addr, err := h.relay.RelayInstanceAddress() @@ -174,8 +174,8 @@ func (h *Handshaker) sendAnswer() error { IceCredentials: IceCredentials{uFrag, pwd}, WgListenPort: h.config.LocalWgPort, Version: version.NetbirdVersion(), - RosenpassPubKey: h.config.RosenpassPubKey, - RosenpassAddr: h.config.RosenpassAddr, + RosenpassPubKey: h.config.RosenpassConfig.PubKey, + RosenpassAddr: h.config.RosenpassConfig.Addr, } addr, err := h.relay.RelayInstanceAddress() if err == nil {