Files
mailrelay/mock_smtp_server.go
wiggin77 b244db94e2 fix race
2025-05-25 00:06:14 -04:00

466 lines
11 KiB
Go

package main
import (
"bufio"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
"net"
"strings"
"sync"
"testing"
"time"
)
const (
smtpSTARTTLS = "STARTTLS"
rsaKeyBits = 2048
ipOctet127 = 127
sleepDurationMs = 10
minAuthParts = 2
)
// MockSMTPServer represents a mock SMTP server for testing.
type MockSMTPServer struct {
listener net.Listener
tlsConfig *tls.Config
address string
port int
running bool
mu sync.Mutex
// Recorded interactions
Connections []MockConnection
// Configuration
RequireAuth bool
RequireSTARTTLS bool
SupportLoginAuth bool
ResponseDelay time.Duration
FailCommands map[string]bool // Commands to fail
CustomResponses map[string]string
ImplicitTLS bool // True if server uses implicit TLS (like port 465)
}
type MockConnection struct {
Commands []string
From string
To []string
Data string
AuthUser string
AuthPass string
UsedTLS bool
}
// NewMockSMTPServer creates a new mock SMTP server.
func NewMockSMTPServer(t *testing.T) *MockSMTPServer {
cert, err := generateTestCert()
if err != nil {
t.Fatalf("Failed to generate test certificate: %v", err)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
InsecureSkipVerify: true,
}
return &MockSMTPServer{
tlsConfig: tlsConfig,
Connections: make([]MockConnection, 0),
FailCommands: make(map[string]bool),
CustomResponses: make(map[string]string),
}
}
// Start starts the mock SMTP server.
func (s *MockSMTPServer) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return err
}
s.listener = listener
addr := listener.Addr().(*net.TCPAddr)
s.address = addr.IP.String()
s.port = addr.Port
s.running = true
go s.acceptConnections()
// Give the server a moment to start
time.Sleep(sleepDurationMs * time.Millisecond)
return nil
}
// StartTLS starts the mock SMTP server with implicit TLS.
func (s *MockSMTPServer) StartTLS() error {
s.mu.Lock()
defer s.mu.Unlock()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return err
}
tlsListener := tls.NewListener(listener, s.tlsConfig)
s.listener = tlsListener
addr := listener.Addr().(*net.TCPAddr)
s.address = addr.IP.String()
s.port = addr.Port
s.running = true
s.ImplicitTLS = true
go s.acceptConnections()
// Give the server a moment to start
time.Sleep(sleepDurationMs * time.Millisecond)
return nil
}
// Stop stops the mock SMTP server.
func (s *MockSMTPServer) Stop() {
s.mu.Lock()
defer s.mu.Unlock()
if s.listener != nil {
s.listener.Close()
s.running = false
}
}
// Address returns the server address.
func (s *MockSMTPServer) Address() string {
return s.address
}
// Port returns the server port.
func (s *MockSMTPServer) Port() int {
return s.port
}
// GetLastConnection returns the most recent connection.
func (s *MockSMTPServer) GetLastConnection() *MockConnection {
s.mu.Lock()
defer s.mu.Unlock()
if len(s.Connections) == 0 {
return nil
}
return &s.Connections[len(s.Connections)-1]
}
// Reset clears all recorded connections.
func (s *MockSMTPServer) Reset() {
s.mu.Lock()
defer s.mu.Unlock()
s.Connections = make([]MockConnection, 0)
}
func (s *MockSMTPServer) acceptConnections() {
for s.isRunning() {
conn, err := s.listener.Accept()
if err != nil {
if s.isRunning() {
fmt.Printf("Accept error: %v\n", err)
}
continue
}
go s.handleConnection(conn)
}
}
func (s *MockSMTPServer) isRunning() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.running
}
func (s *MockSMTPServer) handleConnection(conn net.Conn) {
defer conn.Close()
if s.ResponseDelay > 0 {
time.Sleep(s.ResponseDelay)
}
reader := bufio.NewReader(conn)
writer := bufio.NewWriter(conn)
mockConn := MockConnection{
Commands: make([]string, 0),
To: make([]string, 0),
}
// Check if this is a TLS connection (implicit TLS or post-STARTTLS)
if _, ok := conn.(*tls.Conn); ok || s.ImplicitTLS {
mockConn.UsedTLS = true
}
// Send greeting
_, _ = writer.WriteString("220 mock.smtp.server ESMTP ready\r\n")
writer.Flush()
for {
line, err := reader.ReadString('\n')
if err != nil {
break
}
line = strings.TrimSpace(line)
mockConn.Commands = append(mockConn.Commands, line)
parts := strings.Fields(line)
if len(parts) == 0 {
continue
}
cmd := strings.ToUpper(parts[0])
// Check if we should fail this command
if s.FailCommands[cmd] {
_, _ = writer.WriteString("550 Command failed\r\n")
writer.Flush()
continue
}
// Check for custom responses
if response, exists := s.CustomResponses[cmd]; exists {
_, _ = writer.WriteString(response + "\r\n")
writer.Flush()
continue
}
switch cmd {
case "EHLO", "HELO":
s.handleEHLO(writer)
case smtpSTARTTLS:
tlsConn, newReader, newWriter, upgraded := s.handleSTARTTLS(conn, reader, writer, &mockConn)
if upgraded {
// Connection was upgraded to TLS, switch to new connection
conn = tlsConn
reader = newReader
writer = newWriter
}
case "AUTH":
s.handleAUTH(parts, reader, writer, &mockConn)
case "MAIL":
s.handleMAIL(parts, writer, &mockConn)
case "RCPT":
s.handleRCPT(parts, writer, &mockConn)
case "DATA":
s.handleDATA(reader, writer, &mockConn)
case "QUIT":
_, _ = writer.WriteString("221 Bye\r\n")
writer.Flush()
s.mu.Lock()
s.Connections = append(s.Connections, mockConn)
s.mu.Unlock()
return
default:
_, _ = writer.WriteString("500 Command not recognized\r\n")
writer.Flush()
}
}
s.mu.Lock()
s.Connections = append(s.Connections, mockConn)
s.mu.Unlock()
}
func (s *MockSMTPServer) handleEHLO(writer *bufio.Writer) {
_, _ = writer.WriteString("250-mock.smtp.server\r\n")
if s.RequireSTARTTLS {
_, _ = writer.WriteString("250-STARTTLS\r\n")
}
if s.RequireAuth {
if s.SupportLoginAuth {
_, _ = writer.WriteString("250-AUTH PLAIN LOGIN\r\n")
} else {
_, _ = writer.WriteString("250-AUTH PLAIN\r\n")
}
}
_, _ = writer.WriteString("250 SIZE 10240000\r\n")
writer.Flush()
}
func (s *MockSMTPServer) handleSTARTTLS(conn net.Conn, reader *bufio.Reader, writer *bufio.Writer, mockConn *MockConnection) (*tls.Conn, *bufio.Reader, *bufio.Writer, bool) {
writer.WriteString("220 Ready to start TLS\r\n")
writer.Flush()
// Upgrade the connection to TLS
tlsConn := tls.Server(conn, s.tlsConfig)
if err := tlsConn.Handshake(); err != nil {
// TLS handshake failed, return original connection
return nil, reader, writer, false
}
mockConn.UsedTLS = true
// Return new TLS connection and readers/writers
newReader := bufio.NewReader(tlsConn)
newWriter := bufio.NewWriter(tlsConn)
return tlsConn, newReader, newWriter, true
}
func (s *MockSMTPServer) handleAUTH(parts []string, reader *bufio.Reader, writer *bufio.Writer, mockConn *MockConnection) {
if len(parts) < minAuthParts {
_, _ = writer.WriteString("501 Syntax error\r\n")
writer.Flush()
return
}
authType := strings.ToUpper(parts[1])
switch authType {
case "PLAIN":
// PLAIN auth can be sent in initial command or as a response to challenge
if len(parts) > minAuthParts {
// Credentials provided in initial command
// authData := parts[2] // In a real implementation, we'd decode base64 and parse username/password
mockConn.AuthUser = "testuser"
mockConn.AuthPass = "testpass"
_, _ = writer.WriteString("235 Authentication successful\r\n")
writer.Flush()
} else {
// Challenge/response mode
_, _ = writer.WriteString("334 \r\n")
writer.Flush()
authData, _ := reader.ReadString('\n')
_ = strings.TrimSpace(authData)
// In a real implementation, we'd decode base64 and parse username/password
mockConn.AuthUser = "testuser"
mockConn.AuthPass = "testpass"
_, _ = writer.WriteString("235 Authentication successful\r\n")
writer.Flush()
}
case "LOGIN":
_, _ = writer.WriteString("334 VXNlcm5hbWU6\r\n") // "Username:" in base64
writer.Flush()
username, _ := reader.ReadString('\n')
_ = username
mockConn.AuthUser = strings.TrimSpace(username)
_, _ = writer.WriteString("334 UGFzc3dvcmQ6\r\n") // "Password:" in base64
writer.Flush()
password, _ := reader.ReadString('\n')
_ = password
mockConn.AuthPass = strings.TrimSpace(password)
_, _ = writer.WriteString("235 Authentication successful\r\n")
writer.Flush()
default:
_, _ = writer.WriteString("504 Authentication mechanism not supported\r\n")
writer.Flush()
}
}
func (s *MockSMTPServer) handleMAIL(parts []string, writer *bufio.Writer, mockConn *MockConnection) {
if len(parts) < minAuthParts {
_, _ = writer.WriteString("501 Syntax error\r\n")
writer.Flush()
return
}
fromAddr := strings.TrimPrefix(parts[1], "FROM:")
fromAddr = strings.Trim(fromAddr, "<>")
mockConn.From = fromAddr
_, _ = writer.WriteString("250 OK\r\n")
writer.Flush()
}
func (s *MockSMTPServer) handleRCPT(parts []string, writer *bufio.Writer, mockConn *MockConnection) {
if len(parts) < minAuthParts {
_, _ = writer.WriteString("501 Syntax error\r\n")
writer.Flush()
return
}
toAddr := strings.TrimPrefix(parts[1], "TO:")
toAddr = strings.Trim(toAddr, "<>")
mockConn.To = append(mockConn.To, toAddr)
_, _ = writer.WriteString("250 OK\r\n")
writer.Flush()
}
func (s *MockSMTPServer) handleDATA(reader *bufio.Reader, writer *bufio.Writer, mockConn *MockConnection) {
_, _ = writer.WriteString("354 Start mail input; end with <CRLF>.<CRLF>\r\n")
writer.Flush()
var dataBuilder strings.Builder
for {
line, err := reader.ReadString('\n')
if err != nil {
break
}
if strings.TrimSpace(line) == "." {
break
}
dataBuilder.WriteString(line)
}
mockConn.Data = dataBuilder.String()
_, _ = writer.WriteString("250 OK: message accepted\r\n")
writer.Flush()
}
// generateTestCert creates a self-signed certificate for testing.
func generateTestCert() (tls.Certificate, error) {
// Generate a private key
priv, err := rsa.GenerateKey(rand.Reader, rsaKeyBits)
if err != nil {
return tls.Certificate{}, err
}
// Create certificate template
template := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"Test"},
Country: []string{"US"},
Province: []string{""},
Locality: []string{"Test City"},
StreetAddress: []string{""},
PostalCode: []string{""},
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
IPAddresses: []net.IP{net.IPv4(ipOctet127, 0, 0, 1), net.IPv6loopback},
DNSNames: []string{"localhost"},
}
// Generate the certificate
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return tls.Certificate{}, err
}
// Create the tls.Certificate
return tls.Certificate{
Certificate: [][]byte{certDER},
PrivateKey: priv,
}, nil
}