diff --git a/backend/smb/connpool.go b/backend/smb/connpool.go index 9067b1f05..e040881d9 100644 --- a/backend/smb/connpool.go +++ b/backend/smb/connpool.go @@ -38,7 +38,7 @@ func (f *Fs) dial(ctx context.Context, network, addr string) (*conn, error) { d := &smb2.Dialer{} if f.opt.UseKerberos { - cl, err := createKerberosClient(f.opt.KerberosCCache) + cl, err := NewKerberosFactory().GetClient(f.opt.KerberosCCache) if err != nil { return nil, err } diff --git a/backend/smb/kerberos.go b/backend/smb/kerberos.go index 58414cda1..fdbb19ddd 100644 --- a/backend/smb/kerberos.go +++ b/backend/smb/kerberos.go @@ -7,17 +7,95 @@ import ( "path/filepath" "strings" "sync" + "time" "github.com/jcmturner/gokrb5/v8/client" "github.com/jcmturner/gokrb5/v8/config" "github.com/jcmturner/gokrb5/v8/credentials" ) -var ( - kerberosClient sync.Map // map[string]*client.Client - kerberosErr sync.Map // map[string]error -) +// KerberosFactory encapsulates dependencies and caches for Kerberos clients. +type KerberosFactory struct { + // clientCache caches Kerberos clients keyed by resolved ccache path. + // Clients are reused unless the associated ccache file changes. + clientCache sync.Map // map[string]*client.Client + // errCache caches errors encountered when loading Kerberos clients. + // Prevents repeated attempts for paths that previously failed. + errCache sync.Map // map[string]error + + // modTimeCache tracks the last known modification time of ccache files. + // Used to detect changes and trigger credential refresh. + modTimeCache sync.Map // map[string]time.Time + + loadCCache func(string) (*credentials.CCache, error) + newClient func(*credentials.CCache, *config.Config, ...func(*client.Settings)) (*client.Client, error) + loadConfig func() (*config.Config, error) +} + +// NewKerberosFactory creates a new instance of KerberosFactory with default dependencies. +func NewKerberosFactory() *KerberosFactory { + return &KerberosFactory{ + loadCCache: credentials.LoadCCache, + newClient: client.NewFromCCache, + loadConfig: defaultLoadKerberosConfig, + } +} + +// GetClient returns a cached Kerberos client or creates a new one if needed. +func (kf *KerberosFactory) GetClient(ccachePath string) (*client.Client, error) { + resolvedPath, err := resolveCcachePath(ccachePath) + if err != nil { + return nil, err + } + + stat, err := os.Stat(resolvedPath) + if err != nil { + kf.errCache.Store(resolvedPath, err) + return nil, err + } + mtime := stat.ModTime() + + if oldMod, ok := kf.modTimeCache.Load(resolvedPath); ok { + if oldTime, ok := oldMod.(time.Time); ok && oldTime.Equal(mtime) { + if errVal, ok := kf.errCache.Load(resolvedPath); ok { + return nil, errVal.(error) + } + if clientVal, ok := kf.clientCache.Load(resolvedPath); ok { + return clientVal.(*client.Client), nil + } + } + } + + // Load Kerberos config + cfg, err := kf.loadConfig() + if err != nil { + kf.errCache.Store(resolvedPath, err) + return nil, err + } + + // Load ccache + ccache, err := kf.loadCCache(resolvedPath) + if err != nil { + kf.errCache.Store(resolvedPath, err) + return nil, err + } + + // Create new client + cl, err := kf.newClient(ccache, cfg) + if err != nil { + kf.errCache.Store(resolvedPath, err) + return nil, err + } + + // Cache and return + kf.clientCache.Store(resolvedPath, cl) + kf.errCache.Delete(resolvedPath) + kf.modTimeCache.Store(resolvedPath, mtime) + return cl, nil +} + +// resolveCcachePath resolves the KRB5 ccache path. func resolveCcachePath(ccachePath string) (string, error) { if ccachePath == "" { ccachePath = os.Getenv("KRB5CCNAME") @@ -50,45 +128,11 @@ func resolveCcachePath(ccachePath string) (string, error) { } } -func loadKerberosConfig() (*config.Config, error) { +// defaultLoadKerberosConfig loads Kerberos config from default or env path. +func defaultLoadKerberosConfig() (*config.Config, error) { cfgPath := os.Getenv("KRB5_CONFIG") if cfgPath == "" { cfgPath = "/etc/krb5.conf" } return config.Load(cfgPath) } - -// createKerberosClient creates a new Kerberos client. -func createKerberosClient(ccachePath string) (*client.Client, error) { - ccachePath, err := resolveCcachePath(ccachePath) - if err != nil { - return nil, err - } - - // check if we already have a client or an error for this ccache path - if errVal, ok := kerberosErr.Load(ccachePath); ok { - return nil, errVal.(error) - } - if clientVal, ok := kerberosClient.Load(ccachePath); ok { - return clientVal.(*client.Client), nil - } - - // create a new client if not found in the map - cfg, err := loadKerberosConfig() - if err != nil { - kerberosErr.Store(ccachePath, err) - return nil, err - } - ccache, err := credentials.LoadCCache(ccachePath) - if err != nil { - kerberosErr.Store(ccachePath, err) - return nil, err - } - cl, err := client.NewFromCCache(ccache, cfg) - if err != nil { - kerberosErr.Store(ccachePath, err) - return nil, err - } - kerberosClient.Store(ccachePath, cl) - return cl, nil -} diff --git a/backend/smb/kerberos_test.go b/backend/smb/kerberos_test.go index 4bfe6d240..dc78b3bbf 100644 --- a/backend/smb/kerberos_test.go +++ b/backend/smb/kerberos_test.go @@ -4,7 +4,11 @@ import ( "os" "path/filepath" "testing" + "time" + "github.com/jcmturner/gokrb5/v8/client" + "github.com/jcmturner/gokrb5/v8/config" + "github.com/jcmturner/gokrb5/v8/credentials" "github.com/stretchr/testify/assert" ) @@ -77,3 +81,62 @@ func TestResolveCcachePath(t *testing.T) { }) } } + +func TestKerberosFactory_GetClient_ReloadOnCcacheChange(t *testing.T) { + // Create temp ccache file + tmpFile, err := os.CreateTemp("", "krb5cc_test") + assert.NoError(t, err) + defer func() { + if err := os.Remove(tmpFile.Name()); err != nil { + t.Logf("Failed to remove temp file %s: %v", tmpFile.Name(), err) + } + }() + + unixPath := filepath.ToSlash(tmpFile.Name()) + ccachePath := "FILE:" + unixPath + + initialContent := []byte("CCACHE_VERSION 4\n") + _, err = tmpFile.Write(initialContent) + assert.NoError(t, err) + assert.NoError(t, tmpFile.Close()) + + // Setup mocks + loadCallCount := 0 + mockLoadCCache := func(path string) (*credentials.CCache, error) { + loadCallCount++ + return &credentials.CCache{}, nil + } + + mockNewClient := func(cc *credentials.CCache, cfg *config.Config, opts ...func(*client.Settings)) (*client.Client, error) { + return &client.Client{}, nil + } + + mockLoadConfig := func() (*config.Config, error) { + return &config.Config{}, nil + } + factory := &KerberosFactory{ + loadCCache: mockLoadCCache, + newClient: mockNewClient, + loadConfig: mockLoadConfig, + } + + // First call — triggers loading + _, err = factory.GetClient(ccachePath) + assert.NoError(t, err) + assert.Equal(t, 1, loadCallCount, "expected 1 load call") + + // Second call — should reuse cache, no additional load + _, err = factory.GetClient(ccachePath) + assert.NoError(t, err) + assert.Equal(t, 1, loadCallCount, "expected cached reuse, no new load") + + // Simulate file update + time.Sleep(1 * time.Second) // ensure mtime changes + err = os.WriteFile(tmpFile.Name(), []byte("CCACHE_VERSION 4\n#updated"), 0600) + assert.NoError(t, err) + + // Third call — should detect change, reload + _, err = factory.GetClient(ccachePath) + assert.NoError(t, err) + assert.Equal(t, 2, loadCallCount, "expected reload on changed ccache") +}