mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-19 11:20:18 +02:00
[client] Fix UDP proxy to notify listener when remote conn closed (#4199)
* Fix UDP proxy to notify listener when remote conn closed * Fix sender tests to use t.Errorf for timeout assertions * Fix potential nil pointer
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
package listener
|
||||
|
||||
import "sync"
|
||||
|
||||
type CloseListener struct {
|
||||
listener func()
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewCloseListener() *CloseListener {
|
||||
@@ -9,11 +12,21 @@ func NewCloseListener() *CloseListener {
|
||||
}
|
||||
|
||||
func (c *CloseListener) SetCloseListener(listener func()) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.listener = listener
|
||||
}
|
||||
|
||||
func (c *CloseListener) Notify() {
|
||||
if c.listener != nil {
|
||||
c.listener()
|
||||
c.mu.Lock()
|
||||
|
||||
if c.listener == nil {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
listener := c.listener
|
||||
c.mu.Unlock()
|
||||
|
||||
listener()
|
||||
}
|
||||
|
@@ -183,6 +183,11 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
|
||||
for {
|
||||
n, err := p.remoteConnRead(ctx, buf)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.closeListener.Notify()
|
||||
return
|
||||
}
|
||||
|
||||
|
@@ -122,10 +122,6 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
||||
originalTimeout := healthCheckTimeout
|
||||
healthCheckInterval = 1 * time.Second
|
||||
healthCheckTimeout = 500 * time.Millisecond
|
||||
defer func() {
|
||||
healthCheckInterval = originalInterval
|
||||
healthCheckTimeout = originalTimeout
|
||||
}()
|
||||
|
||||
//nolint:tenv
|
||||
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
|
||||
@@ -164,20 +160,23 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
||||
select {
|
||||
case <-sender.Timeout:
|
||||
if tc.resetCounterOnce {
|
||||
t.Fatalf("should not have timed out before %s", testTimeout)
|
||||
t.Errorf("should not have timed out before %s", testTimeout)
|
||||
}
|
||||
case <-time.After(testTimeout):
|
||||
if tc.resetCounterOnce {
|
||||
return
|
||||
}
|
||||
t.Fatalf("should have timed out before %s", testTimeout)
|
||||
t.Errorf("should have timed out before %s", testTimeout)
|
||||
}
|
||||
|
||||
cancel()
|
||||
select {
|
||||
case <-senderExit:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatalf("sender did not exit in time")
|
||||
}
|
||||
healthCheckInterval = originalInterval
|
||||
healthCheckTimeout = originalTimeout
|
||||
})
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user