diff --git a/apps/master/SimpleWeb/LICENSE b/apps/master/SimpleWeb/LICENSE new file mode 100644 index 000000000..7a31af4b8 --- /dev/null +++ b/apps/master/SimpleWeb/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2014-2016 Ole Christian Eidheim + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/apps/master/SimpleWeb/base_server.hpp b/apps/master/SimpleWeb/base_server.hpp index d540a5de9..dae1ddc76 100644 --- a/apps/master/SimpleWeb/base_server.hpp +++ b/apps/master/SimpleWeb/base_server.hpp @@ -1,72 +1,137 @@ -#ifndef BASE_SERVER_HPP -#define BASE_SERVER_HPP +#pragma once -#include -#include -#include - -#include -#include -#include +#include "utility.hpp" +#include #include #include +#include #include +#include +#include #include -#ifndef CASE_INSENSITIVE_EQUALS_AND_HASH -#define CASE_INSENSITIVE_EQUALS_AND_HASH - -//Based on http://www.boost.org/doc/libs/1_60_0/doc/html/unordered/hash_equality.html -struct case_insensitive_equals -{ - bool operator()(const std::string &key1, const std::string &key2) const - { - return boost::algorithm::iequals(key1, key2); - } -}; - -struct case_insensitive_hash -{ - size_t operator()(const std::string &key) const - { - std::size_t seed = 0; - for (auto &c: key) - boost::hash_combine(seed, std::tolower(c)); - return seed; - } -}; - +#ifdef USE_STANDALONE_ASIO +#include +#include +namespace SimpleWeb { + using error_code = std::error_code; + using errc = std::errc; + namespace make_error_code = std; +} // namespace SimpleWeb +#else +#include +#include +namespace SimpleWeb { + namespace asio = boost::asio; + using error_code = boost::system::error_code; + namespace errc = boost::system::errc; + namespace make_error_code = boost::system::errc; +} // namespace SimpleWeb #endif -namespace SimpleWeb -{ - template +namespace SimpleWeb { + template class Server; - template - class ServerBase - { - public: - virtual ~ServerBase() - {} + template + class ServerBase { + protected: + class Session; - class Response : public std::ostream - { + public: + class Response : public std::enable_shared_from_this, public std::ostream { friend class ServerBase; + friend class Server; + + asio::streambuf streambuf; - boost::asio::streambuf streambuf; + std::shared_ptr session; + long timeout_content; - std::shared_ptr socket; + Response(std::shared_ptr session, long timeout_content) noexcept : std::ostream(&streambuf), session(std::move(session)), timeout_content(timeout_content) {} - Response(const std::shared_ptr &socket) : std::ostream(&streambuf), socket(socket) - {} + template + void write_header(const CaseInsensitiveMultimap &header, size_type size) { + bool content_length_written = false; + bool chunked_transfer_encoding = false; + for(auto &field : header) { + if(!content_length_written && case_insensitive_equal(field.first, "content-length")) + content_length_written = true; + else if(!chunked_transfer_encoding && case_insensitive_equal(field.first, "transfer-encoding") && case_insensitive_equal(field.second, "chunked")) + chunked_transfer_encoding = true; + + *this << field.first << ": " << field.second << "\r\n"; + } + if(!content_length_written && !chunked_transfer_encoding && !close_connection_after_response) + *this << "Content-Length: " << size << "\r\n\r\n"; + else + *this << "\r\n"; + } public: - size_t size() - { + size_t size() noexcept { return streambuf.size(); } + /// Use this function if you need to recursively send parts of a longer message + void send(const std::function &callback = nullptr) noexcept { + session->connection->set_timeout(timeout_content); + auto self = this->shared_from_this(); // Keep Response instance alive through the following async_write + asio::async_write(*session->connection->socket, streambuf, [self, callback](const error_code &ec, size_t /*bytes_transferred*/) { + self->session->connection->cancel_timeout(); + auto lock = self->session->connection->handler_runner->continue_lock(); + if(!lock) + return; + if(callback) + callback(ec); + }); + } + + /// Write directly to stream buffer using std::ostream::write + void write(const char_type *ptr, std::streamsize n) { + std::ostream::write(ptr, n); + } + + /// Convenience function for writing status line, potential header fields, and empty content + void write(StatusCode status_code = StatusCode::success_ok, const CaseInsensitiveMultimap &header = CaseInsensitiveMultimap()) { + *this << "HTTP/1.1 " << SimpleWeb::status_code(status_code) << "\r\n"; + write_header(header, 0); + } + + /// Convenience function for writing status line, header fields, and content + void write(StatusCode status_code, const std::string &content, const CaseInsensitiveMultimap &header = CaseInsensitiveMultimap()) { + *this << "HTTP/1.1 " << SimpleWeb::status_code(status_code) << "\r\n"; + write_header(header, content.size()); + if(!content.empty()) + *this << content; + } + + /// Convenience function for writing status line, header fields, and content + void write(StatusCode status_code, std::istream &content, const CaseInsensitiveMultimap &header = CaseInsensitiveMultimap()) { + *this << "HTTP/1.1 " << SimpleWeb::status_code(status_code) << "\r\n"; + content.seekg(0, std::ios::end); + auto size = content.tellg(); + content.seekg(0, std::ios::beg); + write_header(header, size); + if(size) + *this << content.rdbuf(); + } + + /// Convenience function for writing success status line, header fields, and content + void write(const std::string &content, const CaseInsensitiveMultimap &header = CaseInsensitiveMultimap()) { + write(StatusCode::success_ok, content, header); + } + + /// Convenience function for writing success status line, header fields, and content + void write(std::istream &content, const CaseInsensitiveMultimap &header = CaseInsensitiveMultimap()) { + write(StatusCode::success_ok, content, header); + } + + /// Convenience function for writing success status line, and header fields + void write(const CaseInsensitiveMultimap &header) { + write(StatusCode::success_ok, std::string(), header); + } + /// If true, force server to close the connection after the response have been sent. /// /// This is useful when implementing a HTTP/1.0-server sending content @@ -74,438 +139,399 @@ namespace SimpleWeb bool close_connection_after_response = false; }; - class Content : public std::istream - { + class Content : public std::istream { friend class ServerBase; public: - size_t size() - { + size_t size() noexcept { return streambuf.size(); } - - std::string string() - { - std::stringstream ss; - ss << rdbuf(); - return ss.str(); + /// Convenience function to return std::string. The stream buffer is consumed. + std::string string() noexcept { + try { + std::stringstream ss; + ss << rdbuf(); + return ss.str(); + } + catch(...) { + return std::string(); + } } private: - boost::asio::streambuf &streambuf; - - Content(boost::asio::streambuf &streambuf) : std::istream(&streambuf), streambuf(streambuf) - {} + asio::streambuf &streambuf; + Content(asio::streambuf &streambuf) noexcept : std::istream(&streambuf), streambuf(streambuf) {} }; - class Request - { + class Request { friend class ServerBase; - friend class Server; + friend class Session; public: - std::string method, path, http_version; + std::string method, path, query_string, http_version; Content content; - std::unordered_multimap header; + CaseInsensitiveMultimap header; std::smatch path_match; std::string remote_endpoint_address; unsigned short remote_endpoint_port; + /// Returns query keys with percent-decoded values. + CaseInsensitiveMultimap parse_query_string() noexcept { + return SimpleWeb::QueryString::parse(query_string); + } + private: - Request(const socket_type &socket) : content(streambuf) - { - try - { - remote_endpoint_address = socket.lowest_layer().remote_endpoint().address().to_string(); - remote_endpoint_port = socket.lowest_layer().remote_endpoint().port(); + asio::streambuf streambuf; + + Request(const std::string &remote_endpoint_address = std::string(), unsigned short remote_endpoint_port = 0) noexcept + : content(streambuf), remote_endpoint_address(remote_endpoint_address), remote_endpoint_port(remote_endpoint_port) {} + }; + + protected: + class Connection : public std::enable_shared_from_this { + public: + template + Connection(std::shared_ptr handler_runner, Args &&... args) noexcept : handler_runner(std::move(handler_runner)), socket(new socket_type(std::forward(args)...)) {} + + std::shared_ptr handler_runner; + + std::unique_ptr socket; // Socket must be unique_ptr since asio::ssl::stream is not movable + std::mutex socket_close_mutex; + + std::unique_ptr timer; + + void close() noexcept { + error_code ec; + std::unique_lock lock(socket_close_mutex); // The following operations seems to be needed to run sequentially + socket->lowest_layer().shutdown(asio::ip::tcp::socket::shutdown_both, ec); + socket->lowest_layer().close(ec); + } + + void set_timeout(long seconds) noexcept { + if(seconds == 0) { + timer = nullptr; + return; + } + + timer = std::unique_ptr(new asio::steady_timer(socket->get_io_service())); + timer->expires_from_now(std::chrono::seconds(seconds)); + auto self = this->shared_from_this(); + timer->async_wait([self](const error_code &ec) { + if(!ec) + self->close(); + }); + } + + void cancel_timeout() noexcept { + if(timer) { + error_code ec; + timer->cancel(ec); + } + } + }; + + class Session { + public: + Session(std::shared_ptr connection) noexcept : connection(std::move(connection)) { + try { + auto remote_endpoint = this->connection->socket->lowest_layer().remote_endpoint(); + request = std::shared_ptr(new Request(remote_endpoint.address().to_string(), remote_endpoint.port())); + } + catch(...) { + request = std::shared_ptr(new Request()); } - catch (...) - {} } - boost::asio::streambuf streambuf; + std::shared_ptr connection; + std::shared_ptr request; }; - class Config - { + public: + class Config { friend class ServerBase; - Config(unsigned short port) : port(port) - {} + Config(unsigned short port) noexcept : port(port) {} public: /// Port number to use. Defaults to 80 for HTTP and 443 for HTTPS. unsigned short port; - /// Number of threads that the server will use when start() is called. Defaults to 1 thread. + /// If io_service is not set, number of threads that the server will use when start() is called. + /// Defaults to 1 thread. size_t thread_pool_size = 1; /// Timeout on request handling. Defaults to 5 seconds. - size_t timeout_request = 5; + long timeout_request = 5; /// Timeout on content handling. Defaults to 300 seconds. - size_t timeout_content = 300; + long timeout_content = 300; /// IPv4 address in dotted decimal form or IPv6 address in hexadecimal notation. /// If empty, the address will be any address. std::string address; /// Set to false to avoid binding the socket to an address that is already in use. Defaults to true. bool reuse_address = true; }; - - ///Set before calling start(). + /// Set before calling start(). Config config; private: - class regex_orderable : public std::regex - { + class regex_orderable : public std::regex { std::string str; - public: - regex_orderable(const char *regex_cstr) : std::regex(regex_cstr), str(regex_cstr) - {} - regex_orderable(const std::string ®ex_str) : std::regex(regex_str), str(regex_str) - {} - - bool operator<(const regex_orderable &rhs) const - { + public: + regex_orderable(const char *regex_cstr) : std::regex(regex_cstr), str(regex_cstr) {} + regex_orderable(std::string regex_str) : std::regex(regex_str), str(std::move(regex_str)) {} + bool operator<(const regex_orderable &rhs) const noexcept { return str < rhs.str; } }; public: /// Warning: do not add or remove resources after start() is called - std::map::Response>, - std::shared_ptr::Request>)>>> - resource; + std::map::Response>, std::shared_ptr::Request>)>>> resource; + + std::map::Response>, std::shared_ptr::Request>)>> default_resource; - std::map::Response>, - std::shared_ptr::Request>)>> default_resource; + std::function::Request>, const error_code &)> on_error; - std::function< - void(std::shared_ptr::Request>, - const boost::system::error_code &)> - on_error; + std::function &, std::shared_ptr::Request>)> on_upgrade; - std::function socket, - std::shared_ptr::Request>)> on_upgrade; + /// If you have your own asio::io_service, store its pointer here before running start(). + std::shared_ptr io_service; - virtual void start() - { - if (!io_service) - io_service = std::make_shared(); + virtual void start() { + if(!io_service) { + io_service = std::make_shared(); + internal_io_service = true; + } - if (io_service->stopped()) + if(io_service->stopped()) io_service->reset(); - boost::asio::ip::tcp::endpoint endpoint; - if (config.address.size() > 0) - endpoint = boost::asio::ip::tcp::endpoint(boost::asio::ip::address::from_string(config.address), - config.port); + asio::ip::tcp::endpoint endpoint; + if(config.address.size() > 0) + endpoint = asio::ip::tcp::endpoint(asio::ip::address::from_string(config.address), config.port); else - endpoint = boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), config.port); + endpoint = asio::ip::tcp::endpoint(asio::ip::tcp::v4(), config.port); - if (!acceptor) - acceptor = std::unique_ptr( - new boost::asio::ip::tcp::acceptor(*io_service)); + if(!acceptor) + acceptor = std::unique_ptr(new asio::ip::tcp::acceptor(*io_service)); acceptor->open(endpoint.protocol()); - acceptor->set_option(boost::asio::socket_base::reuse_address(config.reuse_address)); + acceptor->set_option(asio::socket_base::reuse_address(config.reuse_address)); acceptor->bind(endpoint); acceptor->listen(); accept(); - //If thread_pool_size>1, start m_io_service.run() in (thread_pool_size-1) threads for thread-pooling - threads.clear(); - for (size_t c = 1; c < config.thread_pool_size; c++) - { - threads.emplace_back([this]() - { - io_service->run(); - }); - } + if(internal_io_service) { + // If thread_pool_size>1, start m_io_service.run() in (thread_pool_size-1) threads for thread-pooling + threads.clear(); + for(size_t c = 1; c < config.thread_pool_size; c++) { + threads.emplace_back([this]() { + this->io_service->run(); + }); + } - //Main thread - if (config.thread_pool_size > 0) - io_service->run(); + // Main thread + if(config.thread_pool_size > 0) + io_service->run(); - //Wait for the rest of the threads, if any, to finish as well - for (auto &t: threads) - { - t.join(); + // Wait for the rest of the threads, if any, to finish as well + for(auto &t : threads) + t.join(); } } - void stop() - { - acceptor->close(); - if (config.thread_pool_size > 0) - io_service->stop(); + /// Stop accepting new requests, and close current connections. + void stop() noexcept { + if(acceptor) { + error_code ec; + acceptor->close(ec); + + { + std::unique_lock lock(*connections_mutex); + for(auto &connection : *connections) + connection->close(); + connections->clear(); + } + + if(internal_io_service) + io_service->stop(); + } } - ///Use this function if you need to recursively send parts of a longer message - void send(const std::shared_ptr &response, - const std::function &callback = nullptr) const - { - boost::asio::async_write(*response->socket, response->streambuf, [this, response, callback] - (const boost::system::error_code &ec, size_t /*bytes_transferred*/) - { - if (callback) - callback(ec); - }); + virtual ~ServerBase() noexcept { + handler_runner->stop(); + stop(); } - /// If you have your own boost::asio::io_service, store its pointer here before running start(). - /// You might also want to set config.thread_pool_size to 0. - std::shared_ptr io_service; protected: - std::unique_ptr acceptor; + bool internal_io_service = false; + + std::unique_ptr acceptor; std::vector threads; - ServerBase(unsigned short port) : config(port) - {} - - virtual void accept()=0; - - std::shared_ptr - get_timeout_timer(const std::shared_ptr &socket, long seconds) - { - if (seconds == 0) - return nullptr; - - auto timer = std::make_shared(*io_service); - timer->expires_from_now(boost::posix_time::seconds(seconds)); - timer->async_wait([socket](const boost::system::error_code &ec) - { - if (!ec) - { - boost::system::error_code ec; - socket->lowest_layer().shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec); - socket->lowest_layer().close(); - } - }); - return timer; - } + std::shared_ptr> connections; + std::shared_ptr connections_mutex; - void read_request_and_content(const std::shared_ptr &socket) - { - //Create new streambuf (Request::streambuf) for async_read_until() - //shared_ptr is used to pass temporary objects to the asynchronous functions - std::shared_ptr request(new Request(*socket)); + std::shared_ptr handler_runner; - //Set timeout on the following boost::asio::async-read or write function - auto timer = this->get_timeout_timer(socket, config.timeout_request); + ServerBase(unsigned short port) noexcept : config(port), connections(new std::unordered_set()), connections_mutex(new std::mutex()), handler_runner(new ScopeRunner()) {} - boost::asio::async_read_until(*socket, request->streambuf, "\r\n\r\n", [this, socket, request, timer] - (const boost::system::error_code &ec, - size_t bytes_transferred) - { - if (timer) - timer->cancel(); - if (!ec) + virtual void accept() = 0; + + template + std::shared_ptr create_connection(Args &&... args) noexcept { + auto connections = this->connections; + auto connections_mutex = this->connections_mutex; + auto connection = std::shared_ptr(new Connection(handler_runner, std::forward(args)...), [connections, connections_mutex](Connection *connection) { { - //request->streambuf.size() is not necessarily the same as bytes_transferred, from Boost-docs: - //"After a successful async_read_until operation, the streambuf may contain additional data beyond the delimiter" - //The chosen solution is to extract lines from the stream directly when parsing the header. What is left of the - //streambuf (maybe some bytes of the content) is appended to in the async_read-function below (for retrieving content). - size_t num_additional_bytes = - request->streambuf.size() - bytes_transferred; - - if (!this->parse_request(request)) + std::unique_lock lock(*connections_mutex); + auto it = connections->find(connection); + if(it != connections->end()) + connections->erase(it); + } + delete connection; + }); + { + std::unique_lock lock(*connections_mutex); + connections->emplace(connection.get()); + } + return connection; + } + + void read_request_and_content(const std::shared_ptr &session) { + session->connection->set_timeout(config.timeout_request); + asio::async_read_until(*session->connection->socket, session->request->streambuf, "\r\n\r\n", [this, session](const error_code &ec, size_t bytes_transferred) { + session->connection->cancel_timeout(); + auto lock = session->connection->handler_runner->continue_lock(); + if(!lock) + return; + if(!ec) { + // request->streambuf.size() is not necessarily the same as bytes_transferred, from Boost-docs: + // "After a successful async_read_until operation, the streambuf may contain additional data beyond the delimiter" + // The chosen solution is to extract lines from the stream directly when parsing the header. What is left of the + // streambuf (maybe some bytes of the content) is appended to in the async_read-function below (for retrieving content). + size_t num_additional_bytes = session->request->streambuf.size() - bytes_transferred; + + if(!RequestMessage::parse(session->request->content, session->request->method, session->request->path, + session->request->query_string, session->request->http_version, session->request->header)) { + if(this->on_error) + this->on_error(session->request, make_error_code::make_error_code(errc::protocol_error)); return; + } - //If content, read that as well - auto it = request->header.find("Content-Length"); - if (it != request->header.end()) - { - unsigned long long content_length; - try - { + // If content, read that as well + auto it = session->request->header.find("Content-Length"); + if(it != session->request->header.end()) { + unsigned long long content_length = 0; + try { content_length = stoull(it->second); } - catch (const std::exception &e) - { - if (on_error) - on_error(request, boost::system::error_code( - boost::system::errc::protocol_error, - boost::system::generic_category())); + catch(const std::exception &e) { + if(this->on_error) + this->on_error(session->request, make_error_code::make_error_code(errc::protocol_error)); return; } - if (content_length > num_additional_bytes) - { - //Set timeout on the following boost::asio::async-read or write function - auto timer = this->get_timeout_timer(socket, - config.timeout_content); - boost::asio::async_read(*socket, request->streambuf, - boost::asio::transfer_exactly( - content_length - - num_additional_bytes), - [this, socket, request, timer] - (const boost::system::error_code &ec, - size_t /*bytes_transferred*/) - { - if (timer) - timer->cancel(); - if (!ec) - this->find_resource(socket, - request); - else if (on_error) - on_error(request, ec); - }); + if(content_length > num_additional_bytes) { + session->connection->set_timeout(config.timeout_content); + asio::async_read(*session->connection->socket, session->request->streambuf, asio::transfer_exactly(content_length - num_additional_bytes), [this, session](const error_code &ec, size_t /*bytes_transferred*/) { + session->connection->cancel_timeout(); + auto lock = session->connection->handler_runner->continue_lock(); + if(!lock) + return; + if(!ec) + this->find_resource(session); + else if(this->on_error) + this->on_error(session->request, ec); + }); } else - this->find_resource(socket, request); + this->find_resource(session); } else - this->find_resource(socket, request); + this->find_resource(session); } - else if (on_error) - on_error(request, ec); + else if(this->on_error) + this->on_error(session->request, ec); }); } - bool parse_request(const std::shared_ptr &request) const - { - std::string line; - getline(request->content, line); - size_t method_end; - if ((method_end = line.find(' ')) != std::string::npos) - { - size_t path_end; - if ((path_end = line.find(' ', method_end + 1)) != std::string::npos) - { - request->method = line.substr(0, method_end); - request->path = line.substr(method_end + 1, path_end - method_end - 1); - - size_t protocol_end; - if ((protocol_end = line.find('/', path_end + 1)) != std::string::npos) + void find_resource(const std::shared_ptr &session) { + // Upgrade connection + if(on_upgrade) { + auto it = session->request->header.find("Upgrade"); + if(it != session->request->header.end()) { + // remove connection from connections { - if (line.compare(path_end + 1, protocol_end - path_end - 1, "HTTP") != 0) - return false; - request->http_version = line.substr(protocol_end + 1, line.size() - protocol_end - 2); + std::unique_lock lock(*connections_mutex); + auto it = connections->find(session->connection.get()); + if(it != connections->end()) + connections->erase(it); } - else - return false; - - getline(request->content, line); - size_t param_end; - while ((param_end = line.find(':')) != std::string::npos) - { - size_t value_start = param_end + 1; - if ((value_start) < line.size()) - { - if (line[value_start] == ' ') - value_start++; - if (value_start < line.size()) - request->header.emplace(line.substr(0, param_end), - line.substr(value_start, line.size() - value_start - 1)); - } - - getline(request->content, line); - } - } - else - return false; - } - else - return false; - return true; - } - void find_resource(const std::shared_ptr &socket, const std::shared_ptr &request) - { - //Upgrade connection - if (on_upgrade) - { - auto it = request->header.find("Upgrade"); - if (it != request->header.end()) - { - on_upgrade(socket, request); + on_upgrade(session->connection->socket, session->request); return; } } - //Find path- and method-match, and call write_response - for (auto ®ex_method: resource) - { - auto it = regex_method.second.find(request->method); - if (it != regex_method.second.end()) - { + // Find path- and method-match, and call write_response + for(auto ®ex_method : resource) { + auto it = regex_method.second.find(session->request->method); + if(it != regex_method.second.end()) { std::smatch sm_res; - if (std::regex_match(request->path, sm_res, regex_method.first)) - { - request->path_match = std::move(sm_res); - write_response(socket, request, it->second); + if(std::regex_match(session->request->path, sm_res, regex_method.first)) { + session->request->path_match = std::move(sm_res); + write_response(session, it->second); return; } } } - auto it = default_resource.find(request->method); - if (it != default_resource.end()) - { - write_response(socket, request, it->second); - } + auto it = default_resource.find(session->request->method); + if(it != default_resource.end()) + write_response(session, it->second); } - void write_response(const std::shared_ptr &socket, const std::shared_ptr &request, - std::function::Response>, - std::shared_ptr< - typename ServerBase::Request>)> &resource_function) - { - //Set timeout on the following boost::asio::async-read or write function - auto timer = this->get_timeout_timer(socket, config.timeout_content); - - auto response = std::shared_ptr(new Response(socket), [this, request, timer] - (Response *response_ptr) - { + void write_response(const std::shared_ptr &session, + std::function::Response>, std::shared_ptr::Request>)> &resource_function) { + session->connection->set_timeout(config.timeout_content); + auto response = std::shared_ptr(new Response(session, config.timeout_content), [this](Response *response_ptr) { auto response = std::shared_ptr(response_ptr); - this->send(response, [this, response, request, timer]( - const boost::system::error_code &ec) - { - if (timer) - timer->cancel(); - if (!ec) - { - if (response->close_connection_after_response) + response->send([this, response](const error_code &ec) { + if(!ec) { + if(response->close_connection_after_response) return; - auto range = request->header.equal_range( - "Connection"); - for (auto it = range.first; it != range.second; it++) - { - if (boost::iequals(it->second, "close")) - { + auto range = response->session->request->header.equal_range("Connection"); + for(auto it = range.first; it != range.second; it++) { + if(case_insensitive_equal(it->second, "close")) return; - } - else if (boost::iequals(it->second, "keep-alive")) - { - this->read_request_and_content( - response->socket); + else if(case_insensitive_equal(it->second, "keep-alive")) { + auto new_session = std::make_shared(response->session->connection); + this->read_request_and_content(new_session); return; } } - if (request->http_version >= "1.1") - this->read_request_and_content(response->socket); + if(response->session->request->http_version >= "1.1") { + auto new_session = std::make_shared(response->session->connection); + this->read_request_and_content(new_session); + return; + } } - else if (on_error) - on_error(request, ec); + else if(this->on_error) + this->on_error(response->session->request, ec); }); }); - try - { - resource_function(response, request); + try { + resource_function(response, session->request); } - catch (const std::exception &e) - { - if (on_error) - on_error(request, boost::system::error_code(boost::system::errc::operation_canceled, - boost::system::generic_category())); + catch(const std::exception &e) { + if(on_error) + on_error(session->request, make_error_code::make_error_code(errc::operation_canceled)); return; } } }; -} -#endif //BASE_SERVER_HPP +} \ No newline at end of file diff --git a/apps/master/SimpleWeb/http_server.hpp b/apps/master/SimpleWeb/http_server.hpp index 99a5bbd4b..030f85c28 100644 --- a/apps/master/SimpleWeb/http_server.hpp +++ b/apps/master/SimpleWeb/http_server.hpp @@ -1,55 +1,42 @@ -/* - * https://github.com/eidheim/Simple-Web-Server/ - * - * The MIT License (MIT) - * Copyright (c) 2014-2016 Ole Christian Eidheim - */ - -#ifndef SERVER_HTTP_HPP -#define SERVER_HTTP_HPP +#pragma once #include "base_server.hpp" -namespace SimpleWeb -{ +namespace SimpleWeb { - template + template class Server : public ServerBase {}; - typedef boost::asio::ip::tcp::socket HTTP; + using HTTP = asio::ip::tcp::socket; - template<> - class Server : public ServerBase - { + template <> + class Server : public ServerBase { public: - Server() : ServerBase::ServerBase(80) - {} + Server() noexcept : ServerBase::ServerBase(80) {} protected: - virtual void accept() - { - //Create new socket for this connection - //Shared_ptr is used to pass temporary objects to the asynchronous functions - auto socket = std::make_shared(*io_service); - - acceptor->async_accept(*socket, [this, socket](const boost::system::error_code &ec) - { - //Immediately start accepting a new connection (if io_service hasn't been stopped) - if (ec != boost::asio::error::operation_aborted) - accept(); - - if (!ec) - { - boost::asio::ip::tcp::no_delay option(true); - socket->set_option(option); - - this->read_request_and_content(socket); + void accept() override { + auto session = std::make_shared(create_connection(*io_service)); + + acceptor->async_accept(*session->connection->socket, [this, session](const error_code &ec) { + auto lock = session->connection->handler_runner->continue_lock(); + if(!lock) + return; + + // Immediately start accepting a new connection (unless io_service has been stopped) + if(ec != asio::error::operation_aborted) + this->accept(); + + if(!ec) { + asio::ip::tcp::no_delay option(true); + error_code ec; + session->connection->socket->set_option(option, ec); + + this->read_request_and_content(session); } - else if (on_error) - on_error(std::shared_ptr(new Request(*socket)), ec); + else if(this->on_error) + this->on_error(session->request, ec); }); } }; -} - -#endif //SERVER_HTTP_HPP +} // namespace SimpleWeb diff --git a/apps/master/SimpleWeb/https_server.hpp b/apps/master/SimpleWeb/https_server.hpp index fb8268889..3e79d014d 100644 --- a/apps/master/SimpleWeb/https_server.hpp +++ b/apps/master/SimpleWeb/https_server.hpp @@ -1,91 +1,82 @@ -#ifndef HTTPS_SERVER_HPP -#define HTTPS_SERVER_HPP +#pragma once #include "base_server.hpp" + +#ifdef USE_STANDALONE_ASIO +#include +#else #include -#include +#endif + #include +#include -namespace SimpleWeb -{ - typedef boost::asio::ssl::stream HTTPS; +namespace SimpleWeb { + using HTTPS = asio::ssl::stream; - template<> - class Server : public ServerBase - { + template <> + class Server : public ServerBase { std::string session_id_context; bool set_session_id_context = false; + public: - Server(const std::string &cert_file, const std::string &private_key_file, - const std::string &verify_file = std::string()) : ServerBase::ServerBase(443), - context(boost::asio::ssl::context::tlsv12) - { + Server(const std::string &cert_file, const std::string &private_key_file, const std::string &verify_file = std::string()) + : ServerBase::ServerBase(443), context(asio::ssl::context::tlsv12) { context.use_certificate_chain_file(cert_file); - context.use_private_key_file(private_key_file, boost::asio::ssl::context::pem); + context.use_private_key_file(private_key_file, asio::ssl::context::pem); - if (verify_file.size() > 0) - { + if(verify_file.size() > 0) { context.load_verify_file(verify_file); - context.set_verify_mode(boost::asio::ssl::verify_peer | boost::asio::ssl::verify_fail_if_no_peer_cert | - boost::asio::ssl::verify_client_once); + context.set_verify_mode(asio::ssl::verify_peer | asio::ssl::verify_fail_if_no_peer_cert | asio::ssl::verify_client_once); set_session_id_context = true; } } - void start() - { - if (set_session_id_context) - { + void start() override { + if(set_session_id_context) { // Creating session_id_context from address:port but reversed due to small SSL_MAX_SSL_SESSION_ID_LENGTH session_id_context = std::to_string(config.port) + ':'; session_id_context.append(config.address.rbegin(), config.address.rend()); - SSL_CTX_set_session_id_context(context.native_handle(), - reinterpret_cast(session_id_context.data()), - std::min(session_id_context.size(), - SSL_MAX_SSL_SESSION_ID_LENGTH)); + SSL_CTX_set_session_id_context(context.native_handle(), reinterpret_cast(session_id_context.data()), + std::min(session_id_context.size(), SSL_MAX_SSL_SESSION_ID_LENGTH)); } ServerBase::start(); } protected: - boost::asio::ssl::context context; + asio::ssl::context context; - virtual void accept() - { - //Create new socket for this connection - //Shared_ptr is used to pass temporary objects to the asynchronous functions - auto socket = std::make_shared(*io_service, context); + void accept() override { + auto session = std::make_shared(create_connection(*io_service, context)); - acceptor->async_accept((*socket).lowest_layer(), [this, socket](const boost::system::error_code &ec) - { - //Immediately start accepting a new connection (if io_service hasn't been stopped) - if (ec != boost::asio::error::operation_aborted) - accept(); + acceptor->async_accept(session->connection->socket->lowest_layer(), [this, session](const error_code &ec) { + auto lock = session->connection->handler_runner->continue_lock(); + if(!lock) + return; + if(ec != asio::error::operation_aborted) + this->accept(); - if (!ec) - { - boost::asio::ip::tcp::no_delay option(true); - socket->lowest_layer().set_option(option); + if(!ec) { + asio::ip::tcp::no_delay option(true); + error_code ec; + session->connection->socket->lowest_layer().set_option(option, ec); - //Set timeout on the following boost::asio::ssl::stream::async_handshake - auto timer = get_timeout_timer(socket, config.timeout_request); - socket->async_handshake(boost::asio::ssl::stream_base::server, [this, socket, timer] - (const boost::system::error_code &ec) - { - if (timer) - timer->cancel(); - if (!ec) - read_request_and_content(socket); - else if (on_error) - on_error(std::shared_ptr(new Request(*socket)), ec); + session->connection->set_timeout(config.timeout_request); + session->connection->socket->async_handshake(asio::ssl::stream_base::server, [this, session](const error_code &ec) { + session->connection->cancel_timeout(); + auto lock = session->connection->handler_runner->continue_lock(); + if(!lock) + return; + if(!ec) + this->read_request_and_content(session); + else if(this->on_error) + this->on_error(session->request, ec); }); } - else if (on_error) - on_error(std::shared_ptr(new Request(*socket)), ec); + else if(this->on_error) + this->on_error(session->request, ec); }); } }; -} - -#endif //HTTPS_SERVER_HPP +} // namespace SimpleWeb diff --git a/apps/master/SimpleWeb/status_code.hpp b/apps/master/SimpleWeb/status_code.hpp new file mode 100644 index 000000000..dbed69394 --- /dev/null +++ b/apps/master/SimpleWeb/status_code.hpp @@ -0,0 +1,154 @@ +#pragma once + +#include +#include + +namespace SimpleWeb { + enum class StatusCode { + unknown = 0, + information_continue = 100, + information_switching_protocols, + information_processing, + success_ok = 200, + success_created, + success_accepted, + success_non_authoritative_information, + success_no_content, + success_reset_content, + success_partial_content, + success_multi_status, + success_already_reported, + success_im_used = 226, + redirection_multiple_choices = 300, + redirection_moved_permanently, + redirection_found, + redirection_see_other, + redirection_not_modified, + redirection_use_proxy, + redirection_switch_proxy, + redirection_temporary_redirect, + redirection_permanent_redirect, + client_error_bad_request = 400, + client_error_unauthorized, + client_error_payment_required, + client_error_forbidden, + client_error_not_found, + client_error_method_not_allowed, + client_error_not_acceptable, + client_error_proxy_authentication_required, + client_error_request_timeout, + client_error_conflict, + client_error_gone, + client_error_length_required, + client_error_precondition_failed, + client_error_payload_too_large, + client_error_uri_too_long, + client_error_unsupported_media_type, + client_error_range_not_satisfiable, + client_error_expectation_failed, + client_error_im_a_teapot, + client_error_misdirection_required = 421, + client_error_unprocessable_entity, + client_error_locked, + client_error_failed_dependency, + client_error_upgrade_required = 426, + client_error_precondition_required = 428, + client_error_too_many_requests, + client_error_request_header_fields_too_large = 431, + client_error_unavailable_for_legal_reasons = 451, + server_error_internal_server_error = 500, + server_error_not_implemented, + server_error_bad_gateway, + server_error_service_unavailable, + server_error_gateway_timeout, + server_error_http_version_not_supported, + server_error_variant_also_negotiates, + server_error_insufficient_storage, + server_error_loop_detected, + server_error_not_extended = 510, + server_error_network_authentication_required + }; + + const static std::vector> &status_codes() noexcept { + const static std::vector> status_codes = { + {StatusCode::unknown, ""}, + {StatusCode::information_continue, "100 Continue"}, + {StatusCode::information_switching_protocols, "101 Switching Protocols"}, + {StatusCode::information_processing, "102 Processing"}, + {StatusCode::success_ok, "200 OK"}, + {StatusCode::success_created, "201 Created"}, + {StatusCode::success_accepted, "202 Accepted"}, + {StatusCode::success_non_authoritative_information, "203 Non-Authoritative Information"}, + {StatusCode::success_no_content, "204 No Content"}, + {StatusCode::success_reset_content, "205 Reset Content"}, + {StatusCode::success_partial_content, "206 Partial Content"}, + {StatusCode::success_multi_status, "207 Multi-Status"}, + {StatusCode::success_already_reported, "208 Already Reported"}, + {StatusCode::success_im_used, "226 IM Used"}, + {StatusCode::redirection_multiple_choices, "300 Multiple Choices"}, + {StatusCode::redirection_moved_permanently, "301 Moved Permanently"}, + {StatusCode::redirection_found, "302 Found"}, + {StatusCode::redirection_see_other, "303 See Other"}, + {StatusCode::redirection_not_modified, "304 Not Modified"}, + {StatusCode::redirection_use_proxy, "305 Use Proxy"}, + {StatusCode::redirection_switch_proxy, "306 Switch Proxy"}, + {StatusCode::redirection_temporary_redirect, "307 Temporary Redirect"}, + {StatusCode::redirection_permanent_redirect, "308 Permanent Redirect"}, + {StatusCode::client_error_bad_request, "400 Bad Request"}, + {StatusCode::client_error_unauthorized, "401 Unauthorized"}, + {StatusCode::client_error_payment_required, "402 Payment Required"}, + {StatusCode::client_error_forbidden, "403 Forbidden"}, + {StatusCode::client_error_not_found, "404 Not Found"}, + {StatusCode::client_error_method_not_allowed, "405 Method Not Allowed"}, + {StatusCode::client_error_not_acceptable, "406 Not Acceptable"}, + {StatusCode::client_error_proxy_authentication_required, "407 Proxy Authentication Required"}, + {StatusCode::client_error_request_timeout, "408 Request Timeout"}, + {StatusCode::client_error_conflict, "409 Conflict"}, + {StatusCode::client_error_gone, "410 Gone"}, + {StatusCode::client_error_length_required, "411 Length Required"}, + {StatusCode::client_error_precondition_failed, "412 Precondition Failed"}, + {StatusCode::client_error_payload_too_large, "413 Payload Too Large"}, + {StatusCode::client_error_uri_too_long, "414 URI Too Long"}, + {StatusCode::client_error_unsupported_media_type, "415 Unsupported Media Type"}, + {StatusCode::client_error_range_not_satisfiable, "416 Range Not Satisfiable"}, + {StatusCode::client_error_expectation_failed, "417 Expectation Failed"}, + {StatusCode::client_error_im_a_teapot, "418 I'm a teapot"}, + {StatusCode::client_error_misdirection_required, "421 Misdirected Request"}, + {StatusCode::client_error_unprocessable_entity, "422 Unprocessable Entity"}, + {StatusCode::client_error_locked, "423 Locked"}, + {StatusCode::client_error_failed_dependency, "424 Failed Dependency"}, + {StatusCode::client_error_upgrade_required, "426 Upgrade Required"}, + {StatusCode::client_error_precondition_required, "428 Precondition Required"}, + {StatusCode::client_error_too_many_requests, "429 Too Many Requests"}, + {StatusCode::client_error_request_header_fields_too_large, "431 Request Header Fields Too Large"}, + {StatusCode::client_error_unavailable_for_legal_reasons, "451 Unavailable For Legal Reasons"}, + {StatusCode::server_error_internal_server_error, "500 Internal Server Error"}, + {StatusCode::server_error_not_implemented, "501 Not Implemented"}, + {StatusCode::server_error_bad_gateway, "502 Bad Gateway"}, + {StatusCode::server_error_service_unavailable, "503 Service Unavailable"}, + {StatusCode::server_error_gateway_timeout, "504 Gateway Timeout"}, + {StatusCode::server_error_http_version_not_supported, "505 HTTP Version Not Supported"}, + {StatusCode::server_error_variant_also_negotiates, "506 Variant Also Negotiates"}, + {StatusCode::server_error_insufficient_storage, "507 Insufficient Storage"}, + {StatusCode::server_error_loop_detected, "508 Loop Detected"}, + {StatusCode::server_error_not_extended, "510 Not Extended"}, + {StatusCode::server_error_network_authentication_required, "511 Network Authentication Required"}}; + return status_codes; + } + + inline StatusCode status_code(const std::string &status_code_str) noexcept { + for(auto &status_code : status_codes()) { + if(status_code.second == status_code_str) + return status_code.first; + } + return StatusCode::unknown; + } + + inline const std::string &status_code(StatusCode status_code_enum) noexcept { + for(auto &status_code : status_codes()) { + if(status_code.first == status_code_enum) + return status_code.second; + } + return status_codes()[0].second; + } +} // namespace SimpleWeb diff --git a/apps/master/SimpleWeb/utility.hpp b/apps/master/SimpleWeb/utility.hpp new file mode 100644 index 000000000..af56209de --- /dev/null +++ b/apps/master/SimpleWeb/utility.hpp @@ -0,0 +1,340 @@ +#pragma once + +#include "status_code.hpp" +#include +#include +#include +#include +#include + +namespace SimpleWeb { + inline bool case_insensitive_equal(const std::string &str1, const std::string &str2) noexcept { + return str1.size() == str2.size() && + std::equal(str1.begin(), str1.end(), str2.begin(), [](char a, char b) { + return tolower(a) == tolower(b); + }); + } + class CaseInsensitiveEqual { + public: + bool operator()(const std::string &str1, const std::string &str2) const noexcept { + return case_insensitive_equal(str1, str2); + } + }; + // Based on https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x/2595226#2595226 + class CaseInsensitiveHash { + public: + size_t operator()(const std::string &str) const noexcept { + size_t h = 0; + std::hash hash; + for(auto c : str) + h ^= hash(tolower(c)) + 0x9e3779b9 + (h << 6) + (h >> 2); + return h; + } + }; + + using CaseInsensitiveMultimap = std::unordered_multimap; + + /// Percent encoding and decoding + class Percent { + public: + /// Returns percent-encoded string + static std::string encode(const std::string &value) noexcept { + static auto hex_chars = "0123456789ABCDEF"; + + std::string result; + result.reserve(value.size()); // Minimum size of result + + for(auto &chr : value) { + if(chr == ' ') + result += '+'; + else if(chr == '!' || chr == '#' || chr == '$' || (chr >= '&' && chr <= ',') || (chr >= '/' && chr <= ';') || chr == '=' || chr == '?' || chr == '@' || chr == '[' || chr == ']') + result += std::string("%") + hex_chars[chr >> 4] + hex_chars[chr & 15]; + else + result += chr; + } + + return result; + } + + /// Returns percent-decoded string + static std::string decode(const std::string &value) noexcept { + std::string result; + result.reserve(value.size() / 3 + (value.size() % 3)); // Minimum size of result + + for(size_t i = 0; i < value.size(); ++i) { + auto &chr = value[i]; + if(chr == '%' && i + 2 < value.size()) { + auto hex = value.substr(i + 1, 2); + auto decoded_chr = static_cast(std::strtol(hex.c_str(), nullptr, 16)); + result += decoded_chr; + i += 2; + } + else if(chr == '+') + result += ' '; + else + result += chr; + } + + return result; + } + }; + + /// Query string creation and parsing + class QueryString { + public: + /// Returns query string created from given field names and values + static std::string create(const CaseInsensitiveMultimap &fields) noexcept { + std::string result; + + bool first = true; + for(auto &field : fields) { + result += (!first ? "&" : "") + field.first + '=' + Percent::encode(field.second); + first = false; + } + + return result; + } + + /// Returns query keys with percent-decoded values. + static CaseInsensitiveMultimap parse(const std::string &query_string) noexcept { + CaseInsensitiveMultimap result; + + if(query_string.empty()) + return result; + + size_t name_pos = 0; + auto name_end_pos = std::string::npos; + auto value_pos = std::string::npos; + for(size_t c = 0; c < query_string.size(); ++c) { + if(query_string[c] == '&') { + auto name = query_string.substr(name_pos, (name_end_pos == std::string::npos ? c : name_end_pos) - name_pos); + if(!name.empty()) { + auto value = value_pos == std::string::npos ? std::string() : query_string.substr(value_pos, c - value_pos); + result.emplace(std::move(name), Percent::decode(value)); + } + name_pos = c + 1; + name_end_pos = std::string::npos; + value_pos = std::string::npos; + } + else if(query_string[c] == '=') { + name_end_pos = c; + value_pos = c + 1; + } + } + if(name_pos < query_string.size()) { + auto name = query_string.substr(name_pos, name_end_pos - name_pos); + if(!name.empty()) { + auto value = value_pos >= query_string.size() ? std::string() : query_string.substr(value_pos); + result.emplace(std::move(name), Percent::decode(value)); + } + } + + return result; + } + }; + + class HttpHeader { + public: + /// Parse header fields + static CaseInsensitiveMultimap parse(std::istream &stream) noexcept { + CaseInsensitiveMultimap result; + std::string line; + getline(stream, line); + size_t param_end; + while((param_end = line.find(':')) != std::string::npos) { + size_t value_start = param_end + 1; + if(value_start < line.size()) { + if(line[value_start] == ' ') + value_start++; + if(value_start < line.size()) + result.emplace(line.substr(0, param_end), line.substr(value_start, line.size() - value_start - 1)); + } + + getline(stream, line); + } + return result; + } + }; + + class RequestMessage { + public: + /// Parse request line and header fields + static bool parse(std::istream &stream, std::string &method, std::string &path, std::string &query_string, std::string &version, CaseInsensitiveMultimap &header) noexcept { + header.clear(); + std::string line; + getline(stream, line); + size_t method_end; + if((method_end = line.find(' ')) != std::string::npos) { + method = line.substr(0, method_end); + + size_t query_start = std::string::npos; + size_t path_and_query_string_end = std::string::npos; + for(size_t i = method_end + 1; i < line.size(); ++i) { + if(line[i] == '?' && (i + 1) < line.size()) + query_start = i + 1; + else if(line[i] == ' ') { + path_and_query_string_end = i; + break; + } + } + if(path_and_query_string_end != std::string::npos) { + if(query_start != std::string::npos) { + path = line.substr(method_end + 1, query_start - method_end - 2); + query_string = line.substr(query_start, path_and_query_string_end - query_start); + } + else + path = line.substr(method_end + 1, path_and_query_string_end - method_end - 1); + + size_t protocol_end; + if((protocol_end = line.find('/', path_and_query_string_end + 1)) != std::string::npos) { + if(line.compare(path_and_query_string_end + 1, protocol_end - path_and_query_string_end - 1, "HTTP") != 0) + return false; + version = line.substr(protocol_end + 1, line.size() - protocol_end - 2); + } + else + return false; + + header = HttpHeader::parse(stream); + } + else + return false; + } + else + return false; + return true; + } + }; + + class ResponseMessage { + public: + /// Parse status line and header fields + static bool parse(std::istream &stream, std::string &version, std::string &status_code, CaseInsensitiveMultimap &header) noexcept { + header.clear(); + std::string line; + getline(stream, line); + size_t version_end = line.find(' '); + if(version_end != std::string::npos) { + if(5 < line.size()) + version = line.substr(5, version_end - 5); + else + return false; + if((version_end + 1) < line.size()) + status_code = line.substr(version_end + 1, line.size() - (version_end + 1) - 1); + else + return false; + + header = HttpHeader::parse(stream); + } + else + return false; + return true; + } + }; + + class ContentDisposition { + public: + /// Can be used to parse the Content-Disposition header field value when + /// clients are posting requests with enctype="multipart/form-data" + static CaseInsensitiveMultimap parse(const std::string &line) { + CaseInsensitiveMultimap result; + + size_t para_start_pos = 0; + size_t para_end_pos = std::string::npos; + size_t value_start_pos = std::string::npos; + for(size_t c = 0; c < line.size(); ++c) { + if(para_start_pos != std::string::npos) { + if(para_end_pos == std::string::npos) { + if(line[c] == ';') { + result.emplace(line.substr(para_start_pos, c - para_start_pos), std::string()); + para_start_pos = std::string::npos; + } + else if(line[c] == '=') + para_end_pos = c; + } + else { + if(value_start_pos == std::string::npos) { + if(line[c] == '"' && c + 1 < line.size()) + value_start_pos = c + 1; + } + else if(line[c] == '"') { + result.emplace(line.substr(para_start_pos, para_end_pos - para_start_pos), line.substr(value_start_pos, c - value_start_pos)); + para_start_pos = std::string::npos; + para_end_pos = std::string::npos; + value_start_pos = std::string::npos; + } + } + } + else if(line[c] != ' ' && line[c] != ';') + para_start_pos = c; + } + if(para_start_pos != std::string::npos && para_end_pos == std::string::npos) + result.emplace(line.substr(para_start_pos), std::string()); + + return result; + } + }; +} // namespace SimpleWeb + +#ifdef __SSE2__ +#include +namespace SimpleWeb { + inline void spin_loop_pause() noexcept { _mm_pause(); } +} // namespace SimpleWeb +// TODO: need verification that the following checks are correct: +#elif defined(_MSC_VER) && _MSC_VER >= 1800 && (defined(_M_X64) || defined(_M_IX86)) +#include +namespace SimpleWeb { + inline void spin_loop_pause() noexcept { _mm_pause(); } +} // namespace SimpleWeb +#else +namespace SimpleWeb { + inline void spin_loop_pause() noexcept {} +} // namespace SimpleWeb +#endif + +namespace SimpleWeb { + /// Makes it possible to for instance cancel Asio handlers without stopping asio::io_service + class ScopeRunner { + /// Scope count that is set to -1 if scopes are to be canceled + std::atomic count; + + public: + class SharedLock { + friend class ScopeRunner; + std::atomic &count; + SharedLock(std::atomic &count) noexcept : count(count) {} + SharedLock &operator=(const SharedLock &) = delete; + SharedLock(const SharedLock &) = delete; + + public: + ~SharedLock() noexcept { + count.fetch_sub(1); + } + }; + + ScopeRunner() noexcept : count(0) {} + + /// Returns nullptr if scope should be exited, or a shared lock otherwise + std::unique_ptr continue_lock() noexcept { + long expected = count; + while(expected >= 0 && !count.compare_exchange_weak(expected, expected + 1)) + spin_loop_pause(); + + if(expected < 0) + return nullptr; + else + return std::unique_ptr(new SharedLock(count)); + } + + /// Blocks until all shared locks are released, then prevents future shared locks + void stop() noexcept { + long expected = 0; + while(!count.compare_exchange_weak(expected, -1)) { + if(expected < 0) + return; + expected = 0; + spin_loop_pause(); + } + } + }; +} // namespace SimpleWeb \ No newline at end of file