From 7aa7794e3e944895c0b1b4635669affc93800968 Mon Sep 17 00:00:00 2001 From: Emiel Bruijntjes Date: Tue, 6 Mar 2018 22:03:53 +0100 Subject: [PATCH] work in progress on ssl implementation --- examples/libev.cpp | 2 +- src/linux_tcp/function.h | 5 +- src/linux_tcp/library.h | 64 -------------------------- src/linux_tcp/openssl.cpp | 89 +++++++++++++++++++++++++++--------- src/linux_tcp/openssl.h | 3 ++ src/linux_tcp/sslconnected.h | 16 ++++++- src/linux_tcp/sslcontext.h | 4 +- src/linux_tcp/sslhandshake.h | 14 +++--- src/linux_tcp/sslshutdown.h | 4 +- src/linux_tcp/sslwrapper.h | 4 +- src/linux_tcp/tcpconnected.h | 2 +- src/linux_tcp/tcpinbuffer.h | 13 +++++- src/linux_tcp/tcpoutbuffer.h | 5 +- src/linux_tcp/tcpresolver.h | 2 +- 14 files changed, 118 insertions(+), 109 deletions(-) delete mode 100644 src/linux_tcp/library.h diff --git a/examples/libev.cpp b/examples/libev.cpp index 4d909f6..9ee0e2a 100644 --- a/examples/libev.cpp +++ b/examples/libev.cpp @@ -66,7 +66,7 @@ int main() MyHandler handler(loop); // init the SSL library -// SSL_library_init(); + SSL_library_init(); // make a connection AMQP::Address address("amqps://guest:guest@localhost/"); diff --git a/src/linux_tcp/function.h b/src/linux_tcp/function.h index b243643..fa1250d 100644 --- a/src/linux_tcp/function.h +++ b/src/linux_tcp/function.h @@ -119,11 +119,10 @@ private: public: /** * Constructor - * @param library The library to load the function from * @param name Name of the function */ - Function(void *library, const char *name) : - _method(dlsym(library, name)) {} + Function(const char *name) : + _method(dlsym(RTLD_DEFAULT, name)) {} /** * Destructor diff --git a/src/linux_tcp/library.h b/src/linux_tcp/library.h deleted file mode 100644 index 7330eda..0000000 --- a/src/linux_tcp/library.h +++ /dev/null @@ -1,64 +0,0 @@ -/** - * Library.h - * - * The Library class is a wrapper around dlopen() - * - * @author Emiel Bruijntjes - * @copyright 2018 Copernica BV - */ - -/** - * Include guard - */ -#pragma once - -/** - * Begin of namespace - */ -namespace AMQP { - -/** - * Class definition - */ -class Library -{ -private: - /** - * The library handle - * @var void * - */ - void *_handle; - -public: - /** - * Constructor - */ - Library() : _handle(dlopen(nullptr, RTLD_NOW)) {} - - /** - * No copying - * @param that - */ - Library(const Library &that) = delete; - - /** - * Destructor - */ - virtual ~Library() - { - // close library - if (_handle) dlclose(_handle); - } - - /** - * Cast to the handle - * @return void * - */ - operator void * () { return _handle; } -}; - -/** - * End of namespace - */ -} - diff --git a/src/linux_tcp/openssl.cpp b/src/linux_tcp/openssl.cpp index 8e9139d..39bff69 100644 --- a/src/linux_tcp/openssl.cpp +++ b/src/linux_tcp/openssl.cpp @@ -11,7 +11,6 @@ */ #include "openssl.h" #include "function.h" -#include "library.h" /** * Begin of namespace @@ -19,16 +18,35 @@ namespace AMQP { namespace OpenSSL { /** - * Get the library handle - * @return void * + * Is the openssl library loaded? + * @return bool */ -static void *library() +bool valid() { - // create on instance - static Library instance; + // create a function + static Function func("SSL_CTX_new"); + + // we need a library + return func; +} + +/** + * Get the SSL_METHOD for outgoing connections + * @return SSL_METHOD * + */ +const SSL_METHOD *TLS_client_method() +{ + // create a function that loads the method + static Function func("TLS_client_method"); - // return the instance (it has a cast-to-void-ptr operator) - return instance; + // call the openssl function + if (func) return func(); + + // older openssl libraries do not have this function, so we try to load an other function + static Function old("SSLv23_client_method"); + + // call the old one + return old(); } /** @@ -39,7 +57,7 @@ static void *library() SSL_CTX *SSL_CTX_new(const SSL_METHOD *method) { // create a function - static Function func(library(), "SSL_CTX_new"); + static Function func("SSL_CTX_new"); // call the openssl function return func(method); @@ -55,7 +73,7 @@ SSL_CTX *SSL_CTX_new(const SSL_METHOD *method) int SSL_read(SSL *ssl, void *buf, int num) { // create a function - static Function func(library(), "SSL_read"); + static Function func("SSL_read"); // call the openssl function return func(ssl, buf, num); @@ -71,7 +89,7 @@ int SSL_read(SSL *ssl, void *buf, int num) int SSL_write(SSL *ssl, const void *buf, int num) { // create a function - static Function func(library(), "SSL_write"); + static Function func("SSL_write"); // call the openssl function return func(ssl, buf, num); @@ -86,12 +104,24 @@ int SSL_write(SSL *ssl, const void *buf, int num) int SSL_set_fd(SSL *ssl, int fd) { // create a function - static Function func(library(), "SSL_set_fd"); + static Function func("SSL_set_fd"); // call the openssl function return func(ssl, fd); } +/** + * Free an allocated ssl context + * @param ctx + */ +void SSL_CTX_free(SSL_CTX *ctx) +{ + // create a function + static Function func("SSL_CTX_free"); + + // call the openssl function + return func(ctx); +} /** * Free an allocated SSL structure @@ -101,7 +131,7 @@ int SSL_set_fd(SSL *ssl, int fd) void SSL_free(SSL *ssl) { // create a function - static Function func(library(), "SSL_free"); + static Function func("SSL_free"); // call the openssl function return func(ssl); @@ -115,7 +145,7 @@ void SSL_free(SSL *ssl) SSL *SSL_new(SSL_CTX *ctx) { // create a function - static Function func(library(), "SSL_new"); + static Function func("SSL_new"); // call the openssl function return func(ctx); @@ -129,7 +159,7 @@ SSL *SSL_new(SSL_CTX *ctx) int SSL_up_ref(SSL *ssl) { // create a function - static Function func(library(), "SSL_up_ref"); + static Function func("SSL_up_ref"); // call the openssl function if it exists if (func) return func(ssl); @@ -147,7 +177,7 @@ int SSL_up_ref(SSL *ssl) int SSL_shutdown(SSL *ssl) { // create a function - static Function func(library(), "SSL_shutdown"); + static Function func("SSL_shutdown"); // call the openssl function return func(ssl); @@ -160,7 +190,7 @@ int SSL_shutdown(SSL *ssl) void SSL_set_connect_state(SSL *ssl) { // create a function - static Function func(library(), "SSL_set_connect_state"); + static Function func("SSL_set_connect_state"); // call the openssl function func(ssl); @@ -175,7 +205,7 @@ void SSL_set_connect_state(SSL *ssl) int SSL_do_handshake(SSL *ssl) { // create a function - static Function func(library(), "SSL_do_handshake"); + static Function func("SSL_do_handshake"); // call the openssl function return func(ssl); @@ -183,18 +213,35 @@ int SSL_do_handshake(SSL *ssl) /** * Obtain result code for TLS/SSL I/O operation - * @param ssl SSL object - * @param ret the returned diagnostic value of SSL calls + * @param ssl SSL object + * @param ret the returned diagnostic value of SSL calls * @return int returns error values */ int SSL_get_error(const SSL *ssl, int ret) { // create a function - static Function func(library(), "SSL_get_error"); + static Function func("SSL_get_error"); // call the openssl function return func(ssl, ret); } + +/** + * Internal handling function for a ssl context + * @param ssl ssl context + * @param cmd command + * @param larg first arg + * @param parg second arg + * @return long + */ +long SSL_ctrl(SSL *ssl, int cmd, long larg, void *parg) +{ + // create a function + static Function func("SSL_ctrl"); + + // call the openssl function + return func(ssl, cmd, larg, parg); +} /** * End of namespace diff --git a/src/linux_tcp/openssl.h b/src/linux_tcp/openssl.h index 9a21dc9..387654d 100644 --- a/src/linux_tcp/openssl.h +++ b/src/linux_tcp/openssl.h @@ -33,6 +33,7 @@ bool valid(); /** * List of all wrapper methods that are in use inside AMQP-CPP */ +const SSL_METHOD *TLS_client_method(); SSL_CTX *SSL_CTX_new(const SSL_METHOD *method); SSL *SSL_new(SSL_CTX *ctx); int SSL_do_handshake(SSL *ssl); @@ -43,7 +44,9 @@ int SSL_set_fd(SSL *ssl, int fd); int SSL_get_error(const SSL *ssl, int ret); int SSL_up_ref(SSL *ssl); void SSL_set_connect_state(SSL *ssl); +void SSL_CTX_free(SSL_CTX *ctx); void SSL_free(SSL *ssl); +long SSL_ctrl(SSL *ssl, int cmd, long larg, void *parg); /** * End of namespace diff --git a/src/linux_tcp/sslconnected.h b/src/linux_tcp/sslconnected.h index 1ad708b..b6e5d21 100644 --- a/src/linux_tcp/sslconnected.h +++ b/src/linux_tcp/sslconnected.h @@ -148,7 +148,7 @@ private: TcpState *repeat(int result) { // error was returned, so we must investigate what is going on - auto error = SSL_get_error(_ssl, result); + auto error = OpenSSL::SSL_get_error(_ssl, result); // check the error switch (error) { @@ -286,6 +286,20 @@ public: } } + /** + * Flush the connection, sent all buffered data to the socket + * @return TcpState new tcp state + */ + virtual TcpState *flush() override + { + // create an object to wait for the filedescriptor to becomes active + Wait wait(_socket); + + // @todo implementation + + return this; + } + /** * Send data over the connection * @param buffer buffer to send diff --git a/src/linux_tcp/sslcontext.h b/src/linux_tcp/sslcontext.h index a53c04d..2c2d8b3 100644 --- a/src/linux_tcp/sslcontext.h +++ b/src/linux_tcp/sslcontext.h @@ -35,7 +35,7 @@ public: * @param method * @throws std::runtime_error */ - SslContext(const SSL_METHOD *method) : _ctx(SSL_CTX_new(method)) + SslContext(const SSL_METHOD *method) : _ctx(OpenSSL::SSL_CTX_new(method)) { // report error if (_ctx == nullptr) throw std::runtime_error("failed to construct ssl context"); @@ -69,7 +69,7 @@ public: virtual ~SslContext() { // free resource (this updates the refcount -1, and may destruct it) - SSL_CTX_free(_ctx); + OpenSSL::SSL_CTX_free(_ctx); } /** diff --git a/src/linux_tcp/sslhandshake.h b/src/linux_tcp/sslhandshake.h index fb8af3b..4095664 100644 --- a/src/linux_tcp/sslhandshake.h +++ b/src/linux_tcp/sslhandshake.h @@ -111,18 +111,18 @@ public: */ SslHandshake(TcpConnection *connection, int socket, const std::string &hostname, TcpOutBuffer &&buffer, TcpHandler *handler) : TcpState(connection, handler), - _ssl(SslContext(SSLv23_client_method())), + _ssl(SslContext(OpenSSL::TLS_client_method())), _socket(socket), _out(std::move(buffer)) { // we will be using the ssl context as a client - SSL_set_connect_state(_ssl); + OpenSSL::SSL_set_connect_state(_ssl); // associate domain name with the connection - SSL_set_tlsext_host_name(_ssl, hostname.data()); + OpenSSL::SSL_ctrl(_ssl, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, (void *)hostname.data()); // associate the ssl context with the socket filedescriptor - if (SSL_set_fd(_ssl, socket) == 0) throw std::runtime_error("failed to associate filedescriptor with ssl socket"); + if (OpenSSL::SSL_set_fd(_ssl, socket) == 0) throw std::runtime_error("failed to associate filedescriptor with ssl socket"); // we are going to wait until the socket becomes writable before we start the handshake _handler->monitor(_connection, _socket, writable); @@ -164,7 +164,7 @@ public: if (result == 1) return nextstate(new SslConnected(_connection, _socket, _ssl, std::move(_out), _handler)); // error was returned, so we must investigate what is going on - auto error = SSL_get_error(_ssl, result); + auto error = OpenSSL::SSL_get_error(_ssl, result); // check the error switch (error) { @@ -198,13 +198,13 @@ public: while (true) { // start the ssl handshake - int result = SSL_do_handshake(_ssl); + int result = OpenSSL::SSL_do_handshake(_ssl); // if the connection succeeds, we can move to the ssl-connected state if (result == 1) return nextstate(new SslConnected(_connection, _socket, _ssl, std::move(_out), _handler)); // error was returned, so we must investigate what is going on - auto error = SSL_get_error(_ssl, result); + auto error = OpenSSL::SSL_get_error(_ssl, result); // check the error switch (error) diff --git a/src/linux_tcp/sslshutdown.h b/src/linux_tcp/sslshutdown.h index 4a0315e..2361155 100644 --- a/src/linux_tcp/sslshutdown.h +++ b/src/linux_tcp/sslshutdown.h @@ -70,7 +70,7 @@ private: TcpState *repeat(int result) { // error was returned, so we must investigate what is going on - auto error = SSL_get_error(_ssl, result); + auto error = OpenSSL::SSL_get_error(_ssl, result); // check the error switch (error) { @@ -145,7 +145,7 @@ public: Monitor monitor(this); // close the connection - auto result = SSL_shutdown(_ssl); + auto result = OpenSSL::SSL_shutdown(_ssl); // if this is a success, we can proceed with the event loop if (result > 0) return proceed(); diff --git a/src/linux_tcp/sslwrapper.h b/src/linux_tcp/sslwrapper.h index 5c88810..6aa5edd 100644 --- a/src/linux_tcp/sslwrapper.h +++ b/src/linux_tcp/sslwrapper.h @@ -34,7 +34,7 @@ public: * Constructor * @param ctx */ - SslWrapper(SSL_CTX *ctx) : _ssl(SSL_new(ctx)) + SslWrapper(SSL_CTX *ctx) : _ssl(OpenSSL::SSL_new(ctx)) { // report error if (_ssl == nullptr) throw std::runtime_error("failed to construct ssl structure"); @@ -68,7 +68,7 @@ public: virtual ~SslWrapper() { // destruct object - SSL_free(_ssl); + OpenSSL::SSL_free(_ssl); } /** diff --git a/src/linux_tcp/tcpconnected.h b/src/linux_tcp/tcpconnected.h index 670b3db..4aeb8e8 100644 --- a/src/linux_tcp/tcpconnected.h +++ b/src/linux_tcp/tcpconnected.h @@ -225,7 +225,7 @@ public: // create an object to wait for the filedescriptor to becomes active Wait wait(_socket); - // keep running until the out buffer is empty + // keep running until the out buffer is not empty while (_out) { // poll the socket, is it already writable? diff --git a/src/linux_tcp/tcpinbuffer.h b/src/linux_tcp/tcpinbuffer.h index bd020e3..79b708f 100644 --- a/src/linux_tcp/tcpinbuffer.h +++ b/src/linux_tcp/tcpinbuffer.h @@ -121,8 +121,17 @@ public: */ ssize_t receivefrom(SSL *ssl, uint32_t expected) { - // @todo implementation - return 0; + // number of bytes to that still fit in the buffer + size_t bytes = expected - _size; + + // read data + auto result = OpenSSL::SSL_read(ssl, (void *)(_data + _size), bytes); + + // update total buffer size on success + if (result > 0) _size += result; + + // done + return result; } /** diff --git a/src/linux_tcp/tcpoutbuffer.h b/src/linux_tcp/tcpoutbuffer.h index 8d18526..31dbd27 100644 --- a/src/linux_tcp/tcpoutbuffer.h +++ b/src/linux_tcp/tcpoutbuffer.h @@ -18,7 +18,8 @@ */ #include #include -#include +#include "openssl.h" + /** * FIONREAD on Solaris is defined elsewhere */ @@ -284,7 +285,7 @@ public: if (buffers == 0) return 0; // send the data - auto result = SSL_write(ssl, buffer[0].iov_base, buffer[0].iov_len); + auto result = OpenSSL::SSL_write(ssl, buffer[0].iov_base, buffer[0].iov_len); // on success we shrink the buffer if (result > 0) shrink(result); diff --git a/src/linux_tcp/tcpresolver.h b/src/linux_tcp/tcpresolver.h index be45c4f..ee24e2d 100644 --- a/src/linux_tcp/tcpresolver.h +++ b/src/linux_tcp/tcpresolver.h @@ -93,7 +93,7 @@ private: try { // check if we support openssl in the first place - if (!OpenSSL::valid()) throw std::runtime_error("Secure connection cannot be established: the application has no access to openssl"); + if (!OpenSSL::valid()) throw std::runtime_error("Secure connection cannot be established: libssl.so cannot be loaded"); // get address info AddressInfo addresses(_hostname.data(), _port);