diff --git a/include/cppkafka/message.h b/include/cppkafka/message.h index eda40dc..4226c09 100644 --- a/include/cppkafka/message.h +++ b/include/cppkafka/message.h @@ -180,7 +180,7 @@ private: Message(rd_kafka_message_t* handle, NonOwningTag); Message(HandlePtr handle); - void load_internal(void* user_data, InternalPtr internal); + Message& load_internal(); HandlePtr handle_; Buffer payload_; diff --git a/include/cppkafka/message_internal.h b/include/cppkafka/message_internal.h index a71add1..266e145 100644 --- a/include/cppkafka/message_internal.h +++ b/include/cppkafka/message_internal.h @@ -31,11 +31,10 @@ #define CPPKAFKA_MESSAGE_INTERNAL_H #include -#include "message.h" namespace cppkafka { -class Producer; +class Message; struct Internal { virtual ~Internal() = default; @@ -45,16 +44,37 @@ using InternalPtr = std::shared_ptr; /** * \brief Private message data structure */ -class MessageInternal { - friend Producer; -public: - static std::unique_ptr load(const Producer& producer, Message& message); -private: +struct MessageInternal { MessageInternal(void* user_data, std::shared_ptr internal); + static std::unique_ptr load(Message& message); void* user_data_; InternalPtr internal_; }; +template +struct MessageInternalGuard { + MessageInternalGuard(BuilderType& builder) + : builder_(builder), + user_data_(builder.user_data()) { + if (builder_.internal()) { + // Swap contents with user_data + ptr_.reset(new MessageInternal(user_data_, builder_.internal())); + builder_.user_data(ptr_.get()); //overwrite user data + } + } + ~MessageInternalGuard() { + //Restore user data + builder_.user_data(user_data_); + } + void release() { + ptr_.release(); + } +private: + BuilderType& builder_; + std::unique_ptr ptr_; + void* user_data_; +}; + } #endif //CPPKAFKA_MESSAGE_INTERNAL_H diff --git a/include/cppkafka/producer.h b/include/cppkafka/producer.h index a545c95..358a0fc 100644 --- a/include/cppkafka/producer.h +++ b/include/cppkafka/producer.h @@ -31,14 +31,12 @@ #define CPPKAFKA_PRODUCER_H #include -#include #include "kafka_handle_base.h" #include "configuration.h" #include "buffer.h" #include "topic.h" #include "macros.h" #include "message_builder.h" -#include "message_internal.h" namespace cppkafka { @@ -80,7 +78,6 @@ class Message; */ class CPPKAFKA_API Producer : public KafkaHandleBase { public: - friend MessageInternal; /** * The policy to use for the payload. The default policy is COPY_PAYLOAD */ @@ -159,11 +156,7 @@ public: */ void flush(std::chrono::milliseconds timeout); private: - using LoadResult = std::tuple>; - LoadResult load_internal(void* user_data, InternalPtr internal); - PayloadPolicy message_payload_policy_; - bool has_internal_data_; }; } // cppkafka diff --git a/include/cppkafka/utils/buffered_producer.h b/include/cppkafka/utils/buffered_producer.h index cec143c..bf4a2ea 100644 --- a/include/cppkafka/utils/buffered_producer.h +++ b/include/cppkafka/utils/buffered_producer.h @@ -104,7 +104,7 @@ public: /** * Callback to indicate a message failed to be flushed */ - using FlushFailureCallback = std::function; + using FlushFailureCallback = std::function; /** * \brief Constructs a buffered producer using the provided configuration @@ -369,24 +369,22 @@ private: if (!has_internal_data_ && (max_number_retries_ > 0)) { has_internal_data_ = true; //enable once } - if (has_internal_data_) { - // Add message tracker + if (has_internal_data_ && !builder.internal()) { + // Add message tracker only if it hasn't been added before TrackerPtr tracker = std::make_shared(SenderType::Async, max_number_retries_); builder.internal(tracker); return tracker; } return nullptr; } - template void do_add_message(BuilderType&& builder, MessagePriority priority, bool do_flush); - void do_add_message(const Message& message, MessagePriority priority, bool do_flush); - template - void produce_message(MessageType&& message); + template + void produce_message(BuilderType&& builder); Configuration prepare_configuration(Configuration config); void on_delivery_report(const Message& message); - template - void async_produce(MessageType&& message, bool throw_on_error); + template + void async_produce(BuilderType&& message, bool throw_on_error); // Members Producer producer_; @@ -466,7 +464,7 @@ void BufferedProducer::sync_produce(const MessageBuilder& builder) { template void BufferedProducer::produce(const Message& message) { - async_produce(message, true); + async_produce(MessageBuilder(message), true); } template @@ -546,13 +544,6 @@ void BufferedProducer::do_add_message(BuilderType&& builder, } } -template -void BufferedProducer::do_add_message(const Message& message, - MessagePriority priority, - bool do_flush) { - do_add_messsage(MessageBuilder(message), priority, do_flush); -} - template Producer& BufferedProducer::get_producer() { return producer_; @@ -615,11 +606,14 @@ void BufferedProducer::set_flush_failure_callback(FlushFailureCallba } template -template -void BufferedProducer::produce_message(MessageType&& message) { +template +void BufferedProducer::produce_message(BuilderType&& builder) { + using builder_type = typename std::decay::type; while (true) { try { - producer_.produce(std::forward(message)); + MessageInternalGuard internal_guard(const_cast(builder)); + producer_.produce(builder); + internal_guard.release(); // Sent successfully ++pending_acks_; break; @@ -637,23 +631,23 @@ void BufferedProducer::produce_message(MessageType&& message) { } template -template -void BufferedProducer::async_produce(MessageType&& message, bool throw_on_error) { +template +void BufferedProducer::async_produce(BuilderType&& builder, bool throw_on_error) { try { TestParameters* test_params = get_test_parameters(); if (test_params && test_params->force_produce_error_) { throw HandleException(Error(RD_KAFKA_RESP_ERR_UNKNOWN)); } - produce_message(std::forward(message)); + produce_message(std::forward(builder)); } catch (const HandleException& ex) { // If we have a flush failure callback and it returns true, we retry producing this message later CallbackInvoker callback("flush failure", flush_failure_callback_, &producer_); - if (!callback || callback(std::forward(message), ex.get_error())) { - TrackerPtr tracker = std::static_pointer_cast(message.internal()); + if (!callback || callback(std::forward(builder), ex.get_error())) { + TrackerPtr tracker = std::static_pointer_cast(builder.internal()); if (tracker && tracker->num_retries_ > 0) { --tracker->num_retries_; - do_add_message(std::forward(message), MessagePriority::High, false); + do_add_message(std::forward(builder), MessagePriority::High, false); return; } } @@ -676,7 +670,8 @@ template void BufferedProducer::on_delivery_report(const Message& message) { //Get tracker data TestParameters* test_params = get_test_parameters(); - TrackerPtr tracker = std::static_pointer_cast(message.internal()); + TrackerPtr tracker = has_internal_data_ ? + std::static_pointer_cast(MessageInternal::load(const_cast(message))->internal_) : nullptr; bool should_retry = false; if (message.get_error() || (test_params && test_params->force_delivery_error_)) { // We should produce this message again if we don't have a produce failure callback diff --git a/src/configuration.cpp b/src/configuration.cpp index 1783660..061adc7 100644 --- a/src/configuration.cpp +++ b/src/configuration.cpp @@ -31,7 +31,7 @@ #include #include #include "exceptions.h" -#include "message_internal.h" +#include "message.h" #include "producer.h" #include "consumer.h" @@ -40,10 +40,8 @@ using std::map; using std::move; using std::vector; using std::initializer_list; -using std::unique_ptr; -using boost::optional; - using std::chrono::milliseconds; +using boost::optional; namespace cppkafka { @@ -52,7 +50,6 @@ namespace cppkafka { void delivery_report_callback_proxy(rd_kafka_t*, const rd_kafka_message_t* msg, void *opaque) { Producer* handle = static_cast(opaque); Message message = Message::make_non_owning((rd_kafka_message_t*)msg); - unique_ptr internal_data(MessageInternal::load(*handle, message)); CallbackInvoker ("delivery report", handle->get_configuration().get_delivery_report_callback(), handle) (*handle, message); diff --git a/src/message.cpp b/src/message.cpp index 23070bf..d2b3dbb 100644 --- a/src/message.cpp +++ b/src/message.cpp @@ -28,6 +28,7 @@ */ #include "message.h" +#include "message_internal.h" using std::chrono::milliseconds; @@ -64,9 +65,13 @@ Message::Message(HandlePtr handle) user_data_(handle_ ? handle_->_private : nullptr) { } -void Message::load_internal(void* user_data, InternalPtr internal) { - user_data_ = user_data; - internal_ = internal; +Message& Message::load_internal() { + if (user_data_) { + MessageInternal* mi = static_cast(user_data_); + user_data_ = mi->user_data_; + internal_ = mi->internal_; + } + return *this; } // MessageTimestamp diff --git a/src/message_internal.cpp b/src/message_internal.cpp index a385377..c33d469 100644 --- a/src/message_internal.cpp +++ b/src/message_internal.cpp @@ -27,23 +27,22 @@ * */ #include "message_internal.h" -#include "producer.h" +#include "message.h" +#include "message_builder.h" namespace cppkafka { -MessageInternal::MessageInternal(void* user_data, std::shared_ptr internal) +// MessageInternal + +MessageInternal::MessageInternal(void* user_data, + std::shared_ptr internal) : user_data_(user_data), internal_(internal) { } -std::unique_ptr MessageInternal::load(const Producer& producer, Message& message) { - if (producer.has_internal_data_ && message.get_user_data()) { - // Unpack internal data - std::unique_ptr internal_data(static_cast(message.get_user_data())); - message.load_internal(internal_data->user_data_, internal_data->internal_); - return internal_data; - } - return nullptr; +std::unique_ptr MessageInternal::load(Message& message) { + return std::unique_ptr(message.load_internal().get_handle() ? + static_cast(message.get_handle()->_private) : nullptr); } } diff --git a/src/producer.cpp b/src/producer.cpp index bb8affb..4081b53 100644 --- a/src/producer.cpp +++ b/src/producer.cpp @@ -42,8 +42,7 @@ using std::get; namespace cppkafka { Producer::Producer(Configuration config) -: KafkaHandleBase(move(config)), message_payload_policy_(PayloadPolicy::COPY_PAYLOAD), - has_internal_data_(false) { +: KafkaHandleBase(move(config)), message_payload_policy_(PayloadPolicy::COPY_PAYLOAD) { char error_buffer[512]; auto config_handle = get_configuration().get_handle(); rd_kafka_conf_set_opaque(config_handle, this); @@ -69,7 +68,6 @@ void Producer::produce(const MessageBuilder& builder) { const Buffer& payload = builder.payload(); const Buffer& key = builder.key(); const int policy = static_cast(message_payload_policy_); - LoadResult load_result = load_internal(builder.user_data(), builder.internal()); auto result = rd_kafka_producev(get_handle(), RD_KAFKA_V_TOPIC(builder.topic().data()), RD_KAFKA_V_PARTITION(builder.partition()), @@ -77,10 +75,9 @@ void Producer::produce(const MessageBuilder& builder) { RD_KAFKA_V_TIMESTAMP(builder.timestamp().count()), RD_KAFKA_V_KEY((void*)key.get_data(), key.get_size()), RD_KAFKA_V_VALUE((void*)payload.get_data(), payload.get_size()), - RD_KAFKA_V_OPAQUE(get<0>(load_result)), + RD_KAFKA_V_OPAQUE(builder.user_data()), RD_KAFKA_V_END); check_error(result); - get<1>(load_result).release(); //data has been passed-on to rdkafka so we release ownership } void Producer::produce(const Message& message) { @@ -88,7 +85,6 @@ void Producer::produce(const Message& message) { const Buffer& key = message.get_key(); const int policy = static_cast(message_payload_policy_); int64_t duration = message.get_timestamp() ? message.get_timestamp().get().get_timestamp().count() : 0; - LoadResult load_result = load_internal(message.get_user_data(), message.internal()); auto result = rd_kafka_producev(get_handle(), RD_KAFKA_V_TOPIC(message.get_topic().data()), RD_KAFKA_V_PARTITION(message.get_partition()), @@ -96,10 +92,9 @@ void Producer::produce(const Message& message) { RD_KAFKA_V_TIMESTAMP(duration), RD_KAFKA_V_KEY((void*)key.get_data(), key.get_size()), RD_KAFKA_V_VALUE((void*)payload.get_data(), payload.get_size()), - RD_KAFKA_V_OPAQUE(get<0>(load_result)), + RD_KAFKA_V_OPAQUE(message.get_user_data()), RD_KAFKA_V_END); check_error(result); - get<1>(load_result).release(); //data has been passed-on to rdkafka so we release ownership } int Producer::poll() { @@ -119,16 +114,4 @@ void Producer::flush(milliseconds timeout) { check_error(result); } -Producer::LoadResult Producer::load_internal(void* user_data, InternalPtr internal) { - unique_ptr internal_data; - if (!has_internal_data_ && internal) { - has_internal_data_ = true; //enable once for this producer - } - if (has_internal_data_ && get_configuration().get_delivery_report_callback()) { - internal_data.reset(new MessageInternal(user_data, internal)); - user_data = internal_data.get(); //point to the internal data - } - return LoadResult(user_data, move(internal_data)); -} - } // cppkafka