NetBird SSH (#361)

This PR adds support for SSH access through the NetBird network
without managing SSH skeys.
NetBird client app has an embedded SSH server (Linux/Mac only) 
and a netbird ssh command.
This commit is contained in:
Misha Bragin 2022-06-23 17:04:53 +02:00 committed by GitHub
parent f883a10535
commit 06860c4c10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 1702 additions and 349 deletions

View File

@ -94,6 +94,7 @@ func init() {
rootCmd.AddCommand(statusCmd) rootCmd.AddCommand(statusCmd)
rootCmd.AddCommand(loginCmd) rootCmd.AddCommand(loginCmd)
rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
} }

137
client/cmd/ssh.go Normal file
View File

@ -0,0 +1,137 @@
package cmd
import (
"context"
"errors"
"fmt"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/util"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"os"
"os/signal"
"strings"
"syscall"
)
var (
port int
user = "netbird"
host string
)
var sshCmd = &cobra.Command{
Use: "ssh",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New("requires a host argument")
}
split := strings.Split(args[0], "@")
if len(split) == 2 {
user = split[0]
host = split[1]
} else {
host = args[0]
}
return nil
},
Short: "connect to a remote SSH server",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars()
cmd.SetOut(cmd.OutOrStdout())
err := util.InitLog(logLevel, "console")
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
if !util.IsAdmin() {
cmd.Printf("error: you must have Administrator privileges to run this command\n")
return nil
}
ctx := internal.CtxInitState(cmd.Context())
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
defer func() {
err := conn.Close()
if err != nil {
log.Warnf("failed closing dameon gRPC client connection %v", err)
return
}
}()
client := proto.NewDaemonServiceClient(conn)
status, err := client.Status(ctx, &proto.StatusRequest{})
if err != nil {
return fmt.Errorf("unable to get daemon status: %v", err)
}
if status.Status != string(internal.StatusConnected) {
// todo maybe automatically start it?
cmd.Printf("You are disconnected from the NetBird network. Please run the UP command first to connect: \n\n" +
" netbird up \n\n")
return nil
}
config, err := internal.ReadConfig("", "", configPath, nil)
if err != nil {
return err
}
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
sshctx, cancel := context.WithCancel(ctx)
go func() {
if err := runSSH(sshctx, host, []byte(config.SSHKey)); err != nil {
log.Print(err)
}
cancel()
}()
select {
case <-sig:
cancel()
case <-sshctx.Done():
}
return nil
},
}
func runSSH(ctx context.Context, addr string, pemKey []byte) error {
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
if err != nil {
return err
}
go func() {
<-ctx.Done()
err = c.Close()
if err != nil {
return
}
}()
err = c.OpenTerminal()
if err != nil {
return err
}
return nil
}
func init() {
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort))
}

View File

