package http import ( "crypto/tls" "net" "net/http" "reflect" "strings" "testing" "time" "golang.org/x/net/nettest" "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestGetOptions(t *testing.T) { tests := []struct { name string want Options }{ {name: "basic", want: defaultServerOptions}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := GetOptions(); !reflect.DeepEqual(got, tt.want) { t.Errorf("GetOptions() = %v, want %v", got, tt.want) } }) } } func TestMount(t *testing.T) { type args struct { pattern string h http.Handler } tests := []struct { name string args args wantErr bool }{ {name: "basic", args: args{ pattern: "/", h: http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {}), }, wantErr: false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.wantErr { require.Error(t, Mount(tt.args.pattern, tt.args.h)) } else { require.NoError(t, Mount(tt.args.pattern, tt.args.h)) } assert.NotNil(t, defaultServer) assert.True(t, defaultServer.baseRouter.Match(chi.NewRouteContext(), "GET", tt.args.pattern), "Failed to match route after registering") }) if err := Shutdown(); err != nil { t.Fatal(err) } } } func TestNewServer(t *testing.T) { type args struct { listeners []net.Listener tlsListeners []net.Listener opt Options } listener, err := nettest.NewLocalListener("tcp") if err != nil { t.Fatal(err) } tests := []struct { name string args args wantErr bool }{ {name: "default http", args: args{ listeners: []net.Listener{listener}, tlsListeners: []net.Listener{}, opt: defaultServerOptions, }, wantErr: false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := NewServer(tt.args.listeners, tt.args.tlsListeners, tt.args.opt) if (err != nil) != tt.wantErr { t.Errorf("NewServer() error = %v, wantErr %v", err, tt.wantErr) return } s, ok := got.(*server) require.True(t, ok, "NewServer returned unexpected type") if len(tt.args.listeners) > 0 { assert.Equal(t, listener.Addr(), s.addrs[0]) } else { assert.Empty(t, s.addrs) } if len(tt.args.tlsListeners) > 0 { assert.Equal(t, listener.Addr(), s.tlsAddrs[0]) } else { assert.Empty(t, s.tlsAddrs) } if tt.args.opt.BaseURL != "" { assert.NotSame(t, s.baseRouter, s.httpServer.Handler, "should have wrapped baseRouter") } else { assert.Same(t, s.baseRouter, s.httpServer.Handler, "should be baseRouter") } if useSSL(tt.args.opt) { assert.NotNil(t, s.httpServer.TLSConfig, "missing SSL config") } else { assert.Nil(t, s.httpServer.TLSConfig, "unexpectedly has SSL config") } }) } } func TestRestart(t *testing.T) { tests := []struct { name string started bool wantErr bool }{ {name: "started", started: true, wantErr: false}, {name: "stopped", started: false, wantErr: false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.started { require.NoError(t, Restart()) // Call it twice basically } else { require.NoError(t, Shutdown()) } current := defaultServer if err := Restart(); (err != nil) != tt.wantErr { t.Errorf("Restart() error = %v, wantErr %v", err, tt.wantErr) } assert.NotNil(t, defaultServer, "failed to start default server") assert.NotSame(t, current, defaultServer, "same server instance as before restart") }) } } func TestRoute(t *testing.T) { type args struct { pattern string fn func(r chi.Router) } tests := []struct { name string args args test func(t *testing.T, r chi.Router) }{ { name: "basic", args: args{ pattern: "/basic", fn: func(r chi.Router) {}, }, test: func(t *testing.T, r chi.Router) { require.Len(t, r.Routes(), 1) assert.Equal(t, r.Routes()[0].Pattern, "/basic/*") }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { require.NoError(t, Restart()) _, err := Route(tt.args.pattern, tt.args.fn) require.NoError(t, err) tt.test(t, defaultServer.baseRouter) }) if err := Shutdown(); err != nil { t.Fatal(err) } } } func TestSetOptions(t *testing.T) { type args struct { opt Options } tests := []struct { name string args args }{ { name: "basic", args: args{opt: Options{ ListenAddr: "127.0.0.1:9999", BaseURL: "/basic", ServerReadTimeout: 1, ServerWriteTimeout: 1, MaxHeaderBytes: 1, }}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { SetOptions(tt.args.opt) require.Equal(t, tt.args.opt, defaultServerOptions) require.NoError(t, Restart()) if useSSL(tt.args.opt) { assert.Equal(t, tt.args.opt.ListenAddr, defaultServer.tlsAddrs[0].String()) } else { assert.Equal(t, tt.args.opt.ListenAddr, defaultServer.addrs[0].String()) } assert.Equal(t, tt.args.opt.ServerReadTimeout, defaultServer.httpServer.ReadTimeout) assert.Equal(t, tt.args.opt.ServerWriteTimeout, defaultServer.httpServer.WriteTimeout) assert.Equal(t, tt.args.opt.MaxHeaderBytes, defaultServer.httpServer.MaxHeaderBytes) if tt.args.opt.BaseURL != "" && tt.args.opt.BaseURL != "/" { assert.NotSame(t, defaultServer.httpServer.Handler, defaultServer.baseRouter, "BaseURL ignored") } }) SetOptions(DefaultOpt) } } func TestShutdown(t *testing.T) { tests := []struct { name string started bool wantErr bool }{ {name: "started", started: true, wantErr: false}, {name: "stopped", started: false, wantErr: false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.started { require.NoError(t, Restart()) } else { require.NoError(t, Shutdown()) // Call it twice basically } if err := Shutdown(); (err != nil) != tt.wantErr { t.Errorf("Shutdown() error = %v, wantErr %v", err, tt.wantErr) } assert.Nil(t, defaultServer, "default server not deleted") }) } } func TestURL(t *testing.T) { tests := []struct { name string want string }{ {name: "basic", want: "http://127.0.0.1:8080/"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { require.NoError(t, Restart()) if got := URL(); got != tt.want { t.Errorf("URL() = %v, want %v", got, tt.want) } }) } } func Test_server_Mount(t *testing.T) { type args struct { pattern string h http.Handler } tests := []struct { name string args args opt Options }{ {name: "basic", args: args{ pattern: "/", h: http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {}), }, opt: defaultServerOptions}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { listener, err := nettest.NewLocalListener("tcp") require.NoError(t, err) s, err2 := NewServer([]net.Listener{listener}, []net.Listener{}, tt.opt) require.NoError(t, err2) s.Mount(tt.args.pattern, tt.args.h) srv, ok := s.(*server) require.True(t, ok) assert.NotNil(t, srv) assert.True(t, srv.baseRouter.Match(chi.NewRouteContext(), "GET", tt.args.pattern), "Failed to Match() route after registering") }) } } func Test_server_Route(t *testing.T) { type args struct { pattern string fn func(r chi.Router) } tests := []struct { name string args args opt Options test func(t *testing.T, r chi.Router) }{ { name: "basic", args: args{ pattern: "/basic", fn: func(r chi.Router) { }, }, test: func(t *testing.T, r chi.Router) { require.Len(t, r.Routes(), 1) assert.Equal(t, r.Routes()[0].Pattern, "/basic/*") }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { listener, err := nettest.NewLocalListener("tcp") require.NoError(t, err) s, err2 := NewServer([]net.Listener{listener}, []net.Listener{}, tt.opt) require.NoError(t, err2) s.Route(tt.args.pattern, tt.args.fn) srv, ok := s.(*server) require.True(t, ok) assert.NotNil(t, srv) tt.test(t, srv.baseRouter) }) } } func Test_server_Shutdown(t *testing.T) { tests := []struct { name string opt Options wantErr bool }{ { name: "basic", opt: defaultServerOptions, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { listener, err := nettest.NewLocalListener("tcp") require.NoError(t, err) s, err2 := NewServer([]net.Listener{listener}, []net.Listener{}, tt.opt) require.NoError(t, err2) srv, ok := s.(*server) require.True(t, ok) if err := s.Shutdown(); (err != nil) != tt.wantErr { t.Errorf("Shutdown() error = %v, wantErr %v", err, tt.wantErr) } assert.EqualError(t, srv.httpServer.Serve(listener), http.ErrServerClosed.Error()) }) } } func Test_start(t *testing.T) { http.DefaultTransport.(*http.Transport).TLSClientConfig = &tls.Config{InsecureSkipVerify: true} sslServerOptions := defaultServerOptions sslServerOptions.SslCertBody = []byte(`-----BEGIN CERTIFICATE----- MIIBhTCCASugAwIBAgIQIRi6zePL6mKjOipn+dNuaTAKBggqhkjOPQQDAjASMRAw DgYDVQQKEwdBY21lIENvMB4XDTE3MTAyMDE5NDMwNloXDTE4MTAyMDE5NDMwNlow EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD0d 7VNhbWvZLWPuj/RtHFjvtJBEwOkhbN/BnnE8rnZR8+sbwnc/KhCk3FhnpHZnQz7B 5aETbbIgmuvewdjvSBSjYzBhMA4GA1UdDwEB/wQEAwICpDATBgNVHSUEDDAKBggr BgEFBQcDATAPBgNVHRMBAf8EBTADAQH/MCkGA1UdEQQiMCCCDmxvY2FsaG9zdDo1 NDUzgg4xMjcuMC4wLjE6NTQ1MzAKBggqhkjOPQQDAgNIADBFAiEA2zpJEPQyz6/l Wf86aX6PepsntZv2GYlA5UpabfT2EZICICpJ5h/iI+i341gBmLiAFQOyTDT+/wQc 6MF9+Yw1Yy0t -----END CERTIFICATE-----`) sslServerOptions.SslKeyBody = []byte(`-----BEGIN EC PRIVATE KEY----- MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49 AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== -----END EC PRIVATE KEY-----`) tests := []struct { name string opt Options ssl bool wantErr bool }{ { name: "basic", opt: defaultServerOptions, wantErr: false, }, { name: "ssl", opt: sslServerOptions, ssl: true, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { defer func() { err := Shutdown() if err != nil { t.Fatal("couldn't shutdown server") } }() SetOptions(tt.opt) if err := start(); (err != nil) != tt.wantErr { t.Errorf("start() error = %v, wantErr %v", err, tt.wantErr) return } s := defaultServer router := s.Router() router.Head("/", func(writer http.ResponseWriter, request *http.Request) { writer.WriteHeader(201) }) testURL := URL() if tt.ssl { assert.True(t, useSSL(tt.opt)) assert.Equal(t, tt.opt.ListenAddr, s.tlsAddrs[0].String()) assert.True(t, strings.HasPrefix(testURL, "https://")) } else { assert.True(t, strings.HasPrefix(testURL, "http://")) assert.Equal(t, tt.opt.ListenAddr, s.addrs[0].String()) } // try to connect to the test server pause := time.Millisecond for i := 0; i < 10; i++ { resp, err := http.Head(testURL) if err == nil { _ = resp.Body.Close() return } // t.Logf("couldn't connect, sleeping for %v: %v", pause, err) time.Sleep(pause) pause *= 2 } t.Fatal("couldn't connect to server") /* accessing s.httpServer.* can't be done synchronously and is a race condition assert.Equal(t, tt.opt.ServerReadTimeout, defaultServer.httpServer.ReadTimeout) assert.Equal(t, tt.opt.ServerWriteTimeout, defaultServer.httpServer.WriteTimeout) assert.Equal(t, tt.opt.MaxHeaderBytes, defaultServer.httpServer.MaxHeaderBytes) if tt.opt.BaseURL != "" && tt.opt.BaseURL != "/" { assert.NotSame(t, s.baseRouter, s.httpServer.Handler, "should have wrapped baseRouter") } else { assert.Same(t, s.baseRouter, s.httpServer.Handler, "should be baseRouter") } if useSSL(tt.opt) { require.NotNil(t, s.httpServer.TLSConfig, "missing SSL config") assert.NotEmpty(t, s.httpServer.TLSConfig.Certificates, "missing SSL config") } else if s.httpServer.TLSConfig != nil { assert.Empty(t, s.httpServer.TLSConfig.Certificates, "unexpectedly has SSL config") } */ }) } } func Test_useSSL(t *testing.T) { type args struct { opt Options } tests := []struct { name string args args want bool }{ { name: "basic", args: args{opt: Options{ SslCert: "", SslKey: "", ClientCA: "", }}, want: false, }, { name: "basic", args: args{opt: Options{ SslCert: "", SslKey: "test", ClientCA: "", }}, want: true, }, { name: "body", args: args{opt: Options{ SslCert: "", SslKey: "", SslKeyBody: []byte(`test`), ClientCA: "", }}, want: true, }, { name: "basic", args: args{opt: Options{ SslCert: "", SslKey: "test", ClientCA: "", MinTLSVersion: "tls1.2", }}, want: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := useSSL(tt.args.opt); got != tt.want { t.Errorf("useSSL() = %v, want %v", got, tt.want) } }) } }