You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
openmw-tes3coop/apps/master/SimpleWeb/base_server.hpp

512 lines
20 KiB
C++

#ifndef BASE_SERVER_HPP
#define BASE_SERVER_HPP
#include <boost/asio.hpp>
#include <boost/algorithm/string/predicate.hpp>
#include <boost/functional/hash.hpp>
#include <map>
#include <unordered_map>
#include <thread>
#include <functional>
#include <iostream>
#include <sstream>
#include <regex>
#ifndef CASE_INSENSITIVE_EQUALS_AND_HASH
#define CASE_INSENSITIVE_EQUALS_AND_HASH
//Based on http://www.boost.org/doc/libs/1_60_0/doc/html/unordered/hash_equality.html
struct case_insensitive_equals
{
bool operator()(const std::string &key1, const std::string &key2) const
{
return boost::algorithm::iequals(key1, key2);
}
};
struct case_insensitive_hash
{
size_t operator()(const std::string &key) const
{
std::size_t seed = 0;
for (auto &c: key)
boost::hash_combine(seed, std::tolower(c));
return seed;
}
};
#endif
namespace SimpleWeb
{
template<class socket_type>
class Server;
template<class socket_type>
class ServerBase
{
public:
virtual ~ServerBase()
{}
class Response : public std::ostream
{
friend class ServerBase<socket_type>;
boost::asio::streambuf streambuf;
std::shared_ptr<socket_type> socket;
Response(const std::shared_ptr<socket_type> &socket) : std::ostream(&streambuf), socket(socket)
{}
public:
size_t size()
{
return streambuf.size();
}
/// If true, force server to close the connection after the response have been sent.
///
/// This is useful when implementing a HTTP/1.0-server sending content
/// without specifying the content length.
bool close_connection_after_response = false;
};
class Content : public std::istream
{
friend class ServerBase<socket_type>;
public:
size_t size()
{
return streambuf.size();
}
std::string string()
{
std::stringstream ss;
ss << rdbuf();
return ss.str();
}
private:
boost::asio::streambuf &streambuf;
Content(boost::asio::streambuf &streambuf) : std::istream(&streambuf), streambuf(streambuf)
{}
};
class Request
{
friend class ServerBase<socket_type>;
friend class Server<socket_type>;
public:
std::string method, path, http_version;
Content content;
std::unordered_multimap<std::string, std::string, case_insensitive_hash, case_insensitive_equals> header;
std::smatch path_match;
std::string remote_endpoint_address;
unsigned short remote_endpoint_port;
private:
Request(const socket_type &socket) : content(streambuf)
{
try
{
remote_endpoint_address = socket.lowest_layer().remote_endpoint().address().to_string();
remote_endpoint_port = socket.lowest_layer().remote_endpoint().port();
}
catch (...)
{}
}
boost::asio::streambuf streambuf;
};
class Config
{
friend class ServerBase<socket_type>;
Config(unsigned short port) : port(port)
{}
public:
/// Port number to use. Defaults to 80 for HTTP and 443 for HTTPS.
unsigned short port;
/// Number of threads that the server will use when start() is called. Defaults to 1 thread.
size_t thread_pool_size = 1;
/// Timeout on request handling. Defaults to 5 seconds.
size_t timeout_request = 5;
/// Timeout on content handling. Defaults to 300 seconds.
size_t timeout_content = 300;
/// IPv4 address in dotted decimal form or IPv6 address in hexadecimal notation.
/// If empty, the address will be any address.
std::string address;
/// Set to false to avoid binding the socket to an address that is already in use. Defaults to true.
bool reuse_address = true;
};
///Set before calling start().
Config config;
private:
class regex_orderable : public std::regex
{
std::string str;
public:
regex_orderable(const char *regex_cstr) : std::regex(regex_cstr), str(regex_cstr)
{}
regex_orderable(const std::string &regex_str) : std::regex(regex_str), str(regex_str)
{}
bool operator<(const regex_orderable &rhs) const
{
return str < rhs.str;
}
};
public:
/// Warning: do not add or remove resources after start() is called
std::map<regex_orderable, std::map<std::string,
std::function<void(std::shared_ptr<typename ServerBase<socket_type>::Response>,
std::shared_ptr<typename ServerBase<socket_type>::Request>)>>>
resource;
std::map<std::string,
std::function<void(std::shared_ptr<typename ServerBase<socket_type>::Response>,
std::shared_ptr<typename ServerBase<socket_type>::Request>)>> default_resource;
std::function<
void(std::shared_ptr<typename ServerBase<socket_type>::Request>,
const boost::system::error_code &)>
on_error;
std::function<void(std::shared_ptr<socket_type> socket,
std::shared_ptr<typename ServerBase<socket_type>::Request>)> on_upgrade;
virtual void start()
{
if (!io_service)
io_service = std::make_shared<boost::asio::io_service>();
if (io_service->stopped())
io_service->reset();
boost::asio::ip::tcp::endpoint endpoint;
if (config.address.size() > 0)
endpoint = boost::asio::ip::tcp::endpoint(boost::asio::ip::address::from_string(config.address),
config.port);
else
endpoint = boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), config.port);
if (!acceptor)
acceptor = std::unique_ptr<boost::asio::ip::tcp::acceptor>(
new boost::asio::ip::tcp::acceptor(*io_service));
acceptor->open(endpoint.protocol());
acceptor->set_option(boost::asio::socket_base::reuse_address(config.reuse_address));
acceptor->bind(endpoint);
acceptor->listen();
accept();
//If thread_pool_size>1, start m_io_service.run() in (thread_pool_size-1) threads for thread-pooling
threads.clear();
for (size_t c = 1; c < config.thread_pool_size; c++)
{
threads.emplace_back([this]()
{
io_service->run();
});
}
//Main thread
if (config.thread_pool_size > 0)
io_service->run();
//Wait for the rest of the threads, if any, to finish as well
for (auto &t: threads)
{
t.join();
}
}
void stop()
{
acceptor->close();
if (config.thread_pool_size > 0)
io_service->stop();
}
///Use this function if you need to recursively send parts of a longer message
void send(const std::shared_ptr<Response> &response,
const std::function<void(const boost::system::error_code &)> &callback = nullptr) const
{
boost::asio::async_write(*response->socket, response->streambuf, [this, response, callback]
(const boost::system::error_code &ec, size_t /*bytes_transferred*/)
{
if (callback)
callback(ec);
});
}
/// If you have your own boost::asio::io_service, store its pointer here before running start().
/// You might also want to set config.thread_pool_size to 0.
std::shared_ptr<boost::asio::io_service> io_service;
protected:
std::unique_ptr<boost::asio::ip::tcp::acceptor> acceptor;
std::vector<std::thread> threads;
ServerBase(unsigned short port) : config(port)
{}
virtual void accept()=0;
std::shared_ptr<boost::asio::deadline_timer>
get_timeout_timer(const std::shared_ptr<socket_type> &socket, long seconds)
{
if (seconds == 0)
return nullptr;
auto timer = std::make_shared<boost::asio::deadline_timer>(*io_service);
timer->expires_from_now(boost::posix_time::seconds(seconds));
timer->async_wait([socket](const boost::system::error_code &ec)
{
if (!ec)
{
boost::system::error_code ec;
socket->lowest_layer().shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec);
socket->lowest_layer().close();
}
});
return timer;
}
void read_request_and_content(const std::shared_ptr<socket_type> &socket)
{
//Create new streambuf (Request::streambuf) for async_read_until()
//shared_ptr is used to pass temporary objects to the asynchronous functions
std::shared_ptr<Request> request(new Request(*socket));
//Set timeout on the following boost::asio::async-read or write function
auto timer = this->get_timeout_timer(socket, config.timeout_request);
boost::asio::async_read_until(*socket, request->streambuf, "\r\n\r\n", [this, socket, request, timer]
(const boost::system::error_code &ec,
size_t bytes_transferred)
{
if (timer)
timer->cancel();
if (!ec)
{
//request->streambuf.size() is not necessarily the same as bytes_transferred, from Boost-docs:
//"After a successful async_read_until operation, the streambuf may contain additional data beyond the delimiter"
//The chosen solution is to extract lines from the stream directly when parsing the header. What is left of the
//streambuf (maybe some bytes of the content) is appended to in the async_read-function below (for retrieving content).
size_t num_additional_bytes =
request->streambuf.size() - bytes_transferred;
if (!this->parse_request(request))
return;
//If content, read that as well
auto it = request->header.find("Content-Length");
if (it != request->header.end())
{
unsigned long long content_length;
try
{
content_length = stoull(it->second);
}
catch (const std::exception &e)
{
if (on_error)
on_error(request, boost::system::error_code(
boost::system::errc::protocol_error,
boost::system::generic_category()));
return;
}
if (content_length > num_additional_bytes)
{
//Set timeout on the following boost::asio::async-read or write function
auto timer = this->get_timeout_timer(socket,
config.timeout_content);
boost::asio::async_read(*socket, request->streambuf,
boost::asio::transfer_exactly(
content_length -
num_additional_bytes),
[this, socket, request, timer]
(const boost::system::error_code &ec,
size_t /*bytes_transferred*/)
{
if (timer)
timer->cancel();
if (!ec)
this->find_resource(socket,
request);
else if (on_error)
on_error(request, ec);
});
}
else
this->find_resource(socket, request);
}
else
this->find_resource(socket, request);
}
else if (on_error)
on_error(request, ec);
});
}
bool parse_request(const std::shared_ptr<Request> &request) const
{
std::string line;
getline(request->content, line);
size_t method_end;
if ((method_end = line.find(' ')) != std::string::npos)
{
size_t path_end;
if ((path_end = line.find(' ', method_end + 1)) != std::string::npos)
{
request->method = line.substr(0, method_end);
request->path = line.substr(method_end + 1, path_end - method_end - 1);
size_t protocol_end;
if ((protocol_end = line.find('/', path_end + 1)) != std::string::npos)
{
if (line.compare(path_end + 1, protocol_end - path_end - 1, "HTTP") != 0)
return false;
request->http_version = line.substr(protocol_end + 1, line.size() - protocol_end - 2);
}
else
return false;
getline(request->content, line);
size_t param_end;
while ((param_end = line.find(':')) != std::string::npos)
{
size_t value_start = param_end + 1;
if ((value_start) < line.size())
{
if (line[value_start] == ' ')
value_start++;
if (value_start < line.size())
request->header.emplace(line.substr(0, param_end),
line.substr(value_start, line.size() - value_start - 1));
}
getline(request->content, line);
}
}
else
return false;
}
else
return false;
return true;
}
void find_resource(const std::shared_ptr<socket_type> &socket, const std::shared_ptr<Request> &request)
{
//Upgrade connection
if (on_upgrade)
{
auto it = request->header.find("Upgrade");
if (it != request->header.end())
{
on_upgrade(socket, request);
return;
}
}
//Find path- and method-match, and call write_response
for (auto &regex_method: resource)
{
auto it = regex_method.second.find(request->method);
if (it != regex_method.second.end())
{
std::smatch sm_res;
if (std::regex_match(request->path, sm_res, regex_method.first))
{
request->path_match = std::move(sm_res);
write_response(socket, request, it->second);
return;
}
}
}
auto it = default_resource.find(request->method);
if (it != default_resource.end())
{
write_response(socket, request, it->second);
}
}
void write_response(const std::shared_ptr<socket_type> &socket, const std::shared_ptr<Request> &request,
std::function<void(std::shared_ptr<typename ServerBase<socket_type>::Response>,
std::shared_ptr<
typename ServerBase<socket_type>::Request>)> &resource_function)
{
//Set timeout on the following boost::asio::async-read or write function
auto timer = this->get_timeout_timer(socket, config.timeout_content);
auto response = std::shared_ptr<Response>(new Response(socket), [this, request, timer]
(Response *response_ptr)
{
auto response = std::shared_ptr<Response>(response_ptr);
this->send(response, [this, response, request, timer](
const boost::system::error_code &ec)
{
if (timer)
timer->cancel();
if (!ec)
{
if (response->close_connection_after_response)
return;
auto range = request->header.equal_range(
"Connection");
for (auto it = range.first; it != range.second; it++)
{
if (boost::iequals(it->second, "close"))
{
return;
}
else if (boost::iequals(it->second, "keep-alive"))
{
this->read_request_and_content(
response->socket);
return;
}
}
if (request->http_version >= "1.1")
this->read_request_and_content(response->socket);
}
else if (on_error)
on_error(request, ec);
});
});
try
{
resource_function(response, request);
}
catch (const std::exception &e)
{
if (on_error)
on_error(request, boost::system::error_code(boost::system::errc::operation_canceled,
boost::system::generic_category()));
return;
}
}
};
}
#endif //BASE_SERVER_HPP