@ -3,16 +3,16 @@ package internal
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/iface"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/util"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"net/url" "net/url"
"os" "os"
"github.com/netbirdio/netbird/iface"
"github.com/netbirdio/netbird/util"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
) )
var managementURLDefault *url.URL var managementURLDefault *url.URL
@ -38,12 +38,18 @@ type Config struct {
AdminURL *url.URL AdminURL *url.URL
WgIface string WgIface string
IFaceBlackList []string IFaceBlackList []string
// SSHKey is a private SSH key in a PEM format
SSHKey string
} }
// createNewConfig creates a new config generating a new Wireguard key and saving to file // createNewConfig creates a new config generating a new Wireguard key and saving to file
func createNewConfig(managementURL, adminURL, configPath, preSharedKey string) (*Config, error) { func createNewConfig(managementURL, adminURL, configPath, preSharedKey string) (*Config, error) {
wgKey := generateKey() wgKey := generateKey()
config := &Config{PrivateKey: wgKey, WgIface: iface.WgInterfaceDefault, IFaceBlackList: []string{}} pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
if err != nil {
return nil, err
}
config := &Config{SSHKey: string(pem), PrivateKey: wgKey, WgIface: iface.WgInterfaceDefault, IFaceBlackList: []string{}}
if managementURL != "" { if managementURL != "" {
URL, err := parseURL("Management URL", managementURL) URL, err := parseURL("Management URL", managementURL)
if err != nil { if err != nil {
@ -61,7 +67,7 @@ func createNewConfig(managementURL, adminURL, configPath, preSharedKey string) (
config.IFaceBlackList = []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts", config.IFaceBlackList = []string{iface.WgInterfaceDefault, "tun0", "zt", "ZeroTier", "utun", "wg", "ts",
"Tailscale", "tailscale"} "Tailscale", "tailscale"}
err := util.WriteJson(configPath, config) err = util.WriteJson(configPath, config)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -126,6 +132,14 @@ func ReadConfig(managementURL, adminURL, configPath string, preSharedKey *string
config.PreSharedKey = *preSharedKey config.PreSharedKey = *preSharedKey
refresh = true refresh = true
} }
if config.SSHKey == "" {
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
if err != nil {
return nil, err
}
config.SSHKey = string(pem)
refresh = true
}
if refresh { if refresh {
// since we have new management URL, we need to update config file // since we have new management URL, we need to update config file

View File

@ -2,6 +2,7 @@ package internal
import ( import (
"context" "context"
"github.com/netbirdio/netbird/client/ssh"
"time" "time"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
@ -63,8 +64,13 @@ func RunClient(ctx context.Context, config *Config) error {
engineCtx, cancel := context.WithCancel(ctx) engineCtx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
publicSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return err
}
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Wiretrustee config // connect (just a connection, no stream yet) and login to Management Service to get an initial global Wiretrustee config
mgmClient, loginResp, err := connectToManagement(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled) mgmClient, loginResp, err := connectToManagement(engineCtx, config.ManagementURL.Host, myPrivateKey, mgmTlsEnabled,
publicSSHKey)
if err != nil { if err != nil {
log.Debug(err) log.Debug(err)
if s, ok := status.FromError(err); ok && s.Code() == codes.PermissionDenied { if s, ok := status.FromError(err); ok && s.Code() == codes.PermissionDenied {
@ -147,6 +153,7 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
IFaceBlackList: config.IFaceBlackList, IFaceBlackList: config.IFaceBlackList,
WgPrivateKey: key, WgPrivateKey: key,
WgPort: iface.DefaultWgPort, WgPort: iface.DefaultWgPort,
SSHKey: []byte(config.SSHKey),
} }
if config.PreSharedKey != "" { if config.PreSharedKey != "" {
@ -179,7 +186,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig,
} }
// connectToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc) // connectToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc)
func connectToManagement(ctx context.Context, managementAddr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*mgm.GrpcClient, *mgmProto.LoginResponse, error) { func connectToManagement(ctx context.Context, managementAddr string, ourPrivateKey wgtypes.Key, tlsEnabled bool, pubSSHKey []byte) (*mgm.GrpcClient, *mgmProto.LoginResponse, error) {
log.Debugf("connecting to Management Service %s", managementAddr) log.Debugf("connecting to Management Service %s", managementAddr)
client, err := mgm.NewClient(ctx, managementAddr, ourPrivateKey, tlsEnabled) client, err := mgm.NewClient(ctx, managementAddr, ourPrivateKey, tlsEnabled)
if err != nil { if err != nil {
@ -193,7 +200,7 @@ func connectToManagement(ctx context.Context, managementAddr string, ourPrivateK
} }
sysInfo := system.GetInfo(ctx) sysInfo := system.GetInfo(ctx)
loginResp, err := client.Login(*serverPublicKey, sysInfo) loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@ -3,8 +3,10 @@ package internal
import ( import (
"context" "context"
"fmt" "fmt"
nbssh "github.com/netbirdio/netbird/client/ssh"
"math/rand" "math/rand"
"net" "net"
"runtime"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -54,6 +56,9 @@ type EngineConfig struct {
// UDPMuxSrflxPort default value 0 - the system will pick an available port // UDPMuxSrflxPort default value 0 - the system will pick an available port
UDPMuxSrflxPort int UDPMuxSrflxPort int
// SSHKey is a private SSH key in a PEM format
SSHKey []byte
} }
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
@ -87,6 +92,9 @@ type Engine struct {
// networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service // networkSerial is the latest CurrentSerial (state ID) of the network sent by the Management service
networkSerial uint64 networkSerial uint64
sshServerFunc func(hostKeyPEM []byte, addr string) (nbssh.Server, error)
sshServer nbssh.Server
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@ -111,6 +119,7 @@ func NewEngine(
STUNs: []*ice.URL{}, STUNs: []*ice.URL{},
TURNs: []*ice.URL{}, TURNs: []*ice.URL{},
networkSerial: 0, networkSerial: 0,
sshServerFunc: nbssh.DefaultSSHServer,
} }
} }
@ -283,9 +292,14 @@ func (e *Engine) removeAllPeers() error {
return nil return nil
} }
// removePeer closes an existing peer connection and removes a peer // removePeer closes an existing peer connection, removes a peer, and clears authorized key of the SSH server
func (e *Engine) removePeer(peerKey string) error { func (e *Engine) removePeer(peerKey string) error {
log.Debugf("removing peer from engine %s", peerKey) log.Debugf("removing peer from engine %s", peerKey)
if e.sshServer != nil {
e.sshServer.RemoveAuthorizedKey(peerKey)
}
conn, exists := e.peerConns[peerKey] conn, exists := e.peerConns[peerKey]
if exists { if exists {
delete(e.peerConns, peerKey) delete(e.peerConns, peerKey)
@ -398,12 +412,6 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
} }
if update.GetNetworkMap() != nil { if update.GetNetworkMap() != nil {
if update.GetNetworkMap().GetPeerConfig() != nil {
err := e.updateConfig(update.GetNetworkMap().GetPeerConfig())
if err != nil {
return err
}
}
// only apply new changes and ignore old ones // only apply new changes and ignore old ones
err := e.updateNetworkMap(update.GetNetworkMap()) err := e.updateNetworkMap(update.GetNetworkMap())
if err != nil { if err != nil {
@ -414,6 +422,49 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil return nil
} }
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if sshConf.GetSshEnabled() {
if runtime.GOOS == "windows" {
log.Warnf("running SSH server on Windows is not supported")
return nil
}
// start SSH server if it wasn't running
if e.sshServer == nil {
//nil sshServer means it has not yet been started
var err error
e.sshServer, err = e.sshServerFunc(e.config.SSHKey,
fmt.Sprintf("%s:%d", e.wgInterface.Address.IP.String(), nbssh.DefaultSSHPort))
if err != nil {
return err
}
go func() {
// blocking
err = e.sshServer.Start()
if err != nil {
// will throw error when we stop it even if it is a graceful stop
log.Debugf("stopped SSH server with error %v", err)
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.sshServer = nil
log.Infof("stopped SSH server")
}()
} else {
log.Debugf("SSH server is already running")
}
} else {
// Disable SSH server request, so stop it if it was running
if e.sshServer != nil {
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
}
e.sshServer = nil
}
}
return nil
}
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
if e.wgInterface.Address.String() != conf.Address { if e.wgInterface.Address.String() != conf.Address {
oldAddr := e.wgInterface.Address.String() oldAddr := e.wgInterface.Address.String()
@ -422,9 +473,17 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
if err != nil { if err != nil {
return err return err
} }
e.config.WgAddr = conf.Address
log.Infof("updated peer address from %s to %s", oldAddr, conf.Address) log.Infof("updated peer address from %s to %s", oldAddr, conf.Address)
} }
if conf.GetSshConfig() != nil {
err := e.updateSSH(conf.GetSshConfig())
if err != nil {
log.Warnf("failed handling SSH server setup %v", e)
}
}
return nil return nil
} }
@ -486,6 +545,15 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error {
} }
func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
// intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't
if networkMap.GetPeerConfig() != nil {
err := e.updateConfig(networkMap.GetPeerConfig())
if err != nil {
return err
}
}
serial := networkMap.GetSerial() serial := networkMap.GetSerial()
if e.networkSerial > serial { if e.networkSerial > serial {
log.Debugf("received outdated NetworkMap with serial %d, ignoring", serial) log.Debugf("received outdated NetworkMap with serial %d, ignoring", serial)
@ -515,6 +583,18 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
if err != nil { if err != nil {
return err return err
} }
// update SSHServer by adding remote peer SSH keys
if e.sshServer != nil {
for _, config := range networkMap.GetRemotePeers() {
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey()))
if err != nil {
log.Warnf("failed adding authroized key to SSH DefaultServer %v", err)
}
}
}
}
} }
e.networkSerial = serial e.networkSerial = serial

View File

@ -3,6 +3,9 @@ package internal
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/iface"
"github.com/stretchr/testify/assert"
"net" "net"
"os" "os"
"path/filepath" "path/filepath"
@ -40,6 +43,140 @@ var (
} }
) )
func TestEngine_SSH(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping TestEngine_SSH on Windows")
}
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
WgIfaceName: "utun101",
WgAddr: "100.64.0.1/24",
WgPrivateKey: key,
WgPort: 33100,
})
var sshKeysAdded []string
var sshPeersRemoved []string
sshCtx, cancel := context.WithCancel(context.Background())
engine.sshServerFunc = func(hostKeyPEM []byte, addr string) (ssh.Server, error) {
return &ssh.MockServer{
Ctx: sshCtx,
StopFunc: func() error {
cancel()
return nil
},
StartFunc: func() error {
<-ctx.Done()
return ctx.Err()
},
AddAuthorizedKeyFunc: func(peer, newKey string) error {
sshKeysAdded = append(sshKeysAdded, newKey)
return nil
},
RemoveAuthorizedKeyFunc: func(peer string) {
sshPeersRemoved = append(sshPeersRemoved, peer)
},
}, nil
}
err = engine.Start()
if err != nil {
t.Fatal(err)
}
defer func() {
err := engine.Stop()
if err != nil {
return
}
}()
peerWithSSH := &mgmtProto.RemotePeerConfig{
WgPubKey: "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=",
AllowedIps: []string{"100.64.0.21/24"},
SshConfig: &mgmtProto.SSHConfig{
SshPubKey: []byte("ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ"),
},
}
// SSH server is not enabled so SSH config of a remote peer should be ignored
networkMap := &mgmtProto.NetworkMap{
Serial: 6,
PeerConfig: nil,
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
if err != nil {
t.Fatal(err)
}
assert.Nil(t, engine.sshServer)
// SSH server is enabled, therefore SSH config should be applied
networkMap = &mgmtProto.NetworkMap{
Serial: 7,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{SshEnabled: true}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
if err != nil {
t.Fatal(err)
}
time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
assert.Contains(t, sshKeysAdded, "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ")
// now remove peer
networkMap = &mgmtProto.NetworkMap{
Serial: 8,
RemotePeers: []*mgmtProto.RemotePeerConfig{},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
if err != nil {
t.Fatal(err)
}
//time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
// now disable SSH server
networkMap = &mgmtProto.NetworkMap{
Serial: 9,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{SshEnabled: false}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
if err != nil {
t.Fatal(err)
}
assert.Nil(t, engine.sshServer)
}
func TestEngine_UpdateNetworkMap(t *testing.T) { func TestEngine_UpdateNetworkMap(t *testing.T) {
// test setup // test setup
key, err := wgtypes.GeneratePrivateKey() key, err := wgtypes.GeneratePrivateKey()
@ -52,11 +189,12 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
defer cancel() defer cancel()
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{ engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{
WgIfaceName: "utun100", WgIfaceName: "utun102",
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
}) })
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", iface.DefaultMTU)
type testCase struct { type testCase struct {
name string name string
@ -231,7 +369,7 @@ func TestEngine_Sync(t *testing.T) {
} }
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, &EngineConfig{ engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, &EngineConfig{
WgIfaceName: "utun100", WgIfaceName: "utun103",
WgAddr: "100.64.0.1/24", WgAddr: "100.64.0.1/24",
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
@ -418,7 +556,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
} }
info := system.GetInfo(ctx) info := system.GetInfo(ctx)
resp, err := mgmtClient.Register(*publicKey, setupKey, "", info) resp, err := mgmtClient.Register(*publicKey, setupKey, "", info, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -2,8 +2,8 @@ package internal
import ( import (
"context" "context"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
mgm "github.com/netbirdio/netbird/management/client" mgm "github.com/netbirdio/netbird/management/client"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
@ -40,7 +40,11 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
return err return err
} }
_, err = loginPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken) pubSSHKey, err := ssh.GeneratePublicKey([]byte(config.SSHKey))
if err != nil {
return err
}
_, err = loginPeer(ctx, *serverKey, mgmClient, setupKey, jwtToken, pubSSHKey)
if err != nil { if err != nil {
log.Errorf("failed logging-in peer on Management Service : %v", err) log.Errorf("failed logging-in peer on Management Service : %v", err)
return err return err
@ -56,13 +60,13 @@ func Login(ctx context.Context, config *Config, setupKey string, jwtToken string
} }
// loginPeer attempts to login to Management Service. If peer wasn't registered, tries the registration flow. // loginPeer attempts to login to Management Service. If peer wasn't registered, tries the registration flow.
func loginPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string) (*mgmProto.LoginResponse, error) { func loginPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
sysInfo := system.GetInfo(ctx) sysInfo := system.GetInfo(ctx)
loginResp, err := client.Login(serverPublicKey, sysInfo) loginResp, err := client.Login(serverPublicKey, sysInfo, pubSSHKey)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Code() == codes.PermissionDenied { if s, ok := status.FromError(err); ok && s.Code() == codes.PermissionDenied {
log.Debugf("peer registration required") log.Debugf("peer registration required")
return registerPeer(ctx, serverPublicKey, client, setupKey, jwtToken) return registerPeer(ctx, serverPublicKey, client, setupKey, jwtToken, pubSSHKey)
} else { } else {
return nil, err return nil, err
} }
@ -75,7 +79,7 @@ func loginPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.Grp
// registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. // registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key.
// Otherwise tries to register with the provided setupKey via command line. // Otherwise tries to register with the provided setupKey via command line.
func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string) (*mgmProto.LoginResponse, error) { func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string, jwtToken string, pubSSHKey []byte) (*mgmProto.LoginResponse, error) {
validSetupKey, err := uuid.Parse(setupKey) validSetupKey, err := uuid.Parse(setupKey)
if err != nil && jwtToken == "" { if err != nil && jwtToken == "" {
return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err) return nil, status.Errorf(codes.InvalidArgument, "invalid setup-key or no sso information provided, err: %v", err)
@ -83,7 +87,7 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
log.Debugf("sending peer registration request to Management Service") log.Debugf("sending peer registration request to Management Service")
info := system.GetInfo(ctx) info := system.GetInfo(ctx)
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info) loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey)
if err != nil { if err != nil {
log.Errorf("failed registering peer %v,%s", err, validSetupKey.String()) log.Errorf("failed registering peer %v,%s", err, validSetupKey.String())
return nil, err return nil, err

114
client/ssh/client.go Normal file
View File

@ -0,0 +1,114 @@
package ssh
import (
"fmt"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
"net"
"os"
)
// Client wraps crypto/ssh Client to simplify usage
type Client struct {
client *ssh.Client
}
// Close closes the wrapped SSH Client
func (c *Client) Close() error {
return c.client.Close()
}
// OpenTerminal starts an interactive terminal session with the remote SSH server
func (c *Client) OpenTerminal() error {
session, err := c.client.NewSession()
if err != nil {
return fmt.Errorf("failed to open new session: %v", err)
}
defer func() {
err := session.Close()
if err != nil {
return
}
}()
fd := int(os.Stdout.Fd())
state, err := term.MakeRaw(fd)
if err != nil {
return fmt.Errorf("failed to run raw terminal: %s", err)
}
defer func() {
err := term.Restore(fd, state)
if err != nil {
return
}
}()
w, h, err := term.GetSize(fd)
if err != nil {
return fmt.Errorf("terminal get size: %s", err)
}
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}
terminal := os.Getenv("TERM")
if terminal == "" {
terminal = "xterm-256color"
}
if err := session.RequestPty(terminal, h, w, modes); err != nil {
return fmt.Errorf("failed requesting pty session with xterm: %s", err)
}
session.Stdout = os.Stdout
session.Stderr = os.Stderr
session.Stdin = os.Stdin
if err := session.Shell(); err != nil {
return fmt.Errorf("failed to start login shell on the remote host: %s", err)
}
if err := session.Wait(); err != nil {
if e, ok := err.(*ssh.ExitError); ok {
switch e.ExitStatus() {
case 130:
return nil
}
}
return fmt.Errorf("failed running SSH session: %s", err)
}
return nil
}
// DialWithKey connects to the remote SSH server with a provided private key file (PEM).
func DialWithKey(addr, user string, privateKey []byte) (*Client, error) {
signer, err := ssh.ParsePrivateKey(privateKey)
if err != nil {
return nil, err
}
config := &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
HostKeyCallback: ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }),
}
return Dial("tcp", addr, config)
}
// Dial connects to the remote SSH server.
func Dial(network, addr string, config *ssh.ClientConfig) (*Client, error) {
client, err := ssh.Dial(network, addr, config)
if err != nil {
return nil, err
}
return &Client{
client: client,
}, nil
}

183
client/ssh/server.go Normal file
View File

@ -0,0 +1,183 @@
package ssh
import (
"fmt"
"github.com/creack/pty"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
"io"
"net"
"os"
"os/exec"
"sync"
)
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
const DefaultSSHPort = 44338
// DefaultSSHServer is a function that creates DefaultServer
func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) {
return newDefaultServer(hostKeyPEM, addr)
}
// Server is an interface of SSH server
type Server interface {
// Stop stops SSH server.
Stop() error
// Start starts SSH server. Blocking
Start() error
// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys
RemoveAuthorizedKey(peer string)
// AddAuthorizedKey add a given peer key to server authorized keys
AddAuthorizedKey(peer, newKey string) error
}
// DefaultServer is the embedded NetBird SSH server
type DefaultServer struct {
listener net.Listener
// authorizedKeys is ssh pub key indexed by peer WireGuard public key
authorizedKeys map[string]ssh.PublicKey
mu sync.Mutex
hostKeyPEM []byte
sessions []ssh.Session
}
// newDefaultServer creates new server with provided host key
func newDefaultServer(hostKeyPEM []byte, addr string) (*DefaultServer, error) {
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
allowedKeys := make(map[string]ssh.PublicKey)
return &DefaultServer{listener: ln, mu: sync.Mutex{}, hostKeyPEM: hostKeyPEM, authorizedKeys: allowedKeys, sessions: make([]ssh.Session, 0)}, nil
}
// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys
func (srv *DefaultServer) RemoveAuthorizedKey(peer string) {
srv.mu.Lock()
defer srv.mu.Unlock()
delete(srv.authorizedKeys, peer)
}
// AddAuthorizedKey add a given peer key to server authorized keys
func (srv *DefaultServer) AddAuthorizedKey(peer, newKey string) error {
srv.mu.Lock()
defer srv.mu.Unlock()
parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(newKey))
if err != nil {
return err
}
srv.authorizedKeys[peer] = parsedKey
return nil
}
// Stop stops SSH server.
func (srv *DefaultServer) Stop() error {
srv.mu.Lock()
defer srv.mu.Unlock()
err := srv.listener.Close()
if err != nil {
return err
}
for _, session := range srv.sessions {
err := session.Close()
if err != nil {
log.Warnf("failed closing SSH session from %v", err)
}
}
return nil
}
func (srv *DefaultServer) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
srv.mu.Lock()
defer srv.mu.Unlock()
for _, allowed := range srv.authorizedKeys {
if ssh.KeysEqual(allowed, key) {
return true
}
}
return false
}
func getShellType() string {
shell := os.Getenv("SHELL")
if shell == "" {
shell = "sh"
}
return shell
}
// sessionHandler handles SSH session post auth
func (srv *DefaultServer) sessionHandler(session ssh.Session) {
srv.mu.Lock()
srv.sessions = append(srv.sessions, session)
srv.mu.Unlock()
ptyReq, winCh, isPty := session.Pty()
if isPty {
cmd := exec.Command(getShellType())
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%session", ptyReq.Term))
file, err := pty.Start(cmd)
if err != nil {
log.Errorf("failed starting SSH server %v", err)
}
go func() {
for win := range winCh {
setWinSize(file, win.Width, win.Height)
}
}()
srv.stdInOut(file, session)
err = cmd.Wait()
if err != nil {
return
}
} else {
_, err := io.WriteString(session, "only PTY is supported.\n")
if err != nil {
return
}
err = session.Exit(1)
if err != nil {
return
}
}
}
func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) {
go func() {
// stdin
_, err := io.Copy(file, session)
if err != nil {
return
}
}()
go func() {
// stdout
_, err := io.Copy(session, file)
if err != nil {
return
}
}()
}
// Start starts SSH server. Blocking
func (srv *DefaultServer) Start() error {
log.Infof("starting SSH server on addr: %s", srv.listener.Addr().String())
publicKeyOption := ssh.PublicKeyAuth(srv.publicKeyHandler)
hostKeyPEM := ssh.HostKeyPEM(srv.hostKeyPEM)
err := ssh.Serve(srv.listener, srv.sessionHandler, publicKeyOption, hostKeyPEM)
if err != nil {
return err
}
return nil
}

44
client/ssh/server_mock.go Normal file
View File

@ -0,0 +1,44 @@
package ssh
import "context"
// MockServer mocks ssh.Server
type MockServer struct {
Ctx context.Context
StopFunc func() error
StartFunc func() error
AddAuthorizedKeyFunc func(peer, newKey string) error
RemoveAuthorizedKeyFunc func(peer string)
}
// RemoveAuthorizedKey removes SSH key of a given peer from the authorized keys
func (srv *MockServer) RemoveAuthorizedKey(peer string) {
if srv.RemoveAuthorizedKeyFunc == nil {
return
}
srv.RemoveAuthorizedKeyFunc(peer)
}
// AddAuthorizedKey add a given peer key to server authorized keys
func (srv *MockServer) AddAuthorizedKey(peer, newKey string) error {
if srv.AddAuthorizedKeyFunc == nil {
return nil
}
return srv.AddAuthorizedKeyFunc(peer, newKey)
}
// Stop stops SSH server.
func (srv *MockServer) Stop() error {
if srv.StopFunc == nil {
return nil
}
return srv.StopFunc()
}
// Start starts SSH server. Blocking
func (srv *MockServer) Start() error {
if srv.StartFunc == nil {
return nil
}
return srv.StartFunc()
}

121
client/ssh/server_test.go Normal file
View File

@ -0,0 +1,121 @@
package ssh
import (
"fmt"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ssh"
"strings"
"testing"
)
func TestServer_AddAuthorizedKey(t *testing.T) {
key, err := GeneratePrivateKey(ED25519)
if err != nil {
t.Fatal(err)
}
server, err := newDefaultServer(key, "localhost:")
if err != nil {
t.Fatal(err)
}
// add multiple keys
keys := map[string][]byte{}
for i := 0; i < 10; i++ {
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
remotePrivKey, err := GeneratePrivateKey(ED25519)
if err != nil {
t.Fatal(err)
}
remotePubKey, err := GeneratePublicKey(remotePrivKey)
if err != nil {
t.Fatal(err)
}
err = server.AddAuthorizedKey(peer, string(remotePubKey))
if err != nil {
t.Error(err)
}
keys[peer] = remotePubKey
}
// make sure that all keys have been added
for peer, remotePubKey := range keys {
k, ok := server.authorizedKeys[peer]
assert.True(t, ok, "expecting remotePeer key to be found in authorizedKeys")
assert.Equal(t, string(remotePubKey), strings.TrimSpace(string(ssh.MarshalAuthorizedKey(k))))
}
}
func TestServer_RemoveAuthorizedKey(t *testing.T) {
key, err := GeneratePrivateKey(ED25519)
if err != nil {
t.Fatal(err)
}
server, err := newDefaultServer(key, "localhost:")
if err != nil {
t.Fatal(err)
}
remotePrivKey, err := GeneratePrivateKey(ED25519)
if err != nil {
t.Fatal(err)
}
remotePubKey, err := GeneratePublicKey(remotePrivKey)
if err != nil {
t.Fatal(err)
}
err = server.AddAuthorizedKey("remotePeer", string(remotePubKey))
if err != nil {
t.Error(err)
}
server.RemoveAuthorizedKey("remotePeer")
_, ok := server.authorizedKeys["remotePeer"]
assert.False(t, ok, "expecting remotePeer's SSH key to be removed")
}
func TestServer_PubKeyHandler(t *testing.T) {
key, err := GeneratePrivateKey(ED25519)
if err != nil {
t.Fatal(err)
}
server, err := newDefaultServer(key, "localhost:")
if err != nil {
t.Fatal(err)
}
var keys []ssh.PublicKey
for i := 0; i < 10; i++ {
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
remotePrivKey, err := GeneratePrivateKey(ED25519)
if err != nil {
t.Fatal(err)
}
remotePubKey, err := GeneratePublicKey(remotePrivKey)
if err != nil {
t.Fatal(err)
}
remoteParsedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(remotePubKey)
if err != nil {
t.Fatal(err)
}
err = server.AddAuthorizedKey(peer, string(remotePubKey))
if err != nil {
t.Error(err)
}
keys = append(keys, remoteParsedPubKey)
}
for _, key := range keys {
accepted := server.publicKeyHandler(nil, key)
assert.Truef(t, accepted, "expecting SSH connection to be accepted for a given SSH key %s", string(ssh.MarshalAuthorizedKey(key)))
}
}

86
client/ssh/util.go Normal file
View File

@ -0,0 +1,86 @@
package ssh
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"golang.org/x/crypto/ed25519"
gossh "golang.org/x/crypto/ssh"
"strings"
)
// KeyType is a type of SSH key
type KeyType string
// ED25519 is key of type ed25519
const ED25519 KeyType = "ed25519"
// ECDSA is key of type ecdsa
const ECDSA KeyType = "ecdsa"
// RSA is key of type rsa
const RSA KeyType = "rsa"
// RSAKeySize is a size of newly generated RSA key
const RSAKeySize = 2048
// GeneratePrivateKey creates RSA Private Key of specified byte size
func GeneratePrivateKey(keyType KeyType) ([]byte, error) {
var key crypto.Signer
var err error
switch keyType {
case ED25519:
_, key, err = ed25519.GenerateKey(rand.Reader)
case ECDSA:
key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
case RSA:
key, err = rsa.GenerateKey(rand.Reader, RSAKeySize)
default:
return nil, fmt.Errorf("unsupported ket type %s", keyType)
}
if err != nil {
return nil, err
}
pemBytes, err := EncodePrivateKeyToPEM(key)
if err != nil {
return nil, err
}
return pemBytes, nil
}
// GeneratePublicKey returns the public part of the private key
func GeneratePublicKey(key []byte) ([]byte, error) {
signer, err := gossh.ParsePrivateKey(key)
if err != nil {
return nil, err
}
strKey := strings.TrimSpace(string(gossh.MarshalAuthorizedKey(signer.PublicKey())))
return []byte(strKey), nil
}
// EncodePrivateKeyToPEM encodes Private Key from RSA to PEM format
func EncodePrivateKeyToPEM(privateKey crypto.Signer) ([]byte, error) {
mk, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
return nil, err
}
// pem.Block
privBlock := pem.Block{
Type: "PRIVATE KEY",
Bytes: mk,
}
// Private key in PEM format
privatePEM := pem.EncodeToMemory(&privBlock)
return privatePEM, nil
}

