You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

293 lines
8.3 KiB
Rust

use bytes::{Bytes, BytesMut};
use std::convert::TryInto;
use std::error::Error;
use std::fmt::{Debug, Display, Formatter};
use torment_core::infohash::v1::U160;
use torment_core::infohash::InfoHashCapable;
use torment_core::Bitfield;
#[derive(Clone, Debug, Ord, PartialOrd, Eq, PartialEq)]
pub enum Message {
Choke,
Unchoke,
Interested,
NotInterested,
Have(u32),
Bitfield(Bitfield),
Request(SelectionMessage),
Cancel(SelectionMessage),
Piece(PieceMessage),
}
#[derive(Clone, Debug)]
pub enum MessageParsingError {
TooShort,
UnknownType,
InvalidPrefix,
}
impl Error for MessageParsingError {}
impl Display for MessageParsingError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl Message {
pub fn from_bytes(bytes: Bytes) -> Result<Message, MessageParsingError> {
Ok(match bytes[0] {
0 => Message::Choke,
1 => Message::Unchoke,
2 => Message::Interested,
3 => Message::NotInterested,
4 => {
if bytes.len() < 5 {
return Err(MessageParsingError::TooShort);
}
Message::Have(u32::from_be_bytes(bytes[1..5].try_into().unwrap()))
}
5 => Message::Bitfield(Bitfield::new(bytes.slice(1..), bytes.len() - 1)),
6 | 8 => {
if bytes.len() < 13 {
return Err(MessageParsingError::TooShort);
}
let selection = SelectionMessage {
index: u32::from_be_bytes(bytes[1..5].try_into().unwrap()),
offset: u32::from_be_bytes(bytes[5..9].try_into().unwrap()),
length: u32::from_be_bytes(bytes[9..13].try_into().unwrap()),
};
if bytes[0] == 6 {
Message::Request(selection)
} else {
Message::Cancel(selection)
}
}
7 => {
if bytes.len() < 9 {
return Err(MessageParsingError::TooShort);
}
Message::Piece(PieceMessage {
index: u32::from_be_bytes(bytes[1..5].try_into().unwrap()),
offset: u32::from_be_bytes(bytes[5..9].try_into().unwrap()),
piece: bytes.slice(9..),
})
}
_ => return Err(MessageParsingError::UnknownType),
})
}
pub fn to_length_prefixed_bytes(&self) -> Vec<u8> {
let buffer = self.to_bytes();
let size = (buffer.len() as u32).to_be_bytes();
let mut map = Vec::with_capacity(buffer.len() + 4);
map.extend_from_slice(&size);
map.extend(buffer);
map
}
pub fn to_bytes(&self) -> Bytes {
match self {
Message::Choke => Bytes::from_static(&[0u8]),
Message::Unchoke => Bytes::from_static(&[1u8]),
Message::Interested => Bytes::from_static(&[2u8]),
Message::NotInterested => Bytes::from_static(&[3u8]),
Message::Have(piece) => {
let mut buffer = BytesMut::with_capacity(5);
buffer.extend(&[4]);
buffer.extend_from_slice(&piece.to_be_bytes());
buffer.freeze()
}
Message::Bitfield(bitfield) => {
let mut buffer = BytesMut::with_capacity(1 + bitfield.as_ref().len());
buffer.extend(&[5]);
buffer.extend_from_slice(bitfield.as_ref());
buffer.freeze()
}
Message::Request(request) => {
let mut buffer = BytesMut::with_capacity(13);
buffer.extend(&[6]);
buffer.extend_from_slice(&request.index.to_be_bytes());
buffer.extend_from_slice(&request.offset.to_be_bytes());
buffer.extend_from_slice(&request.length.to_be_bytes());
buffer.freeze()
}
Message::Piece(piece) => {
let mut buffer = BytesMut::with_capacity(9 + piece.piece.len());
buffer.extend(&[7]);
buffer.extend_from_slice(&piece.index.to_be_bytes());
buffer.extend_from_slice(&piece.offset.to_be_bytes());
buffer.extend_from_slice(piece.piece.as_ref());
buffer.freeze()
}
Message::Cancel(cancel) => {
let mut buffer = BytesMut::with_capacity(13);
buffer.extend(&[8]);
buffer.extend_from_slice(&cancel.index.to_be_bytes());
buffer.extend_from_slice(&cancel.offset.to_be_bytes());
buffer.extend_from_slice(&cancel.length.to_be_bytes());
buffer.freeze()
}
}
}
}
#[derive(Clone, Debug, Copy)]
pub struct Handshake {
reserved: [u8; 8],
info_hash: U160,
peer_id: U160,
}
impl Handshake {
pub fn from_bytes(bytes: Bytes) -> Result<Handshake, MessageParsingError> {
if bytes.len() < 68 {
return Err(MessageParsingError::TooShort);
}
if &bytes[0..20] != b"\x13BitTorrent protocol" {
return Err(MessageParsingError::InvalidPrefix);
}
Ok(Handshake {
reserved: bytes[20..28].try_into().unwrap(),
info_hash: U160::from_bytes(&bytes[28..48]).unwrap(),
peer_id: U160::from_bytes(&bytes[48..68]).unwrap(),
})
}
pub fn to_bytes(&self) -> [u8; 68] {
let mut header =
*b"\x13BitTorrent protocol\0\0\0\0\0\0\0\0AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
header[28..48].copy_from_slice(self.info_hash.to_bytes().as_ref());
header[48..68].copy_from_slice(self.peer_id.to_bytes().as_ref());
header
}
pub fn new(peer_id: U160, info_hash: U160) -> Handshake {
Handshake {
reserved: [0; 8],
info_hash,
peer_id,
}
}
pub fn peer_id(&self) -> U160 {
self.peer_id
}
pub fn info_hash(&self) -> U160 {
self.info_hash
}
}
#[derive(Clone, Debug, Hash, Copy, Ord, PartialOrd, Eq, PartialEq)]
pub struct SelectionMessage {
index: u32,
offset: u32,
length: u32,
}
impl SelectionMessage {
pub fn new(index: u32, offset: u32, length: u32) -> SelectionMessage {
SelectionMessage {
index,
offset,
length,
}
}
pub fn index(&self) -> u32 {
self.index
}
pub fn offset(&self) -> u32 {
self.offset
}
pub fn length(&self) -> u32 {
self.length
}
}
#[derive(Clone, Ord, PartialOrd, Eq, PartialEq)]
pub struct PieceMessage {
pub index: u32,
pub offset: u32,
pub piece: Bytes,
}
impl Debug for PieceMessage {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PieceMessage")
.field("index", &self.index)
.field("offset", &self.offset)
.field("piece", &self.piece.len())
.finish()
}
}
impl PieceMessage {
pub fn index(&self) -> u32 {
self.index
}
pub fn offset(&self) -> u32 {
self.offset
}
pub fn length(&self) -> u32 {
self.piece.len() as u32
}
pub fn piece(&self) -> Bytes {
self.piece.clone()
}
}
#[cfg(test)]
mod tests {
use crate::message::{Message, PieceMessage, SelectionMessage};
use bytes::Bytes;
use torment_core::Bitfield;
#[test]
fn round_trip() {
let msgs = [
Message::Choke,
Message::Unchoke,
Message::Interested,
Message::NotInterested,
Message::Have(42),
Message::Bitfield(Bitfield::with_size(4)),
Message::Request(SelectionMessage {
index: 69,
offset: 1337,
length: 42,
}),
Message::Piece(PieceMessage {
index: 69,
offset: 1337,
piece: Bytes::from_static(b"hewwo"),
}),
Message::Cancel(SelectionMessage {
index: 69,
offset: 1337,
length: 42,
}),
];
for msg in &msgs {
assert_eq!(msg, &Message::from_bytes(msg.to_bytes()).unwrap())
}
}
}