mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-21 18:22:37 +02:00
Fix close conn threading issue
This commit is contained in:
parent
3430b81622
commit
4ced07dd8d
@ -30,19 +30,18 @@ type connContainer struct {
|
|||||||
|
|
||||||
type Client struct {
|
type Client struct {
|
||||||
log *log.Entry
|
log *log.Entry
|
||||||
|
parentCtx context.Context
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel context.CancelFunc
|
ctxCancel context.CancelFunc
|
||||||
serverAddress string
|
serverAddress string
|
||||||
hashedID []byte
|
hashedID []byte
|
||||||
|
|
||||||
readyToOpenConns bool
|
relayConn net.Conn
|
||||||
conns map[string]*connContainer
|
conns map[string]*connContainer
|
||||||
connsMutext sync.Mutex // protect conns and readyToOpenConns bool
|
serviceIsRunning bool
|
||||||
|
mu sync.Mutex
|
||||||
relayConn net.Conn
|
readLoopMutex sync.Mutex
|
||||||
serviceIsRunning bool
|
wgReadLoop sync.WaitGroup
|
||||||
serviceIsRunningMutex sync.Mutex
|
|
||||||
wgReadLoop sync.WaitGroup
|
|
||||||
|
|
||||||
remoteAddr net.Addr
|
remoteAddr net.Addr
|
||||||
|
|
||||||
@ -51,12 +50,11 @@ type Client struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(ctx context.Context, serverAddress, peerID string) *Client {
|
func NewClient(ctx context.Context, serverAddress, peerID string) *Client {
|
||||||
ctx, ctxCancel := context.WithCancel(ctx)
|
|
||||||
hashedID, hashedStringId := messages.HashID(peerID)
|
hashedID, hashedStringId := messages.HashID(peerID)
|
||||||
return &Client{
|
return &Client{
|
||||||
log: log.WithField("client_id", hashedStringId),
|
log: log.WithField("client_id", hashedStringId),
|
||||||
ctx: ctx,
|
parentCtx: ctx,
|
||||||
ctxCancel: ctxCancel,
|
ctxCancel: func() {},
|
||||||
serverAddress: serverAddress,
|
serverAddress: serverAddress,
|
||||||
hashedID: hashedID,
|
hashedID: hashedID,
|
||||||
conns: make(map[string]*connContainer),
|
conns: make(map[string]*connContainer),
|
||||||
@ -70,39 +68,44 @@ func (c *Client) SetOnDisconnectListener(fn func()) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Connect() error {
|
func (c *Client) Connect() error {
|
||||||
c.serviceIsRunningMutex.Lock()
|
c.readLoopMutex.Lock()
|
||||||
defer c.serviceIsRunningMutex.Unlock()
|
defer c.readLoopMutex.Unlock()
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
|
||||||
if c.serviceIsRunning {
|
if c.serviceIsRunning {
|
||||||
|
c.mu.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
err := c.connect()
|
err := c.connect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
c.mu.Unlock()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.serviceIsRunning = true
|
c.serviceIsRunning = true
|
||||||
|
|
||||||
c.wgReadLoop.Add(1)
|
c.ctx, c.ctxCancel = context.WithCancel(c.parentCtx)
|
||||||
go c.readLoop()
|
context.AfterFunc(c.ctx, func() {
|
||||||
|
cErr := c.Close()
|
||||||
go func() {
|
|
||||||
<-c.ctx.Done()
|
|
||||||
cErr := c.close()
|
|
||||||
if cErr != nil {
|
if cErr != nil {
|
||||||
log.Errorf("failed to close relay connection: %s", cErr)
|
log.Errorf("failed to close relay connection: %s", cErr)
|
||||||
}
|
}
|
||||||
}()
|
})
|
||||||
|
c.wgReadLoop.Add(1)
|
||||||
|
go c.readLoop(c.relayConn)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// todo: what should happen of call with the same peerID?
|
||||||
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
||||||
c.connsMutext.Lock()
|
c.mu.Lock()
|
||||||
defer c.connsMutext.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
if !c.readyToOpenConns {
|
if !c.serviceIsRunning {
|
||||||
return nil, fmt.Errorf("relay connection is not established")
|
return nil, fmt.Errorf("relay connection is not established")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,8 +122,8 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) RelayRemoteAddress() (net.Addr, error) {
|
func (c *Client) RelayRemoteAddress() (net.Addr, error) {
|
||||||
c.serviceIsRunningMutex.Lock()
|
c.mu.Lock()
|
||||||
defer c.serviceIsRunningMutex.Unlock()
|
defer c.mu.Unlock()
|
||||||
if c.remoteAddr == nil {
|
if c.remoteAddr == nil {
|
||||||
return nil, fmt.Errorf("relay connection is not established")
|
return nil, fmt.Errorf("relay connection is not established")
|
||||||
}
|
}
|
||||||
@ -128,14 +131,21 @@ func (c *Client) RelayRemoteAddress() (net.Addr, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
c.serviceIsRunningMutex.Lock()
|
c.readLoopMutex.Lock()
|
||||||
if !c.serviceIsRunning {
|
defer c.readLoopMutex.Unlock()
|
||||||
c.serviceIsRunningMutex.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
var err error
|
||||||
|
if c.serviceIsRunning {
|
||||||
|
c.serviceIsRunning = false
|
||||||
|
err = c.relayConn.Close()
|
||||||
|
}
|
||||||
|
c.closeAllConns()
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
c.wgReadLoop.Wait()
|
||||||
c.ctxCancel()
|
c.ctxCancel()
|
||||||
return c.close()
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) connect() error {
|
func (c *Client) connect() error {
|
||||||
@ -157,27 +167,9 @@ func (c *Client) connect() error {
|
|||||||
|
|
||||||
c.remoteAddr = conn.RemoteAddr()
|
c.remoteAddr = conn.RemoteAddr()
|
||||||
|
|
||||||
c.readyToOpenConns = true
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) close() error {
|
|
||||||
c.serviceIsRunningMutex.Lock()
|
|
||||||
defer c.serviceIsRunningMutex.Unlock()
|
|
||||||
|
|
||||||
if !c.serviceIsRunning {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
c.serviceIsRunning = false
|
|
||||||
|
|
||||||
err := c.relayConn.Close()
|
|
||||||
|
|
||||||
c.wgReadLoop.Wait()
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) handShake() error {
|
func (c *Client) handShake() error {
|
||||||
defer func() {
|
defer func() {
|
||||||
err := c.relayConn.SetReadDeadline(time.Time{})
|
err := c.relayConn.SetReadDeadline(time.Time{})
|
||||||
@ -223,16 +215,18 @@ func (c *Client) handShake() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) readLoop() {
|
func (c *Client) readLoop(relayConn net.Conn) {
|
||||||
var errExit error
|
var errExit error
|
||||||
var n int
|
var n int
|
||||||
for {
|
for {
|
||||||
buf := make([]byte, bufferSize)
|
buf := make([]byte, bufferSize)
|
||||||
n, errExit = c.relayConn.Read(buf)
|
n, errExit = relayConn.Read(buf)
|
||||||
if errExit != nil {
|
if errExit != nil {
|
||||||
|
c.mu.Lock()
|
||||||
if c.serviceIsRunning {
|
if c.serviceIsRunning {
|
||||||
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
c.log.Debugf("failed to read message from relay server: %s", errExit)
|
||||||
}
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -251,44 +245,44 @@ func (c *Client) readLoop() {
|
|||||||
}
|
}
|
||||||
stringID := messages.HashIDToString(peerID)
|
stringID := messages.HashIDToString(peerID)
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
if !c.serviceIsRunning {
|
||||||
|
c.mu.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
container, ok := c.conns[stringID]
|
container, ok := c.conns[stringID]
|
||||||
|
c.mu.Unlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
c.log.Errorf("peer not found: %s", stringID)
|
c.log.Errorf("peer not found: %s", stringID)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
container.messages <- Msg{
|
container.messages <- Msg{buf[:n]}
|
||||||
buf[:n],
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.notifyDisconnected()
|
c.notifyDisconnected()
|
||||||
|
|
||||||
if c.serviceIsRunning {
|
|
||||||
_ = c.relayConn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
c.connsMutext.Lock()
|
|
||||||
c.readyToOpenConns = false
|
|
||||||
for _, container := range c.conns {
|
|
||||||
close(container.messages)
|
|
||||||
}
|
|
||||||
c.conns = make(map[string]*connContainer)
|
|
||||||
c.connsMutext.Unlock()
|
|
||||||
|
|
||||||
c.log.Tracef("exit from read loop")
|
c.log.Tracef("exit from read loop")
|
||||||
c.wgReadLoop.Done()
|
c.wgReadLoop.Done()
|
||||||
|
|
||||||
|
c.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// todo check by reference too, the id is not enought because the id come from the outer conn
|
||||||
func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
|
func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
|
||||||
c.connsMutext.Lock()
|
c.mu.Lock()
|
||||||
|
// conn, ok := c.conns[id]
|
||||||
_, ok := c.conns[id]
|
_, ok := c.conns[id]
|
||||||
|
c.mu.Unlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
c.connsMutext.Unlock()
|
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
}
|
}
|
||||||
c.connsMutext.Unlock()
|
/*
|
||||||
|
if conn != clientRef {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
*/
|
||||||
msg := messages.MarshalTransportMsg(dstID, payload)
|
msg := messages.MarshalTransportMsg(dstID, payload)
|
||||||
n, err := c.relayConn.Write(msg)
|
n, err := c.relayConn.Write(msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -314,9 +308,17 @@ func (c *Client) generateConnReaderFN(msgChannel chan Msg) func(b []byte) (n int
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) closeAllConns() {
|
||||||
|
for _, container := range c.conns {
|
||||||
|
close(container.messages)
|
||||||
|
}
|
||||||
|
c.conns = make(map[string]*connContainer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo check by reference too, the id is not enought because the id come from the outer conn
|
||||||
func (c *Client) closeConn(id string) error {
|
func (c *Client) closeConn(id string) error {
|
||||||
c.connsMutext.Lock()
|
c.mu.Lock()
|
||||||
defer c.connsMutext.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
conn, ok := c.conns[id]
|
conn, ok := c.conns[id]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user