14
client/ssh/window_unix.go Normal file
View File

@ -0,0 +1,14 @@
//go:build linux || darwin
package ssh
import (
"os"
"syscall"
"unsafe"
)
func setWinSize(file *os.File, width, height int) {
syscall.Syscall(syscall.SYS_IOCTL, file.Fd(), uintptr(syscall.TIOCSWINSZ), //nolint
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(height), uint16(width), 0, 0})))
}

View File

@ -0,0 +1,9 @@
package ssh
import (
"os"
)
func setWinSize(file *os.File, width, height int) {
}

8
go.mod
View File

@ -18,7 +18,7 @@ require (
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.1.0 github.com/vishvananda/netlink v1.1.0
golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9
golang.org/x/sys v0.0.0-20220513210249-45d2b4557a2a golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664
golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434 golang.zx2c4.com/wireguard v0.0.0-20211209221555-9c9e7e272434
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de golang.zx2c4.com/wireguard/wgctrl v0.0.0-20211215182854-7a385b3431de
golang.zx2c4.com/wireguard/windows v0.5.1 golang.zx2c4.com/wireguard/windows v0.5.1
@ -30,18 +30,22 @@ require (
require ( require (
fyne.io/fyne/v2 v2.1.4 fyne.io/fyne/v2 v2.1.4
github.com/c-robinson/iplib v1.0.3 github.com/c-robinson/iplib v1.0.3
github.com/creack/pty v1.1.18
github.com/eko/gocache/v2 v2.3.1 github.com/eko/gocache/v2 v2.3.1
github.com/getlantern/systray v1.2.1 github.com/getlantern/systray v1.2.1
github.com/gliderlabs/ssh v0.3.4
github.com/magiconair/properties v1.8.5 github.com/magiconair/properties v1.8.5
github.com/patrickmn/go-cache v2.1.0+incompatible github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/rs/xid v1.3.0 github.com/rs/xid v1.3.0
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/stretchr/testify v1.7.1 github.com/stretchr/testify v1.7.1
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467
) )
require ( require (
github.com/BurntSushi/toml v0.4.1 // indirect github.com/BurntSushi/toml v0.4.1 // indirect
github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect github.com/XiaoMi/pegasus-go-client v0.0.0-20210427083443-f3b6b08bc4c2 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
github.com/beorn7/perks v1.0.1 // indirect github.com/beorn7/perks v1.0.1 // indirect
github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d // indirect github.com/bradfitz/gomemcache v0.0.0-20220106215444-fb4bf637b56d // indirect
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect
@ -110,5 +114,3 @@ require (
) )
replace github.com/pion/ice/v2 => github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb replace github.com/pion/ice/v2 => github.com/wiretrustee/ice/v2 v2.1.21-0.20220218121004-dc81faead4bb
//replace github.com/eko/gocache/v3 => /home/braginini/Documents/projects/my/wiretrustee/gocache

12
go.sum
View File

@ -69,6 +69,8 @@ github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRF
github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho=
github.com/allegro/bigcache/v3 v3.0.2 h1:AKZCw+5eAaVyNTBmI2fgyPVJhHkdWder3O9IrprcQfI= github.com/allegro/bigcache/v3 v3.0.2 h1:AKZCw+5eAaVyNTBmI2fgyPVJhHkdWder3O9IrprcQfI=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
@ -118,6 +120,8 @@ github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.1/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/davecgh/go-spew v0.0.0-20151105211317-5215b55f46b2/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v0.0.0-20151105211317-5215b55f46b2/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@ -178,6 +182,8 @@ github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2H
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmCsR2Do= github.com/gin-gonic/gin v1.5.0/go.mod h1:Nd6IXA8m5kNZdNEHMBd93KT+mdY3+bewLgRvmCsR2Do=
github.com/gliderlabs/ssh v0.3.4 h1:+AXBtim7MTKaLVPgvE+3mhewYRawNLTd+jEEz/wExZw=
github.com/gliderlabs/ssh v0.3.4/go.mod h1:ZSS+CUoKHDrqVakTfTWUlKSr9MtMFkC4UvtQKD7O914=
github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f h1:s0O46d8fPwk9kU4k1jj76wBquMVETx7uveQD9MCIQoU= github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f h1:s0O46d8fPwk9kU4k1jj76wBquMVETx7uveQD9MCIQoU=
github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f/go.mod h1:wjpnOv6ONl2SuJSxqCPVaPZibGFdSci9HFocT9qtVYM= github.com/go-gl/gl v0.0.0-20210813123233-e4099ee2221f/go.mod h1:wjpnOv6ONl2SuJSxqCPVaPZibGFdSci9HFocT9qtVYM=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
@ -644,6 +650,7 @@ golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8U
golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY= golang.org/x/crypto v0.0.0-20190923035154-9ee001bba392/go.mod h1:/lpIB1dKB+9EgE3H3cr1v9wB50oz8l4C4h62xy7jSTY=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20211202192323-5770296d904e/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
@ -891,8 +898,13 @@ golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220513210249-45d2b4557a2a h1:N2T1jUrTQE9Re6TFF5PhvEHXHCguynGhKjWVsIUt5cY= golang.org/x/sys v0.0.0-20220513210249-45d2b4557a2a h1:N2T1jUrTQE9Re6TFF5PhvEHXHCguynGhKjWVsIUt5cY=
golang.org/x/sys v0.0.0-20220513210249-45d2b4557a2a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220513210249-45d2b4557a2a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664 h1:wEZYwx+kK+KlZ0hpvP2Ls1Xr4+RWnlzGFwPP0aiDjIU=
golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467 h1:CBpWXWQpIRjzmkkA+M7q9Fqnwd2mZr3AFqexg8YTfoM=
golang.org/x/term v0.0.0-20220526004731-065cf7ba2467/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@ -12,7 +12,7 @@ type Client interface {
io.Closer io.Closer
Sync(msgHandler func(msg *proto.SyncResponse) error) error Sync(msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKey() (*wgtypes.Key, error) GetServerPublicKey() (*wgtypes.Key, error)
Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info) (*proto.LoginResponse, error) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
Login(serverKey wgtypes.Key, sysInfo *system.Info) (*proto.LoginResponse, error) Login(serverKey wgtypes.Key, sysInfo *system.Info, sshKey []byte) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
} }

View File

@ -158,7 +158,7 @@ func TestClient_LoginUnregistered_ShouldThrow_401(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
sysInfo := system.GetInfo(context.TODO()) sysInfo := system.GetInfo(context.TODO())
_, err = client.Login(*key, sysInfo) _, err = client.Login(*key, sysInfo, nil)
if err == nil { if err == nil {
t.Error("expecting err on unregistered login, got nil") t.Error("expecting err on unregistered login, got nil")
} }
@ -186,7 +186,7 @@ func TestClient_LoginRegistered(t *testing.T) {
t.Error(err) t.Error(err)
} }
info := system.GetInfo(context.TODO()) info := system.GetInfo(context.TODO())
resp, err := client.Register(*key, ValidKey, "", info) resp, err := client.Register(*key, ValidKey, "", info, nil)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -216,7 +216,7 @@ func TestClient_Sync(t *testing.T) {
} }
info := system.GetInfo(context.TODO()) info := system.GetInfo(context.TODO())
_, err = client.Register(*serverKey, ValidKey, "", info) _, err = client.Register(*serverKey, ValidKey, "", info, nil)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@ -232,7 +232,7 @@ func TestClient_Sync(t *testing.T) {
} }
info = system.GetInfo(context.TODO()) info = system.GetInfo(context.TODO())
_, err = remoteClient.Register(*serverKey, ValidKey, "", info) _, err = remoteClient.Register(*serverKey, ValidKey, "", info, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -330,7 +330,7 @@ func Test_SystemMetaDataFromClient(t *testing.T) {
} }
info := system.GetInfo(context.TODO()) info := system.GetInfo(context.TODO())
_, err = testClient.Register(*key, ValidKey, "", info) _, err = testClient.Register(*key, ValidKey, "", info, nil)
if err != nil { if err != nil {
t.Errorf("error while trying to register client: %v", err) t.Errorf("error while trying to register client: %v", err)
} }

