You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
433 lines
15 KiB
Rust
433 lines
15 KiB
Rust
#![allow(dead_code)]
|
|
|
|
use crate::session_manager::SessionManager;
|
|
use bytes::{Buf, Bytes, BytesMut};
|
|
use crossbeam_channel::{bounded, unbounded, Receiver, Sender};
|
|
use polling::{Event, Poller};
|
|
use rand::seq::SliceRandom;
|
|
use std::cmp::max;
|
|
use std::collections::{HashMap, HashSet, VecDeque};
|
|
use std::convert::TryInto;
|
|
use std::io::{ErrorKind, Read, Write};
|
|
use std::net::{IpAddr, Ipv6Addr, SocketAddr, TcpListener, TcpStream};
|
|
use std::ops::Add;
|
|
use std::option::Option::Some;
|
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
use std::sync::Arc;
|
|
use std::thread::JoinHandle;
|
|
use std::time::{Duration, Instant};
|
|
use torment_core::infohash::v1::U160;
|
|
use torment_core::LookupFilter;
|
|
use torment_peer::message::{Handshake, Message};
|
|
use torment_peer::PeerProtocol;
|
|
use torment_peer::PeerProtocol::TCP;
|
|
|
|
pub struct TormentInstance {
|
|
id: U160,
|
|
port: u16,
|
|
dht_enabled: bool,
|
|
connections: OffshoreConnections,
|
|
tracking: HashMap<SocketAddr, (U160, U160)>,
|
|
handshake_sent: HashSet<SocketAddr>,
|
|
session: SessionManager,
|
|
}
|
|
|
|
struct OffshoreConnections {
|
|
control_receiver: Receiver<ControlMessage>,
|
|
thread: JoinHandle<()>,
|
|
thread_sender: Sender<ControlMessage>,
|
|
}
|
|
|
|
enum ControlMessage {
|
|
Data(SocketAddr, Bytes),
|
|
Handshake(Handshake, SocketAddr, PeerProtocol),
|
|
Own(TcpStream, SocketAddr),
|
|
Disown(SocketAddr),
|
|
}
|
|
|
|
struct Stream {
|
|
buffer: BytesMut,
|
|
addr: SocketAddr,
|
|
stream: TcpStream,
|
|
handshake_received: bool,
|
|
id: usize,
|
|
handshake_sent: bool,
|
|
}
|
|
|
|
fn connection_pool_thread(
|
|
port: u16,
|
|
control_sender: Sender<ControlMessage>,
|
|
thread_receiver: Receiver<ControlMessage>,
|
|
) {
|
|
let mut streams: HashMap<SocketAddr, Stream> = HashMap::new();
|
|
let counter = AtomicUsize::new(0);
|
|
let mut index: HashMap<usize, SocketAddr> = HashMap::new();
|
|
let listener =
|
|
TcpListener::bind(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port)).unwrap();
|
|
|
|
let (notify_sender, notify_receiver) = bounded(1);
|
|
let (notify_restart_sender, notify_restart_receiver) = bounded(1);
|
|
|
|
let poller = Arc::new(Poller::new().unwrap());
|
|
let thread_receiver_clone = thread_receiver.clone();
|
|
let poller_clone = Arc::clone(&poller);
|
|
|
|
std::thread::Builder::new()
|
|
.name(format!("io/notify"))
|
|
.spawn(move || {
|
|
for message in thread_receiver {
|
|
// println!("Received channel message, notifying wait");
|
|
poller_clone.notify().unwrap();
|
|
notify_sender.send(message).unwrap();
|
|
notify_restart_receiver.recv().unwrap();
|
|
}
|
|
});
|
|
|
|
poller.insert(&listener).unwrap();
|
|
poller.interest(&listener, Event::readable(usize::MAX));
|
|
|
|
let mut buffer = [0u8; 1024 * 1024];
|
|
loop {
|
|
let mut events = vec![];
|
|
match poller.wait(&mut events, None) {
|
|
Ok(_) => {}
|
|
Err(err) if err.kind() == ErrorKind::Interrupted => {}
|
|
Err(err) => Err(err).unwrap(),
|
|
}
|
|
|
|
// println!("Events => {:?}", events);
|
|
|
|
for event in events {
|
|
if !event.readable {
|
|
continue;
|
|
}
|
|
|
|
if event.key == usize::MAX {
|
|
while let Ok((stream, addr)) = listener.accept() {
|
|
let id = counter.fetch_add(1, Ordering::AcqRel);
|
|
stream.set_nodelay(true).unwrap();
|
|
poller.insert(&stream).unwrap();
|
|
poller.interest(&stream, Event::readable(id)).unwrap();
|
|
streams.insert(
|
|
addr,
|
|
Stream {
|
|
addr,
|
|
handshake_received: false,
|
|
handshake_sent: false,
|
|
buffer: BytesMut::new(),
|
|
stream,
|
|
id,
|
|
},
|
|
);
|
|
index.insert(id, addr);
|
|
}
|
|
|
|
poller
|
|
.interest(&listener, Event::readable(usize::MAX))
|
|
.unwrap();
|
|
continue;
|
|
}
|
|
|
|
let addr = index[&event.key];
|
|
let size = streams
|
|
.get_mut(&addr)
|
|
.unwrap()
|
|
.stream
|
|
.read(&mut buffer)
|
|
.unwrap_or(0);
|
|
|
|
if size == 0 {
|
|
control_sender.send(ControlMessage::Disown(addr)).unwrap();
|
|
index.remove(&event.key);
|
|
if let Some(stream) = streams.remove(&addr) {
|
|
poller.remove(&stream.stream).unwrap();
|
|
}
|
|
println!("{} == Disconnected", addr);
|
|
continue;
|
|
}
|
|
|
|
// println!("{} => {} bytes", addr, size);
|
|
|
|
let stream = streams.get_mut(&addr).unwrap();
|
|
stream.buffer.extend_from_slice(&buffer[0..size]);
|
|
|
|
if stream.buffer.len() >= 68 && !stream.handshake_received {
|
|
let handshake = Handshake::from_bytes(stream.buffer.split_to(68).freeze());
|
|
let handshake = match handshake {
|
|
Err(_) => {
|
|
control_sender.send(ControlMessage::Disown(addr)).unwrap();
|
|
index.remove(&event.key);
|
|
if let Some(stream) = streams.remove(&addr) {
|
|
poller.remove(&stream.stream).unwrap();
|
|
}
|
|
|
|
continue;
|
|
}
|
|
Ok(hand) => hand,
|
|
};
|
|
|
|
println!("{} => {:?}", stream.addr, handshake);
|
|
control_sender
|
|
.send(ControlMessage::Handshake(handshake, stream.addr, TCP))
|
|
.unwrap();
|
|
|
|
stream.handshake_received = true;
|
|
}
|
|
|
|
poller
|
|
.interest(&stream.stream, Event::readable(event.key))
|
|
.unwrap();
|
|
|
|
if !stream.handshake_received {
|
|
continue;
|
|
}
|
|
|
|
while stream.buffer.len() >= 4 {
|
|
let len = u32::from_be_bytes(stream.buffer[..4].try_into().unwrap());
|
|
if len == 0 {
|
|
stream.buffer.advance(4);
|
|
continue;
|
|
}
|
|
|
|
if len + 4 > stream.buffer.len() as u32 {
|
|
break;
|
|
}
|
|
|
|
let mut message_bytes = stream.buffer.split_to((4 + len) as usize);
|
|
message_bytes.advance(4);
|
|
control_sender
|
|
.send(ControlMessage::Data(stream.addr, message_bytes.freeze()))
|
|
.unwrap();
|
|
}
|
|
}
|
|
|
|
let mut msgs = vec![];
|
|
if let Ok(message) = notify_receiver.try_recv() {
|
|
msgs.push(message);
|
|
}
|
|
|
|
while let Ok(message) = thread_receiver_clone.try_recv() {
|
|
msgs.push(message);
|
|
}
|
|
|
|
for message in msgs {
|
|
match message {
|
|
ControlMessage::Data(target, data) => {
|
|
let ok = if let Some(stream) = streams.get_mut(&target) {
|
|
if stream.stream.write_all(&data).is_err() {
|
|
control_sender
|
|
.send(ControlMessage::Disown(stream.addr))
|
|
.unwrap();
|
|
|
|
false
|
|
} else {
|
|
true
|
|
}
|
|
} else {
|
|
true
|
|
};
|
|
|
|
if !ok {
|
|
if let Some(stream) = streams.remove(&target) {
|
|
poller.remove(&stream.stream).unwrap();
|
|
index.remove(&stream.id);
|
|
println!("{} == Disconnected", stream.addr);
|
|
}
|
|
}
|
|
}
|
|
|
|
ControlMessage::Disown(addr) => {
|
|
if let Some(stream) = streams.remove(&addr) {
|
|
index.remove(&stream.id);
|
|
poller.remove(&stream.stream).unwrap();
|
|
}
|
|
}
|
|
|
|
ControlMessage::Own(stream, addr) => {
|
|
let id = counter.fetch_add(1, Ordering::AcqRel);
|
|
poller.insert(&stream).unwrap();
|
|
poller.interest(&stream, Event::readable(id)).unwrap();
|
|
streams.insert(
|
|
addr,
|
|
Stream {
|
|
addr,
|
|
handshake_received: false,
|
|
handshake_sent: true,
|
|
buffer: BytesMut::new(),
|
|
stream,
|
|
id,
|
|
},
|
|
);
|
|
index.insert(id, addr);
|
|
}
|
|
|
|
_ => {}
|
|
};
|
|
}
|
|
|
|
notify_restart_sender.send(()).unwrap();
|
|
}
|
|
}
|
|
|
|
impl TormentInstance {
|
|
pub fn new(port: u16, session: SessionManager) -> TormentInstance {
|
|
let (control_sender, control_receiver) = unbounded();
|
|
|
|
let (thread_sender, thread_receiver) = unbounded();
|
|
let thread = std::thread::Builder::new()
|
|
.name(format!("io/poll"))
|
|
.spawn(move || connection_pool_thread(port, control_sender, thread_receiver))
|
|
.unwrap();
|
|
|
|
TormentInstance {
|
|
id: session.id(),
|
|
port,
|
|
dht_enabled: false,
|
|
connections: OffshoreConnections {
|
|
control_receiver,
|
|
thread,
|
|
thread_sender,
|
|
},
|
|
tracking: Default::default(),
|
|
handshake_sent: Default::default(),
|
|
session,
|
|
}
|
|
}
|
|
|
|
pub fn tracker_logic(&mut self) {
|
|
self.session.announce();
|
|
self.session.tracker_manager_mut().house_keeping();
|
|
|
|
for torrent in self.session.torrents() {
|
|
let peer_count = self.session.peer_count(torrent);
|
|
if peer_count > 25 {
|
|
continue;
|
|
}
|
|
|
|
let mut peers = self
|
|
.session
|
|
.peer_storage
|
|
.get_peers(torrent, LookupFilter::All);
|
|
|
|
println!("have {} peers for {}", peers.len(), torrent);
|
|
|
|
peers.shuffle(&mut rand::thread_rng());
|
|
|
|
let mut todo = max(0, 25 - peer_count);
|
|
while todo > 0 && peers.len() > 0 {
|
|
let peer = peers.pop().unwrap();
|
|
if self.tracking.contains_key(&peer) {
|
|
continue;
|
|
}
|
|
|
|
let thread_sender = self.connections.thread_sender.clone();
|
|
let peer_id = self.id;
|
|
|
|
self.handshake_sent.insert(peer);
|
|
|
|
std::thread::spawn(move || {
|
|
// println!("Trying connection with {}", peer);
|
|
if let Ok(mut stream) = TcpStream::connect(peer) {
|
|
stream.set_nodelay(true).unwrap();
|
|
stream
|
|
.write_all(&Handshake::new(peer_id, torrent).to_bytes())
|
|
.unwrap();
|
|
|
|
// println!("{} <= Connecting", peer);
|
|
|
|
thread_sender
|
|
.send(ControlMessage::Own(stream, peer))
|
|
.unwrap();
|
|
}
|
|
});
|
|
|
|
todo -= 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn logic_loop(&mut self) {
|
|
let mut next_house_keeping = Instant::now();
|
|
|
|
loop {
|
|
if next_house_keeping < Instant::now() {
|
|
self.session.house_keeping();
|
|
self.tracker_logic();
|
|
next_house_keeping = Instant::now().add(Duration::from_secs(10));
|
|
}
|
|
|
|
while next_house_keeping > Instant::now() {
|
|
if let Ok(message) = self
|
|
.connections
|
|
.control_receiver
|
|
.recv_timeout(Duration::from_secs(1))
|
|
{
|
|
match message {
|
|
ControlMessage::Handshake(handshake, addr, protocol) => {
|
|
let handshake_sent = self.handshake_sent.remove(&addr);
|
|
|
|
self.tracking
|
|
.insert(addr, (handshake.peer_id(), handshake.info_hash()));
|
|
if self.session.handshake(handshake, addr, protocol) {
|
|
if !handshake_sent {
|
|
self.connections
|
|
.thread_sender
|
|
.send(ControlMessage::Data(
|
|
addr,
|
|
Bytes::from(
|
|
Handshake::new(self.id, handshake.info_hash())
|
|
.to_bytes()
|
|
.to_vec(),
|
|
),
|
|
))
|
|
.unwrap();
|
|
}
|
|
}
|
|
}
|
|
|
|
ControlMessage::Data(sock_addr, data) => {
|
|
if let Some((peer_id, info_hash)) =
|
|
self.tracking.get(&sock_addr).copied()
|
|
{
|
|
let message = Message::from_bytes(data);
|
|
// println!("{} => {:?}", sock_addr, message);
|
|
if message.is_err() {
|
|
continue;
|
|
}
|
|
|
|
self.session.process(info_hash, peer_id, message.unwrap());
|
|
} else {
|
|
// println!("{} => ????", sock_addr);
|
|
}
|
|
}
|
|
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
let queue: HashMap<_, VecDeque<_>> =
|
|
self.session
|
|
.dump_queue()
|
|
.into_iter()
|
|
.fold(HashMap::new(), |mut map, item| {
|
|
map.entry(item.addr).or_default().push_back(item.message);
|
|
map
|
|
});
|
|
for (addr, messages) in queue {
|
|
let mut bytes = BytesMut::new();
|
|
|
|
for msg in messages {
|
|
// println!("{} <= {:?}", addr, msg);
|
|
bytes.extend_from_slice(&msg.to_length_prefixed_bytes());
|
|
}
|
|
|
|
self.connections
|
|
.thread_sender
|
|
.send(ControlMessage::Data(addr, bytes.freeze()))
|
|
.unwrap()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|