mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 18:22:37 +02:00
Fix close conn threading issue
This commit is contained in:
parent
3430b81622
commit
4ced07dd8d
@ -30,18 +30,17 @@ type connContainer struct {
|
||||
|
||||
type Client struct {
|
||||
log *log.Entry
|
||||
parentCtx context.Context
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
serverAddress string
|
||||
hashedID []byte
|
||||
|
||||
readyToOpenConns bool
|
||||
conns map[string]*connContainer
|
||||
connsMutext sync.Mutex // protect conns and readyToOpenConns bool
|
||||
|
||||
relayConn net.Conn
|
||||
conns map[string]*connContainer
|
||||
serviceIsRunning bool
|
||||
serviceIsRunningMutex sync.Mutex
|
||||
mu sync.Mutex
|
||||
readLoopMutex sync.Mutex
|
||||
wgReadLoop sync.WaitGroup
|
||||
|
||||
remoteAddr net.Addr
|
||||
@ -51,12 +50,11 @@ type Client struct {
|
||||
}
|
||||
|
||||
func NewClient(ctx context.Context, serverAddress, peerID string) *Client {
|
||||
ctx, ctxCancel := context.WithCancel(ctx)
|
||||
hashedID, hashedStringId := messages.HashID(peerID)
|
||||
return &Client{
|
||||
log: log.WithField("client_id", hashedStringId),
|
||||
ctx: ctx,
|
||||
ctxCancel: ctxCancel,
|
||||
parentCtx: ctx,
|
||||
ctxCancel: func() {},
|
||||
serverAddress: serverAddress,
|
||||
hashedID: hashedID,
|
||||
conns: make(map[string]*connContainer),
|
||||
@ -70,39 +68,44 @@ func (c *Client) SetOnDisconnectListener(fn func()) {
|
||||
}
|
||||
|
||||
func (c *Client) Connect() error {
|
||||
c.serviceIsRunningMutex.Lock()
|
||||
defer c.serviceIsRunningMutex.Unlock()
|
||||
c.readLoopMutex.Lock()
|
||||
defer c.readLoopMutex.Unlock()
|
||||
|
||||
c.mu.Lock()
|
||||
|
||||
if c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
err := c.connect()
|
||||
if err != nil {
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
c.serviceIsRunning = true
|
||||
|
||||
c.wgReadLoop.Add(1)
|
||||
go c.readLoop()
|
||||
|
||||
go func() {
|
||||
<-c.ctx.Done()
|
||||
cErr := c.close()
|
||||
c.ctx, c.ctxCancel = context.WithCancel(c.parentCtx)
|
||||
context.AfterFunc(c.ctx, func() {
|
||||
cErr := c.Close()
|
||||
if cErr != nil {
|
||||
log.Errorf("failed to close relay connection: %s", cErr)
|
||||
}
|
||||
}()
|
||||
})
|
||||
c.wgReadLoop.Add(1)
|
||||
go c.readLoop(c.relayConn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// todo: what should happen of call with the same peerID?
|
||||
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
||||
c.connsMutext.Lock()
|
||||
defer c.connsMutext.Unlock()
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if !c.readyToOpenConns {
|
||||
if !c.serviceIsRunning {
|
||||
return nil, fmt.Errorf("relay connection is not established")
|
||||
}
|
||||
|
||||
@ -119,8 +122,8 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
||||
}
|
||||
|
||||
func (c *Client) RelayRemoteAddress() (net.Addr, error) {
|
||||
c.serviceIsRunningMutex.Lock()
|
||||
defer c.serviceIsRunningMutex.Unlock()
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.remoteAddr == nil {
|
||||
return nil, fmt.Errorf("relay connection is not established")
|
||||
}
|
||||
@ -128,14 +131,21 @@ func (c *Client) RelayRemoteAddress() (net.Addr, error) {
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
c.serviceIsRunningMutex.Lock()
|
||||
if !c.serviceIsRunning {
|
||||
c.serviceIsRunningMutex.Unlock()
|
||||
return nil
|
||||
}
|
||||
c.readLoopMutex.Lock()
|
||||
defer c.readLoopMutex.Unlock()
|
||||
|
||||
c.mu.Lock()
|
||||
var err error
|
||||
if c.serviceIsRunning {
|
||||
c.serviceIsRunning = false
|
||||
err = c.relayConn.Close()
|
||||
}
|
||||
c.closeAllConns()
|
||||
c.mu.Unlock()
|
||||
|
||||
c.wgReadLoop.Wait()
|
||||
c.ctxCancel()
|
||||
return c.close()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) connect() error {
|
||||
@ -157,27 +167,9 @@ func (c *Client) connect() error {
|
||||
|
||||
c.remoteAddr = conn.RemoteAddr()
|
||||
|
||||
c.readyToOpenConns = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) close() error {
|
||||
c.serviceIsRunningMutex.Lock()
|
||||
defer c.serviceIsRunningMutex.Unlock()
|
||||
|
||||
if !c.serviceIsRunning {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.serviceIsRunning = false
|
||||
|
||||
err := c.relayConn.Close()
|
||||
|
||||
c.wgReadLoop.Wait()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) handShake() error {
|
||||
defer func() {
|
||||
err := c.relayConn.SetReadDeadline(time.Time{})
|
||||
@ -223,16 +215,18 @@ func (c *Client) handShake() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) readLoop() {
|
||||
func (c *Client) readLoop(relayConn net.Conn) {
|
||||
var errExit error
|
||||
var n int
|
||||
for {
|
||||
buf := make([]byte, bufferSize)
|
||||
n, errExit = c.relayConn.Read(buf)
|
||||
n, errExit = relayConn.Read(buf)
|
||||
if errExit != nil {
|
||||
c.mu.Lock()
|
||||
if c.serviceIsRunning {
|
||||
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
break
|
||||
}
|
||||
|
||||
@ -251,44 +245,44 @@ func (c *Client) readLoop() {
|
||||
}
|
||||
stringID := messages.HashIDToString(peerID)
|
||||
|
||||
c.mu.Lock()
|
||||
if !c.serviceIsRunning {
|
||||
c.mu.Unlock()
|
||||
break
|
||||
}
|
||||
container, ok := c.conns[stringID]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
c.log.Errorf("peer not found: %s", stringID)
|
||||
continue
|
||||
}
|
||||
|
||||
container.messages <- Msg{
|
||||
buf[:n],
|
||||
}
|
||||
container.messages <- Msg{buf[:n]}
|
||||
}
|
||||
}
|
||||
|
||||
c.notifyDisconnected()
|
||||
|
||||
if c.serviceIsRunning {
|
||||
_ = c.relayConn.Close()
|
||||
}
|
||||
|
||||
c.connsMutext.Lock()
|
||||
c.readyToOpenConns = false
|
||||
for _, container := range c.conns {
|
||||
close(container.messages)
|
||||
}
|
||||
c.conns = make(map[string]*connContainer)
|
||||
c.connsMutext.Unlock()
|
||||
|
||||
c.log.Tracef("exit from read loop")
|
||||
c.wgReadLoop.Done()
|
||||
|
||||
c.Close()
|
||||
}
|
||||
|
||||
// todo check by reference too, the id is not enought because the id come from the outer conn
|
||||
func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
|
||||
c.connsMutext.Lock()
|
||||
c.mu.Lock()
|
||||
// conn, ok := c.conns[id]
|
||||
_, ok := c.conns[id]
|
||||
c.mu.Unlock()
|
||||
if !ok {
|
||||
c.connsMutext.Unlock()
|
||||
return 0, io.EOF
|
||||
}
|
||||
c.connsMutext.Unlock()
|
||||
/*
|
||||
if conn != clientRef {
|
||||
return 0, io.EOF
|
||||
}
|
||||
*/
|
||||
msg := messages.MarshalTransportMsg(dstID, payload)
|
||||
n, err := c.relayConn.Write(msg)
|
||||
if err != nil {
|
||||
@ -314,9 +308,17 @@ func (c *Client) generateConnReaderFN(msgChannel chan Msg) func(b []byte) (n int
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) closeAllConns() {
|
||||
for _, container := range c.conns {
|
||||
close(container.messages)
|
||||
}
|
||||
c.conns = make(map[string]*connContainer)
|
||||
}
|
||||
|
||||
// todo check by reference too, the id is not enought because the id come from the outer conn
|
||||
func (c *Client) closeConn(id string) error {
|
||||
c.connsMutext.Lock()
|
||||
defer c.connsMutext.Unlock()
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
conn, ok := c.conns[id]
|
||||
if !ok {
|
||||
|
Loading…
x
Reference in New Issue
Block a user