diff --git a/Cargo.toml b/Cargo.toml index 41dada5..8b61efc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,10 +22,10 @@ doc = false required-features = ["build-binary"] [dependencies] -fs2 = "0.4" -reqwest = { version = "0.11.0", default-features = false, features = [ - "blocking", -] } +tokio = { version = "1", features = ["fs", "io-util"] } +futures = "0.3" +fs4 = { version = "0.7", features = ["tokio"] } +reqwest = { version = "0.11.0", default-features = false, features = ["stream"]} sha2 = "0.10" tempfile = "3.1" log = "0.4" @@ -44,7 +44,7 @@ color-eyre = { version = "0.6", optional = true } [features] default = ["default-tls"] -build-binary = ["env_logger", "structopt", "color-eyre"] +build-binary = ["env_logger", "structopt", "color-eyre", "tokio/macros", "tokio/rt"] rustls-tls = ["reqwest/rustls-tls"] default-tls = ["reqwest/default-tls"] diff --git a/src/cache.rs b/src/cache.rs index 9735884..6f3da29 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -1,16 +1,18 @@ -use fs2::FileExt; +use fs4::tokio::AsyncFileExt; +use futures::StreamExt; use glob::glob; use log::{debug, error, info, warn}; use rand::distributions::{Distribution, Uniform}; -use reqwest::blocking::{Client, ClientBuilder}; use reqwest::header::ETAG; +use reqwest::{Client, ClientBuilder}; use std::default::Default; use std::env; -use std::fs::{self, OpenOptions}; +use std::fs::{self}; use std::path::{Path, PathBuf}; use std::thread; use std::time::{self, Duration}; use tempfile::NamedTempFile; +use tokio::fs::OpenOptions; use crate::archives::{extract_archive, ArchiveFormat}; use crate::utils::hash_str; @@ -39,7 +41,7 @@ impl CacheBuilder { CacheBuilder { config: Config { dir: None, - client_builder: ClientBuilder::new().timeout(None), + client_builder: ClientBuilder::new(), max_retries: 3, max_backoff: 5000, freshness_lifetime: None, @@ -217,8 +219,9 @@ impl Cache { /// /// If the resource is local file, it's path is returned. If the resource is a static HTTP /// resource, it will cached locally and the path to the cache file will be returned. - pub fn cached_path(&self, resource: &str) -> Result { + pub async fn cached_path(&self, resource: &str) -> Result { self.cached_path_with_options(resource, &Options::default()) + .await } /// Get the cached path to a resource using the given options. @@ -252,7 +255,7 @@ impl Cache { /// ).unwrap(); /// assert!(path.is_dir()); /// ``` - pub fn cached_path_with_options( + pub async fn cached_path_with_options( &self, resource: &str, options: &Options, @@ -289,7 +292,9 @@ impl Cache { } } else { // This is a remote resource, so fetch it to the cache. - let meta = self.fetch_remote_resource(resource, options.subdir.as_deref())?; + let meta = self + .fetch_remote_resource(resource, options.subdir.as_deref()) + .await?; // Check if we need to extract. if options.extract { @@ -313,7 +318,8 @@ impl Cache { .read(true) .write(true) .create(true) - .open(lock_path)?; + .open(lock_path) + .await?; filelock.lock_exclusive()?; debug!("Lock on extraction directory acquired for {}", resource); @@ -351,16 +357,20 @@ impl Cache { since = "0.4.4", note = "Please use Cache::cached_path_with_options() instead" )] - pub fn cached_path_in_subdir( + pub async fn cached_path_in_subdir( &self, resource: &str, subdir: Option<&str>, ) -> Result { let options = Options::new(subdir, false); - self.cached_path_with_options(resource, &options) + self.cached_path_with_options(resource, &options).await } - fn fetch_remote_resource(&self, resource: &str, subdir: Option<&str>) -> Result { + async fn fetch_remote_resource( + &self, + resource: &str, + subdir: Option<&str>, + ) -> Result { // Otherwise we attempt to parse the URL. let url = reqwest::Url::parse(resource).map_err(|_| Error::InvalidUrl(String::from(resource)))?; @@ -392,7 +402,7 @@ impl Cache { // No existing version or the existing versions are older than their freshness // lifetimes, so we'll query for the ETAG of the resource and then compare // that with any existing versions. - let etag = self.try_get_etag(resource, &url)?; + let etag = self.try_get_etag(resource, &url).await?; let path = self.resource_to_filepath(resource, &etag, subdir, None); // Before going further we need to obtain a lock on the file to provide @@ -403,7 +413,8 @@ impl Cache { .read(true) .write(true) .create(true) - .open(lock_path)?; + .open(lock_path) + .await?; filelock.lock_exclusive()?; debug!("Lock acquired for {}", resource); @@ -417,7 +428,9 @@ impl Cache { } // No up-to-date version cached, so we have to try downloading it. - let meta = self.try_download_resource(resource, &url, &path, &etag)?; + let meta = self + .try_download_resource(resource, &url, &path, &etag) + .await?; info!("New version of {} cached", resource); @@ -455,7 +468,7 @@ impl Cache { ) } - fn try_download_resource( + async fn try_download_resource( &self, resource: &str, url: &reqwest::Url, @@ -464,7 +477,7 @@ impl Cache { ) -> Result { let mut retries: u32 = 0; loop { - match self.download_resource(resource, url, path, etag) { + match self.download_resource(resource, url, path, etag).await { Ok(meta) => { return Ok(meta); } @@ -489,7 +502,7 @@ impl Cache { } } - fn download_resource( + async fn download_resource( &self, resource: &str, url: &reqwest::Url, @@ -498,10 +511,11 @@ impl Cache { ) -> Result { debug!("Attempting connection to {}", url); - let mut response = self + let response = self .http_client .get(url.clone()) - .send()? + .send() + .await? .error_for_status()?; debug!("Opened connection to {}", url); @@ -510,7 +524,9 @@ impl Cache { // Otherwise if we wrote directly to the cache file and the download got // interrupted we could be left with a corrupted cache file. let tempfile = NamedTempFile::new_in(path.parent().unwrap())?; - let mut tempfile_write_handle = OpenOptions::new().write(true).open(tempfile.path())?; + let mut tempfile_write_handle = + OpenOptions::new().write(true).open(tempfile.path()).await?; + let mut tempfile_write_handle = std::pin::pin!(tempfile_write_handle); info!("Starting download of {}", url); @@ -520,11 +536,21 @@ impl Cache { response.content_length(), tempfile_write_handle, ); - let bytes = response.copy_to(&mut download_wrapper)?; + let mut bytes_stream = response.bytes_stream(); + let mut bytes = 0; + while let Some(item) = bytes_stream.next().await { + bytes += tokio::io::copy(&mut item?.as_ref(), &mut download_wrapper).await?; + } + download_wrapper.finish(); bytes } else { - response.copy_to(&mut tempfile_write_handle)? + let mut bytes_stream = response.bytes_stream(); + let mut bytes = 0; + while let Some(item) = bytes_stream.next().await { + bytes += tokio::io::copy(&mut item?.as_ref(), &mut tempfile_write_handle).await?; + } + bytes }; info!("Downloaded {} bytes", bytes); @@ -545,10 +571,14 @@ impl Cache { Ok(meta) } - fn try_get_etag(&self, resource: &str, url: &reqwest::Url) -> Result, Error> { + async fn try_get_etag( + &self, + resource: &str, + url: &reqwest::Url, + ) -> Result, Error> { let mut retries: u32 = 0; loop { - match self.get_etag(url) { + match self.get_etag(url).await { Ok(etag) => return Ok(etag), Err(err) => { if retries >= self.max_retries { @@ -571,13 +601,16 @@ impl Cache { } } - fn get_etag(&self, url: &reqwest::Url) -> Result, Error> { + async fn get_etag(&self, url: &reqwest::Url) -> Result, Error> { debug!("Fetching ETAG for {}", url); + let response = self .http_client - .head(url.clone()) - .send()? + .get(url.clone()) + .send() + .await? .error_for_status()?; + if let Some(etag) = response.headers().get(ETAG) { if let Ok(s) = etag.to_str() { Ok(Some(s.into())) diff --git a/src/lib.rs b/src/lib.rs index 2b708b6..1d59d25 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -109,9 +109,9 @@ pub use crate::progress_bar::ProgressBar; /// with a temporary [`Cache`](crate::cache::Cache) object. /// Therefore if you're going to be calling this function multiple times, /// it's more efficient to create and use a single `Cache` instead. -pub fn cached_path(resource: &str) -> Result { +pub async fn cached_path(resource: &str) -> Result { let cache = Cache::builder().build()?; - cache.cached_path(resource) + cache.cached_path(resource).await } /// Get the cached path to a resource using the given options. @@ -121,9 +121,9 @@ pub fn cached_path(resource: &str) -> Result { /// with a temporary [`Cache`](crate::cache::Cache) object. /// Therefore if you're going to be calling this function multiple times, /// it's more efficient to create and use a single `Cache` instead. -pub fn cached_path_with_options(resource: &str, options: &Options) -> Result { +pub async fn cached_path_with_options(resource: &str, options: &Options) -> Result { let cache = Cache::builder().build()?; - cache.cached_path_with_options(resource, options) + cache.cached_path_with_options(resource, options).await } #[cfg(test)] diff --git a/src/main.rs b/src/main.rs index e5b7f8c..bd4576e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -59,7 +59,8 @@ struct Opt { quietly: bool, } -fn main() -> Result<()> { +#[tokio::main] +async fn main() -> Result<()> { color_eyre::install()?; env_logger::init(); let opt = Opt::from_args(); @@ -68,7 +69,9 @@ fn main() -> Result<()> { let cache = build_cache_from_opt(&opt)?; let options = Options::new(opt.subdir.as_deref(), opt.extract); - let path = cache.cached_path_with_options(&opt.resource, &options)?; + let path = cache + .cached_path_with_options(&opt.resource, &options) + .await?; println!("{}", path.to_string_lossy()); Ok(()) diff --git a/src/progress_bar.rs b/src/progress_bar.rs index e5766b3..b3721b9 100644 --- a/src/progress_bar.rs +++ b/src/progress_bar.rs @@ -1,5 +1,8 @@ use std::io::{self, Write}; +use std::pin::Pin; +use std::task::{Context, Poll}; use std::time::Instant; +use tokio::io::AsyncWrite; /// Progress bar types. /// @@ -23,11 +26,11 @@ impl Default for ProgressBar { } impl ProgressBar { - pub(crate) fn wrap_download( - &self, + pub(crate) fn wrap_download<'a, W: AsyncWrite>( + &'a self, resource: &str, content_length: Option, - writer: W, + writer: Pin<&'a mut W>, ) -> DownloadWrapper { let bar: Box = match self { ProgressBar::Full => Box::new(FullDownloadBar::new(content_length)), @@ -37,16 +40,17 @@ impl ProgressBar { } } -pub(crate) struct DownloadWrapper { +pub(crate) struct DownloadWrapper<'a, W: AsyncWrite> { bar: Box, - writer: W, + writer: Pin<&'a mut W>, } -impl DownloadWrapper +impl<'a, W> DownloadWrapper<'a, W> where - W: Write, + W: AsyncWrite, { - fn new(bar: Box, writer: W) -> Self { + fn new(bar: Box, writer: Pin<&'a mut W>) -> Self { + // let writer = std::pin::pin!(writer); Self { bar, writer } } @@ -55,27 +59,37 @@ where } } -impl Write for DownloadWrapper { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.writer.write(buf) +impl AsyncWrite for DownloadWrapper<'_, W> { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.writer.as_mut().poll_write(cx, buf) } - fn flush(&mut self) -> io::Result<()> { - self.writer.flush() + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.writer.as_mut().poll_flush(cx) } - fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> io::Result { - self.writer.write_vectored(bufs) + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice], + ) -> Poll> { + self.writer.as_mut().poll_write_vectored(cx, bufs) } - fn write_all(&mut self, buf: &[u8]) -> io::Result<()> { - self.writer.write_all(buf).map(|()| { - self.bar.tick(buf.len()); - }) + fn is_write_vectored(&self) -> bool { + self.writer.is_write_vectored() + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.writer.as_mut().poll_shutdown(cx) } } -trait DownloadBar { +trait DownloadBar: Send + Sync { fn tick(&mut self, chunk_size: usize); fn finish(&self);