[bugfix] httpclient not signing subsequent redirect requests (#2798)

* move http request signing to transport

* actually hook up the http roundtripper ...

* add code comments for the new gtscontext functions
This commit is contained in:
kim 2024-04-02 12:12:26 +01:00 committed by GitHub
parent 4bbdef02f1
commit d61d5c8a6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 98 additions and 34 deletions

View File

@ -19,6 +19,7 @@
import ( import (
"context" "context"
"net/http"
"net/url" "net/url"
"github.com/superseriousbusiness/gotosocial/internal/gtsmodel" "github.com/superseriousbusiness/gotosocial/internal/gtsmodel"
@ -42,6 +43,7 @@
httpSigKey httpSigKey
httpSigPubKeyIDKey httpSigPubKeyIDKey
dryRunKey dryRunKey
httpClientSignFnKey
) )
// DryRun returns whether the "dryrun" context key has been set. This can be // DryRun returns whether the "dryrun" context key has been set. This can be
@ -127,6 +129,19 @@ func SetOtherIRIs(ctx context.Context, iris []*url.URL) context.Context {
return context.WithValue(ctx, otherIRIsKey, iris) return context.WithValue(ctx, otherIRIsKey, iris)
} }
// HTTPClientSignFunc returns an httpclient signing function for the current client
// request context. This can be used to resign a request as calling transport's user.
func HTTPClientSignFunc(ctx context.Context) func(*http.Request) error {
fn, _ := ctx.Value(httpClientSignFnKey).(func(*http.Request) error)
return fn
}
// SetHTTPClientSignFunc stores the given httpclient signing function and returns the wrapped
// context. See HTTPClientSignFunc() for further information on the signing function value.
func SetHTTPClientSignFunc(ctx context.Context, fn func(*http.Request) error) context.Context {
return context.WithValue(ctx, httpClientSignFnKey, fn)
}
// HTTPSignatureVerifier returns an http signature verifier for the current ActivityPub // HTTPSignatureVerifier returns an http signature verifier for the current ActivityPub
// request chain. This verifier can be called to authenticate the current request. // request chain. This verifier can be called to authenticate the current request.
func HTTPSignatureVerifier(ctx context.Context) httpsig.VerifierWithOptions { func HTTPSignatureVerifier(ctx context.Context) httpsig.VerifierWithOptions {

View File

@ -32,7 +32,6 @@
"time" "time"
"codeberg.org/gruf/go-bytesize" "codeberg.org/gruf/go-bytesize"
"codeberg.org/gruf/go-byteutil"
"codeberg.org/gruf/go-cache/v3" "codeberg.org/gruf/go-cache/v3"
errorsv2 "codeberg.org/gruf/go-errors/v2" errorsv2 "codeberg.org/gruf/go-errors/v2"
"codeberg.org/gruf/go-iotools" "codeberg.org/gruf/go-iotools"
@ -163,7 +162,7 @@ func New(cfg Config) *Client {
} }
// Set underlying HTTP client roundtripper. // Set underlying HTTP client roundtripper.
c.client.Transport = &http.Transport{ c.client.Transport = &signingtransport{http.Transport{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
ForceAttemptHTTP2: true, ForceAttemptHTTP2: true,
DialContext: d.DialContext, DialContext: d.DialContext,
@ -175,7 +174,7 @@ func New(cfg Config) *Client {
ReadBufferSize: cfg.ReadBufferSize, ReadBufferSize: cfg.ReadBufferSize,
WriteBufferSize: cfg.WriteBufferSize, WriteBufferSize: cfg.WriteBufferSize,
DisableCompression: cfg.DisableCompression, DisableCompression: cfg.DisableCompression,
} }}
// Initiate outgoing bad hosts lookup cache. // Initiate outgoing bad hosts lookup cache.
c.badHosts = cache.NewTTL[string, struct{}](0, 1000, 0) c.badHosts = cache.NewTTL[string, struct{}](0, 1000, 0)
@ -239,23 +238,6 @@ func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, e
for i := 0; i < maxRetries; i++ { for i := 0; i < maxRetries; i++ {
var backoff time.Duration var backoff time.Duration
// Reset signing header fields
now := time.Now().UTC()
r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
r.Header.Del("Signature")
r.Header.Del("Digest")
// Rewind body reader and content-length if set.
if rc, ok := r.Body.(*byteutil.ReadNopCloser); ok {
rc.Rewind() // set len AFTER rewind
r.ContentLength = int64(rc.Len())
}
// Sign the outgoing request.
if err := sign(r); err != nil {
return nil, err
}
l.Info("performing request") l.Info("performing request")
// Perform the request. // Perform the request.
@ -276,6 +258,9 @@ func (c *Client) DoSigned(r *http.Request, sign SignFunc) (rsp *http.Response, e
// Search for a provided "Retry-After" header value. // Search for a provided "Retry-After" header value.
if after := rsp.Header.Get("Retry-After"); after != "" { if after := rsp.Header.Get("Retry-After"); after != "" {
// Get current time.
now := time.Now()
if u, _ := strconv.ParseUint(after, 10, 32); u != 0 { if u, _ := strconv.ParseUint(after, 10, 32); u != 0 {
// An integer number of backoff seconds was provided. // An integer number of backoff seconds was provided.
backoff = time.Duration(u) * time.Second backoff = time.Duration(u) * time.Second

View File

@ -17,12 +17,45 @@
package httpclient package httpclient
import "net/http" import (
"net/http"
"time"
"codeberg.org/gruf/go-byteutil"
"github.com/superseriousbusiness/gotosocial/internal/gtscontext"
)
// SignFunc is a function signature that provides request signing. // SignFunc is a function signature that provides request signing.
type SignFunc func(r *http.Request) error type SignFunc func(r *http.Request) error
type SigningClient interface { // signingtransport wraps an http.Transport{}
Do(r *http.Request) (*http.Response, error) // (RoundTripper implementer) to check request
DoSigned(r *http.Request, sign SignFunc) (*http.Response, error) // context for a signing function and using for
// all subsequent trips through RoundTrip().
type signingtransport struct {
http.Transport // underlying transport
}
func (t *signingtransport) RoundTrip(r *http.Request) (*http.Response, error) {
if sign := gtscontext.HTTPClientSignFunc(r.Context()); sign != nil {
// Reset signing header fields
now := time.Now().UTC()
r.Header.Set("Date", now.Format("Mon, 02 Jan 2006 15:04:05")+" GMT")
r.Header.Del("Signature")
r.Header.Del("Digest")
// Rewind body reader and content-length if set.
if rc, ok := r.Body.(*byteutil.ReadNopCloser); ok {
rc.Rewind() // set len AFTER rewind
r.ContentLength = int64(rc.Len())
}
// Sign the outgoing request.
if err := sign(r); err != nil {
return nil, err
}
}
// Pass to underlying transport.
return t.Transport.RoundTrip(r)
} }

View File

@ -37,7 +37,6 @@
"github.com/superseriousbusiness/gotosocial/internal/config" "github.com/superseriousbusiness/gotosocial/internal/config"
"github.com/superseriousbusiness/gotosocial/internal/db" "github.com/superseriousbusiness/gotosocial/internal/db"
"github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb" "github.com/superseriousbusiness/gotosocial/internal/federation/federatingdb"
"github.com/superseriousbusiness/gotosocial/internal/httpclient"
"github.com/superseriousbusiness/gotosocial/internal/state" "github.com/superseriousbusiness/gotosocial/internal/state"
) )
@ -54,14 +53,14 @@ type controller struct {
state *state.State state *state.State
fedDB federatingdb.DB fedDB federatingdb.DB
clock pub.Clock clock pub.Clock
client httpclient.SigningClient client pub.HttpClient
trspCache cache.TTLCache[string, *transport] trspCache cache.TTLCache[string, *transport]
userAgent string userAgent string
senders int // no. concurrent batch delivery routines. senders int // no. concurrent batch delivery routines.
} }
// NewController returns an implementation of the Controller interface for creating new transports // NewController returns an implementation of the Controller interface for creating new transports
func NewController(state *state.State, federatingDB federatingdb.DB, clock pub.Clock, client httpclient.SigningClient) Controller { func NewController(state *state.State, federatingDB federatingdb.DB, clock pub.Clock, client pub.HttpClient) Controller {
var ( var (
host = config.GetHost() host = config.GetHost()
proto = config.GetProtocol() proto = config.GetProtocol()

View File

@ -93,30 +93,61 @@ func (t *transport) GET(r *http.Request) (*http.Response, error) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
return nil, errors.New("must be GET request") return nil, errors.New("must be GET request")
} }
ctx := r.Context() // extract, set pubkey ID.
// Prepare HTTP GET signing func with opts.
sign := t.signGET(httpsig.SignatureOption{
ExcludeQueryStringFromPathPseudoHeader: false,
})
ctx := r.Context() // update with signing details.
ctx = gtscontext.SetOutgoingPublicKeyID(ctx, t.pubKeyID) ctx = gtscontext.SetOutgoingPublicKeyID(ctx, t.pubKeyID)
ctx = gtscontext.SetHTTPClientSignFunc(ctx, sign)
r = r.WithContext(ctx) // replace request ctx. r = r.WithContext(ctx) // replace request ctx.
// Set our predefined controller user-agent.
r.Header.Set("User-Agent", t.controller.userAgent) r.Header.Set("User-Agent", t.controller.userAgent)
resp, err := t.controller.client.DoSigned(r, t.signGET(httpsig.SignatureOption{ExcludeQueryStringFromPathPseudoHeader: false})) // Pass to underlying HTTP client.
resp, err := t.controller.client.Do(r)
if err != nil || resp.StatusCode != http.StatusUnauthorized { if err != nil || resp.StatusCode != http.StatusUnauthorized {
return resp, err return resp, err
} }
// try again without the path included in the HTTP signature for better compatibility // Ignore this response.
_ = resp.Body.Close() _ = resp.Body.Close()
return t.controller.client.DoSigned(r, t.signGET(httpsig.SignatureOption{ExcludeQueryStringFromPathPseudoHeader: true}))
// Try again without the path included in
// the HTTP signature for better compatibility.
sign = t.signGET(httpsig.SignatureOption{
ExcludeQueryStringFromPathPseudoHeader: true,
})
ctx = r.Context() // update with signing details.
ctx = gtscontext.SetHTTPClientSignFunc(ctx, sign)
r = r.WithContext(ctx) // replace request ctx.
// Pass to underlying HTTP client.
return t.controller.client.Do(r)
} }
func (t *transport) POST(r *http.Request, body []byte) (*http.Response, error) { func (t *transport) POST(r *http.Request, body []byte) (*http.Response, error) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
return nil, errors.New("must be POST request") return nil, errors.New("must be POST request")
} }
ctx := r.Context() // extract, set pubkey ID.
// Prepare POST signer.
sign := t.signPOST(body)
ctx := r.Context() // update with signing details.
ctx = gtscontext.SetOutgoingPublicKeyID(ctx, t.pubKeyID) ctx = gtscontext.SetOutgoingPublicKeyID(ctx, t.pubKeyID)
ctx = gtscontext.SetHTTPClientSignFunc(ctx, sign)
r = r.WithContext(ctx) // replace request ctx. r = r.WithContext(ctx) // replace request ctx.
// Set our predefined controller user-agent.
r.Header.Set("User-Agent", t.controller.userAgent) r.Header.Set("User-Agent", t.controller.userAgent)
return t.controller.client.DoSigned(r, t.signPOST(body))
// Pass to underlying HTTP client.
return t.controller.client.Do(r)
} }
// signGET will safely sign an HTTP GET request. // signGET will safely sign an HTTP GET request.

View File

@ -26,6 +26,7 @@
"strings" "strings"
"sync" "sync"
"github.com/superseriousbusiness/activity/pub"
"github.com/superseriousbusiness/activity/streams" "github.com/superseriousbusiness/activity/streams"
"github.com/superseriousbusiness/activity/streams/vocab" "github.com/superseriousbusiness/activity/streams/vocab"
apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model" apimodel "github.com/superseriousbusiness/gotosocial/internal/api/model"
@ -51,7 +52,7 @@
// Unlike the other test interfaces provided in this package, you'll probably want to call this function // Unlike the other test interfaces provided in this package, you'll probably want to call this function
// PER TEST rather than per suite, so that the do function can be set on a test by test (or even more granular) // PER TEST rather than per suite, so that the do function can be set on a test by test (or even more granular)
// basis. // basis.
func NewTestTransportController(state *state.State, client httpclient.SigningClient) transport.Controller { func NewTestTransportController(state *state.State, client pub.HttpClient) transport.Controller {
return transport.NewController(state, NewTestFederatingDB(state), &federation.Clock{}, client) return transport.NewController(state, NewTestFederatingDB(state), &federation.Clock{}, client)
} }