mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 10:18:50 +02:00
Fix engine test
This commit is contained in:
parent
854b70141d
commit
6c967d1c27
@ -40,6 +40,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
mgmt "github.com/netbirdio/netbird/management/client"
|
mgmt "github.com/netbirdio/netbird/management/client"
|
||||||
@ -203,6 +204,13 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generate SSH key for the test
|
||||||
|
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@ -218,6 +226,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
ServerSSHAllowed: true,
|
ServerSSHAllowed: true,
|
||||||
|
SSHKey: sshKey,
|
||||||
},
|
},
|
||||||
MobileDependency{},
|
MobileDependency{},
|
||||||
peer.NewRecorder("https://mgm"),
|
peer.NewRecorder("https://mgm"),
|
||||||
@ -229,14 +238,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = engine.Start()
|
err = engine.Start()
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Skip("skipping TestEngine_SSH due to interface creation failure in CI:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Additional check to ensure wgInterface was created successfully
|
|
||||||
if engine.wgInterface == nil {
|
|
||||||
t.Skip("skipping TestEngine_SSH: wgInterface not initialized (likely due to CI permissions)")
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := engine.Stop()
|
err := engine.Stop()
|
||||||
@ -262,9 +264,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
err = engine.updateNetworkMap(networkMap)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Nil(t, engine.sshServer)
|
assert.Nil(t, engine.sshServer)
|
||||||
|
|
||||||
@ -278,9 +278,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
err = engine.updateNetworkMap(networkMap)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(250 * time.Millisecond)
|
time.Sleep(250 * time.Millisecond)
|
||||||
assert.NotNil(t, engine.sshServer)
|
assert.NotNil(t, engine.sshServer)
|
||||||
@ -293,9 +291,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
err = engine.updateNetworkMap(networkMap)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// time.Sleep(250 * time.Millisecond)
|
// time.Sleep(250 * time.Millisecond)
|
||||||
assert.NotNil(t, engine.sshServer)
|
assert.NotNil(t, engine.sshServer)
|
||||||
@ -310,9 +306,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = engine.updateNetworkMap(networkMap)
|
err = engine.updateNetworkMap(networkMap)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Nil(t, engine.sshServer)
|
assert.Nil(t, engine.sshServer)
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package ssh
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
@ -14,6 +15,7 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
cryptossh "golang.org/x/crypto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSSHClient_DialWithKey(t *testing.T) {
|
func TestSSHClient_DialWithKey(t *testing.T) {
|
||||||
@ -703,19 +705,35 @@ func TestSSHClient_SignalHandling(t *testing.T) {
|
|||||||
defer cmdCancel()
|
defer cmdCancel()
|
||||||
|
|
||||||
// Start a long-running command that will be cancelled
|
// Start a long-running command that will be cancelled
|
||||||
|
// Use a command that should work reliably across platforms
|
||||||
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
|
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
|
||||||
|
|
||||||
// The command may return nil (clean exit on signal) or an error
|
// What we care about is that the command was terminated due to context cancellation
|
||||||
// What matters is that the context was actually cancelled
|
// This can manifest in several ways:
|
||||||
|
// 1. Context deadline exceeded error
|
||||||
|
// 2. ExitMissingError (clean termination without exit status)
|
||||||
|
// 3. No error but command completed due to cancellation
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("Received error: %s", err.Error())
|
t.Logf("Received error: %s", err.Error())
|
||||||
// Accept either context deadline exceeded or other cancellation-related errors
|
// Accept context errors or ExitMissingError (both indicate successful cancellation)
|
||||||
isContextError := strings.Contains(err.Error(), "context deadline exceeded") ||
|
var exitMissingErr *cryptossh.ExitMissingError
|
||||||
strings.Contains(err.Error(), "context canceled")
|
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
|
||||||
assert.True(t, isContextError, "Should be cancelled due to timeout, got: %s", err.Error())
|
errors.Is(err, context.Canceled) ||
|
||||||
|
errors.As(err, &exitMissingErr)
|
||||||
|
|
||||||
|
// If we got a valid cancellation error, the test passes
|
||||||
|
if isValidCancellation {
|
||||||
|
t.Logf("Command was successfully cancelled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we got some other error, that's unexpected
|
||||||
|
t.Errorf("Unexpected error type: %s", err.Error())
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the context was actually cancelled (this is the important check)
|
// If no error was returned, the command might have been cancelled cleanly
|
||||||
|
// In this case, we should verify the context was actually cancelled
|
||||||
assert.Error(t, cmdCtx.Err(), "Context should be cancelled due to timeout")
|
assert.Error(t, cmdCtx.Err(), "Context should be cancelled due to timeout")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user