Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ use multiaddr::{Multiaddr, Protocol};
use transport::Endpoint;
use types::ConnectionId;

use crate::transport::manager::DialFailureAddresses;
pub use bandwidth::BandwidthSink;
pub use error::Error;
pub use peer_id::PeerId;
Expand Down Expand Up @@ -199,6 +200,7 @@ impl Litep2p {
config.fallback_names.clone(),
config.codec,
litep2p_config.keep_alive_timeout,
DialFailureAddresses::NotRequired,
);
let executor = Arc::clone(&litep2p_config.executor);
litep2p_config.executor.run(Box::pin(async move {
Expand All @@ -219,6 +221,7 @@ impl Litep2p {
config.fallback_names.clone(),
config.codec,
litep2p_config.keep_alive_timeout,
DialFailureAddresses::NotRequired,
);
litep2p_config.executor.run(Box::pin(async move {
RequestResponseProtocol::new(service, config).run().await
Expand All @@ -234,6 +237,7 @@ impl Litep2p {
Vec::new(),
protocol.codec(),
litep2p_config.keep_alive_timeout,
DialFailureAddresses::NotRequired,
);
litep2p_config.executor.run(Box::pin(async move {
let _ = protocol.run(service).await;
Expand All @@ -253,6 +257,7 @@ impl Litep2p {
Vec::new(),
ping_config.codec,
litep2p_config.keep_alive_timeout,
DialFailureAddresses::NotRequired,
);
litep2p_config.executor.run(Box::pin(async move {
Ping::new(service, ping_config).run().await
Expand All @@ -276,6 +281,7 @@ impl Litep2p {
fallback_names,
kademlia_config.codec,
litep2p_config.keep_alive_timeout,
DialFailureAddresses::Required,
);
litep2p_config.executor.run(Box::pin(async move {
let _ = Kademlia::new(service, kademlia_config).run().await;
Expand All @@ -297,6 +303,7 @@ impl Litep2p {
Vec::new(),
identify_config.codec,
litep2p_config.keep_alive_timeout,
DialFailureAddresses::NotRequired,
);
identify_config.public = Some(litep2p_config.keypair.public().into());

Expand All @@ -317,6 +324,7 @@ impl Litep2p {
Vec::new(),
bitswap_config.codec,
litep2p_config.keep_alive_timeout,
DialFailureAddresses::NotRequired,
);
litep2p_config.executor.run(Box::pin(async move {
Bitswap::new(service, bitswap_config).run().await
Expand Down
7 changes: 6 additions & 1 deletion src/protocol/protocol_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ pub enum InnerTransportEvent {

/// Failed to dial peer.
///
/// This is reported to that protocol which initiated the connection.
/// This is reported to that protocol which initiated the connection. The addresses are only forwarded
/// if the protocol was registered with `DialFailureAddresses::Required`.
DialFailure {
/// Peer ID.
peer: PeerId,
Expand Down Expand Up @@ -437,6 +438,7 @@ mod tests {
use super::*;
use crate::mock::substream::MockSubstream;
use std::collections::HashSet;
use crate::transport::manager::DialFailureAddresses;

#[tokio::test]
async fn fallback_is_provided() {
Expand All @@ -456,6 +458,7 @@ mod tests {
ProtocolName::from("/notif/1/fallback/1"),
ProtocolName::from("/notif/1/fallback/2"),
],
dial_failure_mode: DialFailureAddresses::NotRequired,
},
)]),
);
Expand Down Expand Up @@ -503,6 +506,7 @@ mod tests {
ProtocolName::from("/notif/1/fallback/1"),
ProtocolName::from("/notif/1/fallback/2"),
],
dial_failure_mode: DialFailureAddresses::NotRequired,
},
)]),
);
Expand Down Expand Up @@ -550,6 +554,7 @@ mod tests {
ProtocolName::from("/notif/1/fallback/1"),
ProtocolName::from("/notif/1/fallback/2"),
],
dial_failure_mode: DialFailureAddresses::NotRequired,
},
)]),
);
Expand Down
58 changes: 41 additions & 17 deletions src/transport/manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@ pub(crate) mod handle;
/// Logging target for the file.
const LOG_TARGET: &str = "litep2p::transport-manager";

/// Determines if a protocol requires the list of failed addresses upon a dial failure.
///
/// This is used during protocol registration with the `TransportManager` to specify
/// whether `InnerTransportEvent::DialFailure` events sent to this protocol should
/// include the specific multiaddresses that failed.
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum DialFailureAddresses {
/// The protocol needs the list of failed addresses.
Required,
/// The protocol does not need the list of failed addresses.
NotRequired,
}

/// The connection established result.
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum ConnectionEstablishedResult {
Expand Down Expand Up @@ -106,6 +119,9 @@ pub struct ProtocolContext {

/// Fallback names for the protocol.
pub fallback_names: Vec<ProtocolName>,

/// Specifies if the protocol requires dial failure addresses.
pub dial_failure_mode: DialFailureAddresses,
}

impl ProtocolContext {
Expand All @@ -114,11 +130,20 @@ impl ProtocolContext {
codec: ProtocolCodec,
tx: Sender<InnerTransportEvent>,
fallback_names: Vec<ProtocolName>,
dial_failure_mode: DialFailureAddresses,
) -> Self {
Self {
tx,
codec,
fallback_names,
dial_failure_mode,
}
}

fn dial_failure_addresses(&self, addresses: &[Multiaddr]) -> Vec<Multiaddr> {
match self.dial_failure_mode {
DialFailureAddresses::Required => addresses.to_vec(),
DialFailureAddresses::NotRequired => Vec::new(),
}
}
}
Expand Down Expand Up @@ -402,6 +427,7 @@ impl TransportManager {
fallback_names: Vec<ProtocolName>,
codec: ProtocolCodec,
keep_alive_timeout: Duration,
dial_failure_mode: DialFailureAddresses,
) -> TransportService {
assert!(!self.protocol_names.contains(&protocol));

Expand All @@ -422,7 +448,7 @@ impl TransportManager {

self.protocols.insert(
protocol.clone(),
ProtocolContext::new(codec, sender, fallback_names.clone()),
ProtocolContext::new(codec, sender, fallback_names.clone(), dial_failure_mode),
);
self.protocol_names.insert(protocol);
self.protocol_names.extend(fallback_names);
Expand Down Expand Up @@ -1186,10 +1212,10 @@ impl TransportManager {
?protocol,
"dial failure, notify protocol",
);
match context.tx.try_send(InnerTransportEvent::DialFailure {
peer,
addresses: vec![address.clone()],
}) {

let addresses = context.dial_failure_addresses(std::slice::from_ref(&address));

match context.tx.try_send(InnerTransportEvent::DialFailure { peer, addresses: addresses.clone() }) {
Ok(()) => {}
Err(_) => {
tracing::trace!(
Expand All @@ -1202,10 +1228,7 @@ impl TransportManager {
);
let _ = context
.tx
.send(InnerTransportEvent::DialFailure {
peer,
addresses: vec![address.clone()],
})
.send(InnerTransportEvent::DialFailure { peer, addresses })
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dharjeezy Can you also add to the doc of InnerTransportEvent & TransportEvent that addresses is only forwarded if the protocol was registered with DialFailureAddresses::Required, please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is done now @dmitry-markin. can we merge it in now?

.await;
}
}
Expand Down Expand Up @@ -1336,12 +1359,10 @@ impl TransportManager {
.collect::<Vec<_>>();

for (protocol, context) in &self.protocols {
let addresses = context.dial_failure_addresses(&addresses);
let _ = match context
.tx
.try_send(InnerTransportEvent::DialFailure {
peer,
addresses: addresses.clone(),
}) {
.try_send(InnerTransportEvent::DialFailure { peer, addresses: addresses.clone() }) {
Ok(_) => Ok(()),
Err(_) => {
tracing::trace!(
Expand All @@ -1354,10 +1375,7 @@ impl TransportManager {

context
.tx
.send(InnerTransportEvent::DialFailure {
peer,
addresses: addresses.clone(),
})
.send(InnerTransportEvent::DialFailure { peer, addresses })
.await
}
};
Expand Down Expand Up @@ -1648,12 +1666,14 @@ mod tests {
Vec::new(),
ProtocolCodec::UnsignedVarint(None),
KEEP_ALIVE_TIMEOUT,
DialFailureAddresses::NotRequired
);
manager.register_protocol(
ProtocolName::from("/notif/1"),
Vec::new(),
ProtocolCodec::UnsignedVarint(None),
KEEP_ALIVE_TIMEOUT,
DialFailureAddresses::NotRequired
);
}

Expand All @@ -1668,6 +1688,7 @@ mod tests {
Vec::new(),
ProtocolCodec::UnsignedVarint(None),
KEEP_ALIVE_TIMEOUT,
DialFailureAddresses::NotRequired,
);
manager.register_protocol(
ProtocolName::from("/notif/2"),
Expand All @@ -1677,6 +1698,7 @@ mod tests {
],
ProtocolCodec::UnsignedVarint(None),
KEEP_ALIVE_TIMEOUT,
DialFailureAddresses::NotRequired,
);
}

Expand All @@ -1694,6 +1716,7 @@ mod tests {
],
ProtocolCodec::UnsignedVarint(None),
KEEP_ALIVE_TIMEOUT,
DialFailureAddresses::NotRequired,
);
manager.register_protocol(
ProtocolName::from("/notif/2"),
Expand All @@ -1703,6 +1726,7 @@ mod tests {
],
ProtocolCodec::UnsignedVarint(None),
KEEP_ALIVE_TIMEOUT,
DialFailureAddresses::NotRequired,
);
}

Expand Down
2 changes: 2 additions & 0 deletions src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ pub(crate) enum TransportEvent {
},

/// Failed to dial remote peer.
/// The addresses are only forwarded to the protocol if it was registered with `DialFailureAddresses::Required`.
DialFailure {
/// Connection ID.
connection_id: ConnectionId,
Expand All @@ -166,6 +167,7 @@ pub(crate) enum TransportEvent {
},

/// Open failure for an unnegotiated set of connections.
/// The addresses are only forwarded to the protocol if it was registered with `DialFailureAddresses::Required`.
OpenFailure {
/// Connection ID.
connection_id: ConnectionId,
Expand Down
3 changes: 3 additions & 0 deletions src/transport/quic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ mod tests {
};
use multihash::Multihash;
use tokio::sync::mpsc::channel;
use crate::transport::manager::DialFailureAddresses;

#[tokio::test]
async fn test_quinn() {
Expand All @@ -620,6 +621,7 @@ mod tests {
tx: tx1,
codec: ProtocolCodec::Identity(32),
fallback_names: Vec::new(),
dial_failure_mode: DialFailureAddresses::NotRequired,
},
)]),
};
Expand Down Expand Up @@ -647,6 +649,7 @@ mod tests {
tx: tx2,
codec: ProtocolCodec::Identity(32),
fallback_names: Vec::new(),
dial_failure_mode: DialFailureAddresses::NotRequired,
},
)]),
};
Expand Down
8 changes: 8 additions & 0 deletions src/transport/tcp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,7 @@ mod tests {
use multihash::Multihash;
use std::sync::Arc;
use tokio::sync::mpsc::channel;
use crate::transport::manager::DialFailureAddresses;

#[tokio::test]
async fn connect_and_accept_works() {
Expand All @@ -736,6 +737,7 @@ mod tests {
tx: tx1,
codec: ProtocolCodec::Identity(32),
fallback_names: Vec::new(),
dial_failure_mode: DialFailureAddresses::NotRequired,
},
)]),
};
Expand Down Expand Up @@ -767,7 +769,9 @@ mod tests {
tx: tx2,
codec: ProtocolCodec::Identity(32),
fallback_names: Vec::new(),
dial_failure_mode: DialFailureAddresses::NotRequired,
},

)]),
};
let transport_config2 = Config {
Expand Down Expand Up @@ -830,6 +834,7 @@ mod tests {
tx: tx1,
codec: ProtocolCodec::Identity(32),
fallback_names: Vec::new(),
dial_failure_mode: DialFailureAddresses::NotRequired,
},
)]),
};
Expand Down Expand Up @@ -861,6 +866,7 @@ mod tests {
tx: tx2,
codec: ProtocolCodec::Identity(32),
fallback_names: Vec::new(),
dial_failure_mode: DialFailureAddresses::NotRequired,
},
)]),
};
Expand Down Expand Up @@ -919,6 +925,7 @@ mod tests {
tx: tx1,
codec: ProtocolCodec::Identity(32),
fallback_names: Vec::new(),
dial_failure_mode: DialFailureAddresses::NotRequired,
},
)]),
};
Expand Down Expand Up @@ -957,6 +964,7 @@ mod tests {
tx: tx2,
codec: ProtocolCodec::Identity(32),
fallback_names: Vec::new(),
dial_failure_mode: DialFailureAddresses::NotRequired,
},
)]),
};
Expand Down