Added logic to conditionally enable internal data

This commit is contained in:
accelerated
2018-06-05 09:07:00 -04:00
parent 597c026555
commit f746653841
8 changed files with 125 additions and 57 deletions

View File

@@ -44,6 +44,7 @@
#include <cppkafka/macros.h>
#include <cppkafka/message.h>
#include <cppkafka/message_builder.h>
#include <cppkafka/message_internal.h>
#include <cppkafka/metadata.h>
#include <cppkafka/producer.h>
#include <cppkafka/queue.h>

View File

@@ -35,6 +35,8 @@
namespace cppkafka {
class Producer;
struct Internal {
virtual ~Internal() = default;
};
@@ -44,25 +46,11 @@ using InternalPtr = std::shared_ptr<Internal>;
* \brief Private message data structure
*/
class MessageInternal {
friend class Producer;
friend Producer;
public:
static std::unique_ptr<MessageInternal> load(Message& message) {
if (message.get_user_data()) {
// Unpack internal data
std::unique_ptr<MessageInternal> internal_data(static_cast<MessageInternal*>(message.get_user_data()));
message.load_internal(internal_data->user_data_, internal_data->internal_);
return internal_data;
}
return nullptr;
}
static std::unique_ptr<MessageInternal> load(const Producer& producer, Message& message);
private:
MessageInternal(void* user_data, std::shared_ptr<Internal> internal)
: user_data_(user_data),
internal_(internal) {
}
MessageInternal(void* user_data, std::shared_ptr<Internal> internal);
void* user_data_;
InternalPtr internal_;
};

View File

@@ -31,12 +31,14 @@
#define CPPKAFKA_PRODUCER_H
#include <memory>
#include <tuple>
#include "kafka_handle_base.h"
#include "configuration.h"
#include "buffer.h"
#include "topic.h"
#include "macros.h"
#include "message_builder.h"
#include "message_internal.h"
namespace cppkafka {
@@ -78,6 +80,7 @@ class Message;
*/
class CPPKAFKA_API Producer : public KafkaHandleBase {
public:
friend MessageInternal;
/**
* The policy to use for the payload. The default policy is COPY_PAYLOAD
*/
@@ -156,7 +159,11 @@ public:
*/
void flush(std::chrono::milliseconds timeout);
private:
using LoadResult = std::tuple<void*, std::unique_ptr<MessageInternal>>;
LoadResult load_internal(void* user_data, InternalPtr internal);
PayloadPolicy message_payload_policy_;
bool has_internal_data_;
};
} // cppkafka

View File

@@ -362,6 +362,21 @@ private:
std::promise<bool> should_retry_;
size_t num_retries_;
};
using TrackerPtr = std::shared_ptr<Tracker>;
template <typename BuilderType>
TrackerPtr add_tracker(BuilderType& builder) {
if (!has_internal_data_ && (max_number_retries_ > 0)) {
has_internal_data_ = true; //enable once
}
if (has_internal_data_) {
// Add message tracker
TrackerPtr tracker = std::make_shared<Tracker>(SenderType::Async, max_number_retries_);
builder.internal(tracker);
return tracker;
}
return nullptr;
}
template <typename BuilderType>
void do_add_message(BuilderType&& builder, MessagePriority priority, bool do_flush);
@@ -385,7 +400,8 @@ private:
std::atomic<size_t> flushes_in_progress_{0};
std::atomic<size_t> total_messages_produced_{0};
std::atomic<size_t> total_messages_dropped_{0};
int max_number_retries_{5};
int max_number_retries_{0};
bool has_internal_data_{false};
#ifdef KAFKA_TEST_INSTANCE
TestParameters* test_params_;
#endif
@@ -412,40 +428,40 @@ BufferedProducer<BufferType>::BufferedProducer(Configuration config)
template <typename BufferType>
void BufferedProducer<BufferType>::add_message(const MessageBuilder& builder) {
// Add message tracker
std::shared_ptr<Tracker> tracker = std::make_shared<Tracker>(SenderType::Async, max_number_retries_);
const_cast<MessageBuilder&>(builder).internal(tracker);
add_tracker(const_cast<MessageBuilder&>(builder));
do_add_message(builder, MessagePriority::Low, true);
}
template <typename BufferType>
void BufferedProducer<BufferType>::add_message(Builder builder) {
// Add message tracker
std::shared_ptr<Tracker> tracker = std::make_shared<Tracker>(SenderType::Async, max_number_retries_);
const_cast<Builder&>(builder).internal(tracker);
add_tracker(builder);
do_add_message(move(builder), MessagePriority::Low, true);
}
template <typename BufferType>
void BufferedProducer<BufferType>::produce(const MessageBuilder& builder) {
// Add message tracker
std::shared_ptr<Tracker> tracker = std::make_shared<Tracker>(SenderType::Async, max_number_retries_);
const_cast<MessageBuilder&>(builder).internal(tracker);
add_tracker(const_cast<MessageBuilder&>(builder));
async_produce(builder, true);
}
template <typename BufferType>
void BufferedProducer<BufferType>::sync_produce(const MessageBuilder& builder) {
// Add message tracker
std::shared_ptr<Tracker> tracker = std::make_shared<Tracker>(SenderType::Async, max_number_retries_);
const_cast<MessageBuilder&>(builder).internal(tracker);
std::future<bool> should_retry;
do {
should_retry = tracker->get_new_future();
TrackerPtr tracker = add_tracker(const_cast<MessageBuilder&>(builder));
if (tracker) {
// produce until we succeed or we reach max retry limit
std::future<bool> should_retry;
do {
should_retry = tracker->get_new_future();
produce_message(builder);
wait_for_acks();
}
while (should_retry.get());
}
else {
// produce once
produce_message(builder);
wait_for_acks();
}
while (should_retry.get());
}
template <typename BufferType>
@@ -634,8 +650,8 @@ void BufferedProducer<BufferType>::async_produce(MessageType&& message, bool thr
// If we have a flush failure callback and it returns true, we retry producing this message later
CallbackInvoker<FlushFailureCallback> callback("flush failure", flush_failure_callback_, &producer_);
if (!callback || callback(std::forward<MessageType>(message), ex.get_error())) {
std::shared_ptr<Tracker> tracker = std::static_pointer_cast<Tracker>(message.internal());
if (tracker->num_retries_ > 0) {
TrackerPtr tracker = std::static_pointer_cast<Tracker>(message.internal());
if (tracker && tracker->num_retries_ > 0) {
--tracker->num_retries_;
do_add_message(std::forward<MessageType>(message), MessagePriority::High, false);
return;
@@ -660,7 +676,7 @@ template <typename BufferType>
void BufferedProducer<BufferType>::on_delivery_report(const Message& message) {
//Get tracker data
TestParameters* test_params = get_test_parameters();
std::shared_ptr<Tracker> tracker = std::static_pointer_cast<Tracker>(message.internal());
TrackerPtr tracker = std::static_pointer_cast<Tracker>(message.internal());
bool should_retry = false;
if (message.get_error() || (test_params && test_params->force_delivery_error_)) {
// We should produce this message again if we don't have a produce failure callback
@@ -668,7 +684,7 @@ void BufferedProducer<BufferType>::on_delivery_report(const Message& message) {
CallbackInvoker<ProduceFailureCallback> callback("produce failure", produce_failure_callback_, &producer_);
if (!callback || callback(message)) {
// Check if we have reached the maximum retry limit
if (tracker->num_retries_ > 0) {
if (tracker && tracker->num_retries_ > 0) {
--tracker->num_retries_;
if (tracker->sender_ == SenderType::Async) {
// Re-enqueue for later retransmission with higher priority (i.e. front of the queue)
@@ -691,7 +707,9 @@ void BufferedProducer<BufferType>::on_delivery_report(const Message& message) {
++total_messages_produced_;
}
// Signal producers
tracker->should_retry_.set_value(should_retry);
if (tracker) {
tracker->should_retry_.set_value(should_retry);
}
// Decrement the expected acks
--pending_acks_;
assert(pending_acks_ != (size_t)-1); // Prevent underflow