Fix engine test

This commit is contained in:
Viktor Liu 2025-06-19 22:35:32 +02:00
parent 854b70141d
commit 6c967d1c27
2 changed files with 39 additions and 27 deletions

View File

@ -40,6 +40,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
mgmt "github.com/netbirdio/netbird/management/client" mgmt "github.com/netbirdio/netbird/management/client"
@ -203,6 +204,13 @@ func TestEngine_SSH(t *testing.T) {
return return
} }
// Generate SSH key for the test
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
t.Fatal(err)
return
}
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -218,6 +226,7 @@ func TestEngine_SSH(t *testing.T) {
WgPrivateKey: key, WgPrivateKey: key,
WgPort: 33100, WgPort: 33100,
ServerSSHAllowed: true, ServerSSHAllowed: true,
SSHKey: sshKey,
}, },
MobileDependency{}, MobileDependency{},
peer.NewRecorder("https://mgm"), peer.NewRecorder("https://mgm"),
@ -229,14 +238,7 @@ func TestEngine_SSH(t *testing.T) {
} }
err = engine.Start() err = engine.Start()
if err != nil { require.NoError(t, err)
t.Skip("skipping TestEngine_SSH due to interface creation failure in CI:", err)
}
// Additional check to ensure wgInterface was created successfully
if engine.wgInterface == nil {
t.Skip("skipping TestEngine_SSH: wgInterface not initialized (likely due to CI permissions)")
}
defer func() { defer func() {
err := engine.Stop() err := engine.Stop()
@ -262,9 +264,7 @@ func TestEngine_SSH(t *testing.T) {
} }
err = engine.updateNetworkMap(networkMap) err = engine.updateNetworkMap(networkMap)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
assert.Nil(t, engine.sshServer) assert.Nil(t, engine.sshServer)
@ -278,9 +278,7 @@ func TestEngine_SSH(t *testing.T) {
} }
err = engine.updateNetworkMap(networkMap) err = engine.updateNetworkMap(networkMap)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
time.Sleep(250 * time.Millisecond) time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer) assert.NotNil(t, engine.sshServer)
@ -293,9 +291,7 @@ func TestEngine_SSH(t *testing.T) {
} }
err = engine.updateNetworkMap(networkMap) err = engine.updateNetworkMap(networkMap)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
// time.Sleep(250 * time.Millisecond) // time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer) assert.NotNil(t, engine.sshServer)
@ -310,9 +306,7 @@ func TestEngine_SSH(t *testing.T) {
} }
err = engine.updateNetworkMap(networkMap) err = engine.updateNetworkMap(networkMap)
if err != nil { require.NoError(t, err)
t.Fatal(err)
}
assert.Nil(t, engine.sshServer) assert.Nil(t, engine.sshServer)
} }

View File

@ -3,6 +3,7 @@ package ssh
import ( import (
"bytes" "bytes"
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -14,6 +15,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
) )
func TestSSHClient_DialWithKey(t *testing.T) { func TestSSHClient_DialWithKey(t *testing.T) {
@ -703,19 +705,35 @@ func TestSSHClient_SignalHandling(t *testing.T) {
defer cmdCancel() defer cmdCancel()
// Start a long-running command that will be cancelled // Start a long-running command that will be cancelled
// Use a command that should work reliably across platforms
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10") err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
// The command may return nil (clean exit on signal) or an error // What we care about is that the command was terminated due to context cancellation
// What matters is that the context was actually cancelled // This can manifest in several ways:
// 1. Context deadline exceeded error
// 2. ExitMissingError (clean termination without exit status)
// 3. No error but command completed due to cancellation
if err != nil { if err != nil {
t.Logf("Received error: %s", err.Error()) t.Logf("Received error: %s", err.Error())
// Accept either context deadline exceeded or other cancellation-related errors // Accept context errors or ExitMissingError (both indicate successful cancellation)
isContextError := strings.Contains(err.Error(), "context deadline exceeded") || var exitMissingErr *cryptossh.ExitMissingError
strings.Contains(err.Error(), "context canceled") isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
assert.True(t, isContextError, "Should be cancelled due to timeout, got: %s", err.Error()) errors.Is(err, context.Canceled) ||
errors.As(err, &exitMissingErr)
// If we got a valid cancellation error, the test passes
if isValidCancellation {
t.Logf("Command was successfully cancelled")
return
} }
// Verify the context was actually cancelled (this is the important check) // If we got some other error, that's unexpected
t.Errorf("Unexpected error type: %s", err.Error())
return
}
// If no error was returned, the command might have been cancelled cleanly
// In this case, we should verify the context was actually cancelled
assert.Error(t, cmdCtx.Err(), "Context should be cancelled due to timeout") assert.Error(t, cmdCtx.Err(), "Context should be cancelled due to timeout")
} }