From a5a2d6897dace49a47047771838d5ba32bd85b6f Mon Sep 17 00:00:00 2001 From: Christo Buschek Date: Tue, 9 Jan 2024 17:01:43 +0100 Subject: [PATCH] refactor crate to use async/tokio over blocking calls The current use of `reqwest::blocking` panics when used in an async environment. This commit keeps all functionality as is but uses the Tokio runtime to make it work in an async environment. This means in turn that the current crate does not work in a sync environment. Omitted are any updates regarding documentation. ``` #[tokio::main] async fn main() { let cache_dir = dirs::cache_dir().unwrap(); let remote = "https://..."; let local_path = "my/download/path"; let cache = Cache::builder() .dir(cache_dir) .progress_bar(None) .build() .unwrap(); let path = self .cache .cached_path_with_options(&remote, &Options::default().subdir(&local_path)) .await .unwrap(); println!("{:?}", path); } ``` --- Cargo.toml | 10 +++--- src/cache.rs | 87 +++++++++++++++++++++++++++++++-------------- src/lib.rs | 8 ++--- src/main.rs | 7 ++-- src/progress_bar.rs | 54 +++++++++++++++++----------- 5 files changed, 108 insertions(+), 58 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 41dada5..8b61efc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,10 +22,10 @@ doc = false required-features = ["build-binary"] [dependencies] -fs2 = "0.4" -reqwest = { version = "0.11.0", default-features = false, features = [ - "blocking", -] } +tokio = { version = "1", features = ["fs", "io-util"] } +futures = "0.3" +fs4 = { version = "0.7", features = ["tokio"] } +reqwest = { version = "0.11.0", default-features = false, features = ["stream"]} sha2 = "0.10" tempfile = "3.1" log = "0.4" @@ -44,7 +44,7 @@ color-eyre = { version = "0.6", optional = true } [features] default = ["default-tls"] -build-binary = ["env_logger", "structopt", "color-eyre"] +build-binary = ["env_logger", "structopt", "color-eyre", "tokio/macros", "tokio/rt"] rustls-tls = ["reqwest/rustls-tls"] default-tls = ["reqwest/default-tls"] diff --git a/src/cache.rs b/src/cache.rs index 9735884..6f3da29 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,16 +1,18 @@ -use fs2::FileExt; +use fs4::tokio::AsyncFileExt; +use futures::StreamExt; use glob::glob; use log::{debug, error, info, warn}; use rand::distributions::{Distribution, Uniform}; -use reqwest::blocking::{Client, ClientBuilder}; use reqwest::header::ETAG; +use reqwest::{Client, ClientBuilder}; use std::default::Default; use std::env; -use std::fs::{self, OpenOptions}; +use std::fs::{self}; use std::path::{Path, PathBuf}; use std::thread; use std::time::{self, Duration}; use tempfile::NamedTempFile; +use tokio::fs::OpenOptions; use crate::archives::{extract_archive, ArchiveFormat}; use crate::utils::hash_str; @@ -39,7 +41,7 @@ impl CacheBuilder { CacheBuilder { config: Config { dir: None, - client_builder: ClientBuilder::new().timeout(None), + client_builder: ClientBuilder::new(), max_retries: 3, max_backoff: 5000, freshness_lifetime: None, @@ -217,8 +219,9 @@ impl Cache { /// /// If the resource is local file, it's path is returned. If the resource is a static HTTP /// resource, it will cached locally and the path to the cache file will be returned. - pub fn cached_path(&self, resource: &str) -> Result { + pub async fn cached_path(&self, resource: &str) -> Result { self.cached_path_with_options(resource, &Options::default()) + .await } /// Get the cached path to a resource using the given options. @@ -252,7 +255,7 @@ impl Cache { /// ).unwrap(); /// assert!(path.is_dir()); /// ``` - pub fn cached_path_with_options( + pub async fn cached_path_with_options( &self, resource: &str, options: &Options, @@ -289,7 +292,9 @@ impl Cache { } } else { // This is a remote resource, so fetch it to the cache. - let meta = self.fetch_remote_resource(resource, options.subdir.as_deref())?; + let meta = self + .fetch_remote_resource(resource, options.subdir.as_deref()) + .await?; // Check if we need to extract. if options.extract { @@ -313,7 +318,8 @@ impl Cache { .read(true) .write(true) .create(true) - .open(lock_path)?; + .open(lock_path) + .await?; filelock.lock_exclusive()?; debug!("Lock on extraction directory acquired for {}", resource); @@ -351,16 +357,20 @@ impl Cache { since = "0.4.4", note = "Please use Cache::cached_path_with_options() instead" )] - pub fn cached_path_in_subdir( + pub async fn cached_path_in_subdir( &self, resource: &str, subdir: Option<&str>, ) -> Result { let options = Options::new(subdir, false); - self.cached_path_with_options(resource, &options) + self.cached_path_with_options(resource, &options).await } - fn fetch_remote_resource(&self, resource: &str, subdir: Option<&str>) -> Result { + async fn fetch_remote_resource( + &self, + resource: &str, + subdir: Option<&str>, + ) -> Result { // Otherwise we attempt to parse the URL. let url = reqwest::Url::parse(resource).map_err(|_| Error::InvalidUrl(String::from(resource)))?; @@ -392,7 +402,7 @@ impl Cache { // No existing version or the existing versions are older than their freshness // lifetimes, so we'll query for the ETAG of the resource and then compare // that with any existing versions. - let etag = self.try_get_etag(resource, &url)?; + let etag = self.try_get_etag(resource, &url).await?; let path = self.resource_to_filepath(resource, &etag, subdir, None); // Before going further we need to obtain a lock on the file to provide @@ -403,7 +413,8 @@ impl Cache { .read(true) .write(true) .create(true) - .open(lock_path)?; + .open(lock_path) + .await?; filelock.lock_exclusive()?; debug!("Lock acquired for {}", resource); @@ -417,7 +428,9 @@ impl Cache { } // No up-to-date version cached, so we have to try downloading it. - let meta = self.try_download_resource(resource, &url, &path, &etag)?; + let meta = self + .try_download_resource(resource, &url, &path, &etag) + .await?; info!("New version of {} cached", resource); @@ -455,7 +468,7 @@ impl Cache { ) } - fn try_download_resource( + async fn try_download_resource( &self, resource: &str, url: &reqwest::Url, @@ -464,7 +477,7 @@ impl Cache { ) -> Result { let mut retries: u32 = 0; loop { - match self.download_resource(resource, url, path, etag) { + match self.download_resource(resource, url, path, etag).await { Ok(meta) => { return Ok(meta); } @@ -489,7 +502,7 @@ impl Cache { } } - fn download_resource( + async fn download_resource( &self, resource: &str, url: &reqwest::Url, @@ -498,10 +511,11 @@ impl Cache { ) -> Result { debug!("Attempting connection to {}", url); - let mut response = self + let response = self .http_client .get(url.clone()) - .send()? + .send() + .await? .error_for_status()?; debug!("Opened connection to {}", url); @@ -510,7 +524,9 @@ impl Cache { // Otherwise if we wrote directly to the cache file and the download got // interrupted we could be left with a corrupted cache file. let tempfile = NamedTempFile::new_in(path.parent().unwrap())?; - let mut tempfile_write_handle = OpenOptions::new().write(true).open(tempfile.path())?; + let mut tempfile_write_handle = + OpenOptions::new().write(true).open(tempfile.path()).await?; + let mut tempfile_write_handle = std::pin::pin!(tempfile_write_handle); info!("Starting download of {}", url); @@ -520,11 +536,21 @@ impl Cache { response.content_length(), tempfile_write_handle, ); - let bytes = response.copy_to(&mut download_wrapper)?; + let mut bytes_stream = response.bytes_stream(); + let mut bytes = 0; + while let Some(item) = bytes_stream.next().await { + bytes += tokio::io::copy(&mut item?.as_ref(), &mut download_wrapper).await?; + } + download_wrapper.finish(); bytes } else { - response.copy_to(&mut tempfile_write_handle)? + let mut bytes_stream = response.bytes_stream(); + let mut bytes = 0; + while let Some(item) = bytes_stream.next().await { + bytes += tokio::io::copy(&mut item?.as_ref(), &mut tempfile_write_handle).await?; + } + bytes }; info!("Downloaded {} bytes", bytes); @@ -545,10 +571,14 @@ impl Cache { Ok(meta) } - fn try_get_etag(&self, resource: &str, url: &reqwest::Url) -> Result, Error> { + async fn try_get_etag( + &self, + resource: &str, + url: &reqwest::Url, + ) -> Result, Error> { let mut retries: u32 = 0; loop { - match self.get_etag(url) { + match self.get_etag(url).await { Ok(etag) => return Ok(etag), Err(err) => { if retries >= self.max_retries { @@ -571,13 +601,16 @@ impl Cache { } } - fn get_etag(&self, url: &reqwest::Url) -> Result, Error> { + async fn get_etag(&self, url: &reqwest::Url) -> Result, Error> { debug!("Fetching ETAG for {}", url); + let response = self .http_client - .head(url.clone()) - .send()? + .get(url.clone()) + .send() + .await? .error_for_status()?; + if let Some(etag) = response.headers().get(ETAG) { if let Ok(s) = etag.to_str() { Ok(Some(s.into())) diff --git a/src/lib.rs b/src/lib.rs index 2b708b6..1d59d25 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -109,9 +109,9 @@ pub use crate::progress_bar::ProgressBar; /// with a temporary [`Cache`](crate::cache::Cache) object. /// Therefore if you're going to be calling this function multiple times, /// it's more efficient to create and use a single `Cache` instead. -pub fn cached_path(resource: &str) -> Result { +pub async fn cached_path(resource: &str) -> Result { let cache = Cache::builder().build()?; - cache.cached_path(resource) + cache.cached_path(resource).await } /// Get the cached path to a resource using the given options. @@ -121,9 +121,9 @@ pub fn cached_path(resource: &str) -> Result { /// with a temporary [`Cache`](crate::cache::Cache) object. /// Therefore if you're going to be calling this function multiple times, /// it's more efficient to create and use a single `Cache` instead. -pub fn cached_path_with_options(resource: &str, options: &Options) -> Result { +pub async fn cached_path_with_options(resource: &str, options: &Options) -> Result { let cache = Cache::builder().build()?; - cache.cached_path_with_options(resource, options) + cache.cached_path_with_options(resource, options).await } #[cfg(test)] diff --git a/src/main.rs b/src/main.rs index e5b7f8c..bd4576e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -59,7 +59,8 @@ struct Opt { quietly: bool, } -fn main() -> Result<()> { +#[tokio::main] +async fn main() -> Result<()> { color_eyre::install()?; env_logger::init(); let opt = Opt::from_args(); @@ -68,7 +69,9 @@ fn main() -> Result<()> { let cache = build_cache_from_opt(&opt)?; let options = Options::new(opt.subdir.as_deref(), opt.extract); - let path = cache.cached_path_with_options(&opt.resource, &options)?; + let path = cache + .cached_path_with_options(&opt.resource, &options) + .await?; println!("{}", path.to_string_lossy()); Ok(()) diff --git a/src/progress_bar.rs b/src/progress_bar.rs index e5766b3..b3721b9 100644 --- a/src/progress_bar.rs +++ b/src/progress_bar.rs @@ -1,5 +1,8 @@ use std::io::{self, Write}; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Instant; +use tokio::io::AsyncWrite; /// Progress bar types. /// @@ -23,11 +26,11 @@ impl Default for ProgressBar { } impl ProgressBar { - pub(crate) fn wrap_download( - &self, + pub(crate) fn wrap_download<'a, W: AsyncWrite>( + &'a self, resource: &str, content_length: Option, - writer: W, + writer: Pin<&'a mut W>, ) -> DownloadWrapper { let bar: Box = match self { ProgressBar::Full => Box::new(FullDownloadBar::new(content_length)), @@ -37,16 +40,17 @@ impl ProgressBar { } } -pub(crate) struct DownloadWrapper { +pub(crate) struct DownloadWrapper<'a, W: AsyncWrite> { bar: Box, - writer: W, + writer: Pin<&'a mut W>, } -impl DownloadWrapper +impl<'a, W> DownloadWrapper<'a, W> where - W: Write, + W: AsyncWrite, { - fn new(bar: Box, writer: W) -> Self { + fn new(bar: Box, writer: Pin<&'a mut W>) -> Self { + // let writer = std::pin::pin!(writer); Self { bar, writer } } @@ -55,27 +59,37 @@ where } } -impl Write for DownloadWrapper { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.writer.write(buf) +impl AsyncWrite for DownloadWrapper<'_, W> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.writer.as_mut().poll_write(cx, buf) } - fn flush(&mut self) -> io::Result<()> { - self.writer.flush() + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.writer.as_mut().poll_flush(cx) } - fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> io::Result { - self.writer.write_vectored(bufs) + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice], + ) -> Poll> { + self.writer.as_mut().poll_write_vectored(cx, bufs) } - fn write_all(&mut self, buf: &[u8]) -> io::Result<()> { - self.writer.write_all(buf).map(|()| { - self.bar.tick(buf.len()); - }) + fn is_write_vectored(&self) -> bool { + self.writer.is_write_vectored() + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.writer.as_mut().poll_shutdown(cx) } } -trait DownloadBar { +trait DownloadBar: Send + Sync { fn tick(&mut self, chunk_size: usize); fn finish(&self);