#![allow(dead_code)] use crate::session_manager::SessionManager; use bytes::{Buf, Bytes, BytesMut}; use crossbeam_channel::{bounded, unbounded, Receiver, Sender}; use polling::{Event, Poller}; use rand::seq::SliceRandom; use std::cmp::max; use std::collections::{HashMap, HashSet, VecDeque}; use std::convert::TryInto; use std::io::{ErrorKind, Read, Write}; use std::net::{IpAddr, Ipv6Addr, SocketAddr, TcpListener, TcpStream}; use std::ops::Add; use std::option::Option::Some; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use std::thread::JoinHandle; use std::time::{Duration, Instant}; use torment_core::infohash::v1::U160; use torment_core::LookupFilter; use torment_peer::message::{Handshake, Message}; use torment_peer::PeerProtocol; use torment_peer::PeerProtocol::TCP; pub struct TormentInstance { id: U160, port: u16, dht_enabled: bool, connections: OffshoreConnections, tracking: HashMap, handshake_sent: HashSet, session: SessionManager, } struct OffshoreConnections { control_receiver: Receiver, thread: JoinHandle<()>, thread_sender: Sender, } enum ControlMessage { Data(SocketAddr, Bytes), Handshake(Handshake, SocketAddr, PeerProtocol), Own(TcpStream, SocketAddr), Disown(SocketAddr), } struct Stream { buffer: BytesMut, addr: SocketAddr, stream: TcpStream, handshake_received: bool, id: usize, handshake_sent: bool, } fn connection_pool_thread( port: u16, control_sender: Sender, thread_receiver: Receiver, ) { let mut streams: HashMap = HashMap::new(); let counter = AtomicUsize::new(0); let mut index: HashMap = HashMap::new(); let listener = TcpListener::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port)).unwrap(); let (notify_sender, notify_receiver) = bounded(1); let (notify_restart_sender, notify_restart_receiver) = bounded(1); let poller = Arc::new(Poller::new().unwrap()); let thread_receiver_clone = thread_receiver.clone(); let poller_clone = Arc::clone(&poller); std::thread::Builder::new() .name(format!("io/notify")) .spawn(move || { for message in thread_receiver { // println!("Received channel message, notifying wait"); poller_clone.notify().unwrap(); notify_sender.send(message).unwrap(); notify_restart_receiver.recv().unwrap(); } }); poller.insert(&listener).unwrap(); poller.interest(&listener, Event::readable(usize::MAX)); let mut buffer = [0u8; 1024 * 1024]; loop { let mut events = vec![]; match poller.wait(&mut events, None) { Ok(_) => {} Err(err) if err.kind() == ErrorKind::Interrupted => {} Err(err) => Err(err).unwrap(), } // println!("Events => {:?}", events); for event in events { if !event.readable { continue; } if event.key == usize::MAX { while let Ok((stream, addr)) = listener.accept() { let id = counter.fetch_add(1, Ordering::AcqRel); stream.set_nodelay(true).unwrap(); poller.insert(&stream).unwrap(); poller.interest(&stream, Event::readable(id)).unwrap(); streams.insert( addr, Stream { addr, handshake_received: false, handshake_sent: false, buffer: BytesMut::new(), stream, id, }, ); index.insert(id, addr); } poller .interest(&listener, Event::readable(usize::MAX)) .unwrap(); continue; } let addr = index[&event.key]; let size = streams .get_mut(&addr) .unwrap() .stream .read(&mut buffer) .unwrap_or(0); if size == 0 { control_sender.send(ControlMessage::Disown(addr)).unwrap(); index.remove(&event.key); if let Some(stream) = streams.remove(&addr) { poller.remove(&stream.stream).unwrap(); } println!("{} == Disconnected", addr); continue; } // println!("{} => {} bytes", addr, size); let stream = streams.get_mut(&addr).unwrap(); stream.buffer.extend_from_slice(&buffer[0..size]); if stream.buffer.len() >= 68 && !stream.handshake_received { let handshake = Handshake::from_bytes(stream.buffer.split_to(68).freeze()); let handshake = match handshake { Err(_) => { control_sender.send(ControlMessage::Disown(addr)).unwrap(); index.remove(&event.key); if let Some(stream) = streams.remove(&addr) { poller.remove(&stream.stream).unwrap(); } continue; } Ok(hand) => hand, }; println!("{} => {:?}", stream.addr, handshake); control_sender .send(ControlMessage::Handshake(handshake, stream.addr, TCP)) .unwrap(); stream.handshake_received = true; } poller .interest(&stream.stream, Event::readable(event.key)) .unwrap(); if !stream.handshake_received { continue; } while stream.buffer.len() >= 4 { let len = u32::from_be_bytes(stream.buffer[..4].try_into().unwrap()); if len == 0 { stream.buffer.advance(4); continue; } if len + 4 > stream.buffer.len() as u32 { break; } let mut message_bytes = stream.buffer.split_to((4 + len) as usize); message_bytes.advance(4); control_sender .send(ControlMessage::Data(stream.addr, message_bytes.freeze())) .unwrap(); } } let mut msgs = vec![]; if let Ok(message) = notify_receiver.try_recv() { msgs.push(message); } while let Ok(message) = thread_receiver_clone.try_recv() { msgs.push(message); } for message in msgs { match message { ControlMessage::Data(target, data) => { let ok = if let Some(stream) = streams.get_mut(&target) { if stream.stream.write_all(&data).is_err() { control_sender .send(ControlMessage::Disown(stream.addr)) .unwrap(); false } else { true } } else { true }; if !ok { if let Some(stream) = streams.remove(&target) { poller.remove(&stream.stream).unwrap(); index.remove(&stream.id); println!("{} == Disconnected", stream.addr); } } } ControlMessage::Disown(addr) => { if let Some(stream) = streams.remove(&addr) { index.remove(&stream.id); poller.remove(&stream.stream).unwrap(); } } ControlMessage::Own(stream, addr) => { let id = counter.fetch_add(1, Ordering::AcqRel); poller.insert(&stream).unwrap(); poller.interest(&stream, Event::readable(id)).unwrap(); streams.insert( addr, Stream { addr, handshake_received: false, handshake_sent: true, buffer: BytesMut::new(), stream, id, }, ); index.insert(id, addr); } _ => {} }; } notify_restart_sender.send(()).unwrap(); } } impl TormentInstance { pub fn new(port: u16, session: SessionManager) -> TormentInstance { let (control_sender, control_receiver) = unbounded(); let (thread_sender, thread_receiver) = unbounded(); let thread = std::thread::Builder::new() .name(format!("io/poll")) .spawn(move || connection_pool_thread(port, control_sender, thread_receiver)) .unwrap(); TormentInstance { id: session.id(), port, dht_enabled: false, connections: OffshoreConnections { control_receiver, thread, thread_sender, }, tracking: Default::default(), handshake_sent: Default::default(), session, } } pub fn tracker_logic(&mut self) { self.session.announce(); self.session.tracker_manager_mut().house_keeping(); for torrent in self.session.torrents() { let peer_count = self.session.peer_count(torrent); if peer_count > 25 { continue; } let mut peers = self .session .peer_storage .get_peers(torrent, LookupFilter::All); println!("have {} peers for {}", peers.len(), torrent); peers.shuffle(&mut rand::thread_rng()); let mut todo = max(0, 25 - peer_count); while todo > 0 && peers.len() > 0 { let peer = peers.pop().unwrap(); if self.tracking.contains_key(&peer) { continue; } let thread_sender = self.connections.thread_sender.clone(); let peer_id = self.id; self.handshake_sent.insert(peer); std::thread::spawn(move || { // println!("Trying connection with {}", peer); if let Ok(mut stream) = TcpStream::connect(peer) { stream.set_nodelay(true).unwrap(); stream .write_all(&Handshake::new(peer_id, torrent).to_bytes()) .unwrap(); // println!("{} <= Connecting", peer); thread_sender .send(ControlMessage::Own(stream, peer)) .unwrap(); } }); todo -= 1; } } } pub fn logic_loop(&mut self) { let mut next_house_keeping = Instant::now(); loop { if next_house_keeping < Instant::now() { self.session.house_keeping(); self.tracker_logic(); next_house_keeping = Instant::now().add(Duration::from_secs(10)); } while next_house_keeping > Instant::now() { if let Ok(message) = self .connections .control_receiver .recv_timeout(Duration::from_secs(1)) { match message { ControlMessage::Handshake(handshake, addr, protocol) => { let handshake_sent = self.handshake_sent.remove(&addr); self.tracking .insert(addr, (handshake.peer_id(), handshake.info_hash())); if self.session.handshake(handshake, addr, protocol) { if !handshake_sent { self.connections .thread_sender .send(ControlMessage::Data( addr, Bytes::from( Handshake::new(self.id, handshake.info_hash()) .to_bytes() .to_vec(), ), )) .unwrap(); } } } ControlMessage::Data(sock_addr, data) => { if let Some((peer_id, info_hash)) = self.tracking.get(&sock_addr).copied() { let message = Message::from_bytes(data); // println!("{} => {:?}", sock_addr, message); if message.is_err() { continue; } self.session.process(info_hash, peer_id, message.unwrap()); } else { // println!("{} => ????", sock_addr); } } _ => {} } } let queue: HashMap<_, VecDeque<_>> = self.session .dump_queue() .into_iter() .fold(HashMap::new(), |mut map, item| { map.entry(item.addr).or_default().push_back(item.message); map }); for (addr, messages) in queue { let mut bytes = BytesMut::new(); for msg in messages { // println!("{} <= {:?}", addr, msg); bytes.extend_from_slice(&msg.to_length_prefixed_bytes()); } self.connections .thread_sender .send(ControlMessage::Data(addr, bytes.freeze())) .unwrap() } } } } }