View File

@ -233,13 +233,21 @@ func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*pro
// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key // Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key
// Takes care of encrypting and decrypting messages. // Takes care of encrypting and decrypting messages.
// This method will also collect system info and send it with the request (e.g. hostname, os, etc) // This method will also collect system info and send it with the request (e.g. hostname, os, etc)
func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info) (*proto.LoginResponse, error) { func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, sysInfo *system.Info, pubSSHKey []byte) (*proto.LoginResponse, error) {
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken}) keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: infoToMetaData(sysInfo), JwtToken: jwtToken, PeerKeys: keys})
} }
// Login attempts login to Management Server. Takes care of encrypting and decrypting messages. // Login attempts login to Management Server. Takes care of encrypting and decrypting messages.
func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info) (*proto.LoginResponse, error) { func (c *GrpcClient) Login(serverKey wgtypes.Key, sysInfo *system.Info, pubSSHKey []byte) (*proto.LoginResponse, error) {
return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo)}) keys := &proto.PeerKeys{
SshPubKey: pubSSHKey,
WgPubKey: []byte(c.key.PublicKey().String()),
}
return c.login(serverKey, &proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys})
} }
// GetDeviceAuthorizationFlow returns a device authorization flow information. // GetDeviceAuthorizationFlow returns a device authorization flow information.

