work in progress on ssl implementation

This commit is contained in:
Emiel Bruijntjes 2018-03-06 22:03:53 +01:00
parent 0ca9bc9dad
commit 7aa7794e3e
14 changed files with 118 additions and 109 deletions

View File

@ -66,7 +66,7 @@ int main()
MyHandler handler(loop);
// init the SSL library
// SSL_library_init();
SSL_library_init();
// make a connection
AMQP::Address address("amqps://guest:guest@localhost/");

View File

@ -119,11 +119,10 @@ private:
public:
/**
* Constructor
* @param library The library to load the function from
* @param name Name of the function
*/
Function(void *library, const char *name) :
_method(dlsym(library, name)) {}
Function(const char *name) :
_method(dlsym(RTLD_DEFAULT, name)) {}
/**
* Destructor

View File

@ -1,64 +0,0 @@
/**
* Library.h
*
* The Library class is a wrapper around dlopen()
*
* @author Emiel Bruijntjes <emiel.bruijntjes@copernica.com>
* @copyright 2018 Copernica BV
*/
/**
* Include guard
*/
#pragma once
/**
* Begin of namespace
*/
namespace AMQP {
/**
* Class definition
*/
class Library
{
private:
/**
* The library handle
* @var void *
*/
void *_handle;
public:
/**
* Constructor
*/
Library() : _handle(dlopen(nullptr, RTLD_NOW)) {}
/**
* No copying
* @param that
*/
Library(const Library &that) = delete;
/**
* Destructor
*/
virtual ~Library()
{
// close library
if (_handle) dlclose(_handle);
}
/**
* Cast to the handle
* @return void *
*/
operator void * () { return _handle; }
};
/**
* End of namespace
*/
}

View File

