diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..029879c --- /dev/null +++ b/auth_test.go @@ -0,0 +1,104 @@ +package main + +import ( + "net/smtp" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestLoginAuth(t *testing.T) { + auth := LoginAuth("testuser", "testpass") + assert.NotNil(t, auth) + + // Type assertion to ensure we get the right type + loginAuth, ok := auth.(*loginAuth) + assert.True(t, ok) + assert.Equal(t, "testuser", loginAuth.username) + assert.Equal(t, "testpass", loginAuth.password) +} + +func TestLoginAuthStart(t *testing.T) { + auth := &loginAuth{ + username: "testuser", + password: "testpass", + } + + method, resp, err := auth.Start(&smtp.ServerInfo{}) + + assert.NoError(t, err) + assert.Equal(t, "LOGIN", method) + assert.Empty(t, resp) +} + +func TestLoginAuthNext(t *testing.T) { + auth := &loginAuth{ + username: "testuser", + password: "testpass", + } + + tests := []struct { + name string + serverMsg string + more bool + expected string + expectErr bool + }{ + { + name: "username prompt - User Name", + serverMsg: "User Name", + more: true, + expected: "testuser", + expectErr: false, + }, + { + name: "username prompt - Username:", + serverMsg: "Username:", + more: true, + expected: "testuser", + expectErr: false, + }, + { + name: "password prompt - Password", + serverMsg: "Password", + more: true, + expected: "testpass", + expectErr: false, + }, + { + name: "password prompt - Password:", + serverMsg: "Password:", + more: true, + expected: "testpass", + expectErr: false, + }, + { + name: "unknown server response", + serverMsg: "Unknown Prompt", + more: true, + expected: "", + expectErr: true, + }, + { + name: "more is false", + serverMsg: "anything", + more: false, + expected: "", + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp, err := auth.Next([]byte(tt.serverMsg), tt.more) + + if tt.expectErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown server response") + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, string(resp)) + } + }) + } +} \ No newline at end of file diff --git a/client.go b/client.go index d9e8b5f..1d5afe7 100644 --- a/client.go +++ b/client.go @@ -8,6 +8,7 @@ import ( "net" "net/smtp" "net/textproto" + "strconv" "github.com/flashmob/go-guerrilla/mail" "github.com/pkg/errors" @@ -19,7 +20,7 @@ type closeable interface { // sendMail sends the contents of the envelope to a SMTP server. func sendMail(e *mail.Envelope, config *relayConfig) error { - server := fmt.Sprintf("%s:%d", config.Server, config.Port) + server := net.JoinHostPort(config.Server, strconv.Itoa(config.Port)) to := getTo(e) var msg bytes.Buffer @@ -36,7 +37,7 @@ func sendMail(e *mail.Envelope, config *relayConfig) error { if AllowedSendersFilter.Blocked(e.RemoteIP) { Logger.Info("Remote IP of " + e.RemoteIP + " not allowed to send email.") - return errors.Wrap(err, "Remote IP of "+e.RemoteIP+" not allowed to send email.") + return errors.New("Remote IP of " + e.RemoteIP + " not allowed to send email.") } tlsconfig := &tls.Config{ diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..158876d --- /dev/null +++ b/client_test.go @@ -0,0 +1,105 @@ +package main + +import ( + "net/textproto" + "testing" + + "github.com/flashmob/go-guerrilla/mail" + "github.com/stretchr/testify/assert" +) + +func TestGetTo(t *testing.T) { + tests := []struct { + name string + envelope *mail.Envelope + expected []string + }{ + { + name: "single recipient", + envelope: &mail.Envelope{ + RcptTo: []mail.Address{ + {User: "user1", Host: "example.com"}, + }, + }, + expected: []string{"user1@example.com"}, + }, + { + name: "multiple recipients", + envelope: &mail.Envelope{ + RcptTo: []mail.Address{ + {User: "user1", Host: "example.com"}, + {User: "user2", Host: "test.com"}, + {User: "admin", Host: "company.org"}, + }, + }, + expected: []string{ + "user1@example.com", + "user2@test.com", + "admin@company.org", + }, + }, + { + name: "no recipients", + envelope: &mail.Envelope{RcptTo: []mail.Address{}}, + expected: nil, + }, + { + name: "nil envelope recipients", + envelope: &mail.Envelope{RcptTo: nil}, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getTo(tt.envelope) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestIsQuitError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "SMTP 221 code (acceptable)", + err: &textproto.Error{Code: 221, Msg: "Bye"}, + expected: false, + }, + { + name: "SMTP 250 code (acceptable)", + err: &textproto.Error{Code: 250, Msg: "OK"}, + expected: false, + }, + { + name: "SMTP 550 error code", + err: &textproto.Error{Code: 550, Msg: "Mailbox not found"}, + expected: true, + }, + { + name: "SMTP 421 error code", + err: &textproto.Error{Code: 421, Msg: "Service not available"}, + expected: true, + }, + { + name: "non-textproto error", + err: assert.AnError, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isQuitError(tt.err) + assert.Equal(t, tt.expected, result) + }) + } +} \ No newline at end of file diff --git a/config_test.go b/config_test.go new file mode 100644 index 0000000..d6d4750 --- /dev/null +++ b/config_test.go @@ -0,0 +1,97 @@ +package main + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestConfigDefaults(t *testing.T) { + var cfg mailRelayConfig + configDefaults(&cfg) + + assert.Equal(t, DefaultSTMPPort, cfg.SMTPPort) + assert.Equal(t, false, cfg.SMTPStartTLS) + assert.Equal(t, false, cfg.SMTPLoginAuthType) + assert.Equal(t, int64(DefaultMaxEmailSize), cfg.MaxEmailSize) + assert.Equal(t, false, cfg.SkipCertVerify) + assert.Equal(t, DefaultLocalListenIP, cfg.LocalListenIP) + assert.Equal(t, DefaultLocalListenPort, cfg.LocalListenPort) + assert.Equal(t, []string{"*"}, cfg.AllowedHosts) + assert.Equal(t, "*", cfg.AllowedSenders) + assert.Equal(t, DefaultTimeoutSecs, cfg.TimeoutSecs) +} + +func TestLoadConfig(t *testing.T) { + tests := []struct { + name string + filename string + wantErr bool + validate func(t *testing.T, cfg *mailRelayConfig) + }{ + { + name: "valid config", + filename: "testdata/valid.json", + wantErr: false, + validate: func(t *testing.T, cfg *mailRelayConfig) { + assert.Equal(t, "smtp.test.com", cfg.SMTPServer) + assert.Equal(t, 587, cfg.SMTPPort) + assert.Equal(t, true, cfg.SMTPStartTLS) + assert.Equal(t, "testuser@test.com", cfg.SMTPUsername) + assert.Equal(t, "testpassword", cfg.SMTPPassword) + assert.Equal(t, "relay.test.com", cfg.SMTPHelo) + assert.Equal(t, "127.0.0.1", cfg.LocalListenIP) + assert.Equal(t, 2525, cfg.LocalListenPort) + assert.Equal(t, []string{"test.com", "example.com"}, cfg.AllowedHosts) + assert.Equal(t, 60, cfg.TimeoutSecs) + }, + }, + { + name: "minimal config with defaults", + filename: "testdata/minimal.json", + wantErr: false, + validate: func(t *testing.T, cfg *mailRelayConfig) { + assert.Equal(t, "smtp.minimal.com", cfg.SMTPServer) + assert.Equal(t, "user@minimal.com", cfg.SMTPUsername) + assert.Equal(t, "password", cfg.SMTPPassword) + // Check that defaults are applied + assert.Equal(t, DefaultSTMPPort, cfg.SMTPPort) + assert.Equal(t, DefaultLocalListenIP, cfg.LocalListenIP) + assert.Equal(t, DefaultLocalListenPort, cfg.LocalListenPort) + assert.Equal(t, []string{"*"}, cfg.AllowedHosts) + }, + }, + { + name: "invalid JSON", + filename: "testdata/invalid.json", + wantErr: true, + validate: nil, + }, + { + name: "nonexistent file", + filename: "testdata/nonexistent.json", + wantErr: true, + validate: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg, err := loadConfig(tt.filename) + + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, cfg) + return + } + + require.NoError(t, err) + require.NotNil(t, cfg) + + if tt.validate != nil { + tt.validate(t, cfg) + } + }) + } +} \ No newline at end of file diff --git a/go.mod b/go.mod index 0570e93..10ca08e 100644 --- a/go.mod +++ b/go.mod @@ -6,15 +6,17 @@ require ( github.com/flashmob/go-guerrilla v1.6.1 github.com/jpillora/ipfilter v1.2.2 github.com/pkg/errors v0.9.1 + github.com/stretchr/testify v1.5.1 ) require ( github.com/asaskevich/EventBus v0.0.0-20180103000110-68a521d7cbbb // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-sql-driver/mysql v1.5.0 // indirect github.com/konsorten/go-windows-terminal-sequences v1.0.3 // indirect github.com/phuslu/iploc v1.0.20200807 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sirupsen/logrus v1.6.0 // indirect - github.com/stretchr/testify v1.5.1 // indirect github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce // indirect golang.org/x/sys v0.1.0 // indirect gopkg.in/yaml.v2 v2.3.0 // indirect diff --git a/go.sum b/go.sum index 492216e..464dfbb 100644 --- a/go.sum +++ b/go.sum @@ -29,6 +29,7 @@ github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce/go.mod h1:o8v6yHRoi golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= diff --git a/integration_test.go b/integration_test.go new file mode 100644 index 0000000..e42e21c --- /dev/null +++ b/integration_test.go @@ -0,0 +1,370 @@ +package main + +import ( + "bytes" + "testing" + "time" + + "github.com/flashmob/go-guerrilla/log" + "github.com/flashmob/go-guerrilla/mail" + "github.com/jpillora/ipfilter" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// setupTestLogger initializes the logger for testing +func setupTestLogger(t *testing.T) { + var err error + Logger, err = log.GetLogger("stdout", "info") + require.NoError(t, err) +} + +func TestSendMail_Success(t *testing.T) { + setupTestLogger(t) + + // Start mock SMTP server + server := NewMockSMTPServer(t) + require.NoError(t, server.Start()) + defer server.Stop() + + // Configure for testing + config := &relayConfig{ + Server: server.Address(), + Port: server.Port(), + STARTTLS: true, + LoginAuthType: false, + Username: "", + Password: "", + SkipVerify: true, + HeloHost: "", + } + + // Create test envelope + envelope := &mail.Envelope{ + MailFrom: mail.Address{User: "sender", Host: "test.com"}, + RcptTo: []mail.Address{ + {User: "recipient1", Host: "example.com"}, + {User: "recipient2", Host: "example.com"}, + }, + Data: *bytes.NewBufferString("Subject: Test\r\n\r\nThis is a test email."), + RemoteIP: "127.0.0.1", + } + + // Set up IP filter to allow this IP + AllowedSendersFilter = ipfilter.New(ipfilter.Options{ + AllowedIPs: []string{"127.0.0.1"}, + BlockByDefault: false, + }) + + // Send email + err := sendMail(envelope, config) + assert.NoError(t, err) + + // Verify the mock server received the email + conn := server.GetLastConnection() + require.NotNil(t, conn) + assert.Equal(t, "sender@test.com", conn.From) + assert.Equal(t, []string{"recipient1@example.com", "recipient2@example.com"}, conn.To) + assert.Contains(t, conn.Data, "Subject: Test") + assert.Contains(t, conn.Data, "This is a test email.") +} + +func TestSendMail_WithAuthentication(t *testing.T) { + setupTestLogger(t) + // Start mock SMTP server with auth requirement + server := NewMockSMTPServer(t) + server.RequireAuth = true + require.NoError(t, server.Start()) + defer server.Stop() + + // Configure with authentication + config := &relayConfig{ + Server: server.Address(), + Port: server.Port(), + STARTTLS: true, + LoginAuthType: false, + Username: "testuser", + Password: "testpass", + SkipVerify: true, + HeloHost: "relay.test.com", + } + + // Create test envelope + envelope := &mail.Envelope{ + MailFrom: mail.Address{User: "sender", Host: "test.com"}, + RcptTo: []mail.Address{ + {User: "recipient", Host: "example.com"}, + }, + Data: *bytes.NewBufferString("Subject: Auth Test\r\n\r\nAuthenticated email."), + RemoteIP: "127.0.0.1", + } + + // Allow IP + AllowedSendersFilter = ipfilter.New(ipfilter.Options{ + BlockByDefault: false, + }) + + // Send email + err := sendMail(envelope, config) + assert.NoError(t, err) + + // Verify authentication was used + conn := server.GetLastConnection() + require.NotNil(t, conn) + assert.NotEmpty(t, conn.AuthUser) + assert.NotEmpty(t, conn.AuthPass) +} + +func TestSendMail_WithLoginAuth(t *testing.T) { + setupTestLogger(t) + // Start mock SMTP server with LOGIN auth support + server := NewMockSMTPServer(t) + server.RequireAuth = true + server.SupportLoginAuth = true + require.NoError(t, server.Start()) + defer server.Stop() + + // Configure with LOGIN authentication + config := &relayConfig{ + Server: server.Address(), + Port: server.Port(), + STARTTLS: true, + LoginAuthType: true, + Username: "testuser", + Password: "testpass", + SkipVerify: true, + HeloHost: "", + } + + // Create test envelope + envelope := &mail.Envelope{ + MailFrom: mail.Address{User: "sender", Host: "test.com"}, + RcptTo: []mail.Address{ + {User: "recipient", Host: "example.com"}, + }, + Data: *bytes.NewBufferString("Subject: LOGIN Auth Test\r\n\r\nLOGIN authenticated email."), + RemoteIP: "127.0.0.1", + } + + // Allow IP + AllowedSendersFilter = ipfilter.New(ipfilter.Options{ + BlockByDefault: false, + }) + + // Send email + err := sendMail(envelope, config) + assert.NoError(t, err) + + // Verify LOGIN authentication was used + conn := server.GetLastConnection() + require.NotNil(t, conn) + assert.NotEmpty(t, conn.AuthUser) + assert.NotEmpty(t, conn.AuthPass) +} + +func TestSendMail_IPFiltering_Blocked(t *testing.T) { + setupTestLogger(t) + // Start mock SMTP server + server := NewMockSMTPServer(t) + require.NoError(t, server.Start()) + defer server.Stop() + + // Configure server + config := &relayConfig{ + Server: server.Address(), + Port: server.Port(), + STARTTLS: true, + LoginAuthType: false, + Username: "", + Password: "", + SkipVerify: true, + HeloHost: "", + } + + // Create test envelope with blocked IP + envelope := &mail.Envelope{ + MailFrom: mail.Address{User: "sender", Host: "test.com"}, + RcptTo: []mail.Address{ + {User: "recipient", Host: "example.com"}, + }, + Data: *bytes.NewBufferString("Subject: Test\r\n\r\nThis should be blocked."), + RemoteIP: "192.168.1.100", // This IP will be blocked + } + + // Set up IP filter to block this IP + AllowedSendersFilter = ipfilter.New(ipfilter.Options{ + AllowedIPs: []string{"127.0.0.1"}, + BlockByDefault: true, + }) + + // Send email - should fail due to IP filtering + err := sendMail(envelope, config) + assert.Error(t, err) + assert.Contains(t, err.Error(), "192.168.1.100") + assert.Contains(t, err.Error(), "not allowed to send email") + + // Verify no email was sent to the server + conn := server.GetLastConnection() + // The connection might be nil or have no data since the IP was blocked before SMTP + if conn != nil { + assert.Empty(t, conn.From) + } +} + +func TestSendMail_IPFiltering_Allowed(t *testing.T) { + setupTestLogger(t) + // Start mock SMTP server + server := NewMockSMTPServer(t) + require.NoError(t, server.Start()) + defer server.Stop() + + // Configure server + config := &relayConfig{ + Server: server.Address(), + Port: server.Port(), + STARTTLS: true, + LoginAuthType: false, + Username: "", + Password: "", + SkipVerify: true, + HeloHost: "", + } + + // Create test envelope with allowed IP + envelope := &mail.Envelope{ + MailFrom: mail.Address{User: "sender", Host: "test.com"}, + RcptTo: []mail.Address{ + {User: "recipient", Host: "example.com"}, + }, + Data: *bytes.NewBufferString("Subject: Test\r\n\r\nThis should be allowed."), + RemoteIP: "192.168.1.100", + } + + // Set up IP filter to allow this specific IP + AllowedSendersFilter = ipfilter.New(ipfilter.Options{ + AllowedIPs: []string{"192.168.1.0/24"}, + BlockByDefault: true, + }) + + // Send email - should succeed + err := sendMail(envelope, config) + assert.NoError(t, err) + + // Verify email was sent to the server + conn := server.GetLastConnection() + require.NotNil(t, conn) + assert.Equal(t, "sender@test.com", conn.From) + assert.Equal(t, []string{"recipient@example.com"}, conn.To) +} + +func TestSendMail_ServerErrors(t *testing.T) { + setupTestLogger(t) + tests := []struct { + name string + failCommand string + expectError string + }{ + { + name: "MAIL command fails", + failCommand: "MAIL", + expectError: "mail error", + }, + { + name: "RCPT command fails", + failCommand: "RCPT", + expectError: "rcpt error", + }, + { + name: "DATA command fails", + failCommand: "DATA", + expectError: "data error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Start mock SMTP server + server := NewMockSMTPServer(t) + server.FailCommands[tt.failCommand] = true + require.NoError(t, server.Start()) + defer server.Stop() + + // Configure server + config := &relayConfig{ + Server: server.Address(), + Port: server.Port(), + STARTTLS: true, + LoginAuthType: false, + Username: "", + Password: "", + SkipVerify: true, + HeloHost: "", + } + + // Create test envelope + envelope := &mail.Envelope{ + MailFrom: mail.Address{User: "sender", Host: "test.com"}, + RcptTo: []mail.Address{ + {User: "recipient", Host: "example.com"}, + }, + Data: *bytes.NewBufferString("Subject: Test\r\n\r\nThis should fail."), + RemoteIP: "127.0.0.1", + } + + // Allow IP + AllowedSendersFilter = ipfilter.New(ipfilter.Options{ + BlockByDefault: false, + }) + + // Send email - should fail + err := sendMail(envelope, config) + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectError) + }) + } +} + +func TestSendMail_ConnectionTimeout(t *testing.T) { + setupTestLogger(t) + // Start mock SMTP server with delay + server := NewMockSMTPServer(t) + server.ResponseDelay = 100 * time.Millisecond + require.NoError(t, server.Start()) + defer server.Stop() + + // Configure server + config := &relayConfig{ + Server: server.Address(), + Port: server.Port(), + STARTTLS: true, + LoginAuthType: false, + Username: "", + Password: "", + SkipVerify: true, + HeloHost: "", + } + + // Create test envelope + envelope := &mail.Envelope{ + MailFrom: mail.Address{User: "sender", Host: "test.com"}, + RcptTo: []mail.Address{ + {User: "recipient", Host: "example.com"}, + }, + Data: *bytes.NewBufferString("Subject: Timeout Test\r\n\r\nThis tests server delays."), + RemoteIP: "127.0.0.1", + } + + // Allow IP + AllowedSendersFilter = ipfilter.New(ipfilter.Options{ + BlockByDefault: false, + }) + + // Send email - should still succeed despite delay + err := sendMail(envelope, config) + assert.NoError(t, err) + + // Verify email was eventually sent + conn := server.GetLastConnection() + require.NotNil(t, conn) + assert.Equal(t, "sender@test.com", conn.From) +} \ No newline at end of file diff --git a/mock_smtp_server.go b/mock_smtp_server.go new file mode 100644 index 0000000..73efdd2 --- /dev/null +++ b/mock_smtp_server.go @@ -0,0 +1,448 @@ +package main + +import ( + "bufio" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "math/big" + "net" + "strings" + "sync" + "testing" + "time" +) + +// MockSMTPServer represents a mock SMTP server for testing +type MockSMTPServer struct { + listener net.Listener + tlsConfig *tls.Config + address string + port int + running bool + mu sync.Mutex + + // Recorded interactions + Connections []MockConnection + + // Configuration + RequireAuth bool + RequireSTARTTLS bool + SupportLoginAuth bool + ResponseDelay time.Duration + FailCommands map[string]bool // Commands to fail + CustomResponses map[string]string + ImplicitTLS bool // True if server uses implicit TLS (like port 465) +} + +type MockConnection struct { + Commands []string + From string + To []string + Data string + AuthUser string + AuthPass string + UsedTLS bool +} + +// NewMockSMTPServer creates a new mock SMTP server +func NewMockSMTPServer(t *testing.T) *MockSMTPServer { + cert, err := generateTestCert() + if err != nil { + t.Fatalf("Failed to generate test certificate: %v", err) + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + InsecureSkipVerify: true, + } + + return &MockSMTPServer{ + tlsConfig: tlsConfig, + Connections: make([]MockConnection, 0), + FailCommands: make(map[string]bool), + CustomResponses: make(map[string]string), + } +} + +// Start starts the mock SMTP server +func (s *MockSMTPServer) Start() error { + s.mu.Lock() + defer s.mu.Unlock() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return err + } + + s.listener = listener + s.address = listener.Addr().(*net.TCPAddr).IP.String() + s.port = listener.Addr().(*net.TCPAddr).Port + s.running = true + + go s.acceptConnections() + + // Give the server a moment to start + time.Sleep(10 * time.Millisecond) + return nil +} + +// StartTLS starts the mock SMTP server with implicit TLS +func (s *MockSMTPServer) StartTLS() error { + s.mu.Lock() + defer s.mu.Unlock() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return err + } + + tlsListener := tls.NewListener(listener, s.tlsConfig) + s.listener = tlsListener + s.address = listener.Addr().(*net.TCPAddr).IP.String() + s.port = listener.Addr().(*net.TCPAddr).Port + s.running = true + s.ImplicitTLS = true + + go s.acceptConnections() + + // Give the server a moment to start + time.Sleep(10 * time.Millisecond) + return nil +} + +// Stop stops the mock SMTP server +func (s *MockSMTPServer) Stop() { + s.mu.Lock() + defer s.mu.Unlock() + + if s.listener != nil { + s.listener.Close() + s.running = false + } +} + +// Address returns the server address +func (s *MockSMTPServer) Address() string { + return s.address +} + +// Port returns the server port +func (s *MockSMTPServer) Port() int { + return s.port +} + +// GetLastConnection returns the most recent connection +func (s *MockSMTPServer) GetLastConnection() *MockConnection { + s.mu.Lock() + defer s.mu.Unlock() + + if len(s.Connections) == 0 { + return nil + } + return &s.Connections[len(s.Connections)-1] +} + +// Reset clears all recorded connections +func (s *MockSMTPServer) Reset() { + s.mu.Lock() + defer s.mu.Unlock() + + s.Connections = make([]MockConnection, 0) +} + +func (s *MockSMTPServer) acceptConnections() { + for s.running { + conn, err := s.listener.Accept() + if err != nil { + if s.running { + fmt.Printf("Accept error: %v\n", err) + } + continue + } + + go s.handleConnection(conn) + } +} + +func (s *MockSMTPServer) handleConnection(conn net.Conn) { + defer conn.Close() + + if s.ResponseDelay > 0 { + time.Sleep(s.ResponseDelay) + } + + reader := bufio.NewReader(conn) + writer := bufio.NewWriter(conn) + + mockConn := MockConnection{ + Commands: make([]string, 0), + To: make([]string, 0), + } + + // Check if this is a TLS connection (implicit TLS or post-STARTTLS) + if _, ok := conn.(*tls.Conn); ok || s.ImplicitTLS { + mockConn.UsedTLS = true + } + + // Send greeting + writer.WriteString("220 mock.smtp.server ESMTP ready\r\n") + writer.Flush() + + for { + line, err := reader.ReadString('\n') + if err != nil { + break + } + + line = strings.TrimSpace(line) + mockConn.Commands = append(mockConn.Commands, line) + + parts := strings.Fields(line) + if len(parts) == 0 { + continue + } + + cmd := strings.ToUpper(parts[0]) + + // Check if we should fail this command + if s.FailCommands[cmd] { + writer.WriteString("550 Command failed\r\n") + writer.Flush() + continue + } + + // Check for custom responses + if response, exists := s.CustomResponses[cmd]; exists { + writer.WriteString(response + "\r\n") + writer.Flush() + continue + } + + switch cmd { + case "EHLO", "HELO": + s.handleEHLO(writer) + case "STARTTLS": + tlsConn, newReader, newWriter, upgraded := s.handleSTARTTLS(conn, reader, writer, &mockConn) + if upgraded { + // Connection was upgraded to TLS, switch to new connection + conn = tlsConn + reader = newReader + writer = newWriter + } + case "AUTH": + s.handleAUTH(parts, reader, writer, &mockConn) + case "MAIL": + s.handleMAIL(parts, writer, &mockConn) + case "RCPT": + s.handleRCPT(parts, writer, &mockConn) + case "DATA": + s.handleDATA(reader, writer, &mockConn) + case "QUIT": + writer.WriteString("221 Bye\r\n") + writer.Flush() + s.mu.Lock() + s.Connections = append(s.Connections, mockConn) + s.mu.Unlock() + return + default: + writer.WriteString("500 Command not recognized\r\n") + writer.Flush() + } + } + + s.mu.Lock() + s.Connections = append(s.Connections, mockConn) + s.mu.Unlock() +} + +func (s *MockSMTPServer) handleEHLO(writer *bufio.Writer) { + writer.WriteString("250-mock.smtp.server\r\n") + if s.RequireSTARTTLS { + writer.WriteString("250-STARTTLS\r\n") + } + if s.RequireAuth { + if s.SupportLoginAuth { + writer.WriteString("250-AUTH PLAIN LOGIN\r\n") + } else { + writer.WriteString("250-AUTH PLAIN\r\n") + } + } + writer.WriteString("250 SIZE 10240000\r\n") + writer.Flush() +} + +func (s *MockSMTPServer) handleSTARTTLS(conn net.Conn, reader *bufio.Reader, writer *bufio.Writer, mockConn *MockConnection) (*tls.Conn, *bufio.Reader, *bufio.Writer, bool) { + writer.WriteString("220 Ready to start TLS\r\n") + writer.Flush() + + // Upgrade the connection to TLS + tlsConn := tls.Server(conn, s.tlsConfig) + if err := tlsConn.Handshake(); err != nil { + // TLS handshake failed, return original connection + return nil, reader, writer, false + } + + mockConn.UsedTLS = true + + // Return new TLS connection and readers/writers + newReader := bufio.NewReader(tlsConn) + newWriter := bufio.NewWriter(tlsConn) + + return tlsConn, newReader, newWriter, true +} + + +func (s *MockSMTPServer) handleAUTH(parts []string, reader *bufio.Reader, writer *bufio.Writer, mockConn *MockConnection) { + if len(parts) < 2 { + writer.WriteString("501 Syntax error\r\n") + writer.Flush() + return + } + + authType := strings.ToUpper(parts[1]) + + switch authType { + case "PLAIN": + // PLAIN auth can be sent in initial command or as a response to challenge + if len(parts) > 2 { + // Credentials provided in initial command + // authData := parts[2] // In a real implementation, we'd decode base64 and parse username/password + mockConn.AuthUser = "testuser" + mockConn.AuthPass = "testpass" + + writer.WriteString("235 Authentication successful\r\n") + writer.Flush() + } else { + // Challenge/response mode + writer.WriteString("334 \r\n") + writer.Flush() + + authData, _ := reader.ReadString('\n') + authData = strings.TrimSpace(authData) + // In a real implementation, we'd decode base64 and parse username/password + mockConn.AuthUser = "testuser" + mockConn.AuthPass = "testpass" + + writer.WriteString("235 Authentication successful\r\n") + writer.Flush() + } + + case "LOGIN": + writer.WriteString("334 VXNlcm5hbWU6\r\n") // "Username:" in base64 + writer.Flush() + + username, _ := reader.ReadString('\n') + mockConn.AuthUser = strings.TrimSpace(username) + + writer.WriteString("334 UGFzc3dvcmQ6\r\n") // "Password:" in base64 + writer.Flush() + + password, _ := reader.ReadString('\n') + mockConn.AuthPass = strings.TrimSpace(password) + + writer.WriteString("235 Authentication successful\r\n") + writer.Flush() + + default: + writer.WriteString("504 Authentication mechanism not supported\r\n") + writer.Flush() + } +} + +func (s *MockSMTPServer) handleMAIL(parts []string, writer *bufio.Writer, mockConn *MockConnection) { + if len(parts) < 2 { + writer.WriteString("501 Syntax error\r\n") + writer.Flush() + return + } + + fromAddr := strings.TrimPrefix(parts[1], "FROM:") + fromAddr = strings.Trim(fromAddr, "<>") + mockConn.From = fromAddr + + writer.WriteString("250 OK\r\n") + writer.Flush() +} + +func (s *MockSMTPServer) handleRCPT(parts []string, writer *bufio.Writer, mockConn *MockConnection) { + if len(parts) < 2 { + writer.WriteString("501 Syntax error\r\n") + writer.Flush() + return + } + + toAddr := strings.TrimPrefix(parts[1], "TO:") + toAddr = strings.Trim(toAddr, "<>") + mockConn.To = append(mockConn.To, toAddr) + + writer.WriteString("250 OK\r\n") + writer.Flush() +} + +func (s *MockSMTPServer) handleDATA(reader *bufio.Reader, writer *bufio.Writer, mockConn *MockConnection) { + writer.WriteString("354 Start mail input; end with .\r\n") + writer.Flush() + + var dataBuilder strings.Builder + for { + line, err := reader.ReadString('\n') + if err != nil { + break + } + + if strings.TrimSpace(line) == "." { + break + } + + dataBuilder.WriteString(line) + } + + mockConn.Data = dataBuilder.String() + + writer.WriteString("250 OK: message accepted\r\n") + writer.Flush() +} + +// generateTestCert creates a self-signed certificate for testing +func generateTestCert() (tls.Certificate, error) { + // Generate a private key + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return tls.Certificate{}, err + } + + // Create certificate template + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test"}, + Country: []string{"US"}, + Province: []string{""}, + Locality: []string{"Test City"}, + StreetAddress: []string{""}, + PostalCode: []string{""}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, + DNSNames: []string{"localhost"}, + } + + // Generate the certificate + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return tls.Certificate{}, err + } + + // Create the tls.Certificate + return tls.Certificate{ + Certificate: [][]byte{certDER}, + PrivateKey: priv, + }, nil +} \ No newline at end of file diff --git a/testdata/allowed_ips.txt b/testdata/allowed_ips.txt new file mode 100644 index 0000000..ab2411f --- /dev/null +++ b/testdata/allowed_ips.txt @@ -0,0 +1,3 @@ +192.168.1.0/24 +10.0.0.0/8 +127.0.0.1 \ No newline at end of file diff --git a/testdata/invalid.json b/testdata/invalid.json new file mode 100644 index 0000000..0ac9d41 --- /dev/null +++ b/testdata/invalid.json @@ -0,0 +1,5 @@ +{ + "smtp_server": "smtp.test.com", + "smtp_port": "invalid_port", + "missing_quote: "value" +} \ No newline at end of file diff --git a/testdata/minimal.json b/testdata/minimal.json new file mode 100644 index 0000000..328e6dd --- /dev/null +++ b/testdata/minimal.json @@ -0,0 +1,5 @@ +{ + "smtp_server": "smtp.minimal.com", + "smtp_username": "user@minimal.com", + "smtp_password": "password" +} \ No newline at end of file diff --git a/testdata/valid.json b/testdata/valid.json new file mode 100644 index 0000000..22f1635 --- /dev/null +++ b/testdata/valid.json @@ -0,0 +1,16 @@ +{ + "smtp_server": "smtp.test.com", + "smtp_port": 587, + "smtp_starttls": true, + "smtp_login_auth_type": false, + "smtp_username": "testuser@test.com", + "smtp_password": "testpassword", + "smtp_helo": "relay.test.com", + "smtp_skip_cert_verify": false, + "smtp_max_email_size": 10485760, + "local_listen_ip": "127.0.0.1", + "local_listen_port": 2525, + "allowed_hosts": ["test.com", "example.com"], + "allowed_senders": "*", + "timeout_secs": 60 +} \ No newline at end of file diff --git a/tls_test.go b/tls_test.go new file mode 100644 index 0000000..ad13802 --- /dev/null +++ b/tls_test.go @@ -0,0 +1,252 @@ +package main + +import ( + "bytes" + "testing" + + "github.com/flashmob/go-guerrilla/mail" + "github.com/jpillora/ipfilter" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSendMail_STARTTLS(t *testing.T) { + setupTestLogger(t) + // Start mock SMTP server with STARTTLS requirement + server := NewMockSMTPServer(t) + server.RequireSTARTTLS = true + require.NoError(t, server.Start()) + defer server.Stop() + + // Configure for STARTTLS + config := &relayConfig{ + Server: server.Address(), + Port: server.Port(), + STARTTLS: true, + LoginAuthType: false, + Username: "", + Password: "", + SkipVerify: true, + HeloHost: "", + } + + // Create test envelope + envelope := &mail.Envelope{ + MailFrom: mail.Address{User: "sender", Host: "test.com"}, + RcptTo: []mail.Address{ + {User: "recipient", Host: "example.com"}, + }, + Data: *bytes.NewBufferString("Subject: STARTTLS Test\r\n\r\nThis tests STARTTLS."), + RemoteIP: "127.0.0.1", + } + + // Allow IP + AllowedSendersFilter = ipfilter.New(ipfilter.Options{ + BlockByDefault: false, + }) + + // Send email + err := sendMail(envelope, config) + assert.NoError(t, err) + + // Verify the connection was established + conn := server.GetLastConnection() + require.NotNil(t, conn) + assert.Equal(t, "sender@test.com", conn.From) + assert.Equal(t, []string{"recipient@example.com"}, conn.To) + + // Verify STARTTLS was used (check commands include STARTTLS) + starttlsFound := false + for _, cmd := range conn.Commands { + if cmd == "STARTTLS" { + starttlsFound = true + break + } + } + assert.True(t, starttlsFound, "STARTTLS command should have been sent") + assert.True(t, conn.UsedTLS, "Connection should be marked as using TLS") +} + +func TestSendMail_ImplicitTLS(t *testing.T) { + setupTestLogger(t) + // Start mock SMTP server with implicit TLS + server := NewMockSMTPServer(t) + require.NoError(t, server.StartTLS()) + defer server.Stop() + + // Configure for implicit TLS (no STARTTLS) + config := &relayConfig{ + Server: server.Address(), + Port: server.Port(), + STARTTLS: false, + LoginAuthType: false, + Username: "", + Password: "", + SkipVerify: true, + HeloHost: "", + } + + // Create test envelope + envelope := &mail.Envelope{ + MailFrom: mail.Address{User: "sender", Host: "test.com"}, + RcptTo: []mail.Address{ + {User: "recipient", Host: "example.com"}, + }, + Data: *bytes.NewBufferString("Subject: Implicit TLS Test\r\n\r\nThis tests implicit TLS."), + RemoteIP: "127.0.0.1", + } + + // Allow IP + AllowedSendersFilter = ipfilter.New(ipfilter.Options{ + BlockByDefault: false, + }) + + // Send email + err := sendMail(envelope, config) + assert.NoError(t, err) + + // Verify the connection was established + conn := server.GetLastConnection() + require.NotNil(t, conn) + assert.Equal(t, "sender@test.com", conn.From) + assert.Equal(t, []string{"recipient@example.com"}, conn.To) + + // Verify no STARTTLS command was sent (since we're using implicit TLS) + starttlsFound := false + for _, cmd := range conn.Commands { + if cmd == "STARTTLS" { + starttlsFound = true + break + } + } + assert.False(t, starttlsFound, "STARTTLS command should not have been sent for implicit TLS") +} + +func TestSendMail_TLSWithAuthentication(t *testing.T) { + setupTestLogger(t) + // Start mock SMTP server with both TLS and authentication + server := NewMockSMTPServer(t) + server.RequireSTARTTLS = true + server.RequireAuth = true + require.NoError(t, server.Start()) + defer server.Stop() + + // Configure for STARTTLS with authentication + config := &relayConfig{ + Server: server.Address(), + Port: server.Port(), + STARTTLS: true, + LoginAuthType: false, + Username: "tlsuser", + Password: "tlspass", + SkipVerify: true, + HeloHost: "secure.relay.com", + } + + // Create test envelope + envelope := &mail.Envelope{ + MailFrom: mail.Address{User: "sender", Host: "test.com"}, + RcptTo: []mail.Address{ + {User: "recipient", Host: "example.com"}, + }, + Data: *bytes.NewBufferString("Subject: TLS + Auth Test\r\n\r\nThis tests TLS with authentication."), + RemoteIP: "127.0.0.1", + } + + // Allow IP + AllowedSendersFilter = ipfilter.New(ipfilter.Options{ + BlockByDefault: false, + }) + + // Send email + err := sendMail(envelope, config) + assert.NoError(t, err) + + // Verify the connection was established with both TLS and auth + conn := server.GetLastConnection() + require.NotNil(t, conn) + assert.Equal(t, "sender@test.com", conn.From) + assert.Equal(t, []string{"recipient@example.com"}, conn.To) + assert.True(t, conn.UsedTLS, "Connection should use TLS") + assert.NotEmpty(t, conn.AuthUser, "Authentication should have been used") + assert.NotEmpty(t, conn.AuthPass, "Authentication should have been used") + + // Verify command sequence (STARTTLS should come before AUTH) + starttlsIndex := -1 + authIndex := -1 + for i, cmd := range conn.Commands { + if cmd == "STARTTLS" { + starttlsIndex = i + } + if len(cmd) >= 4 && cmd[:4] == "AUTH" { + authIndex = i + } + } + assert.True(t, starttlsIndex >= 0, "STARTTLS command should be present") + assert.True(t, authIndex >= 0, "AUTH command should be present") + assert.True(t, starttlsIndex < authIndex, "STARTTLS should come before AUTH") +} + +func TestSendMail_TLSSkipVerify(t *testing.T) { + setupTestLogger(t) + // Test that we can handle certificate verification settings + server := NewMockSMTPServer(t) + require.NoError(t, server.StartTLS()) // Use implicit TLS + defer server.Stop() + + tests := []struct { + name string + skipVerify bool + }{ + {"skip certificate verification", true}, + {"enforce certificate verification", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server.Reset() + + config := &relayConfig{ + Server: server.Address(), + Port: server.Port(), + STARTTLS: false, + LoginAuthType: false, + Username: "", + Password: "", + SkipVerify: tt.skipVerify, + HeloHost: "", + } + + envelope := &mail.Envelope{ + MailFrom: mail.Address{User: "sender", Host: "test.com"}, + RcptTo: []mail.Address{ + {User: "recipient", Host: "example.com"}, + }, + Data: *bytes.NewBufferString("Subject: TLS Verify Test\r\n\r\nTesting certificate verification."), + RemoteIP: "127.0.0.1", + } + + // Allow IP + AllowedSendersFilter = ipfilter.New(ipfilter.Options{ + BlockByDefault: false, + }) + + // Send email + err := sendMail(envelope, config) + + if tt.skipVerify { + // Should succeed when skipping verification + assert.NoError(t, err) + + // Verify email was sent + conn := server.GetLastConnection() + require.NotNil(t, conn) + assert.Equal(t, "sender@test.com", conn.From) + } else { + // Should fail when enforcing verification with self-signed cert + assert.Error(t, err) + assert.Contains(t, err.Error(), "certificate") + } + }) + } +} \ No newline at end of file