View File

@ -10,8 +10,8 @@ type MockClient struct {
CloseFunc func() error CloseFunc func() error
SyncFunc func(msgHandler func(msg *proto.SyncResponse) error) error SyncFunc func(msgHandler func(msg *proto.SyncResponse) error) error
GetServerPublicKeyFunc func() (*wgtypes.Key, error) GetServerPublicKeyFunc func() (*wgtypes.Key, error)
RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info) (*proto.LoginResponse, error) RegisterFunc func(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
LoginFunc func(serverKey wgtypes.Key, info *system.Info) (*proto.LoginResponse, error) LoginFunc func(serverKey wgtypes.Key, info *system.Info, sshKey []byte) (*proto.LoginResponse, error)
GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) GetDeviceAuthorizationFlowFunc func(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error)
} }
@ -36,18 +36,18 @@ func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) {
return m.GetServerPublicKeyFunc() return m.GetServerPublicKeyFunc()
} }
func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info) (*proto.LoginResponse, error) { func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string, jwtToken string, info *system.Info, sshKey []byte) (*proto.LoginResponse, error) {
if m.RegisterFunc == nil { if m.RegisterFunc == nil {
return nil, nil return nil, nil
} }
return m.RegisterFunc(serverKey, setupKey, jwtToken, info) return m.RegisterFunc(serverKey, setupKey, jwtToken, info, sshKey)
} }
func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info) (*proto.LoginResponse, error) { func (m *MockClient) Login(serverKey wgtypes.Key, info *system.Info, sshKey []byte) (*proto.LoginResponse, error) {
if m.LoginFunc == nil { if m.LoginFunc == nil {
return nil, nil return nil, nil
} }
return m.LoginFunc(serverKey, info) return m.LoginFunc(serverKey, info, sshKey)
} }
func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) { func (m *MockClient) GetDeviceAuthorizationFlow(serverKey wgtypes.Key) (*proto.DeviceAuthorizationFlow, error) {

File diff suppressed because it is too large Load Diff

View File

@ -71,9 +71,21 @@ message LoginRequest {
PeerSystemMeta meta = 2; PeerSystemMeta meta = 2;
// SSO token (can be empty) // SSO token (can be empty)
string jwtToken = 3; string jwtToken = 3;
// Can be absent for now.
PeerKeys peerKeys = 4;
}
// PeerKeys is additional peer info like SSH pub key and WireGuard public key.
// This message is sent on Login or register requests, or when a key rotation has to happen.
message PeerKeys {
// sshPubKey represents a public SSH key of the peer. Can be absent.
bytes sshPubKey = 1;
// wgPubKey represents a public WireGuard key of the peer. Can be absent.
bytes wgPubKey = 2;
} }
// Peer machine meta data // PeerSystemMeta is machine meta data like OS and version.
message PeerSystemMeta { message PeerSystemMeta {
string hostname = 1; string hostname = 1;
string goOS = 2; string goOS = 2;
@ -143,6 +155,9 @@ message PeerConfig {
string address = 1; string address = 1;
// Wiretrustee DNS server (a Wireguard DNS config) // Wiretrustee DNS server (a Wireguard DNS config)
string dns = 2; string dns = 2;
// SSHConfig of the peer.
SSHConfig sshConfig = 3;
} }
// NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections
@ -172,7 +187,22 @@ message RemotePeerConfig {
// Wireguard allowed IPs of a remote peer e.g. [10.30.30.1/32] // Wireguard allowed IPs of a remote peer e.g. [10.30.30.1/32]
repeated string allowedIps = 2; repeated string allowedIps = 2;
// SSHConfig is a SSH config of the remote peer. SSHConfig.sshPubKey should be ignored because peer knows it's SSH key.
SSHConfig sshConfig = 3;
} }
// SSHConfig represents SSH configurations of a peer.
message SSHConfig {
// sshEnabled indicates whether a SSH server is enabled on this peer
bool sshEnabled = 1;
// sshPubKey is a SSH public key of a peer to be added to authorized_hosts.
// This property should be ignore if SSHConfig comes from PeerConfig.
bytes sshPubKey = 2;
}
// DeviceAuthorizationFlowRequest empty struct for future expansion // DeviceAuthorizationFlowRequest empty struct for future expansion
message DeviceAuthorizationFlowRequest {} message DeviceAuthorizationFlowRequest {}
// DeviceAuthorizationFlow represents Device Authorization Flow information // DeviceAuthorizationFlow represents Device Authorization Flow information

View File

@ -51,6 +51,7 @@ type AccountManager interface {
GetNetworkMap(peerKey string) (*NetworkMap, error) GetNetworkMap(peerKey string) (*NetworkMap, error)
AddPeer(setupKey string, userId string, peer *Peer) (*Peer, error) AddPeer(setupKey string, userId string, peer *Peer) (*Peer, error)
UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error
UpdatePeerSSHKey(peerKey string, sshKey string) error
GetUsersFromAccount(accountId string) ([]*UserInfo, error) GetUsersFromAccount(accountId string) ([]*UserInfo, error)
GetGroup(accountId, groupID string) (*Group, error) GetGroup(accountId, groupID string) (*Group, error)
SaveGroup(accountId string, group *Group) error SaveGroup(accountId string, group *Group) error
@ -65,6 +66,7 @@ type AccountManager interface {
UpdateRule(accountID string, ruleID string, operations []RuleUpdateOperation) (*Rule, error) UpdateRule(accountID string, ruleID string, operations []RuleUpdateOperation) (*Rule, error)
DeleteRule(accountId, ruleID string) error DeleteRule(accountId, ruleID string) error
ListRules(accountId string) ([]*Rule, error) ListRules(accountId string) ([]*Rule, error)
UpdatePeer(accountID string, peer *Peer) (*Peer, error)
} }
type DefaultAccountManager struct { type DefaultAccountManager struct {

View File

@ -185,9 +185,15 @@ func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Pe
return nil, status.Errorf(codes.InvalidArgument, "peer meta data was not provided") return nil, status.Errorf(codes.InvalidArgument, "peer meta data was not provided")
} }
var sshKey []byte
if req.GetPeerKeys() != nil {
sshKey = req.GetPeerKeys().GetSshPubKey()
}
peer, err := s.accountManager.AddPeer(reqSetupKey, userId, &Peer{ peer, err := s.accountManager.AddPeer(reqSetupKey, userId, &Peer{
Key: peerKey.String(), Key: peerKey.String(),
Name: meta.GetHostname(), Name: meta.GetHostname(),
SSHKey: string(sshKey),
Meta: PeerSystemMeta{ Meta: PeerSystemMeta{
Hostname: meta.GetHostname(), Hostname: meta.GetHostname(),
GoOS: meta.GetGoOS(), GoOS: meta.GetGoOS(),
@ -290,6 +296,19 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
return nil, status.Error(codes.Internal, "internal server error") return nil, status.Error(codes.Internal, "internal server error")
} }
} }
var sshKey []byte
if loginReq.GetPeerKeys() != nil {
sshKey = loginReq.GetPeerKeys().GetSshPubKey()
}
if len(sshKey) > 0 {
err = s.accountManager.UpdatePeerSSHKey(peerKey.String(), string(sshKey))
if err != nil {
return nil, err
}
}
// if peer has reached this point then it has logged in // if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{ loginResp := &proto.LoginResponse{
WiretrusteeConfig: toWiretrusteeConfig(s.config, nil), WiretrusteeConfig: toWiretrusteeConfig(s.config, nil),
@ -366,6 +385,7 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot
func toPeerConfig(peer *Peer) *proto.PeerConfig { func toPeerConfig(peer *Peer) *proto.PeerConfig {
return &proto.PeerConfig{ return &proto.PeerConfig{
Address: fmt.Sprintf("%s/%d", peer.IP.String(), SubnetSize), // take it from the network Address: fmt.Sprintf("%s/%d", peer.IP.String(), SubnetSize), // take it from the network
SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled},
} }
} }
@ -375,9 +395,9 @@ func toRemotePeerConfig(peers []*Peer) []*proto.RemotePeerConfig {
remotePeers = append(remotePeers, &proto.RemotePeerConfig{ remotePeers = append(remotePeers, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key, WgPubKey: rPeer.Key,
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)}, AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
}) })
} }
return remotePeers return remotePeers
} }

