diff --git a/README.md b/README.md index 33538a9..390d06b 100644 --- a/README.md +++ b/README.md @@ -64,46 +64,58 @@ Then check out our other commercial and open source solutions: INSTALLING ========== -AMQP-CPP comes with an optional Linux-only TCP module that takes care of the network part required for the AMQP-CPP core library. If you use this module, you are required to link with `pthread`. -There are two methods to compile AMQP-CPP: CMake and Make. CMake is platform portable, but the Makefile only works on Linux. Building of a shared library is currently not supported on Windows. +AMQP-CPP comes with an optional Linux-only TCP module that takes care of the +network part required for the AMQP-CPP core library. If you use this module, you +are required to link with `pthread` and `dl`. + +There are two methods to compile AMQP-CPP: CMake and Make. CMake is platform portable +and works on all systems, while the Makefile only works on Linux. Building of a shared +library is currently not supported on Windows. After building there are two relevant files to include when using the library. -File|Include when? -----|------------ -amqpcpp.h|Always -amqpcpp/linux_tcp.h|If using the Linux-only TCP module + File | Include when? +---------------------|-------------------------------------------------------- + amqpcpp.h | Always + amqpcpp/linux_tcp.h | If using the Linux-only TCP module On Windows you are required to define `NOMINMAX` when compiling code that includes public AMQP-CPP header files. -## CMake -The CMake file supports both building and installing. You can choose not to use the install functionality, and instead manually use the build output at `bin/`. Keep in mind that the TCP module is only supported for Linux. An example install method would be: -``` bash +## Using cmake + +The CMake file supports both building and installing. You can choose not to use +the install functionality, and instead manually use the build output at `bin/`. Keep +in mind that the TCP module is only supported for Linux. An example install method +would be: + +```bash mkdir build cd build cmake .. [-DAMQP-CPP_AMQBUILD_SHARED] [-DAMQP-CPP_LINUX_TCP] cmake --build .. --target install ``` -Option|Default|Meaning -------|-------|------- -AMQP-CPP_BUILD_SHARED|OFF|Static lib(ON) or shared lib(OFF)? Shared is not supported on Windows. -AMQP-CPP_LINUX_TCP|OFF|Should the Linux-only TCP module be built? + Option | Default | Meaning +-------------------------|---------|----------------------------------------------------------------------- + AMQP-CPP_BUILD_SHARED | OFF | Static lib(ON) or shared lib(OFF)? Shared is not supported on Windows. + AMQP-CPP_LINUX_TCP | OFF | Should the Linux-only TCP module be built? -## Make -Installing the library is as easy -as running `make` and `make install`. This will install the full version of -the AMQP-CPP, including the system specific TCP module. If you do not need the -additional TCP module (because you take care of handling the network stuff -yourself), you can also compile a pure form of the library. Use `make pure` -and `make install` for that. +## Using make + +Compiling and installing AMQP-CPP with make is as easy as running `make` and +then `make install`. This will install the full version of AMQP-CPP, including +the system specific TCP module. If you do not need the additional TCP module +(because you take care of handling the network stuff yourself), you can also +compile a pure form of the library. Use `make pure` and `make install` for that. When you compile an application that uses the AMQP-CPP library, do not forget to link with the library. For gcc and clang the linker flag is -lamqpcpp. If you use the fullblown version of AMQP-CPP (with the TCP module), you also -need to pass a -lpthread linker flag, because the TCP module uses a thread -for running an asynchronous and non-blocking DNS hostname lookup. +need to pass the -lpthread and -ldl linker flags, because the TCP module uses a +thread for running an asynchronous and non-blocking DNS hostname lookup, and it +may dynamically look up functions from the openssl library if a secure connection +to RabbitMQ has to be set up. HOW TO USE AMQP-CPP @@ -394,6 +406,88 @@ channel.declareQueue("my-queue"); channel.bindQueue("my-exchange", "my-queue", "my-routing-key"); ```` +SECURE CONNECTIONS +================== + +The TCP module of AMQP-CPP also supports setting up secure connections. If your +RabbitMQ server accepts SSL connections, you can specify the address to your +server using the amqps:// protocol: + +````c++ +// init the SSL library (this works for openssl 1.1, for openssl 1.0 use SSL_library_init()) +OPENSSL_init_ssl(0, NULL); + +// address of the server (secure!) +AMQP::Address address("amqps://guest:guest@localhost/vhost"); + +// create a AMQP connection object +AMQP::TcpConnection connection(&myHandler, address); +```` + +There are two things to take care of if you want to create a secure connection: +(1) you must link your application with the -lssl flag (or use dlopen()), and (2) +you must initialize the openssl library by calling OPENSSL_init_ssl(). This +initializating must take place before you let you application connect to RabbitMQ. +This is necessary because AMQP-CPP needs access to the openssl library to set up +secure connections. It can only access this library if you have linked your +application with this library, or if you have loaded this library at runtime +using dlopen()). + +Linking openssl is the normal thing to do. You just have to add the `-lssl` flag +to your linker. If you however do not want to link your application with openssl, +you can also load the openssl library at runtime, and pass in the pointer to the +handle to AMQP-CPP: + +````c++ +// dynamically open the openssl library +void *handle = dlopen("/path/to/openssl.so", RTLD_LAZY); + +// tell AMQP-CPP library where the handle to openssl can be found +AMQP::openssl(handle); + +// @todo call functions to initialize openssl, and create the AMQP connection +// (see exampe above) +```` + +By itself, AMQP-CPP does not check if the created TLS connection is sufficient +secure. Whether the certificate is expired, self-signed, missing or invalid: for +AMQP-CPP it all doesn't matter and the connection is simply permitted. If you +want to be more strict (for example: if you want to verify the server's certificate), +you must do this yourself by implementing the "onSecured()" method in your handler +object: + +````c++ +#include + +class MyTcpHandler : public AMQP::TcpHandler +{ + /** + * Method that is called right after the TLS connection has been created. + * In this method you can check the connection properties (like the certificate) + * and return false if you find it not secure enough + * @param connection the connection that has just completed the tls handshake + * @param ssl SSL structure from the openssl library + * @return bool true if connection is secure enough to start the AMQP protocol + */ + virtual bool onSecure(AMQP::TcpConnection *connection, const SSL *ssl) override + { + // @todo call functions from the openssl library to check the certificate, + // like SSL_get_peer_certificate() or SSL_get_verify_result(). + // For now we always allow the connection to proceed + return true; + } + + /** + * All other methods (like onConnected(), onError(), etc) are left out of this + * example, but would be here if this was an actual user space handler class. + */ +}; +```` + +The SSL pointer that is passed to the onSecured() method refers to the "SSL" +structure from the openssl library. + + EXISTING EVENT LOOPS ==================== diff --git a/examples/libev.cpp b/examples/libev.cpp index 9ee0e2a..7bfd1af 100644 --- a/examples/libev.cpp +++ b/examples/libev.cpp @@ -14,6 +14,7 @@ #include #include #include +#include /** * Custom handler @@ -66,20 +67,35 @@ int main() MyHandler handler(loop); // init the SSL library +#if OPENSSL_VERSION_NUMBER < 0x10100000L SSL_library_init(); +#else + OPENSSL_init_ssl(0, NULL); +#endif // make a connection - AMQP::Address address("amqps://guest:guest@localhost/"); + AMQP::Address address("amqp://guest:guest@localhost/"); +// AMQP::Address address("amqps://guest:guest@localhost/"); AMQP::TcpConnection connection(&handler, address); // we need a channel too AMQP::TcpChannel channel(&connection); // create a temporary queue - channel.declareQueue(AMQP::exclusive).onSuccess([&connection](const std::string &name, uint32_t messagecount, uint32_t consumercount) { + channel.declareQueue(AMQP::exclusive).onSuccess([&connection, &channel](const std::string &name, uint32_t messagecount, uint32_t consumercount) { // report the name of the temporary queue std::cout << "declared queue " << name << std::endl; + + // close the channel + channel.close().onSuccess([&connection, &channel]() { + + // report that channel was closed + std::cout << "channel closed" << std::endl; + + // close the connection + connection.close(); + }); }); // run the loop diff --git a/include/amqpcpp.h b/include/amqpcpp.h index 39ec8f5..8079b1f 100644 --- a/include/amqpcpp.h +++ b/include/amqpcpp.h @@ -79,4 +79,4 @@ #include "amqpcpp/connectionhandler.h" #include "amqpcpp/connectionimpl.h" #include "amqpcpp/connection.h" - +#include "amqpcpp/openssl.h" diff --git a/include/amqpcpp/address.h b/include/amqpcpp/address.h index c8faaa1..e9bab9a 100644 --- a/include/amqpcpp/address.h +++ b/include/amqpcpp/address.h @@ -27,7 +27,7 @@ private: * The auth method * @var bool */ - bool _secure; + bool _secure = false; /** * Login data (username + password) @@ -52,6 +52,16 @@ private: * @var std::string */ std::string _vhost; + + + /** + * The default port + * @return uint16_t + */ + uint16_t defaultport() const + { + return _secure ? 5671 : 5672; + } public: /** @@ -67,13 +77,13 @@ public: const char *last = data + size; // must start with ampqs:// to have a secure connection (and we also assign a different default port) - _secure = strncmp(data, "amqps://", 8) == 0; - - // default port changes for secure connections - if (_secure) _port = 5671; + if (strncmp(data, "amqps://", 8) == 0) _secure = true; // otherwise protocol must be amqp:// else if (strncmp(data, "amqp://", 7) != 0) throw std::runtime_error("AMQP address should start with \"amqp://\" or \"amqps://\""); + + // assign default port (we may overwrite it later) + _port = defaultport(); // begin of the string was parsed data += _secure ? 8 : 7; @@ -299,9 +309,15 @@ public: { // start with the protocol and login stream << (address._secure ? "amqps://" : "amqp://"); + + // do we have a login? + if (address._login) stream << address._login << "@"; + + // write hostname + stream << address._hostname; // do we need a special portnumber? - if (address._port != 5672) stream << ":" << address._port; + if (address._port != address.defaultport()) stream << ":" << address._port; // append default vhost stream << "/"; diff --git a/include/amqpcpp/linux_tcp/tcphandler.h b/include/amqpcpp/linux_tcp/tcphandler.h index 5549425..9b0c87d 100644 --- a/include/amqpcpp/linux_tcp/tcphandler.h +++ b/include/amqpcpp/linux_tcp/tcphandler.h @@ -6,7 +6,7 @@ * class. * * @author Emiel Bruijntjes - * @copyright 2015 Copernica BV + * @copyright 2015 - 2018 Copernica BV */ /** @@ -14,6 +14,11 @@ */ #pragma once +/** + * Dependencies + */ +#include + /** * Set up namespace */ @@ -35,9 +40,33 @@ public: */ virtual ~TcpHandler() = default; + /** + * Method that is called after a TCP connection has been set up and the initial + * TLS handshake is finished too, but right before the AMQP login handshake is + * going to take place and the first data is going to be sent over the connection. + * This method allows you to inspect the TLS certificate and other connection + * properties, and to break up the connection if you find it not secure enough. + * The default implementation considers all connections to be secure, even if the + * connection has a self-signed or even invalid certificate. To be more strict, + * override this method, inspect the certificate and return false if you do not + * want to use the connection. The passed in SSL pointer is a pointer to a SSL + * structure from the openssl library. This method is only called for secure + * connections (connection with an amqps:// address). + * @param connection The connection for which TLS was just started + * @param ssl Pointer to the SSL structure that can be inspected + * @return bool True to proceed / accept the connection, false to break up + */ + virtual bool onSecured(TcpConnection *connection, const SSL *ssl) + { + // default implementation: do not inspect anything, just allow the connection + return true; + } + /** * Method that is called when the heartbeat frequency is negotiated - * between the server and the client. + * between the server and the client. Applications can override this method + * if they want to use a different heartbeat interval (for example: return 0 + * to disable heartbeats) * @param connection The connection that suggested a heartbeat interval * @param interval The suggested interval from the server * @return uint16_t The interval to use @@ -51,7 +80,9 @@ public: } /** - * Method that is called when the TCP connection ends up in a connected state + * Method that is called when the AMQP connection ends up in a connected state + * This method is called after the TCP connection has been set up, the (optional) + * secure TLS connection, and the AMQP login handshake has been completed. * @param connection The TCP connection */ virtual void onConnected(TcpConnection *connection) {} diff --git a/include/amqpcpp/login.h b/include/amqpcpp/login.h index b9276da..f805e18 100644 --- a/include/amqpcpp/login.h +++ b/include/amqpcpp/login.h @@ -3,7 +3,7 @@ * * This class combines login, password and vhost * - * @copyright 2014 Copernica BV + * @copyright 2014 - 2018 Copernica BV */ /** @@ -65,7 +65,25 @@ public: /** * Destructor */ - virtual ~Login() {} + virtual ~Login() = default; + + /** + * Cast to boolean: is the login set? + * @return bool + */ + operator bool () const + { + return !_user.empty() || !_password.empty(); + } + + /** + * Negate operator: is it not set + * @return bool + */ + bool operator! () const + { + return _user.empty() && _password.empty(); + } /** * Retrieve the user name @@ -143,7 +161,7 @@ public: friend std::ostream &operator<<(std::ostream &stream, const Login &login) { // write username and password - return stream << login._user << "@" << login._password; + return stream << login._user << ":" << login._password; } }; diff --git a/include/amqpcpp/openssl.h b/include/amqpcpp/openssl.h new file mode 100644 index 0000000..f4ac1e1 --- /dev/null +++ b/include/amqpcpp/openssl.h @@ -0,0 +1,37 @@ +/** + * OpenSSL.h + * + * Function to set openssl features + * + * @author Emiel Bruijntjes + * @copyright 2018 Copernica BV + */ + +/** + * Include guard + */ +#pragma once + +/** + * Begin of namespace + */ +namespace AMQP { + +/** + * To make secure "amqps://" connections, AMQP-CPP relies on functions from the + * openssl library. It your application is dynamically linked to openssl (because + * it was compiled with the "-lssl" flag), this works flawlessly because AMQPCPP + * can then locate the openssl symbols in its own project space. However, if the + * openssl library was not linked, you either cannot use amqps:// connections, + * or you have to supply a handle to the openssl library yourself, using the + * following method. + * + * @param handle Handle returned by dlopen() that has access to openssl + */ +void openssl(void *handle); + +/** + * End of namespace + */ +} + diff --git a/src/connectionimpl.cpp b/src/connectionimpl.cpp index 332e106..ab9f781 100644 --- a/src/connectionimpl.cpp +++ b/src/connectionimpl.cpp @@ -3,7 +3,7 @@ * * Implementation of an AMQP connection * - * @copyright 2014 - 2017 Copernica BV + * @copyright 2014 - 2018 Copernica BV */ #include "includes.h" #include "protocolheaderframe.h" @@ -147,7 +147,7 @@ uint64_t ConnectionImpl::parse(const Buffer &buffer) // data we need for the next frame, otherwise we need at least 7 // bytes for processing the header of the next frame _expected = receivedFrame.header() ? (uint32_t)receivedFrame.totalSize() : 7; - + // we're ready for now return processed; } diff --git a/src/linux_tcp/function.h b/src/linux_tcp/function.h index cb8bf89..899f536 100644 --- a/src/linux_tcp/function.h +++ b/src/linux_tcp/function.h @@ -115,10 +115,11 @@ private: public: /** * Constructor + * @param handle Handle to access openssl * @param name Name of the function */ - Function(const char *name) : - _method(dlsym(RTLD_DEFAULT, name)) {} + Function(void *handle, const char *name) : + _method(dlsym(handle, name)) {} /** * Destructor diff --git a/src/linux_tcp/openssl.cpp b/src/linux_tcp/openssl.cpp index 39bff69..2afc378 100644 --- a/src/linux_tcp/openssl.cpp +++ b/src/linux_tcp/openssl.cpp @@ -11,11 +11,33 @@ */ #include "openssl.h" #include "function.h" +#include "amqpcpp/openssl.h" /** - * Begin of namespace + * The handle to access openssl (the result of dlopen("openssl.so")) + * By default we set this to RTLD_DEFAULT, so that AMQP-CPP checks the internal process space */ -namespace AMQP { namespace OpenSSL { +static void *handle = RTLD_DEFAULT; + +/** + * Just the AMQP namespace + */ +namespace AMQP { + +/** + * Function to set the openssl handle + * @param ptr + */ +void openssl(void *ptr) +{ + // store handle + handle = ptr; +} + +/** + * Begin of AMQP::OpenSSL namespace + */ +namespace OpenSSL { /** * Is the openssl library loaded? @@ -24,7 +46,7 @@ namespace AMQP { namespace OpenSSL { bool valid() { // create a function - static Function func("SSL_CTX_new"); + static Function func(handle, "SSL_CTX_new"); // we need a library return func; @@ -37,13 +59,13 @@ bool valid() const SSL_METHOD *TLS_client_method() { // create a function that loads the method - static Function func("TLS_client_method"); + static Function func(handle, "TLS_client_method"); // call the openssl function if (func) return func(); // older openssl libraries do not have this function, so we try to load an other function - static Function old("SSLv23_client_method"); + static Function old(handle, "SSLv23_client_method"); // call the old one return old(); @@ -57,7 +79,7 @@ const SSL_METHOD *TLS_client_method() SSL_CTX *SSL_CTX_new(const SSL_METHOD *method) { // create a function - static Function func("SSL_CTX_new"); + static Function func(handle, "SSL_CTX_new"); // call the openssl function return func(method); @@ -73,7 +95,7 @@ SSL_CTX *SSL_CTX_new(const SSL_METHOD *method) int SSL_read(SSL *ssl, void *buf, int num) { // create a function - static Function func("SSL_read"); + static Function func(handle, "SSL_read"); // call the openssl function return func(ssl, buf, num); @@ -89,7 +111,7 @@ int SSL_read(SSL *ssl, void *buf, int num) int SSL_write(SSL *ssl, const void *buf, int num) { // create a function - static Function func("SSL_write"); + static Function func(handle, "SSL_write"); // call the openssl function return func(ssl, buf, num); @@ -104,12 +126,27 @@ int SSL_write(SSL *ssl, const void *buf, int num) int SSL_set_fd(SSL *ssl, int fd) { // create a function - static Function func("SSL_set_fd"); + static Function func(handle, "SSL_set_fd"); // call the openssl function return func(ssl, fd); } +/** + * The number of bytes availabe in the ssl struct that have been read + * from the socket, but that have not been returned the SSL_read() + * @param ssl SSL object + * @return int number of bytes + */ +int SSL_pending(const SSL *ssl) +{ + // create a function + static Function func(handle, "SSL_pending"); + + // call the openssl function + return func(ssl); +} + /** * Free an allocated ssl context * @param ctx @@ -117,7 +154,7 @@ int SSL_set_fd(SSL *ssl, int fd) void SSL_CTX_free(SSL_CTX *ctx) { // create a function - static Function func("SSL_CTX_free"); + static Function func(handle, "SSL_CTX_free"); // call the openssl function return func(ctx); @@ -131,7 +168,7 @@ void SSL_CTX_free(SSL_CTX *ctx) void SSL_free(SSL *ssl) { // create a function - static Function func("SSL_free"); + static Function func(handle, "SSL_free"); // call the openssl function return func(ssl); @@ -145,7 +182,7 @@ void SSL_free(SSL *ssl) SSL *SSL_new(SSL_CTX *ctx) { // create a function - static Function func("SSL_new"); + static Function func(handle, "SSL_new"); // call the openssl function return func(ctx); @@ -159,7 +196,7 @@ SSL *SSL_new(SSL_CTX *ctx) int SSL_up_ref(SSL *ssl) { // create a function - static Function func("SSL_up_ref"); + static Function func(handle, "SSL_up_ref"); // call the openssl function if it exists if (func) return func(ssl); @@ -177,7 +214,7 @@ int SSL_up_ref(SSL *ssl) int SSL_shutdown(SSL *ssl) { // create a function - static Function func("SSL_shutdown"); + static Function func(handle, "SSL_shutdown"); // call the openssl function return func(ssl); @@ -190,7 +227,7 @@ int SSL_shutdown(SSL *ssl) void SSL_set_connect_state(SSL *ssl) { // create a function - static Function func("SSL_set_connect_state"); + static Function func(handle, "SSL_set_connect_state"); // call the openssl function func(ssl); @@ -205,7 +242,21 @@ void SSL_set_connect_state(SSL *ssl) int SSL_do_handshake(SSL *ssl) { // create a function - static Function func("SSL_do_handshake"); + static Function func(handle, "SSL_do_handshake"); + + // call the openssl function + return func(ssl); +} + +/** + * Obtain shutdown statue for TLS/SSL I/O operation + * @param ssl SSL object + * @return int returns error values + */ +int SSL_get_shutdown(const SSL *ssl) +{ + // create a function + static Function func(handle, "SSL_get_shutdown"); // call the openssl function return func(ssl); @@ -220,7 +271,7 @@ int SSL_do_handshake(SSL *ssl) int SSL_get_error(const SSL *ssl, int ret) { // create a function - static Function func("SSL_get_error"); + static Function func(handle, "SSL_get_error"); // call the openssl function return func(ssl, ret); @@ -237,12 +288,28 @@ int SSL_get_error(const SSL *ssl, int ret) long SSL_ctrl(SSL *ssl, int cmd, long larg, void *parg) { // create a function - static Function func("SSL_ctrl"); + static Function func(handle, "SSL_ctrl"); // call the openssl function return func(ssl, cmd, larg, parg); } + +/** + * Set the certificate file to be used by the connection + * @param ssl ssl structure + * @param file filename + * @param type type of file + * @return int + */ +int SSL_use_certificate_file(SSL *ssl, const char *file, int type) +{ + // create a function + static Function func(handle, "SSL_use_certificate_file"); + // call the openssl function + return func(ssl, file, type); +} + /** * End of namespace */ diff --git a/src/linux_tcp/openssl.h b/src/linux_tcp/openssl.h index 387654d..e3461e2 100644 --- a/src/linux_tcp/openssl.h +++ b/src/linux_tcp/openssl.h @@ -40,9 +40,11 @@ int SSL_do_handshake(SSL *ssl); int SSL_read(SSL *ssl, void *buf, int num); int SSL_write(SSL *ssl, const void *buf, int num); int SSL_shutdown(SSL *ssl); +int SSL_pending(const SSL *ssl); int SSL_set_fd(SSL *ssl, int fd); +int SSL_get_shutdown(const SSL *ssl); int SSL_get_error(const SSL *ssl, int ret); -int SSL_up_ref(SSL *ssl); +int SSL_use_certificate_file(SSL *ssl, const char *file, int type); void SSL_set_connect_state(SSL *ssl); void SSL_CTX_free(SSL_CTX *ctx); void SSL_free(SSL *ssl); diff --git a/src/linux_tcp/sslconnected.h b/src/linux_tcp/sslconnected.h index b6e5d21..7d4b90e 100644 --- a/src/linux_tcp/sslconnected.h +++ b/src/linux_tcp/sslconnected.h @@ -59,17 +59,23 @@ private: * Are we now busy with sending or receiving? * @var enum */ - enum { + enum State { state_idle, state_sending, state_receiving } _state; /** - * Is the object already closed? + * Should we close the connection after we've finished all operations? * @var bool */ bool _closed = false; + + /** + * Have we reported the final instruction to the user? + * @var bool + */ + bool _finalized = false; /** * Cached reallocation instruction @@ -79,25 +85,37 @@ private: /** - * Helper method to report an error - * @return bool Was an error reported? + * Close the connection + * @return bool */ - bool reportError() + bool close() { - // we have an error - report this to the user - _handler->onError(_connection, strerror(errno)); + // do nothing if already closed + if (_socket < 0) return false; + + // and stop monitoring it + _handler->monitor(_connection, _socket, 0); + + // close the socket + ::close(_socket); + + // forget filedescriptor + _socket = -1; // done return true; } - + /** - * Construct the next state + * Construct the final state * @param monitor Object that monitors whether connection still exists * @return TcpState* */ - TcpState *nextState(const Monitor &monitor) + TcpState *finalstate(const Monitor &monitor) { + // close the socket if it is still open + close(); + // if the object is still in a valid state, we can move to the close-state, // otherwise there is no point in moving to a next state return monitor.valid() ? new TcpClosed(this) : nullptr; @@ -113,25 +131,22 @@ private: // if we still have an outgoing buffer we want to send out data if (_out) { - // we still have a buffer with outgoing data - _state = state_sending; - // let's wait until the socket becomes writable _handler->monitor(_connection, _socket, readable | writable); } else if (_closed) { - // we forget the current handler to prevent that things are changed - _handler = nullptr; - // start the state that closes the connection - return new SslShutdown(_connection, _socket, _ssl, _handler); + auto *nextstate = new SslShutdown(_connection, _socket, std::move(_ssl), _finalized, _handler); + + // we forget the current socket to prevent that it gets destructed + _socket = -1; + + // report the next state + return nextstate; } else { - // outgoing buffer is empty, we're idle again waiting for further input - _state = state_idle; - // let's wait until the socket becomes readable _handler->monitor(_connection, _socket, readable); } @@ -141,51 +156,76 @@ private: } /** - * Method to repeat the previous call - * @param result result of an earlier openssl operation + * Method to repeat the previous call\ + * @param monitor monitor to check if connection object still exists + * @param state the state that we were in + * @param result result of an earlier SSL_get_error call * @return TcpState* */ - TcpState *repeat(int result) + TcpState *repeat(const Monitor &monitor, enum State state, int error) { - // error was returned, so we must investigate what is going on - auto error = OpenSSL::SSL_get_error(_ssl, result); - // check the error switch (error) { case SSL_ERROR_WANT_READ: + // remember state + _state = state; + // the operation must be repeated when readable _handler->monitor(_connection, _socket, readable); - return this; + + // allow chaining + return monitor.valid() ? this : nullptr; case SSL_ERROR_WANT_WRITE: + // remember state + _state = state; + // wait until socket becomes writable again _handler->monitor(_connection, _socket, readable | writable); - return this; + + // allow chaining + return monitor.valid() ? this : nullptr; + + case SSL_ERROR_NONE: + // we're ready for the next instruction from userspace + _state = state_idle; + + // turns out no error occured, an no action has to be rescheduled + _handler->monitor(_connection, _socket, _out || _closed ? readable | writable : readable); + + // allow chaining + return monitor.valid() ? this : nullptr; default: - - // @todo check how to handle this - return this; + // if we have already reported an error to user space, we can go to the final state right away + if (_finalized) return finalstate(monitor); + + // remember that we've sent out an error + _finalized = true; + + // tell the handler + _handler->onError(_connection, "ssl error"); + + // go to the final state + return finalstate(monitor); } } /** * Parse the received buffer - * @param size + * @param monitor object to check the existance of the connection object + * @param size number of bytes available * @return TcpState */ - TcpState *parse(size_t size) + TcpState *parse(const Monitor &monitor, size_t size) { // we need a local copy of the buffer - because it is possible that "this" // object gets destructed halfway through the call to the parse() method TcpInBuffer buffer(std::move(_in)); - // because the object might soon be destructed, we create a monitor to check this - Monitor monitor(this); - // parse the buffer auto processed = _connection->parse(buffer); - + // "this" could be removed by now, check this if (!monitor.valid()) return nullptr; @@ -196,7 +236,10 @@ private: _in = std::move(buffer); // do we have to reallocate? - if (_reallocate) _in.reallocate(_reallocate); + if (!_reallocate) return this; + + // reallocate the buffer + _in.reallocate(_reallocate); // we can remove the reallocate instruction _reallocate = 0; @@ -205,6 +248,58 @@ private: return this; } + /** + * Perform a write operation + * @param monitor + * @return TcpState* + */ + TcpState *write(const Monitor &monitor) + { + // assume default state + _state = state_idle; + + // try to send more data from the outgoing buffer + auto result = _out.sendto(_ssl); + + // if this is a success, we can proceed with the event loop + if (result > 0) return proceed(); + + // the operation failed, we may have to repeat our call + return repeat(monitor, state_sending, OpenSSL::SSL_get_error(_ssl, result)); + } + + /** + * Perform a receive operation + * @param monitor + * @return TcpState + */ + TcpState *receive(const Monitor &monitor) + { + // start a loop + do + { + // assume default state + _state = state_idle; + + // read data from ssl into the buffer + auto result = _in.receivefrom(_ssl, _connection->expected()); + + // if this is a failure, we are going to repeat the operation + if (result <= 0) return repeat(monitor, state_receiving, OpenSSL::SSL_get_error(_ssl, result)); + + // go process the received data + auto *nextstate = parse(monitor, result); + + // leap out if we moved to a different state + if (nextstate != this) return nextstate; + } + while (OpenSSL::SSL_pending(_ssl) > 0); + + // go to the next state + return proceed(); + } + + public: /** * Constructor @@ -214,31 +309,25 @@ public: * @param buffer The buffer that was already built * @param handler User-supplied handler object */ - SslConnected(TcpConnection *connection, int socket, const SslWrapper &ssl, TcpOutBuffer &&buffer, TcpHandler *handler) : + SslConnected(TcpConnection *connection, int socket, SslWrapper &&ssl, TcpOutBuffer &&buffer, TcpHandler *handler) : TcpState(connection, handler), - _ssl(ssl), + _ssl(std::move(ssl)), _socket(socket), _out(std::move(buffer)), _in(4096), _state(_out ? state_sending : state_idle) { // tell the handler to monitor the socket if there is an out - _handler->monitor(_connection, _socket, _state == state_sending ? writable : readable); - } + _handler->monitor(_connection, _socket, _state == state_sending ? readable | writable : readable); + } /** * Destructor */ virtual ~SslConnected() noexcept { - // skip if handler is already forgotten - if (_handler == nullptr) return; - - // we no longer have to monitor the socket - _handler->monitor(_connection, _socket, 0); - // close the socket - close(_socket); + close(); } /** @@ -248,55 +337,86 @@ public: virtual int fileno() const override { return _socket; } /** - * Process the filedescriptor in the object + * Process the filedescriptor in the object + * @param monitor Object that can be used to find out if connection object is still alive * @param fd The filedescriptor that is active * @param flags AMQP::readable and/or AMQP::writable * @return New implementation object */ - virtual TcpState *process(int fd, int flags) + virtual TcpState *process(const Monitor &monitor, int fd, int flags) override { // the socket must be the one this connection writes to if (fd != _socket) return this; - // because the object might soon be destructed, we create a monitor to check this - Monitor monitor(this); - - // are we busy with sending or receiving data? - if (_state == state_sending) - { - // try to send more data from the outgoing buffer - auto result = _out.sendto(_ssl); - - // if this is a success, we can proceed with the event loop - if (result > 0) return proceed(); - - // the operation failed, we may have to repeat our call - else return repeat(result); - } - else - { - // read data from ssl into the buffer - auto result = _in.receivefrom(_ssl, _connection->expected()); - - // if this is a success, we may have to update the monitor - if (result > 0) return parse(result); - - // the operation failed, we may have to repeat our call - else return repeat(result); - } + // if we were busy with a write operation, we have to repeat that + if (_state == state_sending) return write(monitor); + + // same is true for read operations, they should also be repeated + if (_state == state_receiving) return receive(monitor); + + // if the socket is readable, we are going to receive data + if (flags & readable) return receive(monitor); + + // socket is not readable (so it must be writable), do we have data to write? + if (_out) return write(monitor); + + // the only scenario in which we can end up here is the socket should be + // closed, but instead of moving to the shutdown-state right, we call proceed() + // because that function is a little more careful + return proceed(); } /** * Flush the connection, sent all buffered data to the socket + * @param monitor Object to check if connection still exists * @return TcpState new tcp state */ - virtual TcpState *flush() override + virtual TcpState *flush(const Monitor &monitor) override { + // we are not going to do this is object is busy reading + if (_state == state_receiving) return this; + // create an object to wait for the filedescriptor to becomes active Wait wait(_socket); - // @todo implementation + // keep looping while we have an outgoing buffer + while (_out) + { + // move to the idle-state + _state = state_idle; + + // try to send more data from the outgoing buffer + auto result = _out.sendto(_ssl); + + // was this a success? + if (result > 0) + { + // proceed to the next state + auto *nextstate = proceed(); + + // leap out if we move to a different state + if (nextstate != this) return nextstate; + } + else + { + // error was returned, so we must investigate what is going on + auto error = OpenSSL::SSL_get_error(_ssl, result); + + // get the next state given the error + auto *nextstate = repeat(monitor, state_sending, error); + + // leap out if we move to a different state + if (nextstate != this) return nextstate; + + // check the type of error, and wait now + switch (error) { + case SSL_ERROR_WANT_READ: wait.readable(); break; + case SSL_ERROR_WANT_WRITE: wait.active(); break; + } + } + } + // done return this; } @@ -305,7 +425,7 @@ public: * @param buffer buffer to send * @param size size of the buffer */ - virtual void send(const char *buffer, size_t size) + virtual void send(const char *buffer, size_t size) override { // put the data in the outgoing buffer _out.add(buffer, size); @@ -314,9 +434,6 @@ public: // for that operation to complete before we can move on if (_state != state_idle) return; - // object is now busy sending - _state = state_sending; - // let's wait until the socket becomes writable _handler->monitor(_connection, _socket, readable | writable); } @@ -334,20 +451,40 @@ public: // pass to base return TcpState::reportNegotiate(heartbeat); } + + /** + * Report a connection error + * @param error + */ + virtual void reportError(const char *error) override + { + // we want to start the elegant ssl shutdown procedure, so we call reportClosed() here too, + // because that function does exactly what we want to do here too + reportClosed(); + + // if the user was already notified of an final state, we do not have to proceed + if (_finalized) return; + + // remember that this is the final call to user space + _finalized = true; + + // pass to handler + _handler->onError(_connection, error); + } /** * Report to the handler that the connection was nicely closed */ virtual void reportClosed() override { - // remember that the object is closed + // remember that the object is going to be closed _closed = true; - // if the previous operation is still in progress + // if the previous operation is still in progress we can wait for that if (_state != state_idle) return; - // wait until the connection is writable - _handler->monitor(_connection, _socket, writable); + // wait until the connection is writable so that we can close it then + _handler->monitor(_connection, _socket, readable | writable); } }; diff --git a/src/linux_tcp/sslcontext.h b/src/linux_tcp/sslcontext.h index 2c2d8b3..ab9dbc7 100644 --- a/src/linux_tcp/sslcontext.h +++ b/src/linux_tcp/sslcontext.h @@ -32,7 +32,7 @@ private: public: /** * Constructor - * @param method + * @param method the connect method to use * @throws std::runtime_error */ SslContext(const SSL_METHOD *method) : _ctx(OpenSSL::SSL_CTX_new(method)) @@ -42,26 +42,11 @@ public: } /** - * Constructor that wraps around an existing context - * @param context - */ - SslContext(SSL_CTX *context) : _ctx(context) - { - // increment refcount - // @todo fix this - //SSL_ctx_up_ref(context); - } - - /** - * Copy constructor + * Copy constructor is delete because the object is refcounted, + * and we do not have a decent way to update the refcount in openssl 1.0 * @param that */ - SslContext(SslContext &that) : _ctx(that._ctx) - { - // increment refcount - // @todo fix this - //SSL_ctx_up_ref(context); - } + SslContext(SslContext &that) = delete; /** * Destructor diff --git a/src/linux_tcp/sslhandshake.h b/src/linux_tcp/sslhandshake.h index 4095664..caa1b4f 100644 --- a/src/linux_tcp/sslhandshake.h +++ b/src/linux_tcp/sslhandshake.h @@ -53,32 +53,58 @@ private: /** * Report a new state - * @param state + * @param monitor * @return TcpState */ - TcpState *nextstate(TcpState *state) + TcpState *nextstate(const Monitor &monitor) { - // forget the socket to prevent that it is closed by the destructor + // check if the handler allows the connection + bool allowed = _handler->onSecured(_connection, _ssl); + + // leap out if the user space function destructed the object + if (!monitor.valid()) return nullptr; + + // copy the socket because we might forget it + auto socket = _socket; + + // forget the socket member to prevent that it is closed by the destructor _socket = -1; - // done - return state; + // if connection is allowed, we move to the next state + if (allowed) return new SslConnected(_connection, socket, std::move(_ssl), std::move(_out), _handler); + + // report that the connection is broken + _handler->onError(_connection, "TLS connection has been rejected"); + + // the onError method could have destructed this object + if (!monitor.valid()) return nullptr; + + // shutdown the connection + return new SslShutdown(_connection, socket, std::move(_ssl), true, _handler); } /** * Helper method to report an error + * @param monitor * @return TcpState* */ - TcpState *reportError() + TcpState *reportError(const Monitor &monitor) { // we are no longer interested in any events for this socket _handler->monitor(_connection, _socket, 0); + // close the socket + close(_socket); + + // forget filedescriptor + _socket = -1; + // we have an error - report this to the user _handler->onError(_connection, "failed to setup ssl connection"); - // done, go to the closed state - return new TcpClosed(_connection, _handler); + // done, go to the closed state (plus check if connection still exists, because + // after the onError() call the user space program may have destructed that object) + return monitor.valid() ? new TcpClosed(this) : nullptr; } /** @@ -114,7 +140,7 @@ public: _ssl(SslContext(OpenSSL::TLS_client_method())), _socket(socket), _out(std::move(buffer)) - { + { // we will be using the ssl context as a client OpenSSL::SSL_set_connect_state(_ssl); @@ -136,7 +162,7 @@ public: // leap out if socket is invalidated if (_socket < 0) return; - // close the socket + // the object got destructed without moving to a new state, this is normally close(_socket); } @@ -148,29 +174,30 @@ public: /** * Process the filedescriptor in the object + * @param monitor Object to check if connection still exists * @param fd Filedescriptor that is active * @param flags AMQP::readable and/or AMQP::writable * @return New state object */ - virtual TcpState *process(int fd, int flags) override + virtual TcpState *process(const Monitor &monitor, int fd, int flags) override { // must be the socket if (fd != _socket) return this; // start the ssl handshake int result = OpenSSL::SSL_do_handshake(_ssl); - + // if the connection succeeds, we can move to the ssl-connected state - if (result == 1) return nextstate(new SslConnected(_connection, _socket, _ssl, std::move(_out), _handler)); + if (result == 1) return nextstate(monitor); // error was returned, so we must investigate what is going on auto error = OpenSSL::SSL_get_error(_ssl, result); - + // check the error switch (error) { case SSL_ERROR_WANT_READ: return proceed(readable); case SSL_ERROR_WANT_WRITE: return proceed(readable | writable); - default: return reportError(); + default: return reportError(monitor); } } @@ -187,9 +214,10 @@ public: /** * Flush the connection, sent all buffered data to the socket + * @param monitor Object to check if connection still exists * @return TcpState new tcp state */ - virtual TcpState *flush() override + virtual TcpState *flush(const Monitor &monitor) override { // create an object to wait for the filedescriptor to becomes active Wait wait(_socket); @@ -201,48 +229,24 @@ public: int result = OpenSSL::SSL_do_handshake(_ssl); // if the connection succeeds, we can move to the ssl-connected state - if (result == 1) return nextstate(new SslConnected(_connection, _socket, _ssl, std::move(_out), _handler)); + if (result == 1) return nextstate(monitor); // error was returned, so we must investigate what is going on auto error = OpenSSL::SSL_get_error(_ssl, result); // check the error - switch (error) - { - // if openssl reports that socket readability or writability is needed, - // we wait for that until this situation is reached - case SSL_ERROR_WANT_READ: wait.readable(); break; - case SSL_ERROR_WANT_WRITE: wait.active(); break; - - // something is wrong, we proceed to the next state - default: return reportError(); + switch (error) { + + // if openssl reports that socket readability or writability is needed, + // we wait for that until this situation is reached + case SSL_ERROR_WANT_READ: wait.readable(); break; + case SSL_ERROR_WANT_WRITE: wait.active(); break; + + // something is wrong, we proceed to the next state + default: return reportError(monitor); } } } - - /** - * Report to the handler that the connection was nicely closed - */ - virtual void reportClosed() override - { - // we no longer have to monitor the socket - _handler->monitor(_connection, _socket, 0); - - // close the socket - close(_socket); - - // socket is closed now - _socket = -1; - - // copy the handler (if might destruct this object) - auto *handler = _handler; - - // reset member before the handler can make a mess of it - _handler = nullptr; - - // notify to handler - handler->onClosed(_connection); - } }; /** diff --git a/src/linux_tcp/sslshutdown.h b/src/linux_tcp/sslshutdown.h index 2361155..64791cf 100644 --- a/src/linux_tcp/sslshutdown.h +++ b/src/linux_tcp/sslshutdown.h @@ -34,44 +34,94 @@ private: * @var int */ int _socket; + + /** + * Have we already notified user space of connection end? + * @var bool + */ + bool _finalized; /** - * Proceed with the next operation after the previous operation was - * a success, possibly changing the filedescriptor-monitor - * @return TcpState* + * Close the socket + * @return bool */ - TcpState *proceed() + bool close() { - // construct monitor to prevent that we access members if object is destructed - Monitor monitor(this); + // skip if already closed + if (_socket < 0) return false; // we're no longer interested in events _handler->monitor(_connection, _socket, 0); - // stop if object was destructed - if (!monitor) return nullptr; - // close the socket - close(_socket); + ::close(_socket); // forget the socket _socket = -1; - // go to the closed state - return new TcpClosed(_connection, _handler); + // done + return true; + } + + /** + * Report an error + * @param monitor object to check if connection still exists + * @return TcpState* + */ + TcpState *reporterror(const Monitor &monitor) + { + // close the socket + close(); + + // if we have already told user space that connection is gone + if (_finalized) return new TcpClosed(this); + + // object will be finalized now + _finalized = true; + + // inform user space that the party is over + _handler->onError(_connection, "ssl shutdown error"); + + // go to the final state (if not yet disconnected) + return monitor.valid() ? new TcpClosed(this) : nullptr; + } + + /** + * Proceed with the next operation after the previous operation was + * a success, possibly changing the filedescriptor-monitor + * @param monitor object to check if connection still exists + * @return TcpState* + */ + TcpState *proceed(const Monitor &monitor) + { + // close the socket + close(); + + // if we have already told user space that connection is gone + if (_finalized) return new TcpClosed(this); + + // object will be finalized now + _finalized = true; + + // inform user space that the party is over + _handler->onClosed(_connection); + + // go to the final state (if not yet disconnected) + return monitor.valid() ? new TcpClosed(this) : nullptr; } /** * Method to repeat the previous call + * @param monitor object to check if connection still exists * @param result result of an earlier openssl operation * @return TcpState* */ - TcpState *repeat(int result) + TcpState *repeat(const Monitor &monitor, int result) { // error was returned, so we must investigate what is going on auto error = OpenSSL::SSL_get_error(_ssl, result); - + // check the error switch (error) { case SSL_ERROR_WANT_READ: @@ -85,9 +135,8 @@ private: return this; default: - - // @todo check how to handle this - return this; + // go to the final state (if not yet disconnected) + return reporterror(monitor); } } @@ -98,15 +147,17 @@ public: * @param connection Parent TCP connection object * @param socket The socket filedescriptor * @param ssl The SSL structure + * @param finalized Is the user already notified of connection end (onError() has been called) * @param handler User-supplied handler object */ - SslShutdown(TcpConnection *connection, int socket, const SslWrapper &ssl, TcpHandler *handler) : + SslShutdown(TcpConnection *connection, int socket, SslWrapper &&ssl, bool finalized, TcpHandler *handler) : TcpState(connection, handler), - _ssl(ssl), - _socket(socket) + _ssl(std::move(ssl)), + _socket(socket), + _finalized(finalized) { - // tell the handler to monitor the socket if there is an out - _handler->monitor(_connection, _socket, readable); + // wait until the socket is accessible + _handler->monitor(_connection, _socket, readable | writable); } /** @@ -114,14 +165,8 @@ public: */ virtual ~SslShutdown() noexcept { - // skip if handler is already forgotten - if (_handler == nullptr) return; - - // we no longer have to monitor the socket - _handler->monitor(_connection, _socket, 0); - // close the socket - close(_socket); + close(); } /** @@ -132,26 +177,66 @@ public: /** * Process the filedescriptor in the object + * @param monitor Object to check if connection still exists * @param fd The filedescriptor that is active * @param flags AMQP::readable and/or AMQP::writable * @return New implementation object */ - virtual TcpState *process(int fd, int flags) + virtual TcpState *process(const Monitor &monitor, int fd, int flags) override { // the socket must be the one this connection writes to if (fd != _socket) return this; - // because the object might soon be destructed, we create a monitor to check this - Monitor monitor(this); - // close the connection auto result = OpenSSL::SSL_shutdown(_ssl); - + + // on result==0 we need an additional call + while (result == 0) result = OpenSSL::SSL_shutdown(_ssl); + // if this is a success, we can proceed with the event loop - if (result > 0) return proceed(); + if (result > 0) return proceed(monitor); // the operation failed, we may have to repeat our call - else return repeat(result); + else return repeat(monitor, result); + } + + /** + * Flush the connection, sent all buffered data to the socket + * @param monitor Object to check if connection still exists + * @return TcpState new tcp state + */ + virtual TcpState *flush(const Monitor &monitor) override + { + // create an object to wait for the filedescriptor to becomes active + Wait wait(_socket); + + // keep looping + while (true) + { + // close the connection + auto result = OpenSSL::SSL_shutdown(_ssl); + + // on result==0 we need an additional call + while (result == 0) result = OpenSSL::SSL_shutdown(_ssl); + + // if this is a success, we can proceed with the event loop + if (result > 0) return proceed(monitor); + + // error was returned, so we must investigate what is going on + auto error = OpenSSL::SSL_get_error(_ssl, result); + + // check the error + switch (error) { + + // if openssl reports that socket readability or writability is needed, + // we wait for that until this situation is reached + case SSL_ERROR_WANT_READ: wait.readable(); break; + case SSL_ERROR_WANT_WRITE: wait.active(); break; + + // something is wrong, we proceed to the next state + default: return reporterror(monitor); + } + } } }; diff --git a/src/linux_tcp/sslwrapper.h b/src/linux_tcp/sslwrapper.h index 6aa5edd..b72e861 100644 --- a/src/linux_tcp/sslwrapper.h +++ b/src/linux_tcp/sslwrapper.h @@ -33,33 +33,31 @@ public: /** * Constructor * @param ctx + * @param file */ SslWrapper(SSL_CTX *ctx) : _ssl(OpenSSL::SSL_new(ctx)) { // report error if (_ssl == nullptr) throw std::runtime_error("failed to construct ssl structure"); + + //OpenSSL::SSL_use_certificate_file(_ssl, "cert.pem", SSL_FILETYPE_PEM); } /** - * Wrapper constructor - * @param ssl - */ - SslWrapper(SSL *ssl) : _ssl(ssl) - { - // one more reference - // @todo fix this - //CRYPTO_add(_ssl); - } - - /** - * Copy constructor + * Copy constructor is removed because openssl 1.0 has no way to up refcount + * (otherwise we could be safely copying objects around) * @param that */ - SslWrapper(const SslWrapper &that) : _ssl(that._ssl) + SslWrapper(const SslWrapper &that) = delete; + + /** + * Move constructor + * @param that + */ + SslWrapper(SslWrapper &&that) : _ssl(that._ssl) { - // one more reference - // @todo fix this - //SSL_up_ref(_ssl); + // invalidate other object + that._ssl = nullptr; } /** @@ -67,6 +65,9 @@ public: */ virtual ~SslWrapper() { + // do nothing if already moved away + if (_ssl == nullptr) return; + // destruct object OpenSSL::SSL_free(_ssl); } diff --git a/src/linux_tcp/tcpconnected.h b/src/linux_tcp/tcpconnected.h index 4aeb8e8..bc04223 100644 --- a/src/linux_tcp/tcpconnected.h +++ b/src/linux_tcp/tcpconnected.h @@ -54,8 +54,36 @@ private: * @var size_t */ size_t _reallocate = 0; + + /** + * Have we already made the last report to the user (about an error or closed connection?) + * @var bool + */ + bool _finalized = false; + /** + * Close the connection + * @return bool + */ + bool close() + { + // do nothing if already closed + if (_socket < 0) return false; + + // and stop monitoring it + _handler->monitor(_connection, _socket, 0); + + // close the socket + ::close(_socket); + + // forget filedescriptor + _socket = -1; + + // done + return true; + } + /** * Helper method to report an error * @return bool Was an error reported? @@ -65,6 +93,16 @@ private: // some errors are ok and do not (necessarily) mean that we're disconnected if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) return false; + // connection can be closed now + close(); + + // if the user has already been notified, we do not have to do anything else + if (_finalized) return true; + + // update the _finalized member before we make the call to user space because + // the user space may destruct this object + _finalized = true; + // we have an error - report this to the user _handler->onError(_connection, strerror(errno)); @@ -110,14 +148,8 @@ public: */ virtual ~TcpConnected() noexcept { - // skip if handler is already forgotten - if (_handler == nullptr) return; - - // we no longer have to monitor the socket - _handler->monitor(_connection, _socket, 0); - // close the socket - close(_socket); + close(); } /** @@ -128,18 +160,16 @@ public: /** * Process the filedescriptor in the object + * @param monitor Monitor to check if the object is still alive * @param fd Filedescriptor that is active * @param flags AMQP::readable and/or AMQP::writable * @return New state object */ - virtual TcpState *process(int fd, int flags) override + virtual TcpState *process(const Monitor &monitor, int fd, int flags) override { // must be the socket if (fd != _socket) return this; - // because the object might soon be destructed, we create a monitor to check this - Monitor monitor(this); - // can we write more data to the socket? if (flags & writable) { @@ -147,7 +177,7 @@ public: auto result = _out.sendto(_socket); // are we in an error state? - if (result < 0 && reportError()) return nextState(monitor); + if (result < 0 && reportError()) return nextState(monitor); // if buffer is empty by now, we no longer have to check for // writability, but only for readability @@ -218,9 +248,10 @@ public: /** * Flush the connection, sent all buffered data to the socket + * @param monitor Object to check if connection still lives * @return TcpState new tcp state */ - virtual TcpState *flush() override + virtual TcpState *flush(const Monitor &monitor) override { // create an object to wait for the filedescriptor to becomes active Wait wait(_socket); @@ -232,7 +263,7 @@ public: if (!wait.writable()) return this; // socket is writable, send as much data as possible - auto *newstate = process(_socket, writable); + auto *newstate = process(monitor, _socket, writable); // are we done if (newstate != this) return newstate; @@ -256,28 +287,47 @@ public: return TcpState::reportNegotiate(heartbeat); } + /** + * Report to the handler that the object is in an error state. + * @param error + */ + virtual void reportError(const char *error) override + { + // close the socket + close(); + + // if the user was already notified of an final state, we do not have to proceed + if (_finalized) return; + + // remember that this is the final call to user space + _finalized = true; + + // pass to handler + _handler->onError(_connection, error); + } + /** * Report to the handler that the connection was nicely closed + * This is the counter-part of the connection->close() call. */ virtual void reportClosed() override { - // we no longer have to monitor the socket - _handler->monitor(_connection, _socket, 0); + // we will shutdown the socket in a very elegant way, we notify the peer + // that we will not be sending out more write operations + shutdown(_socket, SHUT_WR); + + // we still monitor the socket for readability to see if our close call was + // confirmed by the peer + _handler->monitor(_connection, _socket, readable); - // close the socket - close(_socket); + // if the user was already notified of an final state, we do not have to proceed + if (_finalized) return; - // socket is closed now - _socket = -1; + // remember that this is the final call to user space + _finalized = true; - // copy the handler (if might destruct this object) - auto *handler = _handler; - - // reset member before the handler can make a mess of it - _handler = nullptr; - - // notify to handler - handler->onClosed(_connection); + // pass to handler + _handler->onClosed(_connection); } }; diff --git a/src/linux_tcp/tcpconnection.cpp b/src/linux_tcp/tcpconnection.cpp index b1141b9..b57922b 100644 --- a/src/linux_tcp/tcpconnection.cpp +++ b/src/linux_tcp/tcpconnection.cpp @@ -12,6 +12,7 @@ */ #include "includes.h" #include "tcpresolver.h" +#include "tcpstate.h" /** * Set up namespace @@ -53,11 +54,11 @@ int TcpConnection::fileno() const */ void TcpConnection::process(int fd, int flags) { - // monitor the object for destruction - Monitor monitor{ this }; + // monitor the object for destruction, because you never know what the user + Monitor monitor(this); // pass on the the state, that returns a new impl - auto *result = _state->process(fd, flags); + auto *result = _state->process(monitor, fd, flags); // are we still valid if (!monitor.valid()) return; @@ -83,7 +84,7 @@ void TcpConnection::flush() while (true) { // flush the object - auto *newstate = _state->flush(); + auto *newstate = _state->flush(monitor); // done if object no longer exists if (!monitor.valid()) return; @@ -137,14 +138,8 @@ void TcpConnection::onHeartbeat(Connection *connection) */ void TcpConnection::onError(Connection *connection, const char *message) { - // current object is going to be removed, but we have to keep it for a while - auto ptr = std::move(_state); - - // object is now in a closed state - _state.reset(new TcpClosed(ptr.get())); - // tell the implementation to report the error - ptr->reportError(message); + _state->reportError(message); } /** @@ -163,14 +158,8 @@ void TcpConnection::onConnected(Connection *connection) */ void TcpConnection::onClosed(Connection *connection) { - // current object is going to be removed, but we have to keep it for a while - auto ptr = std::move(_state); - - // object is now in a closed state - _state.reset(new TcpClosed(ptr.get())); - // tell the implementation to report that connection is closed now - ptr->reportClosed(); + _state->reportClosed(); } /** diff --git a/src/linux_tcp/tcpresolver.h b/src/linux_tcp/tcpresolver.h index ee24e2d..97e8916 100644 --- a/src/linux_tcp/tcpresolver.h +++ b/src/linux_tcp/tcpresolver.h @@ -93,7 +93,7 @@ private: try { // check if we support openssl in the first place - if (!OpenSSL::valid()) throw std::runtime_error("Secure connection cannot be established: libssl.so cannot be loaded"); + if (_secure && !OpenSSL::valid()) throw std::runtime_error("Secure connection cannot be established: libssl.so cannot be loaded"); // get address info AddressInfo addresses(_hostname.data(), _port); @@ -194,7 +194,7 @@ public: if (_socket >= 0) { // if we need a secure connection, we move to the tls handshake - // @todo catch exception + // @todo catch possible exception if (_secure) return new SslHandshake(_connection, _socket, _hostname, std::move(_buffer), _handler); // otherwise we have a valid regular tcp connection @@ -212,11 +212,12 @@ public: /** * Wait for the resolver to be ready + * @param monitor Object to check if connection still exists * @param fd The filedescriptor that is active * @param flags Flags to indicate that fd is readable and/or writable * @return New implementation object */ - virtual TcpState *process(int fd, int flags) override + virtual TcpState *process(const Monitor &monitor, int fd, int flags) override { // only works if the incoming pipe is readable if (fd != _pipe.in() || !(flags & readable)) return this; @@ -227,9 +228,10 @@ public: /** * Flush state / wait for the connection to complete + * @param monitor Object to check if connection still exists * @return New implementation object */ - virtual TcpState *flush() override + virtual TcpState *flush(const Monitor &monitor) override { // just wait for the other thread to be ready _thread.join(); diff --git a/src/linux_tcp/tcpstate.h b/src/linux_tcp/tcpstate.h index 93ddc21..1c3a280 100644 --- a/src/linux_tcp/tcpstate.h +++ b/src/linux_tcp/tcpstate.h @@ -64,12 +64,19 @@ public: virtual int fileno() const { return -1; } /** - * Process the filedescriptor in the object + * Process the filedescriptor in the object + * + * This method should return the handler object that will be responsible for + * all future readable/writable events for the file descriptor, or nullptr + * if the underlying connection object has already been destructed by the + * user and it would be pointless to set up a new handler. + * + * @param monitor Monitor that can be used to check if the tcp connection is still alive * @param fd The filedescriptor that is active * @param flags AMQP::readable and/or AMQP::writable * @return New implementation object */ - virtual TcpState *process(int fd, int flags) + virtual TcpState *process(const Monitor &monitor, int fd, int flags) { // default implementation does nothing and preserves same implementation return this; @@ -77,8 +84,8 @@ public: /** * Send data over the connection - * @param buffer buffer to send - * @param size size of the buffer + * @param buffer Buffer to send + * @param size Size of the buffer */ virtual void send(const char *buffer, size_t size) { @@ -86,7 +93,25 @@ public: } /** - * Report that heartbeat negotiation is going on + * Flush the connection, all outgoing operations should be completed. + * + * If the state changes during the operation, the new state object should + * be returned instead, or nullptr if the user has closed the connection + * in the meantime. If the connection object got destructed by a user space + * call, this method should return nullptr. A monitor object is pass in to + * allow the flush() method to check if the connection still exists. + * + * If this object returns a new state object (instead of "this"), the + * connection object will immediately proceed with calling flush() on that + * new state object too. + * + * @param monitor Monitor that can be used to check if the tcp connection is still alive + * @return TcpState New implementation object + */ + virtual TcpState *flush(const Monitor &monitor) { return this; } + + /** + * Report to the handler that heartbeat negotiation is going on * @param heartbeat suggested heartbeat * @return uint16_t accepted heartbeat */ @@ -97,16 +122,16 @@ public: } /** - * Flush the connection - * @return TcpState new implementation object - */ - virtual TcpState *flush() { return this; } - - /** - * Report to the handler that the object is in an error state + * Report to the handler that the object is in an error state. + * + * This is the last method to be called on the handler object, from now on + * the handler will no longer be called to report things to user space. + * The state object itself stays active, and further calls to process() + * may be possible. + * * @param error */ - void reportError(const char *error) + virtual void reportError(const char *error) { // pass to handler _handler->onError(_connection, error); @@ -115,7 +140,7 @@ public: /** * Report that a heartbeat frame was received */ - void reportHeartbeat() + virtual void reportHeartbeat() { // pass to handler _handler->onHeartbeat(_connection); @@ -124,14 +149,19 @@ public: /** * Report to the handler that the connection is ready for use */ - void reportConnected() + virtual void reportConnected() { // pass to handler _handler->onConnected(_connection); } /** - * Report to the handler that the connection was nicely closed + * Report to the handler that the connection was correctly closed, after + * the user has called the Connection::close() method. The underlying TCP + * connection still has to be closed. + * + * This is the last method that is called on the object, from now on no + * other methods may be called on the _handler variable. */ virtual void reportClosed() {