mirror of
https://github.com/rclone/rclone.git
synced 2025-03-04 02:11:32 +01:00
* Lower pacer minSleep to establish new connections faster * Use Echo requests to check whether connections are working (required an upgrade of go-smb2) * Only remount shares when needed * Use context for connection establishment * When returning a connection to the pool, only check the ones that encountered errors * Close connections in parallel
254 lines
5.4 KiB
Go
254 lines
5.4 KiB
Go
package smb
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"time"
|
|
|
|
smb2 "github.com/cloudsoda/go-smb2"
|
|
"github.com/rclone/rclone/fs"
|
|
"github.com/rclone/rclone/fs/accounting"
|
|
"github.com/rclone/rclone/fs/config/obscure"
|
|
"github.com/rclone/rclone/fs/fshttp"
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
// dial starts a client connection to the given SMB server. It is a
|
|
// convenience function that connects to the given network address,
|
|
// initiates the SMB handshake, and then sets up a Client.
|
|
//
|
|
// The context is only used for establishing the connection, not after.
|
|
func (f *Fs) dial(ctx context.Context, network, addr string) (*conn, error) {
|
|
dialer := fshttp.NewDialer(ctx)
|
|
tconn, err := dialer.DialContext(ctx, network, addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
pass := ""
|
|
if f.opt.Pass != "" {
|
|
pass, err = obscure.Reveal(f.opt.Pass)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
d := &smb2.Dialer{}
|
|
if f.opt.UseKerberos {
|
|
cl, err := getKerberosClient()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
spn := f.opt.SPN
|
|
if spn == "" {
|
|
spn = "cifs/" + f.opt.Host
|
|
}
|
|
|
|
d.Initiator = &smb2.Krb5Initiator{
|
|
Client: cl,
|
|
TargetSPN: spn,
|
|
}
|
|
} else {
|
|
d.Initiator = &smb2.NTLMInitiator{
|
|
User: f.opt.User,
|
|
Password: pass,
|
|
Domain: f.opt.Domain,
|
|
TargetSPN: f.opt.SPN,
|
|
}
|
|
}
|
|
|
|
session, err := d.DialConn(ctx, tconn, addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &conn{
|
|
smbSession: session,
|
|
conn: &tconn,
|
|
}, nil
|
|
}
|
|
|
|
// conn encapsulates a SMB client and corresponding SMB client
|
|
type conn struct {
|
|
conn *net.Conn
|
|
smbSession *smb2.Session
|
|
smbShare *smb2.Share
|
|
shareName string
|
|
}
|
|
|
|
// Closes the connection
|
|
func (c *conn) close() (err error) {
|
|
if c.smbShare != nil {
|
|
err = c.smbShare.Umount()
|
|
}
|
|
sessionLogoffErr := c.smbSession.Logoff()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return sessionLogoffErr
|
|
}
|
|
|
|
// True if it's closed
|
|
func (c *conn) closed() bool {
|
|
return c.smbSession.Echo() != nil
|
|
}
|
|
|
|
// Show that we are using a SMB session
|
|
//
|
|
// Call removeSession() when done
|
|
func (f *Fs) addSession() {
|
|
f.sessions.Add(1)
|
|
}
|
|
|
|
// Show the SMB session is no longer in use
|
|
func (f *Fs) removeSession() {
|
|
f.sessions.Add(-1)
|
|
}
|
|
|
|
// getSessions shows whether there are any sessions in use
|
|
func (f *Fs) getSessions() int32 {
|
|
return f.sessions.Load()
|
|
}
|
|
|
|
// Open a new connection to the SMB server.
|
|
//
|
|
// The context is only used for establishing the connection, not after.
|
|
func (f *Fs) newConnection(ctx context.Context, share string) (c *conn, err error) {
|
|
c, err = f.dial(ctx, "tcp", f.opt.Host+":"+f.opt.Port)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("couldn't connect SMB: %w", err)
|
|
}
|
|
if share != "" {
|
|
// mount the specified share as well if user requested
|
|
err = c.mountShare(share)
|
|
if err != nil {
|
|
_ = c.smbSession.Logoff()
|
|
return nil, fmt.Errorf("couldn't initialize SMB: %w", err)
|
|
}
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
// Ensure the specified share is mounted or the session is unmounted
|
|
func (c *conn) mountShare(share string) (err error) {
|
|
if c.shareName == share {
|
|
return nil
|
|
}
|
|
if c.smbShare != nil {
|
|
err = c.smbShare.Umount()
|
|
c.smbShare = nil
|
|
}
|
|
if err != nil {
|
|
return
|
|
}
|
|
if share != "" {
|
|
c.smbShare, err = c.smbSession.Mount(share)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
c.shareName = share
|
|
return nil
|
|
}
|
|
|
|
// Get a SMB connection from the pool, or open a new one
|
|
func (f *Fs) getConnection(ctx context.Context, share string) (c *conn, err error) {
|
|
accounting.LimitTPS(ctx)
|
|
f.poolMu.Lock()
|
|
for len(f.pool) > 0 {
|
|
c = f.pool[0]
|
|
f.pool = f.pool[1:]
|
|
err = c.mountShare(share)
|
|
if err == nil {
|
|
break
|
|
}
|
|
fs.Debugf(f, "Discarding unusable SMB connection: %v", err)
|
|
c = nil
|
|
}
|
|
f.poolMu.Unlock()
|
|
if c != nil {
|
|
return c, nil
|
|
}
|
|
err = f.pacer.Call(func() (bool, error) {
|
|
c, err = f.newConnection(ctx, share)
|
|
if err != nil {
|
|
return true, err
|
|
}
|
|
return false, nil
|
|
})
|
|
return c, err
|
|
}
|
|
|
|
// Return a SMB connection to the pool
|
|
//
|
|
// It nils the pointed to connection out so it can't be reused
|
|
//
|
|
// if err is not nil then it checks the connection is alive using an
|
|
// ECHO request
|
|
func (f *Fs) putConnection(pc **conn, err error) {
|
|
if pc == nil {
|
|
return
|
|
}
|
|
c := *pc
|
|
if c == nil {
|
|
return
|
|
}
|
|
*pc = nil
|
|
if err != nil {
|
|
// If not a regular SMB error then check the connection
|
|
if !(errors.Is(err, os.ErrNotExist) || errors.Is(err, os.ErrExist) || errors.Is(err, os.ErrPermission)) {
|
|
echoErr := c.smbSession.Echo()
|
|
if echoErr != nil {
|
|
fs.Debugf(f, "Connection failed, closing: %v", echoErr)
|
|
_ = c.close()
|
|
return
|
|
}
|
|
fs.Debugf(f, "Connection OK after error: %v", err)
|
|
}
|
|
}
|
|
|
|
f.poolMu.Lock()
|
|
f.pool = append(f.pool, c)
|
|
if f.opt.IdleTimeout > 0 {
|
|
f.drain.Reset(time.Duration(f.opt.IdleTimeout)) // nudge on the pool emptying timer
|
|
}
|
|
f.poolMu.Unlock()
|
|
}
|
|
|
|
// Drain the pool of any connections
|
|
func (f *Fs) drainPool(ctx context.Context) (err error) {
|
|
f.poolMu.Lock()
|
|
defer f.poolMu.Unlock()
|
|
if sessions := f.getSessions(); sessions != 0 {
|
|
fs.Debugf(f, "Not closing %d unused connections as %d sessions active", len(f.pool), sessions)
|
|
if f.opt.IdleTimeout > 0 {
|
|
f.drain.Reset(time.Duration(f.opt.IdleTimeout)) // nudge on the pool emptying timer
|
|
}
|
|
return nil
|
|
}
|
|
if f.opt.IdleTimeout > 0 {
|
|
f.drain.Stop()
|
|
}
|
|
if len(f.pool) != 0 {
|
|
fs.Debugf(f, "Closing %d unused connections", len(f.pool))
|
|
}
|
|
|
|
g, _ := errgroup.WithContext(ctx)
|
|
for i, c := range f.pool {
|
|
g.Go(func() (err error) {
|
|
if !c.closed() {
|
|
err = c.close()
|
|
}
|
|
f.pool[i] = nil
|
|
return err
|
|
})
|
|
}
|
|
err = g.Wait()
|
|
f.pool = nil
|
|
return err
|
|
}
|