diff --git a/client/iface/wgproxy/listener/listener.go b/client/iface/wgproxy/listener/listener.go index bfd651548..a8ee354a1 100644 --- a/client/iface/wgproxy/listener/listener.go +++ b/client/iface/wgproxy/listener/listener.go @@ -1,7 +1,10 @@ package listener +import "sync" + type CloseListener struct { listener func() + mu sync.Mutex } func NewCloseListener() *CloseListener { @@ -9,11 +12,21 @@ func NewCloseListener() *CloseListener { } func (c *CloseListener) SetCloseListener(listener func()) { + c.mu.Lock() + defer c.mu.Unlock() + c.listener = listener } func (c *CloseListener) Notify() { - if c.listener != nil { - c.listener() + c.mu.Lock() + + if c.listener == nil { + c.mu.Unlock() + return } + listener := c.listener + c.mu.Unlock() + + listener() } diff --git a/client/iface/wgproxy/udp/proxy.go b/client/iface/wgproxy/udp/proxy.go index df45d8ca5..139ccd4ed 100644 --- a/client/iface/wgproxy/udp/proxy.go +++ b/client/iface/wgproxy/udp/proxy.go @@ -183,6 +183,11 @@ func (p *WGUDPProxy) proxyToLocal(ctx context.Context) { for { n, err := p.remoteConnRead(ctx, buf) if err != nil { + if ctx.Err() != nil { + return + } + + p.closeListener.Notify() return } diff --git a/relay/healthcheck/sender_test.go b/relay/healthcheck/sender_test.go index 39d266b48..23446366a 100644 --- a/relay/healthcheck/sender_test.go +++ b/relay/healthcheck/sender_test.go @@ -122,10 +122,6 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { originalTimeout := healthCheckTimeout healthCheckInterval = 1 * time.Second healthCheckTimeout = 500 * time.Millisecond - defer func() { - healthCheckInterval = originalInterval - healthCheckTimeout = originalTimeout - }() //nolint:tenv os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) @@ -164,20 +160,23 @@ func TestSenderHealthCheckAttemptThreshold(t *testing.T) { select { case <-sender.Timeout: if tc.resetCounterOnce { - t.Fatalf("should not have timed out before %s", testTimeout) + t.Errorf("should not have timed out before %s", testTimeout) } case <-time.After(testTimeout): if tc.resetCounterOnce { return } - t.Fatalf("should have timed out before %s", testTimeout) + t.Errorf("should have timed out before %s", testTimeout) } + cancel() select { case <-senderExit: case <-time.After(2 * time.Second): t.Fatalf("sender did not exit in time") } + healthCheckInterval = originalInterval + healthCheckTimeout = originalTimeout }) }