diff --git a/relay/client/client.go b/relay/client/client.go index fa4f0da6a..aa4fff5d7 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -87,7 +87,7 @@ func (c *Client) Connect() error { var ctx context.Context ctx, c.ctxCancel = context.WithCancel(c.parentCtx) context.AfterFunc(ctx, func() { - cErr := c.Close() + cErr := c.close(false) if cErr != nil { log.Errorf("failed to close relay connection: %s", cErr) } @@ -144,22 +144,30 @@ func (c *Client) HasConns() bool { // Close closes the connection to the relay server and all connections to other peers. func (c *Client) Close() error { + return c.close(false) +} + +func (c *Client) close(byServer bool) error { c.readLoopMutex.Lock() defer c.readLoopMutex.Unlock() c.mu.Lock() var err error if !c.serviceIsRunning { + c.mu.Unlock() return nil } c.serviceIsRunning = false - err = c.relayConn.Close() c.closeAllConns() + if !byServer { + c.writeCloseMsg() + err = c.relayConn.Close() + } c.mu.Unlock() c.wgReadLoop.Wait() - c.log.Infof("relay client ha been closed: %s", c.serverAddress) + c.log.Infof("relay connection closed with: %s", c.serverAddress) c.ctxCancel() return err } @@ -232,8 +240,11 @@ func (c *Client) handShake() error { } func (c *Client) readLoop(relayConn net.Conn) { - var errExit error - var n int + var ( + errExit error + n int + closedByServer bool + ) for { buf := make([]byte, bufferSize) n, errExit = relayConn.Read(buf) @@ -243,7 +254,7 @@ func (c *Client) readLoop(relayConn net.Conn) { c.log.Debugf("failed to read message from relay server: %s", errExit) } c.mu.Unlock() - break + goto Exit } msgType, err := messages.DetermineServerMsgType(buf[:n]) @@ -264,7 +275,7 @@ func (c *Client) readLoop(relayConn net.Conn) { c.mu.Lock() if !c.serviceIsRunning { c.mu.Unlock() - break + goto Exit } container, ok := c.conns[stringID] c.mu.Unlock() @@ -273,16 +284,19 @@ func (c *Client) readLoop(relayConn net.Conn) { continue } + // todo review is this can cause panic container.messages <- Msg{buf[:n]} + case messages.MsgClose: + closedByServer = true + log.Debugf("relay connection close by server") + goto Exit } } +Exit: c.notifyDisconnected() - - c.log.Tracef("exit from read loop") c.wgReadLoop.Done() - - c.Close() + _ = c.close(closedByServer) } // todo check by reference too, the id is not enought because the id come from the outer conn @@ -365,3 +379,11 @@ func (c *Client) notifyDisconnected() { } go c.onDisconnectListener() } + +func (c *Client) writeCloseMsg() { + msg := messages.MarshalCloseMsg() + _, err := c.relayConn.Write(msg) + if err != nil { + c.log.Errorf("failed to send close message: %s", err) + } +} diff --git a/relay/client/client_test.go b/relay/client/client_test.go index 3aa931d68..f5d122276 100644 --- a/relay/client/client_test.go +++ b/relay/client/client_test.go @@ -100,57 +100,56 @@ func TestRegistration(t *testing.T) { } }() - defer func() { - err := srv.Close() - if err != nil { - t.Errorf("failed to close server: %s", err) - } - }() - clientAlice := NewClient(ctx, addr, "alice") err := clientAlice.Connect() if err != nil { + _ = srv.Close() t.Fatalf("failed to connect to server: %s", err) } - defer func() { - err = clientAlice.Close() - if err != nil { - t.Errorf("failed to close conn: %s", err) - } - }() + err = clientAlice.Close() + if err != nil { + t.Errorf("failed to close conn: %s", err) + } + err = srv.Close() + if err != nil { + t.Errorf("failed to close server: %s", err) + } } func TestRegistrationTimeout(t *testing.T) { ctx := context.Background() - udpListener, err := net.ListenUDP("udp", &net.UDPAddr{ + fakeUDPListener, err := net.ListenUDP("udp", &net.UDPAddr{ Port: 1234, IP: net.ParseIP("0.0.0.0"), }) if err != nil { t.Fatalf("failed to bind UDP server: %s", err) } - defer udpListener.Close() + defer func(fakeUDPListener *net.UDPConn) { + _ = fakeUDPListener.Close() + }(fakeUDPListener) - tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{ + fakeTCPListener, err := net.ListenTCP("tcp", &net.TCPAddr{ Port: 1234, IP: net.ParseIP("0.0.0.0"), }) if err != nil { t.Fatalf("failed to bind TCP server: %s", err) } - defer tcpListener.Close() + defer func(fakeTCPListener *net.TCPListener) { + _ = fakeTCPListener.Close() + }(fakeTCPListener) clientAlice := NewClient(ctx, "127.0.0.1:1234", "alice") err = clientAlice.Connect() if err == nil { t.Errorf("failed to connect to server: %s", err) } - defer func() { - err = clientAlice.Close() - if err != nil { - t.Errorf("failed to close conn: %s", err) - } - }() + log.Debugf("%s", err) + err = clientAlice.Close() + if err != nil { + t.Errorf("failed to close conn: %s", err) + } } func TestEcho(t *testing.T) { @@ -259,18 +258,16 @@ func TestBindToUnavailabePeer(t *testing.T) { if err != nil { t.Errorf("failed to connect to server: %s", err) } - defer func() { - log.Infof("closing client") - err := clientAlice.Close() - if err != nil { - t.Errorf("failed to close client: %s", err) - } - }() - _, err = clientAlice.OpenConn("bob") if err != nil { t.Errorf("failed to bind channel: %s", err) } + + log.Infof("closing client") + err = clientAlice.Close() + if err != nil { + t.Errorf("failed to close client: %s", err) + } } func TestBindReconnect(t *testing.T) { @@ -315,7 +312,7 @@ func TestBindReconnect(t *testing.T) { t.Errorf("failed to bind channel: %s", err) } - log.Infof("closing client") + log.Infof("closing client Alice") err = clientAlice.Close() if err != nil { t.Errorf("failed to close client: %s", err) @@ -403,52 +400,6 @@ func TestCloseConn(t *testing.T) { } } -func TestAutoReconnect(t *testing.T) { - ctx := context.Background() - - addr := "localhost:1234" - srv := server.NewServer() - go func() { - err := srv.Listen(addr) - if err != nil { - t.Errorf("failed to bind server: %s", err) - } - }() - - defer func() { - err := srv.Close() - if err != nil { - log.Errorf("failed to close server: %s", err) - } - }() - - clientAlice := NewClient(ctx, addr, "alice") - err := clientAlice.Connect() - if err != nil { - t.Errorf("failed to connect to server: %s", err) - } - - conn, err := clientAlice.OpenConn("bob") - if err != nil { - t.Errorf("failed to bind channel: %s", err) - } - - _ = clientAlice.relayConn.Close() - - _, err = conn.Read(make([]byte, 1)) - if err == nil { - t.Errorf("unexpected reading from closed connection") - } - - log.Infof("waiting for reconnection") - time.Sleep(reconnectingTimeout) - - _, err = clientAlice.OpenConn("bob") - if err != nil { - t.Errorf("failed to open channel: %s", err) - } -} - func TestCloseRelayConn(t *testing.T) { ctx := context.Background() @@ -491,3 +442,82 @@ func TestCloseRelayConn(t *testing.T) { t.Errorf("unexpected opening connection to closed server") } } + +func TestCloseByServer(t *testing.T) { + ctx := context.Background() + + addr1 := "localhost:1234" + srv1 := server.NewServer() + go func() { + err := srv1.Listen(addr1) + if err != nil { + t.Fatalf("failed to bind server: %s", err) + } + }() + + idAlice := "alice" + log.Debugf("connect by alice") + relayClient := NewClient(ctx, addr1, idAlice) + err := relayClient.Connect() + if err != nil { + log.Fatalf("failed to connect to server: %s", err) + } + + disconnected := make(chan struct{}) + relayClient.SetOnDisconnectListener(func() { + log.Infof("client disconnected") + close(disconnected) + }) + + err = srv1.Close() + if err != nil { + t.Fatalf("failed to close server: %s", err) + } + + select { + case <-disconnected: + case <-time.After(3 * time.Second): + log.Fatalf("timeout waiting for client to disconnect") + } + + _, err = relayClient.OpenConn("bob") + if err == nil { + t.Errorf("unexpected opening connection to closed server") + } +} + +func TestCloseByClient(t *testing.T) { + ctx := context.Background() + + addr1 := "localhost:1234" + srv := server.NewServer() + go func() { + err := srv.Listen(addr1) + if err != nil { + t.Fatalf("failed to bind server: %s", err) + } + }() + + idAlice := "alice" + log.Debugf("connect by alice") + relayClient := NewClient(ctx, addr1, idAlice) + err := relayClient.Connect() + if err != nil { + log.Fatalf("failed to connect to server: %s", err) + } + + err = relayClient.Close() + if err != nil { + t.Errorf("failed to close client: %s", err) + } + + _, err = relayClient.OpenConn("bob") + if err == nil { + t.Errorf("unexpected opening connection to closed server") + } + + err = srv.Close() + if err != nil { + t.Fatalf("failed to close server: %s", err) + } +} diff --git a/relay/client/dialer/ws/client_conn.go b/relay/client/dialer/ws/client_conn.go index 72a3fa9b4..698cb9d24 100644 --- a/relay/client/dialer/ws/client_conn.go +++ b/relay/client/dialer/ws/client_conn.go @@ -6,7 +6,6 @@ import ( "time" "github.com/gorilla/websocket" - log "github.com/sirupsen/logrus" ) type Conn struct { @@ -52,9 +51,5 @@ func (c *Conn) SetDeadline(t time.Time) error { } func (c *Conn) Close() error { - err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)) - if err != nil { - log.Errorf("failed to close conn?: %s", err) - } return c.Conn.Close() } diff --git a/relay/client/dialer/ws/ws.go b/relay/client/dialer/ws/ws.go index 3ab3d38a8..b526bd315 100644 --- a/relay/client/dialer/ws/ws.go +++ b/relay/client/dialer/ws/ws.go @@ -3,13 +3,17 @@ package ws import ( "fmt" "net" + "time" "github.com/gorilla/websocket" ) func Dial(address string) (net.Conn, error) { addr := fmt.Sprintf("ws://" + address) - wsConn, _, err := websocket.DefaultDialer.Dial(addr, nil) + wsDialer := websocket.Dialer{ + HandshakeTimeout: 3 * time.Second, + } + wsConn, _, err := wsDialer.Dial(addr, nil) if err != nil { return nil, err } diff --git a/relay/client/manager.go b/relay/client/manager.go index 98ec3a1c8..911561403 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -84,8 +84,10 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { } if !foreign { + log.Debugf("open connection to permanent server: %s", peerKey) return m.relayClient.OpenConn(peerKey) } else { + log.Debugf("open connection to foreign server: %s", serverAddress) return m.openConnVia(serverAddress, peerKey) } } diff --git a/relay/client/manager_test.go b/relay/client/manager_test.go index d0e22dabd..f4de9b4de 100644 --- a/relay/client/manager_test.go +++ b/relay/client/manager_test.go @@ -47,12 +47,14 @@ func TestForeignConn(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") - clientAlice := NewManager(ctx, addr1, idAlice) + mCtx, cancel := context.WithCancel(ctx) + defer cancel() + clientAlice := NewManager(mCtx, addr1, idAlice) clientAlice.Serve() idBob := "bob" log.Debugf("connect by bob") - clientBob := NewManager(ctx, addr2, idBob) + clientBob := NewManager(mCtx, addr2, idBob) clientBob.Serve() bobsSrvAddr, err := clientBob.RelayAddress() @@ -132,61 +134,9 @@ func TestForeginConnClose(t *testing.T) { idAlice := "alice" log.Debugf("connect by alice") - clientAlice := NewManager(ctx, addr1, idAlice) - clientAlice.Serve() - - conn, err := clientAlice.OpenConn(addr2, "anotherpeer") - if err != nil { - t.Fatalf("failed to bind channel: %s", err) - } - - err = conn.Close() - if err != nil { - t.Fatalf("failed to close connection: %s", err) - } - - select {} -} - -func TestForeginAutoClose(t *testing.T) { - ctx := context.Background() - - addr1 := "localhost:1234" - srv1 := server.NewServer() - go func() { - err := srv1.Listen(addr1) - if err != nil { - t.Fatalf("failed to bind server: %s", err) - } - }() - - defer func() { - err := srv1.Close() - if err != nil { - t.Errorf("failed to close server: %s", err) - } - }() - - addr2 := "localhost:2234" - srv2 := server.NewServer() - go func() { - err := srv2.Listen(addr2) - if err != nil { - t.Fatalf("failed to bind server: %s", err) - } - }() - - defer func() { - err := srv2.Close() - if err != nil { - t.Errorf("failed to close server: %s", err) - } - }() - - idAlice := "alice" - log.Debugf("connect by alice") - mgr := NewManager(ctx, addr1, idAlice) - relayCleanupInterval = 2 * time.Second + mCtx, cancel := context.WithCancel(ctx) + defer cancel() + mgr := NewManager(mCtx, addr1, idAlice) mgr.Serve() conn, err := mgr.OpenConn(addr2, "anotherpeer") @@ -198,9 +148,124 @@ func TestForeginAutoClose(t *testing.T) { if err != nil { t.Fatalf("failed to close connection: %s", err) } +} +func TestForeginAutoClose(t *testing.T) { + ctx := context.Background() + relayCleanupInterval = 1 * time.Second + addr1 := "localhost:1234" + srv1 := server.NewServer() + go func() { + t.Log("binding server 1.") + err := srv1.Listen(addr1) + if err != nil { + t.Fatalf("failed to bind server: %s", err) + } + }() + + defer func() { + t.Logf("closing server 1.") + err := srv1.Close() + if err != nil { + t.Errorf("failed to close server: %s", err) + } + t.Logf("server 1. closed") + }() + + addr2 := "localhost:2234" + srv2 := server.NewServer() + go func() { + t.Log("binding server 2.") + err := srv2.Listen(addr2) + if err != nil { + t.Fatalf("failed to bind server: %s", err) + } + }() + defer func() { + t.Logf("closing server 2.") + err := srv2.Close() + if err != nil { + t.Errorf("failed to close server: %s", err) + } + t.Logf("server 2 closed.") + }() + + idAlice := "alice" + t.Log("connect to server 1.") + mCtx, cancel := context.WithCancel(ctx) + defer cancel() + mgr := NewManager(mCtx, addr1, idAlice) + mgr.Serve() + + t.Log("open connection to another peer") + conn, err := mgr.OpenConn(addr2, "anotherpeer") + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + t.Log("close conn") + err = conn.Close() + if err != nil { + t.Fatalf("failed to close connection: %s", err) + } + + t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second) time.Sleep(relayCleanupInterval + 1*time.Second) if len(mgr.relayClients) != 0 { t.Errorf("expected 0, got %d", len(mgr.relayClients)) } + + t.Logf("closing manager") +} + +func TestAutoReconnect(t *testing.T) { + ctx := context.Background() + reconnectingTimeout = 2 * time.Second + + addr := "localhost:1234" + srv := server.NewServer() + go func() { + err := srv.Listen(addr) + if err != nil { + t.Errorf("failed to bind server: %s", err) + } + }() + + defer func() { + err := srv.Close() + if err != nil { + log.Errorf("failed to close server: %s", err) + } + }() + + mCtx, cancel := context.WithCancel(ctx) + defer cancel() + clientAlice := NewManager(mCtx, addr, "alice") + clientAlice.Serve() + ra, err := clientAlice.RelayAddress() + if err != nil { + t.Errorf("failed to get relay address: %s", err) + } + conn, err := clientAlice.OpenConn(ra.String(), "bob") + if err != nil { + t.Errorf("failed to bind channel: %s", err) + } + + t.Log("closing client relay connection") + // todo figure out moc server + _ = clientAlice.relayClient.relayConn.Close() + t.Log("start test reading") + _, err = conn.Read(make([]byte, 1)) + if err == nil { + t.Errorf("unexpected reading from closed connection") + } + + log.Infof("waiting for reconnection") + time.Sleep(reconnectingTimeout + 1*time.Second) + + log.Infof("reopent the connection") + _, err = clientAlice.OpenConn(ra.String(), "bob") + if err != nil { + t.Errorf("failed to open channel: %s", err) + } } diff --git a/relay/messages/message.go b/relay/messages/message.go index 1c34d5034..7f73daa17 100644 --- a/relay/messages/message.go +++ b/relay/messages/message.go @@ -10,6 +10,7 @@ const ( MsgTypeHello MsgType = 0 MsgTypeHelloResponse MsgType = 1 MsgTypeTransport MsgType = 2 + MsgClose MsgType = 3 ) var ( @@ -26,6 +27,8 @@ func (m MsgType) String() string { return "hello response" case MsgTypeTransport: return "transport" + case MsgClose: + return "close" default: return "unknown" } @@ -39,6 +42,8 @@ func DetermineClientMsgType(msg []byte) (MsgType, error) { return msgType, nil case MsgTypeTransport: return msgType, nil + case MsgClose: + return msgType, nil default: return 0, fmt.Errorf("invalid msg type, len: %d", len(msg)) } @@ -52,6 +57,8 @@ func DetermineServerMsgType(msg []byte) (MsgType, error) { return msgType, nil case MsgTypeTransport: return msgType, nil + case MsgClose: + return msgType, nil default: return 0, fmt.Errorf("invalid msg type (len: %d)", len(msg)) } @@ -81,6 +88,14 @@ func MarshalHelloResponse() []byte { return msg } +// Close message + +func MarshalCloseMsg() []byte { + msg := make([]byte, 1) + msg[0] = byte(MsgClose) + return msg +} + // Transport message func MarshalTransportMsg(peerID []byte, payload []byte) []byte { diff --git a/relay/server/listener/listener.go b/relay/server/listener/listener.go index 66e6d357e..3336e0ad8 100644 --- a/relay/server/listener/listener.go +++ b/relay/server/listener/listener.go @@ -5,4 +5,5 @@ import "net" type Listener interface { Listen(func(conn net.Conn)) error Close() error + WaitForExitAcceptedConns() } diff --git a/relay/server/listener/udp/listener.go b/relay/server/listener/udp/listener.go index ebd1c53f1..3c2dfc070 100644 --- a/relay/server/listener/udp/listener.go +++ b/relay/server/listener/udp/listener.go @@ -21,6 +21,11 @@ type Listener struct { lock sync.Mutex } +func (l *Listener) WaitForExitAcceptedConns() { + l.wg.Wait() + return +} + func NewListener(address string) listener.Listener { return &Listener{ address: address, @@ -61,11 +66,11 @@ func (l *Listener) Close() error { l.lock.Lock() defer l.lock.Unlock() - log.Infof("closing UDP server") if l.listener == nil { return nil } + log.Infof("closing UDP listener") close(l.quit) err := l.listener.Close() l.wg.Wait() diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index b53e97505..632de153f 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -33,8 +33,11 @@ func NewListener(address string) listener.Listener { } } -// Listen todo: prevent multiple call func (l *Listener) Listen(acceptFn func(conn net.Conn)) error { + if l.server != nil { + return errors.New("server is already running") + } + l.acceptFn = acceptFn mux := http.NewServeMux() mux.HandleFunc("/", l.onAccept) @@ -69,6 +72,10 @@ func (l *Listener) Close() error { return nil } +func (l *Listener) WaitForExitAcceptedConns() { + l.wg.Wait() +} + func (l *Listener) onAccept(writer http.ResponseWriter, request *http.Request) { l.wg.Add(1) defer l.wg.Done() diff --git a/relay/server/listener/wsnhooyr/conn.go b/relay/server/listener/wsnhooyr/conn.go index 72b6bfecb..b52e5d082 100644 --- a/relay/server/listener/wsnhooyr/conn.go +++ b/relay/server/listener/wsnhooyr/conn.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "sync" "time" log "github.com/sirupsen/logrus" @@ -17,7 +18,9 @@ type Conn struct { lAddr *net.TCPAddr rAddr *net.TCPAddr - ctx context.Context + closed bool + closedMu sync.Mutex + ctx context.Context } func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn { @@ -32,7 +35,7 @@ func NewConn(wsConn *websocket.Conn, lAddr, rAddr *net.TCPAddr) *Conn { func (c *Conn) Read(b []byte) (n int, err error) { t, r, err := c.Reader(c.ctx) if err != nil { - return 0, ioErrHandling(err) + return 0, c.ioErrHandling(err) } if t != websocket.MessageBinary { @@ -42,7 +45,7 @@ func (c *Conn) Read(b []byte) (n int, err error) { n, err = r.Read(b) if err != nil { - return 0, ioErrHandling(err) + return 0, c.ioErrHandling(err) } return n, err } @@ -76,11 +79,23 @@ func (c *Conn) SetDeadline(t time.Time) error { } func (c *Conn) Close() error { - return c.Conn.Close(websocket.StatusNormalClosure, "") + c.closedMu.Lock() + c.closed = true + c.closedMu.Unlock() + return c.Conn.CloseNow() } -// todo: fix io.EOF handling -func ioErrHandling(err error) error { +func (c *Conn) isClosed() bool { + c.closedMu.Lock() + defer c.closedMu.Unlock() + return c.closed +} + +func (c *Conn) ioErrHandling(err error) error { + if c.isClosed() { + return io.EOF + } + var wErr *websocket.CloseError if !errors.As(err, &wErr) { return err diff --git a/relay/server/listener/wsnhooyr/listener.go b/relay/server/listener/wsnhooyr/listener.go index 88370a1fc..e47a60b47 100644 --- a/relay/server/listener/wsnhooyr/listener.go +++ b/relay/server/listener/wsnhooyr/listener.go @@ -56,15 +56,18 @@ func (l *Listener) Close() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - log.Debugf("closing WS server") + log.Infof("stop WS listener") if err := l.server.Shutdown(ctx); err != nil { return fmt.Errorf("server shutdown failed: %v", err) } - - l.wg.Wait() + log.Infof("WS listener stopped") return nil } +func (l *Listener) WaitForExitAcceptedConns() { + l.wg.Wait() +} + func (l *Listener) onAccept(w http.ResponseWriter, r *http.Request) { l.wg.Add(1) defer l.wg.Done() diff --git a/relay/server/server.go b/relay/server/server.go index cb9816907..bb9666a53 100644 --- a/relay/server/server.go +++ b/relay/server/server.go @@ -15,11 +15,9 @@ import ( ws "github.com/netbirdio/netbird/relay/server/listener/wsnhooyr" ) -// Server -// todo: -// authentication: provide JWT token via RPC call. The MGM server can forward the token to the agents. type Server struct { - store *Store + store *Store + storeMu sync.RWMutex UDPListener listener.Listener WSListener listener.Listener @@ -27,7 +25,8 @@ type Server struct { func NewServer() *Server { return &Server{ - store: NewStore(), + store: NewStore(), + storeMu: sync.RWMutex{}, } } @@ -69,6 +68,11 @@ func (r *Server) Close() error { if r.UDPListener != nil { uErr = r.UDPListener.Close() } + + r.sendCloseMsgs() + + r.WSListener.WaitForExitAcceptedConns() + err := errors.Join(wErr, uErr) return err } @@ -88,7 +92,7 @@ func (r *Server) accept(conn net.Conn) { r.store.AddPeer(peer) defer func() { r.store.DeletePeer(peer) - peer.Log.Infof("peer left") + peer.Log.Infof("relay connection closed") }() for { @@ -132,10 +136,33 @@ func (r *Server) accept(conn net.Conn) { } return }() + case messages.MsgClose: + peer.Log.Infof("peer disconnected gracefully") + _ = conn.Close() + return } } } +func (r *Server) sendCloseMsgs() { + msg := messages.MarshalCloseMsg() + + r.storeMu.Lock() + log.Debugf("sending close messages to %d peers", len(r.store.peers)) + for _, p := range r.store.peers { + _, err := p.conn.Write(msg) + if err != nil { + log.Errorf("failed to send close message to peer: %s", p.String()) + } + + err = p.conn.Close() + if err != nil { + log.Errorf("failed to close connection to peer: %s", err) + } + } + r.storeMu.Unlock() +} + func handShake(conn net.Conn) (*Peer, error) { buf := make([]byte, 1500) n, err := conn.Read(buf) diff --git a/relay/server/store.go b/relay/server/store.go index f785f4d0e..1f0f08600 100644 --- a/relay/server/store.go +++ b/relay/server/store.go @@ -24,7 +24,6 @@ func (s *Store) AddPeer(peer *Peer) { func (s *Store) DeletePeer(peer *Peer) { s.peersLock.Lock() defer s.peersLock.Unlock() - delete(s.peers, peer.String()) } @@ -35,3 +34,14 @@ func (s *Store) Peer(id string) (*Peer, bool) { p, ok := s.peers[id] return p, ok } + +func (s *Store) Peers() []*Peer { + s.peersLock.RLock() + defer s.peersLock.RUnlock() + + peers := make([]*Peer, 0, len(s.peers)) + for _, p := range s.peers { + peers = append(peers, p) + } + return peers +}