diff --git a/relay/client/client.go b/relay/client/client.go index c2f6f9c71..f57bb2b92 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -19,10 +19,6 @@ const ( serverResponseTimeout = 8 * time.Second ) -var ( - reconnectingTimeout = 5 * time.Second -) - type Msg struct { buf []byte } @@ -47,35 +43,42 @@ type Client struct { serviceIsRunning bool serviceIsRunningMutex sync.Mutex wgReadLoop sync.WaitGroup - onDisconnected chan struct{} remoteAddr net.Addr + + onDisconnectListener func() + listenerMutex sync.Mutex } func NewClient(ctx context.Context, serverAddress, peerID string) *Client { ctx, ctxCancel := context.WithCancel(ctx) hashedID, hashedStringId := messages.HashID(peerID) return &Client{ - log: log.WithField("client_id", hashedStringId), - ctx: ctx, - ctxCancel: ctxCancel, - serverAddress: serverAddress, - hashedID: hashedID, - conns: make(map[string]*connContainer), - onDisconnected: make(chan struct{}), + log: log.WithField("client_id", hashedStringId), + ctx: ctx, + ctxCancel: ctxCancel, + serverAddress: serverAddress, + hashedID: hashedID, + conns: make(map[string]*connContainer), } } +func (c *Client) SetOnDisconnectListener(fn func()) { + c.listenerMutex.Lock() + defer c.listenerMutex.Unlock() + c.onDisconnectListener = fn +} + func (c *Client) Connect() error { c.serviceIsRunningMutex.Lock() + defer c.serviceIsRunningMutex.Unlock() + if c.serviceIsRunning { - c.serviceIsRunningMutex.Unlock() return nil } err := c.connect() if err != nil { - c.serviceIsRunningMutex.Unlock() return err } @@ -84,41 +87,6 @@ func (c *Client) Connect() error { c.wgReadLoop.Add(1) go c.readLoop() - c.serviceIsRunningMutex.Unlock() - - go func() { - <-c.ctx.Done() - cErr := c.close() - if cErr != nil { - log.Errorf("failed to close relay connection: %s", cErr) - } - }() - - go c.reconnectGuard() - - return nil -} - -func (c *Client) ConnectWithoutReconnect() error { - c.serviceIsRunningMutex.Lock() - if c.serviceIsRunning { - c.serviceIsRunningMutex.Unlock() - return nil - } - - err := c.connect() - if err != nil { - c.serviceIsRunningMutex.Unlock() - return err - } - - c.serviceIsRunning = true - - c.wgReadLoop.Add(1) - go c.readLoop() - - c.serviceIsRunningMutex.Unlock() - go func() { <-c.ctx.Done() cErr := c.close() @@ -210,33 +178,6 @@ func (c *Client) close() error { return err } -func (c *Client) reconnectGuard() { - for { - c.wgReadLoop.Wait() - - c.serviceIsRunningMutex.Lock() - if !c.serviceIsRunning { - c.serviceIsRunningMutex.Unlock() - return - } - - log.Infof("reconnecting to relay server") - err := c.connect() - if err != nil { - log.Errorf("failed to reconnect to relay server: %s", err) - c.serviceIsRunningMutex.Unlock() - time.Sleep(reconnectingTimeout) - continue - } - log.Infof("reconnected to relay server") - c.wgReadLoop.Add(1) - go c.readLoop() - - c.serviceIsRunningMutex.Unlock() - - } -} - func (c *Client) handShake() error { defer func() { err := c.relayConn.SetReadDeadline(time.Time{}) @@ -322,6 +263,8 @@ func (c *Client) readLoop() { } } + c.notifyDisconnected() + if c.serviceIsRunning { _ = c.relayConn.Close() } @@ -384,3 +327,23 @@ func (c *Client) closeConn(id string) error { return nil } + +func (c *Client) onDisconnect() { + c.listenerMutex.Lock() + defer c.listenerMutex.Unlock() + + if c.onDisconnectListener == nil { + return + } + c.onDisconnectListener() +} + +func (c *Client) notifyDisconnected() { + c.listenerMutex.Lock() + defer c.listenerMutex.Unlock() + + if c.onDisconnectListener == nil { + return + } + go c.onDisconnectListener() +} diff --git a/relay/client/guard.go b/relay/client/guard.go new file mode 100644 index 000000000..47a6ff722 --- /dev/null +++ b/relay/client/guard.go @@ -0,0 +1,33 @@ +package client + +import ( + "context" + "time" +) + +var ( + reconnectingTimeout = 5 * time.Second +) + +type Guard struct { + ctx context.Context + relayClient *Client +} + +func NewGuard(context context.Context, relayClient *Client) *Guard { + g := &Guard{ + ctx: context, + relayClient: relayClient, + } + + return g +} + +func (g *Guard) OnDisconnected() { + select { + case <-time.After(time.Second): + _ = g.relayClient.Connect() + case <-g.ctx.Done(): + return + } +} diff --git a/relay/client/manager.go b/relay/client/manager.go index 686b4ac4f..e2ced6f5b 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -4,17 +4,19 @@ import ( "context" "fmt" "net" + "sync" ) -// Manager todo: thread safe type Manager struct { ctx context.Context srvAddress string peerID string - relayClient *Client + relayClient *Client + reconnectGuard *Guard - relayClients map[string]*Client + relayClients map[string]*Client + relayClientsMutex sync.Mutex } func NewManager(ctx context.Context, serverAddress string, peerID string) *Manager { @@ -28,13 +30,33 @@ func NewManager(ctx context.Context, serverAddress string, peerID string) *Manag func (m *Manager) Serve() error { m.relayClient = NewClient(m.ctx, m.srvAddress, m.peerID) + m.reconnectGuard = NewGuard(m.ctx, m.relayClient) + m.relayClient.SetOnDisconnectListener(m.reconnectGuard.OnDisconnected) err := m.relayClient.Connect() if err != nil { return err } + return nil } +func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { + if m.relayClient == nil { + return nil, fmt.Errorf("relay client not connected") + } + + foreign, err := m.isForeignServer(serverAddress) + if err != nil { + return nil, err + } + + if foreign { + return m.openConnVia(serverAddress, peerKey) + } else { + return m.relayClient.OpenConn(peerKey) + } +} + func (m *Manager) RelayAddress() (net.Addr, error) { if m.relayClient == nil { return nil, fmt.Errorf("relay client not connected") @@ -42,48 +64,38 @@ func (m *Manager) RelayAddress() (net.Addr, error) { return m.relayClient.RelayRemoteAddress() } -func (m *Manager) OpenConn(peerKey string) (net.Conn, error) { - if m.relayClient == nil { - return nil, fmt.Errorf("relay client not connected") - } - - rAddr, err := m.relayClient.RelayRemoteAddress() - if err != nil { - return nil, fmt.Errorf("relay client not connected") - } - - return m.OpenConnTo(rAddr.String(), peerKey) -} - -func (m *Manager) OpenConnTo(serverAddress, peerKey string) (net.Conn, error) { - if m.relayClient == nil { - return nil, fmt.Errorf("relay client not connected") - } - rAddr, err := m.relayClient.RelayRemoteAddress() - if err != nil { - return nil, fmt.Errorf("relay client not connected") - } - - if rAddr.String() == serverAddress { - return m.relayClient.OpenConn(peerKey) - } - +func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { relayClient, ok := m.relayClients[serverAddress] if ok { return relayClient.OpenConn(peerKey) } relayClient = NewClient(m.ctx, serverAddress, m.peerID) - err = relayClient.ConnectWithoutReconnect() + err := relayClient.Connect() if err != nil { return nil, err } - + relayClient.SetOnDisconnectListener(func() { + m.deleteRelayConn(serverAddress) + }) conn, err := relayClient.OpenConn(peerKey) if err != nil { return nil, err } m.relayClients[serverAddress] = relayClient + return conn, nil } + +func (m *Manager) deleteRelayConn(address string) { + delete(m.relayClients, address) +} + +func (m *Manager) isForeignServer(address string) (bool, error) { + rAddr, err := m.relayClient.RelayRemoteAddress() + if err != nil { + return false, fmt.Errorf("relay client not connected") + } + return rAddr.String() != address, nil +} diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go index 6c8edd534..2c041fff4 100644 --- a/relay/client/manager_test.go +++ b/relay/client/manager_test.go @@ -48,6 +48,10 @@ func TestNewManager(t *testing.T) { if err != nil { t.Fatalf("failed to connect to server: %s", err) } + aliceSrvAddr, err := clientAlice.RelayAddress() + if err != nil { + t.Fatalf("failed to get relay address: %s", err) + } clientBob := NewManager(ctx, addr2, idBob) err = clientBob.Serve() @@ -59,12 +63,12 @@ func TestNewManager(t *testing.T) { if err != nil { t.Fatalf("failed to get relay address: %s", err) } - connAliceToBob, err := clientAlice.OpenConnTo(bobsSrvAddr.String(), idBob) + connAliceToBob, err := clientAlice.OpenConn(bobsSrvAddr.String(), idBob) if err != nil { t.Fatalf("failed to bind channel: %s", err) } - connBobToAlice, err := clientBob.OpenConn(idAlice) + connBobToAlice, err := clientBob.OpenConn(aliceSrvAddr.String(), idAlice) if err != nil { t.Fatalf("failed to bind channel: %s", err) }