From c365863ea9debdb4bdc1255d75fd426a7be4b1b5 Mon Sep 17 00:00:00 2001
From: tobi <31960611+tsmethurst@users.noreply.github.com>
Date: Mon, 18 Apr 2022 17:17:05 +0200
Subject: [PATCH] [bugfix] Use our own (Batch)Deliver implementation for
 federated messages (#466)

---
 internal/transport/deliver.go | 42 +++++++++++++++++++++++++++++++----
 1 file changed, 38 insertions(+), 4 deletions(-)

diff --git a/internal/transport/deliver.go b/internal/transport/deliver.go
index cbba8f080..fe17f7761 100644
--- a/internal/transport/deliver.go
+++ b/internal/transport/deliver.go
@@ -20,7 +20,10 @@
 
 import (
 	"context"
+	"fmt"
 	"net/url"
+	"strings"
+	"sync"
 
 	"github.com/sirupsen/logrus"
 	"github.com/spf13/viper"
@@ -28,16 +31,47 @@
 )
 
 func (t *transport) BatchDeliver(ctx context.Context, b []byte, recipients []*url.URL) error {
-	return t.sigTransport.BatchDeliver(ctx, b, recipients)
+	// concurrently deliver to recipients; for each delivery, buffer the error if it fails
+	wg := sync.WaitGroup{}
+	errCh := make(chan error, len(recipients))
+	for _, recipient := range recipients {
+		wg.Add(1)
+		go func(r *url.URL) {
+			defer wg.Done()
+			if err := t.Deliver(ctx, b, r); err != nil {
+				errCh <- err
+			}
+		}(recipient)
+	}
+
+	// wait until all deliveries have succeeded or failed
+	wg.Wait()
+
+	// receive any buffered errors
+	errs := make([]string, 0, len(recipients))
+outer:
+	for {
+		select {
+		case e := <-errCh:
+			errs = append(errs, e.Error())
+		default:
+			break outer
+		}
+	}
+
+	if len(errs) > 0 {
+		return fmt.Errorf("BatchDeliver: at least one failure: %s", strings.Join(errs, "; "))
+	}
+
+	return nil
 }
 
 func (t *transport) Deliver(ctx context.Context, b []byte, to *url.URL) error {
 	// if the 'to' host is our own, just skip this delivery since we by definition already have the message!
-	if to.Host == viper.GetString(config.Keys.Host) {
+	if to.Host == viper.GetString(config.Keys.Host) || to.Host == viper.GetString(config.Keys.AccountDomain) {
 		return nil
 	}
 
-	l := logrus.WithField("func", "Deliver")
-	l.Debugf("performing POST to %s", to.String())
+	logrus.Debugf("Deliver: posting as %s to %s", t.pubKeyID, to.String())
 	return t.sigTransport.Deliver(ctx, b, to)
 }