use crate::discord::api::Error::Failed; use crate::discord::dto::{GatewayBot, RateLimitNotice}; use crate::discord::ws::{connect_discord_ws, DiscordReceiver, DiscordSocket}; use log::debug; use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; use reqwest::{Client, ClientBuilder, IntoUrl, Request, RequestBuilder, Url}; use serde::de::DeserializeOwned; use std::borrow::BorrowMut; use std::cmp::max; use std::collections::HashMap; use std::future::Future; use std::ops::{Add, DerefMut, Sub}; use std::pin::Pin; use std::sync::{Arc, RwLock}; use std::task::{Context, Poll}; use std::time::Instant; use tokio::time::{delay_for, delay_until, Delay, Duration}; #[derive(Clone, Debug)] pub struct DiscordAPI { rate_limit: Arc>, base_url: Url, client: Client, token: String, } #[derive(Clone, Debug)] struct RateLimit { global: Instant, per_url: HashMap, } impl RateLimit { fn new() -> RateLimit { RateLimit { global: Instant::now().sub(Duration::from_secs(5)), per_url: HashMap::new(), } } fn wait(&self, url: Url) -> OptionalDelay { let now = Instant::now(); let mut delay_till = None; if now < self.global { delay_till = Some(self.global); } if let Some(per_url) = self.per_url.get(&url) { if now < *per_url { if let Some(other) = delay_till { if other < *per_url { delay_till = Some(*per_url); } } else { delay_till = Some(*per_url) } } } delay_till.map(Into::into).map(delay_until).into() } fn update(&mut self, url: Url, notice: RateLimitNotice) -> OptionalDelay { let after = Instant::now().add(Duration::from_millis(notice.retry_after)); if notice.global { self.global = max(self.global, after); } self.per_url .entry(url.clone()) .and_modify(|&mut mut x| { let old = x.clone(); std::mem::replace(&mut x, max(old, after)); }) .or_insert(after); self.wait(url) } } #[derive(Debug)] pub enum Error { RequestError(reqwest::Error), Failed(u16), } impl From for Error { fn from(x: reqwest::Error) -> Self { Error::RequestError(x) } } struct OptionalDelay(Option); impl Into for Option { fn into(self) -> OptionalDelay { OptionalDelay(self) } } impl Future for OptionalDelay { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { if self.0.is_none() { Poll::Ready(()) } else { let pin = unsafe { Pin::new_unchecked(self.get_mut().0.as_mut().unwrap()) }; pin.poll(cx) } } } impl DiscordAPI { pub fn url(&self, path: &str) -> Url { self.base_url.join(path).unwrap() } pub fn create(token: &str) -> reqwest::Result { let mut headers = HeaderMap::new(); headers.insert( "Authorization", HeaderValue::from_str(&format!("Bot {}", token)).unwrap(), ); let client = ClientBuilder::new() .connection_verbose(true) .default_headers(headers) .build()?; let base_url = Url::parse("https://discordapp.com/api/").unwrap(); let rate_limit = RateLimit::new(); Ok(DiscordAPI { rate_limit: Arc::new(RwLock::new(rate_limit)), base_url, client, token: token.to_string(), }) } async fn send( &self, url: Url, request: RequestBuilder, ) -> Result { loop { let delay = { self.rate_limit.try_read().unwrap().wait(url.clone()) }; delay.await; let clone = request.try_clone().expect("Can't clone body"); debug!("Requesting: {:?}", url); let response = clone.send().await?; let status_code = response.status().as_u16(); if status_code == 429 { let rate_limit: RateLimitNotice = response.json().await?; let delay = { self.rate_limit .try_write() .unwrap() .update(url.clone(), rate_limit) }; delay.await; continue; } if status_code != 200 { return Err(Failed(status_code)); } return Ok(response.json::().await?); } } pub async fn get_gateway_bot(&self) -> Result { let url = self.url("gateway/bot"); self.send(url.clone(), self.client.get(url)).await } pub async fn connect_websocket_bot(&self) -> Result<(DiscordSocket, DiscordReceiver), Error> { let gateway = self.get_gateway_bot().await?; Ok(connect_discord_ws(gateway.url.clone(), self.token.clone()).await?) } }