diff --git a/ca/tls_test.go b/ca/tls_test.go index 24b8ef01..dbcc6023 100644 --- a/ca/tls_test.go +++ b/ca/tls_test.go @@ -59,9 +59,13 @@ func generateOTT(subject string) string { return raw } -func startTestServer(tlsConfig *tls.Config, handler http.Handler) *httptest.Server { +func startTestServer(baseContext context.Context, tlsConfig *tls.Config, handler http.Handler) *httptest.Server { srv := httptest.NewUnstartedServer(handler) srv.TLS = tlsConfig + // Base context MUST be set before the start of the server + srv.Config.BaseContext = func(l net.Listener) context.Context { + return baseContext + } srv.StartTLS() // Force the use of GetCertificate on IPs srv.TLS.Certificates = nil @@ -78,11 +82,8 @@ func startCATestServer() *httptest.Server { panic(err) } // Use a httptest.Server instead - srv := startTestServer(ca.srv.TLSConfig, ca.srv.Handler) baseContext := buildContext(ca.auth, nil, nil, nil) - srv.Config.BaseContext = func(net.Listener) context.Context { - return baseContext - } + srv := startTestServer(baseContext, ca.srv.TLSConfig, ca.srv.Handler) return srv } @@ -153,7 +154,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) { if err != nil { t.Fatalf("Client.GetServerTLSConfig() error = %v", err) } - srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain)) + srvMTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvMTLS.Close() // Create TLS server @@ -163,7 +164,7 @@ func TestClient_GetServerTLSConfig_http(t *testing.T) { if err != nil { t.Fatalf("Client.GetServerTLSConfig() error = %v", err) } - srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain)) + srvTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvTLS.Close() tests := []struct { @@ -258,7 +259,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { if err != nil { t.Fatalf("Client.GetServerTLSConfig() error = %v", err) } - srvMTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain)) + srvMTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvMTLS.Close() // Start TLS server @@ -268,7 +269,7 @@ func TestClient_GetServerTLSConfig_renew(t *testing.T) { if err != nil { t.Fatalf("Client.GetServerTLSConfig() error = %v", err) } - srvTLS := startTestServer(tlsConfig, serverHandler(t, clientDomain)) + srvTLS := startTestServer(context.Background(), tlsConfig, serverHandler(t, clientDomain)) defer srvTLS.Close() // Transport