diff --git a/tensorpipe/transport/shm/reactor.cc b/tensorpipe/transport/shm/reactor.cc index b6297462c..98a356518 100644 --- a/tensorpipe/transport/shm/reactor.cc +++ b/tensorpipe/transport/shm/reactor.cc @@ -27,6 +27,7 @@ void writeToken(util::ringbuffer::Producer& producer, Reactor::TToken token) { TP_DCHECK_EQ(rv, sizeof(token)); break; } + producer.semPostData(); } } // namespace @@ -41,12 +42,13 @@ Reactor::Reactor() { consumer_.emplace(rb); producer_.emplace(rb); deferredFunctionToken_ = add([this]() { handleDeferredFunctionFromLoop(); }); + wakeUpToken_ = add([]() { ; }); thread_ = std::thread(&Reactor::run, this); } void Reactor::close() { if (!closed_.exchange(true)) { - // No need to wake up the reactor, since it is busy-waiting. + trigger(wakeUpToken_); } } @@ -94,6 +96,11 @@ void Reactor::remove(TToken token) { functions_[token] = nullptr; reusableTokens_.insert(token); functionCount_--; + if (functionCount_ >= 2) { + // Wake up loop thread upon each token removal except for the + // two used to defer and wake up. + writeToken(producer_.value(), wakeUpToken_); + } } void Reactor::trigger(TToken token) { @@ -108,14 +115,13 @@ std::tuple Reactor::fds() const { void Reactor::run() { setThreadName("TP_SHM_reactor"); // Stop when another thread has asked the reactor the close and when - // all functions have been removed except for the one used to defer. - while (!closed_ || functionCount_ > 1) { + // all functions have been removed except for the two used to defer + // and wakeup. + while (!closed_ || functionCount_ > 2) { uint32_t token; + consumer_->semWaitData(); auto ret = consumer_->copy(sizeof(token), &token); - if (ret == -ENODATA) { - std::this_thread::yield(); - continue; - } + TP_DCHECK_NE(ret, -ENODATA); TFunction fn; @@ -133,6 +139,7 @@ void Reactor::run() { } TP_DCHECK(deferredFunctionList_.empty()); remove(deferredFunctionToken_); + remove(wakeUpToken_); } Reactor::Trigger::Trigger(Fd&& headerFd, Fd&& dataFd) diff --git a/tensorpipe/transport/shm/reactor.h b/tensorpipe/transport/shm/reactor.h index 044ab7928..8f4975b2e 100644 --- a/tensorpipe/transport/shm/reactor.h +++ b/tensorpipe/transport/shm/reactor.h @@ -83,6 +83,7 @@ class Reactor final { std::atomic closed_{false}; std::atomic joined_{false}; + TToken wakeUpToken_; TToken deferredFunctionToken_; std::mutex deferredFunctionMutex_; std::list deferredFunctionList_; diff --git a/tensorpipe/util/ringbuffer/consumer.h b/tensorpipe/util/ringbuffer/consumer.h index 69c95464c..7cd7d6d25 100644 --- a/tensorpipe/util/ringbuffer/consumer.h +++ b/tensorpipe/util/ringbuffer/consumer.h @@ -154,6 +154,10 @@ class Consumer : public RingBufferWrapper { return static_cast(size); } + void semWaitData() { + semWait_(); + } + protected: bool inTx_{false}; diff --git a/tensorpipe/util/ringbuffer/producer.h b/tensorpipe/util/ringbuffer/producer.h index 9a5c94226..9634639a1 100644 --- a/tensorpipe/util/ringbuffer/producer.h +++ b/tensorpipe/util/ringbuffer/producer.h @@ -119,6 +119,10 @@ class Producer : public RingBufferWrapper { return writeInTx(sizeof(T), &d); } + void semPostData() { + semPost_(); + } + [[nodiscard]] std::pair reserveContiguousInTx( const size_t size) { if (unlikely(size == 0)) { diff --git a/tensorpipe/util/ringbuffer/ringbuffer.h b/tensorpipe/util/ringbuffer/ringbuffer.h index 91664f8ba..da6cb4f35 100644 --- a/tensorpipe/util/ringbuffer/ringbuffer.h +++ b/tensorpipe/util/ringbuffer/ringbuffer.h @@ -8,6 +8,7 @@ #pragma once +#include #include #include @@ -76,6 +77,12 @@ class RingBufferHeader { " buffer to ever be larger than what an int can hold"; in_write_tx.clear(); in_read_tx.clear(); + + // Init semaphore to work on shared memory. + int ret = sem_init(&sem_, 1, 0); + if (ret != 0) { + TP_THROW_SYSTEM(errno); + } } // Get size that is only guaranteed to be correct when producers and consumers @@ -107,6 +114,14 @@ class RingBufferHeader { atomicTail_.store(atomicHead_.load()); } + void semWait() { + sem_wait(&sem_); + } + + void semPost() { + sem_post(&sem_); + } + // acquired by producers. std::atomic_flag in_write_tx; // acquired by consumers. @@ -118,6 +133,8 @@ class RingBufferHeader { // Written by consumer. std::atomic atomicTail_{0}; + sem_t sem_; + // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2007/n2427.html#atomics.lockfree // static_assert( // decltype(atomicHead_)::is_always_lock_free, @@ -194,6 +211,14 @@ class RingBufferWrapper { } protected: + void semWait_() { + header_.semWait(); + } + + void semPost_() { + header_.semPost(); + } + std::shared_ptr rb_; RingBufferHeader& header_; uint8_t* const data_;