diff --git a/src/lib.rs b/src/lib.rs index 8d784dac5..023403a34 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,6 +58,7 @@ use multiaddr::{Multiaddr, Protocol}; use transport::Endpoint; use types::ConnectionId; +use crate::transport::manager::DialFailureAddresses; pub use bandwidth::BandwidthSink; pub use error::Error; pub use peer_id::PeerId; @@ -199,6 +200,7 @@ impl Litep2p { config.fallback_names.clone(), config.codec, litep2p_config.keep_alive_timeout, + DialFailureAddresses::NotRequired, ); let executor = Arc::clone(&litep2p_config.executor); litep2p_config.executor.run(Box::pin(async move { @@ -219,6 +221,7 @@ impl Litep2p { config.fallback_names.clone(), config.codec, litep2p_config.keep_alive_timeout, + DialFailureAddresses::NotRequired, ); litep2p_config.executor.run(Box::pin(async move { RequestResponseProtocol::new(service, config).run().await @@ -234,6 +237,7 @@ impl Litep2p { Vec::new(), protocol.codec(), litep2p_config.keep_alive_timeout, + DialFailureAddresses::NotRequired, ); litep2p_config.executor.run(Box::pin(async move { let _ = protocol.run(service).await; @@ -253,6 +257,7 @@ impl Litep2p { Vec::new(), ping_config.codec, litep2p_config.keep_alive_timeout, + DialFailureAddresses::NotRequired, ); litep2p_config.executor.run(Box::pin(async move { Ping::new(service, ping_config).run().await @@ -276,6 +281,7 @@ impl Litep2p { fallback_names, kademlia_config.codec, litep2p_config.keep_alive_timeout, + DialFailureAddresses::Required, ); litep2p_config.executor.run(Box::pin(async move { let _ = Kademlia::new(service, kademlia_config).run().await; @@ -297,6 +303,7 @@ impl Litep2p { Vec::new(), identify_config.codec, litep2p_config.keep_alive_timeout, + DialFailureAddresses::NotRequired, ); identify_config.public = Some(litep2p_config.keypair.public().into()); @@ -317,6 +324,7 @@ impl Litep2p { Vec::new(), bitswap_config.codec, litep2p_config.keep_alive_timeout, + DialFailureAddresses::NotRequired, ); litep2p_config.executor.run(Box::pin(async move { Bitswap::new(service, bitswap_config).run().await diff --git a/src/protocol/protocol_set.rs b/src/protocol/protocol_set.rs index 3993e89a7..ff84e982e 100644 --- a/src/protocol/protocol_set.rs +++ b/src/protocol/protocol_set.rs @@ -83,7 +83,8 @@ pub enum InnerTransportEvent { /// Failed to dial peer. /// - /// This is reported to that protocol which initiated the connection. + /// This is reported to that protocol which initiated the connection. The addresses are only forwarded + /// if the protocol was registered with `DialFailureAddresses::Required`. DialFailure { /// Peer ID. peer: PeerId, @@ -437,6 +438,7 @@ mod tests { use super::*; use crate::mock::substream::MockSubstream; use std::collections::HashSet; + use crate::transport::manager::DialFailureAddresses; #[tokio::test] async fn fallback_is_provided() { @@ -456,6 +458,7 @@ mod tests { ProtocolName::from("/notif/1/fallback/1"), ProtocolName::from("/notif/1/fallback/2"), ], + dial_failure_mode: DialFailureAddresses::NotRequired, }, )]), ); @@ -503,6 +506,7 @@ mod tests { ProtocolName::from("/notif/1/fallback/1"), ProtocolName::from("/notif/1/fallback/2"), ], + dial_failure_mode: DialFailureAddresses::NotRequired, }, )]), ); @@ -550,6 +554,7 @@ mod tests { ProtocolName::from("/notif/1/fallback/1"), ProtocolName::from("/notif/1/fallback/2"), ], + dial_failure_mode: DialFailureAddresses::NotRequired, }, )]), ); diff --git a/src/transport/manager/mod.rs b/src/transport/manager/mod.rs index bb8d63b89..b82fa097d 100644 --- a/src/transport/manager/mod.rs +++ b/src/transport/manager/mod.rs @@ -73,6 +73,19 @@ pub(crate) mod handle; /// Logging target for the file. const LOG_TARGET: &str = "litep2p::transport-manager"; +/// Determines if a protocol requires the list of failed addresses upon a dial failure. +/// +/// This is used during protocol registration with the `TransportManager` to specify +/// whether `InnerTransportEvent::DialFailure` events sent to this protocol should +/// include the specific multiaddresses that failed. +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub enum DialFailureAddresses { + /// The protocol needs the list of failed addresses. + Required, + /// The protocol does not need the list of failed addresses. + NotRequired, +} + /// The connection established result. #[derive(Debug, Clone, Copy, Eq, PartialEq)] enum ConnectionEstablishedResult { @@ -106,6 +119,9 @@ pub struct ProtocolContext { /// Fallback names for the protocol. pub fallback_names: Vec, + + /// Specifies if the protocol requires dial failure addresses. + pub dial_failure_mode: DialFailureAddresses, } impl ProtocolContext { @@ -114,11 +130,20 @@ impl ProtocolContext { codec: ProtocolCodec, tx: Sender, fallback_names: Vec, + dial_failure_mode: DialFailureAddresses, ) -> Self { Self { tx, codec, fallback_names, + dial_failure_mode, + } + } + + fn dial_failure_addresses(&self, addresses: &[Multiaddr]) -> Vec { + match self.dial_failure_mode { + DialFailureAddresses::Required => addresses.to_vec(), + DialFailureAddresses::NotRequired => Vec::new(), } } } @@ -402,6 +427,7 @@ impl TransportManager { fallback_names: Vec, codec: ProtocolCodec, keep_alive_timeout: Duration, + dial_failure_mode: DialFailureAddresses, ) -> TransportService { assert!(!self.protocol_names.contains(&protocol)); @@ -422,7 +448,7 @@ impl TransportManager { self.protocols.insert( protocol.clone(), - ProtocolContext::new(codec, sender, fallback_names.clone()), + ProtocolContext::new(codec, sender, fallback_names.clone(), dial_failure_mode), ); self.protocol_names.insert(protocol); self.protocol_names.extend(fallback_names); @@ -1186,10 +1212,10 @@ impl TransportManager { ?protocol, "dial failure, notify protocol", ); - match context.tx.try_send(InnerTransportEvent::DialFailure { - peer, - addresses: vec![address.clone()], - }) { + + let addresses = context.dial_failure_addresses(std::slice::from_ref(&address)); + + match context.tx.try_send(InnerTransportEvent::DialFailure { peer, addresses: addresses.clone() }) { Ok(()) => {} Err(_) => { tracing::trace!( @@ -1202,10 +1228,7 @@ impl TransportManager { ); let _ = context .tx - .send(InnerTransportEvent::DialFailure { - peer, - addresses: vec![address.clone()], - }) + .send(InnerTransportEvent::DialFailure { peer, addresses }) .await; } } @@ -1336,12 +1359,10 @@ impl TransportManager { .collect::>(); for (protocol, context) in &self.protocols { + let addresses = context.dial_failure_addresses(&addresses); let _ = match context .tx - .try_send(InnerTransportEvent::DialFailure { - peer, - addresses: addresses.clone(), - }) { + .try_send(InnerTransportEvent::DialFailure { peer, addresses: addresses.clone() }) { Ok(_) => Ok(()), Err(_) => { tracing::trace!( @@ -1354,10 +1375,7 @@ impl TransportManager { context .tx - .send(InnerTransportEvent::DialFailure { - peer, - addresses: addresses.clone(), - }) + .send(InnerTransportEvent::DialFailure { peer, addresses }) .await } }; @@ -1648,12 +1666,14 @@ mod tests { Vec::new(), ProtocolCodec::UnsignedVarint(None), KEEP_ALIVE_TIMEOUT, + DialFailureAddresses::NotRequired ); manager.register_protocol( ProtocolName::from("/notif/1"), Vec::new(), ProtocolCodec::UnsignedVarint(None), KEEP_ALIVE_TIMEOUT, + DialFailureAddresses::NotRequired ); } @@ -1668,6 +1688,7 @@ mod tests { Vec::new(), ProtocolCodec::UnsignedVarint(None), KEEP_ALIVE_TIMEOUT, + DialFailureAddresses::NotRequired, ); manager.register_protocol( ProtocolName::from("/notif/2"), @@ -1677,6 +1698,7 @@ mod tests { ], ProtocolCodec::UnsignedVarint(None), KEEP_ALIVE_TIMEOUT, + DialFailureAddresses::NotRequired, ); } @@ -1694,6 +1716,7 @@ mod tests { ], ProtocolCodec::UnsignedVarint(None), KEEP_ALIVE_TIMEOUT, + DialFailureAddresses::NotRequired, ); manager.register_protocol( ProtocolName::from("/notif/2"), @@ -1703,6 +1726,7 @@ mod tests { ], ProtocolCodec::UnsignedVarint(None), KEEP_ALIVE_TIMEOUT, + DialFailureAddresses::NotRequired, ); } diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 79c582c03..2e4249ca5 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -154,6 +154,7 @@ pub(crate) enum TransportEvent { }, /// Failed to dial remote peer. + /// The addresses are only forwarded to the protocol if it was registered with `DialFailureAddresses::Required`. DialFailure { /// Connection ID. connection_id: ConnectionId, @@ -166,6 +167,7 @@ pub(crate) enum TransportEvent { }, /// Open failure for an unnegotiated set of connections. + /// The addresses are only forwarded to the protocol if it was registered with `DialFailureAddresses::Required`. OpenFailure { /// Connection ID. connection_id: ConnectionId, diff --git a/src/transport/quic/mod.rs b/src/transport/quic/mod.rs index 2c1536317..d2ea05bc6 100644 --- a/src/transport/quic/mod.rs +++ b/src/transport/quic/mod.rs @@ -595,6 +595,7 @@ mod tests { }; use multihash::Multihash; use tokio::sync::mpsc::channel; + use crate::transport::manager::DialFailureAddresses; #[tokio::test] async fn test_quinn() { @@ -620,6 +621,7 @@ mod tests { tx: tx1, codec: ProtocolCodec::Identity(32), fallback_names: Vec::new(), + dial_failure_mode: DialFailureAddresses::NotRequired, }, )]), }; @@ -647,6 +649,7 @@ mod tests { tx: tx2, codec: ProtocolCodec::Identity(32), fallback_names: Vec::new(), + dial_failure_mode: DialFailureAddresses::NotRequired, }, )]), }; diff --git a/src/transport/tcp/mod.rs b/src/transport/tcp/mod.rs index 9d7524335..1052bfe72 100644 --- a/src/transport/tcp/mod.rs +++ b/src/transport/tcp/mod.rs @@ -710,6 +710,7 @@ mod tests { use multihash::Multihash; use std::sync::Arc; use tokio::sync::mpsc::channel; + use crate::transport::manager::DialFailureAddresses; #[tokio::test] async fn connect_and_accept_works() { @@ -736,6 +737,7 @@ mod tests { tx: tx1, codec: ProtocolCodec::Identity(32), fallback_names: Vec::new(), + dial_failure_mode: DialFailureAddresses::NotRequired, }, )]), }; @@ -767,7 +769,9 @@ mod tests { tx: tx2, codec: ProtocolCodec::Identity(32), fallback_names: Vec::new(), + dial_failure_mode: DialFailureAddresses::NotRequired, }, + )]), }; let transport_config2 = Config { @@ -830,6 +834,7 @@ mod tests { tx: tx1, codec: ProtocolCodec::Identity(32), fallback_names: Vec::new(), + dial_failure_mode: DialFailureAddresses::NotRequired, }, )]), }; @@ -861,6 +866,7 @@ mod tests { tx: tx2, codec: ProtocolCodec::Identity(32), fallback_names: Vec::new(), + dial_failure_mode: DialFailureAddresses::NotRequired, }, )]), }; @@ -919,6 +925,7 @@ mod tests { tx: tx1, codec: ProtocolCodec::Identity(32), fallback_names: Vec::new(), + dial_failure_mode: DialFailureAddresses::NotRequired, }, )]), }; @@ -957,6 +964,7 @@ mod tests { tx: tx2, codec: ProtocolCodec::Identity(32), fallback_names: Vec::new(), + dial_failure_mode: DialFailureAddresses::NotRequired, }, )]), };