diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5f5d293..8989561 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,7 +13,9 @@ jobs: - name: "Setup Rust Toolchain" uses: actions-rust-lang/setup-rust-toolchain@v1 - name: "Run Tests" - run: cargo test --all-features + run: cargo test --all-features --doc + - name: "Run Example" + run: cd examples/echo_chat && cargo run # Check formatting with rustfmt formatting: diff --git a/Cargo.lock b/Cargo.lock index 5c1f9e2..b0ddde5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -67,9 +67,9 @@ checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" [[package]] name = "cc" -version = "1.2.6" +version = "1.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d6dbb628b8f8555f86d0323c2eb39e3ec81901f4b83e091db8a6a76d316a333" +checksum = "a012a0df96dd6d06ba9a1b29d6402d1a5d77c6befd2566afdc26e10603dc93d7" dependencies = [ "shlex", ] @@ -806,7 +806,7 @@ dependencies = [ [[package]] name = "socketeer" -version = "0.0.3" +version = "0.1.0" dependencies = [ "bytes", "futures", @@ -829,9 +829,9 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "syn" -version = "2.0.93" +version = "2.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c786062daee0d6db1132800e623df74274a0a87322d8e183338e01b3d98d058" +checksum = "46f71c0377baf4ef1cc3e3402ded576dccc315800fbc62dfc7fe04b009773b4a" dependencies = [ "proc-macro2", "quote", @@ -851,12 +851,13 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.14.0" +version = "3.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" +checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" dependencies = [ "cfg-if", "fastrand", + "getrandom", "once_cell", "rustix", "windows-sys 0.59.0", diff --git a/Cargo.toml b/Cargo.toml index 215f4c7..6e82e63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "socketeer" -version = "0.0.3" +version = "0.1.0" edition = "2021" description = "Simplified websocket client based on Tokio-Tungstenite" authors = ["Zach Heylmun "] diff --git a/src/error.rs b/src/error.rs index 3c4660d..d6687cb 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,25 +1,35 @@ use thiserror::Error; use tokio_tungstenite::tungstenite::Message; +/// Error type for the Socketeer library. +/// This type is used to represent all possible external errors that can occur when using the Socketeer library. #[derive(Debug, Error)] pub enum Error { /// Url Parse Error #[error("Failed to parse URL: {}", 0)] UrlParse { + /// The URL that failed to parse url: String, + /// The source of the error, from the [URL crate](https://docs.rs/url/2.2.2/url/enum.ParseError.html) #[source] source: url::ParseError, }, + /// Websocket Error + /// Error thrown by the Tungstenite library when there is an issue with the websocket connection. #[error("Tungstenite error: {0}")] WebsocketError(#[from] tokio_tungstenite::tungstenite::Error), + /// Socketeer error when the websocket connection is closed unexpectedly. #[error("Socket Closed")] WebsocketClosed, - #[error("Channel Full")] - ChannelFull, + /// Error thrown if a message type not handled by `socketeer` is received. #[error("Unexpected Message type: {0}")] - UnexpectedMessage(Message), + UnexpectedMessageType(Message), + /// Error thrown if the message received fails to serialize or deserialize. #[error("Serialization Error: {0}")] SerializationError(#[from] serde_json::Error), + /// Error thrown if socketeer is dropped without closing the connection. + /// This error will be removed once async destructors are stabilized. + /// See [issue](https://github.com/rust-lang/rust/issues/126482) #[error("Socketeer dropped without closing")] - SocketeerDropped, + SocketeerDroppedWithoutClosing, } diff --git a/src/lib.rs b/src/lib.rs index 6f497b8..4dce9b6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,5 @@ #![doc = include_str!("../README.md")] - +#![deny(missing_docs)] mod error; #[cfg(feature = "mocking")] mod mock_server; @@ -32,6 +32,13 @@ struct TxChannelPayload { response_tx: oneshot::Sender>, } +/// A WebSocket client that manages the connection to a WebSocket server. +/// The client can send and receive messages, and will transparently handle protocol messages. +/// # Type Parameters +/// - `RxMessage`: The type of message that the client will receive from the server. +/// - `TxMessage`: The type of message that the client will send to the server. +/// - `CHANNEL_SIZE`: The size of the internal channels used to communicate between +/// the task managing the WebSocket connection and the client. #[derive(Debug)] pub struct Socketeer< RxMessage: for<'a> Deserialize<'a> + Debug, @@ -84,6 +91,12 @@ impl< }) } + /// Wait for the next parsed message from the WebSocket connection. + /// + /// # Errors + /// + /// - If the WebSocket connection is closed or otherwise errored + /// - If the message cannot be deserialized #[cfg_attr(feature = "tracing", instrument)] pub async fn next_message(&mut self) -> Result { let Some(message) = self.receiever.recv().await else { @@ -102,17 +115,24 @@ impl< let message = serde_json::from_slice(&message)?; Ok(message) } - _ => Err(Error::UnexpectedMessage(message)), + _ => Err(Error::UnexpectedMessageType(message)), } } + /// Send a message to the WebSocket connection. + /// This function will wait for the message to be sent before returning. + /// + /// # Errors + /// + /// - If the message cannot be serialized + /// - If the WebSocket connection is closed, or otherwise errored #[cfg_attr(feature = "tracing", instrument)] pub async fn send(&self, message: TxMessage) -> Result<(), Error> { #[cfg(feature = "tracing")] debug!("Sending message: {:?}", message); let (tx, rx) = oneshot::channel::>(); - let message = serde_json::to_string(&message).unwrap(); + let message = serde_json::to_string(&message)?; self.sender .send(TxChannelPayload { @@ -122,7 +142,10 @@ impl< .await .map_err(|_| Error::WebsocketClosed)?; // We'll ensure that we always respond before dropping the tx channel - rx.await.unwrap() + match rx.await { + Ok(result) => result, + Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"), + } } /// Consume self, closing down any remaining send/recieve, and return a new Socketeer instance if successful @@ -149,6 +172,11 @@ impl< Self::connect(&url).await } + /// Close the WebSocket connection gracefully. + /// This function will wait for the connection to close before returning. + /// # Errors + /// - If the WebSocket connection is already closed + /// - If the WebSocket connection cannot be closed #[cfg_attr(feature = "tracing", instrument)] pub async fn close_connection(self) -> Result<(), Error> { #[cfg(feature = "tracing")] @@ -164,9 +192,14 @@ impl< }) .await .map_err(|_| Error::WebsocketClosed)?; - rx.await.unwrap()?; - self.socket_handle.await.unwrap()?; - Ok(()) + match rx.await { + Ok(result) => result, + Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"), + }?; + match self.socket_handle.await { + Ok(result) => result, + Err(_) => unreachable!("Socket loop does not panic, and is not cancelled"), + } } } @@ -219,12 +252,12 @@ async fn send_socket_message( LoopState::Running } } - Err(_) => LoopState::Error(Error::SocketeerDropped), + Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing), } } else { #[cfg(feature = "tracing")] error!("Socketeer dropped without closing connection"); - LoopState::Error(Error::SocketeerDropped) + LoopState::Error(Error::SocketeerDroppedWithoutClosing) } } @@ -264,7 +297,7 @@ async fn socket_message_received( } Message::Text(_) | Message::Binary(_) => match sender.send(message).await { Ok(()) => LoopState::Running, - Err(_) => LoopState::Error(Error::SocketeerDropped), + Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing), }, _ => LoopState::Running, }, diff --git a/src/mock_server.rs b/src/mock_server.rs index 2ac49e1..13e926e 100644 --- a/src/mock_server.rs +++ b/src/mock_server.rs @@ -19,8 +19,11 @@ use tracing::debug; /// Control messages for testing with the echo server. #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, PartialOrd)] pub enum EchoControlMessage { + /// Send a message which the server should echo back Message(String), + /// Request that the server send the client a ping SendPing, + /// Request that the server close the connection Close, }