diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 7fa423b6a..f4e540c84 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -40,6 +40,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/routemanager" + nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" mgmt "github.com/netbirdio/netbird/management/client" @@ -203,6 +204,13 @@ func TestEngine_SSH(t *testing.T) { return } + // Generate SSH key for the test + sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) + if err != nil { + t.Fatal(err) + return + } + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -218,6 +226,7 @@ func TestEngine_SSH(t *testing.T) { WgPrivateKey: key, WgPort: 33100, ServerSSHAllowed: true, + SSHKey: sshKey, }, MobileDependency{}, peer.NewRecorder("https://mgm"), @@ -229,14 +238,7 @@ func TestEngine_SSH(t *testing.T) { } err = engine.Start() - if err != nil { - t.Skip("skipping TestEngine_SSH due to interface creation failure in CI:", err) - } - - // Additional check to ensure wgInterface was created successfully - if engine.wgInterface == nil { - t.Skip("skipping TestEngine_SSH: wgInterface not initialized (likely due to CI permissions)") - } + require.NoError(t, err) defer func() { err := engine.Stop() @@ -262,9 +264,7 @@ func TestEngine_SSH(t *testing.T) { } err = engine.updateNetworkMap(networkMap) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) assert.Nil(t, engine.sshServer) @@ -278,9 +278,7 @@ func TestEngine_SSH(t *testing.T) { } err = engine.updateNetworkMap(networkMap) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) time.Sleep(250 * time.Millisecond) assert.NotNil(t, engine.sshServer) @@ -293,9 +291,7 @@ func TestEngine_SSH(t *testing.T) { } err = engine.updateNetworkMap(networkMap) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // time.Sleep(250 * time.Millisecond) assert.NotNil(t, engine.sshServer) @@ -310,9 +306,7 @@ func TestEngine_SSH(t *testing.T) { } err = engine.updateNetworkMap(networkMap) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) assert.Nil(t, engine.sshServer) } diff --git a/client/ssh/client_test.go b/client/ssh/client_test.go index 5cd28814e..dd8642cf4 100644 --- a/client/ssh/client_test.go +++ b/client/ssh/client_test.go @@ -3,6 +3,7 @@ package ssh import ( "bytes" "context" + "errors" "fmt" "io" "net" @@ -14,6 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + cryptossh "golang.org/x/crypto/ssh" ) func TestSSHClient_DialWithKey(t *testing.T) { @@ -703,19 +705,35 @@ func TestSSHClient_SignalHandling(t *testing.T) { defer cmdCancel() // Start a long-running command that will be cancelled + // Use a command that should work reliably across platforms err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10") - // The command may return nil (clean exit on signal) or an error - // What matters is that the context was actually cancelled + // What we care about is that the command was terminated due to context cancellation + // This can manifest in several ways: + // 1. Context deadline exceeded error + // 2. ExitMissingError (clean termination without exit status) + // 3. No error but command completed due to cancellation if err != nil { t.Logf("Received error: %s", err.Error()) - // Accept either context deadline exceeded or other cancellation-related errors - isContextError := strings.Contains(err.Error(), "context deadline exceeded") || - strings.Contains(err.Error(), "context canceled") - assert.True(t, isContextError, "Should be cancelled due to timeout, got: %s", err.Error()) + // Accept context errors or ExitMissingError (both indicate successful cancellation) + var exitMissingErr *cryptossh.ExitMissingError + isValidCancellation := errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, context.Canceled) || + errors.As(err, &exitMissingErr) + + // If we got a valid cancellation error, the test passes + if isValidCancellation { + t.Logf("Command was successfully cancelled") + return + } + + // If we got some other error, that's unexpected + t.Errorf("Unexpected error type: %s", err.Error()) + return } - // Verify the context was actually cancelled (this is the important check) + // If no error was returned, the command might have been cancelled cleanly + // In this case, we should verify the context was actually cancelled assert.Error(t, cmdCtx.Err(), "Context should be cancelled due to timeout") }