diff --git a/examples/libev.cpp b/examples/libev.cpp index 3922363..43cad39 100644 --- a/examples/libev.cpp +++ b/examples/libev.cpp @@ -63,7 +63,10 @@ int main() // handler for libev MyHandler handler(loop); - + + // init the SSL library + SSL_library_init(); + // make a connection AMQP::Address address("amqps://guest:guest@localhost/"); AMQP::TcpConnection connection(&handler, address); diff --git a/src/linux_tcp/tcpinbuffer.h b/src/linux_tcp/tcpinbuffer.h index d7ea09c..4487d82 100644 --- a/src/linux_tcp/tcpinbuffer.h +++ b/src/linux_tcp/tcpinbuffer.h @@ -114,13 +114,11 @@ public: * @param expected number of bytes that the library expects * @return ssize_t */ - /* ssize_t receivefrom(SSL *ssl, uint32_t expected) { // @todo implementation return 0; } - */ /** * Shrink the buffer (in practice this is always called with the full buffer size) diff --git a/src/linux_tcp/tcpoutbuffer.h b/src/linux_tcp/tcpoutbuffer.h index 4364c34..ab9e464 100644 --- a/src/linux_tcp/tcpoutbuffer.h +++ b/src/linux_tcp/tcpoutbuffer.h @@ -5,7 +5,7 @@ * output buffer. This is the implementation of that buffer * * @author Emiel Bruijntjes - * @copyright 2015 - 2016 Copernica BV + * @copyright 2015 - 2018 Copernica BV */ /** @@ -18,6 +18,7 @@ */ #include #include +#include /** * FIONREAD on Solaris is defined elsewhere @@ -272,7 +273,6 @@ public: * @param ssl the ssl context to send data to * @return ssize_t number of bytes sent, or the return value of ssl_write */ - /* ssize_t sendto(SSL *ssl) { // we're going to fill a lot of buffers (for ssl only one buffer at a time can be sent) @@ -281,22 +281,18 @@ public: // fill the buffers, and leap out if there is no data auto buffers = fill(buffer, 1); - std::cout << "buffercount = " << buffers << std::endl; - + // just to be sure we do this check if (buffers == 0) return 0; // send the data auto result = SSL_write(ssl, buffer[0].iov_base, buffer[0].iov_len); - // @todo do we have to move to the next buffer to prevent that this buffer is further filled? - // on success we shrink the buffer if (result > 0) shrink(result); // done return result; } - */ }; /** diff --git a/src/linux_tcp/tcpresolver.h b/src/linux_tcp/tcpresolver.h index 008bbc6..231e5f8 100644 --- a/src/linux_tcp/tcpresolver.h +++ b/src/linux_tcp/tcpresolver.h @@ -5,7 +5,7 @@ * server, and to make the initial connection * * @author Emiel Bruijntjes - * @copyright 2015 - 2016 Copernica BV + * @copyright 2015 - 2018 Copernica BV */ /** @@ -20,6 +20,7 @@ #include "tcpstate.h" #include "tcpclosed.h" #include "tcpconnected.h" +#include "tcpsslhandshake.h" #include /** @@ -179,6 +180,31 @@ public: _thread.join(); } + /** + * Proceed to the next state + * @return TcpState * + */ + TcpState *proceed() + { + // do we have a valid socket? + if (_socket >= 0) + { + // if we need a secure connection, we move to the tls handshake + if (_secure) return new TcpSslHandshake(_connection, _socket, _hostname, std::move(_buffer), _handler); + + // otherwise we have a valid regular tcp connection + return new TcpConnected(_connection, _socket, std::move(_buffer), _handler); + } + else + { + // report error + _handler->onError(_connection, _error.data()); + + // create dummy implementation + return new TcpClosed(_connection, _handler); + } + } + /** * Wait for the resolver to be ready * @param fd The filedescriptor that is active @@ -190,21 +216,8 @@ public: // only works if the incoming pipe is readable if (fd != _pipe.in() || !(flags & readable)) return this; - // do we have a valid socket? - if (_socket >= 0) - { - // if we need a secure connection, we move to the tls handshake - //if (_secure) return new TcpSslHandshake(_connection, _socket, std::move(_buffer), _handler); - - // otherwise we have a valid regular tcp connection - return new TcpConnected(_connection, _socket, std::move(_buffer), _handler); - } - - // report error - _handler->onError(_connection, _error.data()); - - // create dummy implementation - return new TcpClosed(_connection, _handler); + // proceed to the next state + return proceed(); } /** @@ -216,21 +229,8 @@ public: // just wait for the other thread to be ready _thread.join(); - // do we have a valid socket? - if (_socket >= 0) - { - // if we need a secure connection, we move to the tls handshake - //if (_secure) return new TcpSslHandshake(_connection, _socket, std::move(_buffer), _handler); - - // otherwise we have a valid regular tcp connection - return new TcpConnected(_connection, _socket, std::move(_buffer), _handler); - } - - // report error - _handler->onError(_connection, _error.data()); - - // create dummy implementation - return new TcpClosed(_connection, _handler); + // proceed to the next state + return proceed(); } /** diff --git a/src/linux_tcp/tcpssl.h b/src/linux_tcp/tcpssl.h new file mode 100644 index 0000000..6e0d9f3 --- /dev/null +++ b/src/linux_tcp/tcpssl.h @@ -0,0 +1,85 @@ +/** + * TcpSsl.h + * + * Wrapper around a SSL pointer + * + * @author Emiel Bruijntjes + * @copyright 2018 Copernica BV + */ + +/** + * Include guard + */ +#pragma once + +/** + * Begin of namespace + */ +namespace AMQP { + +/** + * Class definition + */ +class TcpSsl +{ +private: + /** + * The wrapped object + * @var SSL* + */ + SSL *_ssl; + +public: + /** + * Constructor + * @param ctx + */ + TcpSsl(SSL_CTX *ctx) : _ssl(SSL_new(ctx)) + { + // report error + if (_ssl == nullptr) throw std::runtime_error("failed to construct ssl structure"); + } + + /** + * Wrapper constructor + * @param ssl + */ + TcpSsl(SSL *ssl) : _ssl(ssl) + { + // one more reference + // @todo fix this + //CRYPTO_add(_ssl); + } + + /** + * Copy constructor + * @param that + */ + TcpSsl(const TcpSsl &that) : _ssl(that._ssl) + { + // one more reference + // @todo fix this + //SSL_up_ref(_ssl); + } + + /** + * Destructor + */ + virtual ~TcpSsl() + { + // destruct object + SSL_free(_ssl); + } + + /** + * Cast to the SSL* + * @return SSL * + */ + operator SSL * () const { return _ssl; } +}; + +/** + * End of namespace + */ +} + diff --git a/src/linux_tcp/tcpsslconnected.h b/src/linux_tcp/tcpsslconnected.h index 232fb0b..f52c6d0 100644 --- a/src/linux_tcp/tcpsslconnected.h +++ b/src/linux_tcp/tcpsslconnected.h @@ -19,8 +19,6 @@ #include "wait.h" #include -#include - /** * Set up namespace */ @@ -133,26 +131,19 @@ private: // error was returned, so we must investigate what is going on auto error = SSL_get_error(_ssl, result); - std::cout << "error = " << error << std::endl; - // check the error switch (error) { case SSL_ERROR_WANT_READ: // the operation must be repeated when readable - std::cout << "want read" << std::endl; - _handler->monitor(_connection, _socket, readable); return this; case SSL_ERROR_WANT_WRITE: // wait until socket becomes writable again - std::cout << "want write" << std::endl; - _handler->monitor(_connection, _socket, writable); return this; default: - std::cout << "something else" << std::endl; // @todo check how to handle this return this; @@ -177,8 +168,6 @@ public: _in(4096), _state(_out ? state_sending : state_idle) { - std::cout << "ssl-connected" << std::endl; - // tell the handler to monitor the socket if there is an out _handler->monitor(_connection, _socket, _state == state_sending ? writable : readable); } @@ -212,11 +201,6 @@ public: */ virtual TcpState *process(int fd, int flags) { - std::cout << "process call in ssl-connected" << std::endl; - - std::cout << fd << " - " << _socket << std::endl; - - // the socket must be the one this connection writes to if (fd != _socket) return this; @@ -226,13 +210,9 @@ public: // are we busy with sending or receiving data? if (_state == state_sending) { - std::cout << "busy sending" << std::endl; - // try to send more data from the outgoing buffer auto result = _out.sendto(_ssl); - std::cout << "result = " << result << std::endl; - // if this is a success, we may have to update the monitor if (result > 0) return proceed(); @@ -251,13 +231,8 @@ public: // the operation failed, we may have to repeat our call else return repeat(result); - - // we're busy with receiving data // @todo check this - - std::cout << "receive data" << std::endl; - } // keep same object diff --git a/src/linux_tcp/tcpsslcontext.h b/src/linux_tcp/tcpsslcontext.h new file mode 100644 index 0000000..9809b49 --- /dev/null +++ b/src/linux_tcp/tcpsslcontext.h @@ -0,0 +1,86 @@ +/** + * TcpSslContext.h + * + * Class to create and maintain a tcp ssl context + * + * @author Emiel Bruijntjes + * @copyright 2018 Copernica BV + */ + +/** + * Include guard + */ +#pragma once + +/** + * Begin of namespace + */ +namespace AMQP { + +/** + * Class definition + */ +class TcpSslContext +{ +private: + /** + * The wrapped context + * @var SSL_CTX + */ + SSL_CTX *_ctx; + +public: + /** + * Constructor + * @param method + * @throws std::runtime_error + */ + TcpSslContext(const SSL_METHOD *method) : _ctx(SSL_CTX_new(method)) + { + // report error + if (_ctx == nullptr) throw std::runtime_error("failed to construct ssl context"); + } + + /** + * Constructor that wraps around an existing context + * @param context + */ + TcpSslContext(SSL_CTX *context) : _ctx(context) + { + // increment refcount + // @todo fix this + //SSL_ctx_up_ref(context); + } + + /** + * Copy constructor + * @param that + */ + TcpSslContext(TcpSslContext &that) : _ctx(that._ctx) + { + // increment refcount + // @todo fix this + //SSL_ctx_up_ref(context); + } + + /** + * Destructor + */ + virtual ~TcpSslContext() + { + // free resource (this updates the refcount -1, and may destruct it) + SSL_CTX_free(_ctx); + } + + /** + * Cast to the actual context + * @return SSL_CTX * + */ + operator SSL_CTX * () { return _ctx; } +}; + +/** + * End of namespace + */ +} + diff --git a/src/linux_tcp/tcpsslhandshake.h b/src/linux_tcp/tcpsslhandshake.h index eccb860..7943a67 100644 --- a/src/linux_tcp/tcpsslhandshake.h +++ b/src/linux_tcp/tcpsslhandshake.h @@ -18,11 +18,8 @@ #include "tcpoutbuffer.h" #include "tcpsslconnected.h" #include "wait.h" - -#include -#include -#include -#include +#include "tcpssl.h" +#include "tcpsslcontext.h" /** * Set up namespace @@ -35,17 +32,11 @@ namespace AMQP { class TcpSslHandshake : public TcpState, private Watchable { private: - /** - * SSL context - * @var SSL_CTX - */ - SSL_CTX *ctx; - /** * SSL structure * @var SSL */ - SSL *_ssl; + TcpSsl _ssl; /** * The socket file descriptor @@ -99,38 +90,26 @@ public: * * @param connection Parent TCP connection object * @param socket The socket filedescriptor + * @param hostname The hostname to connect to * @param context SSL context * @param buffer The buffer that was already built * @param handler User-supplied handler object * @throws std::runtime_error */ - TcpSslHandshake(TcpConnection *connection, int socket, TcpOutBuffer &&buffer, TcpHandler *handler) : + TcpSslHandshake(TcpConnection *connection, int socket, const std::string &hostname, TcpOutBuffer &&buffer, TcpHandler *handler) : TcpState(connection, handler), + _ssl(TcpSslContext(SSLv23_client_method())), _socket(socket), _out(std::move(buffer)) { - // init the SSL library - SSL_library_init(); - - // create ssl context - ctx = SSL_CTX_new(TLS_client_method()); - - // create ssl object - _ssl = SSL_new(ctx); - - // leap out on error - if (_ssl == nullptr) throw std::runtime_error("ERROR: SSL structure is null"); - // we will be using the ssl context as a client SSL_set_connect_state(_ssl); - + + // associate domain name with the connection + SSL_set_tlsext_host_name(_ssl, hostname.data()); // associate the ssl context with the socket filedescriptor - int set_fd_ret = SSL_set_fd(_ssl, socket); - if (set_fd_ret == 0) { - reportError(); - std::cout << "error while setting file descriptor" << std::endl; - } + if (SSL_set_fd(_ssl, socket) == 0) throw std::runtime_error("failed to associate filedescriptor with ssl socket"); // we are going to wait until the socket becomes writable before we start the handshake _handler->monitor(_connection, _socket, writable); @@ -179,39 +158,32 @@ public: switch (error) { case SSL_ERROR_WANT_READ: // the handshake must be repeated when socket is readable, wait for that - std::cout << "wait for readability" << std::endl; _handler->monitor(_connection, _socket, readable); break; case SSL_ERROR_WANT_WRITE: // the handshake must be repeated when socket is readable, wait for that - std::cout << "wait for writability" << std::endl; _handler->monitor(_connection, _socket, writable); break; case SSL_ERROR_WANT_ACCEPT: // the BIO was not connected yet, the SSL function should be called again - std::cout << "wait for acceptability" << ERR_error_string(ERR_get_error(), nullptr) << std::endl; _handler->monitor(_connection, _socket, writable); break; case SSL_ERROR_WANT_X509_LOOKUP: - std::cout << "SSL_ERROR_WANT_X509_LOOKUP" << ERR_error_string(ERR_get_error(), nullptr) << std::endl; _handler->monitor(_connection, _socket, writable); break; case SSL_ERROR_SYSCALL: - std::cout << "SSL_ERROR_SYSCALL: " << ERR_error_string(ERR_get_error(), nullptr) << std::endl; _handler->monitor(_connection, _socket, writable); break; case SSL_ERROR_SSL: - std::cout << "SSL_ERROR_SSL" << ERR_error_string(ERR_get_error(), nullptr) << std::endl; _handler->monitor(_connection, _socket, writable); break; default: - std::cout << "unknown error state " << error << std::endl; return reportError(); } }