@ -11,7 +11,6 @@
*/
#include "openssl.h"
#include "function.h"
#include "library.h"
/**
* Begin of namespace
@ -19,16 +18,35 @@
namespace AMQP { namespace OpenSSL {
/**
* Get the library handle
* @return void *
* Is the openssl library loaded?
* @return bool
*/
static void *library()
bool valid()
{
// create on instance
static Library instance;
// create a function
static Function<decltype(::SSL_CTX_new)> func("SSL_CTX_new");
// we need a library
return func;
}
/**
* Get the SSL_METHOD for outgoing connections
* @return SSL_METHOD *
*/
const SSL_METHOD *TLS_client_method()
{
// create a function that loads the method
static Function<decltype(TLS_client_method)> func("TLS_client_method");
// return the instance (it has a cast-to-void-ptr operator)
return instance;
// call the openssl function
if (func) return func();
// older openssl libraries do not have this function, so we try to load an other function
static Function<decltype(TLS_client_method)> old("SSLv23_client_method");
// call the old one
return old();
}
/**
@ -39,7 +57,7 @@ static void *library()
SSL_CTX *SSL_CTX_new(const SSL_METHOD *method)
{
// create a function
static Function<decltype(::SSL_CTX_new)> func(library(), "SSL_CTX_new");
static Function<decltype(::SSL_CTX_new)> func("SSL_CTX_new");
// call the openssl function
return func(method);
@ -55,7 +73,7 @@ SSL_CTX *SSL_CTX_new(const SSL_METHOD *method)
int SSL_read(SSL *ssl, void *buf, int num)
{
// create a function
static Function<decltype(::SSL_read)> func(library(), "SSL_read");
static Function<decltype(::SSL_read)> func("SSL_read");
// call the openssl function
return func(ssl, buf, num);
@ -71,7 +89,7 @@ int SSL_read(SSL *ssl, void *buf, int num)
int SSL_write(SSL *ssl, const void *buf, int num)
{
// create a function
static Function<decltype(::SSL_write)> func(library(), "SSL_write");
static Function<decltype(::SSL_write)> func("SSL_write");
// call the openssl function
return func(ssl, buf, num);
@ -86,12 +104,24 @@ int SSL_write(SSL *ssl, const void *buf, int num)
int SSL_set_fd(SSL *ssl, int fd)
{
// create a function
static Function<decltype(::SSL_set_fd)> func(library(), "SSL_set_fd");
static Function<decltype(::SSL_set_fd)> func("SSL_set_fd");
// call the openssl function
return func(ssl, fd);
}
/**
* Free an allocated ssl context
* @param ctx
*/
void SSL_CTX_free(SSL_CTX *ctx)
{
// create a function
static Function<decltype(::SSL_CTX_free)> func("SSL_CTX_free");
// call the openssl function
return func(ctx);
}
/**
* Free an allocated SSL structure
@ -101,7 +131,7 @@ int SSL_set_fd(SSL *ssl, int fd)
void SSL_free(SSL *ssl)
{
// create a function
static Function<decltype(::SSL_free)> func(library(), "SSL_free");
static Function<decltype(::SSL_free)> func("SSL_free");
// call the openssl function
return func(ssl);
@ -115,7 +145,7 @@ void SSL_free(SSL *ssl)
SSL *SSL_new(SSL_CTX *ctx)
{
// create a function
static Function<decltype(::SSL_new)> func(library(), "SSL_new");
static Function<decltype(::SSL_new)> func("SSL_new");
// call the openssl function
return func(ctx);
@ -129,7 +159,7 @@ SSL *SSL_new(SSL_CTX *ctx)
int SSL_up_ref(SSL *ssl)
{
// create a function
static Function<decltype(SSL_up_ref)> func(library(), "SSL_up_ref");
static Function<decltype(SSL_up_ref)> func("SSL_up_ref");
// call the openssl function if it exists
if (func) return func(ssl);
@ -147,7 +177,7 @@ int SSL_up_ref(SSL *ssl)
int SSL_shutdown(SSL *ssl)
{
// create a function
static Function<decltype(::SSL_shutdown)> func(library(), "SSL_shutdown");
static Function<decltype(::SSL_shutdown)> func("SSL_shutdown");
// call the openssl function
return func(ssl);
@ -160,7 +190,7 @@ int SSL_shutdown(SSL *ssl)
void SSL_set_connect_state(SSL *ssl)
{
// create a function
static Function<decltype(::SSL_set_connect_state)> func(library(), "SSL_set_connect_state");
static Function<decltype(::SSL_set_connect_state)> func("SSL_set_connect_state");
// call the openssl function
func(ssl);
@ -175,7 +205,7 @@ void SSL_set_connect_state(SSL *ssl)
int SSL_do_handshake(SSL *ssl)
{
// create a function
static Function<decltype(::SSL_do_handshake)> func(library(), "SSL_do_handshake");
static Function<decltype(::SSL_do_handshake)> func("SSL_do_handshake");
// call the openssl function
return func(ssl);
@ -183,18 +213,35 @@ int SSL_do_handshake(SSL *ssl)
/**
* Obtain result code for TLS/SSL I/O operation
* @param ssl SSL object
* @param ret the returned diagnostic value of SSL calls
* @param ssl SSL object
* @param ret the returned diagnostic value of SSL calls
* @return int returns error values
*/
int SSL_get_error(const SSL *ssl, int ret)
{
// create a function
static Function<decltype(::SSL_get_error)> func(library(), "SSL_get_error");
static Function<decltype(::SSL_get_error)> func("SSL_get_error");
// call the openssl function
return func(ssl, ret);
}
/**
* Internal handling function for a ssl context
* @param ssl ssl context
* @param cmd command
* @param larg first arg
* @param parg second arg
* @return long
*/
long SSL_ctrl(SSL *ssl, int cmd, long larg, void *parg)
{
// create a function
static Function<decltype(::SSL_ctrl)> func("SSL_ctrl");
// call the openssl function
return func(ssl, cmd, larg, parg);
}
/**
* End of namespace

View File

@ -33,6 +33,7 @@ bool valid();
/**
* List of all wrapper methods that are in use inside AMQP-CPP
*/
const SSL_METHOD *TLS_client_method();
SSL_CTX *SSL_CTX_new(const SSL_METHOD *method);
SSL *SSL_new(SSL_CTX *ctx);
int SSL_do_handshake(SSL *ssl);
@ -43,7 +44,9 @@ int SSL_set_fd(SSL *ssl, int fd);
int SSL_get_error(const SSL *ssl, int ret);
int SSL_up_ref(SSL *ssl);
void SSL_set_connect_state(SSL *ssl);
void SSL_CTX_free(SSL_CTX *ctx);
void SSL_free(SSL *ssl);
long SSL_ctrl(SSL *ssl, int cmd, long larg, void *parg);
/**
* End of namespace

View File

@ -148,7 +148,7 @@ private:
TcpState *repeat(int result)
{
// error was returned, so we must investigate what is going on
auto error = SSL_get_error(_ssl, result);
auto error = OpenSSL::SSL_get_error(_ssl, result);
// check the error
switch (error) {
@ -286,6 +286,20 @@ public:
}
}
/**
* Flush the connection, sent all buffered data to the socket
* @return TcpState new tcp state
*/
virtual TcpState *flush() override
{
// create an object to wait for the filedescriptor to becomes active
Wait wait(_socket);
// @todo implementation
return this;
}
/**
* Send data over the connection
* @param buffer buffer to send

View File

@ -35,7 +35,7 @@ public:
* @param method
* @throws std::runtime_error
*/
SslContext(const SSL_METHOD *method) : _ctx(SSL_CTX_new(method))
SslContext(const SSL_METHOD *method) : _ctx(OpenSSL::SSL_CTX_new(method))
{
// report error
if (_ctx == nullptr) throw std::runtime_error("failed to construct ssl context");
@ -69,7 +69,7 @@ public:
virtual ~SslContext()
{
// free resource (this updates the refcount -1, and may destruct it)
SSL_CTX_free(_ctx);
OpenSSL::SSL_CTX_free(_ctx);
}
/**

View File

@ -111,18 +111,18 @@ public:
*/
SslHandshake(TcpConnection *connection, int socket, const std::string &hostname, TcpOutBuffer &&buffer, TcpHandler *handler) :
TcpState(connection, handler),
_ssl(SslContext(SSLv23_client_method())),
_ssl(SslContext(OpenSSL::TLS_client_method())),
_socket(socket),
_out(std::move(buffer))
{
// we will be using the ssl context as a client
SSL_set_connect_state(_ssl);
OpenSSL::SSL_set_connect_state(_ssl);
// associate domain name with the connection
SSL_set_tlsext_host_name(_ssl, hostname.data());
OpenSSL::SSL_ctrl(_ssl, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, (void *)hostname.data());
// associate the ssl context with the socket filedescriptor
if (SSL_set_fd(_ssl, socket) == 0) throw std::runtime_error("failed to associate filedescriptor with ssl socket");
if (OpenSSL::SSL_set_fd(_ssl, socket) == 0) throw std::runtime_error("failed to associate filedescriptor with ssl socket");
// we are going to wait until the socket becomes writable before we start the handshake
_handler->monitor(_connection, _socket, writable);
@ -164,7 +164,7 @@ public:
if (result == 1) return nextstate(new SslConnected(_connection, _socket, _ssl, std::move(_out), _handler));
// error was returned, so we must investigate what is going on
auto error = SSL_get_error(_ssl, result);
auto error = OpenSSL::SSL_get_error(_ssl, result);
// check the error
switch (error) {
@ -198,13 +198,13 @@ public:
while (true)
{
// start the ssl handshake
int result = SSL_do_handshake(_ssl);
int result = OpenSSL::SSL_do_handshake(_ssl);
// if the connection succeeds, we can move to the ssl-connected state
if (result == 1) return nextstate(new SslConnected(_connection, _socket, _ssl, std::move(_out), _handler));
// error was returned, so we must investigate what is going on
auto error = SSL_get_error(_ssl, result);
auto error = OpenSSL::SSL_get_error(_ssl, result);
// check the error
switch (error)

View File

@ -70,7 +70,7 @@ private:
TcpState *repeat(int result)
{
// error was returned, so we must investigate what is going on
auto error = SSL_get_error(_ssl, result);
auto error = OpenSSL::SSL_get_error(_ssl, result);
// check the error
switch (error) {
@ -145,7 +145,7 @@ public:
Monitor monitor(this);
// close the connection
auto result = SSL_shutdown(_ssl);
auto result = OpenSSL::SSL_shutdown(_ssl);
// if this is a success, we can proceed with the event loop
if (result > 0) return proceed();

View File

@ -34,7 +34,7 @@ public:
* Constructor
* @param ctx
*/
SslWrapper(SSL_CTX *ctx) : _ssl(SSL_new(ctx))
SslWrapper(SSL_CTX *ctx) : _ssl(OpenSSL::SSL_new(ctx))
{
// report error
if (_ssl == nullptr) throw std::runtime_error("failed to construct ssl structure");
@ -68,7 +68,7 @@ public:
virtual ~SslWrapper()
{
// destruct object
SSL_free(_ssl);
OpenSSL::SSL_free(_ssl);
}
/**

View File

@ -225,7 +225,7 @@ public:
// create an object to wait for the filedescriptor to becomes active
Wait wait(_socket);
// keep running until the out buffer is empty
// keep running until the out buffer is not empty
while (_out)
{
// poll the socket, is it already writable?

View File

@ -121,8 +121,17 @@ public:
*/
ssize_t receivefrom(SSL *ssl, uint32_t expected)
{
// @todo implementation
return 0;
// number of bytes to that still fit in the buffer
size_t bytes = expected - _size;
// read data
auto result = OpenSSL::SSL_read(ssl, (void *)(_data + _size), bytes);
// update total buffer size on success
if (result > 0) _size += result;
// done
return result;
}
/**

View File

@ -18,7 +18,8 @@
*/
#include <sys/ioctl.h>
#include <sys/uio.h>
#include <openssl/ssl.h>
#include "openssl.h"
/**
* FIONREAD on Solaris is defined elsewhere
*/
@ -284,7 +285,7 @@ public:
if (buffers == 0) return 0;
// send the data
auto result = SSL_write(ssl, buffer[0].iov_base, buffer[0].iov_len);
auto result = OpenSSL::SSL_write(ssl, buffer[0].iov_base, buffer[0].iov_len);
// on success we shrink the buffer
if (result > 0) shrink(result);

View File

@ -93,7 +93,7 @@ private:
try
{
// check if we support openssl in the first place
if (!OpenSSL::valid()) throw std::runtime_error("Secure connection cannot be established: the application has no access to openssl");
if (!OpenSSL::valid()) throw std::runtime_error("Secure connection cannot be established: libssl.so cannot be loaded");
// get address info
AddressInfo addresses(_hostname.data(), _port);