Fixed a bug where a frame could be sent exceeding the maximum frame size (resulting in protocol errors) and added some optimizations

This commit is contained in:
Martijn Otto 2015-04-30 10:59:03 +02:00
parent b9caf0199d
commit 45deeaa754
9 changed files with 117 additions and 80 deletions

View File

@ -13,6 +13,7 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include <map> #include <map>
#include <unordered_map>
#include <queue> #include <queue>
#include <set> #include <set>
#include <limits> #include <limits>

View File

@ -53,9 +53,9 @@ protected:
/** /**
* All channels that are active * All channels that are active
* @var map * @var std::unordered_map<uint16_t, std::shared_ptr<ChannelImpl>>
*/ */
std::map<uint16_t, std::shared_ptr<ChannelImpl>> _channels; std::unordered_map<uint16_t, std::shared_ptr<ChannelImpl>> _channels;
/** /**
* The last unused channel ID * The last unused channel ID

View File

@ -24,13 +24,13 @@ protected:
* @var BooleanSet * @var BooleanSet
*/ */
BooleanSet _bools1; BooleanSet _bools1;
/** /**
* Second set of booleans * Second set of booleans
* @var BooleanSet * @var BooleanSet
*/ */
BooleanSet _bools2; BooleanSet _bools2;
/** /**
* MIME content type * MIME content type
* @var ShortString * @var ShortString
@ -121,8 +121,8 @@ protected:
* in a derived class * in a derived class
*/ */
MetaData() {} MetaData() {}
public: public:
/** /**
* Read incoming frame * Read incoming frame
@ -217,6 +217,23 @@ public:
void setTimestamp (uint64_t value) { _timestamp = value; _bools2.set(6,true); } void setTimestamp (uint64_t value) { _timestamp = value; _bools2.set(6,true); }
void setMessageID (const std::string &value) { _messageID = value; _bools2.set(7,true); } void setMessageID (const std::string &value) { _messageID = value; _bools2.set(7,true); }
/**
* Set the various supported fields using r-value references
*
* @param value moveable value
*/
void setExpiration (std::string &&value) { _expiration = std::move(value); _bools1.set(0,true); }
void setReplyTo (std::string &&value) { _replyTo = std::move(value); _bools1.set(1,true); }
void setCorrelationID (std::string &&value) { _correlationID = std::move(value); _bools1.set(2,true); }
void setHeaders (Table &&value) { _headers = std::move(value); _bools1.set(5,true); }
void setContentEncoding (std::string &&value) { _contentEncoding = std::move(value); _bools1.set(6,true); }
void setContentType (std::string &&value) { _contentType = std::move(value); _bools1.set(7,true); }
void setClusterID (std::string &&value) { _clusterID = std::move(value); _bools2.set(2,true); }
void setAppID (std::string &&value) { _appID = std::move(value); _bools2.set(3,true); }
void setUserID (std::string &&value) { _userID = std::move(value); _bools2.set(4,true); }
void setTypeName (std::string &&value) { _typeName = std::move(value); _bools2.set(5,true); }
void setMessageID (std::string &&value) { _messageID = std::move(value); _bools2.set(7,true); }
/** /**
* Retrieve the fields * Retrieve the fields
* @return string * @return string
@ -235,7 +252,7 @@ public:
const std::string &typeName () const { return _typeName; } const std::string &typeName () const { return _typeName; }
uint64_t timestamp () const { return _timestamp; } uint64_t timestamp () const { return _timestamp; }
const std::string &messageID () const { return _messageID; } const std::string &messageID () const { return _messageID; }
/** /**
* Is this a message with persistent storage * Is this a message with persistent storage
* This is an alias for retrieving the delivery mode and checking if it is set to 2 * This is an alias for retrieving the delivery mode and checking if it is set to 2
@ -245,19 +262,19 @@ public:
{ {
return hasDeliveryMode() && deliveryMode() == 2; return hasDeliveryMode() && deliveryMode() == 2;
} }
/** /**
* Set whether storage should be persistent or not * Set whether storage should be persistent or not
* @param bool * @param bool
*/ */
void setPersistent(bool value = true) void setPersistent(bool value = true)
{ {
if (value) if (value)
{ {
// simply set the delivery mode // simply set the delivery mode
setDeliveryMode(2); setDeliveryMode(2);
} }
else else
{ {
// we remove the field from the header // we remove the field from the header
_deliveryMode = 0; _deliveryMode = 0;
@ -273,7 +290,7 @@ public:
{ {
// the result (2 for the two boolean sets) // the result (2 for the two boolean sets)
uint32_t result = 2; uint32_t result = 2;
if (hasExpiration()) result += _expiration.size(); if (hasExpiration()) result += _expiration.size();
if (hasReplyTo()) result += _replyTo.size(); if (hasReplyTo()) result += _replyTo.size();
if (hasCorrelationID()) result += _correlationID.size(); if (hasCorrelationID()) result += _correlationID.size();
@ -288,11 +305,11 @@ public:
if (hasTypeName()) result += _typeName.size(); if (hasTypeName()) result += _typeName.size();
if (hasTimestamp()) result += _timestamp.size(); if (hasTimestamp()) result += _timestamp.size();
if (hasMessageID()) result += _messageID.size(); if (hasMessageID()) result += _messageID.size();
// done // done
return result; return result;
} }
/** /**
* Fill an output buffer * Fill an output buffer
* @param buffer * @param buffer
@ -302,7 +319,7 @@ public:
// the two boolean sets are always present // the two boolean sets are always present
_bools1.fill(buffer); _bools1.fill(buffer);
_bools2.fill(buffer); _bools2.fill(buffer);
// only copy the properties that were sent // only copy the properties that were sent
if (hasContentType()) _contentType.fill(buffer); if (hasContentType()) _contentType.fill(buffer);
if (hasContentEncoding()) _contentEncoding.fill(buffer); if (hasContentEncoding()) _contentEncoding.fill(buffer);

View File

@ -21,9 +21,9 @@ class OutBuffer
private: private:
/** /**
* Pointer to the beginning of the buffer * Pointer to the beginning of the buffer
* @var char* * @var std::unique_ptr<char[]>
*/ */
char *_buffer; std::unique_ptr<char[]> _buffer;
/** /**
* Pointer to the buffer to be filled * Pointer to the buffer to be filled
@ -49,56 +49,47 @@ public:
* Constructor * Constructor
* @param capacity * @param capacity
*/ */
OutBuffer(uint32_t capacity) OutBuffer(uint32_t capacity) :
{ _buffer(new char[capacity]),
// initialize members _current(_buffer.get()),
_size = 0; _size(0),
_capacity = capacity; _capacity(capacity)
_buffer = _current = new char[capacity]; {}
}
/** /**
* Copy constructor * Copy constructor
* @param that * @param that
*/ */
OutBuffer(const OutBuffer &that) OutBuffer(const OutBuffer &that) :
_buffer(new char[that._capacity]),
_current(_buffer.get() + that._size),
_size(that._size),
_capacity(that._capacity)
{ {
// initialize members
_size = that._size;
_capacity = that._capacity;
_buffer = new char[_capacity];
_current = _buffer + _size;
// copy memory // copy memory
memcpy(_buffer, that._buffer, _size); memcpy(_buffer.get(), that._buffer.get(), _size);
} }
/** /**
* Move constructor * Move constructor
* @param that * @param that
*/ */
OutBuffer(OutBuffer &&that) OutBuffer(OutBuffer &&that) :
_buffer(std::move(that._buffer)),
_current(that._current),
_size(that._size),
_capacity(that._capacity)
{ {
// copy all members
_size = that._size;
_capacity = that._capacity;
_buffer = that._buffer;
_current = that._current;
// reset the other object // reset the other object
that._size = 0; that._size = 0;
that._capacity = 0; that._capacity = 0;
that._buffer = nullptr;
that._current = nullptr; that._current = nullptr;
} }
/** /**
* Destructor * Destructor
*/ */
virtual ~OutBuffer() virtual ~OutBuffer() {}
{
if (_buffer) delete[] _buffer;
}
/** /**
* Get access to the internal buffer * Get access to the internal buffer
@ -106,7 +97,7 @@ public:
*/ */
const char *data() const const char *data() const
{ {
return _buffer; return _buffer.get();
} }
/** /**

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
/** /**
* String field types for amqp * String field types for amqp
* *
* @copyright 2014 Copernica BV * @copyright 2014 Copernica BV
*/ */
@ -34,7 +34,14 @@ public:
* *
* @param value string value * @param value string value
*/ */
StringField(std::string value) : _data(value) {} StringField(const std::string &value) : _data(value) {}
/**
* Construct based on a std::string
*
* @param value string value
*/
StringField(std::string &&value) : _data(std::move(value)) {}
/** /**
* Construct based on received data * Construct based on received data
@ -44,7 +51,7 @@ public:
{ {
// get the size // get the size
T size(frame); T size(frame);
// allocate string // allocate string
_data = std::string(frame.nextData(size.value()), (size_t) size.value()); _data = std::string(frame.nextData(size.value()), (size_t) size.value());
} }
@ -69,7 +76,7 @@ public:
* *
* @param value new value * @param value new value
*/ */
StringField& operator=(const std::string& value) StringField& operator=(const std::string &value)
{ {
// overwrite data // overwrite data
_data = value; _data = value;
@ -78,6 +85,20 @@ public:
return *this; return *this;
} }
/**
* Assign a new value
*
* @param value new value
*/
StringField& operator=(std::string &&value)
{
// overwrite data
_data = std::move(value);
// allow chaining
return *this;
}
/** /**
* Get the size this field will take when * Get the size this field will take when
* encoded in the AMQP wire-frame format * encoded in the AMQP wire-frame format
@ -87,7 +108,7 @@ public:
{ {
// find out size of the size parameter // find out size of the size parameter
T size(_data.size()); T size(_data.size());
// size of the uint8 or uint32 + the actual string size // size of the uint8 or uint32 + the actual string size
return size.size() + _data.size(); return size.size() + _data.size();
} }
@ -128,7 +149,7 @@ public:
{ {
// create size // create size
T size(_data.size()); T size(_data.size());
// first, write down the size of the string // first, write down the size of the string
size.fill(buffer); size.fill(buffer);

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
/** /**
* AMQP field table * AMQP field table
* *
* @copyright 2014 Copernica BV * @copyright 2014 Copernica BV
*/ */
@ -40,7 +40,7 @@ public:
* @param frame received frame to decode * @param frame received frame to decode
*/ */
Table(ReceivedFrame &frame); Table(ReceivedFrame &frame);
/** /**
* Copy constructor * Copy constructor
* @param table * @param table
@ -64,7 +64,7 @@ public:
* @return Table * @return Table
*/ */
Table &operator=(const Table &table); Table &operator=(const Table &table);
/** /**
* Move assignment operator * Move assignment operator
* @param table * @param table
@ -104,7 +104,7 @@ public:
/** /**
* Get a field * Get a field
* *
* If the field does not exist, an empty string field is returned * If the field does not exist, an empty string field is returned
* *
* @param name field name * @param name field name
@ -153,7 +153,7 @@ public:
} }
/** /**
* Write encoded payload to the given buffer. * Write encoded payload to the given buffer.
* @param buffer * @param buffer
*/ */
virtual void fill(OutBuffer& buffer) const override; virtual void fill(OutBuffer& buffer) const override;
@ -175,23 +175,23 @@ public:
{ {
// prefix // prefix
stream << "table("; stream << "table(";
// is this the first iteration // is this the first iteration
bool first = true; bool first = true;
// loop through all members // loop through all members
for (auto &iter : _fields) for (auto &iter : _fields)
{ {
// split with comma // split with comma
if (!first) stream << ","; if (!first) stream << ",";
// show output // show output
stream << iter.first << ":" << *iter.second; stream << iter.first << ":" << *iter.second;
// no longer first iter // no longer first iter
first = false; first = false;
} }
// postfix // postfix
stream << ")"; stream << ")";
} }

View File

@ -287,6 +287,10 @@ bool ConnectionImpl::send(const Frame &frame)
// some frames can be sent _after_ the close() function was called // some frames can be sent _after_ the close() function was called
if (_closed && !frame.partOfShutdown()) return false; if (_closed && !frame.partOfShutdown()) return false;
// if the frame is bigger than we allow on the connection
// it is impossible to send out this frame successfully
if (frame.totalSize() > _maxFrame) return false;
// we need an output buffer // we need an output buffer
OutBuffer buffer(frame.buffer()); OutBuffer buffer(frame.buffer());

View File

@ -18,6 +18,7 @@
#include <ostream> #include <ostream>
#include <math.h> #include <math.h>
#include <map> #include <map>
#include <unordered_map>
#include <vector> #include <vector>
#include <queue> #include <queue>

View File

@ -10,25 +10,25 @@ namespace AMQP {
*/ */
Table::Table(ReceivedFrame &frame) Table::Table(ReceivedFrame &frame)
{ {
// table buffer begins with the number of bytes to read // table buffer begins with the number of bytes to read
uint32_t bytesToRead = frame.nextUint32(); uint32_t bytesToRead = frame.nextUint32();
// keep going until the correct number of bytes is read. // keep going until the correct number of bytes is read.
while (bytesToRead > 0) while (bytesToRead > 0)
{ {
// field name and type // field name and type
ShortString name(frame); ShortString name(frame);
// subtract number of bytes to read, plus one byte for the decoded type // subtract number of bytes to read, plus one byte for the decoded type
bytesToRead -= (name.size() + 1); bytesToRead -= (name.size() + 1);
// get the field // get the field
Field *field = Field::decode(frame); Field *field = Field::decode(frame);
if (!field) continue; if (!field) continue;
// add field // add field
_fields[name] = std::shared_ptr<Field>(field); _fields[name] = std::shared_ptr<Field>(field);
// subtract size // subtract size
bytesToRead -= field->size(); bytesToRead -= field->size();
} }
@ -43,8 +43,10 @@ Table::Table(const Table &table)
// loop through the table records // loop through the table records
for (auto iter = table._fields.begin(); iter != table._fields.end(); iter++) for (auto iter = table._fields.begin(); iter != table._fields.end(); iter++)
{ {
// add the field // since a map is always ordered, we know that each element will
_fields[iter->first] = std::shared_ptr<Field>(iter->second->clone()); // be inserted at the end of the new map, so we can simply use
// emplace_hint and hint at insertion at the end of the map
_fields.emplace_hint(_fields.end(), std::make_pair(iter->first, iter->second->clone()));
} }
} }
@ -57,21 +59,21 @@ Table &Table::operator=(const Table &table)
{ {
// skip self assignment // skip self assignment
if (this == &table) return *this; if (this == &table) return *this;
// empty current fields // empty current fields
_fields.clear(); _fields.clear();
// loop through the table records // loop through the table records
for (auto iter = table._fields.begin(); iter != table._fields.end(); iter++) for (auto iter = table._fields.begin(); iter != table._fields.end(); iter++)
{ {
// add the field // add the field
_fields[iter->first] = std::shared_ptr<Field>(iter->second->clone()); _fields[iter->first] = std::shared_ptr<Field>(iter->second->clone());
} }
// done // done
return *this; return *this;
} }
/** /**
* Move assignment operator * Move assignment operator
* @param table * @param table
@ -81,17 +83,17 @@ Table &Table::operator=(Table &&table)
{ {
// skip self assignment // skip self assignment
if (this == &table) return *this; if (this == &table) return *this;
// copy fields // copy fields
_fields = std::move(table._fields); _fields = std::move(table._fields);
// done // done
return *this; return *this;
} }
/** /**
* Get a field * Get a field
* *
* If the field does not exist, an empty string field is returned * If the field does not exist, an empty string field is returned
* *
* @param name field name * @param name field name
@ -101,7 +103,7 @@ const Field &Table::get(const std::string &name) const
{ {
// we need an empty string // we need an empty string
static ShortString empty; static ShortString empty;
// locate the element first // locate the element first
auto iter(_fields.find(name)); auto iter(_fields.find(name));
@ -140,7 +142,7 @@ size_t Table::size() const
} }
/** /**
* Write encoded payload to the given buffer. * Write encoded payload to the given buffer.
*/ */
void Table::fill(OutBuffer& buffer) const void Table::fill(OutBuffer& buffer) const
{ {
@ -153,7 +155,7 @@ void Table::fill(OutBuffer& buffer) const
// encode the field name // encode the field name
ShortString name(iter->first); ShortString name(iter->first);
name.fill(buffer); name.fill(buffer);
// encode the element type // encode the element type
buffer.add((uint8_t) iter->second->typeID()); buffer.add((uint8_t) iter->second->typeID());