View File

@ -85,6 +85,9 @@ components:
required: required:
- type - type
- value - value
ssh_enabled:
description: Indicates whether SSH server is enabled on this peer
type: boolean
required: required:
- ip - ip
- connected - connected
@ -93,6 +96,7 @@ components:
- version - version
- groups - groups
- activated_by - activated_by
- ssh_enabled
SetupKey: SetupKey:
type: object type: object
properties: properties:
@ -397,8 +401,11 @@ paths:
properties: properties:
name: name:
type: string type: string
ssh_enabled:
type: boolean
required: required:
- name - name
- ssh_enabled
responses: responses:
'200': '200':
description: A Peer object description: A Peer object

View File

@ -115,6 +115,9 @@ type Peer struct {
// Peer's operating system and version // Peer's operating system and version
Os string `json:"os"` Os string `json:"os"`
// Indicates whether SSH server is enabled on this peer
SshEnabled bool `json:"ssh_enabled"`
// Peer's daemon or cli version // Peer's daemon or cli version
Version string `json:"version"` Version string `json:"version"`
} }
@ -266,6 +269,7 @@ type PutApiGroupsIdJSONBody struct {
// PutApiPeersIdJSONBody defines parameters for PutApiPeersId. // PutApiPeersIdJSONBody defines parameters for PutApiPeersId.
type PutApiPeersIdJSONBody struct { type PutApiPeersIdJSONBody struct {
Name string `json:"name"` Name string `json:"name"`
SshEnabled bool `json:"ssh_enabled"`
} }
// PostApiRulesJSONBody defines parameters for PostApiRules. // PostApiRulesJSONBody defines parameters for PostApiRules.

View File

@ -34,7 +34,9 @@ func (h *Peers) updatePeer(account *server.Account, peer *server.Peer, w http.Re
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
peer, err = h.accountManager.RenamePeer(account.Id, peer.Key, req.Name)
update := &server.Peer{Key: peer.Key, SSHEnabled: req.SshEnabled, Name: req.Name}
peer, err = h.accountManager.UpdatePeer(account.Id, update)
if err != nil { if err != nil {
log.Errorf("failed updating peer %s under account %s %v", peerIp, account.Id, err) log.Errorf("failed updating peer %s under account %s %v", peerIp, account.Id, err)
http.Redirect(w, r, "/", http.StatusInternalServerError) http.Redirect(w, r, "/", http.StatusInternalServerError)
@ -141,5 +143,6 @@ func toPeerResponse(peer *server.Peer, account *server.Account) *api.Peer {
Os: fmt.Sprintf("%s %s", peer.Meta.OS, peer.Meta.Core), Os: fmt.Sprintf("%s %s", peer.Meta.OS, peer.Meta.Core),
Version: peer.Meta.WtVersion, Version: peer.Meta.WtVersion,
Groups: groupsInfo, Groups: groupsInfo,
SshEnabled: peer.SSHEnabled,
} }
} }

View File

@ -41,6 +41,8 @@ type MockAccountManager struct {
ListRulesFunc func(accountID string) ([]*server.Rule, error) ListRulesFunc func(accountID string) ([]*server.Rule, error)
GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error) GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error)
UpdatePeerMetaFunc func(peerKey string, meta server.PeerSystemMeta) error UpdatePeerMetaFunc func(peerKey string, meta server.PeerSystemMeta) error
UpdatePeerSSHKeyFunc func(peerKey string, sshKey string) error
UpdatePeerFunc func(accountID string, peer *server.Peer) (*server.Peer, error)
} }
// GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface
@ -48,7 +50,7 @@ func (am *MockAccountManager) GetUsersFromAccount(accountID string) ([]*server.U
if am.GetUsersFromAccountFunc != nil { if am.GetUsersFromAccountFunc != nil {
return am.GetUsersFromAccountFunc(accountID) return am.GetUsersFromAccountFunc(accountID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount is not implemented")
} }
// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface // GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface
@ -60,7 +62,7 @@ func (am *MockAccountManager) GetOrCreateAccountByUser(
} }
return nil, status.Errorf( return nil, status.Errorf(
codes.Unimplemented, codes.Unimplemented,
"method GetOrCreateAccountByUser not implemented", "method GetOrCreateAccountByUser is not implemented",
) )
} }
@ -69,7 +71,7 @@ func (am *MockAccountManager) GetAccountByUser(userId string) (*server.Account,
if am.GetAccountByUserFunc != nil { if am.GetAccountByUserFunc != nil {
return am.GetAccountByUserFunc(userId) return am.GetAccountByUserFunc(userId)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUser not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetAccountByUser is not implemented")
} }
// AddSetupKey mock implementation of AddSetupKey from server.AccountManager interface // AddSetupKey mock implementation of AddSetupKey from server.AccountManager interface
@ -82,7 +84,7 @@ func (am *MockAccountManager) AddSetupKey(
if am.AddSetupKeyFunc != nil { if am.AddSetupKeyFunc != nil {
return am.AddSetupKeyFunc(accountId, keyName, keyType, expiresIn) return am.AddSetupKeyFunc(accountId, keyName, keyType, expiresIn)
} }
return nil, status.Errorf(codes.Unimplemented, "method AddSetupKey not implemented") return nil, status.Errorf(codes.Unimplemented, "method AddSetupKey is not implemented")
} }
// RevokeSetupKey mock implementation of RevokeSetupKey from server.AccountManager interface // RevokeSetupKey mock implementation of RevokeSetupKey from server.AccountManager interface
@ -93,7 +95,7 @@ func (am *MockAccountManager) RevokeSetupKey(
if am.RevokeSetupKeyFunc != nil { if am.RevokeSetupKeyFunc != nil {
return am.RevokeSetupKeyFunc(accountId, keyId) return am.RevokeSetupKeyFunc(accountId, keyId)
} }
return nil, status.Errorf(codes.Unimplemented, "method RevokeSetupKey not implemented") return nil, status.Errorf(codes.Unimplemented, "method RevokeSetupKey is not implemented")
} }
// RenameSetupKey mock implementation of RenameSetupKey from server.AccountManager interface // RenameSetupKey mock implementation of RenameSetupKey from server.AccountManager interface
@ -105,7 +107,7 @@ func (am *MockAccountManager) RenameSetupKey(
if am.RenameSetupKeyFunc != nil { if am.RenameSetupKeyFunc != nil {
return am.RenameSetupKeyFunc(accountId, keyId, newName) return am.RenameSetupKeyFunc(accountId, keyId, newName)
} }
return nil, status.Errorf(codes.Unimplemented, "method RenameSetupKey not implemented") return nil, status.Errorf(codes.Unimplemented, "method RenameSetupKey is not implemented")
} }
// GetAccountById mock implementation of GetAccountById from server.AccountManager interface // GetAccountById mock implementation of GetAccountById from server.AccountManager interface
@ -113,7 +115,7 @@ func (am *MockAccountManager) GetAccountById(accountId string) (*server.Account,
if am.GetAccountByIdFunc != nil { if am.GetAccountByIdFunc != nil {
return am.GetAccountByIdFunc(accountId) return am.GetAccountByIdFunc(accountId)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetAccountById not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetAccountById is not implemented")
} }
// GetAccountByUserOrAccountId mock implementation of GetAccountByUserOrAccountId from server.AccountManager interface // GetAccountByUserOrAccountId mock implementation of GetAccountByUserOrAccountId from server.AccountManager interface
@ -125,7 +127,7 @@ func (am *MockAccountManager) GetAccountByUserOrAccountId(
} }
return nil, status.Errorf( return nil, status.Errorf(
codes.Unimplemented, codes.Unimplemented,
"method GetAccountByUserOrAccountId not implemented", "method GetAccountByUserOrAccountId is not implemented",
) )
} }
@ -138,7 +140,7 @@ func (am *MockAccountManager) GetAccountWithAuthorizationClaims(
} }
return nil, status.Errorf( return nil, status.Errorf(
codes.Unimplemented, codes.Unimplemented,
"method GetAccountWithAuthorizationClaims not implemented", "method GetAccountWithAuthorizationClaims is not implemented",
) )
} }
@ -147,7 +149,7 @@ func (am *MockAccountManager) AccountExists(accountId string) (*bool, error) {
if am.AccountExistsFunc != nil { if am.AccountExistsFunc != nil {
return am.AccountExistsFunc(accountId) return am.AccountExistsFunc(accountId)
} }
return nil, status.Errorf(codes.Unimplemented, "method AccountExists not implemented") return nil, status.Errorf(codes.Unimplemented, "method AccountExists is not implemented")
} }
// GetPeer mock implementation of GetPeer from server.AccountManager interface // GetPeer mock implementation of GetPeer from server.AccountManager interface
@ -155,7 +157,7 @@ func (am *MockAccountManager) GetPeer(peerKey string) (*server.Peer, error) {
if am.GetPeerFunc != nil { if am.GetPeerFunc != nil {
return am.GetPeerFunc(peerKey) return am.GetPeerFunc(peerKey)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetPeer not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetPeer is not implemented")
} }
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
@ -163,7 +165,7 @@ func (am *MockAccountManager) MarkPeerConnected(peerKey string, connected bool)
if am.MarkPeerConnectedFunc != nil { if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(peerKey, connected) return am.MarkPeerConnectedFunc(peerKey, connected)
} }
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected not implemented") return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
} }
// RenamePeer mock implementation of RenamePeer from server.AccountManager interface // RenamePeer mock implementation of RenamePeer from server.AccountManager interface
@ -175,7 +177,7 @@ func (am *MockAccountManager) RenamePeer(
if am.RenamePeerFunc != nil { if am.RenamePeerFunc != nil {
return am.RenamePeerFunc(accountId, peerKey, newName) return am.RenamePeerFunc(accountId, peerKey, newName)
} }
return nil, status.Errorf(codes.Unimplemented, "method RenamePeer not implemented") return nil, status.Errorf(codes.Unimplemented, "method RenamePeer is not implemented")
} }
// DeletePeer mock implementation of DeletePeer from server.AccountManager interface // DeletePeer mock implementation of DeletePeer from server.AccountManager interface
@ -183,7 +185,7 @@ func (am *MockAccountManager) DeletePeer(accountId string, peerKey string) (*ser
if am.DeletePeerFunc != nil { if am.DeletePeerFunc != nil {
return am.DeletePeerFunc(accountId, peerKey) return am.DeletePeerFunc(accountId, peerKey)
} }
return nil, status.Errorf(codes.Unimplemented, "method DeletePeer not implemented") return nil, status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented")
} }
// GetPeerByIP mock implementation of GetPeerByIP from server.AccountManager interface // GetPeerByIP mock implementation of GetPeerByIP from server.AccountManager interface
@ -191,7 +193,7 @@ func (am *MockAccountManager) GetPeerByIP(accountId string, peerIP string) (*ser
if am.GetPeerByIPFunc != nil { if am.GetPeerByIPFunc != nil {
return am.GetPeerByIPFunc(accountId, peerIP) return am.GetPeerByIPFunc(accountId, peerIP)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetPeerByIP not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetPeerByIP is not implemented")
} }
// GetNetworkMap mock implementation of GetNetworkMap from server.AccountManager interface // GetNetworkMap mock implementation of GetNetworkMap from server.AccountManager interface
@ -199,7 +201,7 @@ func (am *MockAccountManager) GetNetworkMap(peerKey string) (*server.NetworkMap,
if am.GetNetworkMapFunc != nil { if am.GetNetworkMapFunc != nil {
return am.GetNetworkMapFunc(peerKey) return am.GetNetworkMapFunc(peerKey)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetNetworkMap not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetNetworkMap is not implemented")
} }
// AddPeer mock implementation of AddPeer from server.AccountManager interface // AddPeer mock implementation of AddPeer from server.AccountManager interface
@ -211,7 +213,7 @@ func (am *MockAccountManager) AddPeer(
if am.AddPeerFunc != nil { if am.AddPeerFunc != nil {
return am.AddPeerFunc(setupKey, userId, peer) return am.AddPeerFunc(setupKey, userId, peer)
} }
return nil, status.Errorf(codes.Unimplemented, "method AddPeer not implemented") return nil, status.Errorf(codes.Unimplemented, "method AddPeer is not implemented")
} }
// GetGroup mock implementation of GetGroup from server.AccountManager interface // GetGroup mock implementation of GetGroup from server.AccountManager interface
@ -219,7 +221,7 @@ func (am *MockAccountManager) GetGroup(accountID, groupID string) (*server.Group
if am.GetGroupFunc != nil { if am.GetGroupFunc != nil {
return am.GetGroupFunc(accountID, groupID) return am.GetGroupFunc(accountID, groupID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetGroup not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetGroup is not implemented")
} }
// SaveGroup mock implementation of SaveGroup from server.AccountManager interface // SaveGroup mock implementation of SaveGroup from server.AccountManager interface
@ -227,7 +229,7 @@ func (am *MockAccountManager) SaveGroup(accountID string, group *server.Group) e
if am.SaveGroupFunc != nil { if am.SaveGroupFunc != nil {
return am.SaveGroupFunc(accountID, group) return am.SaveGroupFunc(accountID, group)
} }
return status.Errorf(codes.Unimplemented, "method SaveGroup not implemented") return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented")
} }
// UpdateGroup mock implementation of UpdateGroup from server.AccountManager interface // UpdateGroup mock implementation of UpdateGroup from server.AccountManager interface
@ -243,7 +245,7 @@ func (am *MockAccountManager) DeleteGroup(accountID, groupID string) error {
if am.DeleteGroupFunc != nil { if am.DeleteGroupFunc != nil {
return am.DeleteGroupFunc(accountID, groupID) return am.DeleteGroupFunc(accountID, groupID)
} }
return status.Errorf(codes.Unimplemented, "method DeleteGroup not implemented") return status.Errorf(codes.Unimplemented, "method DeleteGroup is not implemented")
} }
// ListGroups mock implementation of ListGroups from server.AccountManager interface // ListGroups mock implementation of ListGroups from server.AccountManager interface
@ -251,7 +253,7 @@ func (am *MockAccountManager) ListGroups(accountID string) ([]*server.Group, err
if am.ListGroupsFunc != nil { if am.ListGroupsFunc != nil {
return am.ListGroupsFunc(accountID) return am.ListGroupsFunc(accountID)
} }
return nil, status.Errorf(codes.Unimplemented, "method ListGroups not implemented") return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
} }
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
@ -259,7 +261,7 @@ func (am *MockAccountManager) GroupAddPeer(accountID, groupID, peerKey string) e
if am.GroupAddPeerFunc != nil { if am.GroupAddPeerFunc != nil {
return am.GroupAddPeerFunc(accountID, groupID, peerKey) return am.GroupAddPeerFunc(accountID, groupID, peerKey)
} }
return status.Errorf(codes.Unimplemented, "method GroupAddPeer not implemented") return status.Errorf(codes.Unimplemented, "method GroupAddPeer is not implemented")
} }
// GroupDeletePeer mock implementation of GroupDeletePeer from server.AccountManager interface // GroupDeletePeer mock implementation of GroupDeletePeer from server.AccountManager interface
@ -267,7 +269,7 @@ func (am *MockAccountManager) GroupDeletePeer(accountID, groupID, peerKey string
if am.GroupDeletePeerFunc != nil { if am.GroupDeletePeerFunc != nil {
return am.GroupDeletePeerFunc(accountID, groupID, peerKey) return am.GroupDeletePeerFunc(accountID, groupID, peerKey)
} }
return status.Errorf(codes.Unimplemented, "method GroupDeletePeer not implemented") return status.Errorf(codes.Unimplemented, "method GroupDeletePeer is not implemented")
} }
// GroupListPeers mock implementation of GroupListPeers from server.AccountManager interface // GroupListPeers mock implementation of GroupListPeers from server.AccountManager interface
@ -275,7 +277,7 @@ func (am *MockAccountManager) GroupListPeers(accountID, groupID string) ([]*serv
if am.GroupListPeersFunc != nil { if am.GroupListPeersFunc != nil {
return am.GroupListPeersFunc(accountID, groupID) return am.GroupListPeersFunc(accountID, groupID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GroupListPeers not implemented") return nil, status.Errorf(codes.Unimplemented, "method GroupListPeers is not implemented")
} }
// GetRule mock implementation of GetRule from server.AccountManager interface // GetRule mock implementation of GetRule from server.AccountManager interface
@ -283,7 +285,7 @@ func (am *MockAccountManager) GetRule(accountID, ruleID string) (*server.Rule, e
if am.GetRuleFunc != nil { if am.GetRuleFunc != nil {
return am.GetRuleFunc(accountID, ruleID) return am.GetRuleFunc(accountID, ruleID)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetRule not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetRule is not implemented")
} }
// SaveRule mock implementation of SaveRule from server.AccountManager interface // SaveRule mock implementation of SaveRule from server.AccountManager interface
@ -291,7 +293,7 @@ func (am *MockAccountManager) SaveRule(accountID string, rule *server.Rule) erro
if am.SaveRuleFunc != nil { if am.SaveRuleFunc != nil {
return am.SaveRuleFunc(accountID, rule) return am.SaveRuleFunc(accountID, rule)
} }
return status.Errorf(codes.Unimplemented, "method SaveRule not implemented") return status.Errorf(codes.Unimplemented, "method SaveRule is not implemented")
} }
// UpdateRule mock implementation of UpdateRule from server.AccountManager interface // UpdateRule mock implementation of UpdateRule from server.AccountManager interface
@ -307,7 +309,7 @@ func (am *MockAccountManager) DeleteRule(accountID, ruleID string) error {
if am.DeleteRuleFunc != nil { if am.DeleteRuleFunc != nil {
return am.DeleteRuleFunc(accountID, ruleID) return am.DeleteRuleFunc(accountID, ruleID)
} }
return status.Errorf(codes.Unimplemented, "method DeleteRule not implemented") return status.Errorf(codes.Unimplemented, "method DeleteRule is not implemented")
} }
// ListRules mock implementation of ListRules from server.AccountManager interface // ListRules mock implementation of ListRules from server.AccountManager interface
@ -315,7 +317,7 @@ func (am *MockAccountManager) ListRules(accountID string) ([]*server.Rule, error
if am.ListRulesFunc != nil { if am.ListRulesFunc != nil {
return am.ListRulesFunc(accountID) return am.ListRulesFunc(accountID)
} }
return nil, status.Errorf(codes.Unimplemented, "method ListRules not implemented") return nil, status.Errorf(codes.Unimplemented, "method ListRules is not implemented")
} }
// UpdatePeerMeta mock implementation of UpdatePeerMeta from server.AccountManager interface // UpdatePeerMeta mock implementation of UpdatePeerMeta from server.AccountManager interface
@ -323,7 +325,7 @@ func (am *MockAccountManager) UpdatePeerMeta(peerKey string, meta server.PeerSys
if am.UpdatePeerMetaFunc != nil { if am.UpdatePeerMetaFunc != nil {
return am.UpdatePeerMetaFunc(peerKey, meta) return am.UpdatePeerMetaFunc(peerKey, meta)
} }
return status.Errorf(codes.Unimplemented, "method UpdatePeerMetaFunc not implemented") return status.Errorf(codes.Unimplemented, "method UpdatePeerMetaFunc is not implemented")
} }
// IsUserAdmin mock implementation of IsUserAdmin from server.AccountManager interface // IsUserAdmin mock implementation of IsUserAdmin from server.AccountManager interface
@ -331,5 +333,21 @@ func (am *MockAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaims)
if am.IsUserAdminFunc != nil { if am.IsUserAdminFunc != nil {
return am.IsUserAdminFunc(claims) return am.IsUserAdminFunc(claims)
} }
return false, status.Errorf(codes.Unimplemented, "method IsUserAdmin not implemented") return false, status.Errorf(codes.Unimplemented, "method IsUserAdmin is not implemented")
}
// UpdatePeerSSHKey mocks UpdatePeerSSHKey function of the account manager
func (am *MockAccountManager) UpdatePeerSSHKey(peerKey string, sshKey string) error {
if am.UpdatePeerSSHKeyFunc != nil {
return am.UpdatePeerSSHKeyFunc(peerKey, sshKey)
}
return status.Errorf(codes.Unimplemented, "method UpdatePeerSSHKey is is not implemented")
}
// UpdatePeer mocks UpdatePeerFunc function of the account manager
func (am *MockAccountManager) UpdatePeer(accountID string, peer *server.Peer) (*server.Peer, error) {
if am.UpdatePeerFunc != nil {
return am.UpdatePeerFunc(accountID, peer)
}
return nil, status.Errorf(codes.Unimplemented, "method UpdatePeerFunc is is not implemented")
} }

