[client] Fix UDP proxy to notify listener when remote conn closed (#4199)

* Fix UDP proxy to notify listener when remote conn closed

* Fix sender tests to use t.Errorf for timeout assertions

* Fix potential nil pointer
This commit is contained in:
Zoltan Papp
2025-07-25 14:14:45 +02:00
committed by GitHub
parent cb85d3f2fc
commit 31872a7fb6
3 changed files with 25 additions and 8 deletions

View File

@@ -1,7 +1,10 @@
package listener package listener
import "sync"
type CloseListener struct { type CloseListener struct {
listener func() listener func()
mu sync.Mutex
} }
func NewCloseListener() *CloseListener { func NewCloseListener() *CloseListener {
@@ -9,11 +12,21 @@ func NewCloseListener() *CloseListener {
} }
func (c *CloseListener) SetCloseListener(listener func()) { func (c *CloseListener) SetCloseListener(listener func()) {
c.mu.Lock()
defer c.mu.Unlock()
c.listener = listener c.listener = listener
} }
func (c *CloseListener) Notify() { func (c *CloseListener) Notify() {
if c.listener != nil { c.mu.Lock()
c.listener()
if c.listener == nil {
c.mu.Unlock()
return
} }
listener := c.listener
c.mu.Unlock()
listener()
} }

View File

@@ -183,6 +183,11 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) {
for { for {
n, err := p.remoteConnRead(ctx, buf) n, err := p.remoteConnRead(ctx, buf)
if err != nil { if err != nil {
if ctx.Err() != nil {
return
}
p.closeListener.Notify()
return return
} }

View File

@@ -122,10 +122,6 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
originalTimeout := healthCheckTimeout originalTimeout := healthCheckTimeout
healthCheckInterval = 1 * time.Second healthCheckInterval = 1 * time.Second
healthCheckTimeout = 500 * time.Millisecond healthCheckTimeout = 500 * time.Millisecond
defer func() {
healthCheckInterval = originalInterval
healthCheckTimeout = originalTimeout
}()
//nolint:tenv //nolint:tenv
os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold))
@@ -164,20 +160,23 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) {
select { select {
case <-sender.Timeout: case <-sender.Timeout:
if tc.resetCounterOnce { if tc.resetCounterOnce {
t.Fatalf("should not have timed out before %s", testTimeout) t.Errorf("should not have timed out before %s", testTimeout)
} }
case <-time.After(testTimeout): case <-time.After(testTimeout):
if tc.resetCounterOnce { if tc.resetCounterOnce {
return return
} }
t.Fatalf("should have timed out before %s", testTimeout) t.Errorf("should have timed out before %s", testTimeout)
} }
cancel()
select { select {
case <-senderExit: case <-senderExit:
case <-time.After(2 * time.Second): case <-time.After(2 * time.Second):
t.Fatalf("sender did not exit in time") t.Fatalf("sender did not exit in time")
} }
healthCheckInterval = originalInterval
healthCheckTimeout = originalTimeout
}) })
} }