Files
netbird/flow/client/client_test.go
2025-03-28 16:08:31 +01:00

257 lines
5.5 KiB
Go

package client_test
import (
"context"
"errors"
"net"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
flow "github.com/netbirdio/netbird/flow/client"
"github.com/netbirdio/netbird/flow/proto"
)
type testServer struct {
proto.UnimplementedFlowServiceServer
events chan *proto.FlowEvent
acks chan *proto.FlowEventAck
grpcSrv *grpc.Server
addr string
}
func newTestServer(t *testing.T) *testServer {
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
s := &testServer{
events: make(chan *proto.FlowEvent, 100),
acks: make(chan *proto.FlowEventAck, 100),
grpcSrv: grpc.NewServer(),
addr: listener.Addr().String(),
}
proto.RegisterFlowServiceServer(s.grpcSrv, s)
go func() {
if err := s.grpcSrv.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
t.Logf("server error: %v", err)
}
}()
t.Cleanup(func() {
s.grpcSrv.Stop()
})
return s
}
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
if err != nil {
return err
}
ctx, cancel := context.WithCancel(stream.Context())
defer cancel()
go func() {
defer cancel()
for {
event, err := stream.Recv()
if err != nil {
return
}
if !event.IsInitiator {
select {
case s.events <- event:
ack := &proto.FlowEventAck{
EventId: event.EventId,
}
select {
case s.acks <- ack:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}
}()
for {
select {
case ack := <-s.acks:
if err := stream.Send(ack); err != nil {
return err
}
case <-ctx.Done():
return ctx.Err()
}
}
}
func TestReceive(t *testing.T) {
server := newTestServer(t)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
t.Cleanup(func() {
err := client.Close()
assert.NoError(t, err, "failed to close flow")
})
receivedAcks := make(map[string]bool)
receiveDone := make(chan struct{})
go func() {
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
if !msg.IsInitiator && len(msg.EventId) > 0 {
id := string(msg.EventId)
receivedAcks[id] = true
if len(receivedAcks) >= 3 {
close(receiveDone)
}
}
return nil
})
if err != nil && !errors.Is(err, context.Canceled) {
t.Logf("receive error: %v", err)
}
}()
time.Sleep(500 * time.Millisecond)
for i := 0; i < 3; i++ {
eventID := uuid.New().String()
// Create acknowledgment and send it to the flow through our test server
ack := &proto.FlowEventAck{
EventId: []byte(eventID),
}
select {
case server.acks <- ack:
case <-time.After(time.Second):
t.Fatal("timeout sending ack")
}
}
select {
case <-receiveDone:
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for acks to be processed")
}
assert.Equal(t, 3, len(receivedAcks))
}
func TestReceive_ContextCancellation(t *testing.T) {
server := newTestServer(t)
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
t.Cleanup(func() {
err := client.Close()
assert.NoError(t, err, "failed to close flow")
})
go func() {
time.Sleep(100 * time.Millisecond)
cancel()
}()
handlerCalled := false
msgHandler := func(msg *proto.FlowEventAck) error {
if !msg.IsInitiator {
handlerCalled = true
}
return nil
}
err = client.Receive(ctx, 1*time.Second, msgHandler)
assert.Error(t, err)
assert.Contains(t, err.Error(), "context canceled")
assert.False(t, handlerCalled)
}
func TestSend(t *testing.T) {
server := newTestServer(t)
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
t.Cleanup(func() {
err := client.Close()
assert.NoError(t, err, "failed to close flow")
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
t.Cleanup(cancel)
ackReceived := make(chan struct{})
go func() {
err := client.Receive(ctx, 1*time.Second, func(ack *proto.FlowEventAck) error {
if len(ack.EventId) > 0 && !ack.IsInitiator {
close(ackReceived)
}
return nil
})
if err != nil && !errors.Is(err, context.Canceled) {
t.Logf("receive error: %v", err)
}
}()
time.Sleep(500 * time.Millisecond)
testEvent := &proto.FlowEvent{
EventId: []byte("test-event-id"),
PublicKey: []byte("test-public-key"),
FlowFields: &proto.FlowFields{
FlowId: []byte("test-flow-id"),
Protocol: 6,
SourceIp: []byte{192, 168, 1, 1},
DestIp: []byte{192, 168, 1, 2},
ConnectionInfo: &proto.FlowFields_PortInfo{
PortInfo: &proto.PortInfo{
SourcePort: 12345,
DestPort: 443,
},
},
},
}
err = client.Send(testEvent)
require.NoError(t, err)
var receivedEvent *proto.FlowEvent
select {
case receivedEvent = <-server.events:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for event to be received by server")
}
assert.Equal(t, testEvent.EventId, receivedEvent.EventId)
assert.Equal(t, testEvent.PublicKey, receivedEvent.PublicKey)
select {
case <-ackReceived:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for ack to be received by flow")
}
}