View File

@ -47,6 +47,10 @@ type Peer struct {
Status *PeerStatus Status *PeerStatus
// The user ID that registered the peer // The user ID that registered the peer
UserID string UserID string
// SSHKey is a public SSH key of the peer
SSHKey string
// SSHEnabled indicated whether SSH server is enabled on the peer
SSHEnabled bool
} }
// Copy copies Peer object // Copy copies Peer object
@ -59,6 +63,8 @@ func (p *Peer) Copy() *Peer {
Name: p.Name, Name: p.Name,
Status: p.Status, Status: p.Status,
UserID: p.UserID, UserID: p.UserID,
SSHKey: p.SSHKey,
SSHEnabled: p.SSHEnabled,
} }
} }
@ -100,6 +106,41 @@ func (am *DefaultAccountManager) MarkPeerConnected(peerKey string, connected boo
return nil return nil
} }
// UpdatePeer updates peer. Only Peer.Name and Peer.SSHEnabled can be updated.
func (am *DefaultAccountManager) UpdatePeer(accountID string, update *Peer) (*Peer, error) {
am.mux.Lock()
defer am.mux.Unlock()
account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, status.Errorf(codes.NotFound, "account not found")
}
peer, err := am.Store.GetPeer(update.Key)
if err != nil {
return nil, err
}
peerCopy := peer.Copy()
if peer.Name != "" {
peerCopy.Name = update.Name
}
peerCopy.SSHEnabled = update.SSHEnabled
err = am.Store.SavePeer(accountID, peerCopy)
if err != nil {
return nil, err
}
err = am.updateAccountPeers(account)
if err != nil {
return nil, err
}
return peerCopy, nil
}
// RenamePeer changes peer's name // RenamePeer changes peer's name
func (am *DefaultAccountManager) RenamePeer( func (am *DefaultAccountManager) RenamePeer(
accountId string, accountId string,
@ -292,6 +333,8 @@ func (am *DefaultAccountManager) AddPeer(
Name: peer.Name, Name: peer.Name,
UserID: userID, UserID: userID,
Status: &PeerStatus{Connected: false, LastSeen: time.Now()}, Status: &PeerStatus{Connected: false, LastSeen: time.Now()},
SSHEnabled: false,
SSHKey: peer.SSHKey,
} }
// add peer to 'All' group // add peer to 'All' group
@ -315,6 +358,38 @@ func (am *DefaultAccountManager) AddPeer(
return newPeer, nil return newPeer, nil
} }
// UpdatePeerSSHKey updates peer's public SSH key
func (am *DefaultAccountManager) UpdatePeerSSHKey(peerKey string, sshKey string) error {
am.mux.Lock()
defer am.mux.Unlock()
if sshKey == "" {
log.Debugf("empty SSH key provided for peer %s, skipping update", peerKey)
return nil
}
peer, err := am.Store.GetPeer(peerKey)
if err != nil {
return err
}
account, err := am.Store.GetPeerAccount(peerKey)
if err != nil {
return err
}
peerCopy := peer.Copy()
peerCopy.SSHKey = sshKey
err = am.Store.SavePeer(account.Id, peerCopy)
if err != nil {
return err
}
// trigger network map update
return am.updateAccountPeers(account)
}
// UpdatePeerMeta updates peer's system metadata // UpdatePeerMeta updates peer's system metadata
func (am *DefaultAccountManager) UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error { func (am *DefaultAccountManager) UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error {
am.mux.Lock() am.mux.Lock()
@ -345,7 +420,7 @@ func (am *DefaultAccountManager) UpdatePeerMeta(peerKey string, meta PeerSystemM
return nil return nil
} }
// getPeersByACL allowed for given peer by ACL // getPeersByACL returns all peers that given peer has access to.
func (am *DefaultAccountManager) getPeersByACL(account *Account, peerKey string) []*Peer { func (am *DefaultAccountManager) getPeersByACL(account *Account, peerKey string) []*Peer {
var peers []*Peer var peers []*Peer
srcRules, err := am.Store.GetPeerSrcRules(account.Id, peerKey) srcRules, err := am.Store.GetPeerSrcRules(account.Id, peerKey)
@ -409,7 +484,8 @@ func (am *DefaultAccountManager) getPeersByACL(account *Account, peerKey string)
return peers return peers
} }
// updateAccountPeers network map constructed by ACL // updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(account *Account) error { func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {
// notify other peers of the change // notify other peers of the change
peers, err := am.Store.GetAccountPeers(account.Id) peers, err := am.Store.GetAccountPeers(account.Id)
@ -422,7 +498,7 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {
err = am.peersUpdateManager.SendUpdate(p.Key, err = am.peersUpdateManager.SendUpdate(p.Key,
&UpdateMessage{ &UpdateMessage{
Update: &proto.SyncResponse{ Update: &proto.SyncResponse{
// fill those field for backward compatibility // fill deprecated fields for backward compatibility
RemotePeers: update, RemotePeers: update,
RemotePeersIsEmpty: len(update) == 0, RemotePeersIsEmpty: len(update) == 0,
// new field // new field
@ -430,6 +506,7 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) error {
Serial: account.Network.CurrentSerial(), Serial: account.Network.CurrentSerial(),
RemotePeers: update, RemotePeers: update,
RemotePeersIsEmpty: len(update) == 0, RemotePeersIsEmpty: len(update) == 0,
PeerConfig: toPeerConfig(p),
}, },
}, },
}) })

View File

@ -9,6 +9,7 @@ import (
type UpdateMessage struct { type UpdateMessage struct {
Update *proto.SyncResponse Update *proto.SyncResponse
} }
type PeersUpdateManager struct { type PeersUpdateManager struct {
peerChannels map[string]chan *UpdateMessage peerChannels map[string]chan *UpdateMessage
channelsMux *sync.Mutex channelsMux *sync.Mutex

12
util/membership_unix.go Normal file
View File

@ -0,0 +1,12 @@
//go:build linux || darwin
package util
import (
"os"
)
// IsAdmin returns true if user is root
func IsAdmin() bool {
return os.Geteuid() == 0
}

View File

@ -0,0 +1,12 @@
package util
import "golang.zx2c4.com/wireguard/windows/elevate"
// IsAdmin returns true if user has admin privileges
func IsAdmin() bool {
adminDesktop, err := elevate.IsAdminDesktop()
if err == nil && adminDesktop {
return true
}
return false
}