mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-20 11:33:07 +02:00
[client] Fix UDP proxy to notify listener when remote conn closed (#4199)
* Fix UDP proxy to notify listener when remote conn closed * Fix sender tests to use t.Errorf for timeout assertions * Fix potential nil pointer
This commit is contained in:
@@ -1,7 +1,10 @@
|
|||||||
package listener
|
package listener
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
type CloseListener struct {
|
type CloseListener struct {
|
||||||
listener func()
|
listener func()
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCloseListener() *CloseListener {
|
func NewCloseListener() *CloseListener {
|
||||||
@@ -9,11 +12,21 @@ func NewCloseListener() *CloseListener {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *CloseListener) SetCloseListener(listener func()) {
|
func (c *CloseListener) SetCloseListener(listener func()) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
c.listener = listener
|
c.listener = listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *CloseListener) Notify() {
|
func (c *CloseListener) Notify() {
|
||||||
if c.listener != nil {
|
c.mu.Lock()
|
||||||
c.listener()
|
|
||||||
|
if c.listener == nil {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
listener := c.listener
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
listener()
|
||||||
}
|
}
|
||||||
|
@@ -183,6 +183,11 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
|
|||||||
for {
|
for {
|
||||||
n, err := p.remoteConnRead(ctx, buf)
|
n, err := p.remoteConnRead(ctx, buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.closeListener.Notify()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -122,10 +122,6 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
|||||||
originalTimeout := healthCheckTimeout
|
originalTimeout := healthCheckTimeout
|
||||||
healthCheckInterval = 1 * time.Second
|
healthCheckInterval = 1 * time.Second
|
||||||
healthCheckTimeout = 500 * time.Millisecond
|
healthCheckTimeout = 500 * time.Millisecond
|
||||||
defer func() {
|
|
||||||
healthCheckInterval = originalInterval
|
|
||||||
healthCheckTimeout = originalTimeout
|
|
||||||
}()
|
|
||||||
|
|
||||||
//nolint:tenv
|
//nolint:tenv
|
||||||
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
|
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
|
||||||
@@ -164,20 +160,23 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
|
|||||||
select {
|
select {
|
||||||
case <-sender.Timeout:
|
case <-sender.Timeout:
|
||||||
if tc.resetCounterOnce {
|
if tc.resetCounterOnce {
|
||||||
t.Fatalf("should not have timed out before %s", testTimeout)
|
t.Errorf("should not have timed out before %s", testTimeout)
|
||||||
}
|
}
|
||||||
case <-time.After(testTimeout):
|
case <-time.After(testTimeout):
|
||||||
if tc.resetCounterOnce {
|
if tc.resetCounterOnce {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
t.Fatalf("should have timed out before %s", testTimeout)
|
t.Errorf("should have timed out before %s", testTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
select {
|
select {
|
||||||
case <-senderExit:
|
case <-senderExit:
|
||||||
case <-time.After(2 * time.Second):
|
case <-time.After(2 * time.Second):
|
||||||
t.Fatalf("sender did not exit in time")
|
t.Fatalf("sender did not exit in time")
|
||||||
}
|
}
|
||||||
|
healthCheckInterval = originalInterval
|
||||||
|
healthCheckTimeout = originalTimeout
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user