diff --git a/.gitignore b/.gitignore index c85eb61d..41be2774 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ .local .vscode + +dist \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 0eaea59c..d035fa57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -902,9 +902,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" [[package]] name = "bzip2" @@ -2813,6 +2813,7 @@ dependencies = [ "arc-swap", "bincode", "borsh 1.5.7", + "bytes", "clap", "crossbeam-channel", "dashmap", @@ -2823,6 +2824,7 @@ dependencies = [ "lazy_static", "libc", "log", + "mio", "prometheus", "prost 0.13.5", "prost-types 0.13.5", @@ -2935,9 +2937,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.176" +version = "0.2.180" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58f929b4d672ea937a23a1ab494143d968337a5f47e56d0815df1e0890ddf174" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" [[package]] name = "libloading" @@ -3201,13 +3203,14 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.4" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", + "log", "wasi 0.11.1+wasi-snapshot-preview1", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index b72a26d7..ca86daa3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,10 @@ members = ["examples", "jito_protos", "proxy"] resolver = "2" +[profile.debug-release] +inherits = "release" +debug = true + [workspace.package] version = "0.2.12-triton" description = "Fast path to receive shreds from Jito, forwarding to local consumers. See https://docs.jito.wtf/lowlatencytxnfeed/ for details." @@ -14,6 +18,7 @@ ahash = "0.8" arc-swap = "1.6" bincode = "1.3.3" borsh = "1.5.3" +bytes = "1.11.0" clap = { version = "4", features = ["derive", "env"] } crossbeam-channel = "0.5.8" dashmap = "5" @@ -24,6 +29,7 @@ jito-protos = { path = "jito_protos" } lazy_static = "1.4.0" libc = "0.2" log = "0.4" +mio = "1.1.1" prost = "0.13" prost-types = "0.13" prometheus = "0.14.0" diff --git a/data-sample.txt b/data-sample.txt new file mode 100644 index 00000000..568c811a --- /dev/null +++ b/data-sample.txt @@ -0,0 +1,75 @@ +set 1: + +shredstream_recv_interval_usec_bucket{le="1"} 41 + +shredstream_recv_interval_usec_bucket{le="5"} 63 + +shredstream_recv_interval_usec_bucket{le="10"} 125 + +shredstream_recv_interval_usec_bucket{le="25"} 47676 + +shredstream_recv_interval_usec_bucket{le="50"} 104430 + +shredstream_recv_interval_usec_bucket{le="100"} 162673 + +shredstream_recv_interval_usec_bucket{le="200"} 190777 + +shredstream_recv_interval_usec_bucket{le="500"} 205050 + +shredstream_recv_interval_usec_bucket{le="1000"} 210046 + +shredstream_recv_interval_usec_bucket{le="2000"} 212204 + +shredstream_recv_interval_usec_bucket{le="+Inf"} 214080 + + + +set 2: + +shredstream_recv_interval_usec_bucket{le="1"} 0 + +shredstream_recv_interval_usec_bucket{le="5"} 0 + +shredstream_recv_interval_usec_bucket{le="10"} 22 + +shredstream_recv_interval_usec_bucket{le="25"} 864700 + +shredstream_recv_interval_usec_bucket{le="50"} 1059516 + +shredstream_recv_interval_usec_bucket{le="100"} 1334130 + +shredstream_recv_interval_usec_bucket{le="200"} 1473381 + +shredstream_recv_interval_usec_bucket{le="500"} 1545124 + +shredstream_recv_interval_usec_bucket{le="1000"} 1569639 + +shredstream_recv_interval_usec_bucket{le="2000"} 1580383 + +shredstream_recv_interval_usec_bucket{le="+Inf"} 1589948 + + + +set 3 : + +shredstream_recv_interval_usec_bucket{le="1"} 0 + +shredstream_recv_interval_usec_bucket{le="5"} 0 + +shredstream_recv_interval_usec_bucket{le="10"} 2 + +shredstream_recv_interval_usec_bucket{le="25"} 129306 + +shredstream_recv_interval_usec_bucket{le="50"} 159982 + +shredstream_recv_interval_usec_bucket{le="100"} 202752 + +shredstream_recv_interval_usec_bucket{le="200"} 225469 + +shredstream_recv_interval_usec_bucket{le="500"} 238215 + +shredstream_recv_interval_usec_bucket{le="1000"} 242741 + +shredstream_recv_interval_usec_bucket{le="2000"} 244727 + +shredstream_recv_interval_usec_bucket{le="+Inf"} 246249 \ No newline at end of file diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 163260bf..0de71da2 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -6,11 +6,22 @@ authors = { workspace = true } homepage = { workspace = true } edition = { workspace = true } +[[bin]] +name = "triton-shredproxy" +path = "src/main2.rs" + +[[bin]] +name = "jito-shredstream-proxy" +path = "src/main.rs" + + + [dependencies] ahash = { workspace = true } arc-swap = { workspace = true } bincode = { workspace = true } borsh = { workspace = true } +bytes = { workspace = true } clap = { workspace = true } crossbeam-channel = { workspace = true } dashmap = { workspace = true } @@ -21,6 +32,7 @@ jito-protos = { workspace = true } lazy_static = { workspace = true } log = { workspace = true } libc = { workspace = true } +mio = { workspace = true } prometheus = { workspace = true } prost = { workspace = true } prost-types = { workspace = true } diff --git a/proxy/src/forwarder.rs b/proxy/src/forwarder.rs index 80275d57..6564f283 100644 --- a/proxy/src/forwarder.rs +++ b/proxy/src/forwarder.rs @@ -14,8 +14,8 @@ use crossbeam_channel::{Receiver, RecvError}; use dashmap::DashMap; use itertools::Itertools; use jito_protos::shredstream::{Entry as PbEntry, TraceShred}; -use log::{debug, error, info, warn}; use libc; +use log::{debug, error, info, warn}; use prost::Message; use socket2::{Domain, Protocol, Socket, Type}; use solana_client::client_error::reqwest; @@ -35,15 +35,13 @@ use solana_streamer::{ use tokio::sync::broadcast::Sender; use crate::{ - ShredstreamProxyError, deshred::{self, ComparableShred, ShredsStateTracker}, prom::{ - observe_dedup_time, observe_send_packet_count, observe_send_duration, - observe_recv_interval, observe_recv_packet_count, - inc_packets_received, inc_packets_deduped, inc_packets_forwarded, - inc_packets_forward_failed, inc_packets_by_source, + inc_packets_by_source, inc_packets_deduped, inc_packets_forward_failed, + inc_packets_forwarded, inc_packets_received, observe_dedup_time, observe_recv_interval, + observe_recv_packet_count, observe_send_duration, observe_send_packet_count, }, - resolve_hostname_port, + resolve_hostname_port, ShredstreamProxyError, }; // values copied from https://github.com/solana-labs/solana/blob/33bde55bbdde13003acf45bb6afe6db4ab599ae4/core/src/sigverify_shreds.rs#L20 @@ -141,7 +139,7 @@ pub fn start_forwarder_threads( thread_hdls.push(hdl); }; - let mut ret = sockets + let mut ret = sockets .into_iter() .chain(maybe_multicast_socket.unwrap_or_default()) .enumerate() @@ -167,13 +165,11 @@ pub fn start_forwarder_threads( let reconstruct_tx = reconstruct_tx.clone(); let exit = exit.clone(); - let send_thread = Builder::new() .name(format!("ssPxyTx_{thread_id}")) .spawn(move || { - let dont_send_to_origin = move |origin: IpAddr, dest: SocketAddr| { - origin != dest.ip() - }; + let dont_send_to_origin = + move |origin: IpAddr, dest: SocketAddr| origin != dest.ip(); let send_socket = { let ipv6_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0); match try_create_ipv6_socket(ipv6_addr) { @@ -184,7 +180,8 @@ pub fn start_forwarder_threads( Err(e) if e.raw_os_error() == Some(libc::EAFNOSUPPORT) => { // This error (code 97 on Linux) means IPv6 is not supported. warn!("IPv6 not available. Falling back to IPv4-only for sending."); - let ipv4_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); + let ipv4_addr = + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); UdpSocket::bind(ipv4_addr) .expect("Failed to bind to IPv4 socket after IPv6 failed") } @@ -249,20 +246,20 @@ pub fn start_forwarder_threads( if let Some((multicast_origin, multicast_socket)) = maybe_triton_multicast_socket { start_multicast_forwarder_thread( - multicast_origin, - multicast_socket, - recycler, - reconstruct_tx, - unioned_dest_sockets, - deduper, - forward_stats, - metrics, - debug_trace_shred, - should_reconstruct_shreds, - use_discovery_service, - shutdown_receiver, - exit, - &mut ret + multicast_origin, + multicast_socket, + recycler, + reconstruct_tx, + unioned_dest_sockets, + deduper, + forward_stats, + metrics, + debug_trace_shred, + should_reconstruct_shreds, + use_discovery_service, + shutdown_receiver, + exit, + &mut ret, ); } ret @@ -270,8 +267,7 @@ pub fn start_forwarder_threads( /// /// Try to create an IPv6 UDP socket bound to the given address. -/// -fn try_create_ipv6_socket(addr: SocketAddr) -> Result { +pub fn try_create_ipv6_socket(addr: SocketAddr) -> Result { let ipv6_socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; ipv6_socket.set_multicast_hops_v6(IP_MULTICAST_TTL)?; ipv6_socket.bind(&addr.into())?; @@ -323,21 +319,19 @@ pub fn start_multicast_forwarder_thread( let reconstruct_tx = reconstruct_tx.clone(); let exit = exit.clone(); - - let send_thread = Builder::new() .name(format!("ssPxyTxMulticast_{thread_id}")) .spawn(move || { - let dont_send_to_mc_origin = move |_origin, dest: SocketAddr| { - dest.ip() != multicast_origin - }; + let dont_send_to_mc_origin = + move |_origin, dest: SocketAddr| dest.ip() != multicast_origin; let send_socket = { let ipv6_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0); match try_create_ipv6_socket(ipv6_addr) { Ok(socket) => { info!("Successfully bound send socket to IPv6 dual-stack address."); - socket.set_multicast_loop_v6(false) - .expect("Failed to disable IPv6 multicast loopback"); + socket + .set_multicast_loop_v6(false) + .expect("Failed to disable IPv6 multicast loopback"); socket } Err(e) if e.raw_os_error() == Some(libc::EAFNOSUPPORT) => { @@ -346,8 +340,11 @@ pub fn start_multicast_forwarder_thread( let ipv4_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); let socket = UdpSocket::bind(ipv4_addr) .expect("Failed to bind to IPv4 socket after IPv6 failed"); - socket.set_multicast_ttl_v4(IP_MULTICAST_TTL).expect("IP_MULTICAST_TTL_V4"); - socket.set_multicast_loop_v4(false) + socket + .set_multicast_ttl_v4(IP_MULTICAST_TTL) + .expect("IP_MULTICAST_TTL_V4"); + socket + .set_multicast_loop_v4(false) .expect("Failed to disable IPv4 multicast loopback"); socket } @@ -424,16 +421,14 @@ fn recv_from_channel_and_send_multiple_dest( reconstruct_tx: &crossbeam_channel::Sender, debug_trace_shred: bool, metrics: &ShredMetrics, -) -> Result<(), ShredstreamProxyError> +) -> Result<(), ShredstreamProxyError> where F: Fn(IpAddr, SocketAddr) -> bool, { let packet_batch = maybe_packet_batch.map_err(ShredstreamProxyError::RecvError)?; let trace_shred_received_time = SystemTime::now(); let batch_len = packet_batch.len() as u64; - metrics - .received - .fetch_add(batch_len, Ordering::Relaxed); + metrics.received.fetch_add(batch_len, Ordering::Relaxed); inc_packets_received(batch_len); observe_recv_packet_count(batch_len as f64); debug!( @@ -454,7 +449,9 @@ where &mut packet_batch_vec, ); let t_dedup_usecs = t.elapsed().as_micros() as u64; - metrics.dedup_time_spent.fetch_add(t_dedup_usecs, Ordering::Relaxed); + metrics + .dedup_time_spent + .fetch_add(t_dedup_usecs, Ordering::Relaxed); observe_dedup_time(t_dedup_usecs as f64); inc_packets_deduped(num_deduped); @@ -470,10 +467,12 @@ where *discarded += is_discarded as u64; *not_discarded += (!is_discarded) as u64; }) - .or_insert_with(|| { - (is_discarded as u64, (!is_discarded) as u64) - }); - let status = if is_discarded { "discarded" } else { "forwarded" }; + .or_insert_with(|| (is_discarded as u64, (!is_discarded) as u64)); + let status = if is_discarded { + "discarded" + } else { + "forwarded" + }; inc_packets_by_source(&addr.to_string(), status, 1); }); }); @@ -502,10 +501,14 @@ where .fetch_add(packets_with_dest.len() as u64, Ordering::Relaxed); metrics.send_batch_count.fetch_add(1, Ordering::Relaxed); const MAX_IOV: usize = libc::UIO_MAXIOV as usize; - let max_iov_count = packets_with_dest.len() / MAX_IOV; + let max_iov_count = packets_with_dest.len() / MAX_IOV; let unsaturated_iov_count = packets_with_dest.len() % MAX_IOV; - metrics.saturated_iov_count.fetch_add(max_iov_count as u64, Ordering::Relaxed); - metrics.unsaturated_iov_count.fetch_add(unsaturated_iov_count as u64, Ordering::Relaxed); + metrics + .saturated_iov_count + .fetch_add(max_iov_count as u64, Ordering::Relaxed); + metrics + .unsaturated_iov_count + .fetch_add(unsaturated_iov_count as u64, Ordering::Relaxed); observe_send_packet_count(packets_with_dest.len() as f64); match batch_send(send_socket, &packets_with_dest) { Ok(_) => { @@ -531,7 +534,9 @@ where } } let t_send_usecs = t.elapsed().as_micros() as u64; - metrics.batch_send_time_spent.fetch_add(t_send_usecs, Ordering::Relaxed); + metrics + .batch_send_time_spent + .fetch_add(t_send_usecs, Ordering::Relaxed); observe_send_duration(t_send_usecs as f64); }); @@ -789,7 +794,11 @@ impl ShredMetrics { datapoint_info!( "shredstream_proxy-sendmmsg_iov_metrics", - ("max_iov_count", self.saturated_iov_count.load(Ordering::Relaxed), i64), + ( + "max_iov_count", + self.saturated_iov_count.load(Ordering::Relaxed), + i64 + ), ( "unsaturated_iov_count", self.unsaturated_iov_count.load(Ordering::Relaxed), @@ -798,12 +807,16 @@ impl ShredMetrics { ); datapoint_info!( - "shredstream_proxy-batch_send_metrics", + "shredstream_proxy-batch_send_metrics", ( - "send_batch_size_sum", self.send_batch_size_sum.load(Ordering::Relaxed), i64 + "send_batch_size_sum", + self.send_batch_size_sum.load(Ordering::Relaxed), + i64 ), ( - "send_batch_count", self.send_batch_count.load(Ordering::Relaxed), i64 + "send_batch_count", + self.send_batch_count.load(Ordering::Relaxed), + i64 ) ); @@ -821,7 +834,6 @@ impl ShredMetrics { ), ); - if self.enabled_grpc_service { datapoint_info!( "shredstream_proxy-service_metrics", diff --git a/proxy/src/main.rs b/proxy/src/main.rs index d04e3338..487db536 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -28,9 +28,13 @@ mod deshred; pub mod forwarder; mod heartbeat; mod multicast_config; +mod triton_multicast_config; mod server; mod token_authenticator; mod prom; +mod recv_mmsg; +mod mem; +mod triton_forwarder; #[derive(Clone, Debug, Parser)] #[clap(author, version, about, long_about = None)] diff --git a/proxy/src/main2.rs b/proxy/src/main2.rs new file mode 100644 index 00000000..7b21a122 --- /dev/null +++ b/proxy/src/main2.rs @@ -0,0 +1,491 @@ +use std::{ + collections::HashMap, io::{self, Error, ErrorKind}, net::{IpAddr, Ipv4Addr, SocketAddr, ToSocketAddrs}, num::NonZeroUsize, panic, path::{Path, PathBuf}, str::FromStr, sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, thread::{self, sleep, spawn, JoinHandle}, time::Duration +}; + +use arc_swap::ArcSwap; +use clap::{arg, Parser}; +use crossbeam_channel::{Receiver, RecvError, Sender}; +use log::*; +use signal_hook::consts::{SIGINT, SIGTERM}; +use solana_client::client_error::{reqwest, ClientError}; +use solana_ledger::shred::Shred; +use solana_metrics::set_host_id; +use solana_sdk::{clock::Slot, signature::read_keypair_file}; +use solana_streamer::streamer::StreamerReceiveStats; +use thiserror::Error; +use tokio::runtime::Runtime; +use tonic::Status; + +use crate::{ + forwarder::ShredMetrics, recv_mmsg::FECSetRoutingStrategy, token_authenticator::BlockEngineConnectionError, triton_forwarder::PktRecvTileMemConfig, triton_multicast_config::{TritonMulticastConfig, TritonMulticastConfigV4, TritonMulticastConfigV6, create_multicast_socket_on_device, create_multicast_sockets_triton} +}; +pub mod deshred; +pub mod forwarder; +pub mod heartbeat; +pub mod triton_multicast_config; +pub mod server; +pub mod token_authenticator; +pub mod prom; +pub mod recv_mmsg; +pub mod mem; +pub mod triton_forwarder; + +use triton_forwarder::{PktRecvMemSizing}; + +#[derive(Clone, Debug, Parser)] +#[clap(author, version, about, long_about = None)] +// https://docs.rs/clap/latest/clap/_derive/_cookbook/git_derive/index.html +struct Args { + #[command(subcommand)] + shredstream_args: ProxySubcommands, +} + +#[derive(Clone, Debug, clap::Subcommand)] +enum ProxySubcommands { + /// Requests shreds from Jito and sends to all destinations. + Shredstream(ShredstreamArgs), + + /// Does not request shreds from Jito. Sends anything received on `src-bind-addr`:`src-bind-port` to all destinations. + ForwardOnly(CommonArgs), +} + +#[derive(clap::Args, Clone, Debug)] +struct ShredstreamArgs { + /// Address for Jito Block Engine. + /// See https://jito-labs.gitbook.io/mev/searcher-resources/block-engine#connection-details + #[arg(long, env)] + block_engine_url: String, + + /// Manual override for auth service address. For internal use. + #[arg(long, env)] + auth_url: Option, + + /// Path to keypair file used to authenticate with the backend. + #[arg(long, env)] + auth_keypair: PathBuf, + + /// Desired regions to receive heartbeats from. + /// Receives `n` different streams. Requires at least 1 region, comma separated. + #[arg(long, env, value_delimiter = ',', required(true))] + desired_regions: Vec, + + #[clap(flatten)] + common_args: CommonArgs, +} + +#[derive(clap::Args, Clone, Debug)] +struct CommonArgs { + /// Address where Shredstream proxy listens. + #[arg(long, env, default_value_t = IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)))] + src_bind_addr: IpAddr, + + /// Port where Shredstream proxy listens. Use `0` for random ephemeral port. + #[arg(long, env, default_value_t = 20_000)] + src_bind_port: u16, + + /// Multicast IP to listen for shreds. If none provided, attempts to + /// parse multicast routes for the device specified by `--multicast-device` + /// via `ip --json route show dev `. + #[arg(long, env)] + multicast_bind_ip: Option, + + /// Network device to use for multicast route discovery and interface selection. + /// Example: `eth0`, `en0`, or `doublezero1`. + #[arg(long, env, default_value = "doublezero1")] + multicast_device: String, + + /// Port to receive multicast shreds + #[arg(long, env, default_value_t = 20001)] + multicast_subscribe_port: u16, + + /// Static set of IP:Port where Shredstream proxy forwards shreds to, comma separated. + /// Eg. `127.0.0.1:8001,10.0.0.1:8001`. + // Note: store the original string, so we can do hostname resolution when refreshing destinations + #[arg(long, env, value_delimiter = ',', value_parser = resolve_hostname_port)] + dest_ip_ports: Vec<(SocketAddr, String)>, + + /// Http JSON endpoint to dynamically get IPs for Shredstream proxy to forward shreds. + /// Endpoints are then set-union with `dest-ip-ports`. + #[arg(long, env)] + endpoint_discovery_url: Option, + + /// Port to send shreds to for hosts fetched via `endpoint-discovery-url`. + /// Port can be found using `scripts/get_tvu_port.sh`. + /// See https://jito-labs.gitbook.io/mev/searcher-services/shredstream#running-shredstream + #[arg(long, env)] + discovered_endpoints_port: Option, + + /// Interval between logging stats to stdout and influx + #[arg(long, env, default_value_t = 15_000)] + metrics_report_interval_ms: u64, + + /// Logs trace shreds to stdout and influx + #[arg(long, env, default_value_t = false)] + debug_trace_shred: bool, + + /// Public IP address to use. + /// Overrides value fetched from `ifconfig.me`. + #[arg(long, env)] + public_ip: Option, + + /// Number of threads to use. Defaults to use up to 4. + #[arg(long, env)] + num_threads: Option, + + /// + /// The multicast group (ip addr) to join for receiving shreds. + /// Multicast groups supports IPv4 and IPv6. + #[arg(long, env)] + triton_multicast_group: Option, + /// The interface to bind to for triton multicast. + /// If IPV6 is used, this argument must be provided. + /// If ipv4, then optional (listen on all interfaces if not provided). + #[arg(long, env)] + triton_multicast_bind_interface: Option, + + /// + /// The port to listen on for triton multicast. + /// If not provided, defaults to 8002. + /// NOTE: this port must match the port used by the triton multicast sender. + #[arg(long, env)] + triton_multicast_port: Option, + + /// Address to bind prometheus metrics server to. If not provided, prometheus server is disabled. + #[arg(long, env)] + prometheus_bind_addr: Option, + + /// Number of tiles dedicated to receiving packets. If not provided, defaults to number of CPU cores is 1. + #[arg(long, env)] + num_pkt_recv_tile: Option, + + /// Number of tiles dedicated to forwarding packets. If not provided, defaults to number of CPU cores is 1. + #[arg(long, env)] + num_pkt_fwd_tile: Option, + + /// Memory sizing for EACH packet receiver, uses t-shirt size convention (xs (default),s,m,l,xl,2xl,3xl,4xl,5xl). Each size increase double the memory, starting at 128MiB for x-small. + #[arg(long, env)] + pkt_recv_channel_memsize: Option, + + /// Use hugepage memory for pkt recv tiles shared memory. + #[arg(long, env, default_value_t = false)] + hugepage: bool, +} + +#[derive(Debug, Error)] +pub enum ShredstreamProxyError { + #[error("TonicError {0}")] + TonicError(#[from] tonic::transport::Error), + #[error("GrpcError {0}")] + GrpcError(#[from] Status), + #[error("ReqwestError {0}")] + ReqwestError(#[from] reqwest::Error), + #[error("SerdeJsonError {0}")] + SerdeJsonError(#[from] serde_json::Error), + #[error("RpcError {0}")] + RpcError(#[from] ClientError), + #[error("BlockEngineConnectionError {0}")] + BlockEngineConnectionError(#[from] BlockEngineConnectionError), + #[error("RecvError {0}")] + RecvError(#[from] RecvError), + #[error("IoError {0}")] + IoError(#[from] io::Error), + #[error("Shutdown")] + Shutdown, +} + +fn resolve_hostname_port(hostname_port: &str) -> io::Result<(SocketAddr, String)> { + let socketaddr = hostname_port.to_socket_addrs()?.next().ok_or_else(|| { + Error::new( + ErrorKind::AddrNotAvailable, + format!("Could not find destination {hostname_port}"), + ) + })?; + + Ok((socketaddr, hostname_port.to_string())) +} + +/// Returns public-facing IPV4 address +pub fn get_public_ip() -> reqwest::Result { + info!("Requesting public ip from ifconfig.me..."); + let client = reqwest::blocking::Client::builder() + .local_address(IpAddr::V4(Ipv4Addr::UNSPECIFIED)) + .build()?; + let response = client.get("https://ifconfig.me/ip").send()?.text()?; + let public_ip = IpAddr::from_str(&response).unwrap(); + info!("Retrieved public ip: {public_ip:?}"); + + Ok(public_ip) +} + +// Creates a channel that gets a message every time `SIGINT` is signalled. +fn shutdown_notifier(exit: Arc) -> io::Result<(Sender<()>, Receiver<()>)> { + let (s, r) = crossbeam_channel::bounded(256); + let mut signals = signal_hook::iterator::Signals::new([SIGINT, SIGTERM])?; + + let s_thread = s.clone(); + thread::spawn(move || { + for _ in signals.forever() { + exit.store(true, Ordering::SeqCst); + // send shutdown signal multiple times since crossbeam doesn't have broadcast channels + // each thread will consume a shutdown signal + for _ in 0..256 { + if s_thread.send(()).is_err() { + break; + } + } + } + }); + + Ok((s, r)) +} + +pub type ReconstructedShredsMap = HashMap>>; +fn main() -> Result<(), ShredstreamProxyError> { + env_logger::builder().init(); + let prom_registry = prometheus::Registry::new(); + prom::register_metrics(&prom_registry); + let all_args: Args = Args::parse(); + let shredstream_args = all_args.shredstream_args.clone(); + // common args + let args = match all_args.shredstream_args { + ProxySubcommands::Shredstream(x) => x.common_args, + ProxySubcommands::ForwardOnly(x) => x, + }; + + + let num_pkt_recv_tiles = args.num_pkt_recv_tile + .map(|x| x.get()) + .unwrap_or(args.num_threads.unwrap_or(1)); + + let num_pkt_fwd_tiles = args.num_pkt_fwd_tile + .map(|x| x.get()) + .unwrap_or(args.num_threads.unwrap_or(1)); + + set_host_id(hostname::get()?.into_string().unwrap()); + if (args.endpoint_discovery_url.is_none() && args.discovered_endpoints_port.is_some()) + || (args.endpoint_discovery_url.is_some() && args.discovered_endpoints_port.is_none()) + { + return Err(ShredstreamProxyError::IoError(io::Error::new(ErrorKind::InvalidInput, "Invalid arguments provided, dynamic endpoints requires both --endpoint-discovery-url and --discovered-endpoints-port."))); + } + if args.endpoint_discovery_url.is_none() + && args.discovered_endpoints_port.is_none() + && args.dest_ip_ports.is_empty() + { + return Err(ShredstreamProxyError::IoError(io::Error::new(ErrorKind::InvalidInput, "No destinations found. You must provide values for --dest-ip-ports or --endpoint-discovery-url."))); + } + + let exit = Arc::new(AtomicBool::new(false)); + let (shutdown_sender, shutdown_receiver) = + shutdown_notifier(exit.clone()).expect("Failed to set up signal handler"); + let panic_hook = panic::take_hook(); + { + let exit = exit.clone(); + panic::set_hook(Box::new(move |panic_info| { + exit.store(true, Ordering::SeqCst); + let _ = shutdown_sender.send(()); + error!("exiting process"); + sleep(Duration::from_secs(1)); + // invoke the default handler and exit the process + panic_hook(panic_info); + })); + } + + let metrics = Arc::new(ShredMetrics::new(false)); + + + let mut thread_handles = vec![]; + if let ProxySubcommands::Shredstream(args) = shredstream_args { + let runtime = Runtime::new()?; + if args.desired_regions.len() > 2 { + warn!( + "Too many regions requested, only regions: {:?} will be used", + &args.desired_regions[..2] + ); + } + let heartbeat_hdl = + start_heartbeat(args, &exit, &shutdown_receiver, runtime, metrics.clone()); + thread_handles.push(heartbeat_hdl); + } + + // share sockets between refresh and forwarder thread + let unioned_dest_sockets = Arc::new(ArcSwap::from_pointee( + args.dest_ip_ports + .iter() + .map(|x| x.0) + .collect::>(), + )); + + let forward_stats = Arc::new(StreamerReceiveStats::new("shredstream_proxy-listen_thread")); + let use_discovery_service = + args.endpoint_discovery_url.is_some() && args.discovered_endpoints_port.is_some(); + + let (doublezero_v4_sk_vec, doublezero_v6_sk_vec) = create_multicast_socket_on_device( + &args.multicast_device, + args.multicast_subscribe_port, + args.multicast_bind_ip, + num_pkt_recv_tiles, + )?; + + + let maybe_triton_multicast_config = match args.triton_multicast_group { + Some(multicast_group) => { + log::info!("Using triton multicast group: {}", multicast_group); + match multicast_group { + IpAddr::V4(ipv4) => { + Some(TritonMulticastConfig::Ipv4(TritonMulticastConfigV4 { + multicast_ip: ipv4, + bind_ifname: args.triton_multicast_bind_interface, + listen_port: args.triton_multicast_port.unwrap_or(8002), + })) + } + IpAddr::V6(ipv6) => { + Some(TritonMulticastConfig::Ipv6(TritonMulticastConfigV6 { + multicast_ip: ipv6, + device_ifname: args.triton_multicast_bind_interface + .ok_or_else(|| { + io::Error::new( + ErrorKind::InvalidInput, + "triton-multicast-bind-interface is required for IPv6", + ) + })?, + listen_port: args.triton_multicast_port.unwrap_or(8002), + })) + } + } + } + None => None, + }; + + let pkt_recv_tile_mem_config = PktRecvTileMemConfig { + memory_size: args.pkt_recv_channel_memsize.unwrap_or_default(), + hugepage: args.hugepage, + ..Default::default() + }; + let proxy_th = { + let exit = Arc::clone(&exit); + let pkt_recv_stats = forward_stats.clone(); + let pkt_fwd_stats = metrics.clone(); + let unioned_dest_sockets = Arc::clone(&unioned_dest_sockets); + std::thread::Builder::new() + .name("tritonProxyMain".to_string()) + .spawn(move || { + triton_forwarder::run_proxy_system( + pkt_recv_tile_mem_config, + unioned_dest_sockets, + maybe_triton_multicast_config, + args.src_bind_addr, + args.src_bind_port, + num_pkt_recv_tiles, + num_pkt_fwd_tiles, + FECSetRoutingStrategy, + exit, + pkt_recv_stats, + pkt_fwd_stats, + doublezero_v4_sk_vec, + doublezero_v6_sk_vec, + ); + }) + .expect("tritonProxyMain") + }; + + thread_handles.push(proxy_th); + + let report_metrics_thread = { + let exit = exit.clone(); + spawn(move || { + while !exit.load(Ordering::Relaxed) { + sleep(Duration::from_secs(1)); + forward_stats.report(); + } + }) + }; + thread_handles.push(report_metrics_thread); + + let metrics_hdl = triton_forwarder::start_forwarder_accessory_thread( + metrics.clone(), + args.metrics_report_interval_ms, + shutdown_receiver.clone(), + exit.clone(), + ); + thread_handles.push(metrics_hdl); + if use_discovery_service { + let refresh_handle = forwarder::start_destination_refresh_thread( + args.endpoint_discovery_url.unwrap(), + args.discovered_endpoints_port.unwrap(), + args.dest_ip_ports, + unioned_dest_sockets, + shutdown_receiver.clone(), + exit.clone(), + ); + thread_handles.push(refresh_handle); + } + + if let Some(prom_bind_addr) = args.prometheus_bind_addr { + let prom_hdl = prom::spawn_prometheus_server( + prom_bind_addr, + prom_registry, + shutdown_receiver.clone() + ); + thread_handles.push(prom_hdl); + } + + info!( + "Shredstream started, listening on {}:{}/udp.", + args.src_bind_addr, args.src_bind_port + ); + + + + for thread in thread_handles { + thread.join().expect("thread panicked"); + } + + info!( + "Exiting Shredstream, {} received , {} sent successfully, {} failed, {} duplicate shreds.", + metrics.agg_received_cumulative.load(Ordering::Relaxed), + metrics + .agg_success_forward_cumulative + .load(Ordering::Relaxed), + metrics.agg_fail_forward_cumulative.load(Ordering::Relaxed), + metrics.duplicate_cumulative.load(Ordering::Relaxed), + ); + Ok(()) +} + +fn start_heartbeat( + args: ShredstreamArgs, + exit: &Arc, + shutdown_receiver: &Receiver<()>, + runtime: Runtime, + metrics: Arc, +) -> JoinHandle<()> { + let auth_keypair = Arc::new( + read_keypair_file(Path::new(&args.auth_keypair)).unwrap_or_else(|e| { + panic!( + "Unable to parse keypair file. Ensure that file {:?} is readable. Error: {e}", + args.auth_keypair + ) + }), + ); + + heartbeat::heartbeat_loop_thread( + args.block_engine_url.clone(), + args.auth_url.unwrap_or(args.block_engine_url), + auth_keypair, + args.desired_regions, + SocketAddr::new( + args.common_args + .public_ip + .unwrap_or_else(|| get_public_ip().unwrap()), + args.common_args.src_bind_port, + ), + runtime, + "shredstream_proxy".to_string(), + metrics, + shutdown_receiver.clone(), + exit.clone(), + ) +} diff --git a/proxy/src/mem.rs b/proxy/src/mem.rs new file mode 100644 index 00000000..fed8a7c7 --- /dev/null +++ b/proxy/src/mem.rs @@ -0,0 +1,632 @@ +use std::{ + hint::spin_loop, + sync::{ + Arc, atomic::{AtomicI32, AtomicUsize, Ordering} + }, time::Duration, +}; + +use bytes::{buf::UninitSlice, Buf, BufMut}; + +#[derive(Debug, thiserror::Error)] +#[error("allocation error")] +pub struct AllocError; + +#[repr(C)] +pub struct SharedMem { + pub ptr: *mut u8, + len: usize, +} + +pub fn try_alloc_shared_mem( + num_items: usize, + capacity: usize, + huge: bool, +) -> Result<*mut u8, AllocError> { + // assert!(align.is_power_of_two(), "alignment must be a power of two"); + assert!( + capacity.is_power_of_two(), + "capacity must be a power of two" + ); + let total_len = capacity * num_items; + let ptr = unsafe { + libc::mmap( + std::ptr::null_mut(), + total_len, + libc::PROT_READ | libc::PROT_WRITE, + libc::MAP_SHARED | libc::MAP_ANONYMOUS | if huge { libc::MAP_HUGETLB } else { 0 }, + -1, + 0, + ) + }; + + if std::ptr::eq(ptr, libc::MAP_FAILED) { + return Err(AllocError); + } + + // zero initialize the memory + unsafe { + std::ptr::write_bytes(ptr as *mut u8, 0, total_len); + } + + Ok(ptr as *mut u8) +} + +impl SharedMem { + pub fn new(element_size: usize, capacity: usize, huge: bool) -> Result { + let ptr = try_alloc_shared_mem(element_size, capacity, huge)?; + let len = capacity * element_size; + + Ok(Self { ptr, len }) + } + + pub fn len(&self) -> usize { + self.len + } + + pub fn dealloc(self) { + unsafe { + libc::munmap(self.ptr as *mut libc::c_void, self.len); + } + } +} + +impl Drop for SharedMem { + fn drop(&mut self) { + unsafe { + libc::munmap(self.ptr as *mut libc::c_void, self.len); + } + } +} + +#[derive(Debug)] +#[repr(C, align(16))] +pub struct FrameDesc { + pub ptr: *mut u8, + pub frame_size: usize, +} + +unsafe impl Send for FrameDesc {} + +#[derive(Debug)] +#[repr(C, align(32))] +pub struct FrameBufMut { + ptr: *mut u8, + desc: FrameDesc, +} + +unsafe impl Send for FrameBufMut {} + +#[derive(Debug)] +#[repr(C, align(32))] +pub struct FrameBuf { + curr_ptr: *mut u8, + len: usize, + desc: FrameDesc, +} + +impl FrameBuf { + #[inline] + pub fn len(&self) -> usize { + let end = unsafe { self.desc.ptr.add(self.len) }; + (end as usize) - (self.curr_ptr as usize) + } + + #[inline] + pub fn into_inner(self) -> FrameDesc { + self.desc + } + + #[inline] + pub unsafe fn detach_desc(&self) -> FrameDesc { + FrameDesc { + ptr: self.desc.ptr, + frame_size: self.desc.frame_size, + } + } + + pub unsafe fn unsafe_clone(&self) -> Self { + Self { + curr_ptr: self.curr_ptr, + len: self.len, + desc: FrameDesc { + ptr: self.desc.ptr, + frame_size: self.desc.frame_size, + }, + } + } + + pub unsafe fn unsafe_subslice_clone(&self, offset: usize, len: usize) -> Self { + assert!(offset + len <= self.len()); + Self { + curr_ptr: self.curr_ptr.add(offset), + len, + desc: FrameDesc { + ptr: self.desc.ptr, + frame_size: self.desc.frame_size, + }, + } + } +} + +impl AsRef<[u8]> for FrameBuf { + fn as_ref(&self) -> &[u8] { + self.chunk() + } +} + +unsafe impl Send for FrameBuf {} + +impl From for FrameBuf { + fn from(buf_mut: FrameBufMut) -> Self { + let len = (buf_mut.ptr as usize) - (buf_mut.desc.ptr as usize); + Self { + curr_ptr: buf_mut.desc.ptr, + len, + desc: buf_mut.desc, + } + } +} + +impl FrameDesc { + pub fn as_mut_buf(&self) -> FrameBufMut { + FrameBufMut { + ptr: self.ptr, + desc: FrameDesc { + ptr: self.ptr, + frame_size: self.frame_size, + }, + } + } +} + +impl From for FrameBufMut { + fn from(desc: FrameDesc) -> Self { + Self { + ptr: desc.ptr, + desc, + } + } +} + +impl FrameBufMut { + #[inline] + pub fn base(&self) -> *mut u8 { + ((self.ptr as usize) & !(self.desc.frame_size - 1)) as *mut u8 + } + + #[inline] + pub fn capacity(&self) -> usize { + self.desc.frame_size + } + + + #[inline] + fn end_ptr(&self) -> *const u8 { + unsafe { self.base().add(self.capacity()) } + } + + #[inline] + pub unsafe fn as_mut_ptr(&self) -> *mut u8 { + self.ptr + } + + #[inline] + pub unsafe fn seek(&mut self, offset: usize) { + assert!( + offset < self.desc.frame_size, + "seek offset out of bounds" + ); + let new_ptr = self.desc.ptr.add(offset); + let end_ptr = self.end_ptr(); + assert!(new_ptr as *const u8 <= end_ptr, "seek out of bounds"); + self.ptr = new_ptr; + } +} + +unsafe impl BufMut for FrameBufMut { + fn remaining_mut(&self) -> usize { + // given that ptr must always aligned with `frame_align`, + // we just be able to infer the remaining mut size from frame_align + let frame_offset = (self.ptr as usize) & (self.desc.frame_size - 1); + self.desc.frame_size - frame_offset + } + + unsafe fn advance_mut(&mut self, cnt: usize) { + let new_ptr = self.ptr.add(cnt); + assert!( + new_ptr as *const u8 <= self.end_ptr(), + "advance_mut out of bounds" + ); + self.ptr = new_ptr; + } + + fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice { + unsafe { UninitSlice::from_raw_parts_mut(self.ptr, self.remaining_mut()) } + } +} + +impl Buf for FrameBuf { + fn remaining(&self) -> usize { + let end = unsafe { self.desc.ptr.add(self.len) }; + (end as usize) - (self.curr_ptr as usize) + } + + fn chunk(&self) -> &[u8] { + unsafe { std::slice::from_raw_parts(self.curr_ptr, self.remaining()) } + } + + fn advance(&mut self, cnt: usize) { + let new_ptr = unsafe { self.curr_ptr.add(cnt) }; + let end = unsafe { self.desc.ptr.add(self.len) }; + assert!(new_ptr as *const u8 <= end, "advance out of bounds"); + self.curr_ptr = new_ptr; + } +} + +use std::{ptr, sync::atomic::AtomicBool}; + +// We wrap T to include a 'ready' flag for each slot +#[repr(C)] +struct Slot { + data: std::mem::MaybeUninit, + is_ready: AtomicBool, +} + +struct RingInner { + buf: *mut Slot, // Changed to Slot + capacity: usize, + mask: usize, + head: AtomicUsize, // Producer index (reserved) + tail: AtomicUsize, // Consumer index + futex_flag: AtomicI32, + shmem: Option, +} + +impl Drop for RingInner { + fn drop(&mut self) { + if let Some(shmem) = self.shmem.take() { + let mut tail = self.tail.load(Ordering::Acquire); + let head = self.head.load(Ordering::Acquire); + + // Drop initialized slots + while tail != head { + unsafe { + let slot = &mut *self.buf.add(tail & self.mask); + if slot.is_ready.load(Ordering::Acquire) { + ptr::drop_in_place(slot.data.as_mut_ptr()); + } + } + tail = tail.wrapping_add(1); + } + + drop(shmem); + } + } +} + +unsafe impl Send for RingInner {} +unsafe impl Sync for RingInner {} + +pub struct Tx { + inner: Arc>, +} + +impl Clone for Tx { + fn clone(&self) -> Self { + Self { + inner: Arc::clone(&self.inner), + } + } +} + +pub struct Rx { + inner: Arc>, +} + +pub fn message_ring(capacity: usize) -> Result<(Tx, Rx), AllocError> { + let capacity = capacity.next_power_of_two(); + let size = std::mem::size_of::>(); + + // Allocate memory for Slots + let shmem = SharedMem::new(size, capacity, false)?; + let ptr = shmem.ptr as *mut Slot; + // Initialize the is_ready flags to false + for i in 0..capacity { + unsafe { + let slot_ptr = ptr.add(i); + ptr::write(&mut (*slot_ptr).is_ready, AtomicBool::new(false)); + } + } + + let inner = Arc::new(RingInner { + buf: ptr, + capacity, + mask: capacity - 1, + head: AtomicUsize::new(0), + tail: AtomicUsize::new(0), + futex_flag: AtomicI32::new(0), + shmem: Some(shmem), + }); + + Ok(( + Tx { + inner: Arc::clone(&inner), + }, + Rx { inner }, + )) +} + +impl Tx { + pub fn send(&self, value: T) -> Result<(), T> { + loop { + // 1. Load head and tail to check if the ring is full. + // head: Relaxed is okay here because it's only a hint for the CAS. + // tail: Acquire is REQUIRED to ensure we don't overwrite data + // the consumer hasn't finished reading yet. + let head = self.inner.head.load(Ordering::Relaxed); + let tail = self.inner.tail.load(Ordering::Acquire); + + // 2. The Fix: Calculate occupancy with wrapping awareness + let occupancy = head.wrapping_sub(tail); + + // A ring is only full if occupancy is >= capacity. + // We add a check for (usize::MAX / 2) to ignore the "stale head" + // cases where occupancy underflows to a massive number. + if occupancy >= self.inner.capacity && occupancy < (usize::MAX / 2) { + return Err(value); + } + + // if head.wrapping_sub(tail) >= self.inner.capacity { + // log::error!("Ring is full: head={}, tail={}, capacity={}", head, tail, self.inner.capacity); + // return Err(value); // Ring is full + // } + + // 2. Claim a slot using Compare-and-Swap (CAS). + // We use SeqCst or AcqRel here to ensure that once we "win" this slot, + // we have a synchronized view of the memory. + if self + .inner + .head + .compare_exchange_weak(head, head + 1, Ordering::AcqRel, Ordering::Relaxed) + .is_ok() + { + unsafe { + // 3. Calculate slot location. + let slot = &*self.inner.buf.add(head & self.inner.mask); + + // 4. Write the data into the MaybeUninit. + // We use .write() which is a wrapper for ptr::write. + ptr::write(slot.data.as_ptr() as *mut T, value); + + // 5. RELEASE the data to the consumer. + // This store ensures the data write above is visible to + // any thread that performs an Acquire load on is_ready. + slot.is_ready.store(true, Ordering::Release); + } + + // 6. Futex Wake Logic. + // If the consumer is sleeping (futex_flag == 0), we wake them. + // We use Release to ensure the flag update is visible. + if self.inner.futex_flag.swap(1, Ordering::Release) == 0 { + unsafe { + libc::syscall( + libc::SYS_futex, + &self.inner.futex_flag as *const AtomicI32, + libc::FUTEX_WAKE, + 1, // Wake 1 thread + ); + } + } + return Ok(()); + } + // If CAS failed, another producer grabbed 'head'. + // The loop will retry with the new head value. + std::hint::spin_loop(); + } + } +} + +impl Rx { + pub fn recv(&mut self) -> T { + self.recv_timeout_inner(None).expect("recv failed") + } + + pub fn recv_timeout(&mut self, duration: Duration) -> Option { + self.recv_timeout_inner(Some(duration)) + } + + fn recv_timeout_inner(&mut self, duration: Option) -> Option { + for _ in 0..999 { + if let Some(val) = self.try_recv() { + return Some(val); + } + spin_loop(); + } + + loop { + if let Some(val) = self.try_recv() { + return Some(val); + } + + self.inner.futex_flag.store(0, Ordering::SeqCst); + + if let Some(val) = self.try_recv() { + return Some(val); + } + + let timespec: Option = duration.map(|d| libc::timespec { + tv_sec: d.as_secs() as libc::time_t, + tv_nsec: d.subsec_nanos() as libc::c_long, + }); + + let timeout_ptr = match ×pec { + Some(ts) => ts as *const libc::timespec, + None => std::ptr::null(), + }; + + unsafe { + libc::syscall( + libc::SYS_futex, + &self.inner.futex_flag as *const AtomicI32, + libc::FUTEX_WAIT, + 0, + timeout_ptr, + ); + } + + if duration.is_some(){ + return self.try_recv(); + } + } + } + + pub fn try_recv(&mut self) -> Option { + let tail = self.inner.tail.load(Ordering::Relaxed); + + unsafe { + let slot = &*self.inner.buf.add(tail & self.inner.mask); + + // IMPORTANT: In MPSC, even if head > tail, the data at tail might + // not be written yet because the producer was interrupted. + if !slot.is_ready.load(Ordering::Acquire) { + return None; + } + + let val = ptr::read(slot.data.as_ptr()); + + // Reset the flag for the next time this slot is used + slot.is_ready.store(false, Ordering::Release); + + // Increment tail to free the slot + self.inner.tail.store(tail + 1, Ordering::Release); + Some(val) + } + } +} + +#[cfg(test)] +mod tests { + use std::{collections::HashSet, sync::Barrier, thread}; + + use super::*; + + #[test] + fn test_mpsc_contention() { + let capacity = 1024; + let (tx, mut rx) = message_ring::(capacity).unwrap(); + + let num_producers = 4; + let msgs_per_producer = 1000; + let barrier = Arc::new(Barrier::new(num_producers + 1)); + let mut handles = Vec::new(); + + // Start Producers + for p in 0..num_producers { + let tx_clone = tx.clone(); + let b_clone = Arc::clone(&barrier); + handles.push(thread::spawn(move || { + b_clone.wait(); // Synchronize start + for i in 0..msgs_per_producer { + let val = p * 10000 + i; + while tx_clone.send(val).is_err() { + spin_loop(); // Wait if ring is full + } + } + })); + } + + barrier.wait(); // Start everyone at once + + let mut received = HashSet::new(); + let total_expected = num_producers * msgs_per_producer; + + for _ in 0..total_expected { + received.insert(rx.recv()); + } + + assert_eq!(received.len(), total_expected); + for h in handles { + h.join().unwrap(); + } + } + + #[test] + fn test_frame_buffer_lifecycle() { + let align = 4096; + let capacity = 1; + // 1. Setup the memory pool + let mem = SharedMem::new(align, capacity, false).unwrap(); + + // At this point, the fill_ring inside PagedAlignedMem logic + // should have been populated. Let's create our own handles for testing. + let (tx_fill, mut rx_fill) = message_ring::(capacity).unwrap(); + let (rx_tx, mut rx_rx) = message_ring::(capacity).unwrap(); + // Manually push frames into our test fill ring + for i in 0..capacity { + tx_fill + .send(FrameDesc { + ptr: unsafe { mem.ptr.add(i * align) }, + frame_size: align, + }) + .unwrap(); + } + + // 2. Simulate taking a frame from the pool + let desc = rx_fill.recv(); + let expected_ptr = desc.ptr; + println!("Received frame at ptr: {:p}", expected_ptr); + let mut buf = desc.as_mut_buf(); + assert_eq!(buf.remaining_mut(), 4096); + buf.put_u32(0xDEADBEEF); + assert_eq!(buf.remaining_mut(), 4092); + + rx_tx.send(desc).unwrap(); + + // 3. Verify the frame returned to the fill ring + let returned_desc = rx_rx.recv(); + assert_eq!(returned_desc.ptr, expected_ptr); + // 4. Verify the frame is zeroed out + } + + #[test] + fn test_blocking_recv() { + let (tx, mut rx) = message_ring::(16).unwrap(); + + let handle = thread::spawn(move || { + thread::sleep(std::time::Duration::from_millis(200)); + tx.send(42).unwrap(); + }); + + let start = std::time::Instant::now(); + let val = rx.recv(); // Should block for ~200ms + + assert_eq!(val, 42); + assert!(start.elapsed().as_millis() >= 200); + handle.join().unwrap(); + } + + #[test] + fn test_buf_and_bufmut_impls() { + let frame_size = 4096; + let shmem = SharedMem::new(frame_size, 1, false).unwrap(); + let desc = FrameDesc { + ptr: shmem.ptr, + frame_size, + }; + + let mut buf_mut: FrameBufMut = desc.into(); + assert_eq!(buf_mut.remaining_mut(), 4096); + buf_mut.put_slice(&[1, 2, 3, 4]); + assert_eq!(buf_mut.remaining_mut(), 4092); + assert_eq!(buf_mut.chunk_mut().len(), 4); + + let mut buf: FrameBuf = buf_mut.into(); + assert_eq!(buf.len(), 4); + assert_eq!(buf.remaining(), 4); + let chunk = buf.chunk(); + assert_eq!(chunk, &[1, 2, 3, 4]); + buf.advance(4); + assert_eq!(buf.remaining(), 0); + assert_eq!(buf.len(), 0) + } +} diff --git a/proxy/src/recv_mmsg.rs b/proxy/src/recv_mmsg.rs new file mode 100644 index 00000000..11294942 --- /dev/null +++ b/proxy/src/recv_mmsg.rs @@ -0,0 +1,441 @@ +use std::{ + cmp, + io, + mem::{self, zeroed, MaybeUninit}, + net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, UdpSocket}, + os::fd::AsRawFd, + sync::atomic::{AtomicBool, Ordering}, + time::{Duration, Instant}, +}; + +use bytes::{Buf, BufMut}; +use itertools::izip; +use libc::{AF_INET, AF_INET6, MSG_DONTWAIT, iovec, mmsghdr, msghdr, sockaddr_storage}; +use log::error; +use mio::Poll; +use socket2::socklen_t; +use solana_perf::packet::{NUM_RCVMMSGS, PACKETS_PER_BATCH}; +use solana_sdk::packet::{Meta, PACKET_DATA_SIZE}; +use solana_streamer::{streamer::StreamerReceiveStats}; + +use crate::{mem::{FrameBuf, FrameBufMut, FrameDesc, Rx, Tx}, prom::{inc_packets_received, observe_recv_packet_count}}; + +pub trait PacketRoutingStrategy: Clone { + fn route_packet(&self, packet: &TritonPacket, num_dest: usize) -> Option; +} + +#[inline] +fn hash_pair(x: u64, y: u32) -> u64 { + let mut h = x ^ ((y as u64) << 32); + h ^= h >> 33; + h = h.wrapping_mul(0xff51afd7ed558ccd); + h ^= h >> 33; + h +} + +#[derive(Debug, Clone)] +pub struct FECSetRoutingStrategy; + +impl PacketRoutingStrategy for FECSetRoutingStrategy { + fn route_packet(&self, packet: &TritonPacket, num_dest: usize) -> Option { + let shred_buf = packet.buffer.chunk(); + let slot = solana_ledger::shred::wire::get_slot(shred_buf)?; + let fec = shred_buf.get(79..79 + 4)?; + let fec_bytes: [u8; 4] = fec.try_into().ok()?; + let fec_set_index = u32::from_le_bytes(fec_bytes); + let hash = hash_pair(slot, fec_set_index); + let dest = (hash as usize) % num_dest; + Some(dest) + } +} + +pub fn recv_loop( + sk_vec: Vec, + exit: &AtomicBool, + stats: &StreamerReceiveStats, + fill_rx: &mut Rx, + packet_tx_vec: &[Tx], + router: R, +) -> std::io::Result<()> +where + R: PacketRoutingStrategy, +{ + let mut packet_batch = Vec::with_capacity(PACKETS_PER_BATCH); + let mut frame_bufmut_vec = Vec::with_capacity(PACKETS_PER_BATCH); + let mut next_stats_report = Instant::now() + Duration::from_secs(1); + let mut router_dest_dist = vec![0usize; packet_tx_vec.len()]; + let mut poll = Poll::new()?; + let mut events = mio::Events::with_capacity(sk_vec.len()); + + // Initial registration of sockets + for (i, socket) in sk_vec.iter().enumerate() { + poll.registry().register( + &mut mio::net::UdpSocket::from_std(socket.try_clone().unwrap()), + mio::Token(i), + mio::Interest::READABLE, + )?; + } + + while !exit.load(Ordering::Relaxed) { + + // Events are always cleared before receiving new ones + let result = poll.poll(&mut events, Some(Duration::from_millis(100))); + + match result { + Ok(_) => { } + Err(e) => { + if e.kind() != io::ErrorKind::TimedOut { + return Err(e); + } + } + } + + if next_stats_report.elapsed() > Duration::ZERO { + next_stats_report = Instant::now() + Duration::from_secs(1); + log::trace!( + "recv_loop: packets_count={}, packet_batches_count={}, full_packet_batches_count={}", + stats.packets_count.load(Ordering::Relaxed), + stats.packet_batches_count.load(Ordering::Relaxed), + stats.full_packet_batches_count.load(Ordering::Relaxed), + ); + } + // Check for exit signal, even if socket is busy + // (for instance the leader transaction socket) + if exit.load(Ordering::Relaxed) { + return Ok(()); + } + // We can't use a for-loop here because we need to be able to drain the readiness of each socket. + // Since each recv_from is bounded by a PACKETS_PER_BATCH, we may need to call recv_from multiple times per socket + // until we get a WouldBlock error. + let mut ev_iter = events.iter(); + let Some(mut ev) = ev_iter.next() else { + continue; + }; + 'drain_readiness_loop: while !exit.load(Ordering::Relaxed) { + + let sk_idx = ev.token().0; + let recv_sk = &sk_vec[sk_idx]; + + // Refill the frame buffers as much as we can, + 'fill_bufmut: while frame_bufmut_vec.len() < PACKETS_PER_BATCH { + let maybe_frame_buf = fill_rx.try_recv(); + match maybe_frame_buf { + Some(frame_desc) => { + let frame_bufmut = frame_desc.as_mut_buf(); + frame_bufmut_vec.push(frame_bufmut); + } + None => { + if frame_bufmut_vec.is_empty() { + // block until we get at least one frame buffer + let Some(frame_desc) = fill_rx.recv_timeout(Duration::from_millis(100)) else { + break 'fill_bufmut + }; + let frame_bufmut = frame_desc.as_mut_buf(); + frame_bufmut_vec.push(frame_bufmut); + } else { + break 'fill_bufmut; + } + } + } + } + + if frame_bufmut_vec.is_empty() { + // No available frame buffers to receive into, wait a bit + log::debug!("recv_loop: no available frame buffers to receive into"); + continue 'drain_readiness_loop; + } + + let t = Instant::now(); + let result = recv_from(&mut frame_bufmut_vec, recv_sk, &mut packet_batch, &exit); + let recv_interval = t.elapsed(); + + + match result { + Ok(len) => { + if len > 0 { + // observe_recv_interval(recv_interval.as_micros() as f64); + inc_packets_received(len as u64); + observe_recv_packet_count(len as f64); + let StreamerReceiveStats { + packets_count, + packet_batches_count, + full_packet_batches_count, + .. + } = stats; + + packets_count.fetch_add(len, Ordering::Relaxed); + packet_batches_count.fetch_add(1, Ordering::Relaxed); + if len == PACKETS_PER_BATCH { + full_packet_batches_count.fetch_add(1, Ordering::Relaxed); + } + packet_batch + .iter_mut() + .for_each(|p| p.meta_mut().set_from_staked_node(false)); + + 'packet_drain: for packet in packet_batch.drain(..) { + let dest_idx = match router.route_packet(&packet, packet_tx_vec.len()) { + Some(idx) => idx, + None => { + log::trace!("Failed to route packet {:?}", packet); + let trashed_frame_bufmut = packet.buffer.into_inner().as_mut_buf(); + frame_bufmut_vec.push(trashed_frame_bufmut); + continue 'packet_drain; + } + }; + router_dest_dist[dest_idx] += 1; + let _ = &packet_tx_vec[dest_idx] + .send(packet) + .expect(format!("failed to send packet to {dest_idx} ring is full, distr:{:?}", router_dest_dist).as_str()); + } + } + } + Err(e) => { + if e.kind() == io::ErrorKind::WouldBlock { + // Only when we drained all events for this poll iteration, we process the next event or break + match ev_iter.next() { + Some(next_ev) => { + ev = next_ev; + continue 'drain_readiness_loop; + } + None => { + break 'drain_readiness_loop; + } + } + } else { + return Err(e); + } + } + } + } + } + Ok(()) +} + +pub fn recv_from( + available_frame_buf_vec: &mut Vec, + socket: &UdpSocket, + batch: &mut Vec, + exit: &AtomicBool, +) -> std::io::Result { + // let mut i: usize = 0; + //DOCUMENTED SIDE-EFFECT + //Performance out of the IO without poll + // * block on the socket until it's readable + // * set the socket to non blocking + // * read until it fails + // * set it back to blocking before returning + // socket.set_nonblocking(false)?; + let batch_capacity = batch.capacity(); + assert!(batch_capacity >= PACKETS_PER_BATCH); + + let mut i = 0; + + while !exit.load(Ordering::Relaxed) { + let npkts = triton_recv_mmsg(socket, available_frame_buf_vec, batch)?; + i += npkts; + if available_frame_buf_vec.is_empty() { + break; + } + if batch.len() >= batch_capacity { + break; + } + // Try to batch into big enough buffers + // will cause less re-shuffling later on. + if i >= PACKETS_PER_BATCH { + break; + } + } + Ok(i) +} + +#[derive(Debug)] +#[repr(C)] +pub struct TritonPacket { + pub buffer: FrameBuf, + pub meta: Meta, +} + +impl TritonPacket { + pub fn new(buffer: FrameBuf) -> Self { + Self { + buffer, + meta: Meta::default(), + } + } + + pub fn meta_mut(&mut self) -> &mut Meta { + &mut self.meta + } +} + +impl AsRef<[u8]> for TritonPacket { + fn as_ref(&self) -> &[u8] { + self.buffer.chunk() + } +} + +pub fn triton_recv_mmsg( + sock: &UdpSocket, + fill_buffers: &mut Vec, + packets: &mut Vec, +) -> io::Result { + // Should never hit this, but bail if the caller didn't provide any Packets + // to receive into + if fill_buffers.is_empty() { + return Ok(0); + } + // Assert that there are no leftovers in packets. + const SOCKADDR_STORAGE_SIZE: socklen_t = mem::size_of::() as socklen_t; + + let mut iovs = [MaybeUninit::uninit(); NUM_RCVMMSGS]; + let mut addrs = [MaybeUninit::zeroed(); NUM_RCVMMSGS]; + let mut hdrs = [MaybeUninit::uninit(); NUM_RCVMMSGS]; + let remaining_packets = packets.capacity() - packets.len(); + let sock_fd = sock.as_raw_fd(); + let count = cmp::min(iovs.len(), remaining_packets).min(fill_buffers.len()); + let mut frame_buffer_inflight_vec: [MaybeUninit; NUM_RCVMMSGS] = + std::array::from_fn(|_| MaybeUninit::uninit()); + + let mut frame_buffer_inflight_cnt = 0; + for (hdr, iov, addr) in izip!(&mut hdrs, &mut iovs, &mut addrs).take(count) { + let buffer = fill_buffers.pop().expect("insufficient fill buffers"); + assert!( + buffer.remaining_mut() >= PACKET_DATA_SIZE, + "fill buffer too small" + ); + let iov_base = unsafe { buffer.as_mut_ptr() as *mut libc::c_void }; + + iov.write(iovec { + iov_base: iov_base, + iov_len: PACKET_DATA_SIZE, + }); + + let msg_hdr = create_msghdr(addr, SOCKADDR_STORAGE_SIZE, iov); + + hdr.write(mmsghdr { + msg_len: 0, + msg_hdr, + }); + // Keep track of the in-flight frame buffers to avoid use-after-free + frame_buffer_inflight_vec[frame_buffer_inflight_cnt].write(buffer); + frame_buffer_inflight_cnt += 1; + } + + let mut ts = libc::timespec { + tv_sec: 1, + tv_nsec: 0, + }; + // TODO: remove .try_into().unwrap() once rust libc fixes recvmmsg types for musl + #[allow(clippy::useless_conversion)] + let nrecv = unsafe { + libc::recvmmsg( + sock_fd, + hdrs[0].assume_init_mut(), + count as u32, + MSG_DONTWAIT.try_into().unwrap(), + &mut ts, + ) + }; + let nrecv = if nrecv < 0 { + // On error, return all in-flight frame buffers back to the caller + for i in 0..frame_buffer_inflight_cnt { + let buffer = unsafe { frame_buffer_inflight_vec[i].assume_init_read() }; + fill_buffers.push(buffer); + } + return Err(io::Error::last_os_error()); + } else { + usize::try_from(nrecv).unwrap() + }; + for (addr, hdr, filled_bufmut) in + izip!(addrs, hdrs, frame_buffer_inflight_vec).take(nrecv) + { + // SAFETY: We initialized `count` elements of `hdrs` above. `count` is + // passed to recvmmsg() as the limit of messages that can be read. So, + // `nrevc <= count` which means we initialized this `hdr` and + // recvmmsg() will have updated it appropriately + let hdr_ref = unsafe { hdr.assume_init_ref() }; + // SAFETY: Similar to above, we initialized this `addr` and recvmmsg() + // will have populated it + let addr_ref = unsafe { addr.assume_init_ref() }; + let mut filled_bufmut = unsafe { filled_bufmut.assume_init_read() }; + unsafe { filled_bufmut.seek(hdr_ref.msg_len as usize); } + let filled_buf: FrameBuf = filled_bufmut.into(); + let mut pkt = TritonPacket { + buffer: filled_buf, + meta: Meta::default(), + }; + pkt.meta_mut().size = hdr_ref.msg_len as usize; + if let Some(addr) = cast_socket_addr(addr_ref, hdr_ref) { + pkt.meta_mut().set_socket_addr(&addr); + } + packets.push(pkt); + } + + for (iov, addr, hdr) in izip!(&mut iovs, &mut addrs, &mut hdrs).take(count) { + // SAFETY: We initialized `count` elements of each array above + // + // It may be that `packets.len() != NUM_RCVMMSGS`; thus, some elements + // in `iovs` / `addrs` / `hdrs` may not get initialized. So, we must + // manually drop `count` elements from each array instead of being able + // to convert [MaybeUninit] to [T] and letting `Drop` do the work + // for us when these items go out of scope at the end of the function + unsafe { + iov.assume_init_drop(); + addr.assume_init_drop(); + hdr.assume_init_drop(); + } + } + + Ok(nrecv) +} + +fn create_msghdr( + msg_name: &mut MaybeUninit, + msg_namelen: socklen_t, + iov: &mut MaybeUninit, +) -> msghdr { + // Cannot construct msghdr directly on musl + // See https://github.com/rust-lang/libc/issues/2344 for more info + let mut msg_hdr: msghdr = unsafe { zeroed() }; + msg_hdr.msg_name = msg_name.as_mut_ptr() as *mut _; + msg_hdr.msg_namelen = msg_namelen; + msg_hdr.msg_iov = iov.as_mut_ptr(); + msg_hdr.msg_iovlen = 1; + msg_hdr.msg_control = std::ptr::null::() as *mut _; + msg_hdr.msg_controllen = 0; + msg_hdr.msg_flags = 0; + msg_hdr +} + +fn cast_socket_addr(addr: &sockaddr_storage, hdr: &mmsghdr) -> Option { + use libc::{sa_family_t, sockaddr_in, sockaddr_in6}; + const SOCKADDR_IN_SIZE: usize = std::mem::size_of::(); + const SOCKADDR_IN6_SIZE: usize = std::mem::size_of::(); + if addr.ss_family == AF_INET as sa_family_t + && hdr.msg_hdr.msg_namelen == SOCKADDR_IN_SIZE as socklen_t + { + // ref: https://github.com/rust-lang/socket2/blob/65085d9dff270e588c0fbdd7217ec0b392b05ef2/src/sockaddr.rs#L167-L172 + let addr = unsafe { &*(addr as *const _ as *const sockaddr_in) }; + return Some(SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes()), + u16::from_be(addr.sin_port), + ))); + } + if addr.ss_family == AF_INET6 as sa_family_t + && hdr.msg_hdr.msg_namelen == SOCKADDR_IN6_SIZE as socklen_t + { + // ref: https://github.com/rust-lang/socket2/blob/65085d9dff270e588c0fbdd7217ec0b392b05ef2/src/sockaddr.rs#L174-L189 + let addr = unsafe { &*(addr as *const _ as *const sockaddr_in6) }; + return Some(SocketAddr::V6(SocketAddrV6::new( + Ipv6Addr::from(addr.sin6_addr.s6_addr), + u16::from_be(addr.sin6_port), + addr.sin6_flowinfo, + addr.sin6_scope_id, + ))); + } + error!( + "recvmmsg unexpected ss_family:{} msg_namelen:{}", + addr.ss_family, hdr.msg_hdr.msg_namelen + ); + None +} diff --git a/proxy/src/triton_forwarder.rs b/proxy/src/triton_forwarder.rs new file mode 100644 index 00000000..0114dca0 --- /dev/null +++ b/proxy/src/triton_forwarder.rs @@ -0,0 +1,869 @@ +use std::{ + collections::VecDeque, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, num::NonZeroUsize, os::fd::AsRawFd, str::FromStr, sync::{ + Arc, atomic::{AtomicBool, Ordering} + }, thread::JoinHandle, time::{Duration, Instant} +}; + +use arc_swap::ArcSwap; +use bytes::Buf; +use crossbeam_channel::{Receiver, Sender}; +use itertools::{izip, Itertools}; +use libc; +use log::{debug, error, info, warn}; +use solana_net_utils::SocketConfig; +use solana_perf::deduper::Deduper; +use solana_streamer::{ + sendmmsg::{batch_send, SendPktsError}, + streamer::{StreamerReceiveStats}, +}; + +use crate::{ + forwarder::{ShredMetrics, try_create_ipv6_socket}, mem::{FrameBuf, FrameDesc, Rx, SharedMem, Tx}, prom::{ + inc_packets_deduped, inc_packets_forward_failed, observe_dedup_time, observe_recv_interval, observe_send_duration, observe_send_packet_count + }, recv_mmsg::{PacketRoutingStrategy, TritonPacket}, triton_multicast_config::TritonMulticastConfig +}; + +// values copied from https://github.com/solana-labs/solana/blob/33bde55bbdde13003acf45bb6afe6db4ab599ae4/core/src/sigverify_shreds.rs#L20 +pub const DEDUPER_FALSE_POSITIVE_RATE: f64 = 0.001; +pub const DEDUPER_NUM_BITS: u64 = 637_534_199; // 76MB +pub const DEDUPER_RESET_CYCLE: Duration = Duration::from_secs(5 * 60); +pub const IP_MULTICAST_TTL: u32 = 8; + +#[derive(Debug, Clone, Copy, Default)] +pub enum PktRecvMemSizing { + #[default] + XSmall = 134217728, // 128MiB + Small = 268435456, // 256MiB + Medium = 536870912, // 512MiB + Large = 1073741824, // 1GiB + XLarge = 2147483648, // 2GiB + XXLarge = 4294967296, // 4GiB + XXXLarge = 8589934592, // 8GiB + XXXXLarge = 17179869184, // 16GiB + XXXXXLarge = 34359738368, // 32GiB +} + +#[derive(Debug, thiserror::Error)] +#[error("Invalid ReceiverMemoryCapacity: {0}")] +pub struct ReceiverMemoryCapacityFromStrErr(String); + +impl FromStr for PktRecvMemSizing { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "xsmall" | "xs" => Ok(PktRecvMemSizing::XSmall), + "small" | "s" => Ok(PktRecvMemSizing::Small), + "medium" | "m" => Ok(PktRecvMemSizing::Medium), + "large" | "l" => Ok(PktRecvMemSizing::Large), + "xlarge" | "xl" => Ok(PktRecvMemSizing::XLarge), + "xxlarge" | "xxl" | "2xl" => Ok(PktRecvMemSizing::XXLarge), + "xxxlarge" | "xxxl" | "3xl" => Ok(PktRecvMemSizing::XXXLarge), + "xxxxlarge" | "xxxxl" | "4xl" => Ok(PktRecvMemSizing::XXXXLarge), + "xxxxxlarge" | "xxxxxl" | "5xl" => Ok(PktRecvMemSizing::XXXXXLarge), + _ => Err(s.to_string()), + } + } +} + +#[derive(Clone, Debug)] +pub struct PktRecvTileMemConfig { + pub frame_size: usize, + pub memory_size: PktRecvMemSizing, + pub hugepage: bool, +} + +impl Default for PktRecvTileMemConfig { + fn default() -> Self { + Self { + frame_size: 2048, + memory_size: PktRecvMemSizing::default(), + hugepage: false, + } + } +} + +fn packet_recv_tile( + pkt_recv_idx: usize, + pkt_recv_socket_vec: Vec, + exit: Arc, + forwarder_stats: Arc, + mut fill_rx: Rx, + packet_tx_vec: Vec>, + packet_router: R, + tile_drop_sig: TileClosedSignal, +) -> std::io::Result> +where + R: PacketRoutingStrategy + Send + 'static, +{ + std::thread::Builder::new() + .name(format!("ssListen{pkt_recv_idx}")) + .spawn(move || { + crate::recv_mmsg::recv_loop( + pkt_recv_socket_vec, + &exit, + &forwarder_stats, + &mut fill_rx, + &packet_tx_vec, + packet_router, + ) + .expect("recv_loop"); + drop(tile_drop_sig); + }) +} + +#[derive(Clone, Debug)] +#[repr(C)] +pub struct SharedMemInfo { + pub start_ptr: *const u8, + pub len: usize, // always a power of 2 +} + +unsafe impl Send for SharedMemInfo {} +unsafe impl Sync for SharedMemInfo {} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TileKind { + PktRecv, + PktFwd, +} + +struct TileClosedSignal { + kind: TileKind, + idx: usize, + tx: Option>, +} + +struct TileWaitGroup { + rx: Receiver<(TileKind, usize)>, + tx: Sender<(TileKind, usize)>, +} + +impl TileWaitGroup { + fn new() -> Self { + let (tx, rx) = crossbeam_channel::unbounded(); + Self { rx, tx } + } + + fn get_tile_closed_signal(&self, kind: TileKind, idx: usize) -> TileClosedSignal { + TileClosedSignal { + kind, + idx, + tx: Some(self.tx.clone()), + } + } + + fn wait_first(self) -> (TileKind, usize) { + drop(self.tx); + self.rx.recv().expect("TileWaitGroup::wait_first") + } +} + +impl Drop for TileClosedSignal { + fn drop(&mut self) { + if let Some(tx) = &self.tx { + let _ = tx.send((self.kind, self.idx)); + } + } +} + +fn packet_fwd_tile( + packet_fwd_idx: usize, + hot_dest_vec: Arc>>, + send_socket: UdpSocket, + mut packet_rx: Rx, + fill_tx_vec: Vec>, + shmem_info_vec: Vec, + stats: Arc, + exit: Arc, + tile_drop_sig: TileClosedSignal, +) -> std::io::Result> { + std::thread::Builder::new() + .name(format!("ssPxyTx_{packet_fwd_idx}")) + .spawn(move || { + let mut deduper = Deduper::<2, [u8]>::new(&mut rand::thread_rng(), DEDUPER_NUM_BITS); + const UIO_MAXIOV: usize = libc::UIO_MAXIOV as usize; + // We allocate double size to account for possible overflow if destinations array is really big + let mut next_batch_send: Vec<(FrameBuf, SocketAddr)> = Vec::with_capacity(UIO_MAXIOV); + let mut queued: VecDeque = VecDeque::with_capacity(UIO_MAXIOV); + + let mut last_batch_to_send = Instant::now(); + + for shmem_info in &shmem_info_vec { + assert!( + shmem_info.len.is_power_of_two(), + "shmem_info.len must be a power of 2" + ); + } + + let mut next_deduper_reset_attempt = Instant::now() + Duration::from_secs(2); + let mut recycled_frames: Vec = Vec::with_capacity(UIO_MAXIOV); + while !exit.load(Ordering::Relaxed) { + if next_deduper_reset_attempt.elapsed() > Duration::ZERO { + deduper.maybe_reset( + &mut rand::thread_rng(), + DEDUPER_FALSE_POSITIVE_RATE, + DEDUPER_RESET_CYCLE, + ); + next_deduper_reset_attempt = Instant::now() + Duration::from_secs(2); + // show stats here... + log::debug!( + "send_batch_count: {}, duplicate: {}, total-pkt-sent: {}, queue-len: {}, to-recycle: {}", + stats.send_batch_count.load(Ordering::Relaxed), + stats.duplicate.load(Ordering::Relaxed), + stats.send_batch_size_sum.load(Ordering::Relaxed), + queued.len(), + recycled_frames.len(), + ); + } + + if queued.is_empty() && recycled_frames.is_empty() && next_batch_send.is_empty() { + if let Some(packet) = packet_rx.recv_timeout(Duration::from_millis(100)) { + let data_size = packet.meta.size; + let data_slice = &packet.buffer.chunk()[..data_size]; + if deduper.dedup(data_slice) { + // put it inside the recycle queue + let desc = packet.buffer.into_inner(); + recycled_frames.push(desc); + stats.duplicate.fetch_add(1, Ordering::Relaxed); + } else { + queued.push_back(packet); + } + } + } + + // Fill up the queued OR recycled_frames as much as possible + 'fill_backlog: while queued.len() < UIO_MAXIOV && recycled_frames.len() < UIO_MAXIOV { + // Fill the batch as much as possible. + let Some(packet) = packet_rx.try_recv() else { + break 'fill_backlog; + }; + let data_size = packet.meta.size; + let data_slice = &packet.buffer.chunk()[..data_size]; + let t = Instant::now(); + if deduper.dedup(data_slice) { + // put it inside the recycle queue + let desc = packet.buffer.into_inner(); + recycled_frames.push(desc); + stats.duplicate.fetch_add(1, Ordering::Relaxed); + inc_packets_deduped(1); + } else { + queued.push_back(packet); + } + let dedup_duration = t.elapsed(); + observe_dedup_time(dedup_duration.as_micros() as f64); + } + + let dests = hot_dest_vec.load(); + let dests_len = dests.len(); + // Fill up the next_batch_send + 'fill_batch_send: while next_batch_send.len() < UIO_MAXIOV + && queued.len() > 0 + && recycled_frames.len() < UIO_MAXIOV + { + let remaining = UIO_MAXIOV - next_batch_send.len(); + if dests_len > remaining { + break 'fill_batch_send; + } + + let Some(packet) = queued.pop_front() else { + break 'fill_batch_send; + }; + let buf = packet.buffer; + let desc = unsafe { buf.detach_desc() }; + recycled_frames.push(desc); + + for dest in dests.iter() { + let origin = packet.meta.socket_addr().ip(); + if origin == dest.ip() { + continue; + } + // Cheap to do since we are just copying a pointer + let buf_clone = unsafe { buf.unsafe_subslice_clone(0, packet.meta.size) }; + next_batch_send.push((buf_clone, *dest)); + } + } + + assert!( + next_batch_send.len() <= UIO_MAXIOV, + "next_batch_send.len() = {}", + next_batch_send.len() + ); + assert!( + recycled_frames.len() <= UIO_MAXIOV, + "recycled_frames.len() = {}", + recycled_frames.len() + ); + assert!( + queued.len() <= UIO_MAXIOV, + "queued.len() = {}", + queued.len() + ); + + let batch_send_ts = Instant::now(); + + if !next_batch_send.is_empty() { + let e = last_batch_to_send.elapsed(); + last_batch_to_send = Instant::now(); + + observe_recv_interval(e.as_micros() as f64); + match batch_send(&send_socket, &next_batch_send) { + Ok(_) => { + // Successfully sent all packets in the batch + let send_duration = batch_send_ts.elapsed(); + stats.batch_send_time_spent.fetch_add(send_duration.as_micros() as u64, Ordering::Relaxed); + stats.send_batch_count.fetch_add(1, Ordering::Relaxed); + stats.send_batch_size_sum.fetch_add(next_batch_send.len() as u64, Ordering::Relaxed); + observe_send_duration(send_duration.as_micros() as f64); + observe_send_packet_count(next_batch_send.len() as f64); + } + Err(SendPktsError::IoError(err, num_failed)) => { + error!( + "Failed to send batch of size {}. \ + {num_failed} packets failed. Error: {err}", + next_batch_send.len() + ); + inc_packets_forward_failed(num_failed as u64); + } + } + } + next_batch_send.clear(); + + // Recycle all used frames + while let Some(desc) = recycled_frames.pop() { + let fill_ring_idx = shmem_info_vec + .iter() + .find_position(|shmem_info| { + let p = desc.ptr as usize; + let start = shmem_info.start_ptr as usize; + p >= start && p < start + shmem_info.len + }) + .expect("unknown frame desc") + .0; + fill_tx_vec[fill_ring_idx] + .send(desc) + .expect("frame recycling"); + } + } + log::info!("Exiting pkt_fwd_tile {}", packet_fwd_idx); + drop(tile_drop_sig); + }) +} + +#[allow(clippy::too_many_arguments)] +pub fn run_proxy_system( + pkt_recv_tile_mem_config: PktRecvTileMemConfig, + dest_addr_vec: Arc>>, + multticast_config: Option, + src_ip: IpAddr, + src_port: u16, + num_pkt_recv_tiles: usize, + num_pkt_fwd_tiles: usize, + pkt_router: R, + exit: Arc, + pk_recv_stats: Arc, + pk_fwd_stats: Arc, + doublezero_v4_sk_vec: Vec, + doublezero_v6_sk_vec: Vec, +) where + R: PacketRoutingStrategy + Send + Sync + 'static, +{ + let mut tile_thread_vec: Vec> = Vec::new(); + // Build pkt_recv sockets + let pkt_recv_multicast_sk_vec = if let Some(multicast_config) = multticast_config { + log::info!("Using Triton multicast configuration for pkt_recv tiles"); + let vec = crate::triton_multicast_config::create_multicast_sockets_triton( + &multicast_config, + NonZeroUsize::new(num_pkt_recv_tiles).expect("num_pkt_recv_tiles must be non-zero"), + ).expect("multicast-config"); + Some(vec) + } else { + None + }; + + assert!(doublezero_v4_sk_vec.len() <= num_pkt_recv_tiles, "doublezero_v4_sk_vec.len() ({}) > num_pkt_recv_tiles ({})", doublezero_v4_sk_vec.len(), num_pkt_recv_tiles); + assert!(doublezero_v6_sk_vec.len() <= num_pkt_recv_tiles, "doublezero_v6_sk_vec.len() ({}) > num_pkt_recv_tiles ({})", doublezero_v6_sk_vec.len(), num_pkt_recv_tiles); + + let (_port, pkt_recv_sk_vec) = solana_net_utils::multi_bind_in_range_with_config( + src_ip, + (src_port, src_port + 1), + SocketConfig::default().reuseport(true), + num_pkt_recv_tiles, + ) + .unwrap_or_else(|_| { + panic!("Failed to bind listener sockets. Check that port {src_port} is not in use.") + }); + assert!(pkt_recv_sk_vec.len() == num_pkt_recv_tiles, "pkt_recv_sk_vec.len() ({}) != num_pkt_recv_tiles ({})", pkt_recv_sk_vec.len(), num_pkt_recv_tiles); + + if let Some(multicast_sk_vec) = &pkt_recv_multicast_sk_vec { + assert!(multicast_sk_vec.len() == num_pkt_recv_tiles, "multicast_sk_vec.len() ({}) != num_pkt_recv_tiles ({})", multicast_sk_vec.len(), num_pkt_recv_tiles); + } + + // Make sure socket are set to nonblocking + for sk in &pkt_recv_sk_vec { + sk.set_nonblocking(true).expect("pkt_recv_sk nonblocking"); + } + + + let mut pkt_recv_sk_raw_fd_vec: Vec = Vec::with_capacity(num_pkt_recv_tiles); + for sk in &pkt_recv_sk_vec { + pkt_recv_sk_raw_fd_vec.push(sk.as_raw_fd()); + } + + let num_frames = + pkt_recv_tile_mem_config.memory_size as usize / pkt_recv_tile_mem_config.frame_size; + let frame_size = pkt_recv_tile_mem_config.frame_size; + + let tile_wait_group = TileWaitGroup::new(); + let mut shmem_info_vec: Vec = Vec::with_capacity(num_pkt_recv_tiles); + let mut fill_tx_vec: Vec> = Vec::with_capacity(num_pkt_recv_tiles); + let mut fill_rx_vec: Vec> = Vec::with_capacity(num_pkt_recv_tiles); + let mut shmem_vec: Vec = Vec::with_capacity(num_pkt_recv_tiles); + + let mut pkt_fwd_sk_vec: Vec = Vec::with_capacity(num_pkt_fwd_tiles); + let mut pkt_fwd_sk_raw_fd_vec: Vec = Vec::with_capacity(num_pkt_fwd_tiles); + + let mut packet_rx_vec: Vec> = Vec::with_capacity(num_pkt_fwd_tiles); + let mut packet_tx_vec: Vec> = Vec::with_capacity(num_pkt_fwd_tiles); + + // Create the shared memory regions for recv tiles + for _ in 0..num_pkt_recv_tiles { + assert!( + num_frames.is_power_of_two(), + "num_frames must be a power of 2" + ); + assert!( + frame_size.is_power_of_two(), + "frame_size must be a power of 2" + ); + let shmem = SharedMem::new(frame_size, num_frames, pkt_recv_tile_mem_config.hugepage) + .expect("SharedMem::new"); + log::info!( + "Created shared memory region with frame_size={} num_frames={} total_size={} hugepage={}", + frame_size, + num_frames, + shmem.len(), + pkt_recv_tile_mem_config.hugepage, + ); + + let shmem_info = SharedMemInfo { + start_ptr: shmem.ptr, + len: shmem.len(), + }; + shmem_info_vec.push(shmem_info); + + let (fill_tx, fill_rx) = crate::mem::message_ring(num_frames).expect("frame ring"); + // Fill the fill ring with all frames + for i in 0..num_frames { + let frame_desc = FrameDesc { + ptr: unsafe { shmem.ptr.add(i * frame_size) }, + frame_size: frame_size, + }; + fill_tx + .send(frame_desc) + .expect("initial frame ring population"); + } + shmem_vec.push(shmem); + fill_tx_vec.push(fill_tx); + fill_rx_vec.push(fill_rx); + log::info!("Initialized frame ring with {} frames", num_frames); + } + + // Create socket for sending packets + for _ in 0..num_pkt_fwd_tiles { + let send_socket = { + let ipv6_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0); + match try_create_ipv6_socket(ipv6_addr) { + Ok(socket) => { + info!("Successfully bound send socket to IPv6 dual-stack address."); + socket + .set_multicast_loop_v6(false) + .expect("Failed to disable IPv6 multicast loopback"); + socket + } + Err(e) if e.raw_os_error() == Some(libc::EAFNOSUPPORT) => { + // This error (code 97 on Linux) means IPv6 is not supported. + warn!("IPv6 not available. Falling back to IPv4-only for sending."); + let ipv4_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); + let socket = UdpSocket::bind(ipv4_addr) + .expect("Failed to bind to IPv4 socket after IPv6 failed"); + socket + .set_multicast_ttl_v4(IP_MULTICAST_TTL) + .expect("IP_MULTICAST_TTL_V4"); + socket + .set_multicast_loop_v4(false) + .expect("Failed to disable IPv4 multicast loopback"); + socket + } + Err(e) => { + // For any other error (e.g., port in use), panic. + panic!("Failed to bind send socket with an unexpected error: {e}"); + } + } + }; + log::info!( + "Packet forwarder sending socket bound to {}", + send_socket.local_addr().unwrap() + ); + pkt_fwd_sk_raw_fd_vec.push(send_socket.as_raw_fd()); + pkt_fwd_sk_vec.push(send_socket); + } + + + let pkt_fwd_tile_ring_capacity = num_frames * num_pkt_recv_tiles; + log::info!( + "Setting pkt_fwd tile's message ring capacity to {} (num_frames {} * num_pkt_recv_tiles {})", + pkt_fwd_tile_ring_capacity, + num_frames, + num_pkt_recv_tiles + ); + + // Create pkt_fwd message rings + // One ring per pkt_fwd tile + for _ in 0..num_pkt_fwd_tiles { + // Worst case scenario all frames from all pkt_recv tiles are sent to this pkt_fwd tile + // We set the ring capacity to that + let (packet_tx, packet_rx) = crate::mem::message_ring(pkt_fwd_tile_ring_capacity).expect("pkt_fwd ring"); + packet_tx_vec.push(packet_tx); + packet_rx_vec.push(packet_rx); + } + + // Spawn pkt_fwd tiles + for (pkt_fwd_idx, pkt_fwd_sk, packet_rx) in izip!( + 0..num_pkt_fwd_tiles, + pkt_fwd_sk_vec.into_iter(), + packet_rx_vec.into_iter() + ) { + let hot_dest_vec = Arc::clone(&dest_addr_vec); + let fill_tx_vec = fill_tx_vec.clone(); + let shmem_info_vec = shmem_info_vec.clone(); + let exit = Arc::clone(&exit); + let th = packet_fwd_tile( + pkt_fwd_idx, + hot_dest_vec, + pkt_fwd_sk, + packet_rx, + fill_tx_vec, + shmem_info_vec, + Arc::clone(&pk_fwd_stats), + exit, + tile_wait_group.get_tile_closed_signal(TileKind::PktFwd, pkt_fwd_idx), + ) + .expect("packet_fwd_tile"); + tile_thread_vec.push(th); + log::info!("Spawned pkt_fwd tile {}", pkt_fwd_idx); + } + + // Spawn pkt_recv tiles + for (pkt_recv_idx, pkt_recv_sk, fill_rx) in izip!( + 0..num_pkt_recv_tiles, + pkt_recv_sk_vec.into_iter(), + fill_rx_vec.into_iter() + ) { + + let mut recv_pkt_vec = vec![ + pkt_recv_sk + ]; + + if let Some(multicast_sk_vec) = &pkt_recv_multicast_sk_vec { + recv_pkt_vec.push(multicast_sk_vec[pkt_recv_idx].try_clone().expect("multicast sk clone")); + } + + if let Some(doublezero_v4_sk) = doublezero_v4_sk_vec.get(pkt_recv_idx) { + recv_pkt_vec.push(doublezero_v4_sk.try_clone().expect("doublezero v4 sk clone")); + } + + if let Some(doublezero_v6_sk) = doublezero_v6_sk_vec.get(pkt_recv_idx) { + recv_pkt_vec.push(doublezero_v6_sk.try_clone().expect("doublezero v6 sk clone")); + } + + let exit = Arc::clone(&exit); + let forwarder_stats = Arc::clone(&pk_recv_stats); + let packet_tx_vec_clone = packet_tx_vec.clone(); + let pkt_router_clone = pkt_router.clone(); + let jh = packet_recv_tile( + pkt_recv_idx, + recv_pkt_vec, + exit, + forwarder_stats, + fill_rx, + packet_tx_vec_clone, + pkt_router_clone, + tile_wait_group.get_tile_closed_signal(TileKind::PktRecv, pkt_recv_idx), + ) + .expect("packet_recv_tile"); + tile_thread_vec.push(jh); + log::info!("Spawned pkt_recv tile {}", pkt_recv_idx); + } + + + let (kind, idx) = tile_wait_group.wait_first(); + warn!("Tile of kind {kind:?} with idx {idx} has exited. Shutting down proxy system"); + + exit.store(true, Ordering::Release); + drop(fill_tx_vec); + drop(packet_tx_vec); + log::info!("Waiting for {} tile threads to exit", tile_thread_vec.len()); + + for th in tile_thread_vec { + let result = th.join(); + if let Err(e) = result { + error!("Tile thread join error: {:?}", e); + } + } +} + +/// Reset dedup + send metrics to influx +pub fn start_forwarder_accessory_thread( + metrics: Arc, + metrics_update_interval_ms: u64, + shutdown_receiver: Receiver<()>, + exit: Arc, +) -> JoinHandle<()> { + std::thread::Builder::new() + .name("ssPxyAccessory".to_string()) + .spawn(move || { + let metrics_tick = + crossbeam_channel::tick(Duration::from_millis(metrics_update_interval_ms)); + while !exit.load(Ordering::Relaxed) { + crossbeam_channel::select! { + // send metrics to influx + recv(metrics_tick) -> _ => { + metrics.report(); + metrics.reset(); + } + + // handle SIGINT shutdown + recv(shutdown_receiver) -> _ => { + break; + } + } + } + }) + .unwrap() +} + +// #[cfg(test)] +// mod tests { +// use std::{ +// net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket}, +// str::FromStr, +// sync::{Arc, Mutex, RwLock}, +// thread, +// thread::sleep, +// time::Duration, +// }; + +// use solana_perf::{ +// deduper::Deduper, +// packet::{Meta, Packet, PacketBatch}, +// }; +// use solana_sdk::packet::{PacketFlags, PACKET_DATA_SIZE}; + +// fn listen_and_collect(listen_socket: UdpSocket, received_packets: Arc>>>) { +// let mut buf = [0u8; PACKET_DATA_SIZE]; +// loop { +// listen_socket.recv(&mut buf).unwrap(); +// received_packets.lock().unwrap().push(Vec::from(buf)); +// } +// } + +// #[test] +// fn test_2shreds_3destinations() { +// let packet_batch = PacketBatch::new(vec![ +// Packet::new( +// [1; PACKET_DATA_SIZE], +// Meta { +// size: PACKET_DATA_SIZE, +// addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), +// port: 48289, // received on random port +// flags: PacketFlags::empty(), +// }, +// ), +// Packet::new( +// [2; PACKET_DATA_SIZE], +// Meta { +// size: PACKET_DATA_SIZE, +// addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), +// port: 9999, +// flags: PacketFlags::empty(), +// }, +// ), +// ]); +// let (packet_sender, packet_receiver) = crossbeam_channel::unbounded::(); +// packet_sender.send(packet_batch).unwrap(); + +// let dest_socketaddrs = vec![ +// SocketAddr::from_str("0.0.0.0:32881").unwrap(), +// SocketAddr::from_str("0.0.0.0:33881").unwrap(), +// SocketAddr::from_str("0.0.0.0:34881").unwrap(), +// ]; + +// let test_listeners = dest_socketaddrs +// .iter() +// .map(|socketaddr| { +// ( +// UdpSocket::bind(socketaddr).unwrap(), +// *socketaddr, +// // store results in vec of packet, where packet is Vec +// Arc::new(Mutex::new(vec![])), +// ) +// }) +// .collect::>(); + +// let udp_sender = UdpSocket::bind("0.0.0.0:10000").unwrap(); + +// // spawn listeners +// test_listeners +// .iter() +// .for_each(|(listen_socket, _socketaddr, to_receive)| { +// let socket = listen_socket.try_clone().unwrap(); +// let to_receive = to_receive.to_owned(); +// thread::spawn(move || listen_and_collect(socket, to_receive)); +// }); + +// // send packets +// recv_from_channel_and_send_multiple_dest( +// packet_receiver.recv(), +// &Arc::new(RwLock::new(Deduper::<2, [u8]>::new( +// &mut rand::thread_rng(), +// crate::forwarder::DEDUPER_NUM_BITS, +// ))), +// &udp_sender, +// &Arc::new(dest_socketaddrs), +// accept_all, +// true, +// &Arc::new(ShredMetrics::default()), +// ) +// .unwrap(); + +// // allow packets to be received +// sleep(Duration::from_millis(500)); + +// let received = test_listeners +// .iter() +// .map(|(_, _, results)| results.clone()) +// .collect::>(); + +// // check results +// for received in received.iter() { +// let received = received.lock().unwrap(); +// assert_eq!(received.len(), 2); +// assert!(received +// .iter() +// .all(|packet| packet.len() == PACKET_DATA_SIZE)); +// assert_eq!(received[0], [1; PACKET_DATA_SIZE]); +// assert_eq!(received[1], [2; PACKET_DATA_SIZE]); +// } + +// assert_eq!( +// received +// .iter() +// .fold(0, |acc, elem| acc + elem.lock().unwrap().len()), +// 6 +// ); +// } + +// #[test] +// fn test_dest_filter() { +// let packet_batch = PacketBatch::new(vec![ +// Packet::new( +// [1; PACKET_DATA_SIZE], +// Meta { +// size: PACKET_DATA_SIZE, +// addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), +// port: 48289, // received on random port +// flags: PacketFlags::empty(), +// }, +// ), +// Packet::new( +// [2; PACKET_DATA_SIZE], +// Meta { +// size: PACKET_DATA_SIZE, +// addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), +// port: 9999, +// flags: PacketFlags::empty(), +// }, +// ), +// ]); +// let (packet_sender, packet_receiver) = crossbeam_channel::unbounded::(); +// packet_sender.send(packet_batch).unwrap(); + +// let dest_socketaddrs = vec![ +// SocketAddr::from_str("0.0.0.0:32881").unwrap(), +// SocketAddr::from_str("0.0.0.0:33881").unwrap(), +// SocketAddr::from_str("0.0.0.0:34881").unwrap(), +// ]; + +// let blacklisted = SocketAddr::from_str("0.0.0.0:34881").unwrap(); // none blacklisted + +// let test_listeners = dest_socketaddrs +// .iter() +// .map(|socketaddr| { +// ( +// UdpSocket::bind(socketaddr).unwrap(), +// *socketaddr, +// // store results in vec of packet, where packet is Vec +// Arc::new(Mutex::new(vec![])), +// ) +// }) +// .collect::>(); + +// let udp_sender = UdpSocket::bind("0.0.0.0:10000").unwrap(); + +// // spawn listeners +// test_listeners +// .iter() +// .for_each(|(listen_socket, _socketaddr, to_receive)| { +// let socket = listen_socket.try_clone().unwrap(); +// let to_receive = to_receive.to_owned(); +// thread::spawn(move || listen_and_collect(socket, to_receive)); +// }); + +// // send packets +// recv_from_channel_and_send_multiple_dest( +// packet_receiver.recv(), +// &Arc::new(RwLock::new(Deduper::<2, [u8]>::new( +// &mut rand::thread_rng(), +// crate::forwarder::DEDUPER_NUM_BITS, +// ))), +// &udp_sender, +// &Arc::new(dest_socketaddrs), +// move |_origin, dest: SocketAddr| dest != blacklisted, +// true, +// &Arc::new(ShredMetrics::default()), +// ) +// .unwrap(); + +// // allow packets to be received +// sleep(Duration::from_millis(500)); + +// let received = test_listeners +// .iter() +// .take(test_listeners.len() - 1) // ignore blacklisted +// .map(|(_, _, results)| results.clone()) +// .collect::>(); + +// // check results +// for received in received.iter() { +// let received = received.lock().unwrap(); +// assert_eq!(received.len(), 2); +// assert!(received +// .iter() +// .all(|packet| packet.len() == PACKET_DATA_SIZE)); +// assert_eq!(received[0], [1; PACKET_DATA_SIZE]); +// assert_eq!(received[1], [2; PACKET_DATA_SIZE]); +// } + +// { +// let received = test_listeners[2].2.lock().unwrap(); // ensure blacklisted received nothing +// assert_eq!(received.len(), 0); +// } +// assert_eq!( +// received +// .iter() +// .fold(0, |acc, elem| acc + elem.lock().unwrap().len()), +// 4 +// ); +// } +// } diff --git a/proxy/src/triton_multicast_config.rs b/proxy/src/triton_multicast_config.rs new file mode 100644 index 00000000..c84b92d7 --- /dev/null +++ b/proxy/src/triton_multicast_config.rs @@ -0,0 +1,422 @@ +use std::{ + io, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, num::NonZeroUsize, process::Command +}; + +use itertools::{Itertools, Either}; +use log::{debug, warn, info}; +use serde::Deserialize; +use socket2::{Domain, Protocol, SockAddr, Socket, Type}; + +fn run_ip_json(args: &[&str]) -> io::Result> { + let output = Command::new("ip").args(args).output()?; + if output.status.success() { + Ok(output.stdout) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + format!( + "`ip {}` failed with status {}", + args.join(" "), + output.status + ), + )) + } +} + +/// Parse multicast groups routed to `device` via `ip --json route show dev ` +pub fn get_ip_route_for_device(device: &str) -> io::Result> { + let stdout = run_ip_json(&["--json", "route", "show", "dev", device])?; + parse_ip_route_for_device(&stdout) +} + +// Pure JSON parsers for unit testing +pub fn parse_ip_route_for_device(bytes: &[u8]) -> io::Result> { + #[derive(Debug, Deserialize)] + struct RouteRow { + dst: String, + } + + let mut groups = serde_json::from_slice::>(bytes) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? + .into_iter() + .filter_map(|r| { + if let Some((base, mask_str)) = r.dst.split_once('/') { + let ip: IpAddr = base.parse().ok()?; + let mask: u8 = mask_str.parse().ok()?; + let is_exact = match ip { + IpAddr::V4(_) => mask == 32, // check if full-length mask (not partial) + IpAddr::V6(_) => mask == 128, + }; + (ip.is_multicast() && is_exact).then_some(ip) + } else { + let ip: IpAddr = r.dst.parse().ok()?; + ip.is_multicast().then_some(ip) + } + }) + .collect::>(); + + groups.sort_unstable(); + groups.dedup(); + Ok(groups) +} + +/// Return the primary IPv4 address configured on `device` (if any), via `ip --json addr show`. +pub fn ipv4_addr_for_device(device: &str) -> io::Result> { + let stdout = run_ip_json(&["--json", "addr", "show", "dev", device])?; + parse_ipv4_addr_from_ip_addr_show_json(&stdout) +} + +/// Return the interface index for `device` (if any), via `ip --json link show`. +pub fn ifindex_for_device(device: &str) -> io::Result> { + let stdout = run_ip_json(&["--json", "link", "show", "dev", device])?; + parse_ifindex_from_ip_link_show_json(&stdout) +} + +// Pure JSON parsers for unit testing +pub fn parse_ipv4_addr_from_ip_addr_show_json(bytes: &[u8]) -> io::Result> { + #[derive(Debug, Deserialize)] + struct AddrInfo { + family: Option, + local: Option, + } + #[derive(Debug, Deserialize)] + struct IfaceRow { + addr_info: Option>, + } + + let rows: Vec = + serde_json::from_slice(bytes).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + let ip = rows + .into_iter() + .flat_map(|row| row.addr_info.unwrap_or_default()) + .find_map(|info| { + (info.family.as_deref() == Some("inet")) + .then_some(info.local) + .flatten() + }) + .and_then(|s| s.parse::().ok()); + + Ok(ip) +} + +pub fn parse_ifindex_from_ip_link_show_json(bytes: &[u8]) -> io::Result> { + #[derive(Debug, Deserialize)] + struct LinkRow { + ifindex: Option, + } + + let rows: Vec = + serde_json::from_slice(bytes).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(rows.into_iter().last().and_then(|r| r.ifindex)) +} + + +/// Creates one UDP socket bound on `multicast_port` and joins applicable multicast groups. +/// If `multicast_ip` is provided, join just that group, otherwise parse `ip route list` for +/// entries on `device_name` and join all multicast groups found. +pub fn create_multicast_socket_on_device( + device_name: &str, + multicast_port: u16, + multicast_ip: Option, + num_threads: usize, +) -> std::io::Result<(Vec, Vec)> { + let device_ipv4 = ipv4_addr_for_device(device_name).unwrap_or_else(|e| { + debug!("Failed to resolve IPv4 address for device {device_name}: {e}"); + None + }); + let device_ifindex_v6 = ifindex_for_device(device_name).unwrap_or_else(|e| { + debug!("Failed to resolve ifindex for device {device_name}: {e}"); + None + }); + + let mut multicast_groups: Vec = Vec::new(); + let (groups_v4, groups_v6): (Vec, Vec) = match multicast_ip { + Some(IpAddr::V4(g)) => (vec![g], Vec::new()), + Some(IpAddr::V6(g6)) => (Vec::new(), vec![g6]), + None => match get_ip_route_for_device(device_name) { + Ok(ips) => ips.into_iter().partition_map(|ip| match ip { + IpAddr::V4(v4) => Either::Left(v4), + IpAddr::V6(v6) => Either::Right(v6), + }), + Err(e) => { + debug!("Failed to parse 'ip route list' for {device_name}: {e}"); + (Vec::new(), Vec::new()) + } + }, + }; + + multicast_groups.extend(groups_v4.iter().map(|ipv4| IpAddr::V4(*ipv4))); + multicast_groups.extend(groups_v6.iter().map(|ipv6| IpAddr::V6(*ipv6))); + + if groups_v4.is_empty() && groups_v6.is_empty() { + debug!("No multicast groups found for device {device_name}; skipping multicast listener"); + return Ok((Vec::new(), Vec::new())); + } + + let mut sockets_v4: Vec = Vec::new(); + if !groups_v4.is_empty() { + let addr_v4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), multicast_port); + + let first_socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?; + first_socket.set_reuse_address(true)?; + first_socket.set_reuse_port(true)?; + first_socket.set_nonblocking(true)?; + first_socket.bind(&SockAddr::from(addr_v4))?; + let actual_socket_addr = first_socket.local_addr()?; + + for g in &groups_v4 { + first_socket.join_multicast_v4(g, &device_ipv4.unwrap_or(Ipv4Addr::UNSPECIFIED))?; + info!("Joined IPv4 multicast group {g} port {multicast_port}"); + } + sockets_v4.push(first_socket.into()); + + // Create N-1 additional sockets on the same port for load balancing + for _ in 1..num_threads { + let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?; + socket.set_reuse_address(true)?; + socket.set_reuse_port(true)?; + socket.bind(&actual_socket_addr)?; + socket.set_nonblocking(true)?; + for g in &groups_v4 { + socket.join_multicast_v4(g, &device_ipv4.unwrap_or(Ipv4Addr::UNSPECIFIED))?; + } + sockets_v4.push(socket.into()); + } + } + + let mut sockets_v6: Vec = Vec::new(); + if !groups_v6.is_empty() { + let addr_v6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), multicast_port); + + let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + socket.set_only_v6(true)?; // IPv6-only + socket.set_reuse_address(true)?; + socket.set_reuse_port(true)?; + socket.set_nonblocking(true)?; + socket.bind(&SockAddr::from(addr_v6))?; + let actual_socket_addr = socket.local_addr()?; + for g in &groups_v6 { + socket.join_multicast_v6(g, device_ifindex_v6.unwrap_or(0))?; + info!("Joined IPv6 multicast group {g} port {multicast_port}"); + } + sockets_v6.push(socket.into()); + + // Create N-1 additional sockets on the same port for load balancing + for _ in 1..num_threads { + let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + socket.set_only_v6(true)?; + socket.set_reuse_address(true)?; + socket.set_reuse_port(true)?; + socket.bind(&actual_socket_addr)?; + socket.set_nonblocking(true)?; + for g in &groups_v6 { + socket.join_multicast_v6(g, device_ifindex_v6.unwrap_or(0))?; + } + sockets_v6.push(socket.into()); + } + } + + Ok((sockets_v4, sockets_v6)) +} + +pub struct TritonMulticastConfigV4 { + pub multicast_ip: Ipv4Addr, + pub bind_ifname: Option, + pub listen_port: u16, +} + +pub struct TritonMulticastConfigV6 { + pub multicast_ip: Ipv6Addr, + pub device_ifname: String, + pub listen_port: u16, +} + +pub enum TritonMulticastConfig { + Ipv4(TritonMulticastConfigV4), + Ipv6(TritonMulticastConfigV6), +} + +impl TritonMulticastConfig { + pub fn ip(&self) -> IpAddr { + match self { + TritonMulticastConfig::Ipv4(cfg) => IpAddr::V4(cfg.multicast_ip), + TritonMulticastConfig::Ipv6(cfg) => IpAddr::V6(cfg.multicast_ip), + } + } +} + +pub fn create_multicast_sockets_triton_v4( + config: &TritonMulticastConfigV4, + num_threads: NonZeroUsize, +) -> io::Result> { + let device_ip = match config.bind_ifname.as_ref() { + Some(ifname) => { + ipv4_addr_for_device(ifname)?.ok_or_else(|| + io::Error::new(io::ErrorKind::NotFound, format!("No IPv4 address found for device {ifname}")) + )? + }, + None => Ipv4Addr::UNSPECIFIED, + }; + let port = config.listen_port; + log::info!("multicast device {} has ip {}", config.bind_ifname.as_deref().unwrap_or("unspecified"), device_ip); + // Step 1: Create first socket, port = 0 → random ephemeral port + let first_socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?; + first_socket.set_reuse_address(true)?; + first_socket.set_reuse_port(true)?; + first_socket.set_nonblocking(true)?; + first_socket.bind(&SockAddr::from(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port)))?; + first_socket.join_multicast_v4(&config.multicast_ip, &device_ip)?; + log::info!("Joined multicast group {} on device IP {}", config.multicast_ip, device_ip); + let local_port = first_socket.local_addr()?.as_socket().unwrap().port(); + + // Step 2: Create N-1 sockets using that same port + let mut sockets = Vec::with_capacity(num_threads.get()); + sockets.push(first_socket.into()); + + for _ in 1..num_threads.get() { + let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?; + socket.set_reuse_address(true)?; + socket.set_reuse_port(true)?; + socket.bind(&SockAddr::from(SocketAddr::new( + IpAddr::V4(Ipv4Addr::UNSPECIFIED), + local_port, + )))?; + socket.set_nonblocking(true)?; + socket.join_multicast_v4(&config.multicast_ip, &device_ip)?; + sockets.push(socket.into()); + } + + Ok(sockets) +} + +// fn create_multicast_socket_triton_v6( +// config: &TritonMulticastConfigV6, +// num_threads: usize, +// ) -> Result { +// let TritonMulticastConfigV6 { +// multicast_ip, +// device_ifname, +// } = config; + +// let addrv6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0); +// let socket = UdpSocket::bind(addrv6)?; +// let ifindex = ifindex_for_device(device_ifname)? +// .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, format!("No such device {device_ifname}")))?; +// socket.join_multicast_v6(multicast_ip, ifindex)?; +// Ok(socket) +// } + + +pub fn create_multicast_sockets_triton_v6( + config: &TritonMulticastConfigV6, + num_threads: NonZeroUsize, +) -> io::Result> { + // Get the interface index for the device name (e.g. "eth0") + let ifindex = ifindex_for_device(&config.device_ifname)? + .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, format!("No such device {}", config.device_ifname)))?; + + // Step 1: Bind first socket to port 0 to let kernel choose a random available port + let first_socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + first_socket.set_only_v6(true)?; // IPv6-only + first_socket.set_reuse_address(true)?; + first_socket.set_reuse_port(true)?; + first_socket.set_nonblocking(true)?; + first_socket.bind(&SockAddr::from(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), config.listen_port)))?; + first_socket.join_multicast_v6(&config.multicast_ip, ifindex)?; + let local_port = first_socket.local_addr()?.as_socket().unwrap().port(); + // Step 2: Create N-1 additional sockets on the same port for load balancing + let mut sockets = Vec::with_capacity(num_threads.get()); + sockets.push(first_socket.into()); + + for _ in 1..num_threads.get() { + let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + socket.set_only_v6(true)?; + socket.set_reuse_address(true)?; + socket.set_reuse_port(true)?; + socket.bind(&SockAddr::from(SocketAddr::new( + IpAddr::V6(Ipv6Addr::UNSPECIFIED), + local_port, + )))?; + socket.set_nonblocking(true)?; + socket.join_multicast_v6(&config.multicast_ip, ifindex)?; + sockets.push(socket.into()); + } + + Ok(sockets) +} + +pub fn create_multicast_sockets_triton( + config: &TritonMulticastConfig, + num_threads: NonZeroUsize, +) -> Result, io::Error> { + + match config { + TritonMulticastConfig::Ipv4(cfg) => { + create_multicast_sockets_triton_v4(cfg, num_threads) + } + TritonMulticastConfig::Ipv6(cfg) => { + create_multicast_sockets_triton_v6(cfg, num_threads) + } + } +} + + +#[cfg(test)] +mod tests { + use std::net::Ipv4Addr; + + use super::{ + parse_ifindex_from_ip_link_show_json, parse_ip_route_for_device, + parse_ipv4_addr_from_ip_addr_show_json, + }; + + #[test] + fn parse_ip_route_for_device_test() { + let json = r#"[{"dst":"169.254.2.112/31","protocol":"kernel","scope":"link","prefsrc":"169.254.2.113","flags":[]},{"dst":"233.84.178.2","gateway":"169.254.2.112","protocol":"static","flags":[]}]"#; + let parsed = parse_ip_route_for_device(json.as_bytes()).unwrap(); + assert_eq!(parsed, vec![Ipv4Addr::new(233, 84, 178, 2)]); + } + + #[test] + fn parse_ipv4_addr_from_addr() { + let json = r#"[ + {"addr_info":[ + {"family":"inet6","local":"fe80::1234"}, + {"family":"inet","local":"192.168.1.10"} + ]} + ]"#; + let parsed = parse_ipv4_addr_from_ip_addr_show_json(json.as_bytes()).unwrap(); + assert_eq!(parsed, Some(Ipv4Addr::new(192, 168, 1, 10))); + } + + #[test] + fn parse_ipv4_addr_from_addr_show_malformed() { + let json = r#"{"not":"an array"}"#; + let res = parse_ipv4_addr_from_ip_addr_show_json(json.as_bytes()); + assert!(res.is_err()); + } + + #[test] + fn parse_ifindex_from_link_show_present() { + let json = r#"[ + {"ifindex":3,"ifname":"lo"} + ]"#; + let parsed = parse_ifindex_from_ip_link_show_json(json.as_bytes()).unwrap(); + assert_eq!(parsed, Some(3)); + } + + #[test] + fn parse_ifindex_from_link_show_empty() { + let json = r#"[]"#; + let parsed = parse_ifindex_from_ip_link_show_json(json.as_bytes()).unwrap(); + assert_eq!(parsed, None); + } + + #[test] + fn parse_ifindex_from_link_show_malformed() { + let json = r#"{"ifindex":3}"#; + let res = parse_ifindex_from_ip_link_show_json(json.as_bytes()); + assert!(res.is_err()); + } +} diff --git a/run-triton-proxy.sh b/run-triton-proxy.sh new file mode 100644 index 00000000..9187c0c4 --- /dev/null +++ b/run-triton-proxy.sh @@ -0,0 +1,13 @@ + +# send random traffic +sudo nping --udp -p 8002 -c 0 --rate 1 --data-length 1200 127.0.0.1 + +while true; do + # Generate 1200 random bytes from /dev/urandom and send them + head -c 1200 /dev/urandom | socat - UDP:127.0.0.1:8002 + sleep 0.01 # Add a small delay between packets +done + + +# run triton-proxy +cargo run --bin triton-proxy -- forward-only --src-bind-addr 127.0.0.1 --src-bind-port 8002 --prometheus-bind-addr 127.0.0.1:9999 --dest-ip-ports 127.0.0.1:8989 \ No newline at end of file diff --git a/scripts/build-dist.sh b/scripts/build-dist.sh new file mode 100755 index 00000000..d8b4ae01 --- /dev/null +++ b/scripts/build-dist.sh @@ -0,0 +1,9 @@ +#!/bin/bash + + +mkdir -p dist +rm -rf dist/* + +cargo build --release --bin triton-shredproxy + +mv target/release/triton-shredproxy dist/triton-shredproxy-ubuntu-22.04 \ No newline at end of file diff --git a/setup_net.sh b/setup_net.sh new file mode 100755 index 00000000..da732aec --- /dev/null +++ b/setup_net.sh @@ -0,0 +1,26 @@ +set -e +# 1. Create the namespace +sudo ip netns add ns1 + +# 2. Create the virtual cable (veth0 <-> veth1) +sudo ip link add veth0 type veth peer name veth1 + +# 3. Move veth1 end into the namespace +sudo ip link set veth1 netns ns1 + +# 4. Assign IP to Host side (veth0) - The Driver +sudo ip addr add 172.31.0.1/24 dev veth0 +sudo ip link set veth0 up + +# 5. Assign IP to Namespace side (veth1) - The Server +sudo ip netns exec ns1 ip addr add 172.31.0.2/24 dev veth1 +sudo ip netns exec ns1 ip link set veth1 up + +# 6. Turn off offloading (Crucial for AF_XDP) +#sudo ethtool -K veth0 gro off +#sudo ip netns exec ns1 ethtool -K veth1 gro off + +#nping --udp -p 1234 --data-length 500 -c 10 + +# 7. Verify connectivity +ping -c 2 172.31.0.2 \ No newline at end of file