mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-16 10:08:12 +02:00
[client] Stop flow grpc receiver properly (#3596)
This commit is contained in:
256
flow/client/client_test.go
Normal file
256
flow/client/client_test.go
Normal file
@ -0,0 +1,256 @@
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
flow "github.com/netbirdio/netbird/flow/client"
|
||||
"github.com/netbirdio/netbird/flow/proto"
|
||||
)
|
||||
|
||||
type testServer struct {
|
||||
proto.UnimplementedFlowServiceServer
|
||||
events chan *proto.FlowEvent
|
||||
acks chan *proto.FlowEventAck
|
||||
grpcSrv *grpc.Server
|
||||
addr string
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T) *testServer {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
s := &testServer{
|
||||
events: make(chan *proto.FlowEvent, 100),
|
||||
acks: make(chan *proto.FlowEventAck, 100),
|
||||
grpcSrv: grpc.NewServer(),
|
||||
addr: listener.Addr().String(),
|
||||
}
|
||||
|
||||
proto.RegisterFlowServiceServer(s.grpcSrv, s)
|
||||
|
||||
go func() {
|
||||
if err := s.grpcSrv.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
|
||||
t.Logf("server error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
t.Cleanup(func() {
|
||||
s.grpcSrv.Stop()
|
||||
})
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
|
||||
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(stream.Context())
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
defer cancel()
|
||||
for {
|
||||
event, err := stream.Recv()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !event.IsInitiator {
|
||||
select {
|
||||
case s.events <- event:
|
||||
ack := &proto.FlowEventAck{
|
||||
EventId: event.EventId,
|
||||
}
|
||||
select {
|
||||
case s.acks <- ack:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case ack := <-s.acks:
|
||||
if err := stream.Send(ack); err != nil {
|
||||
return err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestReceive(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
})
|
||||
|
||||
receivedAcks := make(map[string]bool)
|
||||
receiveDone := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
||||
if !msg.IsInitiator && len(msg.EventId) > 0 {
|
||||
id := string(msg.EventId)
|
||||
receivedAcks[id] = true
|
||||
|
||||
if len(receivedAcks) >= 3 {
|
||||
close(receiveDone)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
t.Logf("receive error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
eventID := uuid.New().String()
|
||||
|
||||
// Create acknowledgment and send it to the flow through our test server
|
||||
ack := &proto.FlowEventAck{
|
||||
EventId: []byte(eventID),
|
||||
}
|
||||
|
||||
select {
|
||||
case server.acks <- ack:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timeout sending ack")
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-receiveDone:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for acks to be processed")
|
||||
}
|
||||
|
||||
assert.Equal(t, 3, len(receivedAcks))
|
||||
}
|
||||
|
||||
func TestReceive_ContextCancellation(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
})
|
||||
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
handlerCalled := false
|
||||
msgHandler := func(msg *proto.FlowEventAck) error {
|
||||
if !msg.IsInitiator {
|
||||
handlerCalled = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Receive(ctx, 1*time.Second, msgHandler)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "context canceled")
|
||||
assert.False(t, handlerCalled)
|
||||
}
|
||||
|
||||
func TestSend(t *testing.T) {
|
||||
server := newTestServer(t)
|
||||
|
||||
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "failed to close flow")
|
||||
})
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
ackReceived := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
err := client.Receive(ctx, 1*time.Second, func(ack *proto.FlowEventAck) error {
|
||||
if len(ack.EventId) > 0 && !ack.IsInitiator {
|
||||
close(ackReceived)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
t.Logf("receive error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
testEvent := &proto.FlowEvent{
|
||||
EventId: []byte("test-event-id"),
|
||||
PublicKey: []byte("test-public-key"),
|
||||
FlowFields: &proto.FlowFields{
|
||||
FlowId: []byte("test-flow-id"),
|
||||
Protocol: 6,
|
||||
SourceIp: []byte{192, 168, 1, 1},
|
||||
DestIp: []byte{192, 168, 1, 2},
|
||||
ConnectionInfo: &proto.FlowFields_PortInfo{
|
||||
PortInfo: &proto.PortInfo{
|
||||
SourcePort: 12345,
|
||||
DestPort: 443,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = client.Send(testEvent)
|
||||
require.NoError(t, err)
|
||||
|
||||
var receivedEvent *proto.FlowEvent
|
||||
select {
|
||||
case receivedEvent = <-server.events:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for event to be received by server")
|
||||
}
|
||||
|
||||
assert.Equal(t, testEvent.EventId, receivedEvent.EventId)
|
||||
assert.Equal(t, testEvent.PublicKey, receivedEvent.PublicKey)
|
||||
|
||||
select {
|
||||
case <-ackReceived:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for ack to be received by flow")
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user