Skip to content

Commit e85104a

Browse files
authored
fix(dns): implement staggered concurrent queries to reduce resolution latency (#145)
* fix(dns): implement staggered concurrent queries to reduce resolution latency Previously, the DNS resolver queried nameservers sequentially, waiting for the full timeout (5s) on each failure before trying the next server. With 3 internal nameservers, this caused 10-15s delays when initial servers were slow or unreachable (common in Docker environments). This change implements staggered concurrent queries: - Query first nameserver immediately - After 150ms, if no response, fire next nameserver in parallel - Continue staggering until all nameservers queried or valid response arrives - Use first successful response - On SERVFAIL/REFUSED/error, immediately trigger next nameserver (no wait) Before: ns0 timeout (5s) -> ns1 timeout (5s) -> ns2 responds = ~10s+ After: ns0 -> +150ms ns1 -> +150ms ns2 -> first response wins = <300ms Applied to both UDP (forwardQuery) and TCP (forwardQueryTCP) paths. * lint the files
1 parent 545750e commit e85104a

File tree

3 files changed

+238
-98
lines changed

3 files changed

+238
-98
lines changed

dns/resolv.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ func isLoopbackAddress(addr string) bool {
113113
// DefaultExternalDNSServers as fallbacks. Duplicates are removed.
114114
func GetSystemNameservers() []string {
115115
parsed, err := ParseResolvConf(DefaultResolveConfFilename)
116-
116+
117117
var result []string
118118
seen := make(map[string]bool)
119-
119+
120120
// Add system nameservers first (if any)
121121
if err == nil {
122122
for _, ns := range parsed.Nameservers {
@@ -126,15 +126,15 @@ func GetSystemNameservers() []string {
126126
}
127127
}
128128
}
129-
129+
130130
// Add default external nameservers as fallbacks
131131
for _, ns := range DefaultExternalDNSServers {
132132
if !seen[ns] {
133133
seen[ns] = true
134134
result = append(result, ns)
135135
}
136136
}
137-
137+
138138
return result
139139
}
140140

dns/server.go

Lines changed: 217 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ const (
1919
dnsPacketSize = 1232 // EDNS0-safe UDP payload size to avoid IPv6 fragmentation; accommodates most real-world queries.
2020
maxRecursionDepth = 10 // maximum CNAME chain depth
2121
maxConcurrentRequests = 1000 // maximum concurrent DNS request handlers
22+
staggerDelay = 150 // milliseconds to wait before querying next nameserver
2223
)
2324

2425
// dnsCacheEntry represents a cached DNS response with metadata
@@ -447,51 +448,108 @@ func (s *DNSResolver) bytesToMsg(data []byte) *dns.Msg {
447448
return msg
448449
}
449450

450-
// forwardQueryTCP forwards a DNS query and returns the response (for TCP)
451+
// forwardQueryTCP forwards a DNS query and returns the response (for TCP) using staggered queries
451452
func (s *DNSResolver) forwardQueryTCP(originalData []byte, domain string, queryType uint16, nameservers []string, cacheKey string) *dns.Msg {
452-
for _, ns := range nameservers {
453-
s.logger.Debug("sending TCP DNS query for %s (%s) to %s", domain, dns.TypeToString[queryType], ns)
454-
response, err := s.queryNameserver(s.ctx, originalData, ns)
455-
if err != nil {
456-
s.logger.Debug("Failed to query %s: %v", ns, err)
457-
continue
458-
}
459-
if response == nil {
460-
continue
461-
}
453+
if len(nameservers) == 0 {
454+
return nil
455+
}
462456

463-
s.debugDNSResponse(response, domain, queryType)
457+
ctx, cancel := context.WithTimeout(s.ctx, s.queryTimeout)
458+
defer cancel()
464459

465-
// If the server returned SERVFAIL or REFUSED, try the next nameserver
466-
// NXDOMAIN is authoritative and should not be retried
467-
if response.Rcode == dns.RcodeServerFailure || response.Rcode == dns.RcodeRefused {
468-
s.logger.Debug("Nameserver %s returned %s, trying next", ns, dns.RcodeToString[response.Rcode])
469-
continue
470-
}
460+
resultCh := make(chan nsResponse, len(nameservers))
461+
staggerTimer := time.NewTimer(staggerDelay * time.Millisecond)
462+
defer staggerTimer.Stop()
471463

472-
// Check if we need to recursively resolve CNAME
473-
if needsRecursion, cnameTarget := s.needsCNAMEResolution(response, queryType); needsRecursion {
474-
s.logger.Debug("Response contains CNAME without final record, recursively resolving %s", cnameTarget)
475-
cnameNameservers := s.selectNameservers(cnameTarget)
476-
finalResponse, err := s.resolveCNAMERecursively(s.ctx, cnameTarget, queryType, cnameNameservers, 1)
477-
if err != nil {
478-
s.logger.Error("Failed to resolve CNAME chain: %v", err)
479-
} else if finalResponse != nil {
480-
finalResponse.MsgHdr = response.MsgHdr
481-
finalResponse.Question = response.Question
482-
finalResponse.Answer = append(response.Answer, finalResponse.Answer...)
483-
response = finalResponse
484-
s.debugDNSResponse(response, domain, queryType)
464+
nextNS := 0
465+
466+
// Start first query immediately
467+
s.logger.Debug("sending TCP DNS query for %s (%s) to %s", domain, dns.TypeToString[queryType], nameservers[nextNS])
468+
go func(ns string) {
469+
response, err := s.queryNameserver(ctx, originalData, ns)
470+
resultCh <- nsResponse{response: response, ns: ns, err: err}
471+
}(nameservers[nextNS])
472+
nextNS++
473+
474+
responsesExpected := 1
475+
responsesReceived := 0
476+
failedResponses := 0
477+
478+
// Continue while we're waiting for responses OR we have more nameservers to try
479+
for responsesReceived < responsesExpected || nextNS < len(nameservers) {
480+
select {
481+
case <-ctx.Done():
482+
s.logger.Debug("TCP DNS query for %s (%s) timed out", domain, dns.TypeToString[queryType])
483+
return nil
484+
485+
case <-staggerTimer.C:
486+
if nextNS < len(nameservers) {
487+
s.logger.Debug("staggering: sending TCP DNS query for %s (%s) to %s", domain, dns.TypeToString[queryType], nameservers[nextNS])
488+
go func(ns string) {
489+
response, err := s.queryNameserver(ctx, originalData, ns)
490+
resultCh <- nsResponse{response: response, ns: ns, err: err}
491+
}(nameservers[nextNS])
492+
nextNS++
493+
responsesExpected++
494+
staggerTimer.Reset(staggerDelay * time.Millisecond)
485495
}
486-
}
487496

488-
// Cache the response
489-
responseBytes, err := response.Pack()
490-
if err == nil {
491-
s.cacheResponse(cacheKey, responseBytes, response, domain, queryType)
492-
}
497+
case result := <-resultCh:
498+
responsesReceived++
493499

494-
return response
500+
if result.err != nil {
501+
s.logger.Debug("Failed to query %s: %v", result.ns, result.err)
502+
failedResponses++
503+
if failedResponses == responsesReceived && nextNS < len(nameservers) {
504+
staggerTimer.Reset(0)
505+
}
506+
continue
507+
}
508+
509+
if result.response == nil {
510+
failedResponses++
511+
if failedResponses == responsesReceived && nextNS < len(nameservers) {
512+
staggerTimer.Reset(0)
513+
}
514+
continue
515+
}
516+
517+
response := result.response
518+
s.debugDNSResponse(response, domain, queryType)
519+
520+
if response.Rcode == dns.RcodeServerFailure || response.Rcode == dns.RcodeRefused {
521+
s.logger.Debug("Nameserver %s returned %s, waiting for other responses", result.ns, dns.RcodeToString[response.Rcode])
522+
failedResponses++
523+
if failedResponses == responsesReceived && nextNS < len(nameservers) {
524+
staggerTimer.Reset(0)
525+
}
526+
continue
527+
}
528+
529+
// Check if we need to recursively resolve CNAME
530+
if needsRecursion, cnameTarget := s.needsCNAMEResolution(response, queryType); needsRecursion {
531+
s.logger.Debug("Response contains CNAME without final record, recursively resolving %s", cnameTarget)
532+
cnameNameservers := s.selectNameservers(cnameTarget)
533+
finalResponse, err := s.resolveCNAMERecursively(ctx, cnameTarget, queryType, cnameNameservers, 1)
534+
if err != nil {
535+
s.logger.Error("Failed to resolve CNAME chain: %v", err)
536+
} else if finalResponse != nil {
537+
finalResponse.MsgHdr = response.MsgHdr
538+
finalResponse.Question = response.Question
539+
finalResponse.Answer = append(response.Answer, finalResponse.Answer...)
540+
response = finalResponse
541+
s.debugDNSResponse(response, domain, queryType)
542+
}
543+
}
544+
545+
// Cache the response
546+
responseBytes, err := response.Pack()
547+
if err == nil {
548+
s.cacheResponse(cacheKey, responseBytes, response, domain, queryType)
549+
}
550+
551+
return response
552+
}
495553
}
496554

497555
s.logger.Debug("TCP DNS returned error for %s (%s)", domain, dns.TypeToString[queryType])
@@ -589,66 +647,138 @@ func (s *DNSResolver) selectNameservers(target string) []string {
589647
return nameservers
590648
}
591649

592-
// forwardQuery forwards a DNS query to nameservers
650+
// nsResponse holds the result of a nameserver query
651+
type nsResponse struct {
652+
response *dns.Msg
653+
ns string
654+
err error
655+
}
656+
657+
// forwardQuery forwards a DNS query to nameservers using staggered concurrent queries.
658+
// It starts with the first nameserver, then after staggerDelay fires off additional
659+
// nameservers if no response has been received yet. Uses the first successful response.
593660
func (s *DNSResolver) forwardQuery(conn *net.UDPConn, originalData []byte, domain string, queryType uint16, clientAddr *net.UDPAddr, nameservers []string, cacheKey string) {
594-
for _, ns := range nameservers {
595-
s.logger.Debug("sending DNS query for %s (%s) to %s", domain, dns.TypeToString[queryType], ns)
596-
response, err := s.queryNameserver(s.ctx, originalData, ns)
597-
if err != nil {
598-
s.logger.Debug("Failed to query %s: %v", ns, err)
599-
continue
600-
}
601-
if response == nil {
602-
continue
603-
}
661+
if len(nameservers) == 0 {
662+
s.sendErrorResponse(conn, originalData, clientAddr)
663+
return
664+
}
604665

605-
s.debugDNSResponse(response, domain, queryType)
666+
ctx, cancel := context.WithTimeout(s.ctx, s.queryTimeout)
667+
defer cancel()
606668

607-
// If the server returned SERVFAIL or REFUSED, try the next nameserver
608-
// NXDOMAIN is authoritative and should not be retried
609-
if response.Rcode == dns.RcodeServerFailure || response.Rcode == dns.RcodeRefused {
610-
s.logger.Debug("Nameserver %s returned %s, trying next", ns, dns.RcodeToString[response.Rcode])
611-
continue
612-
}
669+
resultCh := make(chan nsResponse, len(nameservers))
670+
staggerTimer := time.NewTimer(staggerDelay * time.Millisecond)
671+
defer staggerTimer.Stop()
613672

614-
// Check if we need to recursively resolve CNAME
615-
if needsRecursion, cnameTarget := s.needsCNAMEResolution(response, queryType); needsRecursion {
616-
s.logger.Debug("Response contains CNAME without final record, recursively resolving %s", cnameTarget)
673+
// Track which nameservers we've started querying
674+
nextNS := 0
617675

618-
// Determine nameservers for CNAME target
619-
cnameNameservers := s.selectNameservers(cnameTarget)
676+
// Start first query immediately
677+
s.logger.Debug("sending DNS query for %s (%s) to %s", domain, dns.TypeToString[queryType], nameservers[nextNS])
678+
go func(ns string) {
679+
response, err := s.queryNameserver(ctx, originalData, ns)
680+
resultCh <- nsResponse{response: response, ns: ns, err: err}
681+
}(nameservers[nextNS])
682+
nextNS++
620683

621-
// Recursively resolve the CNAME
622-
finalResponse, err := s.resolveCNAMERecursively(s.ctx, cnameTarget, queryType, cnameNameservers, 1)
623-
if err != nil {
624-
s.logger.Error("Failed to resolve CNAME chain: %v", err)
625-
// Fall back to sending the partial CNAME response
626-
} else if finalResponse != nil {
627-
// Merge CNAME records with final response
628-
finalResponse.MsgHdr = response.MsgHdr // Preserve original header fields and transaction ID
629-
finalResponse.Question = response.Question // Preserve original question
630-
finalResponse.Answer = append(response.Answer, finalResponse.Answer...)
631-
response = finalResponse
632-
s.debugDNSResponse(response, domain, queryType)
633-
}
634-
}
684+
responsesExpected := 1
685+
responsesReceived := 0
686+
failedResponses := 0
635687

636-
// Pack the response back to raw bytes
637-
responseBytes, err := response.Pack()
638-
if err != nil {
639-
s.logger.Error("failed to pack DNS response: %s", err)
688+
// Continue while we're waiting for responses OR we have more nameservers to try
689+
for responsesReceived < responsesExpected || nextNS < len(nameservers) {
690+
select {
691+
case <-ctx.Done():
692+
s.logger.Debug("DNS query for %s (%s) timed out", domain, dns.TypeToString[queryType])
693+
s.sendErrorResponse(conn, originalData, clientAddr)
640694
return
641-
}
642695

643-
// Cache the response using the minimum TTL from all answers
644-
s.cacheResponse(cacheKey, responseBytes, response, domain, queryType)
696+
case <-staggerTimer.C:
697+
// No response yet, fire off the next nameserver
698+
if nextNS < len(nameservers) {
699+
s.logger.Debug("staggering: sending DNS query for %s (%s) to %s", domain, dns.TypeToString[queryType], nameservers[nextNS])
700+
go func(ns string) {
701+
response, err := s.queryNameserver(ctx, originalData, ns)
702+
resultCh <- nsResponse{response: response, ns: ns, err: err}
703+
}(nameservers[nextNS])
704+
nextNS++
705+
responsesExpected++
706+
staggerTimer.Reset(staggerDelay * time.Millisecond)
707+
}
645708

646-
// Send response back to client
647-
_, err = conn.WriteToUDP(responseBytes, clientAddr)
648-
if err != nil {
649-
s.logger.Error("Failed to send DNS response: %v", err)
709+
case result := <-resultCh:
710+
responsesReceived++
711+
712+
if result.err != nil {
713+
s.logger.Debug("Failed to query %s: %v", result.ns, result.err)
714+
failedResponses++
715+
// If all responses so far have failed and we have more nameservers, trigger next one immediately
716+
if failedResponses == responsesReceived && nextNS < len(nameservers) {
717+
staggerTimer.Reset(0) // Fire next nameserver immediately
718+
}
719+
continue
720+
}
721+
722+
if result.response == nil {
723+
failedResponses++
724+
if failedResponses == responsesReceived && nextNS < len(nameservers) {
725+
staggerTimer.Reset(0)
726+
}
727+
continue
728+
}
729+
730+
response := result.response
731+
s.debugDNSResponse(response, domain, queryType)
732+
733+
// If the server returned SERVFAIL or REFUSED, wait for other responses or try next
734+
if response.Rcode == dns.RcodeServerFailure || response.Rcode == dns.RcodeRefused {
735+
s.logger.Debug("Nameserver %s returned %s, waiting for other responses", result.ns, dns.RcodeToString[response.Rcode])
736+
failedResponses++
737+
if failedResponses == responsesReceived && nextNS < len(nameservers) {
738+
staggerTimer.Reset(0)
739+
}
740+
continue
741+
}
742+
743+
// Check if we need to recursively resolve CNAME
744+
if needsRecursion, cnameTarget := s.needsCNAMEResolution(response, queryType); needsRecursion {
745+
s.logger.Debug("Response contains CNAME without final record, recursively resolving %s", cnameTarget)
746+
747+
// Determine nameservers for CNAME target
748+
cnameNameservers := s.selectNameservers(cnameTarget)
749+
750+
// Recursively resolve the CNAME
751+
finalResponse, err := s.resolveCNAMERecursively(ctx, cnameTarget, queryType, cnameNameservers, 1)
752+
if err != nil {
753+
s.logger.Error("Failed to resolve CNAME chain: %v", err)
754+
// Fall back to sending the partial CNAME response
755+
} else if finalResponse != nil {
756+
// Merge CNAME records with final response
757+
finalResponse.MsgHdr = response.MsgHdr // Preserve original header fields and transaction ID
758+
finalResponse.Question = response.Question // Preserve original question
759+
finalResponse.Answer = append(response.Answer, finalResponse.Answer...)
760+
response = finalResponse
761+
s.debugDNSResponse(response, domain, queryType)
762+
}
763+
}
764+
765+
// Pack the response back to raw bytes
766+
responseBytes, err := response.Pack()
767+
if err != nil {
768+
s.logger.Error("failed to pack DNS response: %s", err)
769+
continue
770+
}
771+
772+
// Cache the response using the minimum TTL from all answers
773+
s.cacheResponse(cacheKey, responseBytes, response, domain, queryType)
774+
775+
// Send response back to client
776+
_, err = conn.WriteToUDP(responseBytes, clientAddr)
777+
if err != nil {
778+
s.logger.Error("Failed to send DNS response: %v", err)
779+
}
780+
return
650781
}
651-
return
652782
}
653783

654784
s.logger.Debug("DNS returned error for %s (%s) to %s", domain, dns.TypeToString[queryType], clientAddr)

0 commit comments

Comments
 (0)