diff --git a/lib/mongo/srv/result.rb b/lib/mongo/srv/result.rb index 318011914f..7c63cb9bac 100644 --- a/lib/mongo/srv/result.rb +++ b/lib/mongo/srv/result.rb @@ -110,17 +110,27 @@ def normalize_hostname(host) # A hostname's domain name consists of each of the '.' delineated # parts after the first. For example, the hostname 'foo.bar.baz' # has the domain name 'bar.baz'. + # + # If the hostname has less than three parts, its domain name is the hostname itself. # # @param [ String ] record_host The host of the SRV record. # # @raise [ Mongo::Error::MismatchedDomain ] If the record's domain name doesn't match that of # the hostname. def validate_same_origin!(record_host) - domain_name ||= query_hostname.split('.')[1..-1] - host_parts = record_host.split('.') + srv_host_domain = query_hostname.split('.') + srv_is_less_than_three_parts = srv_host_domain.length < 3 + unless srv_is_less_than_three_parts + srv_host_domain = srv_host_domain[1..-1] + end + record_host_parts = record_host.split('.') + + if (srv_is_less_than_three_parts && record_host_parts.length <= srv_host_domain.length) + raise Error::MismatchedDomain.new(MISMATCHED_DOMAINNAME % [record_host, srv_host_domain]) + end - unless (host_parts.size > domain_name.size) && (domain_name == host_parts[-domain_name.length..-1]) - raise Error::MismatchedDomain.new(MISMATCHED_DOMAINNAME % [record_host, domain_name]) + unless (record_host_parts.size > srv_host_domain.size) && (srv_host_domain == record_host_parts[-srv_host_domain.size..-1]) + raise Error::MismatchedDomain.new(MISMATCHED_DOMAINNAME % [record_host, srv_host_domain]) end end end diff --git a/lib/mongo/uri/srv_protocol.rb b/lib/mongo/uri/srv_protocol.rb index 4336d1c814..c0484db697 100644 --- a/lib/mongo/uri/srv_protocol.rb +++ b/lib/mongo/uri/srv_protocol.rb @@ -168,8 +168,7 @@ def parse!(remaining) # The hostname cannot include a port. # # The hostname must not begin with a dot, end with a dot, or have - # consecutive dots. The hostname must have a minimum of 3 total - # components (foo.bar.tld). + # consecutive dots. The hostname must have a minimum of 1 component # # Raises Error::InvalidURI if validation fails. def validate_srv_hostname(hostname) @@ -185,8 +184,8 @@ def validate_srv_hostname(hostname) if parts.any?(&:empty?) raise_invalid_error!("Hostname cannot have consecutive dots: #{hostname}") end - if parts.length < 3 - raise_invalid_error!("Hostname must have a minimum of 3 components (foo.bar.tld): #{hostname}") + if parts.length < 1 + raise_invalid_error!("Hostname cannot be empty: #{hostname}") end end diff --git a/spec/mongo/srv/result_spec.rb b/spec/mongo/srv/result_spec.rb index 227679a459..c56536f35e 100644 --- a/spec/mongo/srv/result_spec.rb +++ b/spec/mongo/srv/result_spec.rb @@ -23,6 +23,75 @@ expect(result.address_strs).to eq(['foo.bar.com:42']) end end + + example_srv_names = ['i-love-rb', 'i-love-rb.mongodb', 'i-love-ruby.mongodb.io']; + example_host_names = [ + 'rb-00.i-love-rb', + 'rb-00.i-love-rb.mongodb', + 'i-love-ruby-00.mongodb.io' + ]; + example_host_names_that_do_not_match_parent = [ + 'rb-00.i-love-rb-a-little', + 'rb-00.i-love-rb-a-little.mongodb', + 'i-love-ruby-00.evil-mongodb.io' + ]; + + (0..2).each do |i| + context "when srvName has #{i+1} part#{i != 0 ? 's' : ''}" do + let(:srv_name) { example_srv_names[i] } + let(:host_name) { example_host_names[i] } + let(:mismatched_host_name) { example_host_names_that_do_not_match_parent[i] } + context 'when address does not match parent domain' do + let(:record) do + double('record').tap do |record| + allow(record).to receive(:target).and_return(mismatched_host_name) + allow(record).to receive(:port).and_return(42) + allow(record).to receive(:ttl).and_return(1) + end + end + it 'raises MismatchedDomain error' do + expect { + result = described_class.new(srv_name) + result.add_record(record) + }.to raise_error(Mongo::Error::MismatchedDomain) + end + end + + context 'when address matches parent domain' do + let(:record) do + double('record').tap do |record| + allow(record).to receive(:target).and_return(host_name) + allow(record).to receive(:port).and_return(42) + allow(record).to receive(:ttl).and_return(1) + end + end + it 'adds the record' do + result = described_class.new(srv_name) + result.add_record(record) + + expect(result.address_strs).to eq([host_name + ':42']) + end + end + + if i < 2 + context 'when the address is less than 3 parts' do + let(:record) do + double('record').tap do |record| + allow(record).to receive(:target).and_return(srv_name) + allow(record).to receive(:port).and_return(42) + allow(record).to receive(:ttl).and_return(1) + end + end + it 'does not accept address if it does not contain an extra domain level' do + expect { + result = described_class.new(srv_name) + result.add_record(record) + }.to raise_error(Mongo::Error::MismatchedDomain) + end + end + end + end + end end describe '#normalize_hostname' do diff --git a/spec/mongo/uri/srv_protocol_spec.rb b/spec/mongo/uri/srv_protocol_spec.rb index a1b13a0195..be2790eeda 100644 --- a/spec/mongo/uri/srv_protocol_spec.rb +++ b/spec/mongo/uri/srv_protocol_spec.rb @@ -56,16 +56,6 @@ end end - context 'when the host in URI does not have {hostname}, {domainname} and {tld}' do - - let(:string) { "#{scheme}#{hosts}" } - let(:hosts) { '10gen.cc/' } - - it 'raises an error' do - expect { uri }.to raise_error(Mongo::Error::InvalidURI) - end - end - context 'when the {tld} is empty' do let(:string) { "#{scheme}#{hosts}" } @@ -220,24 +210,6 @@ expect { uri }.to raise_error(Mongo::Error::InvalidURI) end end - - context 'mongodb+srv://example.com?w=1' do - - let(:string) { "#{scheme}example.com?w=1" } - - it 'raises an error' do - expect { uri }.to raise_error(Mongo::Error::InvalidURI) - end - end - - context 'mongodb+srv://example.com/?w' do - - let(:string) { "#{scheme}example.com/?w" } - - it 'raises an error' do - expect { uri }.to raise_error(Mongo::Error::InvalidURI) - end - end end describe 'valid uris' do @@ -1302,8 +1274,8 @@ 'a' end - it 'raises an error' do - expect { validate }.to raise_error(Mongo::Error::InvalidURI) + it 'does not raise an error' do + expect { validate }.to_not raise_error end end @@ -1313,8 +1285,8 @@ 'a.b' end - it 'raises an error' do - expect { validate }.to raise_error(Mongo::Error::InvalidURI) + it 'validates the hostname' do + expect { validate }.not_to raise_error end end