You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

433 lines
15 KiB
Rust

#![allow(dead_code)]
use crate::session_manager::SessionManager;
use bytes::{Buf, Bytes, BytesMut};
use crossbeam_channel::{bounded, unbounded, Receiver, Sender};
use polling::{Event, Poller};
use rand::seq::SliceRandom;
use std::cmp::max;
use std::collections::{HashMap, HashSet, VecDeque};
use std::convert::TryInto;
use std::io::{ErrorKind, Read, Write};
use std::net::{IpAddr, Ipv6Addr, SocketAddr, TcpListener, TcpStream};
use std::ops::Add;
use std::option::Option::Some;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread::JoinHandle;
use std::time::{Duration, Instant};
use torment_core::infohash::v1::U160;
use torment_core::LookupFilter;
use torment_peer::message::{Handshake, Message};
use torment_peer::PeerProtocol;
use torment_peer::PeerProtocol::TCP;
pub struct TormentInstance {
id: U160,
port: u16,
dht_enabled: bool,
connections: OffshoreConnections,
tracking: HashMap<SocketAddr, (U160, U160)>,
handshake_sent: HashSet<SocketAddr>,
session: SessionManager,
}
struct OffshoreConnections {
control_receiver: Receiver<ControlMessage>,
thread: JoinHandle<()>,
thread_sender: Sender<ControlMessage>,
}
enum ControlMessage {
Data(SocketAddr, Bytes),
Handshake(Handshake, SocketAddr, PeerProtocol),
Own(TcpStream, SocketAddr),
Disown(SocketAddr),
}
struct Stream {
buffer: BytesMut,
addr: SocketAddr,
stream: TcpStream,
handshake_received: bool,
id: usize,
handshake_sent: bool,
}
fn connection_pool_thread(
port: u16,
control_sender: Sender<ControlMessage>,
thread_receiver: Receiver<ControlMessage>,
) {
let mut streams: HashMap<SocketAddr, Stream> = HashMap::new();
let counter = AtomicUsize::new(0);
let mut index: HashMap<usize, SocketAddr> = HashMap::new();
let listener =
TcpListener::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port)).unwrap();
let (notify_sender, notify_receiver) = bounded(1);
let (notify_restart_sender, notify_restart_receiver) = bounded(1);
let poller = Arc::new(Poller::new().unwrap());
let thread_receiver_clone = thread_receiver.clone();
let poller_clone = Arc::clone(&poller);
std::thread::Builder::new()
.name(format!("io/notify"))
.spawn(move || {
for message in thread_receiver {
// println!("Received channel message, notifying wait");
poller_clone.notify().unwrap();
notify_sender.send(message).unwrap();
notify_restart_receiver.recv().unwrap();
}
});
poller.insert(&listener).unwrap();
poller.interest(&listener, Event::readable(usize::MAX));
let mut buffer = [0u8; 1024 * 1024];
loop {
let mut events = vec![];
match poller.wait(&mut events, None) {
Ok(_) => {}
Err(err) if err.kind() == ErrorKind::Interrupted => {}
Err(err) => Err(err).unwrap(),
}
// println!("Events => {:?}", events);
for event in events {
if !event.readable {
continue;
}
if event.key == usize::MAX {
while let Ok((stream, addr)) = listener.accept() {
let id = counter.fetch_add(1, Ordering::AcqRel);
stream.set_nodelay(true).unwrap();
poller.insert(&stream).unwrap();
poller.interest(&stream, Event::readable(id)).unwrap();
streams.insert(
addr,
Stream {
addr,
handshake_received: false,
handshake_sent: false,
buffer: BytesMut::new(),
stream,
id,
},
);
index.insert(id, addr);
}
poller
.interest(&listener, Event::readable(usize::MAX))
.unwrap();
continue;
}
let addr = index[&event.key];
let size = streams
.get_mut(&addr)
.unwrap()
.stream
.read(&mut buffer)
.unwrap_or(0);
if size == 0 {
control_sender.send(ControlMessage::Disown(addr)).unwrap();
index.remove(&event.key);
if let Some(stream) = streams.remove(&addr) {
poller.remove(&stream.stream).unwrap();
}
println!("{} == Disconnected", addr);
continue;
}
// println!("{} => {} bytes", addr, size);
let stream = streams.get_mut(&addr).unwrap();
stream.buffer.extend_from_slice(&buffer[0..size]);
if stream.buffer.len() >= 68 && !stream.handshake_received {
let handshake = Handshake::from_bytes(stream.buffer.split_to(68).freeze());
let handshake = match handshake {
Err(_) => {
control_sender.send(ControlMessage::Disown(addr)).unwrap();
index.remove(&event.key);
if let Some(stream) = streams.remove(&addr) {
poller.remove(&stream.stream).unwrap();
}
continue;
}
Ok(hand) => hand,
};
println!("{} => {:?}", stream.addr, handshake);
control_sender
.send(ControlMessage::Handshake(handshake, stream.addr, TCP))
.unwrap();
stream.handshake_received = true;
}
poller
.interest(&stream.stream, Event::readable(event.key))
.unwrap();
if !stream.handshake_received {
continue;
}
while stream.buffer.len() >= 4 {
let len = u32::from_be_bytes(stream.buffer[..4].try_into().unwrap());
if len == 0 {
stream.buffer.advance(4);
continue;
}
if len + 4 > stream.buffer.len() as u32 {
break;
}
let mut message_bytes = stream.buffer.split_to((4 + len) as usize);
message_bytes.advance(4);
control_sender
.send(ControlMessage::Data(stream.addr, message_bytes.freeze()))
.unwrap();
}
}
let mut msgs = vec![];
if let Ok(message) = notify_receiver.try_recv() {
msgs.push(message);
}
while let Ok(message) = thread_receiver_clone.try_recv() {
msgs.push(message);
}
for message in msgs {
match message {
ControlMessage::Data(target, data) => {
let ok = if let Some(stream) = streams.get_mut(&target) {
if stream.stream.write_all(&data).is_err() {
control_sender
.send(ControlMessage::Disown(stream.addr))
.unwrap();
false
} else {
true
}
} else {
true
};
if !ok {
if let Some(stream) = streams.remove(&target) {
poller.remove(&stream.stream).unwrap();
index.remove(&stream.id);
println!("{} == Disconnected", stream.addr);
}
}
}
ControlMessage::Disown(addr) => {
if let Some(stream) = streams.remove(&addr) {
index.remove(&stream.id);
poller.remove(&stream.stream).unwrap();
}
}
ControlMessage::Own(stream, addr) => {
let id = counter.fetch_add(1, Ordering::AcqRel);
poller.insert(&stream).unwrap();
poller.interest(&stream, Event::readable(id)).unwrap();
streams.insert(
addr,
Stream {
addr,
handshake_received: false,
handshake_sent: true,
buffer: BytesMut::new(),
stream,
id,
},
);
index.insert(id, addr);
}
_ => {}
};
}
notify_restart_sender.send(()).unwrap();
}
}
impl TormentInstance {
pub fn new(port: u16, session: SessionManager) -> TormentInstance {
let (control_sender, control_receiver) = unbounded();
let (thread_sender, thread_receiver) = unbounded();
let thread = std::thread::Builder::new()
.name(format!("io/poll"))
.spawn(move || connection_pool_thread(port, control_sender, thread_receiver))
.unwrap();
TormentInstance {
id: session.id(),
port,
dht_enabled: false,
connections: OffshoreConnections {
control_receiver,
thread,
thread_sender,
},
tracking: Default::default(),
handshake_sent: Default::default(),
session,
}
}
pub fn tracker_logic(&mut self) {
self.session.announce();
self.session.tracker_manager_mut().house_keeping();
for torrent in self.session.torrents() {
let peer_count = self.session.peer_count(torrent);
if peer_count > 25 {
continue;
}
let mut peers = self
.session
.peer_storage
.get_peers(torrent, LookupFilter::All);
println!("have {} peers for {}", peers.len(), torrent);
peers.shuffle(&mut rand::thread_rng());
let mut todo = max(0, 25 - peer_count);
while todo > 0 && peers.len() > 0 {
let peer = peers.pop().unwrap();
if self.tracking.contains_key(&peer) {
continue;
}
let thread_sender = self.connections.thread_sender.clone();
let peer_id = self.id;
self.handshake_sent.insert(peer);
std::thread::spawn(move || {
// println!("Trying connection with {}", peer);
if let Ok(mut stream) = TcpStream::connect(peer) {
stream.set_nodelay(true).unwrap();
stream
.write_all(&Handshake::new(peer_id, torrent).to_bytes())
.unwrap();
// println!("{} <= Connecting", peer);
thread_sender
.send(ControlMessage::Own(stream, peer))
.unwrap();
}
});
todo -= 1;
}
}
}
pub fn logic_loop(&mut self) {
let mut next_house_keeping = Instant::now();
loop {
if next_house_keeping < Instant::now() {
self.session.house_keeping();
self.tracker_logic();
next_house_keeping = Instant::now().add(Duration::from_secs(10));
}
while next_house_keeping > Instant::now() {
if let Ok(message) = self
.connections
.control_receiver
.recv_timeout(Duration::from_secs(1))
{
match message {
ControlMessage::Handshake(handshake, addr, protocol) => {
let handshake_sent = self.handshake_sent.remove(&addr);
self.tracking
.insert(addr, (handshake.peer_id(), handshake.info_hash()));
if self.session.handshake(handshake, addr, protocol) {
if !handshake_sent {
self.connections
.thread_sender
.send(ControlMessage::Data(
addr,
Bytes::from(
Handshake::new(self.id, handshake.info_hash())
.to_bytes()
.to_vec(),
),
))
.unwrap();
}
}
}
ControlMessage::Data(sock_addr, data) => {
if let Some((peer_id, info_hash)) =
self.tracking.get(&sock_addr).copied()
{
let message = Message::from_bytes(data);
// println!("{} => {:?}", sock_addr, message);
if message.is_err() {
continue;
}
self.session.process(info_hash, peer_id, message.unwrap());
} else {
// println!("{} => ????", sock_addr);
}
}
_ => {}
}
}
let queue: HashMap<_, VecDeque<_>> =
self.session
.dump_queue()
.into_iter()
.fold(HashMap::new(), |mut map, item| {
map.entry(item.addr).or_default().push_back(item.message);
map
});
for (addr, messages) in queue {
let mut bytes = BytesMut::new();
for msg in messages {
// println!("{} <= {:?}", addr, msg);
bytes.extend_from_slice(&msg.to_length_prefixed_bytes());
}
self.connections
.thread_sender
.send(ControlMessage::Data(addr, bytes.freeze()))
.unwrap()
}
}
}
}
}