You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
189 lines
5.1 KiB
Rust
189 lines
5.1 KiB
Rust
use crate::discord::api::Error::Failed;
|
|
use crate::discord::dto::{GatewayBot, RateLimitNotice};
|
|
use crate::discord::ws::{connect_discord_ws, DiscordReceiver, DiscordSocket};
|
|
use log::debug;
|
|
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
|
|
use reqwest::{Client, ClientBuilder, IntoUrl, Request, RequestBuilder, Url};
|
|
use serde::de::DeserializeOwned;
|
|
use std::borrow::BorrowMut;
|
|
use std::cmp::max;
|
|
use std::collections::HashMap;
|
|
use std::future::Future;
|
|
use std::ops::{Add, DerefMut, Sub};
|
|
use std::pin::Pin;
|
|
use std::sync::{Arc, RwLock};
|
|
use std::task::{Context, Poll};
|
|
use std::time::Instant;
|
|
use tokio::time::{delay_for, delay_until, Delay, Duration};
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct DiscordAPI {
|
|
rate_limit: Arc<RwLock<RateLimit>>,
|
|
base_url: Url,
|
|
client: Client,
|
|
token: String,
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
struct RateLimit {
|
|
global: Instant,
|
|
per_url: HashMap<Url, Instant>,
|
|
}
|
|
|
|
impl RateLimit {
|
|
fn new() -> RateLimit {
|
|
RateLimit {
|
|
global: Instant::now().sub(Duration::from_secs(5)),
|
|
per_url: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
fn wait(&self, url: Url) -> OptionalDelay {
|
|
let now = Instant::now();
|
|
let mut delay_till = None;
|
|
if now < self.global {
|
|
delay_till = Some(self.global);
|
|
}
|
|
|
|
if let Some(per_url) = self.per_url.get(&url) {
|
|
if now < *per_url {
|
|
if let Some(other) = delay_till {
|
|
if other < *per_url {
|
|
delay_till = Some(*per_url);
|
|
}
|
|
} else {
|
|
delay_till = Some(*per_url)
|
|
}
|
|
}
|
|
}
|
|
|
|
delay_till.map(Into::into).map(delay_until).into()
|
|
}
|
|
|
|
fn update(&mut self, url: Url, notice: RateLimitNotice) -> OptionalDelay {
|
|
let after = Instant::now().add(Duration::from_millis(notice.retry_after));
|
|
|
|
if notice.global {
|
|
self.global = max(self.global, after);
|
|
}
|
|
|
|
self.per_url
|
|
.entry(url.clone())
|
|
.and_modify(|&mut mut x| {
|
|
let old = x.clone();
|
|
std::mem::replace(&mut x, max(old, after));
|
|
})
|
|
.or_insert(after);
|
|
|
|
self.wait(url)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum Error {
|
|
RequestError(reqwest::Error),
|
|
Failed(u16),
|
|
}
|
|
|
|
impl From<reqwest::Error> for Error {
|
|
fn from(x: reqwest::Error) -> Self {
|
|
Error::RequestError(x)
|
|
}
|
|
}
|
|
|
|
struct OptionalDelay(Option<Delay>);
|
|
|
|
impl Into<OptionalDelay> for Option<Delay> {
|
|
fn into(self) -> OptionalDelay {
|
|
OptionalDelay(self)
|
|
}
|
|
}
|
|
|
|
impl Future for OptionalDelay {
|
|
type Output = ();
|
|
|
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
if self.0.is_none() {
|
|
Poll::Ready(())
|
|
} else {
|
|
let pin = unsafe { Pin::new_unchecked(self.get_mut().0.as_mut().unwrap()) };
|
|
|
|
pin.poll(cx)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl DiscordAPI {
|
|
pub fn url(&self, path: &str) -> Url {
|
|
self.base_url.join(path).unwrap()
|
|
}
|
|
|
|
pub fn create(token: &str) -> reqwest::Result<DiscordAPI> {
|
|
let mut headers = HeaderMap::new();
|
|
|
|
headers.insert(
|
|
"Authorization",
|
|
HeaderValue::from_str(&format!("Bot {}", token)).unwrap(),
|
|
);
|
|
|
|
let client = ClientBuilder::new()
|
|
.connection_verbose(true)
|
|
.default_headers(headers)
|
|
.build()?;
|
|
let base_url = Url::parse("https://discordapp.com/api/").unwrap();
|
|
let rate_limit = RateLimit::new();
|
|
|
|
Ok(DiscordAPI {
|
|
rate_limit: Arc::new(RwLock::new(rate_limit)),
|
|
base_url,
|
|
client,
|
|
token: token.to_string(),
|
|
})
|
|
}
|
|
|
|
async fn send<T: DeserializeOwned>(
|
|
&self,
|
|
url: Url,
|
|
request: RequestBuilder,
|
|
) -> Result<T, Error> {
|
|
loop {
|
|
let delay = { self.rate_limit.try_read().unwrap().wait(url.clone()) };
|
|
delay.await;
|
|
|
|
let clone = request.try_clone().expect("Can't clone body");
|
|
debug!("Requesting: {:?}", url);
|
|
let response = clone.send().await?;
|
|
|
|
let status_code = response.status().as_u16();
|
|
if status_code == 429 {
|
|
let rate_limit: RateLimitNotice = response.json().await?;
|
|
let delay = {
|
|
self.rate_limit
|
|
.try_write()
|
|
.unwrap()
|
|
.update(url.clone(), rate_limit)
|
|
};
|
|
|
|
delay.await;
|
|
continue;
|
|
}
|
|
|
|
if status_code != 200 {
|
|
return Err(Failed(status_code));
|
|
}
|
|
|
|
return Ok(response.json::<T>().await?);
|
|
}
|
|
}
|
|
|
|
pub async fn get_gateway_bot(&self) -> Result<GatewayBot, Error> {
|
|
let url = self.url("gateway/bot");
|
|
self.send(url.clone(), self.client.get(url)).await
|
|
}
|
|
|
|
pub async fn connect_websocket_bot(&self) -> Result<(DiscordSocket, DiscordReceiver), Error> {
|
|
let gateway = self.get_gateway_bot().await?;
|
|
Ok(connect_discord_ws(gateway.url.clone(), self.token.clone()).await?)
|
|
}
|
|
}
|