You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

189 lines
5.1 KiB
Rust

use crate::discord::api::Error::Failed;
use crate::discord::dto::{GatewayBot, RateLimitNotice};
use crate::discord::ws::{connect_discord_ws, DiscordReceiver, DiscordSocket};
use log::debug;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use reqwest::{Client, ClientBuilder, IntoUrl, Request, RequestBuilder, Url};
use serde::de::DeserializeOwned;
use std::borrow::BorrowMut;
use std::cmp::max;
use std::collections::HashMap;
use std::future::Future;
use std::ops::{Add, DerefMut, Sub};
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use std::task::{Context, Poll};
use std::time::Instant;
use tokio::time::{delay_for, delay_until, Delay, Duration};
#[derive(Clone, Debug)]
pub struct DiscordAPI {
rate_limit: Arc<RwLock<RateLimit>>,
base_url: Url,
client: Client,
token: String,
}
#[derive(Clone, Debug)]
struct RateLimit {
global: Instant,
per_url: HashMap<Url, Instant>,
}
impl RateLimit {
fn new() -> RateLimit {
RateLimit {
global: Instant::now().sub(Duration::from_secs(5)),
per_url: HashMap::new(),
}
}
fn wait(&self, url: Url) -> OptionalDelay {
let now = Instant::now();
let mut delay_till = None;
if now < self.global {
delay_till = Some(self.global);
}
if let Some(per_url) = self.per_url.get(&url) {
if now < *per_url {
if let Some(other) = delay_till {
if other < *per_url {
delay_till = Some(*per_url);
}
} else {
delay_till = Some(*per_url)
}
}
}
delay_till.map(Into::into).map(delay_until).into()
}
fn update(&mut self, url: Url, notice: RateLimitNotice) -> OptionalDelay {
let after = Instant::now().add(Duration::from_millis(notice.retry_after));
if notice.global {
self.global = max(self.global, after);
}
self.per_url
.entry(url.clone())
.and_modify(|&mut mut x| {
let old = x.clone();
std::mem::replace(&mut x, max(old, after));
})
.or_insert(after);
self.wait(url)
}
}
#[derive(Debug)]
pub enum Error {
RequestError(reqwest::Error),
Failed(u16),
}
impl From<reqwest::Error> for Error {
fn from(x: reqwest::Error) -> Self {
Error::RequestError(x)
}
}
struct OptionalDelay(Option<Delay>);
impl Into<OptionalDelay> for Option<Delay> {
fn into(self) -> OptionalDelay {
OptionalDelay(self)
}
}
impl Future for OptionalDelay {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.0.is_none() {
Poll::Ready(())
} else {
let pin = unsafe { Pin::new_unchecked(self.get_mut().0.as_mut().unwrap()) };
pin.poll(cx)
}
}
}
impl DiscordAPI {
pub fn url(&self, path: &str) -> Url {
self.base_url.join(path).unwrap()
}
pub fn create(token: &str) -> reqwest::Result<DiscordAPI> {
let mut headers = HeaderMap::new();
headers.insert(
"Authorization",
HeaderValue::from_str(&format!("Bot {}", token)).unwrap(),
);
let client = ClientBuilder::new()
.connection_verbose(true)
.default_headers(headers)
.build()?;
let base_url = Url::parse("https://discordapp.com/api/").unwrap();
let rate_limit = RateLimit::new();
Ok(DiscordAPI {
rate_limit: Arc::new(RwLock::new(rate_limit)),
base_url,
client,
token: token.to_string(),
})
}
async fn send<T: DeserializeOwned>(
&self,
url: Url,
request: RequestBuilder,
) -> Result<T, Error> {
loop {
let delay = { self.rate_limit.try_read().unwrap().wait(url.clone()) };
delay.await;
let clone = request.try_clone().expect("Can't clone body");
debug!("Requesting: {:?}", url);
let response = clone.send().await?;
let status_code = response.status().as_u16();
if status_code == 429 {
let rate_limit: RateLimitNotice = response.json().await?;
let delay = {
self.rate_limit
.try_write()
.unwrap()
.update(url.clone(), rate_limit)
};
delay.await;
continue;
}
if status_code != 200 {
return Err(Failed(status_code));
}
return Ok(response.json::<T>().await?);
}
}
pub async fn get_gateway_bot(&self) -> Result<GatewayBot, Error> {
let url = self.url("gateway/bot");
self.send(url.clone(), self.client.get(url)).await
}
pub async fn connect_websocket_bot(&self) -> Result<(DiscordSocket, DiscordReceiver), Error> {
let gateway = self.get_gateway_bot().await?;
Ok(connect_discord_ws(gateway.url.clone(), self.token.clone()).await?)
}
}