diff --git a/relay/client/manager.go b/relay/client/manager.go index 3981415fc..b14a7701b 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -16,6 +16,7 @@ import ( var ( relayCleanupInterval = 60 * time.Second + keepUnusedServerTime = 5 * time.Second ErrRelayClientNotConnected = fmt.Errorf("relay client not connected") ) @@ -27,10 +28,13 @@ type RelayTrack struct { sync.RWMutex relayClient *Client err error + created time.Time } func NewRelayTrack() *RelayTrack { - return &RelayTrack{} + return &RelayTrack{ + created: time.Now(), + } } type OnServerCloseListener func() @@ -302,6 +306,18 @@ func (m *Manager) cleanUpUnusedRelays() { for addr, rt := range m.relayClients { rt.Lock() + // if the connection failed to the server the relay client will be nil + // but the instance will be kept in the relayClients until the next locking + if rt.err != nil { + rt.Unlock() + continue + } + + if time.Since(rt.created) <= keepUnusedServerTime { + rt.Unlock() + continue + } + if rt.relayClient.HasConns() { rt.Unlock() continue diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go index e9cc2c581..bfc342f25 100644 --- a/relay/client/manager_test.go +++ b/relay/client/manager_test.go @@ -288,8 +288,9 @@ func TestForeginAutoClose(t *testing.T) { t.Fatalf("failed to close connection: %s", err) } - t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second) - time.Sleep(relayCleanupInterval + 1*time.Second) + timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second + t.Logf("waiting for relay cleanup: %s", timeout) + time.Sleep(timeout) if len(mgr.relayClients) != 0 { t.Errorf("expected 0, got %d", len(mgr.relayClients)) } diff --git a/relay/client/picker_test.go b/relay/client/picker_test.go index eb14581e0..4800e05ba 100644 --- a/relay/client/picker_test.go +++ b/relay/client/picker_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "testing" - "time" ) func TestServerPicker_UnavailableServers(t *testing.T) { @@ -13,7 +12,7 @@ func TestServerPicker_UnavailableServers(t *testing.T) { PeerID: "test", } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1) defer cancel() go func() {