From 0924c2ca0a7222d731bca883e803bde5c6335d94 Mon Sep 17 00:00:00 2001 From: playfulkittykat <69809064+playfulkittykat@users.noreply.github.com> Date: Mon, 28 Oct 2024 06:18:24 -0400 Subject: [PATCH] Make a browser compatible rate limiter (#11) --- Cargo.toml | 11 +++- src/client.rs | 16 ++++-- src/client/gloo_rate_limit.rs | 53 +++++++++++++++++++ .../{rate_limit.rs => tokio_rate_limit.rs} | 10 ++-- 4 files changed, 79 insertions(+), 11 deletions(-) create mode 100644 src/client/gloo_rate_limit.rs rename src/client/{rate_limit.rs => tokio_rate_limit.rs} (80%) diff --git a/Cargo.toml b/Cargo.toml index 11c8c34..11571b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,7 @@ codecov = { repository = "nasso/rs621" } [features] default = ["rate-limit", "reqwest/default-tls"] socks = ["reqwest/socks"] -rate-limit = ["tokio", "tokio/time", "tokio/sync"] +rate-limit = ["gloo-timers", "futures", "web-time", "tokio"] [dependencies] thiserror = "1" @@ -33,8 +33,15 @@ derivative = "2" itertools = "0.10" futures = { version = "0.3", default-features = false } reqwest = { version = ">=0.11, <0.13", default-features = false, features = ["json"] } -tokio = { optional = true, version = "1" } [dev-dependencies] mockito = "0.30" tokio = { version = "1", features = ["rt-multi-thread", "macros"] } + +[target.'cfg(target_family = "wasm")'.dependencies] +gloo-timers = { optional = true, version = "0.3", features = ["futures"] } +futures = { optional = true, version = "0.3", features = ["std", "alloc"] } +web-time = { optional = true, version = "1.1.0" } + +[target.'cfg(not(target_family = "wasm"))'.dependencies] +tokio = { optional = true, version = "1", features = ["time", "sync"] } diff --git a/src/client.rs b/src/client.rs index 1cfdc07..9bbfd66 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,17 +1,27 @@ -#[cfg(feature = "rate-limit")] +#[cfg(all(target_family = "wasm", feature = "rate-limit"))] +#[path = "client/gloo_rate_limit.rs"] +mod rate_limit; + +#[cfg(all(not(target_family = "wasm"), feature = "rate-limit"))] +#[path = "client/tokio_rate_limit.rs"] mod rate_limit; #[cfg(not(feature = "rate-limit"))] #[path = "client/dummy_rate_limit.rs"] mod rate_limit; +/// Forced cool down duration performed at every request. E621 allows at most 2 requests per second, +/// so the lowest safe value we can have here is 500 ms. +#[cfg(feature = "rate-limit")] +const REQ_COOLDOWN_DURATION: std::time::Duration = std::time::Duration::from_millis(600); + use futures::Future; use reqwest::{Response, Url}; use serde::Serialize; use { super::error::{Error, Result}, - reqwest::header::{HeaderMap, HeaderValue}, + reqwest::header::HeaderMap, }; #[cfg(any(target_arch = "wasm32", target_arch = "wasm64"))] @@ -29,7 +39,7 @@ fn create_header_map>(user_agent: T) -> Result { let mut headers = HeaderMap::new(); headers.insert( reqwest::header::USER_AGENT, - HeaderValue::from_bytes(user_agent.as_ref()) + reqwest::header::HeaderValue::from_bytes(user_agent.as_ref()) .map_err(|e| Error::InvalidHeaderValue(format!("{}", e)))?, ); diff --git a/src/client/gloo_rate_limit.rs b/src/client/gloo_rate_limit.rs new file mode 100644 index 0000000..5b1eead --- /dev/null +++ b/src/client/gloo_rate_limit.rs @@ -0,0 +1,53 @@ +use super::REQ_COOLDOWN_DURATION; + +use futures::lock::{Mutex, MutexGuard}; + +use std::future::Future; +use std::sync::Arc; + +use web_time::Instant; + +#[derive(Debug, Clone, Default)] +pub struct RateLimit { + // Use a `futures` `Mutex` because ~500ms is crazy long to block an async task. + deadline: Arc>>, +} + +struct Guard<'a>(MutexGuard<'a, Option>); + +impl<'a> Drop for Guard<'a> { + fn drop(&mut self) { + // Use a `Drop` impl so that updating the deadline is panic-safe. + *self.0 = Some(Instant::now() + REQ_COOLDOWN_DURATION); + } +} + +impl RateLimit { + async fn lock(&self) -> Guard { + loop { + let now = Instant::now(); + + let deadline = { + let guard = self.deadline.lock().await; + + match &*guard { + None => return Guard(guard), + Some(deadline) if now >= *deadline => return Guard(guard), + Some(deadline) => *deadline, + } + }; + + gloo_timers::future::sleep(deadline - now).await; + } + } + + pub async fn check(self, fut: F) -> R + where + F: Future, + { + let guard = self.lock().await; + let result = fut.await; + drop(guard); + result + } +} diff --git a/src/client/rate_limit.rs b/src/client/tokio_rate_limit.rs similarity index 80% rename from src/client/rate_limit.rs rename to src/client/tokio_rate_limit.rs index f3bfc64..1a63a70 100644 --- a/src/client/rate_limit.rs +++ b/src/client/tokio_rate_limit.rs @@ -1,13 +1,11 @@ -use futures::Future; +use super::REQ_COOLDOWN_DURATION; + +use std::future::Future; use std::sync::Arc; use tokio::sync::{Mutex, MutexGuard}; -use tokio::time::{sleep_until, Duration, Instant}; - -/// Forced cool down duration performed at every request. E621 allows at most 2 requests per second, -/// so the lowest safe value we can have here is 500 ms. -const REQ_COOLDOWN_DURATION: Duration = Duration::from_millis(600); +use tokio::time::{sleep_until, Instant}; #[derive(Debug, Clone, Default)] pub struct RateLimit {