@@ -19,6 +19,7 @@ const (
1919 dnsPacketSize = 1232 // EDNS0-safe UDP payload size to avoid IPv6 fragmentation; accommodates most real-world queries.
2020 maxRecursionDepth = 10 // maximum CNAME chain depth
2121 maxConcurrentRequests = 1000 // maximum concurrent DNS request handlers
22+ staggerDelay = 150 // milliseconds to wait before querying next nameserver
2223)
2324
2425// dnsCacheEntry represents a cached DNS response with metadata
@@ -447,51 +448,108 @@ func (s *DNSResolver) bytesToMsg(data []byte) *dns.Msg {
447448 return msg
448449}
449450
450- // forwardQueryTCP forwards a DNS query and returns the response (for TCP)
451+ // forwardQueryTCP forwards a DNS query and returns the response (for TCP) using staggered queries
451452func (s * DNSResolver ) forwardQueryTCP (originalData []byte , domain string , queryType uint16 , nameservers []string , cacheKey string ) * dns.Msg {
452- for _ , ns := range nameservers {
453- s .logger .Debug ("sending TCP DNS query for %s (%s) to %s" , domain , dns .TypeToString [queryType ], ns )
454- response , err := s .queryNameserver (s .ctx , originalData , ns )
455- if err != nil {
456- s .logger .Debug ("Failed to query %s: %v" , ns , err )
457- continue
458- }
459- if response == nil {
460- continue
461- }
453+ if len (nameservers ) == 0 {
454+ return nil
455+ }
462456
463- s .debugDNSResponse (response , domain , queryType )
457+ ctx , cancel := context .WithTimeout (s .ctx , s .queryTimeout )
458+ defer cancel ()
464459
465- // If the server returned SERVFAIL or REFUSED, try the next nameserver
466- // NXDOMAIN is authoritative and should not be retried
467- if response .Rcode == dns .RcodeServerFailure || response .Rcode == dns .RcodeRefused {
468- s .logger .Debug ("Nameserver %s returned %s, trying next" , ns , dns .RcodeToString [response .Rcode ])
469- continue
470- }
460+ resultCh := make (chan nsResponse , len (nameservers ))
461+ staggerTimer := time .NewTimer (staggerDelay * time .Millisecond )
462+ defer staggerTimer .Stop ()
471463
472- // Check if we need to recursively resolve CNAME
473- if needsRecursion , cnameTarget := s .needsCNAMEResolution (response , queryType ); needsRecursion {
474- s .logger .Debug ("Response contains CNAME without final record, recursively resolving %s" , cnameTarget )
475- cnameNameservers := s .selectNameservers (cnameTarget )
476- finalResponse , err := s .resolveCNAMERecursively (s .ctx , cnameTarget , queryType , cnameNameservers , 1 )
477- if err != nil {
478- s .logger .Error ("Failed to resolve CNAME chain: %v" , err )
479- } else if finalResponse != nil {
480- finalResponse .MsgHdr = response .MsgHdr
481- finalResponse .Question = response .Question
482- finalResponse .Answer = append (response .Answer , finalResponse .Answer ... )
483- response = finalResponse
484- s .debugDNSResponse (response , domain , queryType )
464+ nextNS := 0
465+
466+ // Start first query immediately
467+ s .logger .Debug ("sending TCP DNS query for %s (%s) to %s" , domain , dns .TypeToString [queryType ], nameservers [nextNS ])
468+ go func (ns string ) {
469+ response , err := s .queryNameserver (ctx , originalData , ns )
470+ resultCh <- nsResponse {response : response , ns : ns , err : err }
471+ }(nameservers [nextNS ])
472+ nextNS ++
473+
474+ responsesExpected := 1
475+ responsesReceived := 0
476+ failedResponses := 0
477+
478+ // Continue while we're waiting for responses OR we have more nameservers to try
479+ for responsesReceived < responsesExpected || nextNS < len (nameservers ) {
480+ select {
481+ case <- ctx .Done ():
482+ s .logger .Debug ("TCP DNS query for %s (%s) timed out" , domain , dns .TypeToString [queryType ])
483+ return nil
484+
485+ case <- staggerTimer .C :
486+ if nextNS < len (nameservers ) {
487+ s .logger .Debug ("staggering: sending TCP DNS query for %s (%s) to %s" , domain , dns .TypeToString [queryType ], nameservers [nextNS ])
488+ go func (ns string ) {
489+ response , err := s .queryNameserver (ctx , originalData , ns )
490+ resultCh <- nsResponse {response : response , ns : ns , err : err }
491+ }(nameservers [nextNS ])
492+ nextNS ++
493+ responsesExpected ++
494+ staggerTimer .Reset (staggerDelay * time .Millisecond )
485495 }
486- }
487496
488- // Cache the response
489- responseBytes , err := response .Pack ()
490- if err == nil {
491- s .cacheResponse (cacheKey , responseBytes , response , domain , queryType )
492- }
497+ case result := <- resultCh :
498+ responsesReceived ++
493499
494- return response
500+ if result .err != nil {
501+ s .logger .Debug ("Failed to query %s: %v" , result .ns , result .err )
502+ failedResponses ++
503+ if failedResponses == responsesReceived && nextNS < len (nameservers ) {
504+ staggerTimer .Reset (0 )
505+ }
506+ continue
507+ }
508+
509+ if result .response == nil {
510+ failedResponses ++
511+ if failedResponses == responsesReceived && nextNS < len (nameservers ) {
512+ staggerTimer .Reset (0 )
513+ }
514+ continue
515+ }
516+
517+ response := result .response
518+ s .debugDNSResponse (response , domain , queryType )
519+
520+ if response .Rcode == dns .RcodeServerFailure || response .Rcode == dns .RcodeRefused {
521+ s .logger .Debug ("Nameserver %s returned %s, waiting for other responses" , result .ns , dns .RcodeToString [response .Rcode ])
522+ failedResponses ++
523+ if failedResponses == responsesReceived && nextNS < len (nameservers ) {
524+ staggerTimer .Reset (0 )
525+ }
526+ continue
527+ }
528+
529+ // Check if we need to recursively resolve CNAME
530+ if needsRecursion , cnameTarget := s .needsCNAMEResolution (response , queryType ); needsRecursion {
531+ s .logger .Debug ("Response contains CNAME without final record, recursively resolving %s" , cnameTarget )
532+ cnameNameservers := s .selectNameservers (cnameTarget )
533+ finalResponse , err := s .resolveCNAMERecursively (ctx , cnameTarget , queryType , cnameNameservers , 1 )
534+ if err != nil {
535+ s .logger .Error ("Failed to resolve CNAME chain: %v" , err )
536+ } else if finalResponse != nil {
537+ finalResponse .MsgHdr = response .MsgHdr
538+ finalResponse .Question = response .Question
539+ finalResponse .Answer = append (response .Answer , finalResponse .Answer ... )
540+ response = finalResponse
541+ s .debugDNSResponse (response , domain , queryType )
542+ }
543+ }
544+
545+ // Cache the response
546+ responseBytes , err := response .Pack ()
547+ if err == nil {
548+ s .cacheResponse (cacheKey , responseBytes , response , domain , queryType )
549+ }
550+
551+ return response
552+ }
495553 }
496554
497555 s .logger .Debug ("TCP DNS returned error for %s (%s)" , domain , dns .TypeToString [queryType ])
@@ -589,66 +647,138 @@ func (s *DNSResolver) selectNameservers(target string) []string {
589647 return nameservers
590648}
591649
592- // forwardQuery forwards a DNS query to nameservers
650+ // nsResponse holds the result of a nameserver query
651+ type nsResponse struct {
652+ response * dns.Msg
653+ ns string
654+ err error
655+ }
656+
657+ // forwardQuery forwards a DNS query to nameservers using staggered concurrent queries.
658+ // It starts with the first nameserver, then after staggerDelay fires off additional
659+ // nameservers if no response has been received yet. Uses the first successful response.
593660func (s * DNSResolver ) forwardQuery (conn * net.UDPConn , originalData []byte , domain string , queryType uint16 , clientAddr * net.UDPAddr , nameservers []string , cacheKey string ) {
594- for _ , ns := range nameservers {
595- s .logger .Debug ("sending DNS query for %s (%s) to %s" , domain , dns .TypeToString [queryType ], ns )
596- response , err := s .queryNameserver (s .ctx , originalData , ns )
597- if err != nil {
598- s .logger .Debug ("Failed to query %s: %v" , ns , err )
599- continue
600- }
601- if response == nil {
602- continue
603- }
661+ if len (nameservers ) == 0 {
662+ s .sendErrorResponse (conn , originalData , clientAddr )
663+ return
664+ }
604665
605- s .debugDNSResponse (response , domain , queryType )
666+ ctx , cancel := context .WithTimeout (s .ctx , s .queryTimeout )
667+ defer cancel ()
606668
607- // If the server returned SERVFAIL or REFUSED, try the next nameserver
608- // NXDOMAIN is authoritative and should not be retried
609- if response .Rcode == dns .RcodeServerFailure || response .Rcode == dns .RcodeRefused {
610- s .logger .Debug ("Nameserver %s returned %s, trying next" , ns , dns .RcodeToString [response .Rcode ])
611- continue
612- }
669+ resultCh := make (chan nsResponse , len (nameservers ))
670+ staggerTimer := time .NewTimer (staggerDelay * time .Millisecond )
671+ defer staggerTimer .Stop ()
613672
614- // Check if we need to recursively resolve CNAME
615- if needsRecursion , cnameTarget := s .needsCNAMEResolution (response , queryType ); needsRecursion {
616- s .logger .Debug ("Response contains CNAME without final record, recursively resolving %s" , cnameTarget )
673+ // Track which nameservers we've started querying
674+ nextNS := 0
617675
618- // Determine nameservers for CNAME target
619- cnameNameservers := s .selectNameservers (cnameTarget )
676+ // Start first query immediately
677+ s .logger .Debug ("sending DNS query for %s (%s) to %s" , domain , dns .TypeToString [queryType ], nameservers [nextNS ])
678+ go func (ns string ) {
679+ response , err := s .queryNameserver (ctx , originalData , ns )
680+ resultCh <- nsResponse {response : response , ns : ns , err : err }
681+ }(nameservers [nextNS ])
682+ nextNS ++
620683
621- // Recursively resolve the CNAME
622- finalResponse , err := s .resolveCNAMERecursively (s .ctx , cnameTarget , queryType , cnameNameservers , 1 )
623- if err != nil {
624- s .logger .Error ("Failed to resolve CNAME chain: %v" , err )
625- // Fall back to sending the partial CNAME response
626- } else if finalResponse != nil {
627- // Merge CNAME records with final response
628- finalResponse .MsgHdr = response .MsgHdr // Preserve original header fields and transaction ID
629- finalResponse .Question = response .Question // Preserve original question
630- finalResponse .Answer = append (response .Answer , finalResponse .Answer ... )
631- response = finalResponse
632- s .debugDNSResponse (response , domain , queryType )
633- }
634- }
684+ responsesExpected := 1
685+ responsesReceived := 0
686+ failedResponses := 0
635687
636- // Pack the response back to raw bytes
637- responseBytes , err := response .Pack ()
638- if err != nil {
639- s .logger .Error ("failed to pack DNS response: %s" , err )
688+ // Continue while we're waiting for responses OR we have more nameservers to try
689+ for responsesReceived < responsesExpected || nextNS < len (nameservers ) {
690+ select {
691+ case <- ctx .Done ():
692+ s .logger .Debug ("DNS query for %s (%s) timed out" , domain , dns .TypeToString [queryType ])
693+ s .sendErrorResponse (conn , originalData , clientAddr )
640694 return
641- }
642695
643- // Cache the response using the minimum TTL from all answers
644- s .cacheResponse (cacheKey , responseBytes , response , domain , queryType )
696+ case <- staggerTimer .C :
697+ // No response yet, fire off the next nameserver
698+ if nextNS < len (nameservers ) {
699+ s .logger .Debug ("staggering: sending DNS query for %s (%s) to %s" , domain , dns .TypeToString [queryType ], nameservers [nextNS ])
700+ go func (ns string ) {
701+ response , err := s .queryNameserver (ctx , originalData , ns )
702+ resultCh <- nsResponse {response : response , ns : ns , err : err }
703+ }(nameservers [nextNS ])
704+ nextNS ++
705+ responsesExpected ++
706+ staggerTimer .Reset (staggerDelay * time .Millisecond )
707+ }
645708
646- // Send response back to client
647- _ , err = conn .WriteToUDP (responseBytes , clientAddr )
648- if err != nil {
649- s .logger .Error ("Failed to send DNS response: %v" , err )
709+ case result := <- resultCh :
710+ responsesReceived ++
711+
712+ if result .err != nil {
713+ s .logger .Debug ("Failed to query %s: %v" , result .ns , result .err )
714+ failedResponses ++
715+ // If all responses so far have failed and we have more nameservers, trigger next one immediately
716+ if failedResponses == responsesReceived && nextNS < len (nameservers ) {
717+ staggerTimer .Reset (0 ) // Fire next nameserver immediately
718+ }
719+ continue
720+ }
721+
722+ if result .response == nil {
723+ failedResponses ++
724+ if failedResponses == responsesReceived && nextNS < len (nameservers ) {
725+ staggerTimer .Reset (0 )
726+ }
727+ continue
728+ }
729+
730+ response := result .response
731+ s .debugDNSResponse (response , domain , queryType )
732+
733+ // If the server returned SERVFAIL or REFUSED, wait for other responses or try next
734+ if response .Rcode == dns .RcodeServerFailure || response .Rcode == dns .RcodeRefused {
735+ s .logger .Debug ("Nameserver %s returned %s, waiting for other responses" , result .ns , dns .RcodeToString [response .Rcode ])
736+ failedResponses ++
737+ if failedResponses == responsesReceived && nextNS < len (nameservers ) {
738+ staggerTimer .Reset (0 )
739+ }
740+ continue
741+ }
742+
743+ // Check if we need to recursively resolve CNAME
744+ if needsRecursion , cnameTarget := s .needsCNAMEResolution (response , queryType ); needsRecursion {
745+ s .logger .Debug ("Response contains CNAME without final record, recursively resolving %s" , cnameTarget )
746+
747+ // Determine nameservers for CNAME target
748+ cnameNameservers := s .selectNameservers (cnameTarget )
749+
750+ // Recursively resolve the CNAME
751+ finalResponse , err := s .resolveCNAMERecursively (ctx , cnameTarget , queryType , cnameNameservers , 1 )
752+ if err != nil {
753+ s .logger .Error ("Failed to resolve CNAME chain: %v" , err )
754+ // Fall back to sending the partial CNAME response
755+ } else if finalResponse != nil {
756+ // Merge CNAME records with final response
757+ finalResponse .MsgHdr = response .MsgHdr // Preserve original header fields and transaction ID
758+ finalResponse .Question = response .Question // Preserve original question
759+ finalResponse .Answer = append (response .Answer , finalResponse .Answer ... )
760+ response = finalResponse
761+ s .debugDNSResponse (response , domain , queryType )
762+ }
763+ }
764+
765+ // Pack the response back to raw bytes
766+ responseBytes , err := response .Pack ()
767+ if err != nil {
768+ s .logger .Error ("failed to pack DNS response: %s" , err )
769+ continue
770+ }
771+
772+ // Cache the response using the minimum TTL from all answers
773+ s .cacheResponse (cacheKey , responseBytes , response , domain , queryType )
774+
775+ // Send response back to client
776+ _ , err = conn .WriteToUDP (responseBytes , clientAddr )
777+ if err != nil {
778+ s .logger .Error ("Failed to send DNS response: %v" , err )
779+ }
780+ return
650781 }
651- return
652782 }
653783
654784 s .logger .Debug ("DNS returned error for %s (%s) to %s" , domain , dns .TypeToString [queryType ], clientAddr )
0 commit comments