From d739aab345d0b21ad06d52f3063f0d671774c55f Mon Sep 17 00:00:00 2001 From: Mariano Cano Date: Thu, 17 Aug 2023 12:56:26 -0700 Subject: [PATCH] Define BaseContext before starting the server in tests If the http.Server BaseContext is not define before the start of the server, it might not be properly set depending on the goroutine scheduler. This was causing random errors on CI. --- ca/tls_test.go | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/ca/tls_test.go b/ca/tls_test.go index 24b8ef01..dbcc6023 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -59,9 +59,13 @@ func generateOTT(subject string) string { return raw } -func startTestServer(tlsConfig *tls.Config, handler http.Handler) *httptest.Server { +func startTestServer(baseContext context.Context, tlsConfig *tls.Config, handler http.Handler) *httptest.Server { srv := httptest.NewUnstartedServer(handler) srv.TLS = tlsConfig + // Base context MUST be set before the start of the server + srv.Config.BaseContext = func(l net.Listener) context.Context { + return baseContext + } srv.StartTLS() // Force the use of GetCertificate on IPs srv.TLS.Certificates = nil @@ -78,11 +82,8 @@ func startCATestServer() *httptest.Server { panic(err) } // Use a httptest.Server instead - srv := startTestServer(ca.srv.TLSConfig, ca.srv.Handler) baseContext := buildContext(ca.auth, nil, nil, nil) - srv.Config.BaseContext = func(net.Listener) context.Context { - return baseContext - } + srv := startTestServer(baseContext, ca.srv.TLSConfig, ca.srv.Handler) return srv } @@ -153,7 +154,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) { if err != nil { t.Fatalf("Client.GetServerTLSConfig() error = %v", err) } - srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain)) + srvMTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvMTLS.Close() // Create TLS server @@ -163,7 +164,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) { if err != nil { t.Fatalf("Client.GetServerTLSConfig() error = %v", err) } - srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain)) + srvTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvTLS.Close() tests := []struct { @@ -258,7 +259,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { if err != nil { t.Fatalf("Client.GetServerTLSConfig() error = %v", err) } - srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain)) + srvMTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvMTLS.Close() // Start TLS server @@ -268,7 +269,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { if err != nil { t.Fatalf("Client.GetServerTLSConfig() error = %v", err) } - srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain)) + srvTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvTLS.Close() // Transport