Skip to content

Commit

Permalink
Make a browser compatible rate limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
playfulkittykat committed Oct 21, 2024
1 parent 236bac4 commit 4937bd3
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 17 deletions.
29 changes: 21 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
[package]
name = "rs621"
[workspace]
members = ["rs621-ratelimit"]

[workspace.package]
version = "0.7.0-alpha1"
repository = "https://github.com/nasso/rs621"
license = "MIT OR Apache-2.0"
authors = ["nasso <nassomails@gmail.com>"]
readme = "README.md"

[workspace.dependencies]
futures = { version = "0.3", default-features = false }

[package]
name = "rs621"
version.workspace = true
repository.workspace = true
license.workspace = true
authors.workspace = true
readme.workspace = true
edition = "2018"
description = "Rust crate for the E621 API (a large online archive of furry art)."
repository = "https://github.com/nasso/rs621"
readme = "README.md"
keywords = ["e621", "e926", "furry", "api", "client"]
categories = ["api-bindings"]
license = "MIT OR Apache-2.0"
exclude = ["src/mocked"]

[badges]
Expand All @@ -20,7 +33,7 @@ codecov = { repository = "nasso/rs621" }
[features]
default = ["rate-limit", "reqwest/default-tls"]
socks = ["reqwest/socks"]
rate-limit = ["tokio", "tokio/time", "tokio/sync"]
rate-limit = ["rs621-ratelimit"]

[dependencies]
thiserror = "1"
Expand All @@ -31,9 +44,9 @@ serde = { version = "1", features = ["derive"] }
serde_json = "1"
derivative = "2"
itertools = "0.10"
futures = { version = "0.3", default-features = false }
reqwest = { version = "0.11", default-features = false, features = ["json"] }
tokio = { optional = true, version = "1" }
rs621-ratelimit = { path = "./rs621-ratelimit", optional = true }
futures.workspace = true

[dev-dependencies]
mockito = "0.30"
Expand Down
17 changes: 17 additions & 0 deletions rs621-ratelimit/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "rs621-ratelimit"
edition = "2021"
description = "Request rate limiting library for rs621."
version.workspace = true
license.workspace = true
repository.workspace = true
authors.workspace = true
readme.workspace = true

[target.'cfg(target_family = "wasm")'.dependencies]
gloo-timers = { version = "0.3", features = ["futures"] }
futures = { workspace = true, features = ["std", "alloc"] }
web-time = "1.1.0"

[target.'cfg(not(target_family = "wasm"))'.dependencies]
tokio = { version = "1", features = ["time", "sync"] }
53 changes: 53 additions & 0 deletions rs621-ratelimit/src/gloo.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use crate::REQ_COOLDOWN_DURATION;

use futures::lock::{Mutex, MutexGuard};

use std::future::Future;
use std::sync::Arc;

use web_time::Instant;

#[derive(Debug, Clone, Default)]
pub struct RateLimit {
// Use a `futures` `Mutex` because ~500ms is crazy long to block an async task.
deadline: Arc<Mutex<Option<Instant>>>,
}

struct Guard<'a>(MutexGuard<'a, Option<Instant>>);

impl<'a> Drop for Guard<'a> {
fn drop(&mut self) {
// Use a `Drop` impl so that updating the deadline is panic-safe.
*self.0 = Some(Instant::now() + REQ_COOLDOWN_DURATION);
}
}

impl RateLimit {
async fn lock(&self) -> Guard {
loop {
let now = Instant::now();

let deadline = {
let guard = self.deadline.lock().await;

match &*guard {
None => return Guard(guard),
Some(deadline) if now >= *deadline => return Guard(guard),
Some(deadline) => *deadline,
}
};

gloo_timers::future::sleep(deadline - now).await;
}
}

pub async fn check<F, R>(self, fut: F) -> R
where
F: Future<Output = R>,
{
let guard = self.lock().await;
let result = fut.await;
drop(guard);
result
}
}
16 changes: 16 additions & 0 deletions rs621-ratelimit/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#[cfg(not(target_family = "wasm"))]
#[path = "tokio.rs"]
mod platform;

#[cfg(target_family = "wasm")]
#[path = "gloo.rs"]
mod platform;

#[doc(inline)]
pub use self::platform::*;

use std::time::Duration;

/// Forced cool down duration performed at every request. E621 allows at most 2 requests per second,
/// so the lowest safe value we can have here is 500 ms.
const REQ_COOLDOWN_DURATION: Duration = Duration::from_millis(600);
10 changes: 4 additions & 6 deletions src/client/rate_limit.rs → rs621-ratelimit/src/tokio.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use futures::Future;
use crate::REQ_COOLDOWN_DURATION;

use std::future::Future;

use std::sync::Arc;

use tokio::sync::{Mutex, MutexGuard};
use tokio::time::{sleep_until, Duration, Instant};

/// Forced cool down duration performed at every request. E621 allows at most 2 requests per second,
/// so the lowest safe value we can have here is 500 ms.
const REQ_COOLDOWN_DURATION: Duration = Duration::from_millis(600);
use tokio::time::{sleep_until, Instant};

#[derive(Debug, Clone, Default)]
pub struct RateLimit {
Expand Down
6 changes: 3 additions & 3 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#[cfg(feature = "rate-limit")]
mod rate_limit;
use rs621_ratelimit as rate_limit;

#[cfg(not(feature = "rate-limit"))]
#[path = "client/dummy_rate_limit.rs"]
Expand All @@ -11,7 +11,7 @@ use serde::Serialize;

use {
super::error::{Error, Result},
reqwest::header::{HeaderMap, HeaderValue},
reqwest::header::HeaderMap,
};

#[cfg(any(target_arch = "wasm32", target_arch = "wasm64"))]
Expand All @@ -29,7 +29,7 @@ fn create_header_map<T: AsRef<[u8]>>(user_agent: T) -> Result<HeaderMap> {
let mut headers = HeaderMap::new();
headers.insert(
reqwest::header::USER_AGENT,
HeaderValue::from_bytes(user_agent.as_ref())
reqwest::header::HeaderValue::from_bytes(user_agent.as_ref())
.map_err(|e| Error::InvalidHeaderValue(format!("{}", e)))?,
);

Expand Down

0 comments on commit 4937bd3

Please sign in to comment.