Fix close conn threading issue

This commit is contained in:
Zoltan Papp 2024-06-03 00:29:08 +02:00
parent 3430b81622
commit 4ced07dd8d

View File

@ -30,18 +30,17 @@ type connContainer struct {
type Client struct { type Client struct {
log *log.Entry log *log.Entry
parentCtx context.Context
ctx context.Context ctx context.Context
ctxCancel context.CancelFunc ctxCancel context.CancelFunc
serverAddress string serverAddress string
hashedID []byte hashedID []byte
readyToOpenConns bool
conns map[string]*connContainer
connsMutext sync.Mutex // protect conns and readyToOpenConns bool
relayConn net.Conn relayConn net.Conn
conns map[string]*connContainer
serviceIsRunning bool serviceIsRunning bool
serviceIsRunningMutex sync.Mutex mu sync.Mutex
readLoopMutex sync.Mutex
wgReadLoop sync.WaitGroup wgReadLoop sync.WaitGroup
remoteAddr net.Addr remoteAddr net.Addr
@ -51,12 +50,11 @@ type Client struct {
} }
func NewClient(ctx context.Context, serverAddress, peerID string) *Client { func NewClient(ctx context.Context, serverAddress, peerID string) *Client {
ctx, ctxCancel := context.WithCancel(ctx)
hashedID, hashedStringId := messages.HashID(peerID) hashedID, hashedStringId := messages.HashID(peerID)
return &Client{ return &Client{
log: log.WithField("client_id", hashedStringId), log: log.WithField("client_id", hashedStringId),
ctx: ctx, parentCtx: ctx,
ctxCancel: ctxCancel, ctxCancel: func() {},
serverAddress: serverAddress, serverAddress: serverAddress,
hashedID: hashedID, hashedID: hashedID,
conns: make(map[string]*connContainer), conns: make(map[string]*connContainer),
@ -70,39 +68,44 @@ func (c *Client) SetOnDisconnectListener(fn func()) {
} }
func (c *Client) Connect() error { func (c *Client) Connect() error {
c.serviceIsRunningMutex.Lock() c.readLoopMutex.Lock()
defer c.serviceIsRunningMutex.Unlock() defer c.readLoopMutex.Unlock()
c.mu.Lock()
if c.serviceIsRunning { if c.serviceIsRunning {
c.mu.Unlock()
return nil return nil
} }
c.mu.Unlock()
err := c.connect() err := c.connect()
if err != nil { if err != nil {
c.mu.Unlock()
return err return err
} }
c.serviceIsRunning = true c.serviceIsRunning = true
c.wgReadLoop.Add(1) c.ctx, c.ctxCancel = context.WithCancel(c.parentCtx)
go c.readLoop() context.AfterFunc(c.ctx, func() {
cErr := c.Close()
go func() {
<-c.ctx.Done()
cErr := c.close()
if cErr != nil { if cErr != nil {
log.Errorf("failed to close relay connection: %s", cErr) log.Errorf("failed to close relay connection: %s", cErr)
} }
}() })
c.wgReadLoop.Add(1)
go c.readLoop(c.relayConn)
return nil return nil
} }
// todo: what should happen of call with the same peerID?
func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
c.connsMutext.Lock() c.mu.Lock()
defer c.connsMutext.Unlock() defer c.mu.Unlock()
if !c.readyToOpenConns { if !c.serviceIsRunning {
return nil, fmt.Errorf("relay connection is not established") return nil, fmt.Errorf("relay connection is not established")
} }
@ -119,8 +122,8 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) {
} }
func (c *Client) RelayRemoteAddress() (net.Addr, error) { func (c *Client) RelayRemoteAddress() (net.Addr, error) {
c.serviceIsRunningMutex.Lock() c.mu.Lock()
defer c.serviceIsRunningMutex.Unlock() defer c.mu.Unlock()
if c.remoteAddr == nil { if c.remoteAddr == nil {
return nil, fmt.Errorf("relay connection is not established") return nil, fmt.Errorf("relay connection is not established")
} }
@ -128,14 +131,21 @@ func (c *Client) RelayRemoteAddress() (net.Addr, error) {
} }
func (c *Client) Close() error { func (c *Client) Close() error {
c.serviceIsRunningMutex.Lock() c.readLoopMutex.Lock()
if !c.serviceIsRunning { defer c.readLoopMutex.Unlock()
c.serviceIsRunningMutex.Unlock()
return nil
}
c.mu.Lock()
var err error
if c.serviceIsRunning {
c.serviceIsRunning = false
err = c.relayConn.Close()
}
c.closeAllConns()
c.mu.Unlock()
c.wgReadLoop.Wait()
c.ctxCancel() c.ctxCancel()
return c.close() return err
} }
func (c *Client) connect() error { func (c *Client) connect() error {
@ -157,27 +167,9 @@ func (c *Client) connect() error {
c.remoteAddr = conn.RemoteAddr() c.remoteAddr = conn.RemoteAddr()
c.readyToOpenConns = true
return nil return nil
} }
func (c *Client) close() error {
c.serviceIsRunningMutex.Lock()
defer c.serviceIsRunningMutex.Unlock()
if !c.serviceIsRunning {
return nil
}
c.serviceIsRunning = false
err := c.relayConn.Close()
c.wgReadLoop.Wait()
return err
}
func (c *Client) handShake() error { func (c *Client) handShake() error {
defer func() { defer func() {
err := c.relayConn.SetReadDeadline(time.Time{}) err := c.relayConn.SetReadDeadline(time.Time{})
@ -223,16 +215,18 @@ func (c *Client) handShake() error {
return nil return nil
} }
func (c *Client) readLoop() { func (c *Client) readLoop(relayConn net.Conn) {
var errExit error var errExit error
var n int var n int
for { for {
buf := make([]byte, bufferSize) buf := make([]byte, bufferSize)
n, errExit = c.relayConn.Read(buf) n, errExit = relayConn.Read(buf)
if errExit != nil { if errExit != nil {
c.mu.Lock()
if c.serviceIsRunning { if c.serviceIsRunning {
c.log.Debugf("failed to read message from relay server: %s", errExit) c.log.Debugf("failed to read message from relay server: %s", errExit)
} }
c.mu.Unlock()
break break
} }
@ -251,44 +245,44 @@ func (c *Client) readLoop() {
} }
stringID := messages.HashIDToString(peerID) stringID := messages.HashIDToString(peerID)
c.mu.Lock()
if !c.serviceIsRunning {
c.mu.Unlock()
break
}
container, ok := c.conns[stringID] container, ok := c.conns[stringID]
c.mu.Unlock()
if !ok { if !ok {
c.log.Errorf("peer not found: %s", stringID) c.log.Errorf("peer not found: %s", stringID)
continue continue
} }
container.messages <- Msg{ container.messages <- Msg{buf[:n]}
buf[:n],
}
} }
} }
c.notifyDisconnected() c.notifyDisconnected()
if c.serviceIsRunning {
_ = c.relayConn.Close()
}
c.connsMutext.Lock()
c.readyToOpenConns = false
for _, container := range c.conns {
close(container.messages)
}
c.conns = make(map[string]*connContainer)
c.connsMutext.Unlock()
c.log.Tracef("exit from read loop") c.log.Tracef("exit from read loop")
c.wgReadLoop.Done() c.wgReadLoop.Done()
c.Close()
} }
// todo check by reference too, the id is not enought because the id come from the outer conn
func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) { func (c *Client) writeTo(id string, dstID []byte, payload []byte) (int, error) {
c.connsMutext.Lock() c.mu.Lock()
// conn, ok := c.conns[id]
_, ok := c.conns[id] _, ok := c.conns[id]
c.mu.Unlock()
if !ok { if !ok {
c.connsMutext.Unlock()
return 0, io.EOF return 0, io.EOF
} }
c.connsMutext.Unlock() /*
if conn != clientRef {
return 0, io.EOF
}
*/
msg := messages.MarshalTransportMsg(dstID, payload) msg := messages.MarshalTransportMsg(dstID, payload)
n, err := c.relayConn.Write(msg) n, err := c.relayConn.Write(msg)
if err != nil { if err != nil {
@ -314,9 +308,17 @@ func (c *Client) generateConnReaderFN(msgChannel chan Msg) func(b []byte) (n int
} }
} }
func (c *Client) closeAllConns() {
for _, container := range c.conns {
close(container.messages)
}
c.conns = make(map[string]*connContainer)
}
// todo check by reference too, the id is not enought because the id come from the outer conn
func (c *Client) closeConn(id string) error { func (c *Client) closeConn(id string) error {
c.connsMutext.Lock() c.mu.Lock()
defer c.connsMutext.Unlock() defer c.mu.Unlock()
conn, ok := c.conns[id] conn, ok := c.conns[id]
if !ok { if !ok {