From 3e9ede7860328f6f396ae927b7cf87c4c1198a3a Mon Sep 17 00:00:00 2001 From: Jakob Truelsen Date: Sat, 8 Feb 2025 12:04:20 +0100 Subject: [PATCH] client-daemon: Fix more race conditions in upstream connection handeling --- src/bin/sadmin/client_daemon.rs | 113 ++++++++++++++++---------------- src/bin/sadmin/main.rs | 2 + src/bin/sadmin/state.rs | 83 +++++++++++++++++++++++ 3 files changed, 140 insertions(+), 58 deletions(-) create mode 100644 src/bin/sadmin/state.rs diff --git a/src/bin/sadmin/client_daemon.rs b/src/bin/sadmin/client_daemon.rs index e15b50e7..5f977799 100644 --- a/src/bin/sadmin/client_daemon.rs +++ b/src/bin/sadmin/client_daemon.rs @@ -30,10 +30,7 @@ use tokio::{ }, process::ChildStdin, select, - sync::{ - mpsc::{UnboundedReceiver, UnboundedSender}, - Notify, - }, + sync::mpsc::{UnboundedReceiver, UnboundedSender}, time::timeout, }; use tokio_rustls::{client::TlsStream, rustls, TlsConnector}; @@ -52,6 +49,7 @@ use crate::{ connection::Config, persist_daemon, service_control::DaemonControlMessage, + state::State, tokio_passfd::{self}, }; use sdnotify::SdNotify; @@ -103,16 +101,20 @@ pub struct ClientDaemon { pub type PersistMessageSender = tokio::sync::oneshot::Sender<(persist_daemon::Message, Option)>; +#[derive(PartialEq, Eq)] +enum ConnectionState { + Good, + Bad, +} + pub struct Client { connector: TlsConnector, pub config: Config, command_tasks: Mutex>>, - send_failure_notify: Notify, - sender_clear: Notify, - new_send_notify: Notify, sender: tokio::sync::Mutex< Option>>, >, + connection_state: State, script_stdin: Mutex>>, persist_responses: Mutex>, persist_idc: AtomicU64, @@ -141,34 +143,46 @@ impl Client { message.push(30); loop { let mut s = self.sender.lock().await; - if let Some(v) = s.deref_mut() { - let write_all = write_all_and_flush(v, &message); - let sender_clear = self.sender_clear.notified(); - let sleep = tokio::time::sleep(Duration::from_secs(40)); - tokio::select!( - val = write_all => { - if let Err(e) = val { - // The send errored out, notify the recv half so we can try to initiate a new connection - error!("Failed sending message to backend: {}", e); - self.send_failure_notify.notify_one(); - *s = None - } - break - } - _ = sender_clear => {}, - _ = sleep => { - // The send timeouted, notify the recv half so we can try to initiate a new connection - error!("Timout sending message to server"); - self.send_failure_notify.notify_one(); - *s = None - } - ); + if *self.connection_state.get() != ConnectionState::Good { + std::mem::drop(s); + // We do not currently have a send socket so lets wait for one + info!("We do not currently have a send socket so lets wait for one"); + self.connection_state + .wait(|s| s == &ConnectionState::Good) + .await; continue; - } - // We do not currently have a send socket so lets wait for one - info!("We do not currently have a send socket so lets wait for one"); - std::mem::drop(s); - self.new_send_notify.notified().await; + }; + let Some(v) = s.deref_mut() else { + std::mem::drop(s); + // We do not currently have a send socket so lets wait for one + error!("Logic error do not currently have a send socket so lets wait for one"); + self.connection_state.set(ConnectionState::Bad); + self.connection_state + .wait(|s| s == &ConnectionState::Good) + .await; + continue; + }; + let write_all = write_all_and_flush(v, &message); + let disconnected = self.connection_state.wait(|v| v != &ConnectionState::Good); + let sleep = tokio::time::sleep(Duration::from_secs(40)); + tokio::select!( + val = write_all => { + if let Err(e) = val { + // The send errored out, notify the recv half so we can try to initiate a new connection + error!("Failed sending message to backend: {}", e); + self.connection_state.set(ConnectionState::Bad); + } + break + } + _ = disconnected => { + // We are disconnected, wait for reconnect + }, + _ = sleep => { + // The send timeouted, notify the recv half so we can try to initiate a new connection + error!("Timout sending message to server"); + self.connection_state.set(ConnectionState::Bad); + } + ); } } @@ -940,8 +954,9 @@ impl Client { auth_message.push(30); write_all_and_flush(&mut write, &auth_message).await?; - *self.sender.lock().await = Some(write); - self.new_send_notify.notify_one(); + let mut l = self.sender.lock().await; + *l = Some(write); + self.connection_state.set(ConnectionState::Good); Ok(read) } @@ -1005,7 +1020,7 @@ impl Client { } let mut start = buffer.len(); let read = read.read_buf(&mut buffer); - let send_failure = self.send_failure_notify.notified(); + let disconnect = self.connection_state.wait(|v| v != &ConnectionState::Good); let sleep = tokio::time::sleep(Duration::from_secs(120)); run_token.set_location(file!(), line!()); tokio::select! { @@ -1022,7 +1037,7 @@ impl Client { } } } - _ = send_failure => { + _ = disconnect => { break } _ = sleep => { @@ -1059,23 +1074,7 @@ impl Client { break; } } - info!("Trying to take sender for disconnect"); - run_token.set_location(file!(), line!()); - { - let f = async { - loop { - self.sender_clear.notify_waiters(); - self.sender_clear.notify_one(); - tokio::time::sleep(Duration::from_millis(1)).await - } - }; - tokio::select! { - mut l = self.sender.lock() => { - let _sender = l.take(); - } - () = f => {panic!()} - } - } + self.connection_state.set(ConnectionState::Bad); run_token.set_location(file!(), line!()); info!("Took sender for disconnect"); if let Some(notifier) = ¬ifier { @@ -1632,9 +1631,7 @@ pub async fn client_daemon(config: Config, args: ClientDaemon) -> Result<()> { config, db, command_tasks: Default::default(), - send_failure_notify: Default::default(), - sender_clear: Default::default(), - new_send_notify: Default::default(), + connection_state: State::new(ConnectionState::Bad), sender: Default::default(), script_stdin: Default::default(), persist_responses: Default::default(), diff --git a/src/bin/sadmin/main.rs b/src/bin/sadmin/main.rs index 16225d26..b2ded35f 100644 --- a/src/bin/sadmin/main.rs +++ b/src/bin/sadmin/main.rs @@ -32,6 +32,8 @@ mod run; mod service_control; mod service_deploy; #[cfg(feature = "daemon")] +mod state; +#[cfg(feature = "daemon")] mod tokio_passfd; mod upgrade; diff --git a/src/bin/sadmin/state.rs b/src/bin/sadmin/state.rs new file mode 100644 index 00000000..26f16ffb --- /dev/null +++ b/src/bin/sadmin/state.rs @@ -0,0 +1,83 @@ +use std::{ + future::Future, + ops::Deref, + sync::{Mutex, MutexGuard}, + task::Waker, +}; + +struct StateContent { + state: T, + waiters: Vec, +} + +/// Store some state T, that can be mutated +/// It is possible to wait for the value to get into a specific state +pub struct State { + content: std::sync::Mutex>, +} +pub struct StateValue<'a, T>(MutexGuard<'a, StateContent>); + +impl Deref for StateValue<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.0.state + } +} + +pub struct StateWaiter<'a, T, P: Fn(&T) -> bool> { + state: &'a State, + p: P, +} + +impl<'a, T, P: Fn(&T) -> bool> Future for StateWaiter<'a, T, P> { + type Output = StateValue<'a, T>; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let mut content = self.state.content.lock().unwrap(); + if (self.p)(&content.state) { + std::task::Poll::Ready(StateValue(content)) + } else { + content.waiters.push(cx.waker().clone()); + std::task::Poll::Pending + } + } +} + +unsafe impl bool> Send for StateWaiter<'_, T, P> {} + +impl State { + pub fn new(v: T) -> Self { + State { + content: Mutex::new(StateContent { + state: v, + waiters: Vec::new(), + }), + } + } + + /// Update the value to v, notify any waiters where v fulfills the predicate + pub fn set(&self, v: T) { + let mut inner = self.content.lock().unwrap(); + if inner.state == v { + return; + } + for w in std::mem::take(&mut inner.waiters) { + w.wake(); + } + inner.state = v; + } + + /// Get the current value, return a wrapper of the mutex lock + pub fn get(&self) -> StateValue<'_, T> { + StateValue(self.content.lock().unwrap()) + } + + /// Return future waiting for predicate to full some predicate + pub fn wait bool>(&self, p: P) -> StateWaiter<'_, T, P> { + StateWaiter { state: self, p } + } +}