From 16f9c884fc49430f14a426e8efc48d5b61f3a788 Mon Sep 17 00:00:00 2001 From: Louis-Vincent Boudreault Date: Tue, 13 Jan 2026 02:39:54 +0000 Subject: [PATCH 01/13] wip --- Cargo.lock | 5 +- Cargo.toml | 1 + proxy/Cargo.toml | 1 + proxy/src/main.rs | 3 + proxy/src/mem.rs | 556 ++++++++++++++++ proxy/src/recv_mmsg.rs | 317 ++++++++++ proxy/src/triton_forwarder.rs | 1119 +++++++++++++++++++++++++++++++++ 7 files changed, 2000 insertions(+), 2 deletions(-) create mode 100644 proxy/src/mem.rs create mode 100644 proxy/src/recv_mmsg.rs create mode 100644 proxy/src/triton_forwarder.rs diff --git a/Cargo.lock b/Cargo.lock index 0eaea59..1d3f992 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -902,9 +902,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" [[package]] name = "bzip2" @@ -2813,6 +2813,7 @@ dependencies = [ "arc-swap", "bincode", "borsh 1.5.7", + "bytes", "clap", "crossbeam-channel", "dashmap", diff --git a/Cargo.toml b/Cargo.toml index b72a26d..54f0dc7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ ahash = "0.8" arc-swap = "1.6" bincode = "1.3.3" borsh = "1.5.3" +bytes = "1.11.0" clap = { version = "4", features = ["derive", "env"] } crossbeam-channel = "0.5.8" dashmap = "5" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 163260b..3c7e1a9 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -11,6 +11,7 @@ ahash = { workspace = true } arc-swap = { workspace = true } bincode = { workspace = true } borsh = { workspace = true } +bytes = { workspace = true } clap = { workspace = true } crossbeam-channel = { workspace = true } dashmap = { workspace = true } diff --git a/proxy/src/main.rs b/proxy/src/main.rs index d04e333..edb8deb 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -31,6 +31,9 @@ mod multicast_config; mod server; mod token_authenticator; mod prom; +// mod recv_mmsg; +mod mem; +// mod triton_forwarder; #[derive(Clone, Debug, Parser)] #[clap(author, version, about, long_about = None)] diff --git a/proxy/src/mem.rs b/proxy/src/mem.rs new file mode 100644 index 0000000..0c2a7df --- /dev/null +++ b/proxy/src/mem.rs @@ -0,0 +1,556 @@ +use std::{ + cell::UnsafeCell, + hint::spin_loop, + ops::{Index, IndexMut}, + sync::{ + atomic::{AtomicI32, AtomicUsize, Ordering}, + Arc, + }, + thread::{self, Thread}, +}; + +use bytes::{Buf, BufMut, buf::UninitSlice}; + +#[derive(Debug, thiserror::Error)] +#[error("allocation error")] +pub struct AllocError; + +#[repr(C)] +pub struct SharedMem { + ptr: *mut u8, + aligned_size: usize, + capacity: usize, +} + +pub fn try_alloc_shared_mem( + align: usize, + capacity: usize, + huge: bool, +) -> Result<*mut u8, AllocError> { + // assert!(align.is_power_of_two(), "alignment must be a power of two"); + assert!( + capacity.is_power_of_two(), + "capacity must be a power of two" + ); + let aligned_size = capacity * align; + let ptr = unsafe { + libc::mmap( + std::ptr::null_mut(), + aligned_size, + libc::PROT_READ | libc::PROT_WRITE, + libc::MAP_SHARED | libc::MAP_ANONYMOUS | if huge { libc::MAP_HUGETLB } else { 0 }, + -1, + 0, + ) + }; + + if std::ptr::eq(ptr, libc::MAP_FAILED) { + return Err(AllocError); + } + + // zero initialize the memory + unsafe { + std::ptr::write_bytes(ptr as *mut u8, 0, aligned_size); + } + + Ok(ptr as *mut u8) +} + + + +impl SharedMem { + fn new(element_size: usize, capacity: usize, huge: bool) -> Result { + let ptr = try_alloc_shared_mem(element_size, capacity, huge)?; + let aligned_size = capacity * element_size; + + Ok(Self { + ptr, + aligned_size, + capacity, + }) + } + + fn dealloc(&self) { + unsafe { + libc::munmap(self.ptr as *mut libc::c_void, self.aligned_size); + } + } +} + +impl Drop for SharedMem { + fn drop(&mut self) { + self.dealloc(); + } +} + +#[derive(Debug)] +#[repr(C, align(16))] +pub struct FrameDesc { + pub ptr: *mut u8, + pub frame_size: usize, +} + +#[derive(Debug)] +#[repr(C, align(32))] +pub struct FrameBufMut { + ptr: *mut u8, + desc: FrameDesc, +} + +unsafe impl Send for FrameBufMut {} + +#[derive(Debug)] +#[repr(C, align(32))] +pub struct FrameBuf { + curr_ptr: *mut u8, + len: usize, + desc: FrameDesc, +} + +impl FrameBuf { + #[inline] + pub fn len(&self) -> usize { + let end = unsafe { self.desc.ptr.add(self.len) }; + (end as usize) - (self.curr_ptr as usize) + } +} + +unsafe impl Send for FrameBuf {} + +impl From for FrameBuf { + fn from(buf_mut: FrameBufMut) -> Self { + let len = (buf_mut.ptr as usize) - (buf_mut.desc.ptr as usize); + Self { + curr_ptr: buf_mut.desc.ptr, + len, + desc: buf_mut.desc, + } + } +} + +impl FrameDesc { + pub fn as_mut_buf(&self) -> FrameBufMut { + FrameBufMut { + ptr: self.ptr, + desc: FrameDesc { + ptr: self.ptr, + frame_size: self.frame_size, + }, + } + } +} + +impl From for FrameBufMut { + fn from(desc: FrameDesc) -> Self { + Self { + ptr: desc.ptr, + desc, + } + } +} + + +impl FrameBufMut { + #[inline] + pub fn base(&self) -> *mut u8 { + ((self.ptr as usize) & !(self.desc.frame_size - 1)) as *mut u8 + } + + pub fn len(&self) -> usize { + let base = self.base() as usize; + (self.ptr as usize) - base + } + + #[inline] + pub fn capacity(&self) -> usize { + self.desc.frame_size + } + + #[inline] + pub fn cast_to(&self) -> *mut T { + self.base() as *mut T + } + + #[inline] + fn end_ptr(&self) -> *const u8 { + unsafe { self.base().add(self.capacity()) } + } +} + +unsafe impl BufMut for FrameBufMut { + fn remaining_mut(&self) -> usize { + // given that ptr must always aligned with `frame_align`, + // we just be able to infer the remaining mut size from frame_align + let frame_offset = (self.ptr as usize) & (self.desc.frame_size - 1); + self.desc.frame_size - frame_offset + } + + unsafe fn advance_mut(&mut self, cnt: usize) { + let new_ptr = self.ptr.add(cnt); + assert!( + new_ptr as *const u8 <= self.end_ptr(), + "advance_mut out of bounds" + ); + self.ptr = new_ptr; + } + + fn chunk_mut(&mut self) -> &mut bytes::buf::UninitSlice { + unsafe { UninitSlice::from_raw_parts_mut(self.ptr, self.remaining_mut()) } + } +} + +impl Buf for FrameBuf { + fn remaining(&self) -> usize { + let end = unsafe { self.desc.ptr.add(self.len) }; + (end as usize) - (self.curr_ptr as usize) + } + + fn chunk(&self) -> &[u8] { + unsafe { std::slice::from_raw_parts(self.curr_ptr, self.remaining()) } + } + + fn advance(&mut self, cnt: usize) { + let new_ptr = unsafe { self.curr_ptr.add(cnt) }; + let end = unsafe { self.desc.ptr.add(self.len) }; + assert!( + new_ptr as *const u8 <= end, + "advance out of bounds" + ); + self.curr_ptr = new_ptr; + } +} + +use std::{ptr, sync::atomic::AtomicBool}; + +// We wrap T to include a 'ready' flag for each slot +#[repr(C, align(64))] +struct Slot { + data: std::mem::MaybeUninit, + is_ready: AtomicBool, +} + +struct RingInner { + buf: *mut Slot, // Changed to Slot + capacity: usize, + mask: usize, + head: AtomicUsize, // Producer index (reserved) + tail: AtomicUsize, // Consumer index + futex_flag: AtomicI32, + shmem: Option, +} + +impl Drop for RingInner { + fn drop(&mut self) { + if let Some(shmem) = self.shmem.take() { + let mut tail = self.tail.load(Ordering::Acquire); + let head = self.head.load(Ordering::Acquire); + + // Drop initialized slots + while tail != head { + unsafe { + let slot = &mut *self.buf.add(tail & self.mask); + if slot.is_ready.load(Ordering::Acquire) { + ptr::drop_in_place(slot.data.as_mut_ptr()); + } + } + tail = tail.wrapping_add(1); + } + + drop(shmem); + } + } +} + +unsafe impl Send for RingInner {} +unsafe impl Sync for RingInner {} + +pub struct Tx { + inner: Arc>, +} + +impl Clone for Tx { + fn clone(&self) -> Self { + Self { + inner: Arc::clone(&self.inner), + } + } +} + +pub struct Rx { + inner: Arc>, +} + +pub fn message_ring(capacity: usize) -> Result<(Tx, Rx), AllocError> { + let capacity = capacity.next_power_of_two(); + let align = std::mem::size_of::>(); + + // Allocate memory for Slots + let shmem = SharedMem::new(align, capacity, false)?; + let ptr = shmem.ptr as *mut Slot; + // Initialize the is_ready flags to false + for i in 0..capacity { + unsafe { + let slot_ptr = ptr.add(i); + ptr::write(&mut (*slot_ptr).is_ready, AtomicBool::new(false)); + } + } + + let inner = Arc::new(RingInner { + buf: ptr, + capacity, + mask: capacity - 1, + head: AtomicUsize::new(0), + tail: AtomicUsize::new(0), + futex_flag: AtomicI32::new(0), + shmem: Some(shmem), + }); + + Ok(( + Tx { + inner: Arc::clone(&inner), + }, + Rx { inner }, + )) +} + +impl Tx { + pub fn send(&self, value: T) -> Result<(), T> { + loop { + // 1. Load head and tail to check if the ring is full. + // head: Relaxed is okay here because it's only a hint for the CAS. + // tail: Acquire is REQUIRED to ensure we don't overwrite data + // the consumer hasn't finished reading yet. + let head = self.inner.head.load(Ordering::Relaxed); + let tail = self.inner.tail.load(Ordering::Acquire); + + if head.wrapping_sub(tail) >= self.inner.capacity { + return Err(value); // Ring is full + } + + // 2. Claim a slot using Compare-and-Swap (CAS). + // We use SeqCst or AcqRel here to ensure that once we "win" this slot, + // we have a synchronized view of the memory. + if self + .inner + .head + .compare_exchange_weak(head, head + 1, Ordering::AcqRel, Ordering::Relaxed) + .is_ok() + { + unsafe { + // 3. Calculate slot location. + let slot = &*self.inner.buf.add(head & self.inner.mask); + + // 4. Write the data into the MaybeUninit. + // We use .write() which is a wrapper for ptr::write. + ptr::write(slot.data.as_ptr() as *mut T, value); + + // 5. RELEASE the data to the consumer. + // This store ensures the data write above is visible to + // any thread that performs an Acquire load on is_ready. + slot.is_ready.store(true, Ordering::Release); + } + + // 6. Futex Wake Logic. + // If the consumer is sleeping (futex_flag == 0), we wake them. + // We use Release to ensure the flag update is visible. + if self.inner.futex_flag.swap(1, Ordering::Release) == 0 { + unsafe { + libc::syscall( + libc::SYS_futex, + &self.inner.futex_flag as *const AtomicI32, + libc::FUTEX_WAKE, + 1, // Wake 1 thread + ); + } + } + return Ok(()); + } + // If CAS failed, another producer grabbed 'head'. + // The loop will retry with the new head value. + std::hint::spin_loop(); + } + } +} + +impl Rx { + pub fn recv(&mut self) -> T { + for _ in 0..999 { + if let Some(val) = self.try_recv() { + return val; + } + spin_loop(); + } + + loop { + if let Some(val) = self.try_recv() { + return val; + } + + self.inner.futex_flag.store(0, Ordering::SeqCst); + + if let Some(val) = self.try_recv() { + return val; + } + + unsafe { + libc::syscall( + libc::SYS_futex, + &self.inner.futex_flag as *const AtomicI32, + libc::FUTEX_WAIT, + 0, + std::ptr::null::(), + ); + } + } + } + + pub fn try_recv(&mut self) -> Option { + let tail = self.inner.tail.load(Ordering::Relaxed); + + unsafe { + let slot = &*self.inner.buf.add(tail & self.inner.mask); + + // IMPORTANT: In MPSC, even if head > tail, the data at tail might + // not be written yet because the producer was interrupted. + if !slot.is_ready.load(Ordering::Acquire) { + return None; + } + + let val = ptr::read(slot.data.as_ptr()); + + // Reset the flag for the next time this slot is used + slot.is_ready.store(false, Ordering::Release); + + // Increment tail to free the slot + self.inner.tail.store(tail + 1, Ordering::Release); + Some(val) + } + } +} + +#[cfg(test)] +mod tests { + use std::{collections::HashSet, sync::Barrier}; + + use super::*; + + #[test] + fn test_mpsc_contention() { + let capacity = 1024; + let (tx, mut rx) = message_ring::(capacity).unwrap(); + + let num_producers = 4; + let msgs_per_producer = 1000; + let barrier = Arc::new(Barrier::new(num_producers + 1)); + let mut handles = Vec::new(); + + // Start Producers + for p in 0..num_producers { + let tx_clone = tx.clone(); + let b_clone = Arc::clone(&barrier); + handles.push(thread::spawn(move || { + b_clone.wait(); // Synchronize start + for i in 0..msgs_per_producer { + let val = p * 10000 + i; + while tx_clone.send(val).is_err() { + spin_loop(); // Wait if ring is full + } + } + })); + } + + barrier.wait(); // Start everyone at once + + let mut received = HashSet::new(); + let total_expected = num_producers * msgs_per_producer; + + for _ in 0..total_expected { + received.insert(rx.recv()); + } + + assert_eq!(received.len(), total_expected); + for h in handles { + h.join().unwrap(); + } + } + + #[test] + fn test_frame_buffer_lifecycle() { + let align = 4096; + let capacity = 1; + // 1. Setup the memory pool + let mem = SharedMem::new(align, capacity, false).unwrap(); + + // At this point, the fill_ring inside PagedAlignedMem logic + // should have been populated. Let's create our own handles for testing. + let (tx_fill, mut rx_fill) = message_ring::(capacity).unwrap(); + let (rx_tx, mut rx_rx) = message_ring::(capacity).unwrap(); + // Manually push frames into our test fill ring + for i in 0..capacity { + tx_fill + .send(FrameDesc { + ptr: unsafe { mem.ptr.add(i * align) }, + frame_size: align, + }) + .unwrap(); + } + + // 2. Simulate taking a frame from the pool + let desc = rx_fill.recv(); + let expected_ptr = desc.ptr; + println!("Received frame at ptr: {:p}", expected_ptr); + let mut buf = desc.as_mut_buf(); + assert_eq!(buf.remaining_mut(), 4096); + buf.put_u32(0xDEADBEEF); + assert_eq!(buf.remaining_mut(), 4092); + + rx_tx.send(desc).unwrap(); + + // 3. Verify the frame returned to the fill ring + let returned_desc = rx_rx.recv(); + assert_eq!(returned_desc.ptr, expected_ptr); + // 4. Verify the frame is zeroed out + } + + #[test] + fn test_blocking_recv() { + let (tx, mut rx) = message_ring::(16).unwrap(); + + let handle = thread::spawn(move || { + thread::sleep(std::time::Duration::from_millis(200)); + tx.send(42).unwrap(); + }); + + let start = std::time::Instant::now(); + let val = rx.recv(); // Should block for ~200ms + + assert_eq!(val, 42); + assert!(start.elapsed().as_millis() >= 200); + handle.join().unwrap(); + } + + #[test] + fn test_buf_and_bufmut_impls() { + let frame_size = 4096; + let shmem = SharedMem::new(frame_size, 1, false).unwrap(); + let desc = FrameDesc { + ptr: shmem.ptr, + frame_size, + }; + + let mut buf_mut: FrameBufMut = desc.into(); + assert_eq!(buf_mut.remaining_mut(), 4096); + buf_mut.put_slice(&[1, 2, 3, 4]); + assert_eq!(buf_mut.remaining_mut(), 4092); + assert_eq!(buf_mut.len(), 4); + + let mut buf: FrameBuf = buf_mut.into(); + assert_eq!(buf.len(), 4); + assert_eq!(buf.remaining(), 4); + let chunk = buf.chunk(); + assert_eq!(chunk, &[1, 2, 3, 4]); + buf.advance(4); + assert_eq!(buf.remaining(), 0); + assert_eq!(buf.len(), 0) + } +} diff --git a/proxy/src/recv_mmsg.rs b/proxy/src/recv_mmsg.rs new file mode 100644 index 0000000..2c030c8 --- /dev/null +++ b/proxy/src/recv_mmsg.rs @@ -0,0 +1,317 @@ +use bytes::BufMut; +use itertools::izip; +use libc::{AF_INET, AF_INET6, MSG_WAITFORONE, iovec, mmsghdr, msghdr, sockaddr_storage}; +use socket2::socklen_t; +use solana_perf::packet::{NUM_RCVMMSGS, PACKETS_PER_BATCH}; +use solana_sdk::packet::{Meta, PACKET_DATA_SIZE, Packet}; +use log::{error, trace}; +use solana_streamer::{recvmmsg::recv_mmsg, streamer::StreamerReceiveStats}; +use std::{ + cmp, collections::VecDeque, io, mem::{self, MaybeUninit, zeroed}, net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, UdpSocket}, os::fd::AsRawFd, sync::atomic::{AtomicBool, Ordering}, time::{Duration, Instant} +}; + +use crate::mem::{FrameBuffer, FrameDesc, PagedAlignedMem, Rx, Tx, try_alloc_shared_mem}; + + +pub struct RecvMemConfig { + pub frames_count: usize, + pub hugepages: bool, +} + + +fn recv_loop( + socket: &UdpSocket, + exit: &AtomicBool, + stats: &StreamerReceiveStats, + coalesce: Duration, + mem_config: &RecvMemConfig, +) -> std::io::Result<()> { + + let data_shmem = try_alloc_shared_mem(PACKET_DATA_SIZE.next_power_of_two(), mem_config.frames_count, mem_config.hugepages).expect("try_alloc_shared_mem"); + let mut packet_batch = Vec::with_capacity(PACKETS_PER_BATCH); + loop { + // Check for exit signal, even if socket is busy + // (for instance the leader transaction socket) + if exit.load(Ordering::Relaxed) { + return Ok(()); + } + + + if let Ok(len) = recv_from(&mut packet_batch, socket, coalesce) { + if len > 0 { + let StreamerReceiveStats { + packets_count, + packet_batches_count, + full_packet_batches_count, + max_channel_len, + .. + } = stats; + + packets_count.fetch_add(len, Ordering::Relaxed); + packet_batches_count.fetch_add(1, Ordering::Relaxed); + max_channel_len.fetch_max(packet_batch_sender.len(), Ordering::Relaxed); + if len == PACKETS_PER_BATCH { + full_packet_batches_count.fetch_add(1, Ordering::Relaxed); + } + packet_batch + .iter_mut() + .for_each(|p| p.meta_mut().set_from_staked_node(is_staked_service)); + packet_batch_sender.send(packet_batch)?; + } + break; + } + } +} + + + +pub fn recv_from( + fill_ring_rx: &mut Rx, + fill_ring_tx: &Tx, + socket: &UdpSocket, + max_wait: Duration, + batch: &mut Vec +) -> std::io::Result { + // let mut i: usize = 0; + //DOCUMENTED SIDE-EFFECT + //Performance out of the IO without poll + // * block on the socket until it's readable + // * set the socket to non blocking + // * read until it fails + // * set it back to blocking before returning + socket.set_nonblocking(false)?; + trace!("receiving on {}", socket.local_addr().unwrap()); + let start = Instant::now(); + + assert!(batch.capacity() >= PACKETS_PER_BATCH); + + struct Defer<'a> { + i: usize, + allocated_frame: usize, + batch: &'a mut Vec, + }; + + impl Drop for Defer<'_> { + fn drop(&mut self) { + // Return unused frames to the fill ring + let exceeding_allocs = self.allocated_frame.saturating_sub(self.i); + (0..exceeding_allocs).for_each(|_| { + if let Some(unused_buffer) = self.batch.pop() { + drop(unused_buffer); + } + }); + self.allocated_frame = 0; + } + } + + let mut defer = Defer { + i: 0, + allocated_frame: 0, + batch, + }; + + loop { + + let frame_desc = fill_ring_rx.recv(); + let buffer = FrameBuffer::new(frame_desc, fill_ring_tx.clone()); + defer.allocated_frame += 1; + defer.batch[defer.i] = TritonPacket::new(buffer); + + let mut j = defer.i + 1; + while j < PACKETS_PER_BATCH { + + let Some(frame_desc) = fill_ring_rx.try_recv() else { + break; + }; + let buffer = FrameBuffer::new(frame_desc, fill_ring_tx.clone()); + defer.batch[j] = TritonPacket::new(buffer); + defer.allocated_frame += 1; + j += 1; + + } + + match triton_recv_mmsg(socket, &mut defer.batch[defer.i..j]) { + Err(_) if defer.i > 0 => { + if start.elapsed() > max_wait { + break; + } + } + Err(e) => { + trace!("recv_from err {:?}", e); + return Err(e); + } + Ok(npkts) => { + if defer.i == 0 { + socket.set_nonblocking(true)?; + } + trace!("got {} packets", npkts); + defer.i += npkts; + // Try to batch into big enough buffers + // will cause less re-shuffling later on. + if start.elapsed() > max_wait || defer.i >= PACKETS_PER_BATCH { + break; + } + } + } + } + + Ok(defer.i) +} + +pub struct TritonPacket { + pub buffer: FrameBuffer, + pub meta: Meta, +} + +impl TritonPacket { + pub fn new(buffer: FrameBuffer) -> Self { + Self { + buffer, + meta: Meta::default(), + } + } + + pub fn meta_mut(&mut self) -> &mut Meta { + &mut self.meta + } +} + + +pub fn triton_recv_mmsg(sock: &UdpSocket, packets: &mut [TritonPacket]) -> io::Result { + // Should never hit this, but bail if the caller didn't provide any Packets + // to receive into + if packets.is_empty() { + return Ok(0); + } + // Assert that there are no leftovers in packets. + const SOCKADDR_STORAGE_SIZE: socklen_t = mem::size_of::() as socklen_t; + + let mut iovs = [MaybeUninit::uninit(); NUM_RCVMMSGS]; + let mut addrs = [MaybeUninit::zeroed(); NUM_RCVMMSGS]; + let mut hdrs = [MaybeUninit::uninit(); NUM_RCVMMSGS]; + + let sock_fd = sock.as_raw_fd(); + let count = cmp::min(iovs.len(), packets.len()); + + for (packet, hdr, iov, addr) in + izip!(packets.iter_mut(), &mut hdrs, &mut iovs, &mut addrs).take(count) + { + let buffer = packet.buffer.base(); + iov.write(iovec { + iov_base: buffer as *mut libc::c_void, + iov_len: PACKET_DATA_SIZE, + }); + + let msg_hdr = create_msghdr(addr, SOCKADDR_STORAGE_SIZE, iov); + + hdr.write(mmsghdr { + msg_len: 0, + msg_hdr, + }); + } + + let mut ts = libc::timespec { + tv_sec: 1, + tv_nsec: 0, + }; + // TODO: remove .try_into().unwrap() once rust libc fixes recvmmsg types for musl + #[allow(clippy::useless_conversion)] + let nrecv = unsafe { + libc::recvmmsg( + sock_fd, + hdrs[0].assume_init_mut(), + count as u32, + MSG_WAITFORONE.try_into().unwrap(), + &mut ts, + ) + }; + let nrecv = if nrecv < 0 { + return Err(io::Error::last_os_error()); + } else { + usize::try_from(nrecv).unwrap() + }; + for (addr, hdr, pkt) in izip!(addrs, hdrs, packets.iter_mut()).take(nrecv) { + // SAFETY: We initialized `count` elements of `hdrs` above. `count` is + // passed to recvmmsg() as the limit of messages that can be read. So, + // `nrevc <= count` which means we initialized this `hdr` and + // recvmmsg() will have updated it appropriately + let hdr_ref = unsafe { hdr.assume_init_ref() }; + // SAFETY: Similar to above, we initialized this `addr` and recvmmsg() + // will have populated it + let addr_ref = unsafe { addr.assume_init_ref() }; + pkt.meta_mut().size = hdr_ref.msg_len as usize; + if let Some(addr) = cast_socket_addr(addr_ref, hdr_ref) { + pkt.meta_mut().set_socket_addr(&addr); + } + } + + for (iov, addr, hdr) in izip!(&mut iovs, &mut addrs, &mut hdrs).take(count) { + // SAFETY: We initialized `count` elements of each array above + // + // It may be that `packets.len() != NUM_RCVMMSGS`; thus, some elements + // in `iovs` / `addrs` / `hdrs` may not get initialized. So, we must + // manually drop `count` elements from each array instead of being able + // to convert [MaybeUninit] to [T] and letting `Drop` do the work + // for us when these items go out of scope at the end of the function + unsafe { + iov.assume_init_drop(); + addr.assume_init_drop(); + hdr.assume_init_drop(); + } + } + + Ok(nrecv) +} + + +fn create_msghdr( + msg_name: &mut MaybeUninit, + msg_namelen: socklen_t, + iov: &mut MaybeUninit, +) -> msghdr { + // Cannot construct msghdr directly on musl + // See https://github.com/rust-lang/libc/issues/2344 for more info + let mut msg_hdr: msghdr = unsafe { zeroed() }; + msg_hdr.msg_name = msg_name.as_mut_ptr() as *mut _; + msg_hdr.msg_namelen = msg_namelen; + msg_hdr.msg_iov = iov.as_mut_ptr(); + msg_hdr.msg_iovlen = 1; + msg_hdr.msg_control = std::ptr::null::() as *mut _; + msg_hdr.msg_controllen = 0; + msg_hdr.msg_flags = 0; + msg_hdr +} + + +fn cast_socket_addr(addr: &sockaddr_storage, hdr: &mmsghdr) -> Option { + use libc::{sa_family_t, sockaddr_in, sockaddr_in6}; + const SOCKADDR_IN_SIZE: usize = std::mem::size_of::(); + const SOCKADDR_IN6_SIZE: usize = std::mem::size_of::(); + if addr.ss_family == AF_INET as sa_family_t + && hdr.msg_hdr.msg_namelen == SOCKADDR_IN_SIZE as socklen_t + { + // ref: https://github.com/rust-lang/socket2/blob/65085d9dff270e588c0fbdd7217ec0b392b05ef2/src/sockaddr.rs#L167-L172 + let addr = unsafe { &*(addr as *const _ as *const sockaddr_in) }; + return Some(SocketAddr::V4(SocketAddrV4::new( + Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes()), + u16::from_be(addr.sin_port), + ))); + } + if addr.ss_family == AF_INET6 as sa_family_t + && hdr.msg_hdr.msg_namelen == SOCKADDR_IN6_SIZE as socklen_t + { + // ref: https://github.com/rust-lang/socket2/blob/65085d9dff270e588c0fbdd7217ec0b392b05ef2/src/sockaddr.rs#L174-L189 + let addr = unsafe { &*(addr as *const _ as *const sockaddr_in6) }; + return Some(SocketAddr::V6(SocketAddrV6::new( + Ipv6Addr::from(addr.sin6_addr.s6_addr), + u16::from_be(addr.sin6_port), + addr.sin6_flowinfo, + addr.sin6_scope_id, + ))); + } + error!( + "recvmmsg unexpected ss_family:{} msg_namelen:{}", + addr.ss_family, hdr.msg_hdr.msg_namelen + ); + None +} diff --git a/proxy/src/triton_forwarder.rs b/proxy/src/triton_forwarder.rs new file mode 100644 index 0000000..94261d7 --- /dev/null +++ b/proxy/src/triton_forwarder.rs @@ -0,0 +1,1119 @@ +use std::{ + collections::HashSet, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, + sync::{ + atomic::{AtomicBool, AtomicU64, Ordering}, + Arc, RwLock, + }, + thread::{Builder, JoinHandle}, + time::{Duration, Instant, SystemTime}, +}; + +use arc_swap::ArcSwap; +use crossbeam_channel::{Receiver, RecvError}; +use dashmap::DashMap; +use itertools::Itertools; +use jito_protos::shredstream::{Entry as PbEntry, TraceShred}; +use log::{debug, error, info, warn}; +use libc; +use prost::Message; +use socket2::{Domain, Protocol, Socket, Type}; +use solana_client::client_error::reqwest; +use solana_metrics::{datapoint_info, datapoint_warn}; +use solana_net_utils::SocketConfig; +use solana_perf::{ + deduper::Deduper, + packet::{PacketBatch, PacketBatchRecycler}, + recycler::Recycler, +}; +use solana_streamer::{ + sendmmsg::{batch_send, SendPktsError}, + streamer::{self, StreamerReceiveStats}, +}; +use tokio::sync::broadcast::Sender; + +use crate::{ + ShredstreamProxyError, prom::{ + inc_packets_by_source, inc_packets_deduped, inc_packets_forward_failed, inc_packets_forwarded, inc_packets_received, observe_dedup_time, observe_recv_interval, observe_recv_packet_count, observe_send_duration, observe_send_packet_count + }, resolve_hostname_port +}; + +// values copied from https://github.com/solana-labs/solana/blob/33bde55bbdde13003acf45bb6afe6db4ab599ae4/core/src/sigverify_shreds.rs#L20 +pub const DEDUPER_FALSE_POSITIVE_RATE: f64 = 0.001; +pub const DEDUPER_NUM_BITS: u64 = 637_534_199; // 76MB +pub const DEDUPER_RESET_CYCLE: Duration = Duration::from_secs(5 * 60); +pub const IP_MULTICAST_TTL: u32 = 8; + + +fn spawn_packet_receiver( + num_receiver: usize, + src_addr: IpAddr, + src_port: u16, + packet_sender: crossbeam_channel::Sender, + exit: Arc, + recycler: PacketBatchRecycler, + forwarder_stats: Arc, + threads: &mut Vec>, +) { + + let (_port, sockets) = solana_net_utils::multi_bind_in_range_with_config( + src_addr, + (src_port, src_port + 1), + SocketConfig::default().reuseport(true), + num_receiver, + ) + .unwrap_or_else(|_| { + panic!("Failed to bind listener sockets. Check that port {src_port} is not in use.") + }); + + for (thread_id, socket) in sockets.into_iter().enumerate() { + let packet_sender = packet_sender.clone(); + let socket = Arc::new(socket); + + let th = std::thread::Builder::new() + .name(format!("ssListen{thread_id}")) + .spawn(move || { + crate::recv_mmsg( + + ) + }); + + let listen_thread = streamer::receiver( + format!("ssListen{thread_id}"), + Arc::new(socket), + exit.clone(), + packet_sender.clone(), + recycler.clone(), + Arc::clone(&forwarder_stats), + Duration::default(), + false, + None, + false, + ); + threads.push(listen_thread); + } +} + + +fn packet_router( + router_idx: usize, + packet_rx: crossbeam_channel::Receiver, +) -> std::io::Result> { + + std::thread::Builder::new() + .name(format!("pktRouter{router_idx}")) + .spawn(move || { + while let Ok(packet_batch) = packet_rx.recv() { + // route packets based on some criteria + for packet in packet_batch.iter() { + // Example routing logic (to be replaced with actual logic) + let dest = packet.meta().addr; + + debug!("Router {router_idx} routing packet to {dest}"); + // Send packet to appropriate destination + } + } + info!("Exiting packet router thread {router_idx}."); + }) +} + + +/// Bind to ports and start forwarding shreds +#[allow(clippy::too_many_arguments)] +pub fn start_forwarder_threads( + unioned_dest_sockets: Arc>>, /* sockets shared between endpoint discovery thread and forwarders */ + src_addr: IpAddr, + src_port: u16, + maybe_multicast_socket: Option>, + maybe_triton_multicast_socket: Option<(IpAddr, Vec)>, + num_threads: Option, + deduper: Arc>>, + entry_sender: Arc>, + debug_trace_shred: bool, + use_discovery_service: bool, + forward_stats: Arc, + metrics: Arc, + shutdown_receiver: Receiver<()>, + exit: Arc, +) -> Vec> { + let num_threads = num_threads + .unwrap_or_else(|| usize::from(std::thread::available_parallelism().unwrap()).min(4)); + + let recycler: PacketBatchRecycler = Recycler::warmed(100, 1024); + + // multi_bind_in_range returns (port, Vec) + let (_port, sockets) = solana_net_utils::multi_bind_in_range_with_config( + src_addr, + (src_port, src_port + 1), + SocketConfig::default().reuseport(true), + num_threads, + ) + .unwrap_or_else(|_| { + panic!("Failed to bind listener sockets. Check that port {src_port} is not in use.") + }); + + let mut ret = sockets + .into_iter() + .chain(maybe_multicast_socket.unwrap_or_default()) + .enumerate() + .flat_map(|(thread_id, source)| { + let (packet_sender, packet_receiver) = crossbeam_channel::unbounded(); + let listen_thread = streamer::receiver( + format!("ssListen{thread_id}"), + Arc::new(source), + exit.clone(), + packet_sender, + recycler.clone(), + forward_stats.clone(), + Duration::default(), + false, + None, + false, + ); + + let deduper = deduper.clone(); + let unioned_dest_sockets = unioned_dest_sockets.clone(); + let metrics = metrics.clone(); + let shutdown_receiver = shutdown_receiver.clone(); + let exit = exit.clone(); + + + let send_thread = Builder::new() + .name(format!("ssPxyTx_{thread_id}")) + .spawn(move || { + let dont_send_to_origin = move |origin: IpAddr, dest: SocketAddr| { + origin != dest.ip() + }; + let send_socket = { + let ipv6_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0); + match try_create_ipv6_socket(ipv6_addr) { + Ok(socket) => { + info!("Successfully bound send socket to IPv6 dual-stack address."); + socket + } + Err(e) if e.raw_os_error() == Some(libc::EAFNOSUPPORT) => { + // This error (code 97 on Linux) means IPv6 is not supported. + warn!("IPv6 not available. Falling back to IPv4-only for sending."); + let ipv4_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); + UdpSocket::bind(ipv4_addr) + .expect("Failed to bind to IPv4 socket after IPv6 failed") + } + Err(e) => { + // For any other error (e.g., port in use), panic. + panic!("Failed to bind send socket with an unexpected error: {e}"); + } + } + }; + let mut local_dest_sockets = unioned_dest_sockets.load(); + + let refresh_subscribers_tick = if use_discovery_service { + crossbeam_channel::tick(Duration::from_secs(30)) + } else { + crossbeam_channel::tick(Duration::MAX) + }; + + let mut last_recv = Instant::now(); + while !exit.load(Ordering::Relaxed) { + crossbeam_channel::select! { + // forward packets + recv(packet_receiver) -> maybe_packet_batch => { + let e = last_recv.elapsed(); + last_recv = Instant::now(); + observe_recv_interval(e.as_micros() as f64); + let res = recv_from_channel_and_send_multiple_dest( + maybe_packet_batch, + &deduper, + &send_socket, + &local_dest_sockets, + dont_send_to_origin, + debug_trace_shred, + &metrics, + ); + + // If the channel is closed or error, break out + if res.is_err() { + break; + } + } + + // refresh thread-local subscribers + recv(refresh_subscribers_tick) -> _ => { + local_dest_sockets = unioned_dest_sockets.load(); + } + + // handle shutdown (avoid using sleep since it can hang) + recv(shutdown_receiver) -> _ => { + break; + } + } + } + info!("Exiting forwarder thread {thread_id}."); + }) + .unwrap(); + + vec![listen_thread, send_thread] + }) + .collect::>>(); + + if let Some((multicast_origin, multicast_socket)) = maybe_triton_multicast_socket { + start_multicast_forwarder_thread( + multicast_origin, + multicast_socket, + recycler, + unioned_dest_sockets, + deduper, + forward_stats, + metrics, + debug_trace_shred, + use_discovery_service, + shutdown_receiver, + exit, + &mut ret + ); + } + ret +} + +/// +/// Try to create an IPv6 UDP socket bound to the given address. +/// +fn try_create_ipv6_socket(addr: SocketAddr) -> Result { + let ipv6_socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + ipv6_socket.set_multicast_hops_v6(IP_MULTICAST_TTL)?; + ipv6_socket.bind(&addr.into())?; + Ok(ipv6_socket.into()) +} + +pub struct MulticastSource { + pub socket: Arc, + pub group: IpAddr, +} + +#[allow(clippy::too_many_arguments)] +pub fn start_multicast_forwarder_thread( + multicast_origin: IpAddr, + sockets: Vec, + recycler: PacketBatchRecycler, + unioned_dest_sockets: Arc>>, + deduper: Arc>>, + forward_stats: Arc, + metrics: Arc, + debug_trace_shred: bool, + use_discovery_service: bool, + shutdown_receiver: Receiver<()>, + exit: Arc, + out: &mut Vec>, +) { + for (thread_id, socket) in sockets.into_iter().enumerate() { + let (packet_sender, packet_receiver) = crossbeam_channel::unbounded(); + let listen_thread = streamer::receiver( + format!("ssListenMulticast{thread_id}"), + Arc::new(socket), + exit.clone(), + packet_sender, + recycler.clone(), + forward_stats.clone(), + Duration::default(), + false, + None, + false, + ); + + out.push(listen_thread); + let deduper = deduper.clone(); + let unioned_dest_sockets = unioned_dest_sockets.clone(); + let metrics = metrics.clone(); + let shutdown_receiver = shutdown_receiver.clone(); + let exit = exit.clone(); + + + + let send_thread = Builder::new() + .name(format!("ssPxyTxMulticast_{thread_id}")) + .spawn(move || { + let dont_send_to_mc_origin = move |_origin, dest: SocketAddr| { + dest.ip() != multicast_origin + }; + let send_socket = { + let ipv6_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0); + match try_create_ipv6_socket(ipv6_addr) { + Ok(socket) => { + info!("Successfully bound send socket to IPv6 dual-stack address."); + socket.set_multicast_loop_v6(false) + .expect("Failed to disable IPv6 multicast loopback"); + socket + } + Err(e) if e.raw_os_error() == Some(libc::EAFNOSUPPORT) => { + // This error (code 97 on Linux) means IPv6 is not supported. + warn!("IPv6 not available. Falling back to IPv4-only for sending."); + let ipv4_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); + let socket = UdpSocket::bind(ipv4_addr) + .expect("Failed to bind to IPv4 socket after IPv6 failed"); + socket.set_multicast_ttl_v4(IP_MULTICAST_TTL).expect("IP_MULTICAST_TTL_V4"); + socket.set_multicast_loop_v4(false) + .expect("Failed to disable IPv4 multicast loopback"); + socket + } + Err(e) => { + // For any other error (e.g., port in use), panic. + panic!("Failed to bind send socket with an unexpected error: {e}"); + } + } + }; + let mut local_dest_sockets = unioned_dest_sockets.load(); + + let refresh_subscribers_tick = if use_discovery_service { + crossbeam_channel::tick(Duration::from_secs(30)) + } else { + crossbeam_channel::tick(Duration::MAX) + }; + + while !exit.load(Ordering::Relaxed) { + crossbeam_channel::select! { + // forward packets + recv(packet_receiver) -> maybe_packet_batch => { + let res = recv_from_channel_and_send_multiple_dest( + maybe_packet_batch, + &deduper, + &send_socket, + &local_dest_sockets, + dont_send_to_mc_origin, + debug_trace_shred, + &metrics, + ); + + // If the channel is closed or error, break out + if res.is_err() { + break; + } + } + + // refresh thread-local subscribers + recv(refresh_subscribers_tick) -> _ => { + local_dest_sockets = unioned_dest_sockets.load(); + } + + // handle shutdown (avoid using sleep since it can hang) + recv(shutdown_receiver) -> _ => { + break; + } + } + } + info!("Exiting forwarder thread {thread_id}."); + }) + .unwrap(); + + out.push(send_thread); + } +} + +#[allow(dead_code)] +fn accept_all(_origin: IpAddr, _dest: SocketAddr) -> bool { + true +} + +/// Broadcasts the same packet to multiple recipients, parses it into a Shred if possible, +/// and stores that shred in `all_shreds`. +#[allow(clippy::too_many_arguments)] +fn recv_from_channel_and_send_multiple_dest( + maybe_packet_batch: Result, + deduper: &RwLock>, + send_socket: &UdpSocket, + local_dest_sockets: &[SocketAddr], + local_dest_socket_filter: F, + debug_trace_shred: bool, + metrics: &ShredMetrics, +) -> Result<(), ShredstreamProxyError> +where + F: Fn(IpAddr, SocketAddr) -> bool, +{ + let packet_batch = maybe_packet_batch.map_err(ShredstreamProxyError::RecvError)?; + let trace_shred_received_time = SystemTime::now(); + let batch_len = packet_batch.len() as u64; + metrics + .received + .fetch_add(batch_len, Ordering::Relaxed); + inc_packets_received(batch_len); + observe_recv_packet_count(batch_len as f64); + debug!( + "Got batch of {} packets, total size in bytes: {}", + packet_batch.len(), + packet_batch.iter().map(|x| x.meta().size).sum::() + ); + + + let mut packet_batch_vec = vec![packet_batch]; + + let t = Instant::now(); + let num_deduped = solana_perf::deduper::dedup_packets_and_count_discards( + &deduper.read().unwrap(), + &mut packet_batch_vec, + ); + let t_dedup_usecs = t.elapsed().as_micros() as u64; + metrics.dedup_time_spent.fetch_add(t_dedup_usecs, Ordering::Relaxed); + observe_dedup_time(t_dedup_usecs as f64); + inc_packets_deduped(num_deduped); + + // Store stats for each Packet + packet_batch_vec.iter().for_each(|batch| { + batch.iter().for_each(|packet| { + let addr = packet.meta().addr; + let is_discarded = packet.meta().discard(); + metrics + .packets_received + .entry(addr) + .and_modify(|(discarded, not_discarded)| { + *discarded += is_discarded as u64; + *not_discarded += (!is_discarded) as u64; + }) + .or_insert_with(|| { + (is_discarded as u64, (!is_discarded) as u64) + }); + let status = if is_discarded { "discarded" } else { "forwarded" }; + inc_packets_by_source(&addr.to_string(), status, 1); + }); + }); + + // send out to RPCs + local_dest_sockets.iter().for_each(|outgoing_socketaddr| { + let packets_with_dest = packet_batch_vec[0] + .iter() + .filter_map(|pkt| { + let addr = pkt.meta().addr; + if local_dest_socket_filter(addr, *outgoing_socketaddr) { + Some(pkt) + } else { + None + } + }) + .filter_map(|pkt| { + let data = pkt.data(..)?; + let addr = outgoing_socketaddr; + Some((data, addr)) + }) + .collect::>(); + let t = Instant::now(); + metrics + .send_batch_size_sum + .fetch_add(packets_with_dest.len() as u64, Ordering::Relaxed); + metrics.send_batch_count.fetch_add(1, Ordering::Relaxed); + const MAX_IOV: usize = libc::UIO_MAXIOV as usize; + let max_iov_count = packets_with_dest.len() / MAX_IOV; + let unsaturated_iov_count = packets_with_dest.len() % MAX_IOV; + metrics.saturated_iov_count.fetch_add(max_iov_count as u64, Ordering::Relaxed); + metrics.unsaturated_iov_count.fetch_add(unsaturated_iov_count as u64, Ordering::Relaxed); + observe_send_packet_count(packets_with_dest.len() as f64); + match batch_send(send_socket, &packets_with_dest) { + Ok(_) => { + metrics + .success_forward + .fetch_add(packets_with_dest.len() as u64, Ordering::Relaxed); + metrics.duplicate.fetch_add(num_deduped, Ordering::Relaxed); + inc_packets_forwarded(packets_with_dest.len() as u64); + } + Err(SendPktsError::IoError(err, num_failed)) => { + metrics + .fail_forward + .fetch_add(packets_with_dest.len() as u64, Ordering::Relaxed); + metrics + .duplicate + .fetch_add(num_failed as u64, Ordering::Relaxed); + inc_packets_forward_failed(packets_with_dest.len() as u64); + error!( + "Failed to send batch of size {} to {outgoing_socketaddr:?}. \ + {num_failed} packets failed. Error: {err}", + packets_with_dest.len() + ); + } + } + let t_send_usecs = t.elapsed().as_micros() as u64; + metrics.batch_send_time_spent.fetch_add(t_send_usecs, Ordering::Relaxed); + observe_send_duration(t_send_usecs as f64); + }); + + // Count TraceShred shreds + if debug_trace_shred { + packet_batch_vec[0] + .iter() + .filter_map(|p| TraceShred::decode(p.data(..)?).ok()) + .filter(|t| t.created_at.is_some()) + .for_each(|trace_shred| { + let elapsed = trace_shred_received_time + .duration_since(SystemTime::try_from(trace_shred.created_at.unwrap()).unwrap()) + .unwrap_or_default(); + + datapoint_info!( + "shredstream_proxy-trace_shred_latency", + "trace_region" => trace_shred.region, + ("trace_seq_num", trace_shred.seq_num as i64, i64), + ("elapsed_micros", elapsed.as_micros(), i64), + ); + }); + } + + Ok(()) +} + +/// Starts a thread that updates our destinations used by the forwarder threads +pub fn start_destination_refresh_thread( + endpoint_discovery_url: String, + discovered_endpoints_port: u16, + static_dest_sockets: Vec<(SocketAddr, String)>, + unioned_dest_sockets: Arc>>, + shutdown_receiver: Receiver<()>, + exit: Arc, +) -> JoinHandle<()> { + Builder::new().name("ssPxyDstRefresh".to_string()).spawn(move || { + let fetch_socket_tick = crossbeam_channel::tick(Duration::from_secs(30)); + let metrics_tick = crossbeam_channel::tick(Duration::from_secs(30)); + let mut socket_count = static_dest_sockets.len(); + while !exit.load(Ordering::Relaxed) { + crossbeam_channel::select! { + recv(fetch_socket_tick) -> _ => { + let fetched = fetch_unioned_destinations( + &endpoint_discovery_url, + discovered_endpoints_port, + &static_dest_sockets, + ); + let new_sockets = match fetched { + Ok(s) => { + info!("Sending shreds to {} destinations: {s:?}", s.len()); + s + } + Err(e) => { + warn!("Failed to fetch from discovery service, retrying. Error: {e}"); + datapoint_warn!("shredstream_proxy-destination_refresh_error", + ("prev_unioned_dest_count", socket_count, i64), + ("errors", 1, i64), + ("error_str", e.to_string(), String), + ); + continue; + } + }; + socket_count = new_sockets.len(); + unioned_dest_sockets.store(Arc::new(new_sockets)); + } + recv(metrics_tick) -> _ => { + datapoint_info!("shredstream_proxy-destination_refresh_stats", + ("destination_count", socket_count, i64), + ); + } + recv(shutdown_receiver) -> _ => { + break; + } + } + } + }).unwrap() +} + +/// Returns dynamically discovered endpoints with CLI arg defined endpoints +fn fetch_unioned_destinations( + endpoint_discovery_url: &str, + discovered_endpoints_port: u16, + static_dest_sockets: &[(SocketAddr, String)], +) -> Result, ShredstreamProxyError> { + let bytes = reqwest::blocking::get(endpoint_discovery_url)?.bytes()?; + + let sockets_json = match serde_json::from_slice::>(&bytes) { + Ok(s) => s, + Err(e) => { + warn!( + "Failed to parse json from: {:?}", + std::str::from_utf8(&bytes) + ); + return Err(ShredstreamProxyError::from(e)); + } + }; + + // resolve again since ip address could change + let static_dest_sockets = static_dest_sockets + .iter() + .filter_map(|(_socketaddr, hostname_port)| { + Some(resolve_hostname_port(hostname_port).ok()?.0) + }) + .collect::>(); + + let unioned_dest_sockets = sockets_json + .into_iter() + .map(|ip| SocketAddr::new(ip, discovered_endpoints_port)) + .chain(static_dest_sockets) + .unique() + .collect::>(); + Ok(unioned_dest_sockets) +} + +/// Reset dedup + send metrics to influx +pub fn start_forwarder_accessory_thread( + deduper: Arc>>, + metrics: Arc, + metrics_update_interval_ms: u64, + shutdown_receiver: Receiver<()>, + exit: Arc, +) -> JoinHandle<()> { + Builder::new() + .name("ssPxyAccessory".to_string()) + .spawn(move || { + let metrics_tick = + crossbeam_channel::tick(Duration::from_millis(metrics_update_interval_ms)); + let deduper_reset_tick = crossbeam_channel::tick(Duration::from_secs(2)); + let mut rng = rand::thread_rng(); + while !exit.load(Ordering::Relaxed) { + crossbeam_channel::select! { + // reset deduper to avoid false positives + recv(deduper_reset_tick) -> _ => { + deduper + .write() + .unwrap() + .maybe_reset(&mut rng, DEDUPER_FALSE_POSITIVE_RATE, DEDUPER_RESET_CYCLE); + } + + // send metrics to influx + recv(metrics_tick) -> _ => { + metrics.report(); + metrics.reset(); + } + + // handle SIGINT shutdown + recv(shutdown_receiver) -> _ => { + break; + } + } + } + }) + .unwrap() +} + +pub struct ShredMetrics { + // receive stats + /// Total number of shreds received. Includes duplicates when receiving shreds from multiple regions + pub received: AtomicU64, + /// Total number of shreds successfully forwarded, accounting for all destinations + pub success_forward: AtomicU64, + /// Total number of shreds failed to forward, accounting for all destinations + pub fail_forward: AtomicU64, + /// Number of duplicate shreds received + pub duplicate: AtomicU64, + /// (discarded, not discarded, from other shredstream instances) + pub packets_received: DashMap, + /// The batch size we are sending to batch_send solana crate call. + pub send_batch_size_sum: AtomicU64, + pub send_batch_count: AtomicU64, + /// Number of occurrences we can saturated the iovecs in sendmmsg + pub saturated_iov_count: AtomicU64, + /// Number of occurrences we could not saturate the iovecs in sendmmsg + pub unsaturated_iov_count: AtomicU64, + + // service metrics + pub enabled_grpc_service: bool, + /// Number of data shreds recovered using coding shreds + pub recovered_count: AtomicU64, + /// Number of Solana entries decoded from shreds + pub entry_count: AtomicU64, + /// Number of transactions decoded from shreds + pub txn_count: AtomicU64, + /// Number of times we couldn't find the previous DATA_COMPLETE_SHRED flag + pub unknown_start_position_count: AtomicU64, + /// Number of FEC recovery errors + pub fec_recovery_error_count: AtomicU64, + /// Number of bincode Entry deserialization errors + pub bincode_deserialize_error_count: AtomicU64, + /// Number of times we couldn't find the previous DATA_COMPLETE_SHRED flag but tried to deshred+deserialize, and failed + pub unknown_start_position_error_count: AtomicU64, + + // cumulative time spent in deduping packets + pub dedup_time_spent: AtomicU64, + pub batch_send_time_spent: AtomicU64, + + // cumulative metrics (persist after reset) + pub agg_received_cumulative: AtomicU64, + pub agg_success_forward_cumulative: AtomicU64, + pub agg_fail_forward_cumulative: AtomicU64, + pub duplicate_cumulative: AtomicU64, +} + +impl Default for ShredMetrics { + fn default() -> Self { + Self::new(false) + } +} + +impl ShredMetrics { + pub fn new(enabled_grpc_service: bool) -> Self { + Self { + enabled_grpc_service, + received: Default::default(), + success_forward: Default::default(), + fail_forward: Default::default(), + duplicate: Default::default(), + packets_received: DashMap::with_capacity(10), + recovered_count: Default::default(), + entry_count: Default::default(), + txn_count: Default::default(), + unknown_start_position_count: Default::default(), + fec_recovery_error_count: Default::default(), + bincode_deserialize_error_count: Default::default(), + unknown_start_position_error_count: Default::default(), + agg_received_cumulative: Default::default(), + agg_success_forward_cumulative: Default::default(), + agg_fail_forward_cumulative: Default::default(), + duplicate_cumulative: Default::default(), + dedup_time_spent: Default::default(), + batch_send_time_spent: Default::default(), + send_batch_size_sum: Default::default(), + send_batch_count: Default::default(), + saturated_iov_count: Default::default(), + unsaturated_iov_count: Default::default(), + } + } + + pub fn report(&self) { + datapoint_info!( + "shredstream_proxy-connection_metrics", + ("received", self.received.load(Ordering::Relaxed), i64), + ( + "success_forward", + self.success_forward.load(Ordering::Relaxed), + i64 + ), + ( + "fail_forward", + self.fail_forward.load(Ordering::Relaxed), + i64 + ), + ("duplicate", self.duplicate.load(Ordering::Relaxed), i64), + ); + + datapoint_info!( + "shredstream_proxy-sendmmsg_iov_metrics", + ("max_iov_count", self.saturated_iov_count.load(Ordering::Relaxed), i64), + ( + "unsaturated_iov_count", + self.unsaturated_iov_count.load(Ordering::Relaxed), + i64 + ), + ); + + datapoint_info!( + "shredstream_proxy-batch_send_metrics", + ( + "send_batch_size_sum", self.send_batch_size_sum.load(Ordering::Relaxed), i64 + ), + ( + "send_batch_count", self.send_batch_count.load(Ordering::Relaxed), i64 + ) + ); + + datapoint_info!( + "shredstream_proxy-time_allocation", + ( + "deduping", + self.dedup_time_spent.load(Ordering::Relaxed), + i64 + ), + ( + "batch_send", + self.batch_send_time_spent.load(Ordering::Relaxed), + i64 + ), + ); + + + if self.enabled_grpc_service { + datapoint_info!( + "shredstream_proxy-service_metrics", + ( + "recovered_count", + self.recovered_count.swap(0, Ordering::Relaxed), + i64 + ), + ( + "entry_count", + self.entry_count.swap(0, Ordering::Relaxed), + i64 + ), + ("txn_count", self.txn_count.swap(0, Ordering::Relaxed), i64), + ( + "unknown_start_position_count", + self.unknown_start_position_count.swap(0, Ordering::Relaxed), + i64 + ), + ( + "fec_recovery_error_count", + self.fec_recovery_error_count.swap(0, Ordering::Relaxed), + i64 + ), + ( + "bincode_deserialize_error_count", + self.bincode_deserialize_error_count + .swap(0, Ordering::Relaxed), + i64 + ), + ( + "unknown_start_position_error_count", + self.unknown_start_position_error_count + .swap(0, Ordering::Relaxed), + i64 + ), + ); + } + + self.packets_received + .retain(|addr, (discarded_packets, not_discarded_packets)| { + datapoint_info!("shredstream_proxy-receiver_stats", + "addr" => addr.to_string(), + ("discarded_packets", *discarded_packets, i64), + ("not_discarded_packets", *not_discarded_packets, i64), + ); + false + }); + } + + /// resets current values, increments cumulative values + pub fn reset(&self) { + self.agg_received_cumulative + .fetch_add(self.received.swap(0, Ordering::Relaxed), Ordering::Relaxed); + self.agg_success_forward_cumulative.fetch_add( + self.success_forward.swap(0, Ordering::Relaxed), + Ordering::Relaxed, + ); + self.agg_fail_forward_cumulative.fetch_add( + self.fail_forward.swap(0, Ordering::Relaxed), + Ordering::Relaxed, + ); + self.duplicate_cumulative + .fetch_add(self.duplicate.swap(0, Ordering::Relaxed), Ordering::Relaxed); + self.dedup_time_spent.swap(0, Ordering::Relaxed); + self.batch_send_time_spent.swap(0, Ordering::Relaxed); + self.send_batch_size_sum.swap(0, Ordering::Relaxed); + self.send_batch_count.swap(0, Ordering::Relaxed); + self.saturated_iov_count.swap(0, Ordering::Relaxed); + self.unsaturated_iov_count.swap(0, Ordering::Relaxed); + } +} + +#[cfg(test)] +mod tests { + use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket}, + str::FromStr, + sync::{Arc, Mutex, RwLock}, + thread, + thread::sleep, + time::Duration, + }; + + use solana_perf::{ + deduper::Deduper, + packet::{Meta, Packet, PacketBatch}, + }; + use solana_sdk::packet::{PacketFlags, PACKET_DATA_SIZE}; + + use crate::triton_forwarder::{accept_all, recv_from_channel_and_send_multiple_dest, ShredMetrics}; + + fn listen_and_collect(listen_socket: UdpSocket, received_packets: Arc>>>) { + let mut buf = [0u8; PACKET_DATA_SIZE]; + loop { + listen_socket.recv(&mut buf).unwrap(); + received_packets.lock().unwrap().push(Vec::from(buf)); + } + } + + #[test] + fn test_2shreds_3destinations() { + let packet_batch = PacketBatch::new(vec![ + Packet::new( + [1; PACKET_DATA_SIZE], + Meta { + size: PACKET_DATA_SIZE, + addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), + port: 48289, // received on random port + flags: PacketFlags::empty(), + }, + ), + Packet::new( + [2; PACKET_DATA_SIZE], + Meta { + size: PACKET_DATA_SIZE, + addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), + port: 9999, + flags: PacketFlags::empty(), + }, + ), + ]); + let (packet_sender, packet_receiver) = crossbeam_channel::unbounded::(); + packet_sender.send(packet_batch).unwrap(); + + let dest_socketaddrs = vec![ + SocketAddr::from_str("0.0.0.0:32881").unwrap(), + SocketAddr::from_str("0.0.0.0:33881").unwrap(), + SocketAddr::from_str("0.0.0.0:34881").unwrap(), + ]; + + let test_listeners = dest_socketaddrs + .iter() + .map(|socketaddr| { + ( + UdpSocket::bind(socketaddr).unwrap(), + *socketaddr, + // store results in vec of packet, where packet is Vec + Arc::new(Mutex::new(vec![])), + ) + }) + .collect::>(); + + let udp_sender = UdpSocket::bind("0.0.0.0:10000").unwrap(); + + // spawn listeners + test_listeners + .iter() + .for_each(|(listen_socket, _socketaddr, to_receive)| { + let socket = listen_socket.try_clone().unwrap(); + let to_receive = to_receive.to_owned(); + thread::spawn(move || listen_and_collect(socket, to_receive)); + }); + + // send packets + recv_from_channel_and_send_multiple_dest( + packet_receiver.recv(), + &Arc::new(RwLock::new(Deduper::<2, [u8]>::new( + &mut rand::thread_rng(), + crate::forwarder::DEDUPER_NUM_BITS, + ))), + &udp_sender, + &Arc::new(dest_socketaddrs), + accept_all, + true, + &Arc::new(ShredMetrics::default()), + ) + .unwrap(); + + // allow packets to be received + sleep(Duration::from_millis(500)); + + let received = test_listeners + .iter() + .map(|(_, _, results)| results.clone()) + .collect::>(); + + // check results + for received in received.iter() { + let received = received.lock().unwrap(); + assert_eq!(received.len(), 2); + assert!(received + .iter() + .all(|packet| packet.len() == PACKET_DATA_SIZE)); + assert_eq!(received[0], [1; PACKET_DATA_SIZE]); + assert_eq!(received[1], [2; PACKET_DATA_SIZE]); + } + + assert_eq!( + received + .iter() + .fold(0, |acc, elem| acc + elem.lock().unwrap().len()), + 6 + ); + } + + #[test] + fn test_dest_filter() { + let packet_batch = PacketBatch::new(vec![ + Packet::new( + [1; PACKET_DATA_SIZE], + Meta { + size: PACKET_DATA_SIZE, + addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), + port: 48289, // received on random port + flags: PacketFlags::empty(), + }, + ), + Packet::new( + [2; PACKET_DATA_SIZE], + Meta { + size: PACKET_DATA_SIZE, + addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), + port: 9999, + flags: PacketFlags::empty(), + }, + ), + ]); + let (packet_sender, packet_receiver) = crossbeam_channel::unbounded::(); + packet_sender.send(packet_batch).unwrap(); + + let dest_socketaddrs = vec![ + SocketAddr::from_str("0.0.0.0:32881").unwrap(), + SocketAddr::from_str("0.0.0.0:33881").unwrap(), + SocketAddr::from_str("0.0.0.0:34881").unwrap(), + ]; + + let blacklisted = SocketAddr::from_str("0.0.0.0:34881").unwrap(); // none blacklisted + + let test_listeners = dest_socketaddrs + .iter() + .map(|socketaddr| { + ( + UdpSocket::bind(socketaddr).unwrap(), + *socketaddr, + // store results in vec of packet, where packet is Vec + Arc::new(Mutex::new(vec![])), + ) + }) + .collect::>(); + + let udp_sender = UdpSocket::bind("0.0.0.0:10000").unwrap(); + + // spawn listeners + test_listeners + .iter() + .for_each(|(listen_socket, _socketaddr, to_receive)| { + let socket = listen_socket.try_clone().unwrap(); + let to_receive = to_receive.to_owned(); + thread::spawn(move || listen_and_collect(socket, to_receive)); + }); + + // send packets + recv_from_channel_and_send_multiple_dest( + packet_receiver.recv(), + &Arc::new(RwLock::new(Deduper::<2, [u8]>::new( + &mut rand::thread_rng(), + crate::forwarder::DEDUPER_NUM_BITS, + ))), + &udp_sender, + &Arc::new(dest_socketaddrs), + move |_origin, dest: SocketAddr| dest != blacklisted, + true, + &Arc::new(ShredMetrics::default()), + ) + .unwrap(); + + // allow packets to be received + sleep(Duration::from_millis(500)); + + let received = test_listeners + .iter() + .take(test_listeners.len() - 1) // ignore blacklisted + .map(|(_, _, results)| results.clone()) + .collect::>(); + + // check results + for received in received.iter() { + let received = received.lock().unwrap(); + assert_eq!(received.len(), 2); + assert!(received + .iter() + .all(|packet| packet.len() == PACKET_DATA_SIZE)); + assert_eq!(received[0], [1; PACKET_DATA_SIZE]); + assert_eq!(received[1], [2; PACKET_DATA_SIZE]); + } + + { + let received = test_listeners[2].2.lock().unwrap(); // ensure blacklisted received nothing + assert_eq!(received.len(), 0); + } + assert_eq!( + received + .iter() + .fold(0, |acc, elem| acc + elem.lock().unwrap().len()), + 4 + ); + } +} From 5a956e1cf395a59cfaa43f05c70d1afb82d53c52 Mon Sep 17 00:00:00 2001 From: Louis-Vincent Boudreault Date: Wed, 14 Jan 2026 03:16:21 +0000 Subject: [PATCH 02/13] wip: recv_mmsg --- proxy/src/main.rs | 2 +- proxy/src/mem.rs | 5 ++ proxy/src/recv_mmsg.rs | 108 +++++++++++++++++------------------------ 3 files changed, 51 insertions(+), 64 deletions(-) diff --git a/proxy/src/main.rs b/proxy/src/main.rs index edb8deb..073b071 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -31,7 +31,7 @@ mod multicast_config; mod server; mod token_authenticator; mod prom; -// mod recv_mmsg; +mod recv_mmsg; mod mem; // mod triton_forwarder; diff --git a/proxy/src/mem.rs b/proxy/src/mem.rs index 0c2a7df..769b543 100644 --- a/proxy/src/mem.rs +++ b/proxy/src/mem.rs @@ -175,6 +175,11 @@ impl FrameBufMut { fn end_ptr(&self) -> *const u8 { unsafe { self.base().add(self.capacity()) } } + + #[inline] + pub unsafe fn as_mut_ptr(&self) -> *mut u8 { + self.ptr + } } unsafe impl BufMut for FrameBufMut { diff --git a/proxy/src/recv_mmsg.rs b/proxy/src/recv_mmsg.rs index 2c030c8..810cae3 100644 --- a/proxy/src/recv_mmsg.rs +++ b/proxy/src/recv_mmsg.rs @@ -10,7 +10,7 @@ use std::{ cmp, collections::VecDeque, io, mem::{self, MaybeUninit, zeroed}, net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, UdpSocket}, os::fd::AsRawFd, sync::atomic::{AtomicBool, Ordering}, time::{Duration, Instant} }; -use crate::mem::{FrameBuffer, FrameDesc, PagedAlignedMem, Rx, Tx, try_alloc_shared_mem}; +use crate::mem::{FrameBuf, FrameBufMut, FrameDesc, Rx, SharedMem, Tx, try_alloc_shared_mem}; pub struct RecvMemConfig { @@ -25,9 +25,10 @@ fn recv_loop( stats: &StreamerReceiveStats, coalesce: Duration, mem_config: &RecvMemConfig, + fill_rx: &mut Rx, + rx_tx: &Tx, ) -> std::io::Result<()> { - let data_shmem = try_alloc_shared_mem(PACKET_DATA_SIZE.next_power_of_two(), mem_config.frames_count, mem_config.hugepages).expect("try_alloc_shared_mem"); let mut packet_batch = Vec::with_capacity(PACKETS_PER_BATCH); loop { // Check for exit signal, even if socket is busy @@ -66,8 +67,7 @@ fn recv_loop( pub fn recv_from( - fill_ring_rx: &mut Rx, - fill_ring_tx: &Tx, + available_frame_buf_vec: &mut Vec, socket: &UdpSocket, max_wait: Duration, batch: &mut Vec @@ -85,53 +85,12 @@ pub fn recv_from( assert!(batch.capacity() >= PACKETS_PER_BATCH); - struct Defer<'a> { - i: usize, - allocated_frame: usize, - batch: &'a mut Vec, - }; - - impl Drop for Defer<'_> { - fn drop(&mut self) { - // Return unused frames to the fill ring - let exceeding_allocs = self.allocated_frame.saturating_sub(self.i); - (0..exceeding_allocs).for_each(|_| { - if let Some(unused_buffer) = self.batch.pop() { - drop(unused_buffer); - } - }); - self.allocated_frame = 0; - } - } - - let mut defer = Defer { - i: 0, - allocated_frame: 0, - batch, - }; + let mut i = 0; loop { - let frame_desc = fill_ring_rx.recv(); - let buffer = FrameBuffer::new(frame_desc, fill_ring_tx.clone()); - defer.allocated_frame += 1; - defer.batch[defer.i] = TritonPacket::new(buffer); - - let mut j = defer.i + 1; - while j < PACKETS_PER_BATCH { - - let Some(frame_desc) = fill_ring_rx.try_recv() else { - break; - }; - let buffer = FrameBuffer::new(frame_desc, fill_ring_tx.clone()); - defer.batch[j] = TritonPacket::new(buffer); - defer.allocated_frame += 1; - j += 1; - - } - - match triton_recv_mmsg(socket, &mut defer.batch[defer.i..j]) { - Err(_) if defer.i > 0 => { + match triton_recv_mmsg(socket, available_frame_buf_vec, &mut batch[i..]) { + Err(_) if i > 0 => { if start.elapsed() > max_wait { break; } @@ -141,30 +100,29 @@ pub fn recv_from( return Err(e); } Ok(npkts) => { - if defer.i == 0 { + if i == 0 { socket.set_nonblocking(true)?; } trace!("got {} packets", npkts); - defer.i += npkts; + i += npkts; // Try to batch into big enough buffers // will cause less re-shuffling later on. - if start.elapsed() > max_wait || defer.i >= PACKETS_PER_BATCH { + if start.elapsed() > max_wait || i >= PACKETS_PER_BATCH { break; } } } } - - Ok(defer.i) + Ok(i) } pub struct TritonPacket { - pub buffer: FrameBuffer, + pub buffer: FrameBuf, pub meta: Meta, } impl TritonPacket { - pub fn new(buffer: FrameBuffer) -> Self { + pub fn new(buffer: FrameBuf) -> Self { Self { buffer, meta: Meta::default(), @@ -177,7 +135,11 @@ impl TritonPacket { } -pub fn triton_recv_mmsg(sock: &UdpSocket, packets: &mut [TritonPacket]) -> io::Result { +pub fn triton_recv_mmsg( + sock: &UdpSocket, + fill_buffers: &mut Vec, + packets: &mut [TritonPacket], +) -> io::Result { // Should never hit this, but bail if the caller didn't provide any Packets // to receive into if packets.is_empty() { @@ -191,14 +153,23 @@ pub fn triton_recv_mmsg(sock: &UdpSocket, packets: &mut [TritonPacket]) -> io::R let mut hdrs = [MaybeUninit::uninit(); NUM_RCVMMSGS]; let sock_fd = sock.as_raw_fd(); - let count = cmp::min(iovs.len(), packets.len()); + let count = cmp::min(iovs.len(), packets.len()).min(fill_buffers.len()); + + let mut frame_buffer_inflight_vec: [MaybeUninit; NUM_RCVMMSGS] = + std::array::from_fn(|_| MaybeUninit::uninit()); + + let mut frame_buffer_inflight_cnt = 0; + for (hdr, iov, addr) in + izip!(&mut hdrs, &mut iovs, &mut addrs).take(count) + { + let buffer = fill_buffers.pop().expect("insufficient fill buffers"); + assert!(buffer.remaining_mut() >= PACKET_DATA_SIZE, "fill buffer too small"); + let iov_base = unsafe { + buffer.as_mut_ptr() as *mut libc::c_void + }; - for (packet, hdr, iov, addr) in - izip!(packets.iter_mut(), &mut hdrs, &mut iovs, &mut addrs).take(count) - { - let buffer = packet.buffer.base(); iov.write(iovec { - iov_base: buffer as *mut libc::c_void, + iov_base: iov_base, iov_len: PACKET_DATA_SIZE, }); @@ -208,6 +179,9 @@ pub fn triton_recv_mmsg(sock: &UdpSocket, packets: &mut [TritonPacket]) -> io::R msg_len: 0, msg_hdr, }); + // Keep track of the in-flight frame buffers to avoid use-after-free + frame_buffer_inflight_vec[frame_buffer_inflight_cnt].write(buffer); + frame_buffer_inflight_cnt += 1; } let mut ts = libc::timespec { @@ -226,11 +200,16 @@ pub fn triton_recv_mmsg(sock: &UdpSocket, packets: &mut [TritonPacket]) -> io::R ) }; let nrecv = if nrecv < 0 { + // On error, return all in-flight frame buffers back to the caller + for i in 0..frame_buffer_inflight_cnt { + let buffer = unsafe { frame_buffer_inflight_vec[i].assume_init_read() }; + fill_buffers.push(buffer); + } return Err(io::Error::last_os_error()); } else { usize::try_from(nrecv).unwrap() }; - for (addr, hdr, pkt) in izip!(addrs, hdrs, packets.iter_mut()).take(nrecv) { + for (addr, hdr, pkt, filled_bufmut) in izip!(addrs, hdrs, packets.iter_mut(), frame_buffer_inflight_vec).take(nrecv) { // SAFETY: We initialized `count` elements of `hdrs` above. `count` is // passed to recvmmsg() as the limit of messages that can be read. So, // `nrevc <= count` which means we initialized this `hdr` and @@ -239,6 +218,9 @@ pub fn triton_recv_mmsg(sock: &UdpSocket, packets: &mut [TritonPacket]) -> io::R // SAFETY: Similar to above, we initialized this `addr` and recvmmsg() // will have populated it let addr_ref = unsafe { addr.assume_init_ref() }; + let filled_bufmut = unsafe { filled_bufmut.assume_init_read() }; + let filled_buf: FrameBuf = filled_bufmut.into(); + pkt.buffer = filled_buf; pkt.meta_mut().size = hdr_ref.msg_len as usize; if let Some(addr) = cast_socket_addr(addr_ref, hdr_ref) { pkt.meta_mut().set_socket_addr(&addr); From 43a71beb7a20e0381c87862297523303d6c71e3f Mon Sep 17 00:00:00 2001 From: Louis-Vincent Boudreault Date: Wed, 14 Jan 2026 21:32:10 +0000 Subject: [PATCH 03/13] wip: triton_forwarder --- proxy/src/main.rs | 2 +- proxy/src/mem.rs | 76 +- proxy/src/recv_mmsg.rs | 117 ++- proxy/src/triton_forwarder.rs | 1537 +++++++++++---------------------- 4 files changed, 681 insertions(+), 1051 deletions(-) diff --git a/proxy/src/main.rs b/proxy/src/main.rs index 073b071..a6c71ff 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -33,7 +33,7 @@ mod token_authenticator; mod prom; mod recv_mmsg; mod mem; -// mod triton_forwarder; +mod triton_forwarder; #[derive(Clone, Debug, Parser)] #[clap(author, version, about, long_about = None)] diff --git a/proxy/src/mem.rs b/proxy/src/mem.rs index 769b543..dd00c17 100644 --- a/proxy/src/mem.rs +++ b/proxy/src/mem.rs @@ -17,13 +17,12 @@ pub struct AllocError; #[repr(C)] pub struct SharedMem { - ptr: *mut u8, - aligned_size: usize, - capacity: usize, + pub ptr: *mut u8, + len: usize, } pub fn try_alloc_shared_mem( - align: usize, + num_items: usize, capacity: usize, huge: bool, ) -> Result<*mut u8, AllocError> { @@ -32,11 +31,11 @@ pub fn try_alloc_shared_mem( capacity.is_power_of_two(), "capacity must be a power of two" ); - let aligned_size = capacity * align; + let total_len = capacity * num_items; let ptr = unsafe { libc::mmap( std::ptr::null_mut(), - aligned_size, + total_len, libc::PROT_READ | libc::PROT_WRITE, libc::MAP_SHARED | libc::MAP_ANONYMOUS | if huge { libc::MAP_HUGETLB } else { 0 }, -1, @@ -50,7 +49,7 @@ pub fn try_alloc_shared_mem( // zero initialize the memory unsafe { - std::ptr::write_bytes(ptr as *mut u8, 0, aligned_size); + std::ptr::write_bytes(ptr as *mut u8, 0, total_len); } Ok(ptr as *mut u8) @@ -59,27 +58,32 @@ pub fn try_alloc_shared_mem( impl SharedMem { - fn new(element_size: usize, capacity: usize, huge: bool) -> Result { + pub fn new(element_size: usize, capacity: usize, huge: bool) -> Result { let ptr = try_alloc_shared_mem(element_size, capacity, huge)?; - let aligned_size = capacity * element_size; + let len = capacity * element_size; Ok(Self { ptr, - aligned_size, - capacity, + len, }) } - fn dealloc(&self) { + pub fn len(&self) -> usize { + self.len + } + + pub fn dealloc(self) { unsafe { - libc::munmap(self.ptr as *mut libc::c_void, self.aligned_size); + libc::munmap(self.ptr as *mut libc::c_void, self.len); } } } impl Drop for SharedMem { fn drop(&mut self) { - self.dealloc(); + unsafe { + libc::munmap(self.ptr as *mut libc::c_void, self.len); + } } } @@ -90,6 +94,8 @@ pub struct FrameDesc { pub frame_size: usize, } +unsafe impl Send for FrameDesc {} + #[derive(Debug)] #[repr(C, align(32))] pub struct FrameBufMut { @@ -113,6 +119,47 @@ impl FrameBuf { let end = unsafe { self.desc.ptr.add(self.len) }; (end as usize) - (self.curr_ptr as usize) } + + + #[inline] + pub fn into_inner(self) -> FrameDesc { + self.desc + } + + #[inline] + pub unsafe fn detach_desc(&self) -> FrameDesc { + FrameDesc { ptr: self.desc.ptr, frame_size: self.desc.frame_size } + } + + pub unsafe fn unsafe_clone(&self) -> Self { + Self { + curr_ptr: self.curr_ptr, + len: self.len, + desc: FrameDesc { + ptr: self.desc.ptr, + frame_size: self.desc.frame_size, + }, + } + } + + pub unsafe fn unsafe_subslice_clone(&self, offset: usize, len: usize) -> Self { + assert!(offset + len <= self.len()); + Self { + curr_ptr: self.curr_ptr.add(offset), + len, + desc: FrameDesc { + ptr: self.desc.ptr, + frame_size: self.desc.frame_size, + }, + } + } + +} + +impl AsRef<[u8]> for FrameBuf { + fn as_ref(&self) -> &[u8] { + self.chunk() + } } unsafe impl Send for FrameBuf {} @@ -128,6 +175,7 @@ impl From for FrameBuf { } } + impl FrameDesc { pub fn as_mut_buf(&self) -> FrameBufMut { FrameBufMut { diff --git a/proxy/src/recv_mmsg.rs b/proxy/src/recv_mmsg.rs index 810cae3..11be2a2 100644 --- a/proxy/src/recv_mmsg.rs +++ b/proxy/src/recv_mmsg.rs @@ -1,65 +1,148 @@ -use bytes::BufMut; +use bytes::{Buf, BufMut}; use itertools::izip; use libc::{AF_INET, AF_INET6, MSG_WAITFORONE, iovec, mmsghdr, msghdr, sockaddr_storage}; use socket2::socklen_t; +use solana_ledger::shred::ShredId; use solana_perf::packet::{NUM_RCVMMSGS, PACKETS_PER_BATCH}; use solana_sdk::packet::{Meta, PACKET_DATA_SIZE, Packet}; use log::{error, trace}; use solana_streamer::{recvmmsg::recv_mmsg, streamer::StreamerReceiveStats}; use std::{ - cmp, collections::VecDeque, io, mem::{self, MaybeUninit, zeroed}, net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, UdpSocket}, os::fd::AsRawFd, sync::atomic::{AtomicBool, Ordering}, time::{Duration, Instant} + cmp, collections::VecDeque, hint::spin_loop, io, mem::{self, MaybeUninit, zeroed}, net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, UdpSocket}, os::fd::AsRawFd, sync::atomic::{AtomicBool, Ordering}, time::{Duration, Instant} }; use crate::mem::{FrameBuf, FrameBufMut, FrameDesc, Rx, SharedMem, Tx, try_alloc_shared_mem}; +const OFFSET_SHRED_TYPE: usize = 82; +const OFFSET_DATA_PARENT: usize = 83; // 83 + 0 +const OFFSET_DATA_INDEX: usize = 83 - 15; // Index is actually in common header +const OFFSET_CODING_POSITION: usize = 83 + 2; + +// Shred types based on Solana spec +const SHRED_TYPE_DATA: u8 = 0b1010_0101; +const SHRED_TYPE_CODING: u8 = 0b0101_1010; + pub struct RecvMemConfig { pub frames_count: usize, pub hugepages: bool, } +pub trait PacketRoutingStrategy: Clone { + fn route_packet(&self, packet: &TritonPacket, num_dest: usize) -> Option; +} + + +#[inline] +fn hash_pair(x: u64, y: u32) -> u64 { + let mut h = x ^ ((y as u64) << 32); + h ^= h >> 33; + h = h.wrapping_mul(0xff51afd7ed558ccd); + h ^= h >> 33; + h +} + +#[derive(Debug, Clone)] +struct FECSetRoutingStrategy; + +impl PacketRoutingStrategy for FECSetRoutingStrategy { + fn route_packet(&self, packet: &TritonPacket, num_dest: usize) -> Option { + let shred_buf = packet.buffer.chunk(); + let slot = solana_ledger::shred::wire::get_slot(shred_buf)?; + let fec = shred_buf.get(79..79 + 4)?; + let fec_bytes: [u8; 4] = fec.try_into().ok()?; + let fec_set_index = u32::from_le_bytes(fec_bytes); + let hash = hash_pair(slot, fec_set_index); + let dest = (hash as usize) % num_dest; + Some(dest) + } +} + + -fn recv_loop( +pub fn recv_loop( socket: &UdpSocket, exit: &AtomicBool, stats: &StreamerReceiveStats, coalesce: Duration, - mem_config: &RecvMemConfig, fill_rx: &mut Rx, - rx_tx: &Tx, -) -> std::io::Result<()> { + packet_tx_vec: &[Tx], + router: R, +) -> std::io::Result<()> + where R: PacketRoutingStrategy +{ let mut packet_batch = Vec::with_capacity(PACKETS_PER_BATCH); + let mut frame_bufmut_vec = Vec::with_capacity(PACKETS_PER_BATCH); + + loop { + // Check for exit signal, even if socket is busy // (for instance the leader transaction socket) if exit.load(Ordering::Relaxed) { return Ok(()); } - - if let Ok(len) = recv_from(&mut packet_batch, socket, coalesce) { + // Refill the frame buffers as much as we can, + 'fill_bufmut: while frame_bufmut_vec.len() < PACKETS_PER_BATCH { + let maybe_frame_buf = fill_rx.try_recv(); + match maybe_frame_buf { + Some(frame_desc) => { + let frame_bufmut = frame_desc.as_mut_buf(); + frame_bufmut_vec.push(frame_bufmut); + } + None => { + if frame_bufmut_vec.is_empty() { + // block until we get at least one frame buffer + let frame_desc = fill_rx.recv(); + let frame_bufmut = frame_desc.as_mut_buf(); + frame_bufmut_vec.push(frame_bufmut); + } else { + break 'fill_bufmut; + } + } + } + } + let result = recv_from( + &mut frame_bufmut_vec, + socket, + coalesce, + &mut packet_batch, + ); + if let Ok(len) = result { if len > 0 { let StreamerReceiveStats { packets_count, packet_batches_count, full_packet_batches_count, - max_channel_len, .. } = stats; packets_count.fetch_add(len, Ordering::Relaxed); packet_batches_count.fetch_add(1, Ordering::Relaxed); - max_channel_len.fetch_max(packet_batch_sender.len(), Ordering::Relaxed); if len == PACKETS_PER_BATCH { full_packet_batches_count.fetch_add(1, Ordering::Relaxed); } packet_batch .iter_mut() - .for_each(|p| p.meta_mut().set_from_staked_node(is_staked_service)); - packet_batch_sender.send(packet_batch)?; + .for_each(|p| p.meta_mut().set_from_staked_node(false)); + + for packet in packet_batch.drain(..) { + let dest_idx = match router.route_packet(&packet, packet_tx_vec.len()) { + Some(idx) => idx, + None => { + log::debug!("Failed to route packet {:?}", packet); + let trashed_frame_bufmut = packet.buffer.into_inner().as_mut_buf(); + frame_bufmut_vec.push(trashed_frame_bufmut); + continue; + } + }; + let _ = &packet_tx_vec[dest_idx] + .send(packet) + .expect("Failed to send packet to processor"); + } } - break; } } } @@ -116,6 +199,8 @@ pub fn recv_from( Ok(i) } +#[derive(Debug)] +#[repr(C)] pub struct TritonPacket { pub buffer: FrameBuf, pub meta: Meta, @@ -134,6 +219,12 @@ impl TritonPacket { } } +impl AsRef<[u8]> for TritonPacket { + fn as_ref(&self) -> &[u8] { + self.buffer.chunk() + } +} + pub fn triton_recv_mmsg( sock: &UdpSocket, diff --git a/proxy/src/triton_forwarder.rs b/proxy/src/triton_forwarder.rs index 94261d7..d6b2709 100644 --- a/proxy/src/triton_forwarder.rs +++ b/proxy/src/triton_forwarder.rs @@ -1,18 +1,14 @@ use std::{ - collections::HashSet, - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, - sync::{ - atomic::{AtomicBool, AtomicU64, Ordering}, - Arc, RwLock, - }, - thread::{Builder, JoinHandle}, - time::{Duration, Instant, SystemTime}, + collections::{HashSet, VecDeque}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, str::FromStr, sync::{ + Arc, RwLock, atomic::{AtomicBool, AtomicU64, Ordering} + }, thread::{Builder, JoinHandle}, time::{Duration, Instant, SystemTime} }; use arc_swap::ArcSwap; +use bytes::Buf; use crossbeam_channel::{Receiver, RecvError}; use dashmap::DashMap; -use itertools::Itertools; +use itertools::{Itertools, izip}; use jito_protos::shredstream::{Entry as PbEntry, TraceShred}; use log::{debug, error, info, warn}; use libc; @@ -33,9 +29,9 @@ use solana_streamer::{ use tokio::sync::broadcast::Sender; use crate::{ - ShredstreamProxyError, prom::{ + ShredstreamProxyError, forwarder::{ShredMetrics, try_create_ipv6_socket}, mem::{FrameBuf, FrameDesc, Rx, SharedMem, Tx}, prom::{ inc_packets_by_source, inc_packets_deduped, inc_packets_forward_failed, inc_packets_forwarded, inc_packets_received, observe_dedup_time, observe_recv_interval, observe_recv_packet_count, observe_send_duration, observe_send_packet_count - }, resolve_hostname_port + }, recv_mmsg::{PacketRoutingStrategy, TritonPacket}, resolve_hostname_port }; // values copied from https://github.com/solana-labs/solana/blob/33bde55bbdde13003acf45bb6afe6db4ab599ae4/core/src/sigverify_shreds.rs#L20 @@ -45,1075 +41,570 @@ pub const DEDUPER_RESET_CYCLE: Duration = Duration::from_secs(5 * 60); pub const IP_MULTICAST_TTL: u32 = 8; -fn spawn_packet_receiver( - num_receiver: usize, +#[derive(Debug, Clone, Copy, Default)] +pub enum ReceiverMemorySizing { + #[default] + XSmall = 134217728, // 128MiB + Small = 268435456, // 256MiB + Medium = 536870912, // 512MiB + Large = 1073741824, // 1GiB + XLarge = 2147483648, // 2GiB + XXLarge = 4294967296, // 4GiB + XXXLarge = 8589934592, // 8GiB + XXXXLarge = 17179869184, // 16GiB + XXXXXLarge = 34359738368, // 32GiB +} + +#[derive(Debug, thiserror::Error)] +#[error("Invalid ReceiverMemoryCapacity: {0}")] +pub struct ReceiverMemoryCapacityFromStrErr(String); + + +impl FromStr for ReceiverMemorySizing { + type Err = String; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "xsmall" | "xs" => Ok(ReceiverMemorySizing::XSmall), + "small" | "s" => Ok(ReceiverMemorySizing::Small), + "medium" | "m" => Ok(ReceiverMemorySizing::Medium), + "large" | "l" => Ok(ReceiverMemorySizing::Large), + "xlarge" | "xl" => Ok(ReceiverMemorySizing::XLarge), + "xxlarge" | "xxl" | "2xl" => Ok(ReceiverMemorySizing::XXLarge), + "xxxlarge" | "xxxl" | "3xl" => Ok(ReceiverMemorySizing::XXXLarge), + "xxxxlarge" | "xxxxl" | "4xl" => Ok(ReceiverMemorySizing::XXXXLarge), + "xxxxxlarge" | "xxxxxl" | "5xl" => Ok(ReceiverMemorySizing::XXXXXLarge), + _ => Err(s.to_string()), + } + } +} + +#[derive(Clone, Debug)] +pub struct PacketRecvTileMemConfig { + pub frame_size: usize, + pub memory_size: ReceiverMemorySizing, + pub hugepage: bool, + +} + +impl Default for PacketRecvTileMemConfig { + fn default() -> Self { + Self { + frame_size: 2048, + memory_size: ReceiverMemorySizing::default(), + hugepage: false, + } + } +} + +fn packet_recv_tile( + sockets: Vec, src_addr: IpAddr, src_port: u16, - packet_sender: crossbeam_channel::Sender, exit: Arc, - recycler: PacketBatchRecycler, forwarder_stats: Arc, + fill_rx_vec: Vec>, + packet_tx_vec: Vec>, + packet_router: R, threads: &mut Vec>, -) { - - let (_port, sockets) = solana_net_utils::multi_bind_in_range_with_config( - src_addr, - (src_port, src_port + 1), - SocketConfig::default().reuseport(true), - num_receiver, - ) - .unwrap_or_else(|_| { - panic!("Failed to bind listener sockets. Check that port {src_port} is not in use.") - }); +) -> std::io::Result<()> + where R: PacketRoutingStrategy + Send + 'static, +{ + assert!(sockets.len() == fill_rx_vec.len(), "mismatched fill_rx_vec and sockets length"); + assert!(sockets.len() == packet_tx_vec.len(), "mismatched packet_tx_vec and sockets length"); + + // let (_port, sockets) = solana_net_utils::multi_bind_in_range_with_config( + // src_addr, + // (src_port, src_port + 1), + // SocketConfig::default().reuseport(true), + // num_receiver, + // ) + // .unwrap_or_else(|_| { + // panic!("Failed to bind listener sockets. Check that port {src_port} is not in use.") + // }); + + for (thread_id, socket, mut fill_rx) in izip!(0..sockets.len(), sockets.into_iter(), fill_rx_vec.into_iter()) { + + // let shmem = SharedMem::new( + // mem_config.frame_size, + // mem_config.memory_size as usize / mem_config.frame_size, + // mem_config.hugepage, + // ).expect("SharedMem::new"); - for (thread_id, socket) in sockets.into_iter().enumerate() { - let packet_sender = packet_sender.clone(); let socket = Arc::new(socket); - + let exit = Arc::clone(&exit); + let stats = Arc::clone(&forwarder_stats); + let packet_tx_vec = packet_tx_vec.clone(); + let packet_router = R::clone(&packet_router); let th = std::thread::Builder::new() .name(format!("ssListen{thread_id}")) .spawn(move || { - crate::recv_mmsg( - + let socket = socket; + crate::recv_mmsg::recv_loop( + &socket, + &exit, + &stats, + Duration::default(), + &mut fill_rx, + &packet_tx_vec, + packet_router, ) - }); - - let listen_thread = streamer::receiver( - format!("ssListen{thread_id}"), - Arc::new(socket), - exit.clone(), - packet_sender.clone(), - recycler.clone(), - Arc::clone(&forwarder_stats), - Duration::default(), - false, - None, - false, - ); - threads.push(listen_thread); + .expect("recv_loop") + })?; + + threads.push(th); } + Ok(()) +} + +#[derive(Clone, Debug)] +#[repr(C)] +pub struct SharedMemInfo { + pub start_ptr: *const u8, + pub len: usize, // always a power of 2 } +unsafe impl Send for SharedMemInfo {} +unsafe impl Sync for SharedMemInfo {} -fn packet_router( - router_idx: usize, - packet_rx: crossbeam_channel::Receiver, +fn packet_fwd_tile( + packet_fwd_idx: usize, + hot_dest_vec: Arc>>, + send_socket: UdpSocket, + mut packet_rx: Rx, + fill_tx_vec: Vec>, + shmem_info_vec: Vec, ) -> std::io::Result> { std::thread::Builder::new() - .name(format!("pktRouter{router_idx}")) + .name(format!("ssPxyTx_{packet_fwd_idx}")) .spawn(move || { - while let Ok(packet_batch) = packet_rx.recv() { - // route packets based on some criteria - for packet in packet_batch.iter() { - // Example routing logic (to be replaced with actual logic) - let dest = packet.meta().addr; - - debug!("Router {router_idx} routing packet to {dest}"); - // Send packet to appropriate destination - } - } - info!("Exiting packet router thread {router_idx}."); - }) -} - -/// Bind to ports and start forwarding shreds -#[allow(clippy::too_many_arguments)] -pub fn start_forwarder_threads( - unioned_dest_sockets: Arc>>, /* sockets shared between endpoint discovery thread and forwarders */ - src_addr: IpAddr, - src_port: u16, - maybe_multicast_socket: Option>, - maybe_triton_multicast_socket: Option<(IpAddr, Vec)>, - num_threads: Option, - deduper: Arc>>, - entry_sender: Arc>, - debug_trace_shred: bool, - use_discovery_service: bool, - forward_stats: Arc, - metrics: Arc, - shutdown_receiver: Receiver<()>, - exit: Arc, -) -> Vec> { - let num_threads = num_threads - .unwrap_or_else(|| usize::from(std::thread::available_parallelism().unwrap()).min(4)); + let mut deduper = Deduper::<2, [u8]>::new(&mut rand::thread_rng(), DEDUPER_NUM_BITS); + const UIO_MAXIOV: usize = libc::UIO_MAXIOV as usize; + // We allocate double size to account for possible overflow if destinations array is really big + let mut next_batch_send: Vec<(FrameBuf, SocketAddr)> = Vec::with_capacity(UIO_MAXIOV); + let mut queued: VecDeque = VecDeque::with_capacity(UIO_MAXIOV); - let recycler: PacketBatchRecycler = Recycler::warmed(100, 1024); + for shmem_info in &shmem_info_vec { + assert!(shmem_info.len.is_power_of_two(), "shmem_info.len must be a power of 2"); + } - // multi_bind_in_range returns (port, Vec) - let (_port, sockets) = solana_net_utils::multi_bind_in_range_with_config( - src_addr, - (src_port, src_port + 1), - SocketConfig::default().reuseport(true), - num_threads, - ) - .unwrap_or_else(|_| { - panic!("Failed to bind listener sockets. Check that port {src_port} is not in use.") - }); + let mut next_deduper_reset_attempt = Instant::now() + Duration::from_secs(2); + let mut recycled_frames: Vec = Vec::with_capacity(UIO_MAXIOV); + loop { - let mut ret = sockets - .into_iter() - .chain(maybe_multicast_socket.unwrap_or_default()) - .enumerate() - .flat_map(|(thread_id, source)| { - let (packet_sender, packet_receiver) = crossbeam_channel::unbounded(); - let listen_thread = streamer::receiver( - format!("ssListen{thread_id}"), - Arc::new(source), - exit.clone(), - packet_sender, - recycler.clone(), - forward_stats.clone(), - Duration::default(), - false, - None, - false, - ); - - let deduper = deduper.clone(); - let unioned_dest_sockets = unioned_dest_sockets.clone(); - let metrics = metrics.clone(); - let shutdown_receiver = shutdown_receiver.clone(); - let exit = exit.clone(); - - - let send_thread = Builder::new() - .name(format!("ssPxyTx_{thread_id}")) - .spawn(move || { - let dont_send_to_origin = move |origin: IpAddr, dest: SocketAddr| { - origin != dest.ip() - }; - let send_socket = { - let ipv6_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0); - match try_create_ipv6_socket(ipv6_addr) { - Ok(socket) => { - info!("Successfully bound send socket to IPv6 dual-stack address."); - socket - } - Err(e) if e.raw_os_error() == Some(libc::EAFNOSUPPORT) => { - // This error (code 97 on Linux) means IPv6 is not supported. - warn!("IPv6 not available. Falling back to IPv4-only for sending."); - let ipv4_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); - UdpSocket::bind(ipv4_addr) - .expect("Failed to bind to IPv4 socket after IPv6 failed") - } - Err(e) => { - // For any other error (e.g., port in use), panic. - panic!("Failed to bind send socket with an unexpected error: {e}"); - } - } - }; - let mut local_dest_sockets = unioned_dest_sockets.load(); + if next_deduper_reset_attempt.elapsed() > Duration::ZERO { + deduper.maybe_reset(&mut rand::thread_rng(), DEDUPER_FALSE_POSITIVE_RATE, DEDUPER_RESET_CYCLE); + next_deduper_reset_attempt = Instant::now() + Duration::from_secs(2); + } - let refresh_subscribers_tick = if use_discovery_service { - crossbeam_channel::tick(Duration::from_secs(30)) + // Fill up the queued OR recycled_frames as much as possible + while queued.len() < UIO_MAXIOV && recycled_frames.len() < UIO_MAXIOV { + // Fill the batch as much as possible. + let packet = packet_rx.recv(); + let data_size = packet.meta.size; + let data_slice = &packet.buffer.chunk()[..data_size]; + if deduper.dedup(data_slice) { + + // put it inside the recycle queue + debug!("Deduped packet from {}", packet.meta.addr); + let desc = packet.buffer.into_inner(); + recycled_frames.push(desc); } else { - crossbeam_channel::tick(Duration::MAX) - }; - - let mut last_recv = Instant::now(); - while !exit.load(Ordering::Relaxed) { - crossbeam_channel::select! { - // forward packets - recv(packet_receiver) -> maybe_packet_batch => { - let e = last_recv.elapsed(); - last_recv = Instant::now(); - observe_recv_interval(e.as_micros() as f64); - let res = recv_from_channel_and_send_multiple_dest( - maybe_packet_batch, - &deduper, - &send_socket, - &local_dest_sockets, - dont_send_to_origin, - debug_trace_shred, - &metrics, - ); - - // If the channel is closed or error, break out - if res.is_err() { - break; - } - } - - // refresh thread-local subscribers - recv(refresh_subscribers_tick) -> _ => { - local_dest_sockets = unioned_dest_sockets.load(); - } - - // handle shutdown (avoid using sleep since it can hang) - recv(shutdown_receiver) -> _ => { - break; - } - } + queued.push_back(packet); } - info!("Exiting forwarder thread {thread_id}."); - }) - .unwrap(); - - vec![listen_thread, send_thread] - }) - .collect::>>(); - - if let Some((multicast_origin, multicast_socket)) = maybe_triton_multicast_socket { - start_multicast_forwarder_thread( - multicast_origin, - multicast_socket, - recycler, - unioned_dest_sockets, - deduper, - forward_stats, - metrics, - debug_trace_shred, - use_discovery_service, - shutdown_receiver, - exit, - &mut ret - ); - } - ret -} - -/// -/// Try to create an IPv6 UDP socket bound to the given address. -/// -fn try_create_ipv6_socket(addr: SocketAddr) -> Result { - let ipv6_socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; - ipv6_socket.set_multicast_hops_v6(IP_MULTICAST_TTL)?; - ipv6_socket.bind(&addr.into())?; - Ok(ipv6_socket.into()) -} -pub struct MulticastSource { - pub socket: Arc, - pub group: IpAddr, -} - -#[allow(clippy::too_many_arguments)] -pub fn start_multicast_forwarder_thread( - multicast_origin: IpAddr, - sockets: Vec, - recycler: PacketBatchRecycler, - unioned_dest_sockets: Arc>>, - deduper: Arc>>, - forward_stats: Arc, - metrics: Arc, - debug_trace_shred: bool, - use_discovery_service: bool, - shutdown_receiver: Receiver<()>, - exit: Arc, - out: &mut Vec>, -) { - for (thread_id, socket) in sockets.into_iter().enumerate() { - let (packet_sender, packet_receiver) = crossbeam_channel::unbounded(); - let listen_thread = streamer::receiver( - format!("ssListenMulticast{thread_id}"), - Arc::new(socket), - exit.clone(), - packet_sender, - recycler.clone(), - forward_stats.clone(), - Duration::default(), - false, - None, - false, - ); - - out.push(listen_thread); - let deduper = deduper.clone(); - let unioned_dest_sockets = unioned_dest_sockets.clone(); - let metrics = metrics.clone(); - let shutdown_receiver = shutdown_receiver.clone(); - let exit = exit.clone(); - - - - let send_thread = Builder::new() - .name(format!("ssPxyTxMulticast_{thread_id}")) - .spawn(move || { - let dont_send_to_mc_origin = move |_origin, dest: SocketAddr| { - dest.ip() != multicast_origin - }; - let send_socket = { - let ipv6_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0); - match try_create_ipv6_socket(ipv6_addr) { - Ok(socket) => { - info!("Successfully bound send socket to IPv6 dual-stack address."); - socket.set_multicast_loop_v6(false) - .expect("Failed to disable IPv6 multicast loopback"); - socket - } - Err(e) if e.raw_os_error() == Some(libc::EAFNOSUPPORT) => { - // This error (code 97 on Linux) means IPv6 is not supported. - warn!("IPv6 not available. Falling back to IPv4-only for sending."); - let ipv4_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); - let socket = UdpSocket::bind(ipv4_addr) - .expect("Failed to bind to IPv4 socket after IPv6 failed"); - socket.set_multicast_ttl_v4(IP_MULTICAST_TTL).expect("IP_MULTICAST_TTL_V4"); - socket.set_multicast_loop_v4(false) - .expect("Failed to disable IPv4 multicast loopback"); - socket - } - Err(e) => { - // For any other error (e.g., port in use), panic. - panic!("Failed to bind send socket with an unexpected error: {e}"); - } - } - }; - let mut local_dest_sockets = unioned_dest_sockets.load(); - - let refresh_subscribers_tick = if use_discovery_service { - crossbeam_channel::tick(Duration::from_secs(30)) - } else { - crossbeam_channel::tick(Duration::MAX) - }; - - while !exit.load(Ordering::Relaxed) { - crossbeam_channel::select! { - // forward packets - recv(packet_receiver) -> maybe_packet_batch => { - let res = recv_from_channel_and_send_multiple_dest( - maybe_packet_batch, - &deduper, - &send_socket, - &local_dest_sockets, - dont_send_to_mc_origin, - debug_trace_shred, - &metrics, - ); - - // If the channel is closed or error, break out - if res.is_err() { - break; - } - } - - // refresh thread-local subscribers - recv(refresh_subscribers_tick) -> _ => { - local_dest_sockets = unioned_dest_sockets.load(); - } - - // handle shutdown (avoid using sleep since it can hang) - recv(shutdown_receiver) -> _ => { - break; - } - } } - info!("Exiting forwarder thread {thread_id}."); - }) - .unwrap(); - out.push(send_thread); - } -} + let dests = hot_dest_vec.load(); + let dests_len = dests.len(); -#[allow(dead_code)] -fn accept_all(_origin: IpAddr, _dest: SocketAddr) -> bool { - true -} + // Fill up the next_batch_send + 'fill_batch_send: while next_batch_send.len() < UIO_MAXIOV && queued.len() > 0 && dests_len > 0 { -/// Broadcasts the same packet to multiple recipients, parses it into a Shred if possible, -/// and stores that shred in `all_shreds`. -#[allow(clippy::too_many_arguments)] -fn recv_from_channel_and_send_multiple_dest( - maybe_packet_batch: Result, - deduper: &RwLock>, - send_socket: &UdpSocket, - local_dest_sockets: &[SocketAddr], - local_dest_socket_filter: F, - debug_trace_shred: bool, - metrics: &ShredMetrics, -) -> Result<(), ShredstreamProxyError> -where - F: Fn(IpAddr, SocketAddr) -> bool, -{ - let packet_batch = maybe_packet_batch.map_err(ShredstreamProxyError::RecvError)?; - let trace_shred_received_time = SystemTime::now(); - let batch_len = packet_batch.len() as u64; - metrics - .received - .fetch_add(batch_len, Ordering::Relaxed); - inc_packets_received(batch_len); - observe_recv_packet_count(batch_len as f64); - debug!( - "Got batch of {} packets, total size in bytes: {}", - packet_batch.len(), - packet_batch.iter().map(|x| x.meta().size).sum::() - ); - - - let mut packet_batch_vec = vec![packet_batch]; - - let t = Instant::now(); - let num_deduped = solana_perf::deduper::dedup_packets_and_count_discards( - &deduper.read().unwrap(), - &mut packet_batch_vec, - ); - let t_dedup_usecs = t.elapsed().as_micros() as u64; - metrics.dedup_time_spent.fetch_add(t_dedup_usecs, Ordering::Relaxed); - observe_dedup_time(t_dedup_usecs as f64); - inc_packets_deduped(num_deduped); - - // Store stats for each Packet - packet_batch_vec.iter().for_each(|batch| { - batch.iter().for_each(|packet| { - let addr = packet.meta().addr; - let is_discarded = packet.meta().discard(); - metrics - .packets_received - .entry(addr) - .and_modify(|(discarded, not_discarded)| { - *discarded += is_discarded as u64; - *not_discarded += (!is_discarded) as u64; - }) - .or_insert_with(|| { - (is_discarded as u64, (!is_discarded) as u64) - }); - let status = if is_discarded { "discarded" } else { "forwarded" }; - inc_packets_by_source(&addr.to_string(), status, 1); - }); - }); + let remaining = UIO_MAXIOV - next_batch_send.len(); + if dests_len < remaining { + break 'fill_batch_send; + } - // send out to RPCs - local_dest_sockets.iter().for_each(|outgoing_socketaddr| { - let packets_with_dest = packet_batch_vec[0] - .iter() - .filter_map(|pkt| { - let addr = pkt.meta().addr; - if local_dest_socket_filter(addr, *outgoing_socketaddr) { - Some(pkt) - } else { - None + let Some(packet) = queued.pop_front() else { + break 'fill_batch_send; + }; + let buf = packet.buffer; + let desc = unsafe { buf.detach_desc() }; + recycled_frames.push(desc); + + for dest in dests.iter() { + // Cheap to do since we are just copying a pointer + let buf_clone = unsafe { buf.unsafe_subslice_clone(0, packet.meta.size) }; + next_batch_send.push((buf_clone, *dest)); + } } - }) - .filter_map(|pkt| { - let data = pkt.data(..)?; - let addr = outgoing_socketaddr; - Some((data, addr)) - }) - .collect::>(); - let t = Instant::now(); - metrics - .send_batch_size_sum - .fetch_add(packets_with_dest.len() as u64, Ordering::Relaxed); - metrics.send_batch_count.fetch_add(1, Ordering::Relaxed); - const MAX_IOV: usize = libc::UIO_MAXIOV as usize; - let max_iov_count = packets_with_dest.len() / MAX_IOV; - let unsaturated_iov_count = packets_with_dest.len() % MAX_IOV; - metrics.saturated_iov_count.fetch_add(max_iov_count as u64, Ordering::Relaxed); - metrics.unsaturated_iov_count.fetch_add(unsaturated_iov_count as u64, Ordering::Relaxed); - observe_send_packet_count(packets_with_dest.len() as f64); - match batch_send(send_socket, &packets_with_dest) { - Ok(_) => { - metrics - .success_forward - .fetch_add(packets_with_dest.len() as u64, Ordering::Relaxed); - metrics.duplicate.fetch_add(num_deduped, Ordering::Relaxed); - inc_packets_forwarded(packets_with_dest.len() as u64); - } - Err(SendPktsError::IoError(err, num_failed)) => { - metrics - .fail_forward - .fetch_add(packets_with_dest.len() as u64, Ordering::Relaxed); - metrics - .duplicate - .fetch_add(num_failed as u64, Ordering::Relaxed); - inc_packets_forward_failed(packets_with_dest.len() as u64); - error!( - "Failed to send batch of size {} to {outgoing_socketaddr:?}. \ - {num_failed} packets failed. Error: {err}", - packets_with_dest.len() - ); - } - } - let t_send_usecs = t.elapsed().as_micros() as u64; - metrics.batch_send_time_spent.fetch_add(t_send_usecs, Ordering::Relaxed); - observe_send_duration(t_send_usecs as f64); - }); - // Count TraceShred shreds - if debug_trace_shred { - packet_batch_vec[0] - .iter() - .filter_map(|p| TraceShred::decode(p.data(..)?).ok()) - .filter(|t| t.created_at.is_some()) - .for_each(|trace_shred| { - let elapsed = trace_shred_received_time - .duration_since(SystemTime::try_from(trace_shred.created_at.unwrap()).unwrap()) - .unwrap_or_default(); - - datapoint_info!( - "shredstream_proxy-trace_shred_latency", - "trace_region" => trace_shred.region, - ("trace_seq_num", trace_shred.seq_num as i64, i64), - ("elapsed_micros", elapsed.as_micros(), i64), - ); - }); - } + assert!(next_batch_send.len() <= UIO_MAXIOV, "next_batch_send.len() = {}", next_batch_send.len()); + assert!(recycled_frames.len() <= UIO_MAXIOV, "recycled_frames.len() = {}", recycled_frames.len()); + assert!(queued.len() <= UIO_MAXIOV, "queued.len() = {}", queued.len()); - Ok(()) -} - -/// Starts a thread that updates our destinations used by the forwarder threads -pub fn start_destination_refresh_thread( - endpoint_discovery_url: String, - discovered_endpoints_port: u16, - static_dest_sockets: Vec<(SocketAddr, String)>, - unioned_dest_sockets: Arc>>, - shutdown_receiver: Receiver<()>, - exit: Arc, -) -> JoinHandle<()> { - Builder::new().name("ssPxyDstRefresh".to_string()).spawn(move || { - let fetch_socket_tick = crossbeam_channel::tick(Duration::from_secs(30)); - let metrics_tick = crossbeam_channel::tick(Duration::from_secs(30)); - let mut socket_count = static_dest_sockets.len(); - while !exit.load(Ordering::Relaxed) { - crossbeam_channel::select! { - recv(fetch_socket_tick) -> _ => { - let fetched = fetch_unioned_destinations( - &endpoint_discovery_url, - discovered_endpoints_port, - &static_dest_sockets, - ); - let new_sockets = match fetched { - Ok(s) => { - info!("Sending shreds to {} destinations: {s:?}", s.len()); - s - } - Err(e) => { - warn!("Failed to fetch from discovery service, retrying. Error: {e}"); - datapoint_warn!("shredstream_proxy-destination_refresh_error", - ("prev_unioned_dest_count", socket_count, i64), - ("errors", 1, i64), - ("error_str", e.to_string(), String), - ); - continue; - } - }; - socket_count = new_sockets.len(); - unioned_dest_sockets.store(Arc::new(new_sockets)); + match batch_send(&send_socket, &next_batch_send) { + Ok(_) => { + // Successfully sent all packets in the batch } - recv(metrics_tick) -> _ => { - datapoint_info!("shredstream_proxy-destination_refresh_stats", - ("destination_count", socket_count, i64), + Err(SendPktsError::IoError(err, num_failed)) => { + error!( + "Failed to send batch of size {}. \ + {num_failed} packets failed. Error: {err}", + next_batch_send.len() ); } - recv(shutdown_receiver) -> _ => { - break; - } } - } - }).unwrap() -} - -/// Returns dynamically discovered endpoints with CLI arg defined endpoints -fn fetch_unioned_destinations( - endpoint_discovery_url: &str, - discovered_endpoints_port: u16, - static_dest_sockets: &[(SocketAddr, String)], -) -> Result, ShredstreamProxyError> { - let bytes = reqwest::blocking::get(endpoint_discovery_url)?.bytes()?; - - let sockets_json = match serde_json::from_slice::>(&bytes) { - Ok(s) => s, - Err(e) => { - warn!( - "Failed to parse json from: {:?}", - std::str::from_utf8(&bytes) - ); - return Err(ShredstreamProxyError::from(e)); - } - }; - // resolve again since ip address could change - let static_dest_sockets = static_dest_sockets - .iter() - .filter_map(|(_socketaddr, hostname_port)| { - Some(resolve_hostname_port(hostname_port).ok()?.0) - }) - .collect::>(); - - let unioned_dest_sockets = sockets_json - .into_iter() - .map(|ip| SocketAddr::new(ip, discovered_endpoints_port)) - .chain(static_dest_sockets) - .unique() - .collect::>(); - Ok(unioned_dest_sockets) -} -/// Reset dedup + send metrics to influx -pub fn start_forwarder_accessory_thread( - deduper: Arc>>, - metrics: Arc, - metrics_update_interval_ms: u64, - shutdown_receiver: Receiver<()>, - exit: Arc, -) -> JoinHandle<()> { - Builder::new() - .name("ssPxyAccessory".to_string()) - .spawn(move || { - let metrics_tick = - crossbeam_channel::tick(Duration::from_millis(metrics_update_interval_ms)); - let deduper_reset_tick = crossbeam_channel::tick(Duration::from_secs(2)); - let mut rng = rand::thread_rng(); - while !exit.load(Ordering::Relaxed) { - crossbeam_channel::select! { - // reset deduper to avoid false positives - recv(deduper_reset_tick) -> _ => { - deduper - .write() - .unwrap() - .maybe_reset(&mut rng, DEDUPER_FALSE_POSITIVE_RATE, DEDUPER_RESET_CYCLE); - } - - // send metrics to influx - recv(metrics_tick) -> _ => { - metrics.report(); - metrics.reset(); - } - - // handle SIGINT shutdown - recv(shutdown_receiver) -> _ => { - break; - } + // Recycle all used frames + while let Some(desc) = recycled_frames.pop() { + let fill_ring_idx = shmem_info_vec.iter().find_position(|shmem_info| { + (desc.ptr as usize) & (shmem_info.len - 1) == (shmem_info.start_ptr as usize) + }).expect("unknown frame desc").0; + fill_tx_vec[fill_ring_idx].send(desc).expect("frame recycling"); } + } }) - .unwrap() } -pub struct ShredMetrics { - // receive stats - /// Total number of shreds received. Includes duplicates when receiving shreds from multiple regions - pub received: AtomicU64, - /// Total number of shreds successfully forwarded, accounting for all destinations - pub success_forward: AtomicU64, - /// Total number of shreds failed to forward, accounting for all destinations - pub fail_forward: AtomicU64, - /// Number of duplicate shreds received - pub duplicate: AtomicU64, - /// (discarded, not discarded, from other shredstream instances) - pub packets_received: DashMap, - /// The batch size we are sending to batch_send solana crate call. - pub send_batch_size_sum: AtomicU64, - pub send_batch_count: AtomicU64, - /// Number of occurrences we can saturated the iovecs in sendmmsg - pub saturated_iov_count: AtomicU64, - /// Number of occurrences we could not saturate the iovecs in sendmmsg - pub unsaturated_iov_count: AtomicU64, - - // service metrics - pub enabled_grpc_service: bool, - /// Number of data shreds recovered using coding shreds - pub recovered_count: AtomicU64, - /// Number of Solana entries decoded from shreds - pub entry_count: AtomicU64, - /// Number of transactions decoded from shreds - pub txn_count: AtomicU64, - /// Number of times we couldn't find the previous DATA_COMPLETE_SHRED flag - pub unknown_start_position_count: AtomicU64, - /// Number of FEC recovery errors - pub fec_recovery_error_count: AtomicU64, - /// Number of bincode Entry deserialization errors - pub bincode_deserialize_error_count: AtomicU64, - /// Number of times we couldn't find the previous DATA_COMPLETE_SHRED flag but tried to deshred+deserialize, and failed - pub unknown_start_position_error_count: AtomicU64, - - // cumulative time spent in deduping packets - pub dedup_time_spent: AtomicU64, - pub batch_send_time_spent: AtomicU64, - - // cumulative metrics (persist after reset) - pub agg_received_cumulative: AtomicU64, - pub agg_success_forward_cumulative: AtomicU64, - pub agg_fail_forward_cumulative: AtomicU64, - pub duplicate_cumulative: AtomicU64, +#[derive(thiserror::Error, Debug)] +pub enum ProxySystemError { + #[error(transparent)] + IoError(std::io::Error), + #[error(transparent)] + AllocError(crate::mem::AllocError), } -impl Default for ShredMetrics { - fn default() -> Self { - Self::new(false) - } -} -impl ShredMetrics { - pub fn new(enabled_grpc_service: bool) -> Self { - Self { - enabled_grpc_service, - received: Default::default(), - success_forward: Default::default(), - fail_forward: Default::default(), - duplicate: Default::default(), - packets_received: DashMap::with_capacity(10), - recovered_count: Default::default(), - entry_count: Default::default(), - txn_count: Default::default(), - unknown_start_position_count: Default::default(), - fec_recovery_error_count: Default::default(), - bincode_deserialize_error_count: Default::default(), - unknown_start_position_error_count: Default::default(), - agg_received_cumulative: Default::default(), - agg_success_forward_cumulative: Default::default(), - agg_fail_forward_cumulative: Default::default(), - duplicate_cumulative: Default::default(), - dedup_time_spent: Default::default(), - batch_send_time_spent: Default::default(), - send_batch_size_sum: Default::default(), - send_batch_count: Default::default(), - saturated_iov_count: Default::default(), - unsaturated_iov_count: Default::default(), - } - } - - pub fn report(&self) { - datapoint_info!( - "shredstream_proxy-connection_metrics", - ("received", self.received.load(Ordering::Relaxed), i64), - ( - "success_forward", - self.success_forward.load(Ordering::Relaxed), - i64 - ), - ( - "fail_forward", - self.fail_forward.load(Ordering::Relaxed), - i64 - ), - ("duplicate", self.duplicate.load(Ordering::Relaxed), i64), - ); - - datapoint_info!( - "shredstream_proxy-sendmmsg_iov_metrics", - ("max_iov_count", self.saturated_iov_count.load(Ordering::Relaxed), i64), - ( - "unsaturated_iov_count", - self.unsaturated_iov_count.load(Ordering::Relaxed), - i64 - ), - ); - - datapoint_info!( - "shredstream_proxy-batch_send_metrics", - ( - "send_batch_size_sum", self.send_batch_size_sum.load(Ordering::Relaxed), i64 - ), - ( - "send_batch_count", self.send_batch_count.load(Ordering::Relaxed), i64 - ) - ); - - datapoint_info!( - "shredstream_proxy-time_allocation", - ( - "deduping", - self.dedup_time_spent.load(Ordering::Relaxed), - i64 - ), - ( - "batch_send", - self.batch_send_time_spent.load(Ordering::Relaxed), - i64 - ), - ); - - - if self.enabled_grpc_service { - datapoint_info!( - "shredstream_proxy-service_metrics", - ( - "recovered_count", - self.recovered_count.swap(0, Ordering::Relaxed), - i64 - ), - ( - "entry_count", - self.entry_count.swap(0, Ordering::Relaxed), - i64 - ), - ("txn_count", self.txn_count.swap(0, Ordering::Relaxed), i64), - ( - "unknown_start_position_count", - self.unknown_start_position_count.swap(0, Ordering::Relaxed), - i64 - ), - ( - "fec_recovery_error_count", - self.fec_recovery_error_count.swap(0, Ordering::Relaxed), - i64 - ), - ( - "bincode_deserialize_error_count", - self.bincode_deserialize_error_count - .swap(0, Ordering::Relaxed), - i64 - ), - ( - "unknown_start_position_error_count", - self.unknown_start_position_error_count - .swap(0, Ordering::Relaxed), - i64 - ), - ); - } +pub fn spawn_proxy_system( + pkt_recv_tile_mem_config: PacketRecvTileMemConfig, + dest_addr_vec: Arc>>, + src_ip: IpAddr, + src_port: u16, + num_pkt_recv_tiles: usize, + num_pkt_fwd_tiles: usize, + pkt_router: R, + exit: Arc, + stats: Arc, +) -> JoinHandle<()> + where R: PacketRoutingStrategy + Send + Sync + 'static, +{ - self.packets_received - .retain(|addr, (discarded_packets, not_discarded_packets)| { - datapoint_info!("shredstream_proxy-receiver_stats", - "addr" => addr.to_string(), - ("discarded_packets", *discarded_packets, i64), - ("not_discarded_packets", *not_discarded_packets, i64), - ); - false - }); - } + // Build pkt_recv sockets + let (_port, sockets) = solana_net_utils::multi_bind_in_range_with_config( + src_ip, + (src_port, src_port + 1), + SocketConfig::default().reuseport(true), + num_pkt_recv_tiles, + ) + .unwrap_or_else(|_| { + panic!("Failed to bind listener sockets. Check that port {src_port} is not in use.") + }); - /// resets current values, increments cumulative values - pub fn reset(&self) { - self.agg_received_cumulative - .fetch_add(self.received.swap(0, Ordering::Relaxed), Ordering::Relaxed); - self.agg_success_forward_cumulative.fetch_add( - self.success_forward.swap(0, Ordering::Relaxed), - Ordering::Relaxed, - ); - self.agg_fail_forward_cumulative.fetch_add( - self.fail_forward.swap(0, Ordering::Relaxed), - Ordering::Relaxed, - ); - self.duplicate_cumulative - .fetch_add(self.duplicate.swap(0, Ordering::Relaxed), Ordering::Relaxed); - self.dedup_time_spent.swap(0, Ordering::Relaxed); - self.batch_send_time_spent.swap(0, Ordering::Relaxed); - self.send_batch_size_sum.swap(0, Ordering::Relaxed); - self.send_batch_count.swap(0, Ordering::Relaxed); - self.saturated_iov_count.swap(0, Ordering::Relaxed); - self.unsaturated_iov_count.swap(0, Ordering::Relaxed); - } -} -#[cfg(test)] -mod tests { - use std::{ - net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket}, - str::FromStr, - sync::{Arc, Mutex, RwLock}, - thread, - thread::sleep, - time::Duration, + // Create the shared memory regions for recv tiles + let mut shmem_info_vec: Vec = Vec::with_capacity(num_pkt_recv_tiles); + let mut fill_tx_vec: Vec> = Vec::with_capacity(num_pkt_recv_tiles); + let mut fill_rx_vec: Vec> = Vec::with_capacity(num_pkt_recv_tiles); + letm + let mut shmem_vec: Vec = Vec::with_capacity(num_pkt_recv_tiles); + for _ in 0..num_pkt_recv_tiles { + let frame_size = pkt_recv_tile_mem_config.frame_size; + let num_frames = pkt_recv_tile_mem_config.memory_size as usize / pkt_recv_tile_mem_config.frame_size; + assert!(num_frames.is_power_of_two(), "num_frames must be a power of 2"); + assert!(frame_size.is_power_of_two(), "frame_size must be a power of 2"); + let shmem = SharedMem::new( + frame_size, + num_frames, + pkt_recv_tile_mem_config.hugepage, + ).expect("SharedMem::new"); + + let shmem_info = SharedMemInfo { + start_ptr: shmem.ptr, + len: shmem.len(), + }; + shmem_info_vec.push(shmem_info); + + let (fill_tx, fill_rx) = crate::mem::message_ring(num_frames).expect("frame ring"); + // Fill the fill ring with all frames + for i in 0..num_frames { + let frame_desc = FrameDesc { + ptr: unsafe { shmem.ptr.add(i * frame_size) }, + frame_size: frame_size, + }; + fill_tx.send(frame_desc).expect("initial frame ring population"); + } + shmem_vec.push(shmem); + fill_tx_vec.push(fill_tx); + fill_rx_vec.push(fill_rx); + } + + // Create socket for sending packets + + + let send_socket = { + let ipv6_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0); + match try_create_ipv6_socket(ipv6_addr) { + Ok(socket) => { + info!("Successfully bound send socket to IPv6 dual-stack address."); + socket.set_multicast_loop_v6(false) + .expect("Failed to disable IPv6 multicast loopback"); + socket + } + Err(e) if e.raw_os_error() == Some(libc::EAFNOSUPPORT) => { + // This error (code 97 on Linux) means IPv6 is not supported. + warn!("IPv6 not available. Falling back to IPv4-only for sending."); + let ipv4_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); + let socket = UdpSocket::bind(ipv4_addr) + .expect("Failed to bind to IPv4 socket after IPv6 failed"); + socket.set_multicast_ttl_v4(IP_MULTICAST_TTL).expect("IP_MULTICAST_TTL_V4"); + socket.set_multicast_loop_v4(false) + .expect("Failed to disable IPv4 multicast loopback"); + socket + } + Err(e) => { + // For any other error (e.g., port in use), panic. + panic!("Failed to bind send socket with an unexpected error: {e}"); + } + } }; - use solana_perf::{ - deduper::Deduper, - packet::{Meta, Packet, PacketBatch}, - }; - use solana_sdk::packet::{PacketFlags, PACKET_DATA_SIZE}; - use crate::triton_forwarder::{accept_all, recv_from_channel_and_send_multiple_dest, ShredMetrics}; + - fn listen_and_collect(listen_socket: UdpSocket, received_packets: Arc>>>) { - let mut buf = [0u8; PACKET_DATA_SIZE]; - loop { - listen_socket.recv(&mut buf).unwrap(); - received_packets.lock().unwrap().push(Vec::from(buf)); - } - } - #[test] - fn test_2shreds_3destinations() { - let packet_batch = PacketBatch::new(vec![ - Packet::new( - [1; PACKET_DATA_SIZE], - Meta { - size: PACKET_DATA_SIZE, - addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), - port: 48289, // received on random port - flags: PacketFlags::empty(), - }, - ), - Packet::new( - [2; PACKET_DATA_SIZE], - Meta { - size: PACKET_DATA_SIZE, - addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), - port: 9999, - flags: PacketFlags::empty(), - }, - ), - ]); - let (packet_sender, packet_receiver) = crossbeam_channel::unbounded::(); - packet_sender.send(packet_batch).unwrap(); - - let dest_socketaddrs = vec![ - SocketAddr::from_str("0.0.0.0:32881").unwrap(), - SocketAddr::from_str("0.0.0.0:33881").unwrap(), - SocketAddr::from_str("0.0.0.0:34881").unwrap(), - ]; - - let test_listeners = dest_socketaddrs - .iter() - .map(|socketaddr| { - ( - UdpSocket::bind(socketaddr).unwrap(), - *socketaddr, - // store results in vec of packet, where packet is Vec - Arc::new(Mutex::new(vec![])), - ) - }) - .collect::>(); - - let udp_sender = UdpSocket::bind("0.0.0.0:10000").unwrap(); - - // spawn listeners - test_listeners - .iter() - .for_each(|(listen_socket, _socketaddr, to_receive)| { - let socket = listen_socket.try_clone().unwrap(); - let to_receive = to_receive.to_owned(); - thread::spawn(move || listen_and_collect(socket, to_receive)); - }); - - // send packets - recv_from_channel_and_send_multiple_dest( - packet_receiver.recv(), - &Arc::new(RwLock::new(Deduper::<2, [u8]>::new( - &mut rand::thread_rng(), - crate::forwarder::DEDUPER_NUM_BITS, - ))), - &udp_sender, - &Arc::new(dest_socketaddrs), - accept_all, - true, - &Arc::new(ShredMetrics::default()), - ) - .unwrap(); - - // allow packets to be received - sleep(Duration::from_millis(500)); - - let received = test_listeners - .iter() - .map(|(_, _, results)| results.clone()) - .collect::>(); - - // check results - for received in received.iter() { - let received = received.lock().unwrap(); - assert_eq!(received.len(), 2); - assert!(received - .iter() - .all(|packet| packet.len() == PACKET_DATA_SIZE)); - assert_eq!(received[0], [1; PACKET_DATA_SIZE]); - assert_eq!(received[1], [2; PACKET_DATA_SIZE]); - } + todo!() +} - assert_eq!( - received - .iter() - .fold(0, |acc, elem| acc + elem.lock().unwrap().len()), - 6 - ); - } - #[test] - fn test_dest_filter() { - let packet_batch = PacketBatch::new(vec![ - Packet::new( - [1; PACKET_DATA_SIZE], - Meta { - size: PACKET_DATA_SIZE, - addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), - port: 48289, // received on random port - flags: PacketFlags::empty(), - }, - ), - Packet::new( - [2; PACKET_DATA_SIZE], - Meta { - size: PACKET_DATA_SIZE, - addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), - port: 9999, - flags: PacketFlags::empty(), - }, - ), - ]); - let (packet_sender, packet_receiver) = crossbeam_channel::unbounded::(); - packet_sender.send(packet_batch).unwrap(); - - let dest_socketaddrs = vec![ - SocketAddr::from_str("0.0.0.0:32881").unwrap(), - SocketAddr::from_str("0.0.0.0:33881").unwrap(), - SocketAddr::from_str("0.0.0.0:34881").unwrap(), - ]; - - let blacklisted = SocketAddr::from_str("0.0.0.0:34881").unwrap(); // none blacklisted - - let test_listeners = dest_socketaddrs - .iter() - .map(|socketaddr| { - ( - UdpSocket::bind(socketaddr).unwrap(), - *socketaddr, - // store results in vec of packet, where packet is Vec - Arc::new(Mutex::new(vec![])), - ) - }) - .collect::>(); - - let udp_sender = UdpSocket::bind("0.0.0.0:10000").unwrap(); - - // spawn listeners - test_listeners - .iter() - .for_each(|(listen_socket, _socketaddr, to_receive)| { - let socket = listen_socket.try_clone().unwrap(); - let to_receive = to_receive.to_owned(); - thread::spawn(move || listen_and_collect(socket, to_receive)); - }); - - // send packets - recv_from_channel_and_send_multiple_dest( - packet_receiver.recv(), - &Arc::new(RwLock::new(Deduper::<2, [u8]>::new( - &mut rand::thread_rng(), - crate::forwarder::DEDUPER_NUM_BITS, - ))), - &udp_sender, - &Arc::new(dest_socketaddrs), - move |_origin, dest: SocketAddr| dest != blacklisted, - true, - &Arc::new(ShredMetrics::default()), - ) - .unwrap(); - - // allow packets to be received - sleep(Duration::from_millis(500)); - - let received = test_listeners - .iter() - .take(test_listeners.len() - 1) // ignore blacklisted - .map(|(_, _, results)| results.clone()) - .collect::>(); - - // check results - for received in received.iter() { - let received = received.lock().unwrap(); - assert_eq!(received.len(), 2); - assert!(received - .iter() - .all(|packet| packet.len() == PACKET_DATA_SIZE)); - assert_eq!(received[0], [1; PACKET_DATA_SIZE]); - assert_eq!(received[1], [2; PACKET_DATA_SIZE]); - } - { - let received = test_listeners[2].2.lock().unwrap(); // ensure blacklisted received nothing - assert_eq!(received.len(), 0); - } - assert_eq!( - received - .iter() - .fold(0, |acc, elem| acc + elem.lock().unwrap().len()), - 4 - ); - } -} +// #[cfg(test)] +// mod tests { +// use std::{ +// net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket}, +// str::FromStr, +// sync::{Arc, Mutex, RwLock}, +// thread, +// thread::sleep, +// time::Duration, +// }; + +// use solana_perf::{ +// deduper::Deduper, +// packet::{Meta, Packet, PacketBatch}, +// }; +// use solana_sdk::packet::{PacketFlags, PACKET_DATA_SIZE}; + + +// fn listen_and_collect(listen_socket: UdpSocket, received_packets: Arc>>>) { +// let mut buf = [0u8; PACKET_DATA_SIZE]; +// loop { +// listen_socket.recv(&mut buf).unwrap(); +// received_packets.lock().unwrap().push(Vec::from(buf)); +// } +// } + +// #[test] +// fn test_2shreds_3destinations() { +// let packet_batch = PacketBatch::new(vec![ +// Packet::new( +// [1; PACKET_DATA_SIZE], +// Meta { +// size: PACKET_DATA_SIZE, +// addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), +// port: 48289, // received on random port +// flags: PacketFlags::empty(), +// }, +// ), +// Packet::new( +// [2; PACKET_DATA_SIZE], +// Meta { +// size: PACKET_DATA_SIZE, +// addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), +// port: 9999, +// flags: PacketFlags::empty(), +// }, +// ), +// ]); +// let (packet_sender, packet_receiver) = crossbeam_channel::unbounded::(); +// packet_sender.send(packet_batch).unwrap(); + +// let dest_socketaddrs = vec![ +// SocketAddr::from_str("0.0.0.0:32881").unwrap(), +// SocketAddr::from_str("0.0.0.0:33881").unwrap(), +// SocketAddr::from_str("0.0.0.0:34881").unwrap(), +// ]; + +// let test_listeners = dest_socketaddrs +// .iter() +// .map(|socketaddr| { +// ( +// UdpSocket::bind(socketaddr).unwrap(), +// *socketaddr, +// // store results in vec of packet, where packet is Vec +// Arc::new(Mutex::new(vec![])), +// ) +// }) +// .collect::>(); + +// let udp_sender = UdpSocket::bind("0.0.0.0:10000").unwrap(); + +// // spawn listeners +// test_listeners +// .iter() +// .for_each(|(listen_socket, _socketaddr, to_receive)| { +// let socket = listen_socket.try_clone().unwrap(); +// let to_receive = to_receive.to_owned(); +// thread::spawn(move || listen_and_collect(socket, to_receive)); +// }); + +// // send packets +// recv_from_channel_and_send_multiple_dest( +// packet_receiver.recv(), +// &Arc::new(RwLock::new(Deduper::<2, [u8]>::new( +// &mut rand::thread_rng(), +// crate::forwarder::DEDUPER_NUM_BITS, +// ))), +// &udp_sender, +// &Arc::new(dest_socketaddrs), +// accept_all, +// true, +// &Arc::new(ShredMetrics::default()), +// ) +// .unwrap(); + +// // allow packets to be received +// sleep(Duration::from_millis(500)); + +// let received = test_listeners +// .iter() +// .map(|(_, _, results)| results.clone()) +// .collect::>(); + +// // check results +// for received in received.iter() { +// let received = received.lock().unwrap(); +// assert_eq!(received.len(), 2); +// assert!(received +// .iter() +// .all(|packet| packet.len() == PACKET_DATA_SIZE)); +// assert_eq!(received[0], [1; PACKET_DATA_SIZE]); +// assert_eq!(received[1], [2; PACKET_DATA_SIZE]); +// } + +// assert_eq!( +// received +// .iter() +// .fold(0, |acc, elem| acc + elem.lock().unwrap().len()), +// 6 +// ); +// } + +// #[test] +// fn test_dest_filter() { +// let packet_batch = PacketBatch::new(vec![ +// Packet::new( +// [1; PACKET_DATA_SIZE], +// Meta { +// size: PACKET_DATA_SIZE, +// addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), +// port: 48289, // received on random port +// flags: PacketFlags::empty(), +// }, +// ), +// Packet::new( +// [2; PACKET_DATA_SIZE], +// Meta { +// size: PACKET_DATA_SIZE, +// addr: IpAddr::V4(Ipv4Addr::UNSPECIFIED), +// port: 9999, +// flags: PacketFlags::empty(), +// }, +// ), +// ]); +// let (packet_sender, packet_receiver) = crossbeam_channel::unbounded::(); +// packet_sender.send(packet_batch).unwrap(); + +// let dest_socketaddrs = vec![ +// SocketAddr::from_str("0.0.0.0:32881").unwrap(), +// SocketAddr::from_str("0.0.0.0:33881").unwrap(), +// SocketAddr::from_str("0.0.0.0:34881").unwrap(), +// ]; + +// let blacklisted = SocketAddr::from_str("0.0.0.0:34881").unwrap(); // none blacklisted + +// let test_listeners = dest_socketaddrs +// .iter() +// .map(|socketaddr| { +// ( +// UdpSocket::bind(socketaddr).unwrap(), +// *socketaddr, +// // store results in vec of packet, where packet is Vec +// Arc::new(Mutex::new(vec![])), +// ) +// }) +// .collect::>(); + +// let udp_sender = UdpSocket::bind("0.0.0.0:10000").unwrap(); + +// // spawn listeners +// test_listeners +// .iter() +// .for_each(|(listen_socket, _socketaddr, to_receive)| { +// let socket = listen_socket.try_clone().unwrap(); +// let to_receive = to_receive.to_owned(); +// thread::spawn(move || listen_and_collect(socket, to_receive)); +// }); + +// // send packets +// recv_from_channel_and_send_multiple_dest( +// packet_receiver.recv(), +// &Arc::new(RwLock::new(Deduper::<2, [u8]>::new( +// &mut rand::thread_rng(), +// crate::forwarder::DEDUPER_NUM_BITS, +// ))), +// &udp_sender, +// &Arc::new(dest_socketaddrs), +// move |_origin, dest: SocketAddr| dest != blacklisted, +// true, +// &Arc::new(ShredMetrics::default()), +// ) +// .unwrap(); + +// // allow packets to be received +// sleep(Duration::from_millis(500)); + +// let received = test_listeners +// .iter() +// .take(test_listeners.len() - 1) // ignore blacklisted +// .map(|(_, _, results)| results.clone()) +// .collect::>(); + +// // check results +// for received in received.iter() { +// let received = received.lock().unwrap(); +// assert_eq!(received.len(), 2); +// assert!(received +// .iter() +// .all(|packet| packet.len() == PACKET_DATA_SIZE)); +// assert_eq!(received[0], [1; PACKET_DATA_SIZE]); +// assert_eq!(received[1], [2; PACKET_DATA_SIZE]); +// } + +// { +// let received = test_listeners[2].2.lock().unwrap(); // ensure blacklisted received nothing +// assert_eq!(received.len(), 0); +// } +// assert_eq!( +// received +// .iter() +// .fold(0, |acc, elem| acc + elem.lock().unwrap().len()), +// 4 +// ); +// } +// } From d7e239799ee8403bf6c9dbee6040eaa4c65fbf5a Mon Sep 17 00:00:00 2001 From: Louis-Vincent Boudreault Date: Wed, 14 Jan 2026 22:41:34 +0000 Subject: [PATCH 04/13] fmt --- proxy/src/forwarder.rs | 124 ++++++----- proxy/src/mem.rs | 23 +- proxy/src/recv_mmsg.rs | 79 +++---- proxy/src/triton_forwarder.rs | 408 ++++++++++++++++++++++------------ 4 files changed, 375 insertions(+), 259 deletions(-) diff --git a/proxy/src/forwarder.rs b/proxy/src/forwarder.rs index 80275d5..6564f28 100644 --- a/proxy/src/forwarder.rs +++ b/proxy/src/forwarder.rs @@ -14,8 +14,8 @@ use crossbeam_channel::{Receiver, RecvError}; use dashmap::DashMap; use itertools::Itertools; use jito_protos::shredstream::{Entry as PbEntry, TraceShred}; -use log::{debug, error, info, warn}; use libc; +use log::{debug, error, info, warn}; use prost::Message; use socket2::{Domain, Protocol, Socket, Type}; use solana_client::client_error::reqwest; @@ -35,15 +35,13 @@ use solana_streamer::{ use tokio::sync::broadcast::Sender; use crate::{ - ShredstreamProxyError, deshred::{self, ComparableShred, ShredsStateTracker}, prom::{ - observe_dedup_time, observe_send_packet_count, observe_send_duration, - observe_recv_interval, observe_recv_packet_count, - inc_packets_received, inc_packets_deduped, inc_packets_forwarded, - inc_packets_forward_failed, inc_packets_by_source, + inc_packets_by_source, inc_packets_deduped, inc_packets_forward_failed, + inc_packets_forwarded, inc_packets_received, observe_dedup_time, observe_recv_interval, + observe_recv_packet_count, observe_send_duration, observe_send_packet_count, }, - resolve_hostname_port, + resolve_hostname_port, ShredstreamProxyError, }; // values copied from https://github.com/solana-labs/solana/blob/33bde55bbdde13003acf45bb6afe6db4ab599ae4/core/src/sigverify_shreds.rs#L20 @@ -141,7 +139,7 @@ pub fn start_forwarder_threads( thread_hdls.push(hdl); }; - let mut ret = sockets + let mut ret = sockets .into_iter() .chain(maybe_multicast_socket.unwrap_or_default()) .enumerate() @@ -167,13 +165,11 @@ pub fn start_forwarder_threads( let reconstruct_tx = reconstruct_tx.clone(); let exit = exit.clone(); - let send_thread = Builder::new() .name(format!("ssPxyTx_{thread_id}")) .spawn(move || { - let dont_send_to_origin = move |origin: IpAddr, dest: SocketAddr| { - origin != dest.ip() - }; + let dont_send_to_origin = + move |origin: IpAddr, dest: SocketAddr| origin != dest.ip(); let send_socket = { let ipv6_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0); match try_create_ipv6_socket(ipv6_addr) { @@ -184,7 +180,8 @@ pub fn start_forwarder_threads( Err(e) if e.raw_os_error() == Some(libc::EAFNOSUPPORT) => { // This error (code 97 on Linux) means IPv6 is not supported. warn!("IPv6 not available. Falling back to IPv4-only for sending."); - let ipv4_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); + let ipv4_addr = + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); UdpSocket::bind(ipv4_addr) .expect("Failed to bind to IPv4 socket after IPv6 failed") } @@ -249,20 +246,20 @@ pub fn start_forwarder_threads( if let Some((multicast_origin, multicast_socket)) = maybe_triton_multicast_socket { start_multicast_forwarder_thread( - multicast_origin, - multicast_socket, - recycler, - reconstruct_tx, - unioned_dest_sockets, - deduper, - forward_stats, - metrics, - debug_trace_shred, - should_reconstruct_shreds, - use_discovery_service, - shutdown_receiver, - exit, - &mut ret + multicast_origin, + multicast_socket, + recycler, + reconstruct_tx, + unioned_dest_sockets, + deduper, + forward_stats, + metrics, + debug_trace_shred, + should_reconstruct_shreds, + use_discovery_service, + shutdown_receiver, + exit, + &mut ret, ); } ret @@ -270,8 +267,7 @@ pub fn start_forwarder_threads( /// /// Try to create an IPv6 UDP socket bound to the given address. -/// -fn try_create_ipv6_socket(addr: SocketAddr) -> Result { +pub fn try_create_ipv6_socket(addr: SocketAddr) -> Result { let ipv6_socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; ipv6_socket.set_multicast_hops_v6(IP_MULTICAST_TTL)?; ipv6_socket.bind(&addr.into())?; @@ -323,21 +319,19 @@ pub fn start_multicast_forwarder_thread( let reconstruct_tx = reconstruct_tx.clone(); let exit = exit.clone(); - - let send_thread = Builder::new() .name(format!("ssPxyTxMulticast_{thread_id}")) .spawn(move || { - let dont_send_to_mc_origin = move |_origin, dest: SocketAddr| { - dest.ip() != multicast_origin - }; + let dont_send_to_mc_origin = + move |_origin, dest: SocketAddr| dest.ip() != multicast_origin; let send_socket = { let ipv6_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0); match try_create_ipv6_socket(ipv6_addr) { Ok(socket) => { info!("Successfully bound send socket to IPv6 dual-stack address."); - socket.set_multicast_loop_v6(false) - .expect("Failed to disable IPv6 multicast loopback"); + socket + .set_multicast_loop_v6(false) + .expect("Failed to disable IPv6 multicast loopback"); socket } Err(e) if e.raw_os_error() == Some(libc::EAFNOSUPPORT) => { @@ -346,8 +340,11 @@ pub fn start_multicast_forwarder_thread( let ipv4_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); let socket = UdpSocket::bind(ipv4_addr) .expect("Failed to bind to IPv4 socket after IPv6 failed"); - socket.set_multicast_ttl_v4(IP_MULTICAST_TTL).expect("IP_MULTICAST_TTL_V4"); - socket.set_multicast_loop_v4(false) + socket + .set_multicast_ttl_v4(IP_MULTICAST_TTL) + .expect("IP_MULTICAST_TTL_V4"); + socket + .set_multicast_loop_v4(false) .expect("Failed to disable IPv4 multicast loopback"); socket } @@ -424,16 +421,14 @@ fn recv_from_channel_and_send_multiple_dest( reconstruct_tx: &crossbeam_channel::Sender, debug_trace_shred: bool, metrics: &ShredMetrics, -) -> Result<(), ShredstreamProxyError> +) -> Result<(), ShredstreamProxyError> where F: Fn(IpAddr, SocketAddr) -> bool, { let packet_batch = maybe_packet_batch.map_err(ShredstreamProxyError::RecvError)?; let trace_shred_received_time = SystemTime::now(); let batch_len = packet_batch.len() as u64; - metrics - .received - .fetch_add(batch_len, Ordering::Relaxed); + metrics.received.fetch_add(batch_len, Ordering::Relaxed); inc_packets_received(batch_len); observe_recv_packet_count(batch_len as f64); debug!( @@ -454,7 +449,9 @@ where &mut packet_batch_vec, ); let t_dedup_usecs = t.elapsed().as_micros() as u64; - metrics.dedup_time_spent.fetch_add(t_dedup_usecs, Ordering::Relaxed); + metrics + .dedup_time_spent + .fetch_add(t_dedup_usecs, Ordering::Relaxed); observe_dedup_time(t_dedup_usecs as f64); inc_packets_deduped(num_deduped); @@ -470,10 +467,12 @@ where *discarded += is_discarded as u64; *not_discarded += (!is_discarded) as u64; }) - .or_insert_with(|| { - (is_discarded as u64, (!is_discarded) as u64) - }); - let status = if is_discarded { "discarded" } else { "forwarded" }; + .or_insert_with(|| (is_discarded as u64, (!is_discarded) as u64)); + let status = if is_discarded { + "discarded" + } else { + "forwarded" + }; inc_packets_by_source(&addr.to_string(), status, 1); }); }); @@ -502,10 +501,14 @@ where .fetch_add(packets_with_dest.len() as u64, Ordering::Relaxed); metrics.send_batch_count.fetch_add(1, Ordering::Relaxed); const MAX_IOV: usize = libc::UIO_MAXIOV as usize; - let max_iov_count = packets_with_dest.len() / MAX_IOV; + let max_iov_count = packets_with_dest.len() / MAX_IOV; let unsaturated_iov_count = packets_with_dest.len() % MAX_IOV; - metrics.saturated_iov_count.fetch_add(max_iov_count as u64, Ordering::Relaxed); - metrics.unsaturated_iov_count.fetch_add(unsaturated_iov_count as u64, Ordering::Relaxed); + metrics + .saturated_iov_count + .fetch_add(max_iov_count as u64, Ordering::Relaxed); + metrics + .unsaturated_iov_count + .fetch_add(unsaturated_iov_count as u64, Ordering::Relaxed); observe_send_packet_count(packets_with_dest.len() as f64); match batch_send(send_socket, &packets_with_dest) { Ok(_) => { @@ -531,7 +534,9 @@ where } } let t_send_usecs = t.elapsed().as_micros() as u64; - metrics.batch_send_time_spent.fetch_add(t_send_usecs, Ordering::Relaxed); + metrics + .batch_send_time_spent + .fetch_add(t_send_usecs, Ordering::Relaxed); observe_send_duration(t_send_usecs as f64); }); @@ -789,7 +794,11 @@ impl ShredMetrics { datapoint_info!( "shredstream_proxy-sendmmsg_iov_metrics", - ("max_iov_count", self.saturated_iov_count.load(Ordering::Relaxed), i64), + ( + "max_iov_count", + self.saturated_iov_count.load(Ordering::Relaxed), + i64 + ), ( "unsaturated_iov_count", self.unsaturated_iov_count.load(Ordering::Relaxed), @@ -798,12 +807,16 @@ impl ShredMetrics { ); datapoint_info!( - "shredstream_proxy-batch_send_metrics", + "shredstream_proxy-batch_send_metrics", ( - "send_batch_size_sum", self.send_batch_size_sum.load(Ordering::Relaxed), i64 + "send_batch_size_sum", + self.send_batch_size_sum.load(Ordering::Relaxed), + i64 ), ( - "send_batch_count", self.send_batch_count.load(Ordering::Relaxed), i64 + "send_batch_count", + self.send_batch_count.load(Ordering::Relaxed), + i64 ) ); @@ -821,7 +834,6 @@ impl ShredMetrics { ), ); - if self.enabled_grpc_service { datapoint_info!( "shredstream_proxy-service_metrics", diff --git a/proxy/src/mem.rs b/proxy/src/mem.rs index dd00c17..513fe9a 100644 --- a/proxy/src/mem.rs +++ b/proxy/src/mem.rs @@ -9,7 +9,7 @@ use std::{ thread::{self, Thread}, }; -use bytes::{Buf, BufMut, buf::UninitSlice}; +use bytes::{buf::UninitSlice, Buf, BufMut}; #[derive(Debug, thiserror::Error)] #[error("allocation error")] @@ -55,17 +55,12 @@ pub fn try_alloc_shared_mem( Ok(ptr as *mut u8) } - - impl SharedMem { pub fn new(element_size: usize, capacity: usize, huge: bool) -> Result { let ptr = try_alloc_shared_mem(element_size, capacity, huge)?; let len = capacity * element_size; - Ok(Self { - ptr, - len, - }) + Ok(Self { ptr, len }) } pub fn len(&self) -> usize { @@ -120,7 +115,6 @@ impl FrameBuf { (end as usize) - (self.curr_ptr as usize) } - #[inline] pub fn into_inner(self) -> FrameDesc { self.desc @@ -128,7 +122,10 @@ impl FrameBuf { #[inline] pub unsafe fn detach_desc(&self) -> FrameDesc { - FrameDesc { ptr: self.desc.ptr, frame_size: self.desc.frame_size } + FrameDesc { + ptr: self.desc.ptr, + frame_size: self.desc.frame_size, + } } pub unsafe fn unsafe_clone(&self) -> Self { @@ -153,7 +150,6 @@ impl FrameBuf { }, } } - } impl AsRef<[u8]> for FrameBuf { @@ -175,7 +171,6 @@ impl From for FrameBuf { } } - impl FrameDesc { pub fn as_mut_buf(&self) -> FrameBufMut { FrameBufMut { @@ -197,7 +192,6 @@ impl From for FrameBufMut { } } - impl FrameBufMut { #[inline] pub fn base(&self) -> *mut u8 { @@ -265,10 +259,7 @@ impl Buf for FrameBuf { fn advance(&mut self, cnt: usize) { let new_ptr = unsafe { self.curr_ptr.add(cnt) }; let end = unsafe { self.desc.ptr.add(self.len) }; - assert!( - new_ptr as *const u8 <= end, - "advance out of bounds" - ); + assert!(new_ptr as *const u8 <= end, "advance out of bounds"); self.curr_ptr = new_ptr; } } diff --git a/proxy/src/recv_mmsg.rs b/proxy/src/recv_mmsg.rs index 11be2a2..a027193 100644 --- a/proxy/src/recv_mmsg.rs +++ b/proxy/src/recv_mmsg.rs @@ -1,18 +1,26 @@ +use std::{ + cmp, + collections::VecDeque, + hint::spin_loop, + io, + mem::{self, zeroed, MaybeUninit}, + net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, UdpSocket}, + os::fd::AsRawFd, + sync::atomic::{AtomicBool, Ordering}, + time::{Duration, Instant}, +}; + use bytes::{Buf, BufMut}; use itertools::izip; -use libc::{AF_INET, AF_INET6, MSG_WAITFORONE, iovec, mmsghdr, msghdr, sockaddr_storage}; +use libc::{iovec, mmsghdr, msghdr, sockaddr_storage, AF_INET, AF_INET6, MSG_WAITFORONE}; +use log::{error, trace}; use socket2::socklen_t; use solana_ledger::shred::ShredId; use solana_perf::packet::{NUM_RCVMMSGS, PACKETS_PER_BATCH}; -use solana_sdk::packet::{Meta, PACKET_DATA_SIZE, Packet}; -use log::{error, trace}; +use solana_sdk::packet::{Meta, Packet, PACKET_DATA_SIZE}; use solana_streamer::{recvmmsg::recv_mmsg, streamer::StreamerReceiveStats}; -use std::{ - cmp, collections::VecDeque, hint::spin_loop, io, mem::{self, MaybeUninit, zeroed}, net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, UdpSocket}, os::fd::AsRawFd, sync::atomic::{AtomicBool, Ordering}, time::{Duration, Instant} -}; - -use crate::mem::{FrameBuf, FrameBufMut, FrameDesc, Rx, SharedMem, Tx, try_alloc_shared_mem}; +use crate::mem::{try_alloc_shared_mem, FrameBuf, FrameBufMut, FrameDesc, Rx, SharedMem, Tx}; const OFFSET_SHRED_TYPE: usize = 82; const OFFSET_DATA_PARENT: usize = 83; // 83 + 0 @@ -32,7 +40,6 @@ pub trait PacketRoutingStrategy: Clone { fn route_packet(&self, packet: &TritonPacket, num_dest: usize) -> Option; } - #[inline] fn hash_pair(x: u64, y: u32) -> u64 { let mut h = x ^ ((y as u64) << 32); @@ -49,7 +56,7 @@ impl PacketRoutingStrategy for FECSetRoutingStrategy { fn route_packet(&self, packet: &TritonPacket, num_dest: usize) -> Option { let shred_buf = packet.buffer.chunk(); let slot = solana_ledger::shred::wire::get_slot(shred_buf)?; - let fec = shred_buf.get(79..79 + 4)?; + let fec = shred_buf.get(79..79 + 4)?; let fec_bytes: [u8; 4] = fec.try_into().ok()?; let fec_set_index = u32::from_le_bytes(fec_bytes); let hash = hash_pair(slot, fec_set_index); @@ -58,8 +65,6 @@ impl PacketRoutingStrategy for FECSetRoutingStrategy { } } - - pub fn recv_loop( socket: &UdpSocket, exit: &AtomicBool, @@ -68,23 +73,21 @@ pub fn recv_loop( fill_rx: &mut Rx, packet_tx_vec: &[Tx], router: R, -) -> std::io::Result<()> - where R: PacketRoutingStrategy +) -> std::io::Result<()> +where + R: PacketRoutingStrategy, { - let mut packet_batch = Vec::with_capacity(PACKETS_PER_BATCH); let mut frame_bufmut_vec = Vec::with_capacity(PACKETS_PER_BATCH); - loop { - // Check for exit signal, even if socket is busy // (for instance the leader transaction socket) if exit.load(Ordering::Relaxed) { return Ok(()); } - // Refill the frame buffers as much as we can, + // Refill the frame buffers as much as we can, 'fill_bufmut: while frame_bufmut_vec.len() < PACKETS_PER_BATCH { let maybe_frame_buf = fill_rx.try_recv(); match maybe_frame_buf { @@ -104,12 +107,7 @@ pub fn recv_loop( } } } - let result = recv_from( - &mut frame_bufmut_vec, - socket, - coalesce, - &mut packet_batch, - ); + let result = recv_from(&mut frame_bufmut_vec, socket, coalesce, &mut packet_batch); if let Ok(len) = result { if len > 0 { let StreamerReceiveStats { @@ -147,13 +145,11 @@ pub fn recv_loop( } } - - pub fn recv_from( available_frame_buf_vec: &mut Vec, - socket: &UdpSocket, + socket: &UdpSocket, max_wait: Duration, - batch: &mut Vec + batch: &mut Vec, ) -> std::io::Result { // let mut i: usize = 0; //DOCUMENTED SIDE-EFFECT @@ -171,7 +167,6 @@ pub fn recv_from( let mut i = 0; loop { - match triton_recv_mmsg(socket, available_frame_buf_vec, &mut batch[i..]) { Err(_) if i > 0 => { if start.elapsed() > max_wait { @@ -225,11 +220,10 @@ impl AsRef<[u8]> for TritonPacket { } } - pub fn triton_recv_mmsg( - sock: &UdpSocket, - fill_buffers: &mut Vec, - packets: &mut [TritonPacket], + sock: &UdpSocket, + fill_buffers: &mut Vec, + packets: &mut [TritonPacket], ) -> io::Result { // Should never hit this, but bail if the caller didn't provide any Packets // to receive into @@ -250,14 +244,13 @@ pub fn triton_recv_mmsg( std::array::from_fn(|_| MaybeUninit::uninit()); let mut frame_buffer_inflight_cnt = 0; - for (hdr, iov, addr) in - izip!(&mut hdrs, &mut iovs, &mut addrs).take(count) - { + for (hdr, iov, addr) in izip!(&mut hdrs, &mut iovs, &mut addrs).take(count) { let buffer = fill_buffers.pop().expect("insufficient fill buffers"); - assert!(buffer.remaining_mut() >= PACKET_DATA_SIZE, "fill buffer too small"); - let iov_base = unsafe { - buffer.as_mut_ptr() as *mut libc::c_void - }; + assert!( + buffer.remaining_mut() >= PACKET_DATA_SIZE, + "fill buffer too small" + ); + let iov_base = unsafe { buffer.as_mut_ptr() as *mut libc::c_void }; iov.write(iovec { iov_base: iov_base, @@ -300,7 +293,9 @@ pub fn triton_recv_mmsg( } else { usize::try_from(nrecv).unwrap() }; - for (addr, hdr, pkt, filled_bufmut) in izip!(addrs, hdrs, packets.iter_mut(), frame_buffer_inflight_vec).take(nrecv) { + for (addr, hdr, pkt, filled_bufmut) in + izip!(addrs, hdrs, packets.iter_mut(), frame_buffer_inflight_vec).take(nrecv) + { // SAFETY: We initialized `count` elements of `hdrs` above. `count` is // passed to recvmmsg() as the limit of messages that can be read. So, // `nrevc <= count` which means we initialized this `hdr` and @@ -336,7 +331,6 @@ pub fn triton_recv_mmsg( Ok(nrecv) } - fn create_msghdr( msg_name: &mut MaybeUninit, msg_namelen: socklen_t, @@ -355,7 +349,6 @@ fn create_msghdr( msg_hdr } - fn cast_socket_addr(addr: &sockaddr_storage, hdr: &mmsghdr) -> Option { use libc::{sa_family_t, sockaddr_in, sockaddr_in6}; const SOCKADDR_IN_SIZE: usize = std::mem::size_of::(); diff --git a/proxy/src/triton_forwarder.rs b/proxy/src/triton_forwarder.rs index d6b2709..de753e1 100644 --- a/proxy/src/triton_forwarder.rs +++ b/proxy/src/triton_forwarder.rs @@ -1,17 +1,23 @@ use std::{ - collections::{HashSet, VecDeque}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, str::FromStr, sync::{ - Arc, RwLock, atomic::{AtomicBool, AtomicU64, Ordering} - }, thread::{Builder, JoinHandle}, time::{Duration, Instant, SystemTime} + collections::{HashSet, VecDeque}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, + str::FromStr, + sync::{ + atomic::{AtomicBool, AtomicU64, Ordering}, + Arc, RwLock, + }, + thread::{Builder, JoinHandle}, + time::{Duration, Instant, SystemTime}, }; use arc_swap::ArcSwap; use bytes::Buf; -use crossbeam_channel::{Receiver, RecvError}; +use crossbeam_channel::{Receiver, RecvError, Sender}; use dashmap::DashMap; -use itertools::{Itertools, izip}; +use itertools::{izip, Itertools}; use jito_protos::shredstream::{Entry as PbEntry, TraceShred}; -use log::{debug, error, info, warn}; use libc; +use log::{debug, error, info, warn}; use prost::Message; use socket2::{Domain, Protocol, Socket, Type}; use solana_client::client_error::reqwest; @@ -22,16 +28,22 @@ use solana_perf::{ packet::{PacketBatch, PacketBatchRecycler}, recycler::Recycler, }; +use solana_sdk::exit; use solana_streamer::{ sendmmsg::{batch_send, SendPktsError}, streamer::{self, StreamerReceiveStats}, }; -use tokio::sync::broadcast::Sender; use crate::{ - ShredstreamProxyError, forwarder::{ShredMetrics, try_create_ipv6_socket}, mem::{FrameBuf, FrameDesc, Rx, SharedMem, Tx}, prom::{ - inc_packets_by_source, inc_packets_deduped, inc_packets_forward_failed, inc_packets_forwarded, inc_packets_received, observe_dedup_time, observe_recv_interval, observe_recv_packet_count, observe_send_duration, observe_send_packet_count - }, recv_mmsg::{PacketRoutingStrategy, TritonPacket}, resolve_hostname_port + forwarder::{try_create_ipv6_socket, ShredMetrics}, + mem::{FrameBuf, FrameDesc, Rx, SharedMem, Tx}, + prom::{ + inc_packets_by_source, inc_packets_deduped, inc_packets_forward_failed, + inc_packets_forwarded, inc_packets_received, observe_dedup_time, observe_recv_interval, + observe_recv_packet_count, observe_send_duration, observe_send_packet_count, + }, + recv_mmsg::{PacketRoutingStrategy, TritonPacket}, + resolve_hostname_port, ShredstreamProxyError, }; // values copied from https://github.com/solana-labs/solana/blob/33bde55bbdde13003acf45bb6afe6db4ab599ae4/core/src/sigverify_shreds.rs#L20 @@ -40,18 +52,17 @@ pub const DEDUPER_NUM_BITS: u64 = 637_534_199; // 76MB pub const DEDUPER_RESET_CYCLE: Duration = Duration::from_secs(5 * 60); pub const IP_MULTICAST_TTL: u32 = 8; - #[derive(Debug, Clone, Copy, Default)] pub enum ReceiverMemorySizing { #[default] XSmall = 134217728, // 128MiB - Small = 268435456, // 256MiB - Medium = 536870912, // 512MiB - Large = 1073741824, // 1GiB - XLarge = 2147483648, // 2GiB - XXLarge = 4294967296, // 4GiB - XXXLarge = 8589934592, // 8GiB - XXXXLarge = 17179869184, // 16GiB + Small = 268435456, // 256MiB + Medium = 536870912, // 512MiB + Large = 1073741824, // 1GiB + XLarge = 2147483648, // 2GiB + XXLarge = 4294967296, // 4GiB + XXXLarge = 8589934592, // 8GiB + XXXXLarge = 17179869184, // 16GiB XXXXXLarge = 34359738368, // 32GiB } @@ -59,7 +70,6 @@ pub enum ReceiverMemorySizing { #[error("Invalid ReceiverMemoryCapacity: {0}")] pub struct ReceiverMemoryCapacityFromStrErr(String); - impl FromStr for ReceiverMemorySizing { type Err = String; @@ -72,7 +82,7 @@ impl FromStr for ReceiverMemorySizing { "xlarge" | "xl" => Ok(ReceiverMemorySizing::XLarge), "xxlarge" | "xxl" | "2xl" => Ok(ReceiverMemorySizing::XXLarge), "xxxlarge" | "xxxl" | "3xl" => Ok(ReceiverMemorySizing::XXXLarge), - "xxxxlarge" | "xxxxl" | "4xl" => Ok(ReceiverMemorySizing::XXXXLarge), + "xxxxlarge" | "xxxxl" | "4xl" => Ok(ReceiverMemorySizing::XXXXLarge), "xxxxxlarge" | "xxxxxl" | "5xl" => Ok(ReceiverMemorySizing::XXXXXLarge), _ => Err(s.to_string()), } @@ -84,7 +94,6 @@ pub struct PacketRecvTileMemConfig { pub frame_size: usize, pub memory_size: ReceiverMemorySizing, pub hugepage: bool, - } impl Default for PacketRecvTileMemConfig { @@ -98,63 +107,33 @@ impl Default for PacketRecvTileMemConfig { } fn packet_recv_tile( - sockets: Vec, - src_addr: IpAddr, - src_port: u16, + pkt_recv_idx: usize, + pkt_recv_socket: UdpSocket, exit: Arc, forwarder_stats: Arc, - fill_rx_vec: Vec>, + mut fill_rx: Rx, packet_tx_vec: Vec>, packet_router: R, - threads: &mut Vec>, -) -> std::io::Result<()> - where R: PacketRoutingStrategy + Send + 'static, -{ - assert!(sockets.len() == fill_rx_vec.len(), "mismatched fill_rx_vec and sockets length"); - assert!(sockets.len() == packet_tx_vec.len(), "mismatched packet_tx_vec and sockets length"); - - // let (_port, sockets) = solana_net_utils::multi_bind_in_range_with_config( - // src_addr, - // (src_port, src_port + 1), - // SocketConfig::default().reuseport(true), - // num_receiver, - // ) - // .unwrap_or_else(|_| { - // panic!("Failed to bind listener sockets. Check that port {src_port} is not in use.") - // }); - - for (thread_id, socket, mut fill_rx) in izip!(0..sockets.len(), sockets.into_iter(), fill_rx_vec.into_iter()) { - - // let shmem = SharedMem::new( - // mem_config.frame_size, - // mem_config.memory_size as usize / mem_config.frame_size, - // mem_config.hugepage, - // ).expect("SharedMem::new"); - - let socket = Arc::new(socket); - let exit = Arc::clone(&exit); - let stats = Arc::clone(&forwarder_stats); - let packet_tx_vec = packet_tx_vec.clone(); - let packet_router = R::clone(&packet_router); - let th = std::thread::Builder::new() - .name(format!("ssListen{thread_id}")) - .spawn(move || { - let socket = socket; - crate::recv_mmsg::recv_loop( - &socket, - &exit, - &stats, - Duration::default(), - &mut fill_rx, - &packet_tx_vec, - packet_router, - ) - .expect("recv_loop") - })?; - - threads.push(th); - } - Ok(()) + tile_drop_sig: TileClosedSignal, +) -> std::io::Result> +where + R: PacketRoutingStrategy + Send + 'static, +{ + std::thread::Builder::new() + .name(format!("ssListen{pkt_recv_idx}")) + .spawn(move || { + crate::recv_mmsg::recv_loop( + &pkt_recv_socket, + &exit, + &forwarder_stats, + Duration::default(), + &mut fill_rx, + &packet_tx_vec, + packet_router, + ) + .expect("recv_loop"); + drop(tile_drop_sig); + }) } #[derive(Clone, Debug)] @@ -167,6 +146,51 @@ pub struct SharedMemInfo { unsafe impl Send for SharedMemInfo {} unsafe impl Sync for SharedMemInfo {} +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TileKind { + PktRecv, + PktFwd, +} + +struct TileClosedSignal { + kind: TileKind, + idx: usize, + tx: Option>, +} + +struct TileWaitGroup { + rx: Receiver<(TileKind, usize)>, + tx: Sender<(TileKind, usize)>, +} + +impl TileWaitGroup { + fn new() -> Self { + let (tx, rx) = crossbeam_channel::unbounded(); + Self { rx, tx } + } + + fn get_tile_closed_signal(&self, kind: TileKind, idx: usize) -> TileClosedSignal { + TileClosedSignal { + kind, + idx, + tx: Some(self.tx.clone()), + } + } + + fn wait_first(self) -> (TileKind, usize) { + drop(self.tx); + self.rx.recv().expect("TileWaitGroup::wait_first") + } +} + +impl Drop for TileClosedSignal { + fn drop(&mut self) { + if let Some(tx) = &self.tx { + let _ = tx.send((self.kind, self.idx)); + } + } +} + fn packet_fwd_tile( packet_fwd_idx: usize, hot_dest_vec: Arc>>, @@ -174,12 +198,12 @@ fn packet_fwd_tile( mut packet_rx: Rx, fill_tx_vec: Vec>, shmem_info_vec: Vec, + exit: Arc, + tile_drop_sig: TileClosedSignal, ) -> std::io::Result> { - std::thread::Builder::new() .name(format!("ssPxyTx_{packet_fwd_idx}")) .spawn(move || { - let mut deduper = Deduper::<2, [u8]>::new(&mut rand::thread_rng(), DEDUPER_NUM_BITS); const UIO_MAXIOV: usize = libc::UIO_MAXIOV as usize; // We allocate double size to account for possible overflow if destinations array is really big @@ -187,15 +211,21 @@ fn packet_fwd_tile( let mut queued: VecDeque = VecDeque::with_capacity(UIO_MAXIOV); for shmem_info in &shmem_info_vec { - assert!(shmem_info.len.is_power_of_two(), "shmem_info.len must be a power of 2"); + assert!( + shmem_info.len.is_power_of_two(), + "shmem_info.len must be a power of 2" + ); } let mut next_deduper_reset_attempt = Instant::now() + Duration::from_secs(2); let mut recycled_frames: Vec = Vec::with_capacity(UIO_MAXIOV); - loop { - + while exit.load(Ordering::Relaxed) == false { if next_deduper_reset_attempt.elapsed() > Duration::ZERO { - deduper.maybe_reset(&mut rand::thread_rng(), DEDUPER_FALSE_POSITIVE_RATE, DEDUPER_RESET_CYCLE); + deduper.maybe_reset( + &mut rand::thread_rng(), + DEDUPER_FALSE_POSITIVE_RATE, + DEDUPER_RESET_CYCLE, + ); next_deduper_reset_attempt = Instant::now() + Duration::from_secs(2); } @@ -206,7 +236,6 @@ fn packet_fwd_tile( let data_size = packet.meta.size; let data_slice = &packet.buffer.chunk()[..data_size]; if deduper.dedup(data_slice) { - // put it inside the recycle queue debug!("Deduped packet from {}", packet.meta.addr); let desc = packet.buffer.into_inner(); @@ -214,15 +243,16 @@ fn packet_fwd_tile( } else { queued.push_back(packet); } - } let dests = hot_dest_vec.load(); let dests_len = dests.len(); // Fill up the next_batch_send - 'fill_batch_send: while next_batch_send.len() < UIO_MAXIOV && queued.len() > 0 && dests_len > 0 { - + 'fill_batch_send: while next_batch_send.len() < UIO_MAXIOV + && queued.len() > 0 + && dests_len > 0 + { let remaining = UIO_MAXIOV - next_batch_send.len(); if dests_len < remaining { break 'fill_batch_send; @@ -242,9 +272,21 @@ fn packet_fwd_tile( } } - assert!(next_batch_send.len() <= UIO_MAXIOV, "next_batch_send.len() = {}", next_batch_send.len()); - assert!(recycled_frames.len() <= UIO_MAXIOV, "recycled_frames.len() = {}", recycled_frames.len()); - assert!(queued.len() <= UIO_MAXIOV, "queued.len() = {}", queued.len()); + assert!( + next_batch_send.len() <= UIO_MAXIOV, + "next_batch_send.len() = {}", + next_batch_send.len() + ); + assert!( + recycled_frames.len() <= UIO_MAXIOV, + "recycled_frames.len() = {}", + recycled_frames.len() + ); + assert!( + queued.len() <= UIO_MAXIOV, + "queued.len() = {}", + queued.len() + ); match batch_send(&send_socket, &next_batch_send) { Ok(_) => { @@ -259,16 +301,22 @@ fn packet_fwd_tile( } } - // Recycle all used frames while let Some(desc) = recycled_frames.pop() { - let fill_ring_idx = shmem_info_vec.iter().find_position(|shmem_info| { - (desc.ptr as usize) & (shmem_info.len - 1) == (shmem_info.start_ptr as usize) - }).expect("unknown frame desc").0; - fill_tx_vec[fill_ring_idx].send(desc).expect("frame recycling"); + let fill_ring_idx = shmem_info_vec + .iter() + .find_position(|shmem_info| { + (desc.ptr as usize) & (shmem_info.len - 1) + == (shmem_info.start_ptr as usize) + }) + .expect("unknown frame desc") + .0; + fill_tx_vec[fill_ring_idx] + .send(desc) + .expect("frame recycling"); } - } + drop(tile_drop_sig); }) } @@ -280,8 +328,7 @@ pub enum ProxySystemError { AllocError(crate::mem::AllocError), } - -pub fn spawn_proxy_system( +pub fn run_proxy_system( pkt_recv_tile_mem_config: PacketRecvTileMemConfig, dest_addr_vec: Arc>>, src_ip: IpAddr, @@ -291,12 +338,12 @@ pub fn spawn_proxy_system( pkt_router: R, exit: Arc, stats: Arc, -) -> JoinHandle<()> - where R: PacketRoutingStrategy + Send + Sync + 'static, +) where + R: PacketRoutingStrategy + Send + Sync + 'static, { - + let mut tile_thread_vec: Vec> = Vec::new(); // Build pkt_recv sockets - let (_port, sockets) = solana_net_utils::multi_bind_in_range_with_config( + let (_port, pkt_recv_sk_vec) = solana_net_utils::multi_bind_in_range_with_config( src_ip, (src_port, src_port + 1), SocketConfig::default().reuseport(true), @@ -306,23 +353,32 @@ pub fn spawn_proxy_system( panic!("Failed to bind listener sockets. Check that port {src_port} is not in use.") }); + let num_frames = + pkt_recv_tile_mem_config.memory_size as usize / pkt_recv_tile_mem_config.frame_size; + let frame_size = pkt_recv_tile_mem_config.frame_size; - // Create the shared memory regions for recv tiles + let tile_wait_group = TileWaitGroup::new(); let mut shmem_info_vec: Vec = Vec::with_capacity(num_pkt_recv_tiles); let mut fill_tx_vec: Vec> = Vec::with_capacity(num_pkt_recv_tiles); let mut fill_rx_vec: Vec> = Vec::with_capacity(num_pkt_recv_tiles); - letm let mut shmem_vec: Vec = Vec::with_capacity(num_pkt_recv_tiles); + + let mut pkt_fwd_sk_vec: Vec = Vec::with_capacity(num_pkt_fwd_tiles); + let mut packet_rx_vec: Vec> = Vec::with_capacity(num_pkt_fwd_tiles); + let mut packet_tx_vec: Vec> = Vec::with_capacity(num_pkt_fwd_tiles); + + // Create the shared memory regions for recv tiles for _ in 0..num_pkt_recv_tiles { - let frame_size = pkt_recv_tile_mem_config.frame_size; - let num_frames = pkt_recv_tile_mem_config.memory_size as usize / pkt_recv_tile_mem_config.frame_size; - assert!(num_frames.is_power_of_two(), "num_frames must be a power of 2"); - assert!(frame_size.is_power_of_two(), "frame_size must be a power of 2"); - let shmem = SharedMem::new( - frame_size, - num_frames, - pkt_recv_tile_mem_config.hugepage, - ).expect("SharedMem::new"); + assert!( + num_frames.is_power_of_two(), + "num_frames must be a power of 2" + ); + assert!( + frame_size.is_power_of_two(), + "frame_size must be a power of 2" + ); + let shmem = SharedMem::new(frame_size, num_frames, pkt_recv_tile_mem_config.hugepage) + .expect("SharedMem::new"); let shmem_info = SharedMemInfo { start_ptr: shmem.ptr, @@ -337,51 +393,116 @@ pub fn spawn_proxy_system( ptr: unsafe { shmem.ptr.add(i * frame_size) }, frame_size: frame_size, }; - fill_tx.send(frame_desc).expect("initial frame ring population"); + fill_tx + .send(frame_desc) + .expect("initial frame ring population"); } shmem_vec.push(shmem); fill_tx_vec.push(fill_tx); fill_rx_vec.push(fill_rx); - } + } // Create socket for sending packets - - - let send_socket = { - let ipv6_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0); - match try_create_ipv6_socket(ipv6_addr) { - Ok(socket) => { - info!("Successfully bound send socket to IPv6 dual-stack address."); - socket.set_multicast_loop_v6(false) - .expect("Failed to disable IPv6 multicast loopback"); - socket - } - Err(e) if e.raw_os_error() == Some(libc::EAFNOSUPPORT) => { - // This error (code 97 on Linux) means IPv6 is not supported. - warn!("IPv6 not available. Falling back to IPv4-only for sending."); - let ipv4_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); - let socket = UdpSocket::bind(ipv4_addr) - .expect("Failed to bind to IPv4 socket after IPv6 failed"); - socket.set_multicast_ttl_v4(IP_MULTICAST_TTL).expect("IP_MULTICAST_TTL_V4"); - socket.set_multicast_loop_v4(false) - .expect("Failed to disable IPv4 multicast loopback"); - socket - } - Err(e) => { - // For any other error (e.g., port in use), panic. - panic!("Failed to bind send socket with an unexpected error: {e}"); + for _ in 0..num_pkt_fwd_tiles { + let send_socket = { + let ipv6_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0); + match try_create_ipv6_socket(ipv6_addr) { + Ok(socket) => { + info!("Successfully bound send socket to IPv6 dual-stack address."); + socket + .set_multicast_loop_v6(false) + .expect("Failed to disable IPv6 multicast loopback"); + socket + } + Err(e) if e.raw_os_error() == Some(libc::EAFNOSUPPORT) => { + // This error (code 97 on Linux) means IPv6 is not supported. + warn!("IPv6 not available. Falling back to IPv4-only for sending."); + let ipv4_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0); + let socket = UdpSocket::bind(ipv4_addr) + .expect("Failed to bind to IPv4 socket after IPv6 failed"); + socket + .set_multicast_ttl_v4(IP_MULTICAST_TTL) + .expect("IP_MULTICAST_TTL_V4"); + socket + .set_multicast_loop_v4(false) + .expect("Failed to disable IPv4 multicast loopback"); + socket + } + Err(e) => { + // For any other error (e.g., port in use), panic. + panic!("Failed to bind send socket with an unexpected error: {e}"); + } } - } - }; - + }; + pkt_fwd_sk_vec.push(send_socket); + } - + // Create pkt_fwd message rings + // One ring per pkt_fwd tile + for _ in 0..num_pkt_fwd_tiles { + let (packet_tx, packet_rx) = crate::mem::message_ring(num_frames).expect("pkt_fwd ring"); + packet_tx_vec.push(packet_tx); + packet_rx_vec.push(packet_rx); + } + // Spawn pkt_fwd tiles + for (pkt_fwd_idx, pkt_fwd_sk, packet_rx) in izip!( + 0..num_pkt_fwd_tiles, + pkt_fwd_sk_vec.into_iter(), + packet_rx_vec.into_iter() + ) { + let hot_dest_vec = Arc::clone(&dest_addr_vec); + let fill_tx_vec = fill_tx_vec.clone(); + let shmem_info_vec = shmem_info_vec.clone(); + let exit = Arc::clone(&exit); + let th = packet_fwd_tile( + pkt_fwd_idx, + hot_dest_vec, + pkt_fwd_sk, + packet_rx, + fill_tx_vec, + shmem_info_vec, + exit, + tile_wait_group.get_tile_closed_signal(TileKind::PktFwd, pkt_fwd_idx), + ) + .expect("packet_fwd_tile"); + tile_thread_vec.push(th); + } - todo!() -} + // Spawn pkt_recv tiles + for (pkt_recv_idx, pkt_recv_sk, fill_rx) in izip!( + 0..num_pkt_recv_tiles, + pkt_recv_sk_vec.into_iter(), + fill_rx_vec.into_iter() + ) { + let exit = Arc::clone(&exit); + let forwarder_stats = Arc::clone(&stats); + let packet_tx_vec_clone = packet_tx_vec.clone(); + let pkt_router_clone = pkt_router.clone(); + packet_recv_tile( + pkt_recv_idx, + pkt_recv_sk, + exit, + forwarder_stats, + fill_rx, + packet_tx_vec_clone, + pkt_router_clone, + tile_wait_group.get_tile_closed_signal(TileKind::PktRecv, pkt_recv_idx), + ) + .expect("packet_recv_tile"); + } + let (kind, idx) = tile_wait_group.wait_first(); + warn!("Tile of kind {kind:?} with idx {idx} has exited. Shutting down proxy system"); + exit.store(true, Ordering::Release); + for th in tile_thread_vec { + let result = th.join(); + if let Err(e) = result { + error!("Tile thread join error: {:?}", e); + } + } +} // #[cfg(test)] // mod tests { @@ -400,7 +521,6 @@ pub fn spawn_proxy_system( // }; // use solana_sdk::packet::{PacketFlags, PACKET_DATA_SIZE}; - // fn listen_and_collect(listen_socket: UdpSocket, received_packets: Arc>>>) { // let mut buf = [0u8; PACKET_DATA_SIZE]; // loop { From 921c5673b11698e6746fb7d387cb4f37d29f117b Mon Sep 17 00:00:00 2001 From: Louis-Vincent Boudreault Date: Wed, 14 Jan 2026 23:10:30 +0000 Subject: [PATCH 05/13] first draft of main2 --- proxy/Cargo.toml | 9 + proxy/src/main2.rs | 458 ++++++++++++++++++++++++++++++++++ proxy/src/mem.rs | 40 ++- proxy/src/recv_mmsg.rs | 2 +- proxy/src/triton_forwarder.rs | 80 ++++-- 5 files changed, 557 insertions(+), 32 deletions(-) create mode 100644 proxy/src/main2.rs diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 3c7e1a9..eddbb0c 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -6,6 +6,15 @@ authors = { workspace = true } homepage = { workspace = true } edition = { workspace = true } +[[bin]] +name = "triton-proxy" +path = "src/main2.rs" + +[[bin]] +name = "jito-shredstream-proxy" +path = "src/main.rs" + + [dependencies] ahash = { workspace = true } arc-swap = { workspace = true } diff --git a/proxy/src/main2.rs b/proxy/src/main2.rs new file mode 100644 index 0000000..07efb4f --- /dev/null +++ b/proxy/src/main2.rs @@ -0,0 +1,458 @@ +use std::{ + collections::HashMap, io::{self, Error, ErrorKind}, net::{IpAddr, Ipv4Addr, SocketAddr, ToSocketAddrs}, num::NonZeroUsize, panic, path::{Path, PathBuf}, str::FromStr, sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, thread::{self, sleep, spawn, JoinHandle}, time::Duration +}; + +use arc_swap::ArcSwap; +use clap::{arg, Parser}; +use crossbeam_channel::{Receiver, RecvError, Sender}; +use log::*; +use signal_hook::consts::{SIGINT, SIGTERM}; +use solana_client::client_error::{reqwest, ClientError}; +use solana_ledger::shred::Shred; +use solana_metrics::set_host_id; +use solana_sdk::{clock::Slot, signature::read_keypair_file}; +use solana_streamer::streamer::StreamerReceiveStats; +use thiserror::Error; +use tokio::runtime::Runtime; +use tonic::Status; + +use crate::{ + forwarder::ShredMetrics, multicast_config::{TritonMulticastConfig, TritonMulticastConfigV4, TritonMulticastConfigV6, create_multicast_sockets_triton}, recv_mmsg::FECSetRoutingStrategy, token_authenticator::BlockEngineConnectionError, triton_forwarder::PktRecvTileMemConfig, +}; +mod deshred; +pub mod forwarder; +mod heartbeat; +mod multicast_config; +mod server; +mod token_authenticator; +mod prom; +mod recv_mmsg; +mod mem; +mod triton_forwarder; + +use triton_forwarder::{PktRecvMemSizing}; + +#[derive(Clone, Debug, Parser)] +#[clap(author, version, about, long_about = None)] +// https://docs.rs/clap/latest/clap/_derive/_cookbook/git_derive/index.html +struct Args { + #[command(subcommand)] + shredstream_args: ProxySubcommands, +} + +#[derive(Clone, Debug, clap::Subcommand)] +enum ProxySubcommands { + /// Requests shreds from Jito and sends to all destinations. + Shredstream(ShredstreamArgs), + + /// Does not request shreds from Jito. Sends anything received on `src-bind-addr`:`src-bind-port` to all destinations. + ForwardOnly(CommonArgs), +} + +#[derive(clap::Args, Clone, Debug)] +struct ShredstreamArgs { + /// Address for Jito Block Engine. + /// See https://jito-labs.gitbook.io/mev/searcher-resources/block-engine#connection-details + #[arg(long, env)] + block_engine_url: String, + + /// Manual override for auth service address. For internal use. + #[arg(long, env)] + auth_url: Option, + + /// Path to keypair file used to authenticate with the backend. + #[arg(long, env)] + auth_keypair: PathBuf, + + /// Desired regions to receive heartbeats from. + /// Receives `n` different streams. Requires at least 1 region, comma separated. + #[arg(long, env, value_delimiter = ',', required(true))] + desired_regions: Vec, + + #[clap(flatten)] + common_args: CommonArgs, +} + +#[derive(clap::Args, Clone, Debug)] +struct CommonArgs { + /// Address where Shredstream proxy listens. + #[arg(long, env, default_value_t = IpAddr::V4(std::net::Ipv4Addr::new(0, 0, 0, 0)))] + src_bind_addr: IpAddr, + + /// Port where Shredstream proxy listens. Use `0` for random ephemeral port. + #[arg(long, env, default_value_t = 20_000)] + src_bind_port: u16, + + /// Static set of IP:Port where Shredstream proxy forwards shreds to, comma separated. + /// Eg. `127.0.0.1:8001,10.0.0.1:8001`. + // Note: store the original string, so we can do hostname resolution when refreshing destinations + #[arg(long, env, value_delimiter = ',', value_parser = resolve_hostname_port)] + dest_ip_ports: Vec<(SocketAddr, String)>, + + /// Http JSON endpoint to dynamically get IPs for Shredstream proxy to forward shreds. + /// Endpoints are then set-union with `dest-ip-ports`. + #[arg(long, env)] + endpoint_discovery_url: Option, + + /// Port to send shreds to for hosts fetched via `endpoint-discovery-url`. + /// Port can be found using `scripts/get_tvu_port.sh`. + /// See https://jito-labs.gitbook.io/mev/searcher-services/shredstream#running-shredstream + #[arg(long, env)] + discovered_endpoints_port: Option, + + /// Interval between logging stats to stdout and influx + #[arg(long, env, default_value_t = 15_000)] + metrics_report_interval_ms: u64, + + /// Logs trace shreds to stdout and influx + #[arg(long, env, default_value_t = false)] + debug_trace_shred: bool, + + /// Public IP address to use. + /// Overrides value fetched from `ifconfig.me`. + #[arg(long, env)] + public_ip: Option, + + /// Number of threads to use. Defaults to use up to 4. + #[arg(long, env)] + num_threads: Option, + + /// + /// The multicast group (ip addr) to join for receiving shreds. + /// Multicast groups supports IPv4 and IPv6. + #[arg(long, env)] + triton_multicast_group: Option, + /// The interface to bind to for triton multicast. + /// If IPV6 is used, this argument must be provided. + /// If ipv4, then optional (listen on all interfaces if not provided). + #[arg(long, env)] + triton_multicast_bind_interface: Option, + #[arg(long, env, default_value_t = 1)] + triton_multicast_num_threads: usize, + + /// Address to bind prometheus metrics server to. If not provided, prometheus server is disabled. + #[arg(long, env)] + prometheus_bind_addr: Option, + + /// Number of tiles dedicated to receiving packets. If not provided, defaults to number of CPU cores is 1. + #[arg(long, env)] + num_pkt_recv_tile: Option, + + /// Number of tiles dedicated to forwarding packets. If not provided, defaults to number of CPU cores is 1. + #[arg(long, env)] + num_pkt_fwd_tile: Option, + + /// Memory sizing for EACH packet receiver, uses t-shirt size convention (xs (default),s,m,l,xl,2xl,3xl,4xl,5xl). Each size increase double the memory, starting at 128MiB for x-small. + #[arg(long, env)] + pkt_recv_channel_memsize: Option, +} + +#[derive(Debug, Error)] +pub enum ShredstreamProxyError { + #[error("TonicError {0}")] + TonicError(#[from] tonic::transport::Error), + #[error("GrpcError {0}")] + GrpcError(#[from] Status), + #[error("ReqwestError {0}")] + ReqwestError(#[from] reqwest::Error), + #[error("SerdeJsonError {0}")] + SerdeJsonError(#[from] serde_json::Error), + #[error("RpcError {0}")] + RpcError(#[from] ClientError), + #[error("BlockEngineConnectionError {0}")] + BlockEngineConnectionError(#[from] BlockEngineConnectionError), + #[error("RecvError {0}")] + RecvError(#[from] RecvError), + #[error("IoError {0}")] + IoError(#[from] io::Error), + #[error("Shutdown")] + Shutdown, +} + +fn resolve_hostname_port(hostname_port: &str) -> io::Result<(SocketAddr, String)> { + let socketaddr = hostname_port.to_socket_addrs()?.next().ok_or_else(|| { + Error::new( + ErrorKind::AddrNotAvailable, + format!("Could not find destination {hostname_port}"), + ) + })?; + + Ok((socketaddr, hostname_port.to_string())) +} + +/// Returns public-facing IPV4 address +pub fn get_public_ip() -> reqwest::Result { + info!("Requesting public ip from ifconfig.me..."); + let client = reqwest::blocking::Client::builder() + .local_address(IpAddr::V4(Ipv4Addr::UNSPECIFIED)) + .build()?; + let response = client.get("https://ifconfig.me/ip").send()?.text()?; + let public_ip = IpAddr::from_str(&response).unwrap(); + info!("Retrieved public ip: {public_ip:?}"); + + Ok(public_ip) +} + +// Creates a channel that gets a message every time `SIGINT` is signalled. +fn shutdown_notifier(exit: Arc) -> io::Result<(Sender<()>, Receiver<()>)> { + let (s, r) = crossbeam_channel::bounded(256); + let mut signals = signal_hook::iterator::Signals::new([SIGINT, SIGTERM])?; + + let s_thread = s.clone(); + thread::spawn(move || { + for _ in signals.forever() { + exit.store(true, Ordering::SeqCst); + // send shutdown signal multiple times since crossbeam doesn't have broadcast channels + // each thread will consume a shutdown signal + for _ in 0..256 { + if s_thread.send(()).is_err() { + break; + } + } + } + }); + + Ok((s, r)) +} + +pub type ReconstructedShredsMap = HashMap>>; +fn main() -> Result<(), ShredstreamProxyError> { + env_logger::builder().init(); + let prom_registry = prometheus::Registry::new(); + prom::register_metrics(&prom_registry); + let all_args: Args = Args::parse(); + + let shredstream_args = all_args.shredstream_args.clone(); + // common args + let args = match all_args.shredstream_args { + ProxySubcommands::Shredstream(x) => x.common_args, + ProxySubcommands::ForwardOnly(x) => x, + }; + set_host_id(hostname::get()?.into_string().unwrap()); + if (args.endpoint_discovery_url.is_none() && args.discovered_endpoints_port.is_some()) + || (args.endpoint_discovery_url.is_some() && args.discovered_endpoints_port.is_none()) + { + return Err(ShredstreamProxyError::IoError(io::Error::new(ErrorKind::InvalidInput, "Invalid arguments provided, dynamic endpoints requires both --endpoint-discovery-url and --discovered-endpoints-port."))); + } + if args.endpoint_discovery_url.is_none() + && args.discovered_endpoints_port.is_none() + && args.dest_ip_ports.is_empty() + { + return Err(ShredstreamProxyError::IoError(io::Error::new(ErrorKind::InvalidInput, "No destinations found. You must provide values for --dest-ip-ports or --endpoint-discovery-url."))); + } + + let exit = Arc::new(AtomicBool::new(false)); + let (shutdown_sender, shutdown_receiver) = + shutdown_notifier(exit.clone()).expect("Failed to set up signal handler"); + let panic_hook = panic::take_hook(); + { + let exit = exit.clone(); + panic::set_hook(Box::new(move |panic_info| { + exit.store(true, Ordering::SeqCst); + let _ = shutdown_sender.send(()); + error!("exiting process"); + sleep(Duration::from_secs(1)); + // invoke the default handler and exit the process + panic_hook(panic_info); + })); + } + + let metrics = Arc::new(ShredMetrics::new(false)); + + + let mut thread_handles = vec![]; + if let ProxySubcommands::Shredstream(args) = shredstream_args { + let runtime = Runtime::new()?; + if args.desired_regions.len() > 2 { + warn!( + "Too many regions requested, only regions: {:?} will be used", + &args.desired_regions[..2] + ); + } + let heartbeat_hdl = + start_heartbeat(args, &exit, &shutdown_receiver, runtime, metrics.clone()); + thread_handles.push(heartbeat_hdl); + } + + // share sockets between refresh and forwarder thread + let unioned_dest_sockets = Arc::new(ArcSwap::from_pointee( + args.dest_ip_ports + .iter() + .map(|x| x.0) + .collect::>(), + )); + + let forward_stats = Arc::new(StreamerReceiveStats::new("shredstream_proxy-listen_thread")); + let use_discovery_service = + args.endpoint_discovery_url.is_some() && args.discovered_endpoints_port.is_some(); + + // let maybe_triton_multicast_config = match args.triton_multicast_group { + // Some(multicast_group) => { + // match multicast_group { + // IpAddr::V4(ipv4) => { + // Some(TritonMulticastConfig::Ipv4(TritonMulticastConfigV4 { + // multicast_ip: ipv4, + // bind_ifname: args.triton_multicast_bind_interface, + // })) + // } + // IpAddr::V6(ipv6) => { + // Some(TritonMulticastConfig::Ipv6(TritonMulticastConfigV6 { + // multicast_ip: ipv6, + // device_ifname: args.triton_multicast_bind_interface + // .ok_or_else(|| { + // io::Error::new( + // ErrorKind::InvalidInput, + // "triton-multicast-bind-interface is required for IPv6", + // ) + // })?, + // })) + // } + // } + // } + // None => None, + // }; + + // let maybe_triton_multicast_socket = maybe_triton_multicast_config + // .and_then(|config| { + // let num_threads = NonZeroUsize::new(args.triton_multicast_num_threads) + // .ok_or_else(|| { + // io::Error::new( + // ErrorKind::InvalidInput, + // "triton-multicast-num-threads must be non-zero", + // ) + // }).ok()?; + // Some( + // create_multicast_sockets_triton(&config, num_threads) + // .map(|ok| (config.ip(), ok)) + // ) + // }) + // .transpose()?; + + let pkt_recv_tile_mem_config = PktRecvTileMemConfig { + memory_size: args.pkt_recv_channel_memsize.unwrap_or_default(), + ..Default::default() + }; + let proxy_th = { + let exit = Arc::clone(&exit); + let forward_stats = forward_stats.clone(); + let unioned_dest_sockets = Arc::clone(&unioned_dest_sockets); + std::thread::Builder::new() + .name("tritonProxyMain".to_string()) + .spawn(move || { + triton_forwarder::run_proxy_system( + pkt_recv_tile_mem_config, + unioned_dest_sockets, + args.src_bind_addr, + args.src_bind_port, + args.num_pkt_recv_tile.map(|x| x.get()).unwrap_or(1), + args.num_pkt_fwd_tile.map(|x| x.get()).unwrap_or(1), + FECSetRoutingStrategy, + exit, + forward_stats, + ); + }) + .expect("tritonProxyMain") + }; + + thread_handles.push(proxy_th); + + let report_metrics_thread = { + let exit = exit.clone(); + spawn(move || { + while !exit.load(Ordering::Relaxed) { + sleep(Duration::from_secs(1)); + forward_stats.report(); + } + }) + }; + thread_handles.push(report_metrics_thread); + + // let metrics_hdl = forwarder::start_forwarder_accessory_thread( + // deduper, + // metrics.clone(), + // args.metrics_report_interval_ms, + // shutdown_receiver.clone(), + // exit.clone(), + // ); + // thread_handles.push(metrics_hdl); + if use_discovery_service { + let refresh_handle = forwarder::start_destination_refresh_thread( + args.endpoint_discovery_url.unwrap(), + args.discovered_endpoints_port.unwrap(), + args.dest_ip_ports, + unioned_dest_sockets, + shutdown_receiver.clone(), + exit.clone(), + ); + thread_handles.push(refresh_handle); + } + + if let Some(prom_bind_addr) = args.prometheus_bind_addr { + let prom_hdl = prom::spawn_prometheus_server( + prom_bind_addr, + prom_registry, + shutdown_receiver.clone() + ); + thread_handles.push(prom_hdl); + } + + info!( + "Shredstream started, listening on {}:{}/udp.", + args.src_bind_addr, args.src_bind_port + ); + + + + for thread in thread_handles { + thread.join().expect("thread panicked"); + } + + info!( + "Exiting Shredstream, {} received , {} sent successfully, {} failed, {} duplicate shreds.", + metrics.agg_received_cumulative.load(Ordering::Relaxed), + metrics + .agg_success_forward_cumulative + .load(Ordering::Relaxed), + metrics.agg_fail_forward_cumulative.load(Ordering::Relaxed), + metrics.duplicate_cumulative.load(Ordering::Relaxed), + ); + Ok(()) +} + +fn start_heartbeat( + args: ShredstreamArgs, + exit: &Arc, + shutdown_receiver: &Receiver<()>, + runtime: Runtime, + metrics: Arc, +) -> JoinHandle<()> { + let auth_keypair = Arc::new( + read_keypair_file(Path::new(&args.auth_keypair)).unwrap_or_else(|e| { + panic!( + "Unable to parse keypair file. Ensure that file {:?} is readable. Error: {e}", + args.auth_keypair + ) + }), + ); + + heartbeat::heartbeat_loop_thread( + args.block_engine_url.clone(), + args.auth_url.unwrap_or(args.block_engine_url), + auth_keypair, + args.desired_regions, + SocketAddr::new( + args.common_args + .public_ip + .unwrap_or_else(|| get_public_ip().unwrap()), + args.common_args.src_bind_port, + ), + runtime, + "shredstream_proxy".to_string(), + metrics, + shutdown_receiver.clone(), + exit.clone(), + ) +} diff --git a/proxy/src/mem.rs b/proxy/src/mem.rs index 513fe9a..f1eea0d 100644 --- a/proxy/src/mem.rs +++ b/proxy/src/mem.rs @@ -1,12 +1,8 @@ use std::{ - cell::UnsafeCell, hint::spin_loop, - ops::{Index, IndexMut}, sync::{ - atomic::{AtomicI32, AtomicUsize, Ordering}, - Arc, - }, - thread::{self, Thread}, + Arc, atomic::{AtomicI32, AtomicUsize, Ordering} + }, time::Duration, }; use bytes::{buf::UninitSlice, Buf, BufMut}; @@ -418,23 +414,41 @@ impl Tx { impl Rx { pub fn recv(&mut self) -> T { + self.recv_timeout_inner(None).expect("recv failed") + } + + pub fn recv_timeout(&mut self, duration: Duration) -> Option { + self.recv_timeout_inner(Some(duration)) + } + + fn recv_timeout_inner(&mut self, duration: Option) -> Option { for _ in 0..999 { if let Some(val) = self.try_recv() { - return val; + return Some(val); } spin_loop(); } loop { if let Some(val) = self.try_recv() { - return val; + return Some(val); } self.inner.futex_flag.store(0, Ordering::SeqCst); if let Some(val) = self.try_recv() { - return val; + return Some(val); } + + let timespec: Option = duration.map(|d| libc::timespec { + tv_sec: d.as_secs() as libc::time_t, + tv_nsec: d.subsec_nanos() as libc::c_long, + }); + + let timeout_ptr = match ×pec { + Some(ts) => ts as *const libc::timespec, + None => std::ptr::null(), + }; unsafe { libc::syscall( @@ -442,9 +456,13 @@ impl Rx { &self.inner.futex_flag as *const AtomicI32, libc::FUTEX_WAIT, 0, - std::ptr::null::(), + timeout_ptr, ); } + + if duration.is_some(){ + return self.try_recv(); + } } } @@ -474,7 +492,7 @@ impl Rx { #[cfg(test)] mod tests { - use std::{collections::HashSet, sync::Barrier}; + use std::{collections::HashSet, sync::Barrier, thread}; use super::*; diff --git a/proxy/src/recv_mmsg.rs b/proxy/src/recv_mmsg.rs index a027193..faaa4bc 100644 --- a/proxy/src/recv_mmsg.rs +++ b/proxy/src/recv_mmsg.rs @@ -50,7 +50,7 @@ fn hash_pair(x: u64, y: u32) -> u64 { } #[derive(Debug, Clone)] -struct FECSetRoutingStrategy; +pub struct FECSetRoutingStrategy; impl PacketRoutingStrategy for FECSetRoutingStrategy { fn route_packet(&self, packet: &TritonPacket, num_dest: usize) -> Option { diff --git a/proxy/src/triton_forwarder.rs b/proxy/src/triton_forwarder.rs index de753e1..220758e 100644 --- a/proxy/src/triton_forwarder.rs +++ b/proxy/src/triton_forwarder.rs @@ -53,7 +53,7 @@ pub const DEDUPER_RESET_CYCLE: Duration = Duration::from_secs(5 * 60); pub const IP_MULTICAST_TTL: u32 = 8; #[derive(Debug, Clone, Copy, Default)] -pub enum ReceiverMemorySizing { +pub enum PktRecvMemSizing { #[default] XSmall = 134217728, // 128MiB Small = 268435456, // 256MiB @@ -70,37 +70,37 @@ pub enum ReceiverMemorySizing { #[error("Invalid ReceiverMemoryCapacity: {0}")] pub struct ReceiverMemoryCapacityFromStrErr(String); -impl FromStr for ReceiverMemorySizing { +impl FromStr for PktRecvMemSizing { type Err = String; fn from_str(s: &str) -> Result { match s.to_lowercase().as_str() { - "xsmall" | "xs" => Ok(ReceiverMemorySizing::XSmall), - "small" | "s" => Ok(ReceiverMemorySizing::Small), - "medium" | "m" => Ok(ReceiverMemorySizing::Medium), - "large" | "l" => Ok(ReceiverMemorySizing::Large), - "xlarge" | "xl" => Ok(ReceiverMemorySizing::XLarge), - "xxlarge" | "xxl" | "2xl" => Ok(ReceiverMemorySizing::XXLarge), - "xxxlarge" | "xxxl" | "3xl" => Ok(ReceiverMemorySizing::XXXLarge), - "xxxxlarge" | "xxxxl" | "4xl" => Ok(ReceiverMemorySizing::XXXXLarge), - "xxxxxlarge" | "xxxxxl" | "5xl" => Ok(ReceiverMemorySizing::XXXXXLarge), + "xsmall" | "xs" => Ok(PktRecvMemSizing::XSmall), + "small" | "s" => Ok(PktRecvMemSizing::Small), + "medium" | "m" => Ok(PktRecvMemSizing::Medium), + "large" | "l" => Ok(PktRecvMemSizing::Large), + "xlarge" | "xl" => Ok(PktRecvMemSizing::XLarge), + "xxlarge" | "xxl" | "2xl" => Ok(PktRecvMemSizing::XXLarge), + "xxxlarge" | "xxxl" | "3xl" => Ok(PktRecvMemSizing::XXXLarge), + "xxxxlarge" | "xxxxl" | "4xl" => Ok(PktRecvMemSizing::XXXXLarge), + "xxxxxlarge" | "xxxxxl" | "5xl" => Ok(PktRecvMemSizing::XXXXXLarge), _ => Err(s.to_string()), } } } #[derive(Clone, Debug)] -pub struct PacketRecvTileMemConfig { +pub struct PktRecvTileMemConfig { pub frame_size: usize, - pub memory_size: ReceiverMemorySizing, + pub memory_size: PktRecvMemSizing, pub hugepage: bool, } -impl Default for PacketRecvTileMemConfig { +impl Default for PktRecvTileMemConfig { fn default() -> Self { Self { frame_size: 2048, - memory_size: ReceiverMemorySizing::default(), + memory_size: PktRecvMemSizing::default(), hugepage: false, } } @@ -219,7 +219,7 @@ fn packet_fwd_tile( let mut next_deduper_reset_attempt = Instant::now() + Duration::from_secs(2); let mut recycled_frames: Vec = Vec::with_capacity(UIO_MAXIOV); - while exit.load(Ordering::Relaxed) == false { + while !exit.load(Ordering::Relaxed) { if next_deduper_reset_attempt.elapsed() > Duration::ZERO { deduper.maybe_reset( &mut rand::thread_rng(), @@ -229,10 +229,27 @@ fn packet_fwd_tile( next_deduper_reset_attempt = Instant::now() + Duration::from_secs(2); } + if queued.is_empty() && recycled_frames.is_empty() && next_batch_send.is_empty() { + if let Some(packet) = packet_rx.recv_timeout(Duration::from_millis(100)) { + let data_size = packet.meta.size; + let data_slice = &packet.buffer.chunk()[..data_size]; + if deduper.dedup(data_slice) { + // put it inside the recycle queue + debug!("Deduped packet from {}", packet.meta.addr); + let desc = packet.buffer.into_inner(); + recycled_frames.push(desc); + } else { + queued.push_back(packet); + } + } + } + // Fill up the queued OR recycled_frames as much as possible - while queued.len() < UIO_MAXIOV && recycled_frames.len() < UIO_MAXIOV { + 'fill_backlog: while queued.len() < UIO_MAXIOV && recycled_frames.len() < UIO_MAXIOV { // Fill the batch as much as possible. - let packet = packet_rx.recv(); + let Some(packet) = packet_rx.try_recv() else { + break 'fill_backlog; + }; let data_size = packet.meta.size; let data_slice = &packet.buffer.chunk()[..data_size]; if deduper.dedup(data_slice) { @@ -316,6 +333,7 @@ fn packet_fwd_tile( .expect("frame recycling"); } } + log::info!("Exiting pkt_fwd_tile {}", packet_fwd_idx); drop(tile_drop_sig); }) } @@ -329,7 +347,7 @@ pub enum ProxySystemError { } pub fn run_proxy_system( - pkt_recv_tile_mem_config: PacketRecvTileMemConfig, + pkt_recv_tile_mem_config: PktRecvTileMemConfig, dest_addr_vec: Arc>>, src_ip: IpAddr, src_port: u16, @@ -379,6 +397,13 @@ pub fn run_proxy_system( ); let shmem = SharedMem::new(frame_size, num_frames, pkt_recv_tile_mem_config.hugepage) .expect("SharedMem::new"); + log::info!( + "Created shared memory region with frame_size={} num_frames={} total_size={} hugepage={}", + frame_size, + num_frames, + shmem.len(), + pkt_recv_tile_mem_config.hugepage, + ); let shmem_info = SharedMemInfo { start_ptr: shmem.ptr, @@ -400,6 +425,7 @@ pub fn run_proxy_system( shmem_vec.push(shmem); fill_tx_vec.push(fill_tx); fill_rx_vec.push(fill_rx); + log::info!("Initialized frame ring with {} frames", num_frames); } // Create socket for sending packets @@ -434,6 +460,10 @@ pub fn run_proxy_system( } } }; + log::info!( + "Packet forwarder sending socket bound to {}", + send_socket.local_addr().unwrap() + ); pkt_fwd_sk_vec.push(send_socket); } @@ -444,6 +474,10 @@ pub fn run_proxy_system( packet_tx_vec.push(packet_tx); packet_rx_vec.push(packet_rx); } + log::info!( + "Initialized pkt_fwd message rings with {} slots", + num_frames + ); // Spawn pkt_fwd tiles for (pkt_fwd_idx, pkt_fwd_sk, packet_rx) in izip!( @@ -467,6 +501,7 @@ pub fn run_proxy_system( ) .expect("packet_fwd_tile"); tile_thread_vec.push(th); + log::info!("Spawned pkt_fwd tile {}", pkt_fwd_idx); } // Spawn pkt_recv tiles @@ -479,7 +514,7 @@ pub fn run_proxy_system( let forwarder_stats = Arc::clone(&stats); let packet_tx_vec_clone = packet_tx_vec.clone(); let pkt_router_clone = pkt_router.clone(); - packet_recv_tile( + let jh = packet_recv_tile( pkt_recv_idx, pkt_recv_sk, exit, @@ -490,12 +525,17 @@ pub fn run_proxy_system( tile_wait_group.get_tile_closed_signal(TileKind::PktRecv, pkt_recv_idx), ) .expect("packet_recv_tile"); + tile_thread_vec.push(jh); + log::info!("Spawned pkt_recv tile {}", pkt_recv_idx); } let (kind, idx) = tile_wait_group.wait_first(); warn!("Tile of kind {kind:?} with idx {idx} has exited. Shutting down proxy system"); exit.store(true, Ordering::Release); + drop(fill_tx_vec); + drop(packet_tx_vec); + log::info!("Waiting for {} tile threads to exit", tile_thread_vec.len()); for th in tile_thread_vec { let result = th.join(); if let Err(e) = result { From 273b30e42de1e6c12f13f52c0d28af7e754c762d Mon Sep 17 00:00:00 2001 From: Louis-Vincent Boudreault Date: Thu, 15 Jan 2026 01:13:53 +0000 Subject: [PATCH 06/13] fixed triton_forwarder --- proxy/src/main2.rs | 6 ++++-- proxy/src/recv_mmsg.rs | 30 +++++++++++++++++++++--------- proxy/src/triton_forwarder.rs | 11 +++++++++-- setup_net.sh | 23 +++++++++++++++++++++++ 4 files changed, 57 insertions(+), 13 deletions(-) create mode 100755 setup_net.sh diff --git a/proxy/src/main2.rs b/proxy/src/main2.rs index 07efb4f..349bac9 100644 --- a/proxy/src/main2.rs +++ b/proxy/src/main2.rs @@ -337,7 +337,8 @@ fn main() -> Result<(), ShredstreamProxyError> { }; let proxy_th = { let exit = Arc::clone(&exit); - let forward_stats = forward_stats.clone(); + let pkt_recv_stats = forward_stats.clone(); + let pkt_fwd_stats = metrics.clone(); let unioned_dest_sockets = Arc::clone(&unioned_dest_sockets); std::thread::Builder::new() .name("tritonProxyMain".to_string()) @@ -351,7 +352,8 @@ fn main() -> Result<(), ShredstreamProxyError> { args.num_pkt_fwd_tile.map(|x| x.get()).unwrap_or(1), FECSetRoutingStrategy, exit, - forward_stats, + pkt_recv_stats, + pkt_fwd_stats, ); }) .expect("tritonProxyMain") diff --git a/proxy/src/recv_mmsg.rs b/proxy/src/recv_mmsg.rs index faaa4bc..cf83b43 100644 --- a/proxy/src/recv_mmsg.rs +++ b/proxy/src/recv_mmsg.rs @@ -89,6 +89,7 @@ where // Refill the frame buffers as much as we can, 'fill_bufmut: while frame_bufmut_vec.len() < PACKETS_PER_BATCH { + log::trace!(" Refilling frame buffers {}", frame_bufmut_vec.len()); let maybe_frame_buf = fill_rx.try_recv(); match maybe_frame_buf { Some(frame_desc) => { @@ -107,9 +108,12 @@ where } } } + let result = recv_from(&mut frame_bufmut_vec, socket, coalesce, &mut packet_batch); + if let Ok(len) = result { if len > 0 { + log::trace!("Received {} packets", len); let StreamerReceiveStats { packets_count, packet_batches_count, @@ -167,7 +171,7 @@ pub fn recv_from( let mut i = 0; loop { - match triton_recv_mmsg(socket, available_frame_buf_vec, &mut batch[i..]) { + match triton_recv_mmsg(socket, available_frame_buf_vec, batch) { Err(_) if i > 0 => { if start.elapsed() > max_wait { break; @@ -223,11 +227,12 @@ impl AsRef<[u8]> for TritonPacket { pub fn triton_recv_mmsg( sock: &UdpSocket, fill_buffers: &mut Vec, - packets: &mut [TritonPacket], + packets: &mut Vec, ) -> io::Result { // Should never hit this, but bail if the caller didn't provide any Packets // to receive into - if packets.is_empty() { + if fill_buffers.is_empty() { + log::trace!("triton_recv_mmsg: no fill buffers to receive into"); return Ok(0); } // Assert that there are no leftovers in packets. @@ -236,10 +241,13 @@ pub fn triton_recv_mmsg( let mut iovs = [MaybeUninit::uninit(); NUM_RCVMMSGS]; let mut addrs = [MaybeUninit::zeroed(); NUM_RCVMMSGS]; let mut hdrs = [MaybeUninit::uninit(); NUM_RCVMMSGS]; - + let remaining_packets = packets.capacity() - packets.len(); let sock_fd = sock.as_raw_fd(); - let count = cmp::min(iovs.len(), packets.len()).min(fill_buffers.len()); - + let count = cmp::min(iovs.len(), remaining_packets).min(fill_buffers.len()); + log::trace!( + "triton_recv_mmsg: preparing to receive up to {} packets", + count + ); let mut frame_buffer_inflight_vec: [MaybeUninit; NUM_RCVMMSGS] = std::array::from_fn(|_| MaybeUninit::uninit()); @@ -293,8 +301,8 @@ pub fn triton_recv_mmsg( } else { usize::try_from(nrecv).unwrap() }; - for (addr, hdr, pkt, filled_bufmut) in - izip!(addrs, hdrs, packets.iter_mut(), frame_buffer_inflight_vec).take(nrecv) + for (i, addr, hdr, filled_bufmut) in + izip!(0..nrecv, addrs, hdrs, frame_buffer_inflight_vec).take(nrecv) { // SAFETY: We initialized `count` elements of `hdrs` above. `count` is // passed to recvmmsg() as the limit of messages that can be read. So, @@ -306,11 +314,15 @@ pub fn triton_recv_mmsg( let addr_ref = unsafe { addr.assume_init_ref() }; let filled_bufmut = unsafe { filled_bufmut.assume_init_read() }; let filled_buf: FrameBuf = filled_bufmut.into(); - pkt.buffer = filled_buf; + let mut pkt = TritonPacket { + buffer: filled_buf, + meta: Meta::default(), + }; pkt.meta_mut().size = hdr_ref.msg_len as usize; if let Some(addr) = cast_socket_addr(addr_ref, hdr_ref) { pkt.meta_mut().set_socket_addr(&addr); } + packets.push(pkt); } for (iov, addr, hdr) in izip!(&mut iovs, &mut addrs, &mut hdrs).take(count) { diff --git a/proxy/src/triton_forwarder.rs b/proxy/src/triton_forwarder.rs index 220758e..f9d35f2 100644 --- a/proxy/src/triton_forwarder.rs +++ b/proxy/src/triton_forwarder.rs @@ -198,6 +198,7 @@ fn packet_fwd_tile( mut packet_rx: Rx, fill_tx_vec: Vec>, shmem_info_vec: Vec, + stats: Arc, exit: Arc, tile_drop_sig: TileClosedSignal, ) -> std::io::Result> { @@ -305,9 +306,13 @@ fn packet_fwd_tile( queued.len() ); + let batch_send_ts = Instant::now(); match batch_send(&send_socket, &next_batch_send) { Ok(_) => { // Successfully sent all packets in the batch + let send_duration = batch_send_ts.elapsed(); + stats.batch_send_time_spent.fetch_add(send_duration.as_micros() as u64, Ordering::Relaxed); + stats.send_batch_count.fetch_add(1, Ordering::Relaxed); } Err(SendPktsError::IoError(err, num_failed)) => { error!( @@ -355,7 +360,8 @@ pub fn run_proxy_system( num_pkt_fwd_tiles: usize, pkt_router: R, exit: Arc, - stats: Arc, + pk_recv_stats: Arc, + pk_fwd_stats: Arc, ) where R: PacketRoutingStrategy + Send + Sync + 'static, { @@ -496,6 +502,7 @@ pub fn run_proxy_system( packet_rx, fill_tx_vec, shmem_info_vec, + Arc::clone(&pk_fwd_stats), exit, tile_wait_group.get_tile_closed_signal(TileKind::PktFwd, pkt_fwd_idx), ) @@ -511,7 +518,7 @@ pub fn run_proxy_system( fill_rx_vec.into_iter() ) { let exit = Arc::clone(&exit); - let forwarder_stats = Arc::clone(&stats); + let forwarder_stats = Arc::clone(&pk_recv_stats); let packet_tx_vec_clone = packet_tx_vec.clone(); let pkt_router_clone = pkt_router.clone(); let jh = packet_recv_tile( diff --git a/setup_net.sh b/setup_net.sh new file mode 100755 index 0000000..7540ec1 --- /dev/null +++ b/setup_net.sh @@ -0,0 +1,23 @@ +# 1. Create the namespace +sudo ip netns add ns1 + +# 2. Create the virtual cable (veth0 <-> veth1) +sudo ip link add veth0 type veth peer name veth1 + +# 3. Move veth1 end into the namespace +sudo ip link set veth1 netns ns1 + +# 4. Assign IP to Host side (veth0) - The Driver +sudo ip addr add 172.31.0.1/24 dev veth0 +sudo ip link set veth0 up + +# 5. Assign IP to Namespace side (veth1) - The Server +sudo ip netns exec ns1 ip addr add 172.31.0.2/24 dev veth1 +sudo ip netns exec ns1 ip link set veth1 up + +# 6. Turn off offloading (Crucial for AF_XDP) +#sudo ethtool -K veth0 gro off +#sudo ip netns exec ns1 ethtool -K veth1 gro off + +# 7. Verify connectivity +ping -c 2 10.0.0.11 \ No newline at end of file From aea6f46bf9f056b0d05f496377ad82b6f89de9fd Mon Sep 17 00:00:00 2001 From: Louis-Vincent Boudreault Date: Thu, 15 Jan 2026 01:29:59 +0000 Subject: [PATCH 07/13] added prometheus metric to triton-forwarder --- proxy/src/main2.rs | 18 ++++++------- proxy/src/mem.rs | 11 +------- proxy/src/recv_mmsg.rs | 29 +++++++------------- proxy/src/triton_forwarder.rs | 51 ++++++++++++++--------------------- 4 files changed, 39 insertions(+), 70 deletions(-) diff --git a/proxy/src/main2.rs b/proxy/src/main2.rs index 349bac9..d28cfa1 100644 --- a/proxy/src/main2.rs +++ b/proxy/src/main2.rs @@ -22,16 +22,16 @@ use tonic::Status; use crate::{ forwarder::ShredMetrics, multicast_config::{TritonMulticastConfig, TritonMulticastConfigV4, TritonMulticastConfigV6, create_multicast_sockets_triton}, recv_mmsg::FECSetRoutingStrategy, token_authenticator::BlockEngineConnectionError, triton_forwarder::PktRecvTileMemConfig, }; -mod deshred; +pub mod deshred; pub mod forwarder; -mod heartbeat; -mod multicast_config; -mod server; -mod token_authenticator; -mod prom; -mod recv_mmsg; -mod mem; -mod triton_forwarder; +pub mod heartbeat; +pub mod multicast_config; +pub mod server; +pub mod token_authenticator; +pub mod prom; +pub mod recv_mmsg; +pub mod mem; +pub mod triton_forwarder; use triton_forwarder::{PktRecvMemSizing}; diff --git a/proxy/src/mem.rs b/proxy/src/mem.rs index f1eea0d..765f9ec 100644 --- a/proxy/src/mem.rs +++ b/proxy/src/mem.rs @@ -194,20 +194,11 @@ impl FrameBufMut { ((self.ptr as usize) & !(self.desc.frame_size - 1)) as *mut u8 } - pub fn len(&self) -> usize { - let base = self.base() as usize; - (self.ptr as usize) - base - } - #[inline] pub fn capacity(&self) -> usize { self.desc.frame_size } - #[inline] - pub fn cast_to(&self) -> *mut T { - self.base() as *mut T - } #[inline] fn end_ptr(&self) -> *const u8 { @@ -604,7 +595,7 @@ mod tests { assert_eq!(buf_mut.remaining_mut(), 4096); buf_mut.put_slice(&[1, 2, 3, 4]); assert_eq!(buf_mut.remaining_mut(), 4092); - assert_eq!(buf_mut.len(), 4); + assert_eq!(buf_mut.chunk_mut().len(), 4); let mut buf: FrameBuf = buf_mut.into(); assert_eq!(buf.len(), 4); diff --git a/proxy/src/recv_mmsg.rs b/proxy/src/recv_mmsg.rs index cf83b43..2e18536 100644 --- a/proxy/src/recv_mmsg.rs +++ b/proxy/src/recv_mmsg.rs @@ -1,7 +1,5 @@ use std::{ cmp, - collections::VecDeque, - hint::spin_loop, io, mem::{self, zeroed, MaybeUninit}, net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, UdpSocket}, @@ -15,26 +13,11 @@ use itertools::izip; use libc::{iovec, mmsghdr, msghdr, sockaddr_storage, AF_INET, AF_INET6, MSG_WAITFORONE}; use log::{error, trace}; use socket2::socklen_t; -use solana_ledger::shred::ShredId; use solana_perf::packet::{NUM_RCVMMSGS, PACKETS_PER_BATCH}; -use solana_sdk::packet::{Meta, Packet, PACKET_DATA_SIZE}; -use solana_streamer::{recvmmsg::recv_mmsg, streamer::StreamerReceiveStats}; +use solana_sdk::packet::{Meta, PACKET_DATA_SIZE}; +use solana_streamer::{streamer::StreamerReceiveStats}; -use crate::mem::{try_alloc_shared_mem, FrameBuf, FrameBufMut, FrameDesc, Rx, SharedMem, Tx}; - -const OFFSET_SHRED_TYPE: usize = 82; -const OFFSET_DATA_PARENT: usize = 83; // 83 + 0 -const OFFSET_DATA_INDEX: usize = 83 - 15; // Index is actually in common header -const OFFSET_CODING_POSITION: usize = 83 + 2; - -// Shred types based on Solana spec -const SHRED_TYPE_DATA: u8 = 0b1010_0101; -const SHRED_TYPE_CODING: u8 = 0b0101_1010; - -pub struct RecvMemConfig { - pub frames_count: usize, - pub hugepages: bool, -} +use crate::{mem::{FrameBuf, FrameBufMut, FrameDesc, Rx, Tx}, prom::{inc_packets_by_source, inc_packets_received, observe_recv_interval, observe_recv_packet_count}}; pub trait PacketRoutingStrategy: Clone { fn route_packet(&self, packet: &TritonPacket, num_dest: usize) -> Option; @@ -109,11 +92,17 @@ where } } + let t = Instant::now(); let result = recv_from(&mut frame_bufmut_vec, socket, coalesce, &mut packet_batch); + let recv_interval = t.elapsed(); + if let Ok(len) = result { if len > 0 { + observe_recv_interval(recv_interval.as_micros() as f64); log::trace!("Received {} packets", len); + inc_packets_received(len as u64); + observe_recv_packet_count(len as f64); let StreamerReceiveStats { packets_count, packet_batches_count, diff --git a/proxy/src/triton_forwarder.rs b/proxy/src/triton_forwarder.rs index f9d35f2..18c63f6 100644 --- a/proxy/src/triton_forwarder.rs +++ b/proxy/src/triton_forwarder.rs @@ -1,49 +1,37 @@ use std::{ - collections::{HashSet, VecDeque}, + collections::VecDeque, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, str::FromStr, sync::{ - atomic::{AtomicBool, AtomicU64, Ordering}, - Arc, RwLock, + atomic::{AtomicBool, Ordering}, + Arc, }, - thread::{Builder, JoinHandle}, - time::{Duration, Instant, SystemTime}, + thread::JoinHandle, + time::{Duration, Instant}, }; use arc_swap::ArcSwap; use bytes::Buf; -use crossbeam_channel::{Receiver, RecvError, Sender}; -use dashmap::DashMap; +use crossbeam_channel::{Receiver, Sender}; use itertools::{izip, Itertools}; -use jito_protos::shredstream::{Entry as PbEntry, TraceShred}; use libc; use log::{debug, error, info, warn}; -use prost::Message; -use socket2::{Domain, Protocol, Socket, Type}; -use solana_client::client_error::reqwest; -use solana_metrics::{datapoint_info, datapoint_warn}; use solana_net_utils::SocketConfig; -use solana_perf::{ - deduper::Deduper, - packet::{PacketBatch, PacketBatchRecycler}, - recycler::Recycler, -}; -use solana_sdk::exit; +use solana_perf::deduper::Deduper; use solana_streamer::{ sendmmsg::{batch_send, SendPktsError}, - streamer::{self, StreamerReceiveStats}, + streamer::{StreamerReceiveStats}, }; use crate::{ forwarder::{try_create_ipv6_socket, ShredMetrics}, mem::{FrameBuf, FrameDesc, Rx, SharedMem, Tx}, prom::{ - inc_packets_by_source, inc_packets_deduped, inc_packets_forward_failed, - inc_packets_forwarded, inc_packets_received, observe_dedup_time, observe_recv_interval, - observe_recv_packet_count, observe_send_duration, observe_send_packet_count, + inc_packets_deduped, inc_packets_forward_failed, + observe_dedup_time, + observe_send_duration, observe_send_packet_count, }, recv_mmsg::{PacketRoutingStrategy, TritonPacket}, - resolve_hostname_port, ShredstreamProxyError, }; // values copied from https://github.com/solana-labs/solana/blob/33bde55bbdde13003acf45bb6afe6db4ab599ae4/core/src/sigverify_shreds.rs#L20 @@ -253,14 +241,18 @@ fn packet_fwd_tile( }; let data_size = packet.meta.size; let data_slice = &packet.buffer.chunk()[..data_size]; + let t = Instant::now(); if deduper.dedup(data_slice) { // put it inside the recycle queue debug!("Deduped packet from {}", packet.meta.addr); let desc = packet.buffer.into_inner(); recycled_frames.push(desc); + inc_packets_deduped(1); } else { queued.push_back(packet); } + let dedup_duration = t.elapsed(); + observe_dedup_time(dedup_duration.as_micros() as f64); } let dests = hot_dest_vec.load(); @@ -313,6 +305,8 @@ fn packet_fwd_tile( let send_duration = batch_send_ts.elapsed(); stats.batch_send_time_spent.fetch_add(send_duration.as_micros() as u64, Ordering::Relaxed); stats.send_batch_count.fetch_add(1, Ordering::Relaxed); + observe_send_duration(send_duration.as_micros() as f64); + observe_send_packet_count(next_batch_send.len() as f64); } Err(SendPktsError::IoError(err, num_failed)) => { error!( @@ -320,9 +314,12 @@ fn packet_fwd_tile( {num_failed} packets failed. Error: {err}", next_batch_send.len() ); + inc_packets_forward_failed(num_failed as u64); } } + next_batch_send.clear(); + // Recycle all used frames while let Some(desc) = recycled_frames.pop() { let fill_ring_idx = shmem_info_vec @@ -343,14 +340,6 @@ fn packet_fwd_tile( }) } -#[derive(thiserror::Error, Debug)] -pub enum ProxySystemError { - #[error(transparent)] - IoError(std::io::Error), - #[error(transparent)] - AllocError(crate::mem::AllocError), -} - pub fn run_proxy_system( pkt_recv_tile_mem_config: PktRecvTileMemConfig, dest_addr_vec: Arc>>, From 64326d0b20ade926409ec20edc5163626edbb341 Mon Sep 17 00:00:00 2001 From: Louis-Vincent Boudreault Date: Thu, 15 Jan 2026 02:53:52 +0000 Subject: [PATCH 08/13] support triton multicast --- proxy/src/main.rs | 1 + proxy/src/main2.rs | 71 +++--- proxy/src/recv_mmsg.rs | 13 +- proxy/src/triton_forwarder.rs | 49 +++-- proxy/src/triton_multicast_config.rs | 315 +++++++++++++++++++++++++++ setup_net.sh | 5 +- 6 files changed, 384 insertions(+), 70 deletions(-) create mode 100644 proxy/src/triton_multicast_config.rs diff --git a/proxy/src/main.rs b/proxy/src/main.rs index a6c71ff..487db53 100644 --- a/proxy/src/main.rs +++ b/proxy/src/main.rs @@ -28,6 +28,7 @@ mod deshred; pub mod forwarder; mod heartbeat; mod multicast_config; +mod triton_multicast_config; mod server; mod token_authenticator; mod prom; diff --git a/proxy/src/main2.rs b/proxy/src/main2.rs index d28cfa1..80c7736 100644 --- a/proxy/src/main2.rs +++ b/proxy/src/main2.rs @@ -20,12 +20,12 @@ use tokio::runtime::Runtime; use tonic::Status; use crate::{ - forwarder::ShredMetrics, multicast_config::{TritonMulticastConfig, TritonMulticastConfigV4, TritonMulticastConfigV6, create_multicast_sockets_triton}, recv_mmsg::FECSetRoutingStrategy, token_authenticator::BlockEngineConnectionError, triton_forwarder::PktRecvTileMemConfig, + forwarder::ShredMetrics, triton_multicast_config::{TritonMulticastConfig, TritonMulticastConfigV4, TritonMulticastConfigV6, create_multicast_sockets_triton}, recv_mmsg::FECSetRoutingStrategy, token_authenticator::BlockEngineConnectionError, triton_forwarder::PktRecvTileMemConfig, }; pub mod deshred; pub mod forwarder; pub mod heartbeat; -pub mod multicast_config; +pub mod triton_multicast_config; pub mod server; pub mod token_authenticator; pub mod prom; @@ -289,47 +289,31 @@ fn main() -> Result<(), ShredstreamProxyError> { let use_discovery_service = args.endpoint_discovery_url.is_some() && args.discovered_endpoints_port.is_some(); - // let maybe_triton_multicast_config = match args.triton_multicast_group { - // Some(multicast_group) => { - // match multicast_group { - // IpAddr::V4(ipv4) => { - // Some(TritonMulticastConfig::Ipv4(TritonMulticastConfigV4 { - // multicast_ip: ipv4, - // bind_ifname: args.triton_multicast_bind_interface, - // })) - // } - // IpAddr::V6(ipv6) => { - // Some(TritonMulticastConfig::Ipv6(TritonMulticastConfigV6 { - // multicast_ip: ipv6, - // device_ifname: args.triton_multicast_bind_interface - // .ok_or_else(|| { - // io::Error::new( - // ErrorKind::InvalidInput, - // "triton-multicast-bind-interface is required for IPv6", - // ) - // })?, - // })) - // } - // } - // } - // None => None, - // }; - - // let maybe_triton_multicast_socket = maybe_triton_multicast_config - // .and_then(|config| { - // let num_threads = NonZeroUsize::new(args.triton_multicast_num_threads) - // .ok_or_else(|| { - // io::Error::new( - // ErrorKind::InvalidInput, - // "triton-multicast-num-threads must be non-zero", - // ) - // }).ok()?; - // Some( - // create_multicast_sockets_triton(&config, num_threads) - // .map(|ok| (config.ip(), ok)) - // ) - // }) - // .transpose()?; + let maybe_triton_multicast_config = match args.triton_multicast_group { + Some(multicast_group) => { + match multicast_group { + IpAddr::V4(ipv4) => { + Some(TritonMulticastConfig::Ipv4(TritonMulticastConfigV4 { + multicast_ip: ipv4, + bind_ifname: args.triton_multicast_bind_interface, + })) + } + IpAddr::V6(ipv6) => { + Some(TritonMulticastConfig::Ipv6(TritonMulticastConfigV6 { + multicast_ip: ipv6, + device_ifname: args.triton_multicast_bind_interface + .ok_or_else(|| { + io::Error::new( + ErrorKind::InvalidInput, + "triton-multicast-bind-interface is required for IPv6", + ) + })?, + })) + } + } + } + None => None, + }; let pkt_recv_tile_mem_config = PktRecvTileMemConfig { memory_size: args.pkt_recv_channel_memsize.unwrap_or_default(), @@ -346,6 +330,7 @@ fn main() -> Result<(), ShredstreamProxyError> { triton_forwarder::run_proxy_system( pkt_recv_tile_mem_config, unioned_dest_sockets, + maybe_triton_multicast_config, args.src_bind_addr, args.src_bind_port, args.num_pkt_recv_tile.map(|x| x.get()).unwrap_or(1), diff --git a/proxy/src/recv_mmsg.rs b/proxy/src/recv_mmsg.rs index 2e18536..e1bcab4 100644 --- a/proxy/src/recv_mmsg.rs +++ b/proxy/src/recv_mmsg.rs @@ -72,7 +72,6 @@ where // Refill the frame buffers as much as we can, 'fill_bufmut: while frame_bufmut_vec.len() < PACKETS_PER_BATCH { - log::trace!(" Refilling frame buffers {}", frame_bufmut_vec.len()); let maybe_frame_buf = fill_rx.try_recv(); match maybe_frame_buf { Some(frame_desc) => { @@ -91,6 +90,9 @@ where } } } + + + log::trace!("frame bufmut_vec length: {}", frame_bufmut_vec.len()); let t = Instant::now(); let result = recv_from(&mut frame_bufmut_vec, socket, coalesce, &mut packet_batch); @@ -98,6 +100,7 @@ where let recv_interval = t.elapsed(); if let Ok(len) = result { + log::trace!("recv_from got {} packets in {:?}", len, recv_interval); if len > 0 { observe_recv_interval(recv_interval.as_micros() as f64); log::trace!("Received {} packets", len); @@ -160,6 +163,7 @@ pub fn recv_from( let mut i = 0; loop { + log::trace!("Preparing to receive packets, currently have {} packets", i); match triton_recv_mmsg(socket, available_frame_buf_vec, batch) { Err(_) if i > 0 => { if start.elapsed() > max_wait { @@ -269,6 +273,7 @@ pub fn triton_recv_mmsg( tv_sec: 1, tv_nsec: 0, }; + log::trace!("Calling recvmmsg for up to {} packets", count); // TODO: remove .try_into().unwrap() once rust libc fixes recvmmsg types for musl #[allow(clippy::useless_conversion)] let nrecv = unsafe { @@ -280,6 +285,7 @@ pub fn triton_recv_mmsg( &mut ts, ) }; + trace!("recvmmsg returned nrecv={}", nrecv); let nrecv = if nrecv < 0 { // On error, return all in-flight frame buffers back to the caller for i in 0..frame_buffer_inflight_cnt { @@ -290,8 +296,8 @@ pub fn triton_recv_mmsg( } else { usize::try_from(nrecv).unwrap() }; - for (i, addr, hdr, filled_bufmut) in - izip!(0..nrecv, addrs, hdrs, frame_buffer_inflight_vec).take(nrecv) + for (addr, hdr, filled_bufmut) in + izip!(addrs, hdrs, frame_buffer_inflight_vec).take(nrecv) { // SAFETY: We initialized `count` elements of `hdrs` above. `count` is // passed to recvmmsg() as the limit of messages that can be read. So, @@ -309,6 +315,7 @@ pub fn triton_recv_mmsg( }; pkt.meta_mut().size = hdr_ref.msg_len as usize; if let Some(addr) = cast_socket_addr(addr_ref, hdr_ref) { + log::trace!("Received packet from {}", addr); pkt.meta_mut().set_socket_addr(&addr); } packets.push(pkt); diff --git a/proxy/src/triton_forwarder.rs b/proxy/src/triton_forwarder.rs index 18c63f6..772c7a1 100644 --- a/proxy/src/triton_forwarder.rs +++ b/proxy/src/triton_forwarder.rs @@ -1,13 +1,7 @@ use std::{ - collections::VecDeque, - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, - str::FromStr, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - thread::JoinHandle, - time::{Duration, Instant}, + collections::VecDeque, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, num::NonZeroUsize, str::FromStr, sync::{ + Arc, atomic::{AtomicBool, Ordering} + }, thread::JoinHandle, time::{Duration, Instant} }; use arc_swap::ArcSwap; @@ -24,14 +18,11 @@ use solana_streamer::{ }; use crate::{ - forwarder::{try_create_ipv6_socket, ShredMetrics}, - mem::{FrameBuf, FrameDesc, Rx, SharedMem, Tx}, - prom::{ + forwarder::{ShredMetrics, try_create_ipv6_socket}, mem::{FrameBuf, FrameDesc, Rx, SharedMem, Tx}, triton_multicast_config::TritonMulticastConfig, prom::{ inc_packets_deduped, inc_packets_forward_failed, observe_dedup_time, observe_send_duration, observe_send_packet_count, - }, - recv_mmsg::{PacketRoutingStrategy, TritonPacket}, + }, recv_mmsg::{PacketRoutingStrategy, TritonPacket} }; // values copied from https://github.com/solana-labs/solana/blob/33bde55bbdde13003acf45bb6afe6db4ab599ae4/core/src/sigverify_shreds.rs#L20 @@ -343,6 +334,7 @@ fn packet_fwd_tile( pub fn run_proxy_system( pkt_recv_tile_mem_config: PktRecvTileMemConfig, dest_addr_vec: Arc>>, + multticast_config: Option, src_ip: IpAddr, src_port: u16, num_pkt_recv_tiles: usize, @@ -356,15 +348,26 @@ pub fn run_proxy_system( { let mut tile_thread_vec: Vec> = Vec::new(); // Build pkt_recv sockets - let (_port, pkt_recv_sk_vec) = solana_net_utils::multi_bind_in_range_with_config( - src_ip, - (src_port, src_port + 1), - SocketConfig::default().reuseport(true), - num_pkt_recv_tiles, - ) - .unwrap_or_else(|_| { - panic!("Failed to bind listener sockets. Check that port {src_port} is not in use.") - }); + let pkt_recv_sk_vec = if let Some(multicast_config) = multticast_config { + log::info!("Using Triton multicast configuration for pkt_recv tiles"); + crate::triton_multicast_config::create_multicast_sockets_triton( + &multicast_config, + NonZeroUsize::new(num_pkt_recv_tiles).expect("num_pkt_recv_tiles must be non-zero"), + src_ip, + src_port, + ).expect("multicast-config") + } else { + let (_port, pkt_recv_sk_vec) = solana_net_utils::multi_bind_in_range_with_config( + src_ip, + (src_port, src_port + 1), + SocketConfig::default().reuseport(true), + num_pkt_recv_tiles, + ) + .unwrap_or_else(|_| { + panic!("Failed to bind listener sockets. Check that port {src_port} is not in use.") + }); + pkt_recv_sk_vec + }; let num_frames = pkt_recv_tile_mem_config.memory_size as usize / pkt_recv_tile_mem_config.frame_size; diff --git a/proxy/src/triton_multicast_config.rs b/proxy/src/triton_multicast_config.rs new file mode 100644 index 0000000..2f63e4e --- /dev/null +++ b/proxy/src/triton_multicast_config.rs @@ -0,0 +1,315 @@ +use std::{ + io, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, num::NonZeroUsize, process::Command +}; + +use serde::Deserialize; +use socket2::{Domain, Protocol, SockAddr, Socket, Type}; + +fn run_ip_json(args: &[&str]) -> io::Result> { + let output = Command::new("ip").args(args).output()?; + if output.status.success() { + Ok(output.stdout) + } else { + Err(io::Error::new( + io::ErrorKind::Other, + format!( + "`ip {}` failed with status {}", + args.join(" "), + output.status + ), + )) + } +} + +/// Parse multicast groups routed to `device` via `ip --json route show dev ` +pub fn get_ip_route_for_device(device: &str) -> io::Result> { + let stdout = run_ip_json(&["--json", "route", "show", "dev", device])?; + parse_ip_route_for_device(&stdout) +} + +// Pure JSON parsers for unit testing +pub fn parse_ip_route_for_device(bytes: &[u8]) -> io::Result> { + #[derive(Debug, Deserialize)] + struct RouteRow { + dst: String, + } + + let mut groups = serde_json::from_slice::>(bytes) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))? + .into_iter() + .filter_map(|r| { + if let Some((base, mask_str)) = r.dst.split_once('/') { + let ip: IpAddr = base.parse().ok()?; + let mask: u8 = mask_str.parse().ok()?; + let is_exact = match ip { + IpAddr::V4(_) => mask == 32, // check if full-length mask (not partial) + IpAddr::V6(_) => mask == 128, + }; + (ip.is_multicast() && is_exact).then_some(ip) + } else { + let ip: IpAddr = r.dst.parse().ok()?; + ip.is_multicast().then_some(ip) + } + }) + .collect::>(); + + groups.sort_unstable(); + groups.dedup(); + Ok(groups) +} + +/// Return the primary IPv4 address configured on `device` (if any), via `ip --json addr show`. +pub fn ipv4_addr_for_device(device: &str) -> io::Result> { + let stdout = run_ip_json(&["--json", "addr", "show", "dev", device])?; + parse_ipv4_addr_from_ip_addr_show_json(&stdout) +} + +/// Return the interface index for `device` (if any), via `ip --json link show`. +pub fn ifindex_for_device(device: &str) -> io::Result> { + let stdout = run_ip_json(&["--json", "link", "show", "dev", device])?; + parse_ifindex_from_ip_link_show_json(&stdout) +} + +// Pure JSON parsers for unit testing +pub fn parse_ipv4_addr_from_ip_addr_show_json(bytes: &[u8]) -> io::Result> { + #[derive(Debug, Deserialize)] + struct AddrInfo { + family: Option, + local: Option, + } + #[derive(Debug, Deserialize)] + struct IfaceRow { + addr_info: Option>, + } + + let rows: Vec = + serde_json::from_slice(bytes).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + + let ip = rows + .into_iter() + .flat_map(|row| row.addr_info.unwrap_or_default()) + .find_map(|info| { + (info.family.as_deref() == Some("inet")) + .then_some(info.local) + .flatten() + }) + .and_then(|s| s.parse::().ok()); + + Ok(ip) +} + +pub fn parse_ifindex_from_ip_link_show_json(bytes: &[u8]) -> io::Result> { + #[derive(Debug, Deserialize)] + struct LinkRow { + ifindex: Option, + } + + let rows: Vec = + serde_json::from_slice(bytes).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + Ok(rows.into_iter().last().and_then(|r| r.ifindex)) +} + + +pub struct TritonMulticastConfigV4 { + pub multicast_ip: Ipv4Addr, + pub bind_ifname: Option, +} + +pub struct TritonMulticastConfigV6 { + pub multicast_ip: Ipv6Addr, + pub device_ifname: String, +} + +pub enum TritonMulticastConfig { + Ipv4(TritonMulticastConfigV4), + Ipv6(TritonMulticastConfigV6), +} + +impl TritonMulticastConfig { + pub fn ip(&self) -> IpAddr { + match self { + TritonMulticastConfig::Ipv4(cfg) => IpAddr::V4(cfg.multicast_ip), + TritonMulticastConfig::Ipv6(cfg) => IpAddr::V6(cfg.multicast_ip), + } + } +} + +pub fn create_multicast_sockets_triton_v4( + config: &TritonMulticastConfigV4, + num_threads: NonZeroUsize, + src_ip: IpAddr, + src_port: u16, +) -> io::Result> { + let device_ip = match config.bind_ifname.as_ref() { + Some(ifname) => { + ipv4_addr_for_device(ifname)?.ok_or_else(|| + io::Error::new(io::ErrorKind::NotFound, format!("No IPv4 address found for device {ifname}")) + )? + }, + None => Ipv4Addr::UNSPECIFIED, + }; + let IpAddr::V4(src_ip) = src_ip else { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "Source IP must be IPv4")); + }; + // Step 1: Create first socket, port = 0 → random ephemeral port + let first_socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?; + first_socket.set_reuse_address(true)?; + first_socket.set_reuse_port(true)?; + first_socket.bind(&SockAddr::from(SocketAddr::new(IpAddr::V4(src_ip), src_port)))?; + first_socket.join_multicast_v4(&config.multicast_ip, &device_ip)?; + let local_port = first_socket.local_addr()?.as_socket().unwrap().port(); + + // Step 2: Create N-1 sockets using that same port + let mut sockets = Vec::with_capacity(num_threads.get()); + sockets.push(first_socket.into()); + + for _ in 1..num_threads.get() { + let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?; + socket.set_reuse_address(true)?; + socket.set_reuse_port(true)?; + socket.bind(&SockAddr::from(SocketAddr::new( + IpAddr::V4(Ipv4Addr::UNSPECIFIED), + local_port, + )))?; + socket.join_multicast_v4(&config.multicast_ip, &device_ip)?; + sockets.push(socket.into()); + } + + Ok(sockets) +} + +// fn create_multicast_socket_triton_v6( +// config: &TritonMulticastConfigV6, +// num_threads: usize, +// ) -> Result { +// let TritonMulticastConfigV6 { +// multicast_ip, +// device_ifname, +// } = config; + +// let addrv6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0); +// let socket = UdpSocket::bind(addrv6)?; +// let ifindex = ifindex_for_device(device_ifname)? +// .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, format!("No such device {device_ifname}")))?; +// socket.join_multicast_v6(multicast_ip, ifindex)?; +// Ok(socket) +// } + + +pub fn create_multicast_sockets_triton_v6( + config: &TritonMulticastConfigV6, + num_threads: NonZeroUsize, + src_ip: IpAddr, + src_port: u16, +) -> io::Result> { + // Get the interface index for the device name (e.g. "eth0") + let ifindex = ifindex_for_device(&config.device_ifname)? + .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, format!("No such device {}", config.device_ifname)))?; + + let IpAddr::V6(_src_ip) = src_ip else { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "Source IP must be IPv6")); + }; + // Step 1: Bind first socket to port 0 to let kernel choose a random available port + let first_socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + first_socket.set_only_v6(true)?; // IPv6-only + first_socket.set_reuse_address(true)?; + first_socket.set_reuse_port(true)?; + first_socket.bind(&SockAddr::from(SocketAddr::new(IpAddr::V6(_src_ip), src_port)))?; + first_socket.join_multicast_v6(&config.multicast_ip, ifindex)?; + let local_port = first_socket.local_addr()?.as_socket().unwrap().port(); + // Step 2: Create N-1 additional sockets on the same port for load balancing + let mut sockets = Vec::with_capacity(num_threads.get()); + sockets.push(first_socket.into()); + + for _ in 1..num_threads.get() { + let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + socket.set_only_v6(true)?; + socket.set_reuse_address(true)?; + socket.set_reuse_port(true)?; + socket.bind(&SockAddr::from(SocketAddr::new( + IpAddr::V6(Ipv6Addr::UNSPECIFIED), + local_port, + )))?; + socket.join_multicast_v6(&config.multicast_ip, ifindex)?; + sockets.push(socket.into()); + } + + Ok(sockets) +} + +pub fn create_multicast_sockets_triton( + config: &TritonMulticastConfig, + num_threads: NonZeroUsize, + src_ip: IpAddr, + src_port: u16, +) -> Result, io::Error> { + + match config { + TritonMulticastConfig::Ipv4(cfg) => { + create_multicast_sockets_triton_v4(cfg, num_threads, src_ip, src_port) + } + TritonMulticastConfig::Ipv6(cfg) => { + create_multicast_sockets_triton_v6(cfg, num_threads, src_ip, src_port) + } + } +} + + +#[cfg(test)] +mod tests { + use std::net::Ipv4Addr; + + use super::{ + parse_ifindex_from_ip_link_show_json, parse_ip_route_for_device, + parse_ipv4_addr_from_ip_addr_show_json, + }; + + #[test] + fn parse_ip_route_for_device_test() { + let json = r#"[{"dst":"169.254.2.112/31","protocol":"kernel","scope":"link","prefsrc":"169.254.2.113","flags":[]},{"dst":"233.84.178.2","gateway":"169.254.2.112","protocol":"static","flags":[]}]"#; + let parsed = parse_ip_route_for_device(json.as_bytes()).unwrap(); + assert_eq!(parsed, vec![Ipv4Addr::new(233, 84, 178, 2)]); + } + + #[test] + fn parse_ipv4_addr_from_addr() { + let json = r#"[ + {"addr_info":[ + {"family":"inet6","local":"fe80::1234"}, + {"family":"inet","local":"192.168.1.10"} + ]} + ]"#; + let parsed = parse_ipv4_addr_from_ip_addr_show_json(json.as_bytes()).unwrap(); + assert_eq!(parsed, Some(Ipv4Addr::new(192, 168, 1, 10))); + } + + #[test] + fn parse_ipv4_addr_from_addr_show_malformed() { + let json = r#"{"not":"an array"}"#; + let res = parse_ipv4_addr_from_ip_addr_show_json(json.as_bytes()); + assert!(res.is_err()); + } + + #[test] + fn parse_ifindex_from_link_show_present() { + let json = r#"[ + {"ifindex":3,"ifname":"lo"} + ]"#; + let parsed = parse_ifindex_from_ip_link_show_json(json.as_bytes()).unwrap(); + assert_eq!(parsed, Some(3)); + } + + #[test] + fn parse_ifindex_from_link_show_empty() { + let json = r#"[]"#; + let parsed = parse_ifindex_from_ip_link_show_json(json.as_bytes()).unwrap(); + assert_eq!(parsed, None); + } + + #[test] + fn parse_ifindex_from_link_show_malformed() { + let json = r#"{"ifindex":3}"#; + let res = parse_ifindex_from_ip_link_show_json(json.as_bytes()); + assert!(res.is_err()); + } +} diff --git a/setup_net.sh b/setup_net.sh index 7540ec1..da732ae 100755 --- a/setup_net.sh +++ b/setup_net.sh @@ -1,3 +1,4 @@ +set -e # 1. Create the namespace sudo ip netns add ns1 @@ -19,5 +20,7 @@ sudo ip netns exec ns1 ip link set veth1 up #sudo ethtool -K veth0 gro off #sudo ip netns exec ns1 ethtool -K veth1 gro off +#nping --udp -p 1234 --data-length 500 -c 10 + # 7. Verify connectivity -ping -c 2 10.0.0.11 \ No newline at end of file +ping -c 2 172.31.0.2 \ No newline at end of file From 047e257ccdc21e08523ae6926a6f6bc3f590e393 Mon Sep 17 00:00:00 2001 From: Louis-Vincent Boudreault Date: Sat, 17 Jan 2026 21:18:42 +0000 Subject: [PATCH 09/13] fixed many errors --- Cargo.toml | 4 + data-sample.txt | 75 +++++++++++++ proxy/Cargo.toml | 1 + proxy/src/main2.rs | 21 ++-- proxy/src/mem.rs | 33 +++++- proxy/src/recv_mmsg.rs | 120 ++++++++++++--------- proxy/src/triton_forwarder.rs | 154 ++++++++++++++++++++------- proxy/src/triton_multicast_config.rs | 2 + run-triton-proxy.sh | 13 +++ 9 files changed, 326 insertions(+), 97 deletions(-) create mode 100644 data-sample.txt create mode 100644 run-triton-proxy.sh diff --git a/Cargo.toml b/Cargo.toml index 54f0dc7..ef239a7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,10 @@ members = ["examples", "jito_protos", "proxy"] resolver = "2" +[profile.debug-release] +inherits = "release" +debug = true + [workspace.package] version = "0.2.12-triton" description = "Fast path to receive shreds from Jito, forwarding to local consumers. See https://docs.jito.wtf/lowlatencytxnfeed/ for details." diff --git a/data-sample.txt b/data-sample.txt new file mode 100644 index 0000000..568c811 --- /dev/null +++ b/data-sample.txt @@ -0,0 +1,75 @@ +set 1: + +shredstream_recv_interval_usec_bucket{le="1"} 41 + +shredstream_recv_interval_usec_bucket{le="5"} 63 + +shredstream_recv_interval_usec_bucket{le="10"} 125 + +shredstream_recv_interval_usec_bucket{le="25"} 47676 + +shredstream_recv_interval_usec_bucket{le="50"} 104430 + +shredstream_recv_interval_usec_bucket{le="100"} 162673 + +shredstream_recv_interval_usec_bucket{le="200"} 190777 + +shredstream_recv_interval_usec_bucket{le="500"} 205050 + +shredstream_recv_interval_usec_bucket{le="1000"} 210046 + +shredstream_recv_interval_usec_bucket{le="2000"} 212204 + +shredstream_recv_interval_usec_bucket{le="+Inf"} 214080 + + + +set 2: + +shredstream_recv_interval_usec_bucket{le="1"} 0 + +shredstream_recv_interval_usec_bucket{le="5"} 0 + +shredstream_recv_interval_usec_bucket{le="10"} 22 + +shredstream_recv_interval_usec_bucket{le="25"} 864700 + +shredstream_recv_interval_usec_bucket{le="50"} 1059516 + +shredstream_recv_interval_usec_bucket{le="100"} 1334130 + +shredstream_recv_interval_usec_bucket{le="200"} 1473381 + +shredstream_recv_interval_usec_bucket{le="500"} 1545124 + +shredstream_recv_interval_usec_bucket{le="1000"} 1569639 + +shredstream_recv_interval_usec_bucket{le="2000"} 1580383 + +shredstream_recv_interval_usec_bucket{le="+Inf"} 1589948 + + + +set 3 : + +shredstream_recv_interval_usec_bucket{le="1"} 0 + +shredstream_recv_interval_usec_bucket{le="5"} 0 + +shredstream_recv_interval_usec_bucket{le="10"} 2 + +shredstream_recv_interval_usec_bucket{le="25"} 129306 + +shredstream_recv_interval_usec_bucket{le="50"} 159982 + +shredstream_recv_interval_usec_bucket{le="100"} 202752 + +shredstream_recv_interval_usec_bucket{le="200"} 225469 + +shredstream_recv_interval_usec_bucket{le="500"} 238215 + +shredstream_recv_interval_usec_bucket{le="1000"} 242741 + +shredstream_recv_interval_usec_bucket{le="2000"} 244727 + +shredstream_recv_interval_usec_bucket{le="+Inf"} 246249 \ No newline at end of file diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index eddbb0c..7192d3a 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -15,6 +15,7 @@ name = "jito-shredstream-proxy" path = "src/main.rs" + [dependencies] ahash = { workspace = true } arc-swap = { workspace = true } diff --git a/proxy/src/main2.rs b/proxy/src/main2.rs index 80c7736..25c2f2a 100644 --- a/proxy/src/main2.rs +++ b/proxy/src/main2.rs @@ -148,6 +148,10 @@ struct CommonArgs { /// Memory sizing for EACH packet receiver, uses t-shirt size convention (xs (default),s,m,l,xl,2xl,3xl,4xl,5xl). Each size increase double the memory, starting at 128MiB for x-small. #[arg(long, env)] pkt_recv_channel_memsize: Option, + + /// Use hugepage memory for pkt recv tiles shared memory. + #[arg(long, env, default_value_t = false)] + hugepage: bool, } #[derive(Debug, Error)] @@ -291,6 +295,7 @@ fn main() -> Result<(), ShredstreamProxyError> { let maybe_triton_multicast_config = match args.triton_multicast_group { Some(multicast_group) => { + log::info!("Using triton multicast group: {}", multicast_group); match multicast_group { IpAddr::V4(ipv4) => { Some(TritonMulticastConfig::Ipv4(TritonMulticastConfigV4 { @@ -317,6 +322,7 @@ fn main() -> Result<(), ShredstreamProxyError> { let pkt_recv_tile_mem_config = PktRecvTileMemConfig { memory_size: args.pkt_recv_channel_memsize.unwrap_or_default(), + hugepage: args.hugepage, ..Default::default() }; let proxy_th = { @@ -357,14 +363,13 @@ fn main() -> Result<(), ShredstreamProxyError> { }; thread_handles.push(report_metrics_thread); - // let metrics_hdl = forwarder::start_forwarder_accessory_thread( - // deduper, - // metrics.clone(), - // args.metrics_report_interval_ms, - // shutdown_receiver.clone(), - // exit.clone(), - // ); - // thread_handles.push(metrics_hdl); + let metrics_hdl = triton_forwarder::start_forwarder_accessory_thread( + metrics.clone(), + args.metrics_report_interval_ms, + shutdown_receiver.clone(), + exit.clone(), + ); + thread_handles.push(metrics_hdl); if use_discovery_service { let refresh_handle = forwarder::start_destination_refresh_thread( args.endpoint_discovery_url.unwrap(), diff --git a/proxy/src/mem.rs b/proxy/src/mem.rs index 765f9ec..fed8a7c 100644 --- a/proxy/src/mem.rs +++ b/proxy/src/mem.rs @@ -209,6 +209,18 @@ impl FrameBufMut { pub unsafe fn as_mut_ptr(&self) -> *mut u8 { self.ptr } + + #[inline] + pub unsafe fn seek(&mut self, offset: usize) { + assert!( + offset < self.desc.frame_size, + "seek offset out of bounds" + ); + let new_ptr = self.desc.ptr.add(offset); + let end_ptr = self.end_ptr(); + assert!(new_ptr as *const u8 <= end_ptr, "seek out of bounds"); + self.ptr = new_ptr; + } } unsafe impl BufMut for FrameBufMut { @@ -254,7 +266,7 @@ impl Buf for FrameBuf { use std::{ptr, sync::atomic::AtomicBool}; // We wrap T to include a 'ready' flag for each slot -#[repr(C, align(64))] +#[repr(C)] struct Slot { data: std::mem::MaybeUninit, is_ready: AtomicBool, @@ -313,10 +325,10 @@ pub struct Rx { pub fn message_ring(capacity: usize) -> Result<(Tx, Rx), AllocError> { let capacity = capacity.next_power_of_two(); - let align = std::mem::size_of::>(); + let size = std::mem::size_of::>(); // Allocate memory for Slots - let shmem = SharedMem::new(align, capacity, false)?; + let shmem = SharedMem::new(size, capacity, false)?; let ptr = shmem.ptr as *mut Slot; // Initialize the is_ready flags to false for i in 0..capacity { @@ -354,10 +366,21 @@ impl Tx { let head = self.inner.head.load(Ordering::Relaxed); let tail = self.inner.tail.load(Ordering::Acquire); - if head.wrapping_sub(tail) >= self.inner.capacity { - return Err(value); // Ring is full + // 2. The Fix: Calculate occupancy with wrapping awareness + let occupancy = head.wrapping_sub(tail); + + // A ring is only full if occupancy is >= capacity. + // We add a check for (usize::MAX / 2) to ignore the "stale head" + // cases where occupancy underflows to a massive number. + if occupancy >= self.inner.capacity && occupancy < (usize::MAX / 2) { + return Err(value); } + // if head.wrapping_sub(tail) >= self.inner.capacity { + // log::error!("Ring is full: head={}, tail={}, capacity={}", head, tail, self.inner.capacity); + // return Err(value); // Ring is full + // } + // 2. Claim a slot using Compare-and-Swap (CAS). // We use SeqCst or AcqRel here to ensure that once we "win" this slot, // we have a synchronized view of the memory. diff --git a/proxy/src/recv_mmsg.rs b/proxy/src/recv_mmsg.rs index e1bcab4..1d0a57a 100644 --- a/proxy/src/recv_mmsg.rs +++ b/proxy/src/recv_mmsg.rs @@ -14,10 +14,10 @@ use libc::{iovec, mmsghdr, msghdr, sockaddr_storage, AF_INET, AF_INET6, MSG_WAIT use log::{error, trace}; use socket2::socklen_t; use solana_perf::packet::{NUM_RCVMMSGS, PACKETS_PER_BATCH}; -use solana_sdk::packet::{Meta, PACKET_DATA_SIZE}; +use solana_sdk::{exit, packet::{Meta, PACKET_DATA_SIZE}}; use solana_streamer::{streamer::StreamerReceiveStats}; -use crate::{mem::{FrameBuf, FrameBufMut, FrameDesc, Rx, Tx}, prom::{inc_packets_by_source, inc_packets_received, observe_recv_interval, observe_recv_packet_count}}; +use crate::{mem::{FrameBuf, FrameBufMut, FrameDesc, Rx, Tx}, prom::{inc_packets_received, observe_recv_interval, observe_recv_packet_count}}; pub trait PacketRoutingStrategy: Clone { fn route_packet(&self, packet: &TritonPacket, num_dest: usize) -> Option; @@ -62,8 +62,19 @@ where { let mut packet_batch = Vec::with_capacity(PACKETS_PER_BATCH); let mut frame_bufmut_vec = Vec::with_capacity(PACKETS_PER_BATCH); - + let mut next_stats_report = Instant::now() + Duration::from_secs(1); + let mut router_dest_dist = vec![0usize; packet_tx_vec.len()]; loop { + + if next_stats_report.elapsed() > Duration::ZERO { + next_stats_report = Instant::now() + Duration::from_secs(1); + log::trace!( + "recv_loop: packets_count={}, packet_batches_count={}, full_packet_batches_count={}", + stats.packets_count.load(Ordering::Relaxed), + stats.packet_batches_count.load(Ordering::Relaxed), + stats.full_packet_batches_count.load(Ordering::Relaxed), + ); + } // Check for exit signal, even if socket is busy // (for instance the leader transaction socket) if exit.load(Ordering::Relaxed) { @@ -81,7 +92,9 @@ where None => { if frame_bufmut_vec.is_empty() { // block until we get at least one frame buffer - let frame_desc = fill_rx.recv(); + let Some(frame_desc) = fill_rx.recv_timeout(Duration::from_millis(10)) else { + break 'fill_bufmut + }; let frame_bufmut = frame_desc.as_mut_buf(); frame_bufmut_vec.push(frame_bufmut); } else { @@ -91,52 +104,62 @@ where } } + if frame_bufmut_vec.is_empty() { + // No available frame buffers to receive into, wait a bit + log::debug!("recv_loop: no available frame buffers to receive into"); + continue; + } log::trace!("frame bufmut_vec length: {}", frame_bufmut_vec.len()); let t = Instant::now(); - let result = recv_from(&mut frame_bufmut_vec, socket, coalesce, &mut packet_batch); - + let result = recv_from(&mut frame_bufmut_vec, socket, coalesce, &mut packet_batch, &exit); let recv_interval = t.elapsed(); - if let Ok(len) = result { - log::trace!("recv_from got {} packets in {:?}", len, recv_interval); - if len > 0 { - observe_recv_interval(recv_interval.as_micros() as f64); - log::trace!("Received {} packets", len); - inc_packets_received(len as u64); - observe_recv_packet_count(len as f64); - let StreamerReceiveStats { - packets_count, - packet_batches_count, - full_packet_batches_count, - .. - } = stats; - - packets_count.fetch_add(len, Ordering::Relaxed); - packet_batches_count.fetch_add(1, Ordering::Relaxed); - if len == PACKETS_PER_BATCH { - full_packet_batches_count.fetch_add(1, Ordering::Relaxed); - } - packet_batch - .iter_mut() - .for_each(|p| p.meta_mut().set_from_staked_node(false)); - - for packet in packet_batch.drain(..) { - let dest_idx = match router.route_packet(&packet, packet_tx_vec.len()) { - Some(idx) => idx, - None => { - log::debug!("Failed to route packet {:?}", packet); - let trashed_frame_bufmut = packet.buffer.into_inner().as_mut_buf(); - frame_bufmut_vec.push(trashed_frame_bufmut); - continue; - } - }; - let _ = &packet_tx_vec[dest_idx] - .send(packet) - .expect("Failed to send packet to processor"); + match result { + Ok(len) => { + log::trace!("recv_from got {} packets in {:?}", len, recv_interval); + if len > 0 { + // observe_recv_interval(recv_interval.as_micros() as f64); + log::trace!("Received {} packets", len); + inc_packets_received(len as u64); + observe_recv_packet_count(len as f64); + let StreamerReceiveStats { + packets_count, + packet_batches_count, + full_packet_batches_count, + .. + } = stats; + + packets_count.fetch_add(len, Ordering::Relaxed); + packet_batches_count.fetch_add(1, Ordering::Relaxed); + if len == PACKETS_PER_BATCH { + full_packet_batches_count.fetch_add(1, Ordering::Relaxed); + } + packet_batch + .iter_mut() + .for_each(|p| p.meta_mut().set_from_staked_node(false)); + + for packet in packet_batch.drain(..) { + let dest_idx = match router.route_packet(&packet, packet_tx_vec.len()) { + Some(idx) => idx, + None => { + log::debug!("Failed to route packet {:?}", packet); + let trashed_frame_bufmut = packet.buffer.into_inner().as_mut_buf(); + frame_bufmut_vec.push(trashed_frame_bufmut); + continue; + } + }; + router_dest_dist[dest_idx] += 1; + let _ = &packet_tx_vec[dest_idx] + .send(packet) + .expect(format!("failed to send packet to {dest_idx} ring is full, distr:{:?}", router_dest_dist).as_str()); + } } } + Err(e) => { + error!("recv_from error: {:?}", e); + } } } } @@ -146,6 +169,7 @@ pub fn recv_from( socket: &UdpSocket, max_wait: Duration, batch: &mut Vec, + exit: &AtomicBool, ) -> std::io::Result { // let mut i: usize = 0; //DOCUMENTED SIDE-EFFECT @@ -162,7 +186,7 @@ pub fn recv_from( let mut i = 0; - loop { + while !exit.load(Ordering::Relaxed) { log::trace!("Preparing to receive packets, currently have {} packets", i); match triton_recv_mmsg(socket, available_frame_buf_vec, batch) { Err(_) if i > 0 => { @@ -273,8 +297,8 @@ pub fn triton_recv_mmsg( tv_sec: 1, tv_nsec: 0, }; - log::trace!("Calling recvmmsg for up to {} packets", count); - // TODO: remove .try_into().unwrap() once rust libc fixes recvmmsg types for musl + // TODO: remove .try_into().unwrap() once rust libc fixes recvmmsg types for musl + log::trace!("Calling recvmmsg with count={}", count); #[allow(clippy::useless_conversion)] let nrecv = unsafe { libc::recvmmsg( @@ -285,7 +309,7 @@ pub fn triton_recv_mmsg( &mut ts, ) }; - trace!("recvmmsg returned nrecv={}", nrecv); + log::trace!("recvmmsg returned nrecv={}", nrecv); let nrecv = if nrecv < 0 { // On error, return all in-flight frame buffers back to the caller for i in 0..frame_buffer_inflight_cnt { @@ -307,7 +331,8 @@ pub fn triton_recv_mmsg( // SAFETY: Similar to above, we initialized this `addr` and recvmmsg() // will have populated it let addr_ref = unsafe { addr.assume_init_ref() }; - let filled_bufmut = unsafe { filled_bufmut.assume_init_read() }; + let mut filled_bufmut = unsafe { filled_bufmut.assume_init_read() }; + unsafe { filled_bufmut.seek(hdr_ref.msg_len as usize); } let filled_buf: FrameBuf = filled_bufmut.into(); let mut pkt = TritonPacket { buffer: filled_buf, @@ -315,7 +340,6 @@ pub fn triton_recv_mmsg( }; pkt.meta_mut().size = hdr_ref.msg_len as usize; if let Some(addr) = cast_socket_addr(addr_ref, hdr_ref) { - log::trace!("Received packet from {}", addr); pkt.meta_mut().set_socket_addr(&addr); } packets.push(pkt); diff --git a/proxy/src/triton_forwarder.rs b/proxy/src/triton_forwarder.rs index 772c7a1..b6aa501 100644 --- a/proxy/src/triton_forwarder.rs +++ b/proxy/src/triton_forwarder.rs @@ -1,5 +1,5 @@ use std::{ - collections::VecDeque, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, num::NonZeroUsize, str::FromStr, sync::{ + collections::VecDeque, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, num::NonZeroUsize, os::fd::AsRawFd, str::FromStr, sync::{ Arc, atomic::{AtomicBool, Ordering} }, thread::JoinHandle, time::{Duration, Instant} }; @@ -18,11 +18,9 @@ use solana_streamer::{ }; use crate::{ - forwarder::{ShredMetrics, try_create_ipv6_socket}, mem::{FrameBuf, FrameDesc, Rx, SharedMem, Tx}, triton_multicast_config::TritonMulticastConfig, prom::{ - inc_packets_deduped, inc_packets_forward_failed, - observe_dedup_time, - observe_send_duration, observe_send_packet_count, - }, recv_mmsg::{PacketRoutingStrategy, TritonPacket} + forwarder::{ShredMetrics, try_create_ipv6_socket}, mem::{FrameBuf, FrameDesc, Rx, SharedMem, Tx}, prom::{ + inc_packets_deduped, inc_packets_forward_failed, observe_dedup_time, observe_recv_interval, observe_send_duration, observe_send_packet_count + }, recv_mmsg::{PacketRoutingStrategy, TritonPacket}, triton_multicast_config::TritonMulticastConfig }; // values copied from https://github.com/solana-labs/solana/blob/33bde55bbdde13003acf45bb6afe6db4ab599ae4/core/src/sigverify_shreds.rs#L20 @@ -190,6 +188,8 @@ fn packet_fwd_tile( let mut next_batch_send: Vec<(FrameBuf, SocketAddr)> = Vec::with_capacity(UIO_MAXIOV); let mut queued: VecDeque = VecDeque::with_capacity(UIO_MAXIOV); + let mut last_batch_to_send = Instant::now(); + for shmem_info in &shmem_info_vec { assert!( shmem_info.len.is_power_of_two(), @@ -207,6 +207,15 @@ fn packet_fwd_tile( DEDUPER_RESET_CYCLE, ); next_deduper_reset_attempt = Instant::now() + Duration::from_secs(2); + // show stats here... + log::trace!( + "send_batch_count: {}, duplicate: {}, total-pkt-sent: {}, queue-len: {}, to-recycle: {}", + stats.send_batch_count.load(Ordering::Relaxed), + stats.duplicate.load(Ordering::Relaxed), + stats.send_batch_size_sum.load(Ordering::Relaxed), + queued.len(), + recycled_frames.len(), + ); } if queued.is_empty() && recycled_frames.is_empty() && next_batch_send.is_empty() { @@ -215,9 +224,9 @@ fn packet_fwd_tile( let data_slice = &packet.buffer.chunk()[..data_size]; if deduper.dedup(data_slice) { // put it inside the recycle queue - debug!("Deduped packet from {}", packet.meta.addr); let desc = packet.buffer.into_inner(); recycled_frames.push(desc); + stats.duplicate.fetch_add(1, Ordering::Relaxed); } else { queued.push_back(packet); } @@ -235,9 +244,9 @@ fn packet_fwd_tile( let t = Instant::now(); if deduper.dedup(data_slice) { // put it inside the recycle queue - debug!("Deduped packet from {}", packet.meta.addr); let desc = packet.buffer.into_inner(); recycled_frames.push(desc); + stats.duplicate.fetch_add(1, Ordering::Relaxed); inc_packets_deduped(1); } else { queued.push_back(packet); @@ -248,14 +257,13 @@ fn packet_fwd_tile( let dests = hot_dest_vec.load(); let dests_len = dests.len(); - // Fill up the next_batch_send 'fill_batch_send: while next_batch_send.len() < UIO_MAXIOV && queued.len() > 0 - && dests_len > 0 + && recycled_frames.len() < UIO_MAXIOV { let remaining = UIO_MAXIOV - next_batch_send.len(); - if dests_len < remaining { + if dests_len > remaining { break 'fill_batch_send; } @@ -288,27 +296,34 @@ fn packet_fwd_tile( "queued.len() = {}", queued.len() ); - + let batch_send_ts = Instant::now(); - match batch_send(&send_socket, &next_batch_send) { - Ok(_) => { - // Successfully sent all packets in the batch - let send_duration = batch_send_ts.elapsed(); - stats.batch_send_time_spent.fetch_add(send_duration.as_micros() as u64, Ordering::Relaxed); - stats.send_batch_count.fetch_add(1, Ordering::Relaxed); - observe_send_duration(send_duration.as_micros() as f64); - observe_send_packet_count(next_batch_send.len() as f64); - } - Err(SendPktsError::IoError(err, num_failed)) => { - error!( - "Failed to send batch of size {}. \ - {num_failed} packets failed. Error: {err}", - next_batch_send.len() - ); - inc_packets_forward_failed(num_failed as u64); + + if !next_batch_send.is_empty() { + let e = last_batch_to_send.elapsed(); + last_batch_to_send = Instant::now(); + + observe_recv_interval(e.as_micros() as f64); + match batch_send(&send_socket, &next_batch_send) { + Ok(_) => { + // Successfully sent all packets in the batch + let send_duration = batch_send_ts.elapsed(); + stats.batch_send_time_spent.fetch_add(send_duration.as_micros() as u64, Ordering::Relaxed); + stats.send_batch_count.fetch_add(1, Ordering::Relaxed); + stats.send_batch_size_sum.fetch_add(next_batch_send.len() as u64, Ordering::Relaxed); + observe_send_duration(send_duration.as_micros() as f64); + observe_send_packet_count(next_batch_send.len() as f64); + } + Err(SendPktsError::IoError(err, num_failed)) => { + error!( + "Failed to send batch of size {}. \ + {num_failed} packets failed. Error: {err}", + next_batch_send.len() + ); + inc_packets_forward_failed(num_failed as u64); + } } } - next_batch_send.clear(); // Recycle all used frames @@ -316,8 +331,9 @@ fn packet_fwd_tile( let fill_ring_idx = shmem_info_vec .iter() .find_position(|shmem_info| { - (desc.ptr as usize) & (shmem_info.len - 1) - == (shmem_info.start_ptr as usize) + let p = desc.ptr as usize; + let start = shmem_info.start_ptr as usize; + p >= start && p < start + shmem_info.len }) .expect("unknown frame desc") .0; @@ -368,6 +384,12 @@ pub fn run_proxy_system( }); pkt_recv_sk_vec }; + assert!(pkt_recv_sk_vec.len() == num_pkt_recv_tiles, "pkt_recv_sk_vec.len() ({}) != num_pkt_recv_tiles ({})", pkt_recv_sk_vec.len(), num_pkt_recv_tiles); + + let mut pkt_recv_sk_raw_fd_vec: Vec = Vec::with_capacity(num_pkt_recv_tiles); + for sk in &pkt_recv_sk_vec { + pkt_recv_sk_raw_fd_vec.push(sk.as_raw_fd()); + } let num_frames = pkt_recv_tile_mem_config.memory_size as usize / pkt_recv_tile_mem_config.frame_size; @@ -380,6 +402,8 @@ pub fn run_proxy_system( let mut shmem_vec: Vec = Vec::with_capacity(num_pkt_recv_tiles); let mut pkt_fwd_sk_vec: Vec = Vec::with_capacity(num_pkt_fwd_tiles); + let mut pkt_fwd_sk_raw_fd_vec: Vec = Vec::with_capacity(num_pkt_fwd_tiles); + let mut packet_rx_vec: Vec> = Vec::with_capacity(num_pkt_fwd_tiles); let mut packet_tx_vec: Vec> = Vec::with_capacity(num_pkt_fwd_tiles); @@ -462,20 +486,28 @@ pub fn run_proxy_system( "Packet forwarder sending socket bound to {}", send_socket.local_addr().unwrap() ); + pkt_fwd_sk_raw_fd_vec.push(send_socket.as_raw_fd()); pkt_fwd_sk_vec.push(send_socket); } + + let pkt_fwd_tile_ring_capacity = num_frames * num_pkt_recv_tiles; + log::info!( + "Setting pkt_fwd tile's message ring capacity to {} (num_frames {} * num_pkt_recv_tiles {})", + pkt_fwd_tile_ring_capacity, + num_frames, + num_pkt_recv_tiles + ); + // Create pkt_fwd message rings // One ring per pkt_fwd tile for _ in 0..num_pkt_fwd_tiles { - let (packet_tx, packet_rx) = crate::mem::message_ring(num_frames).expect("pkt_fwd ring"); + // Worst case scenario all frames from all pkt_recv tiles are sent to this pkt_fwd tile + // We set the ring capacity to that + let (packet_tx, packet_rx) = crate::mem::message_ring(pkt_fwd_tile_ring_capacity).expect("pkt_fwd ring"); packet_tx_vec.push(packet_tx); packet_rx_vec.push(packet_rx); } - log::info!( - "Initialized pkt_fwd message rings with {} slots", - num_frames - ); // Spawn pkt_fwd tiles for (pkt_fwd_idx, pkt_fwd_sk, packet_rx) in izip!( @@ -528,6 +560,7 @@ pub fn run_proxy_system( log::info!("Spawned pkt_recv tile {}", pkt_recv_idx); } + let (kind, idx) = tile_wait_group.wait_first(); warn!("Tile of kind {kind:?} with idx {idx} has exited. Shutting down proxy system"); @@ -535,6 +568,22 @@ pub fn run_proxy_system( drop(fill_tx_vec); drop(packet_tx_vec); log::info!("Waiting for {} tile threads to exit", tile_thread_vec.len()); + + + // There is a special case where pkt_recv could be blocked inside recvmmsg syscall if no new packets arrive in a triton_recv_msg iteratoin (when newly set to blocking mode). + // We have to force close the socket to break it out of the syscall. + for sk_fd in pkt_recv_sk_raw_fd_vec { + unsafe { + libc::close(sk_fd); + } + } + + for sk_fd in pkt_fwd_sk_raw_fd_vec { + unsafe { + libc::close(sk_fd); + } + } + for th in tile_thread_vec { let result = th.join(); if let Err(e) = result { @@ -543,6 +592,39 @@ pub fn run_proxy_system( } } + + + +/// Reset dedup + send metrics to influx +pub fn start_forwarder_accessory_thread( + metrics: Arc, + metrics_update_interval_ms: u64, + shutdown_receiver: Receiver<()>, + exit: Arc, +) -> JoinHandle<()> { + std::thread::Builder::new() + .name("ssPxyAccessory".to_string()) + .spawn(move || { + let metrics_tick = + crossbeam_channel::tick(Duration::from_millis(metrics_update_interval_ms)); + while !exit.load(Ordering::Relaxed) { + crossbeam_channel::select! { + // send metrics to influx + recv(metrics_tick) -> _ => { + metrics.report(); + metrics.reset(); + } + + // handle SIGINT shutdown + recv(shutdown_receiver) -> _ => { + break; + } + } + } + }) + .unwrap() +} + // #[cfg(test)] // mod tests { // use std::{ diff --git a/proxy/src/triton_multicast_config.rs b/proxy/src/triton_multicast_config.rs index 2f63e4e..c312bbd 100644 --- a/proxy/src/triton_multicast_config.rs +++ b/proxy/src/triton_multicast_config.rs @@ -148,6 +148,7 @@ pub fn create_multicast_sockets_triton_v4( }, None => Ipv4Addr::UNSPECIFIED, }; + log::info!("multicast device {} has ip {}", config.bind_ifname.as_deref().unwrap_or("unspecified"), device_ip); let IpAddr::V4(src_ip) = src_ip else { return Err(io::Error::new(io::ErrorKind::InvalidInput, "Source IP must be IPv4")); }; @@ -157,6 +158,7 @@ pub fn create_multicast_sockets_triton_v4( first_socket.set_reuse_port(true)?; first_socket.bind(&SockAddr::from(SocketAddr::new(IpAddr::V4(src_ip), src_port)))?; first_socket.join_multicast_v4(&config.multicast_ip, &device_ip)?; + log::info!("Joined multicast group {} on device IP {}", config.multicast_ip, device_ip); let local_port = first_socket.local_addr()?.as_socket().unwrap().port(); // Step 2: Create N-1 sockets using that same port diff --git a/run-triton-proxy.sh b/run-triton-proxy.sh new file mode 100644 index 0000000..9187c0c --- /dev/null +++ b/run-triton-proxy.sh @@ -0,0 +1,13 @@ + +# send random traffic +sudo nping --udp -p 8002 -c 0 --rate 1 --data-length 1200 127.0.0.1 + +while true; do + # Generate 1200 random bytes from /dev/urandom and send them + head -c 1200 /dev/urandom | socat - UDP:127.0.0.1:8002 + sleep 0.01 # Add a small delay between packets +done + + +# run triton-proxy +cargo run --bin triton-proxy -- forward-only --src-bind-addr 127.0.0.1 --src-bind-port 8002 --prometheus-bind-addr 127.0.0.1:9999 --dest-ip-ports 127.0.0.1:8989 \ No newline at end of file From f88fe1b57cb57c97694dfbdd68230d1d9b529765 Mon Sep 17 00:00:00 2001 From: Louis-Vincent Boudreault Date: Mon, 26 Jan 2026 15:11:22 +0000 Subject: [PATCH 10/13] use epoll to listen to concurrent socket per recv tile --- Cargo.lock | 12 +- Cargo.toml | 1 + proxy/Cargo.toml | 1 + proxy/src/main2.rs | 26 ++- proxy/src/recv_mmsg.rs | 242 ++++++++++++++++----------- proxy/src/triton_forwarder.rs | 78 +++++---- proxy/src/triton_multicast_config.rs | 27 ++- 7 files changed, 233 insertions(+), 154 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1d3f992..d035fa5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2824,6 +2824,7 @@ dependencies = [ "lazy_static", "libc", "log", + "mio", "prometheus", "prost 0.13.5", "prost-types 0.13.5", @@ -2936,9 +2937,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.176" +version = "0.2.180" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58f929b4d672ea937a23a1ab494143d968337a5f47e56d0815df1e0890ddf174" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" [[package]] name = "libloading" @@ -3202,13 +3203,14 @@ dependencies = [ [[package]] name = "mio" -version = "1.0.4" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", + "log", "wasi 0.11.1+wasi-snapshot-preview1", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index ef239a7..ca86daa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ jito-protos = { path = "jito_protos" } lazy_static = "1.4.0" libc = "0.2" log = "0.4" +mio = "1.1.1" prost = "0.13" prost-types = "0.13" prometheus = "0.14.0" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 7192d3a..881cb25 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -32,6 +32,7 @@ jito-protos = { workspace = true } lazy_static = { workspace = true } log = { workspace = true } libc = { workspace = true } +mio = { workspace = true } prometheus = { workspace = true } prost = { workspace = true } prost-types = { workspace = true } diff --git a/proxy/src/main2.rs b/proxy/src/main2.rs index 25c2f2a..3e38fb5 100644 --- a/proxy/src/main2.rs +++ b/proxy/src/main2.rs @@ -86,6 +86,21 @@ struct CommonArgs { #[arg(long, env, default_value_t = 20_000)] src_bind_port: u16, + /// Multicast IP to listen for shreds. If none provided, attempts to + /// parse multicast routes for the device specified by `--multicast-device` + /// via `ip --json route show dev `. + #[arg(long, env)] + multicast_bind_ip: Option, + + /// Network device to use for multicast route discovery and interface selection. + /// Example: `eth0`, `en0`, or `doublezero1`. + #[arg(long, env, default_value = "doublezero1")] + multicast_device: String, + + /// Port to receive multicast shreds + #[arg(long, env, default_value_t = 20001)] + multicast_subscribe_port: u16, + /// Static set of IP:Port where Shredstream proxy forwards shreds to, comma separated. /// Eg. `127.0.0.1:8001,10.0.0.1:8001`. // Note: store the original string, so we can do hostname resolution when refreshing destinations @@ -130,8 +145,13 @@ struct CommonArgs { /// If ipv4, then optional (listen on all interfaces if not provided). #[arg(long, env)] triton_multicast_bind_interface: Option, - #[arg(long, env, default_value_t = 1)] - triton_multicast_num_threads: usize, + + /// + /// The port to listen on for triton multicast. + /// If not provided, defaults to 8002. + /// NOTE: this port must match the port used by the triton multicast sender. + #[arg(long, env)] + triton_multicast_port: Option, /// Address to bind prometheus metrics server to. If not provided, prometheus server is disabled. #[arg(long, env)] @@ -301,6 +321,7 @@ fn main() -> Result<(), ShredstreamProxyError> { Some(TritonMulticastConfig::Ipv4(TritonMulticastConfigV4 { multicast_ip: ipv4, bind_ifname: args.triton_multicast_bind_interface, + listen_port: args.triton_multicast_port.unwrap_or(8002), })) } IpAddr::V6(ipv6) => { @@ -313,6 +334,7 @@ fn main() -> Result<(), ShredstreamProxyError> { "triton-multicast-bind-interface is required for IPv6", ) })?, + listen_port: args.triton_multicast_port.unwrap_or(8002), })) } } diff --git a/proxy/src/recv_mmsg.rs b/proxy/src/recv_mmsg.rs index 1d0a57a..64c3e5c 100644 --- a/proxy/src/recv_mmsg.rs +++ b/proxy/src/recv_mmsg.rs @@ -10,14 +10,15 @@ use std::{ use bytes::{Buf, BufMut}; use itertools::izip; -use libc::{iovec, mmsghdr, msghdr, sockaddr_storage, AF_INET, AF_INET6, MSG_WAITFORONE}; +use libc::{AF_INET, AF_INET6, MSG_DONTWAIT, iovec, mmsghdr, msghdr, sockaddr_storage}; use log::{error, trace}; +use mio::Poll; use socket2::socklen_t; use solana_perf::packet::{NUM_RCVMMSGS, PACKETS_PER_BATCH}; -use solana_sdk::{exit, packet::{Meta, PACKET_DATA_SIZE}}; +use solana_sdk::packet::{Meta, PACKET_DATA_SIZE}; use solana_streamer::{streamer::StreamerReceiveStats}; -use crate::{mem::{FrameBuf, FrameBufMut, FrameDesc, Rx, Tx}, prom::{inc_packets_received, observe_recv_interval, observe_recv_packet_count}}; +use crate::{mem::{FrameBuf, FrameBufMut, FrameDesc, Rx, Tx}, prom::{inc_packets_received, observe_recv_packet_count}}; pub trait PacketRoutingStrategy: Clone { fn route_packet(&self, packet: &TritonPacket, num_dest: usize) -> Option; @@ -49,10 +50,9 @@ impl PacketRoutingStrategy for FECSetRoutingStrategy { } pub fn recv_loop( - socket: &UdpSocket, + sk_vec: Vec, exit: &AtomicBool, stats: &StreamerReceiveStats, - coalesce: Duration, fill_rx: &mut Rx, packet_tx_vec: &[Tx], router: R, @@ -64,7 +64,31 @@ where let mut frame_bufmut_vec = Vec::with_capacity(PACKETS_PER_BATCH); let mut next_stats_report = Instant::now() + Duration::from_secs(1); let mut router_dest_dist = vec![0usize; packet_tx_vec.len()]; - loop { + let mut poll = Poll::new()?; + let mut events = mio::Events::with_capacity(sk_vec.len()); + + // Initial registration of sockets + for (i, socket) in sk_vec.iter().enumerate() { + poll.registry().register( + &mut mio::net::UdpSocket::from_std(socket.try_clone().unwrap()), + mio::Token(i), + mio::Interest::READABLE, + )?; + } + + while !exit.load(Ordering::Relaxed) { + + // Events are always cleared before receiving new ones + let result = poll.poll(&mut events, Some(Duration::from_millis(100))); + + match result { + Ok(_) => { } + Err(e) => { + if e.kind() != io::ErrorKind::TimedOut { + return Err(e); + } + } + } if next_stats_report.elapsed() > Duration::ZERO { next_stats_report = Instant::now() + Duration::from_secs(1); @@ -80,94 +104,120 @@ where if exit.load(Ordering::Relaxed) { return Ok(()); } - - // Refill the frame buffers as much as we can, - 'fill_bufmut: while frame_bufmut_vec.len() < PACKETS_PER_BATCH { - let maybe_frame_buf = fill_rx.try_recv(); - match maybe_frame_buf { - Some(frame_desc) => { - let frame_bufmut = frame_desc.as_mut_buf(); - frame_bufmut_vec.push(frame_bufmut); - } - None => { - if frame_bufmut_vec.is_empty() { - // block until we get at least one frame buffer - let Some(frame_desc) = fill_rx.recv_timeout(Duration::from_millis(10)) else { - break 'fill_bufmut - }; + // We can't use a for-loop here because we need to be able to drain the readiness of each socket. + // Since each recv_from is bounded by a PACKETS_PER_BATCH, we may need to call recv_from multiple times per socket + // until we get a WouldBlock error. + let mut ev_iter = events.iter(); + let Some(mut ev) = ev_iter.next() else { + continue; + }; + 'drain_readiness_loop: while !exit.load(Ordering::Relaxed) { + + let sk_idx = ev.token().0; + let recv_sk = &sk_vec[sk_idx]; + + // Refill the frame buffers as much as we can, + 'fill_bufmut: while frame_bufmut_vec.len() < PACKETS_PER_BATCH { + let maybe_frame_buf = fill_rx.try_recv(); + match maybe_frame_buf { + Some(frame_desc) => { let frame_bufmut = frame_desc.as_mut_buf(); frame_bufmut_vec.push(frame_bufmut); - } else { - break 'fill_bufmut; + } + None => { + if frame_bufmut_vec.is_empty() { + // block until we get at least one frame buffer + let Some(frame_desc) = fill_rx.recv_timeout(Duration::from_millis(100)) else { + break 'fill_bufmut + }; + let frame_bufmut = frame_desc.as_mut_buf(); + frame_bufmut_vec.push(frame_bufmut); + } else { + break 'fill_bufmut; + } } } } - } - if frame_bufmut_vec.is_empty() { - // No available frame buffers to receive into, wait a bit - log::debug!("recv_loop: no available frame buffers to receive into"); - continue; - } - - log::trace!("frame bufmut_vec length: {}", frame_bufmut_vec.len()); - - let t = Instant::now(); - let result = recv_from(&mut frame_bufmut_vec, socket, coalesce, &mut packet_batch, &exit); - let recv_interval = t.elapsed(); + if frame_bufmut_vec.is_empty() { + // No available frame buffers to receive into, wait a bit + log::debug!("recv_loop: no available frame buffers to receive into"); + continue 'drain_readiness_loop; + } - match result { - Ok(len) => { - log::trace!("recv_from got {} packets in {:?}", len, recv_interval); - if len > 0 { - // observe_recv_interval(recv_interval.as_micros() as f64); - log::trace!("Received {} packets", len); - inc_packets_received(len as u64); - observe_recv_packet_count(len as f64); - let StreamerReceiveStats { - packets_count, - packet_batches_count, - full_packet_batches_count, - .. - } = stats; - - packets_count.fetch_add(len, Ordering::Relaxed); - packet_batches_count.fetch_add(1, Ordering::Relaxed); - if len == PACKETS_PER_BATCH { - full_packet_batches_count.fetch_add(1, Ordering::Relaxed); + log::trace!("frame bufmut_vec length: {}", frame_bufmut_vec.len()); + + let t = Instant::now(); + let result = recv_from(&mut frame_bufmut_vec, recv_sk, &mut packet_batch, &exit); + let recv_interval = t.elapsed(); + + + match result { + Ok(len) => { + log::trace!("recv_from got {} packets in {:?}", len, recv_interval); + if len > 0 { + // observe_recv_interval(recv_interval.as_micros() as f64); + log::trace!("Received {} packets", len); + inc_packets_received(len as u64); + observe_recv_packet_count(len as f64); + let StreamerReceiveStats { + packets_count, + packet_batches_count, + full_packet_batches_count, + .. + } = stats; + + packets_count.fetch_add(len, Ordering::Relaxed); + packet_batches_count.fetch_add(1, Ordering::Relaxed); + if len == PACKETS_PER_BATCH { + full_packet_batches_count.fetch_add(1, Ordering::Relaxed); + } + packet_batch + .iter_mut() + .for_each(|p| p.meta_mut().set_from_staked_node(false)); + + 'packet_drain: for packet in packet_batch.drain(..) { + let dest_idx = match router.route_packet(&packet, packet_tx_vec.len()) { + Some(idx) => idx, + None => { + log::debug!("Failed to route packet {:?}", packet); + let trashed_frame_bufmut = packet.buffer.into_inner().as_mut_buf(); + frame_bufmut_vec.push(trashed_frame_bufmut); + continue 'packet_drain; + } + }; + router_dest_dist[dest_idx] += 1; + let _ = &packet_tx_vec[dest_idx] + .send(packet) + .expect(format!("failed to send packet to {dest_idx} ring is full, distr:{:?}", router_dest_dist).as_str()); + } } - packet_batch - .iter_mut() - .for_each(|p| p.meta_mut().set_from_staked_node(false)); - - for packet in packet_batch.drain(..) { - let dest_idx = match router.route_packet(&packet, packet_tx_vec.len()) { - Some(idx) => idx, + } + Err(e) => { + if e.kind() == io::ErrorKind::WouldBlock { + // Only when we drained all events for this poll iteration, we process the next event or break + match ev_iter.next() { + Some(next_ev) => { + ev = next_ev; + continue 'drain_readiness_loop; + } None => { - log::debug!("Failed to route packet {:?}", packet); - let trashed_frame_bufmut = packet.buffer.into_inner().as_mut_buf(); - frame_bufmut_vec.push(trashed_frame_bufmut); - continue; + break 'drain_readiness_loop; } - }; - router_dest_dist[dest_idx] += 1; - let _ = &packet_tx_vec[dest_idx] - .send(packet) - .expect(format!("failed to send packet to {dest_idx} ring is full, distr:{:?}", router_dest_dist).as_str()); + } + } else { + return Err(e); } } } - Err(e) => { - error!("recv_from error: {:?}", e); - } } } + Ok(()) } pub fn recv_from( available_frame_buf_vec: &mut Vec, socket: &UdpSocket, - max_wait: Duration, batch: &mut Vec, exit: &AtomicBool, ) -> std::io::Result { @@ -178,38 +228,28 @@ pub fn recv_from( // * set the socket to non blocking // * read until it fails // * set it back to blocking before returning - socket.set_nonblocking(false)?; + // socket.set_nonblocking(false)?; trace!("receiving on {}", socket.local_addr().unwrap()); - let start = Instant::now(); - - assert!(batch.capacity() >= PACKETS_PER_BATCH); + let batch_capacity = batch.capacity(); + assert!(batch_capacity >= PACKETS_PER_BATCH); let mut i = 0; while !exit.load(Ordering::Relaxed) { log::trace!("Preparing to receive packets, currently have {} packets", i); - match triton_recv_mmsg(socket, available_frame_buf_vec, batch) { - Err(_) if i > 0 => { - if start.elapsed() > max_wait { - break; - } - } - Err(e) => { - trace!("recv_from err {:?}", e); - return Err(e); - } - Ok(npkts) => { - if i == 0 { - socket.set_nonblocking(true)?; - } - trace!("got {} packets", npkts); - i += npkts; - // Try to batch into big enough buffers - // will cause less re-shuffling later on. - if start.elapsed() > max_wait || i >= PACKETS_PER_BATCH { - break; - } - } + let npkts = triton_recv_mmsg(socket, available_frame_buf_vec, batch)?; + trace!("got {} packets", npkts); + i += npkts; + if available_frame_buf_vec.is_empty() { + break; + } + if batch.len() >= batch_capacity { + break; + } + // Try to batch into big enough buffers + // will cause less re-shuffling later on. + if i >= PACKETS_PER_BATCH { + break; } } Ok(i) @@ -305,7 +345,7 @@ pub fn triton_recv_mmsg( sock_fd, hdrs[0].assume_init_mut(), count as u32, - MSG_WAITFORONE.try_into().unwrap(), + MSG_DONTWAIT.try_into().unwrap(), &mut ts, ) }; diff --git a/proxy/src/triton_forwarder.rs b/proxy/src/triton_forwarder.rs index b6aa501..a2497ba 100644 --- a/proxy/src/triton_forwarder.rs +++ b/proxy/src/triton_forwarder.rs @@ -85,7 +85,7 @@ impl Default for PktRecvTileMemConfig { fn packet_recv_tile( pkt_recv_idx: usize, - pkt_recv_socket: UdpSocket, + pkt_recv_socket_vec: Vec, exit: Arc, forwarder_stats: Arc, mut fill_rx: Rx, @@ -100,10 +100,9 @@ where .name(format!("ssListen{pkt_recv_idx}")) .spawn(move || { crate::recv_mmsg::recv_loop( - &pkt_recv_socket, + pkt_recv_socket_vec, &exit, &forwarder_stats, - Duration::default(), &mut fill_rx, &packet_tx_vec, packet_router, @@ -364,28 +363,37 @@ pub fn run_proxy_system( { let mut tile_thread_vec: Vec> = Vec::new(); // Build pkt_recv sockets - let pkt_recv_sk_vec = if let Some(multicast_config) = multticast_config { + let pkt_recv_multicast_sk_vec = if let Some(multicast_config) = multticast_config { log::info!("Using Triton multicast configuration for pkt_recv tiles"); - crate::triton_multicast_config::create_multicast_sockets_triton( + let vec = crate::triton_multicast_config::create_multicast_sockets_triton( &multicast_config, NonZeroUsize::new(num_pkt_recv_tiles).expect("num_pkt_recv_tiles must be non-zero"), - src_ip, - src_port, - ).expect("multicast-config") + ).expect("multicast-config"); + Some(vec) } else { - let (_port, pkt_recv_sk_vec) = solana_net_utils::multi_bind_in_range_with_config( - src_ip, - (src_port, src_port + 1), - SocketConfig::default().reuseport(true), - num_pkt_recv_tiles, - ) - .unwrap_or_else(|_| { - panic!("Failed to bind listener sockets. Check that port {src_port} is not in use.") - }); - pkt_recv_sk_vec + None }; + let (_port, pkt_recv_sk_vec) = solana_net_utils::multi_bind_in_range_with_config( + src_ip, + (src_port, src_port + 1), + SocketConfig::default().reuseport(true), + num_pkt_recv_tiles, + ) + .unwrap_or_else(|_| { + panic!("Failed to bind listener sockets. Check that port {src_port} is not in use.") + }); assert!(pkt_recv_sk_vec.len() == num_pkt_recv_tiles, "pkt_recv_sk_vec.len() ({}) != num_pkt_recv_tiles ({})", pkt_recv_sk_vec.len(), num_pkt_recv_tiles); + if let Some(multicast_sk_vec) = &pkt_recv_multicast_sk_vec { + assert!(multicast_sk_vec.len() == num_pkt_recv_tiles, "multicast_sk_vec.len() ({}) != num_pkt_recv_tiles ({})", multicast_sk_vec.len(), num_pkt_recv_tiles); + } + + // Make sure socket are set to nonblocking + for sk in &pkt_recv_sk_vec { + sk.set_nonblocking(true).expect("pkt_recv_sk nonblocking"); + } + + let mut pkt_recv_sk_raw_fd_vec: Vec = Vec::with_capacity(num_pkt_recv_tiles); for sk in &pkt_recv_sk_vec { pkt_recv_sk_raw_fd_vec.push(sk.as_raw_fd()); @@ -541,13 +549,23 @@ pub fn run_proxy_system( pkt_recv_sk_vec.into_iter(), fill_rx_vec.into_iter() ) { + + let mut recv_pkt_vec = vec![ + pkt_recv_sk + ]; + + if let Some(multicast_sk_vec) = &pkt_recv_multicast_sk_vec { + recv_pkt_vec.push(multicast_sk_vec[pkt_recv_idx].try_clone().expect("multicast sk clone")); + } + + let exit = Arc::clone(&exit); let forwarder_stats = Arc::clone(&pk_recv_stats); let packet_tx_vec_clone = packet_tx_vec.clone(); let pkt_router_clone = pkt_router.clone(); let jh = packet_recv_tile( pkt_recv_idx, - pkt_recv_sk, + recv_pkt_vec, exit, forwarder_stats, fill_rx, @@ -572,17 +590,17 @@ pub fn run_proxy_system( // There is a special case where pkt_recv could be blocked inside recvmmsg syscall if no new packets arrive in a triton_recv_msg iteratoin (when newly set to blocking mode). // We have to force close the socket to break it out of the syscall. - for sk_fd in pkt_recv_sk_raw_fd_vec { - unsafe { - libc::close(sk_fd); - } - } - - for sk_fd in pkt_fwd_sk_raw_fd_vec { - unsafe { - libc::close(sk_fd); - } - } + // for sk_fd in pkt_recv_sk_raw_fd_vec { + // unsafe { + // libc::close(sk_fd); + // } + // } + + // for sk_fd in pkt_fwd_sk_raw_fd_vec { + // unsafe { + // libc::close(sk_fd); + // } + // } for th in tile_thread_vec { let result = th.join(); diff --git a/proxy/src/triton_multicast_config.rs b/proxy/src/triton_multicast_config.rs index c312bbd..70831bd 100644 --- a/proxy/src/triton_multicast_config.rs +++ b/proxy/src/triton_multicast_config.rs @@ -113,11 +113,13 @@ pub fn parse_ifindex_from_ip_link_show_json(bytes: &[u8]) -> io::Result, + pub listen_port: u16, } pub struct TritonMulticastConfigV6 { pub multicast_ip: Ipv6Addr, pub device_ifname: String, + pub listen_port: u16, } pub enum TritonMulticastConfig { @@ -137,8 +139,6 @@ impl TritonMulticastConfig { pub fn create_multicast_sockets_triton_v4( config: &TritonMulticastConfigV4, num_threads: NonZeroUsize, - src_ip: IpAddr, - src_port: u16, ) -> io::Result> { let device_ip = match config.bind_ifname.as_ref() { Some(ifname) => { @@ -148,15 +148,14 @@ pub fn create_multicast_sockets_triton_v4( }, None => Ipv4Addr::UNSPECIFIED, }; + let port = config.listen_port; log::info!("multicast device {} has ip {}", config.bind_ifname.as_deref().unwrap_or("unspecified"), device_ip); - let IpAddr::V4(src_ip) = src_ip else { - return Err(io::Error::new(io::ErrorKind::InvalidInput, "Source IP must be IPv4")); - }; // Step 1: Create first socket, port = 0 → random ephemeral port let first_socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?; first_socket.set_reuse_address(true)?; first_socket.set_reuse_port(true)?; - first_socket.bind(&SockAddr::from(SocketAddr::new(IpAddr::V4(src_ip), src_port)))?; + first_socket.set_nonblocking(true)?; + first_socket.bind(&SockAddr::from(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port)))?; first_socket.join_multicast_v4(&config.multicast_ip, &device_ip)?; log::info!("Joined multicast group {} on device IP {}", config.multicast_ip, device_ip); let local_port = first_socket.local_addr()?.as_socket().unwrap().port(); @@ -173,6 +172,7 @@ pub fn create_multicast_sockets_triton_v4( IpAddr::V4(Ipv4Addr::UNSPECIFIED), local_port, )))?; + socket.set_nonblocking(true)?; socket.join_multicast_v4(&config.multicast_ip, &device_ip)?; sockets.push(socket.into()); } @@ -201,22 +201,18 @@ pub fn create_multicast_sockets_triton_v4( pub fn create_multicast_sockets_triton_v6( config: &TritonMulticastConfigV6, num_threads: NonZeroUsize, - src_ip: IpAddr, - src_port: u16, ) -> io::Result> { // Get the interface index for the device name (e.g. "eth0") let ifindex = ifindex_for_device(&config.device_ifname)? .ok_or_else(|| io::Error::new(io::ErrorKind::NotFound, format!("No such device {}", config.device_ifname)))?; - let IpAddr::V6(_src_ip) = src_ip else { - return Err(io::Error::new(io::ErrorKind::InvalidInput, "Source IP must be IPv6")); - }; // Step 1: Bind first socket to port 0 to let kernel choose a random available port let first_socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; first_socket.set_only_v6(true)?; // IPv6-only first_socket.set_reuse_address(true)?; first_socket.set_reuse_port(true)?; - first_socket.bind(&SockAddr::from(SocketAddr::new(IpAddr::V6(_src_ip), src_port)))?; + first_socket.set_nonblocking(true)?; + first_socket.bind(&SockAddr::from(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), config.listen_port)))?; first_socket.join_multicast_v6(&config.multicast_ip, ifindex)?; let local_port = first_socket.local_addr()?.as_socket().unwrap().port(); // Step 2: Create N-1 additional sockets on the same port for load balancing @@ -232,6 +228,7 @@ pub fn create_multicast_sockets_triton_v6( IpAddr::V6(Ipv6Addr::UNSPECIFIED), local_port, )))?; + socket.set_nonblocking(true)?; socket.join_multicast_v6(&config.multicast_ip, ifindex)?; sockets.push(socket.into()); } @@ -242,16 +239,14 @@ pub fn create_multicast_sockets_triton_v6( pub fn create_multicast_sockets_triton( config: &TritonMulticastConfig, num_threads: NonZeroUsize, - src_ip: IpAddr, - src_port: u16, ) -> Result, io::Error> { match config { TritonMulticastConfig::Ipv4(cfg) => { - create_multicast_sockets_triton_v4(cfg, num_threads, src_ip, src_port) + create_multicast_sockets_triton_v4(cfg, num_threads) } TritonMulticastConfig::Ipv6(cfg) => { - create_multicast_sockets_triton_v6(cfg, num_threads, src_ip, src_port) + create_multicast_sockets_triton_v6(cfg, num_threads) } } } From 7bcb81985e0cd27a2bb88fdae8e59abd6b8280da Mon Sep 17 00:00:00 2001 From: Louis-Vincent Boudreault Date: Mon, 26 Jan 2026 15:36:42 +0000 Subject: [PATCH 11/13] support doublezero --- proxy/src/main2.rs | 27 ++++++- proxy/src/triton_forwarder.rs | 14 ++++ proxy/src/triton_multicast_config.rs | 110 +++++++++++++++++++++++++++ 3 files changed, 147 insertions(+), 4 deletions(-) diff --git a/proxy/src/main2.rs b/proxy/src/main2.rs index 3e38fb5..7b21a12 100644 --- a/proxy/src/main2.rs +++ b/proxy/src/main2.rs @@ -20,7 +20,7 @@ use tokio::runtime::Runtime; use tonic::Status; use crate::{ - forwarder::ShredMetrics, triton_multicast_config::{TritonMulticastConfig, TritonMulticastConfigV4, TritonMulticastConfigV6, create_multicast_sockets_triton}, recv_mmsg::FECSetRoutingStrategy, token_authenticator::BlockEngineConnectionError, triton_forwarder::PktRecvTileMemConfig, + forwarder::ShredMetrics, recv_mmsg::FECSetRoutingStrategy, token_authenticator::BlockEngineConnectionError, triton_forwarder::PktRecvTileMemConfig, triton_multicast_config::{TritonMulticastConfig, TritonMulticastConfigV4, TritonMulticastConfigV6, create_multicast_socket_on_device, create_multicast_sockets_triton} }; pub mod deshred; pub mod forwarder; @@ -248,13 +248,22 @@ fn main() -> Result<(), ShredstreamProxyError> { let prom_registry = prometheus::Registry::new(); prom::register_metrics(&prom_registry); let all_args: Args = Args::parse(); - let shredstream_args = all_args.shredstream_args.clone(); // common args let args = match all_args.shredstream_args { ProxySubcommands::Shredstream(x) => x.common_args, ProxySubcommands::ForwardOnly(x) => x, }; + + + let num_pkt_recv_tiles = args.num_pkt_recv_tile + .map(|x| x.get()) + .unwrap_or(args.num_threads.unwrap_or(1)); + + let num_pkt_fwd_tiles = args.num_pkt_fwd_tile + .map(|x| x.get()) + .unwrap_or(args.num_threads.unwrap_or(1)); + set_host_id(hostname::get()?.into_string().unwrap()); if (args.endpoint_discovery_url.is_none() && args.discovered_endpoints_port.is_some()) || (args.endpoint_discovery_url.is_some() && args.discovered_endpoints_port.is_none()) @@ -313,6 +322,14 @@ fn main() -> Result<(), ShredstreamProxyError> { let use_discovery_service = args.endpoint_discovery_url.is_some() && args.discovered_endpoints_port.is_some(); + let (doublezero_v4_sk_vec, doublezero_v6_sk_vec) = create_multicast_socket_on_device( + &args.multicast_device, + args.multicast_subscribe_port, + args.multicast_bind_ip, + num_pkt_recv_tiles, + )?; + + let maybe_triton_multicast_config = match args.triton_multicast_group { Some(multicast_group) => { log::info!("Using triton multicast group: {}", multicast_group); @@ -361,12 +378,14 @@ fn main() -> Result<(), ShredstreamProxyError> { maybe_triton_multicast_config, args.src_bind_addr, args.src_bind_port, - args.num_pkt_recv_tile.map(|x| x.get()).unwrap_or(1), - args.num_pkt_fwd_tile.map(|x| x.get()).unwrap_or(1), + num_pkt_recv_tiles, + num_pkt_fwd_tiles, FECSetRoutingStrategy, exit, pkt_recv_stats, pkt_fwd_stats, + doublezero_v4_sk_vec, + doublezero_v6_sk_vec, ); }) .expect("tritonProxyMain") diff --git a/proxy/src/triton_forwarder.rs b/proxy/src/triton_forwarder.rs index a2497ba..0254498 100644 --- a/proxy/src/triton_forwarder.rs +++ b/proxy/src/triton_forwarder.rs @@ -346,6 +346,7 @@ fn packet_fwd_tile( }) } +#[allow(clippy::too_many_arguments)] pub fn run_proxy_system( pkt_recv_tile_mem_config: PktRecvTileMemConfig, dest_addr_vec: Arc>>, @@ -358,6 +359,8 @@ pub fn run_proxy_system( exit: Arc, pk_recv_stats: Arc, pk_fwd_stats: Arc, + doublezero_v4_sk_vec: Vec, + doublezero_v6_sk_vec: Vec, ) where R: PacketRoutingStrategy + Send + Sync + 'static, { @@ -373,6 +376,10 @@ pub fn run_proxy_system( } else { None }; + + assert!(doublezero_v4_sk_vec.len() <= num_pkt_recv_tiles, "doublezero_v4_sk_vec.len() ({}) > num_pkt_recv_tiles ({})", doublezero_v4_sk_vec.len(), num_pkt_recv_tiles); + assert!(doublezero_v6_sk_vec.len() <= num_pkt_recv_tiles, "doublezero_v6_sk_vec.len() ({}) > num_pkt_recv_tiles ({})", doublezero_v6_sk_vec.len(), num_pkt_recv_tiles); + let (_port, pkt_recv_sk_vec) = solana_net_utils::multi_bind_in_range_with_config( src_ip, (src_port, src_port + 1), @@ -558,6 +565,13 @@ pub fn run_proxy_system( recv_pkt_vec.push(multicast_sk_vec[pkt_recv_idx].try_clone().expect("multicast sk clone")); } + if let Some(doublezero_sk) = doublezero_v4_sk_vec.get(pkt_recv_idx) { + recv_pkt_vec.push(doublezero_sk.try_clone().expect("doublezero v4 sk clone")); + } + + if let Some(doublezero_sk) = doublezero_v6_sk_vec.get(pkt_recv_idx) { + recv_pkt_vec.push(doublezero_sk.try_clone().expect("doublezero v6 sk clone")); + } let exit = Arc::clone(&exit); let forwarder_stats = Arc::clone(&pk_recv_stats); diff --git a/proxy/src/triton_multicast_config.rs b/proxy/src/triton_multicast_config.rs index 70831bd..c84b92d 100644 --- a/proxy/src/triton_multicast_config.rs +++ b/proxy/src/triton_multicast_config.rs @@ -2,6 +2,8 @@ use std::{ io, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}, num::NonZeroUsize, process::Command }; +use itertools::{Itertools, Either}; +use log::{debug, warn, info}; use serde::Deserialize; use socket2::{Domain, Protocol, SockAddr, Socket, Type}; @@ -110,6 +112,114 @@ pub fn parse_ifindex_from_ip_link_show_json(bytes: &[u8]) -> io::Result, + num_threads: usize, +) -> std::io::Result<(Vec, Vec)> { + let device_ipv4 = ipv4_addr_for_device(device_name).unwrap_or_else(|e| { + debug!("Failed to resolve IPv4 address for device {device_name}: {e}"); + None + }); + let device_ifindex_v6 = ifindex_for_device(device_name).unwrap_or_else(|e| { + debug!("Failed to resolve ifindex for device {device_name}: {e}"); + None + }); + + let mut multicast_groups: Vec = Vec::new(); + let (groups_v4, groups_v6): (Vec, Vec) = match multicast_ip { + Some(IpAddr::V4(g)) => (vec![g], Vec::new()), + Some(IpAddr::V6(g6)) => (Vec::new(), vec![g6]), + None => match get_ip_route_for_device(device_name) { + Ok(ips) => ips.into_iter().partition_map(|ip| match ip { + IpAddr::V4(v4) => Either::Left(v4), + IpAddr::V6(v6) => Either::Right(v6), + }), + Err(e) => { + debug!("Failed to parse 'ip route list' for {device_name}: {e}"); + (Vec::new(), Vec::new()) + } + }, + }; + + multicast_groups.extend(groups_v4.iter().map(|ipv4| IpAddr::V4(*ipv4))); + multicast_groups.extend(groups_v6.iter().map(|ipv6| IpAddr::V6(*ipv6))); + + if groups_v4.is_empty() && groups_v6.is_empty() { + debug!("No multicast groups found for device {device_name}; skipping multicast listener"); + return Ok((Vec::new(), Vec::new())); + } + + let mut sockets_v4: Vec = Vec::new(); + if !groups_v4.is_empty() { + let addr_v4 = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), multicast_port); + + let first_socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?; + first_socket.set_reuse_address(true)?; + first_socket.set_reuse_port(true)?; + first_socket.set_nonblocking(true)?; + first_socket.bind(&SockAddr::from(addr_v4))?; + let actual_socket_addr = first_socket.local_addr()?; + + for g in &groups_v4 { + first_socket.join_multicast_v4(g, &device_ipv4.unwrap_or(Ipv4Addr::UNSPECIFIED))?; + info!("Joined IPv4 multicast group {g} port {multicast_port}"); + } + sockets_v4.push(first_socket.into()); + + // Create N-1 additional sockets on the same port for load balancing + for _ in 1..num_threads { + let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?; + socket.set_reuse_address(true)?; + socket.set_reuse_port(true)?; + socket.bind(&actual_socket_addr)?; + socket.set_nonblocking(true)?; + for g in &groups_v4 { + socket.join_multicast_v4(g, &device_ipv4.unwrap_or(Ipv4Addr::UNSPECIFIED))?; + } + sockets_v4.push(socket.into()); + } + } + + let mut sockets_v6: Vec = Vec::new(); + if !groups_v6.is_empty() { + let addr_v6 = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), multicast_port); + + let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + socket.set_only_v6(true)?; // IPv6-only + socket.set_reuse_address(true)?; + socket.set_reuse_port(true)?; + socket.set_nonblocking(true)?; + socket.bind(&SockAddr::from(addr_v6))?; + let actual_socket_addr = socket.local_addr()?; + for g in &groups_v6 { + socket.join_multicast_v6(g, device_ifindex_v6.unwrap_or(0))?; + info!("Joined IPv6 multicast group {g} port {multicast_port}"); + } + sockets_v6.push(socket.into()); + + // Create N-1 additional sockets on the same port for load balancing + for _ in 1..num_threads { + let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?; + socket.set_only_v6(true)?; + socket.set_reuse_address(true)?; + socket.set_reuse_port(true)?; + socket.bind(&actual_socket_addr)?; + socket.set_nonblocking(true)?; + for g in &groups_v6 { + socket.join_multicast_v6(g, device_ifindex_v6.unwrap_or(0))?; + } + sockets_v6.push(socket.into()); + } + } + + Ok((sockets_v4, sockets_v6)) +} + pub struct TritonMulticastConfigV4 { pub multicast_ip: Ipv4Addr, pub bind_ifname: Option, From 488416a57c33fe04b13c51129460d56abc5225f4 Mon Sep 17 00:00:00 2001 From: Louis-Vincent Boudreault Date: Mon, 26 Jan 2026 16:06:16 +0000 Subject: [PATCH 12/13] cleaned trace logs --- proxy/src/recv_mmsg.rs | 18 ++---------------- proxy/src/triton_forwarder.rs | 34 ++++++++++------------------------ 2 files changed, 12 insertions(+), 40 deletions(-) diff --git a/proxy/src/recv_mmsg.rs b/proxy/src/recv_mmsg.rs index 64c3e5c..1129494 100644 --- a/proxy/src/recv_mmsg.rs +++ b/proxy/src/recv_mmsg.rs @@ -11,7 +11,7 @@ use std::{ use bytes::{Buf, BufMut}; use itertools::izip; use libc::{AF_INET, AF_INET6, MSG_DONTWAIT, iovec, mmsghdr, msghdr, sockaddr_storage}; -use log::{error, trace}; +use log::error; use mio::Poll; use socket2::socklen_t; use solana_perf::packet::{NUM_RCVMMSGS, PACKETS_PER_BATCH}; @@ -144,8 +144,6 @@ where log::debug!("recv_loop: no available frame buffers to receive into"); continue 'drain_readiness_loop; } - - log::trace!("frame bufmut_vec length: {}", frame_bufmut_vec.len()); let t = Instant::now(); let result = recv_from(&mut frame_bufmut_vec, recv_sk, &mut packet_batch, &exit); @@ -154,10 +152,8 @@ where match result { Ok(len) => { - log::trace!("recv_from got {} packets in {:?}", len, recv_interval); if len > 0 { // observe_recv_interval(recv_interval.as_micros() as f64); - log::trace!("Received {} packets", len); inc_packets_received(len as u64); observe_recv_packet_count(len as f64); let StreamerReceiveStats { @@ -180,7 +176,7 @@ where let dest_idx = match router.route_packet(&packet, packet_tx_vec.len()) { Some(idx) => idx, None => { - log::debug!("Failed to route packet {:?}", packet); + log::trace!("Failed to route packet {:?}", packet); let trashed_frame_bufmut = packet.buffer.into_inner().as_mut_buf(); frame_bufmut_vec.push(trashed_frame_bufmut); continue 'packet_drain; @@ -229,16 +225,13 @@ pub fn recv_from( // * read until it fails // * set it back to blocking before returning // socket.set_nonblocking(false)?; - trace!("receiving on {}", socket.local_addr().unwrap()); let batch_capacity = batch.capacity(); assert!(batch_capacity >= PACKETS_PER_BATCH); let mut i = 0; while !exit.load(Ordering::Relaxed) { - log::trace!("Preparing to receive packets, currently have {} packets", i); let npkts = triton_recv_mmsg(socket, available_frame_buf_vec, batch)?; - trace!("got {} packets", npkts); i += npkts; if available_frame_buf_vec.is_empty() { break; @@ -289,7 +282,6 @@ pub fn triton_recv_mmsg( // Should never hit this, but bail if the caller didn't provide any Packets // to receive into if fill_buffers.is_empty() { - log::trace!("triton_recv_mmsg: no fill buffers to receive into"); return Ok(0); } // Assert that there are no leftovers in packets. @@ -301,10 +293,6 @@ pub fn triton_recv_mmsg( let remaining_packets = packets.capacity() - packets.len(); let sock_fd = sock.as_raw_fd(); let count = cmp::min(iovs.len(), remaining_packets).min(fill_buffers.len()); - log::trace!( - "triton_recv_mmsg: preparing to receive up to {} packets", - count - ); let mut frame_buffer_inflight_vec: [MaybeUninit; NUM_RCVMMSGS] = std::array::from_fn(|_| MaybeUninit::uninit()); @@ -338,7 +326,6 @@ pub fn triton_recv_mmsg( tv_nsec: 0, }; // TODO: remove .try_into().unwrap() once rust libc fixes recvmmsg types for musl - log::trace!("Calling recvmmsg with count={}", count); #[allow(clippy::useless_conversion)] let nrecv = unsafe { libc::recvmmsg( @@ -349,7 +336,6 @@ pub fn triton_recv_mmsg( &mut ts, ) }; - log::trace!("recvmmsg returned nrecv={}", nrecv); let nrecv = if nrecv < 0 { // On error, return all in-flight frame buffers back to the caller for i in 0..frame_buffer_inflight_cnt { diff --git a/proxy/src/triton_forwarder.rs b/proxy/src/triton_forwarder.rs index 0254498..0114dca 100644 --- a/proxy/src/triton_forwarder.rs +++ b/proxy/src/triton_forwarder.rs @@ -207,7 +207,7 @@ fn packet_fwd_tile( ); next_deduper_reset_attempt = Instant::now() + Duration::from_secs(2); // show stats here... - log::trace!( + log::debug!( "send_batch_count: {}, duplicate: {}, total-pkt-sent: {}, queue-len: {}, to-recycle: {}", stats.send_batch_count.load(Ordering::Relaxed), stats.duplicate.load(Ordering::Relaxed), @@ -274,6 +274,10 @@ fn packet_fwd_tile( recycled_frames.push(desc); for dest in dests.iter() { + let origin = packet.meta.socket_addr().ip(); + if origin == dest.ip() { + continue; + } // Cheap to do since we are just copying a pointer let buf_clone = unsafe { buf.unsafe_subslice_clone(0, packet.meta.size) }; next_batch_send.push((buf_clone, *dest)); @@ -565,12 +569,12 @@ pub fn run_proxy_system( recv_pkt_vec.push(multicast_sk_vec[pkt_recv_idx].try_clone().expect("multicast sk clone")); } - if let Some(doublezero_sk) = doublezero_v4_sk_vec.get(pkt_recv_idx) { - recv_pkt_vec.push(doublezero_sk.try_clone().expect("doublezero v4 sk clone")); + if let Some(doublezero_v4_sk) = doublezero_v4_sk_vec.get(pkt_recv_idx) { + recv_pkt_vec.push(doublezero_v4_sk.try_clone().expect("doublezero v4 sk clone")); } - if let Some(doublezero_sk) = doublezero_v6_sk_vec.get(pkt_recv_idx) { - recv_pkt_vec.push(doublezero_sk.try_clone().expect("doublezero v6 sk clone")); + if let Some(doublezero_v6_sk) = doublezero_v6_sk_vec.get(pkt_recv_idx) { + recv_pkt_vec.push(doublezero_v6_sk.try_clone().expect("doublezero v6 sk clone")); } let exit = Arc::clone(&exit); @@ -600,22 +604,7 @@ pub fn run_proxy_system( drop(fill_tx_vec); drop(packet_tx_vec); log::info!("Waiting for {} tile threads to exit", tile_thread_vec.len()); - - - // There is a special case where pkt_recv could be blocked inside recvmmsg syscall if no new packets arrive in a triton_recv_msg iteratoin (when newly set to blocking mode). - // We have to force close the socket to break it out of the syscall. - // for sk_fd in pkt_recv_sk_raw_fd_vec { - // unsafe { - // libc::close(sk_fd); - // } - // } - - // for sk_fd in pkt_fwd_sk_raw_fd_vec { - // unsafe { - // libc::close(sk_fd); - // } - // } - + for th in tile_thread_vec { let result = th.join(); if let Err(e) = result { @@ -624,9 +613,6 @@ pub fn run_proxy_system( } } - - - /// Reset dedup + send metrics to influx pub fn start_forwarder_accessory_thread( metrics: Arc, From 9a5ead52cb33fdeb195493c0f3868e944794510a Mon Sep 17 00:00:00 2001 From: Louis-Vincent Boudreault Date: Mon, 26 Jan 2026 16:25:01 +0000 Subject: [PATCH 13/13] buidl script --- .gitignore | 2 ++ proxy/Cargo.toml | 2 +- scripts/build-dist.sh | 9 +++++++++ 3 files changed, 12 insertions(+), 1 deletion(-) create mode 100755 scripts/build-dist.sh diff --git a/.gitignore b/.gitignore index c85eb61..41be277 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ .local .vscode + +dist \ No newline at end of file diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 881cb25..0de71da 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -7,7 +7,7 @@ homepage = { workspace = true } edition = { workspace = true } [[bin]] -name = "triton-proxy" +name = "triton-shredproxy" path = "src/main2.rs" [[bin]] diff --git a/scripts/build-dist.sh b/scripts/build-dist.sh new file mode 100755 index 0000000..d8b4ae0 --- /dev/null +++ b/scripts/build-dist.sh @@ -0,0 +1,9 @@ +#!/bin/bash + + +mkdir -p dist +rm -rf dist/* + +cargo build --release --bin triton-shredproxy + +mv target/release/triton-shredproxy dist/triton-shredproxy-ubuntu-22.04 \ No newline at end of file