diff --git a/bdns/dns.go b/bdns/dns.go index ea91a5c43..b6f5c1eba 100644 --- a/bdns/dns.go +++ b/bdns/dns.go @@ -17,6 +17,7 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/letsencrypt/boulder/features" blog "github.com/letsencrypt/boulder/log" "github.com/letsencrypt/boulder/metrics" ) @@ -89,21 +90,31 @@ func New( log blog.Logger, tlsConfig *tls.Config, ) Client { - // Clone the default transport because it comes with various settings that we - // like, which are different from the zero value of an `http.Transport`. Then - // set it to force HTTP/2, because Unbound will reject non-HTTP/2 DoH - // requests. - transport := http.DefaultTransport.(*http.Transport).Clone() - transport.TLSClientConfig = tlsConfig - transport.ForceAttemptHTTP2 = true - - exchanger := &dohExchanger{ - clk: clk, - hc: http.Client{ - Timeout: readTimeout, - Transport: transport, - }, - userAgent: userAgent, + var exchanger exchanger + + if features.Get().DOH { + // Clone the default transport because it comes with various settings that we + // like, which are different from the zero value of an `http.Transport`. Then + // set it to force HTTP/2, because Unbound will reject non-HTTP/2 DoH + // requests. + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = tlsConfig + transport.ForceAttemptHTTP2 = true + + exchanger = &dohExchanger{ + clk: clk, + hc: http.Client{ + Timeout: readTimeout, + Transport: transport, + }, + userAgent: userAgent, + } + } else { + exchanger = &dns.Client{ + // Set timeout for underlying net.Conn + ReadTimeout: readTimeout, + Net: "udp", + } } queryTime := promauto.With(stats).NewHistogramVec( @@ -230,8 +241,14 @@ func (c *impl) exchangeOne(ctx context.Context, hostname string, qtype uint16) ( // Check if the error is a network timeout, rather than a local context // timeout. If it is, retry instead of giving up. - var netErr net.Error - isRetryable := ctx.Err() == nil && errors.As(err, &netErr) && netErr.Timeout() + var isRetryable bool + if features.Get().DOH { + var netErr net.Error + isRetryable = ctx.Err() == nil && errors.As(err, &netErr) && netErr.Timeout() + } else { + var opErr *net.OpError + isRetryable = ctx.Err() == nil && errors.As(err, &opErr) && opErr.Temporary() + } hasRetriesLeft := tries < c.maxTries if isRetryable && hasRetriesLeft { continue