Add --bind flag for choosing the local addr on outgoing connections - fixes #1087

Supported by all remotes except FTP.
This commit is contained in:
Nick Craig-Wood 2017-07-23 16:10:23 +01:00
parent 8b30023f0d
commit 9f24639568
7 changed files with 61 additions and 20 deletions

View File

@ -249,6 +249,13 @@ If running rclone from a script you might want to use today's date as
the directory name passed to `--backup-dir` to store the old files, or
you might want to pass `--suffix` with today's date.
### --bind string ###
Local address to bind to for outgoing connections. This can be an
IPv4 address (1.2.3.4), an IPv6 address (1234::789A) or host name. If
the host name doesn't resolve or resoves to more than one IP address
it will give an error.
### --bwlimit=BANDWIDTH_SPEC ###
This option controls the bandwidth limit. Limits can be specified

View File

@ -126,4 +126,6 @@ with it: `--dump-headers`, `--dump-bodies`, `--dump-auth`
Note that `--timeout` isn't supported (but `--contimeout` is).
Note that `--bind` isn't supported.
FTP could support server side move but doesn't yet.

View File

@ -14,6 +14,7 @@ import (
"io"
"io/ioutil"
"log"
"net"
"os"
"os/user"
"path/filepath"
@ -98,6 +99,7 @@ var (
useListR = BoolP("fast-list", "", false, "Use recursive list if available. Uses more memory but fewer transactions.")
tpsLimit = Float64P("tpslimit", "", 0, "Limit HTTP transactions per second to this.")
tpsLimitBurst = IntP("tpslimit-burst", "", 1, "Max burst of transactions for --tpslimit.")
bindAddr = StringP("bind", "", "", "Local address to bind to for outgoing connections, IPv4, IPv4 or name.")
logLevel = LogLevelNotice
statsLogLevel = LogLevelInfo
bwLimit BwTimetable
@ -232,6 +234,7 @@ type ConfigInfo struct {
BufferSize SizeSuffix
TPSLimit float64
TPSLimitBurst int
BindAddr net.IP
}
// Return the path to the configuration file
@ -398,6 +401,17 @@ func LoadConfig() {
log.Fatalf(`Can only use --suffix with --backup-dir.`)
}
if *bindAddr != "" {
addrs, err := net.LookupIP(*bindAddr)
if err != nil {
log.Fatalf("--bind: Failed to parse %q as IP address: %v", *bindAddr, err)
}
if len(addrs) != 1 {
log.Fatalf("--bind: Expecting 1 IP address for %q but got %d", *bindAddr, len(addrs))
}
Config.BindAddr = addrs[0]
}
// Load configuration file.
var err error
configData, err = loadConfigFile()

View File

@ -272,3 +272,16 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
}
return resp, err
}
// NewDialer creates a net.Dialer structure with Timeout, Keepalive
// and LocalAddr set from rclone flags.
func (ci *ConfigInfo) NewDialer() *net.Dialer {
dialer := &net.Dialer{
Timeout: ci.ConnectTimeout,
KeepAlive: 30 * time.Second,
}
if ci.BindAddr != nil {
dialer.LocalAddr = &net.TCPAddr{IP: ci.BindAddr}
}
return dialer
}

View File

@ -12,23 +12,18 @@ import (
)
// dial with context and timeouts
func dialContextTimeout(ctx context.Context, network, address string, connectTimeout, timeout time.Duration) (net.Conn, error) {
dialer := net.Dialer{
Timeout: connectTimeout,
KeepAlive: 30 * time.Second,
}
func (ci *ConfigInfo) dialContextTimeout(ctx context.Context, network, address string) (net.Conn, error) {
dialer := ci.NewDialer()
c, err := dialer.DialContext(ctx, network, address)
if err != nil {
return c, err
}
return newTimeoutConn(c, timeout), nil
return newTimeoutConn(c, ci.Timeout), nil
}
// Initialise the http.Transport for go1.7+
func (ci *ConfigInfo) initTransport(t *http.Transport) {
t.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
return dialContextTimeout(ctx, network, address, ci.ConnectTimeout, ci.Timeout)
}
t.DialContext = ci.dialContextTimeout
t.IdleConnTimeout = 60 * time.Second
t.ExpectContinueTimeout = ci.ConnectTimeout
}

View File

@ -7,25 +7,19 @@ package fs
import (
"net"
"net/http"
"time"
)
// dial with timeouts
func dialTimeout(network, address string, connectTimeout, timeout time.Duration) (net.Conn, error) {
dialer := net.Dialer{
Timeout: connectTimeout,
KeepAlive: 30 * time.Second,
}
func (ci *ConfigInfo) dialTimeout(network, address string) (net.Conn, error) {
dialer := ci.NewDialer()
c, err := dialer.Dial(network, address)
if err != nil {
return c, err
}
return newTimeoutConn(c, timeout), nil
return newTimeoutConn(c, ci.Timeout), nil
}
// Initialise the http.Transport for pre go1.7
func (ci *ConfigInfo) initTransport(t *http.Transport) {
t.Dial = func(network, address string) (net.Conn, error) {
return dialTimeout(network, address, ci.ConnectTimeout, ci.Timeout)
}
t.Dial = dialTimeout
}

View File

@ -79,6 +79,22 @@ type ObjectReader struct {
sftpFile *sftp.File
}
// Dial starts a client connection to the given SSH server. It is a
// convenience function that connects to the given network address,
// initiates the SSH handshake, and then sets up a Client.
func Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
dialer := fs.Config.NewDialer()
conn, err := dialer.Dial(network, addr)
if err != nil {
return nil, err
}
c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
return nil, err
}
return ssh.NewClient(c, chans, reqs), nil
}
// NewFs creates a new Fs object from the name and root. It connects to
// the host specified in the config file.
func NewFs(name, root string) (fs.Fs, error) {
@ -135,7 +151,7 @@ func NewFs(name, root string) (fs.Fs, error) {
config.Auth = append(config.Auth, ssh.Password(clearpass))
}
sshClient, err := ssh.Dial("tcp", host+":"+port, config)
sshClient, err := Dial("tcp", host+":"+port, config)
if err != nil {
return nil, errors.Wrap(err, "couldn't connect ssh")
}