diff --git a/Cargo.toml b/Cargo.toml index 96ad74f..a4fa214 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,14 +1,27 @@ -[package] -name = "rs621" +[workspace] +members = ["rs621-ratelimit"] + +[workspace.package] version = "0.7.0-alpha1" +repository = "https://github.com/nasso/rs621" +license = "MIT OR Apache-2.0" authors = ["nasso "] +readme = "README.md" + +[workspace.dependencies] +futures = { version = "0.3", default-features = false } + +[package] +name = "rs621" +version.workspace = true +repository.workspace = true +license.workspace = true +authors.workspace = true +readme.workspace = true edition = "2018" description = "Rust crate for the E621 API (a large online archive of furry art)." -repository = "https://github.com/nasso/rs621" -readme = "README.md" keywords = ["e621", "e926", "furry", "api", "client"] categories = ["api-bindings"] -license = "MIT OR Apache-2.0" exclude = ["src/mocked"] [badges] @@ -20,7 +33,7 @@ codecov = { repository = "nasso/rs621" } [features] default = ["rate-limit", "reqwest/default-tls"] socks = ["reqwest/socks"] -rate-limit = ["tokio", "tokio/time", "tokio/sync"] +rate-limit = ["rs621-ratelimit"] [dependencies] thiserror = "1" @@ -31,9 +44,9 @@ serde = { version = "1", features = ["derive"] } serde_json = "1" derivative = "2" itertools = "0.10" -futures = { version = "0.3", default-features = false } reqwest = { version = "0.11", default-features = false, features = ["json"] } -tokio = { optional = true, version = "1" } +rs621-ratelimit = { path = "./rs621-ratelimit", optional = true } +futures.workspace = true [dev-dependencies] mockito = "0.30" diff --git a/rs621-ratelimit/Cargo.toml b/rs621-ratelimit/Cargo.toml new file mode 100644 index 0000000..7208f61 --- /dev/null +++ b/rs621-ratelimit/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "rs621-ratelimit" +edition = "2021" +description = "Request rate limiting library for rs621." +version.workspace = true +license.workspace = true +repository.workspace = true +authors.workspace = true +readme.workspace = true + +[target.'cfg(target_family = "wasm")'.dependencies] +gloo-timers = { version = "0.3", features = ["futures"] } +futures = { workspace = true, features = ["std", "alloc"] } +web-time = "1.1.0" + +[target.'cfg(not(target_family = "wasm"))'.dependencies] +tokio = { version = "1", features = ["time", "sync"] } diff --git a/rs621-ratelimit/src/gloo.rs b/rs621-ratelimit/src/gloo.rs new file mode 100644 index 0000000..b0aeb8c --- /dev/null +++ b/rs621-ratelimit/src/gloo.rs @@ -0,0 +1,53 @@ +use crate::REQ_COOLDOWN_DURATION; + +use futures::lock::{Mutex, MutexGuard}; + +use std::future::Future; +use std::sync::Arc; + +use web_time::Instant; + +#[derive(Debug, Clone, Default)] +pub struct RateLimit { + // Use a `futures` `Mutex` because ~500ms is crazy long to block an async task. + deadline: Arc>>, +} + +struct Guard<'a>(MutexGuard<'a, Option>); + +impl<'a> Drop for Guard<'a> { + fn drop(&mut self) { + // Use a `Drop` impl so that updating the deadline is panic-safe. + *self.0 = Some(Instant::now() + REQ_COOLDOWN_DURATION); + } +} + +impl RateLimit { + async fn lock(&self) -> Guard { + loop { + let now = Instant::now(); + + let deadline = { + let guard = self.deadline.lock().await; + + match &*guard { + None => return Guard(guard), + Some(deadline) if now >= *deadline => return Guard(guard), + Some(deadline) => *deadline, + } + }; + + gloo_timers::future::sleep(deadline - now).await; + } + } + + pub async fn check(self, fut: F) -> R + where + F: Future, + { + let guard = self.lock().await; + let result = fut.await; + drop(guard); + result + } +} diff --git a/rs621-ratelimit/src/lib.rs b/rs621-ratelimit/src/lib.rs new file mode 100644 index 0000000..fc7ace4 --- /dev/null +++ b/rs621-ratelimit/src/lib.rs @@ -0,0 +1,16 @@ +#[cfg(not(target_family = "wasm"))] +#[path = "tokio.rs"] +mod platform; + +#[cfg(target_family = "wasm")] +#[path = "gloo.rs"] +mod platform; + +#[doc(inline)] +pub use self::platform::*; + +use std::time::Duration; + +/// Forced cool down duration performed at every request. E621 allows at most 2 requests per second, +/// so the lowest safe value we can have here is 500 ms. +const REQ_COOLDOWN_DURATION: Duration = Duration::from_millis(600); diff --git a/src/client/rate_limit.rs b/rs621-ratelimit/src/tokio.rs similarity index 80% rename from src/client/rate_limit.rs rename to rs621-ratelimit/src/tokio.rs index f3bfc64..a2c8cbe 100644 --- a/src/client/rate_limit.rs +++ b/rs621-ratelimit/src/tokio.rs @@ -1,13 +1,11 @@ -use futures::Future; +use crate::REQ_COOLDOWN_DURATION; + +use std::future::Future; use std::sync::Arc; use tokio::sync::{Mutex, MutexGuard}; -use tokio::time::{sleep_until, Duration, Instant}; - -/// Forced cool down duration performed at every request. E621 allows at most 2 requests per second, -/// so the lowest safe value we can have here is 500 ms. -const REQ_COOLDOWN_DURATION: Duration = Duration::from_millis(600); +use tokio::time::{sleep_until, Instant}; #[derive(Debug, Clone, Default)] pub struct RateLimit { diff --git a/src/client.rs b/src/client.rs index 1cfdc07..a1b6ae1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,5 +1,5 @@ #[cfg(feature = "rate-limit")] -mod rate_limit; +use rs621_ratelimit as rate_limit; #[cfg(not(feature = "rate-limit"))] #[path = "client/dummy_rate_limit.rs"] @@ -11,7 +11,7 @@ use serde::Serialize; use { super::error::{Error, Result}, - reqwest::header::{HeaderMap, HeaderValue}, + reqwest::header::HeaderMap, }; #[cfg(any(target_arch = "wasm32", target_arch = "wasm64"))] @@ -29,7 +29,7 @@ fn create_header_map>(user_agent: T) -> Result { let mut headers = HeaderMap::new(); headers.insert( reqwest::header::USER_AGENT, - HeaderValue::from_bytes(user_agent.as_ref()) + reqwest::header::HeaderValue::from_bytes(user_agent.as_ref()) .map_err(|e| Error::InvalidHeaderValue(format!("{}", e)))?, );