mirror of
https://github.com/netbirdio/netbird.git
synced 2025-07-20 07:57:15 +02:00
257 lines
5.5 KiB
Go
257 lines
5.5 KiB
Go
package client_test
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/grpc"
|
|
|
|
flow "github.com/netbirdio/netbird/flow/client"
|
|
"github.com/netbirdio/netbird/flow/proto"
|
|
)
|
|
|
|
type testServer struct {
|
|
proto.UnimplementedFlowServiceServer
|
|
events chan *proto.FlowEvent
|
|
acks chan *proto.FlowEventAck
|
|
grpcSrv *grpc.Server
|
|
addr string
|
|
}
|
|
|
|
func newTestServer(t *testing.T) *testServer {
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
|
|
s := &testServer{
|
|
events: make(chan *proto.FlowEvent, 100),
|
|
acks: make(chan *proto.FlowEventAck, 100),
|
|
grpcSrv: grpc.NewServer(),
|
|
addr: listener.Addr().String(),
|
|
}
|
|
|
|
proto.RegisterFlowServiceServer(s.grpcSrv, s)
|
|
|
|
go func() {
|
|
if err := s.grpcSrv.Serve(listener); err != nil && !errors.Is(err, grpc.ErrServerStopped) {
|
|
t.Logf("server error: %v", err)
|
|
}
|
|
}()
|
|
|
|
t.Cleanup(func() {
|
|
s.grpcSrv.Stop()
|
|
})
|
|
|
|
return s
|
|
}
|
|
|
|
func (s *testServer) Events(stream proto.FlowService_EventsServer) error {
|
|
err := stream.Send(&proto.FlowEventAck{IsInitiator: true})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(stream.Context())
|
|
defer cancel()
|
|
|
|
go func() {
|
|
defer cancel()
|
|
for {
|
|
event, err := stream.Recv()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if !event.IsInitiator {
|
|
select {
|
|
case s.events <- event:
|
|
ack := &proto.FlowEventAck{
|
|
EventId: event.EventId,
|
|
}
|
|
select {
|
|
case s.acks <- ack:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case ack := <-s.acks:
|
|
if err := stream.Send(ack); err != nil {
|
|
return err
|
|
}
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestReceive(t *testing.T) {
|
|
server := newTestServer(t)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
t.Cleanup(cancel)
|
|
|
|
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
err := client.Close()
|
|
assert.NoError(t, err, "failed to close flow")
|
|
})
|
|
|
|
receivedAcks := make(map[string]bool)
|
|
receiveDone := make(chan struct{})
|
|
|
|
go func() {
|
|
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
|
|
if !msg.IsInitiator && len(msg.EventId) > 0 {
|
|
id := string(msg.EventId)
|
|
receivedAcks[id] = true
|
|
|
|
if len(receivedAcks) >= 3 {
|
|
close(receiveDone)
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil && !errors.Is(err, context.Canceled) {
|
|
t.Logf("receive error: %v", err)
|
|
}
|
|
}()
|
|
|
|
time.Sleep(500 * time.Millisecond)
|
|
|
|
for i := 0; i < 3; i++ {
|
|
eventID := uuid.New().String()
|
|
|
|
// Create acknowledgment and send it to the flow through our test server
|
|
ack := &proto.FlowEventAck{
|
|
EventId: []byte(eventID),
|
|
}
|
|
|
|
select {
|
|
case server.acks <- ack:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timeout sending ack")
|
|
}
|
|
}
|
|
|
|
select {
|
|
case <-receiveDone:
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("timeout waiting for acks to be processed")
|
|
}
|
|
|
|
assert.Equal(t, 3, len(receivedAcks))
|
|
}
|
|
|
|
func TestReceive_ContextCancellation(t *testing.T) {
|
|
server := newTestServer(t)
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
t.Cleanup(cancel)
|
|
|
|
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
err := client.Close()
|
|
assert.NoError(t, err, "failed to close flow")
|
|
})
|
|
|
|
go func() {
|
|
time.Sleep(100 * time.Millisecond)
|
|
cancel()
|
|
}()
|
|
|
|
handlerCalled := false
|
|
msgHandler := func(msg *proto.FlowEventAck) error {
|
|
if !msg.IsInitiator {
|
|
handlerCalled = true
|
|
}
|
|
return nil
|
|
}
|
|
|
|
err = client.Receive(ctx, 1*time.Second, msgHandler)
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "context canceled")
|
|
assert.False(t, handlerCalled)
|
|
}
|
|
|
|
func TestSend(t *testing.T) {
|
|
server := newTestServer(t)
|
|
|
|
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
err := client.Close()
|
|
assert.NoError(t, err, "failed to close flow")
|
|
})
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
t.Cleanup(cancel)
|
|
|
|
ackReceived := make(chan struct{})
|
|
|
|
go func() {
|
|
err := client.Receive(ctx, 1*time.Second, func(ack *proto.FlowEventAck) error {
|
|
if len(ack.EventId) > 0 && !ack.IsInitiator {
|
|
close(ackReceived)
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil && !errors.Is(err, context.Canceled) {
|
|
t.Logf("receive error: %v", err)
|
|
}
|
|
}()
|
|
|
|
time.Sleep(500 * time.Millisecond)
|
|
|
|
testEvent := &proto.FlowEvent{
|
|
EventId: []byte("test-event-id"),
|
|
PublicKey: []byte("test-public-key"),
|
|
FlowFields: &proto.FlowFields{
|
|
FlowId: []byte("test-flow-id"),
|
|
Protocol: 6,
|
|
SourceIp: []byte{192, 168, 1, 1},
|
|
DestIp: []byte{192, 168, 1, 2},
|
|
ConnectionInfo: &proto.FlowFields_PortInfo{
|
|
PortInfo: &proto.PortInfo{
|
|
SourcePort: 12345,
|
|
DestPort: 443,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
err = client.Send(testEvent)
|
|
require.NoError(t, err)
|
|
|
|
var receivedEvent *proto.FlowEvent
|
|
select {
|
|
case receivedEvent = <-server.events:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for event to be received by server")
|
|
}
|
|
|
|
assert.Equal(t, testEvent.EventId, receivedEvent.EventId)
|
|
assert.Equal(t, testEvent.PublicKey, receivedEvent.PublicKey)
|
|
|
|
select {
|
|
case <-ackReceived:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for ack to be received by flow")
|
|
}
|
|
}
|