Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions tensorpipe/transport/shm/reactor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void writeToken(util::ringbuffer::Producer& producer, Reactor::TToken token) {
TP_DCHECK_EQ(rv, sizeof(token));
break;
}
producer.semPostData();
}

} // namespace
Expand All @@ -41,12 +42,13 @@ Reactor::Reactor() {
consumer_.emplace(rb);
producer_.emplace(rb);
deferredFunctionToken_ = add([this]() { handleDeferredFunctionFromLoop(); });
wakeUpToken_ = add([]() { ; });
thread_ = std::thread(&Reactor::run, this);
}

void Reactor::close() {
if (!closed_.exchange(true)) {
// No need to wake up the reactor, since it is busy-waiting.
trigger(wakeUpToken_);
}
}

Expand Down Expand Up @@ -94,6 +96,11 @@ void Reactor::remove(TToken token) {
functions_[token] = nullptr;
reusableTokens_.insert(token);
functionCount_--;
if (functionCount_ >= 2) {
// Wake up loop thread upon each token removal except for the
// two used to defer and wake up.
writeToken(producer_.value(), wakeUpToken_);
}
}

void Reactor::trigger(TToken token) {
Expand All @@ -108,14 +115,13 @@ std::tuple<int, int> Reactor::fds() const {
void Reactor::run() {
setThreadName("TP_SHM_reactor");
// Stop when another thread has asked the reactor the close and when
// all functions have been removed except for the one used to defer.
while (!closed_ || functionCount_ > 1) {
// all functions have been removed except for the two used to defer
// and wakeup.
while (!closed_ || functionCount_ > 2) {
uint32_t token;
consumer_->semWaitData();
auto ret = consumer_->copy(sizeof(token), &token);
if (ret == -ENODATA) {
std::this_thread::yield();
continue;
}
TP_DCHECK_NE(ret, -ENODATA);

TFunction fn;

Expand All @@ -133,6 +139,7 @@ void Reactor::run() {
}
TP_DCHECK(deferredFunctionList_.empty());
remove(deferredFunctionToken_);
remove(wakeUpToken_);
}

Reactor::Trigger::Trigger(Fd&& headerFd, Fd&& dataFd)
Expand Down
1 change: 1 addition & 0 deletions tensorpipe/transport/shm/reactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class Reactor final {
std::atomic<bool> closed_{false};
std::atomic<bool> joined_{false};

TToken wakeUpToken_;
TToken deferredFunctionToken_;
std::mutex deferredFunctionMutex_;
std::list<TDeferredFunction> deferredFunctionList_;
Expand Down
4 changes: 4 additions & 0 deletions tensorpipe/util/ringbuffer/consumer.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ class Consumer : public RingBufferWrapper {
return static_cast<ssize_t>(size);
}

void semWaitData() {
semWait_();
}

protected:
bool inTx_{false};

Expand Down
4 changes: 4 additions & 0 deletions tensorpipe/util/ringbuffer/producer.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ class Producer : public RingBufferWrapper {
return writeInTx(sizeof(T), &d);
}

void semPostData() {
semPost_();
}

[[nodiscard]] std::pair<ssize_t, void*> reserveContiguousInTx(
const size_t size) {
if (unlikely(size == 0)) {
Expand Down
25 changes: 25 additions & 0 deletions tensorpipe/util/ringbuffer/ringbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include <semaphore.h>
#include <sys/types.h>

#include <atomic>
Expand Down Expand Up @@ -76,6 +77,12 @@ class RingBufferHeader {
" buffer to ever be larger than what an int can hold";
in_write_tx.clear();
in_read_tx.clear();

// Init semaphore to work on shared memory.
int ret = sem_init(&sem_, 1, 0);
if (ret != 0) {
TP_THROW_SYSTEM(errno);
}
}

// Get size that is only guaranteed to be correct when producers and consumers
Expand Down Expand Up @@ -107,6 +114,14 @@ class RingBufferHeader {
atomicTail_.store(atomicHead_.load());
}

void semWait() {
sem_wait(&sem_);
}

void semPost() {
sem_post(&sem_);
}

// acquired by producers.
std::atomic_flag in_write_tx;
// acquired by consumers.
Expand All @@ -118,6 +133,8 @@ class RingBufferHeader {
// Written by consumer.
std::atomic<uint64_t> atomicTail_{0};

sem_t sem_;

// http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2007/n2427.html#atomics.lockfree
// static_assert(
// decltype(atomicHead_)::is_always_lock_free,
Expand Down Expand Up @@ -194,6 +211,14 @@ class RingBufferWrapper {
}

protected:
void semWait_() {
header_.semWait();
}

void semPost_() {
header_.semPost();
}

std::shared_ptr<RingBuffer> rb_;
RingBufferHeader& header_;
uint8_t* const data_;
Expand Down