Skip to content

Commit

Permalink
Make a browser compatible rate limiter (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
playfulkittykat authored Oct 28, 2024
1 parent 299494b commit 0924c2c
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 11 deletions.
11 changes: 9 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ codecov = { repository = "nasso/rs621" }
[features]
default = ["rate-limit", "reqwest/default-tls"]
socks = ["reqwest/socks"]
rate-limit = ["tokio", "tokio/time", "tokio/sync"]
rate-limit = ["gloo-timers", "futures", "web-time", "tokio"]

[dependencies]
thiserror = "1"
Expand All @@ -33,8 +33,15 @@ derivative = "2"
itertools = "0.10"
futures = { version = "0.3", default-features = false }
reqwest = { version = ">=0.11, <0.13", default-features = false, features = ["json"] }
tokio = { optional = true, version = "1" }

[dev-dependencies]
mockito = "0.30"
tokio = { version = "1", features = ["rt-multi-thread", "macros"] }

[target.'cfg(target_family = "wasm")'.dependencies]
gloo-timers = { optional = true, version = "0.3", features = ["futures"] }
futures = { optional = true, version = "0.3", features = ["std", "alloc"] }
web-time = { optional = true, version = "1.1.0" }

[target.'cfg(not(target_family = "wasm"))'.dependencies]
tokio = { optional = true, version = "1", features = ["time", "sync"] }
16 changes: 13 additions & 3 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
#[cfg(feature = "rate-limit")]
#[cfg(all(target_family = "wasm", feature = "rate-limit"))]
#[path = "client/gloo_rate_limit.rs"]
mod rate_limit;

#[cfg(all(not(target_family = "wasm"), feature = "rate-limit"))]
#[path = "client/tokio_rate_limit.rs"]
mod rate_limit;

#[cfg(not(feature = "rate-limit"))]
#[path = "client/dummy_rate_limit.rs"]
mod rate_limit;

/// Forced cool down duration performed at every request. E621 allows at most 2 requests per second,
/// so the lowest safe value we can have here is 500 ms.
#[cfg(feature = "rate-limit")]
const REQ_COOLDOWN_DURATION: std::time::Duration = std::time::Duration::from_millis(600);

use futures::Future;
use reqwest::{Response, Url};
use serde::Serialize;

use {
super::error::{Error, Result},
reqwest::header::{HeaderMap, HeaderValue},
reqwest::header::HeaderMap,
};

#[cfg(any(target_arch = "wasm32", target_arch = "wasm64"))]
Expand All @@ -29,7 +39,7 @@ fn create_header_map<T: AsRef<[u8]>>(user_agent: T) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::USER_AGENT,
HeaderValue::from_bytes(user_agent.as_ref())
reqwest::header::HeaderValue::from_bytes(user_agent.as_ref())
.map_err(|e| Error::InvalidHeaderValue(format!("{}", e)))?,
);

Expand Down
53 changes: 53 additions & 0 deletions src/client/gloo_rate_limit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use super::REQ_COOLDOWN_DURATION;

use futures::lock::{Mutex, MutexGuard};

use std::future::Future;
use std::sync::Arc;

use web_time::Instant;

#[derive(Debug, Clone, Default)]
pub struct RateLimit {
// Use a `futures` `Mutex` because ~500ms is crazy long to block an async task.
deadline: Arc<Mutex<Option<Instant>>>,
}

struct Guard<'a>(MutexGuard<'a, Option<Instant>>);

impl<'a> Drop for Guard<'a> {
fn drop(&mut self) {
// Use a `Drop` impl so that updating the deadline is panic-safe.
*self.0 = Some(Instant::now() + REQ_COOLDOWN_DURATION);
}
}

impl RateLimit {
async fn lock(&self) -> Guard {
loop {
let now = Instant::now();

let deadline = {
let guard = self.deadline.lock().await;

match &*guard {
None => return Guard(guard),
Some(deadline) if now >= *deadline => return Guard(guard),
Some(deadline) => *deadline,
}
};

gloo_timers::future::sleep(deadline - now).await;
}
}

pub async fn check<F, R>(self, fut: F) -> R
where
F: Future<Output = R>,
{
let guard = self.lock().await;
let result = fut.await;
drop(guard);
result
}
}
10 changes: 4 additions & 6 deletions src/client/rate_limit.rs → src/client/tokio_rate_limit.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use futures::Future;
use super::REQ_COOLDOWN_DURATION;

use std::future::Future;

use std::sync::Arc;

use tokio::sync::{Mutex, MutexGuard};
use tokio::time::{sleep_until, Duration, Instant};

/// Forced cool down duration performed at every request. E621 allows at most 2 requests per second,
/// so the lowest safe value we can have here is 500 ms.
const REQ_COOLDOWN_DURATION: Duration = Duration::from_millis(600);
use tokio::time::{sleep_until, Instant};

#[derive(Debug, Clone, Default)]
pub struct RateLimit {
Expand Down

0 comments on commit 0924c2c

Please sign in to comment.