From 28119ba89dcae23a966ab1fe9ed33b1dab063d5f Mon Sep 17 00:00:00 2001 From: hiyijian Date: Tue, 19 Jan 2016 11:57:42 +0800 Subject: [PATCH 1/3] Support asymmetric Dirichlet prior --- inference/infer.cpp | 6 +- inference/inferer.cpp | 3 +- src/alias_table.cpp | 156 +++++++++++++++++++-------------- src/alias_table.h | 31 +++++-- src/asym_alpha.cpp | 122 ++++++++++++++++++++++++++ src/asym_alpha.h | 60 +++++++++++++ src/common.cpp | 6 +- src/common.h | 10 ++- src/data_stream.cpp | 16 ++-- src/lightlda.cpp | 83 +++--------------- src/meta.cpp | 6 +- src/model.cpp | 198 ++++++++++++++++++++++++++++++++++++++++-- src/model.h | 40 +++++++-- src/sampler.cpp | 102 +++++++++++++++++----- src/sampler.h | 11 ++- src/trainer.cpp | 38 +++++++- src/trainer.h | 5 +- 17 files changed, 696 insertions(+), 197 deletions(-) create mode 100644 src/asym_alpha.cpp create mode 100644 src/asym_alpha.h diff --git a/inference/infer.cpp b/inference/infer.cpp index a0d3a60..6a8020d 100644 --- a/inference/infer.cpp +++ b/inference/infer.cpp @@ -28,8 +28,8 @@ namespace multiverso { namespace lightlda LocalModel* model = new LocalModel(&meta); model->Init(); //init document stream data_stream = CreateDataStream(); - //init documents - InitDocument(); + //init doc-topic + InitDocTopic(); //init alias table AliasTable* alias_table = new AliasTable(); //init inferers @@ -102,7 +102,7 @@ namespace multiverso { namespace lightlda return nullptr; } - static void InitDocument() + static void InitDocTopic() { xorshift_rng rng; for (int32_t block = 0; block < Config::num_blocks; ++block) diff --git a/inference/inferer.cpp b/inference/inferer.cpp index 2bd8f3b..64ff8d6 100644 --- a/inference/inferer.cpp +++ b/inference/inferer.cpp @@ -74,7 +74,8 @@ namespace multiverso { namespace lightlda for (int32_t doc_id = id_; doc_id < data.Size(); doc_id += thread_num_) { Document* doc = data.GetOneDoc(doc_id); - sampler_->SampleOneDoc(doc, 0, lastword, model_, alias_); + //TODO: Asymmeric prior + sampler_->SampleOneDoc(doc, 0, lastword, model_, alias_, nullptr); } } diff --git a/src/alias_table.cpp b/src/alias_table.cpp index 5b383bd..738acc3 100644 --- a/src/alias_table.cpp +++ b/src/alias_table.cpp @@ -10,13 +10,16 @@ #include #include +#define SAFE_DELETE(p) if((p)) { delete (p); (p) = nullptr; } + namespace multiverso { namespace lightlda { _THREAD_LOCAL std::vector* AliasTable::q_w_proportion_; - _THREAD_LOCAL std::vector* AliasTable::q_w_proportion_int_; - _THREAD_LOCAL std::vector>* AliasTable::L_; - _THREAD_LOCAL std::vector>* AliasTable::H_; + _THREAD_LOCAL std::vector* AliasMultinomialRNGInt::q_proportion_int_; + _THREAD_LOCAL std::vector>* AliasMultinomialRNGInt::L_; + _THREAD_LOCAL std::vector>* AliasMultinomialRNGInt::H_; + // -- AliasTable implement area --------------------------------- // AliasTable::AliasTable() { memory_size_ = Config::alias_capacity / sizeof(int32_t); @@ -25,6 +28,8 @@ namespace multiverso { namespace lightlda beta_ = Config::beta; beta_sum_ = beta_ * num_vocabs_; memory_block_ = new int32_t[memory_size_]; + + alias_rng_int_ = new AliasMultinomialRNGInt(num_topics_); beta_kv_vector_ = new int32_t[2 * num_topics_]; @@ -34,6 +39,7 @@ namespace multiverso { namespace lightlda AliasTable::~AliasTable() { + delete alias_rng_int_; delete[] memory_block_; delete[] beta_kv_vector_; } @@ -47,12 +53,6 @@ namespace multiverso { namespace lightlda { if (q_w_proportion_ == nullptr) q_w_proportion_ = new std::vector(num_topics_); - if (q_w_proportion_int_ == nullptr) - q_w_proportion_int_ = new std::vector(num_topics_); - if (L_ == nullptr) - L_ = new std::vector>(num_topics_); - if (H_ == nullptr) - H_ = new std::vector>(num_topics_); // Compute the proportion Row& summary_row = model->GetSummaryRow(); if (word == -1) // build alias row for beta @@ -63,8 +63,7 @@ namespace multiverso { namespace lightlda (*q_w_proportion_)[k] = beta_ / (summary_row.At(k) + beta_sum_); beta_mass_ += (*q_w_proportion_)[k]; } - AliasMultinomialRNG(num_topics_, beta_mass_, beta_height_, - beta_kv_vector_); + alias_rng_int_->Build(*q_w_proportion_, num_topics_, beta_mass_, beta_height_, beta_kv_vector_); } else // build alias row for word { @@ -105,7 +104,7 @@ namespace multiverso { namespace lightlda word_topic_row.NonzeroSize()); } } - AliasMultinomialRNG(size, mass_[word], height_[word], + alias_rng_int_->Build(*q_w_proportion_, size, mass_[word], height_[word], memory_block_ + word_entry.begin_offset); } return 0; @@ -118,62 +117,35 @@ namespace multiverso { namespace lightlda int32_t capacity = word_entry.capacity; if (word_entry.is_dense) { - auto sample = rng.rand(); - int32_t idx = sample / height_[word]; - if (capacity <= idx) idx = capacity - 1; - - int32_t* p = kv_vector + 2 * idx; - int32_t k = *p++; - int32_t v = *p; - int32_t m = -(sample < v); - return (idx & m) | (k & ~m); + return alias_rng_int_->Propose(rng, height_[word], kv_vector); } else { - auto sample = rng.rand_double() * (mass_[word] + beta_mass_); - if (sample < mass_[word]) - { - int32_t* idx_vector = kv_vector + 2 * word_entry.capacity; - auto n_kw_sample = rng.rand(); - int32_t idx = n_kw_sample / height_[word]; - if (capacity <= idx) idx = capacity - 1; - int32_t* p = kv_vector + 2 * idx; - int32_t k = *p++; - int32_t v = *p; - int32_t id = idx_vector[idx]; - int32_t m = -(n_kw_sample < v); - return (id & m) | (idx_vector[k] & ~m); - } - else - { - auto beta_sample = rng.rand(); - int32_t idx = beta_sample / beta_height_; - if (num_topics_ <= idx) idx = num_topics_ - 1; - int32_t* p = beta_kv_vector_ + 2 * idx; - int32_t k = *p++; - int32_t v = *p; - int32_t m = -(beta_sample < v); - return (idx & m) | (k & ~m); - } + return alias_rng_int_->Propose(rng, height_[word], beta_height_, + mass_[word], beta_mass_, + kv_vector, capacity, + beta_kv_vector_); } } void AliasTable::Clear() { - delete q_w_proportion_; - q_w_proportion_ = nullptr; - delete q_w_proportion_int_; - q_w_proportion_int_ = nullptr; - delete L_; - L_ = nullptr; - delete H_; - H_ = nullptr; + SAFE_DELETE(q_w_proportion_); + alias_rng_int_->Clear(); } + // -- AliasTable implement area --------------------------------- // - - void AliasTable::AliasMultinomialRNG(int32_t size, float mass, int32_t& height, - int32_t* kv_vector) + // -- AliasMultinomialRNGInt implement area --------------------------------- // + void AliasMultinomialRNGInt::Build(const std::vector& q_proportion, int32_t size, + float mass, int32_t & height, int32_t* kv_vector) { + if (q_proportion_int_ == nullptr) + q_proportion_int_ = new std::vector(size_); + if (L_ == nullptr) + L_ = new std::vector>(size_); + if (H_ == nullptr) + H_ = new std::vector>(size_); + int32_t mass_int = 0x7fffffff; int32_t a_int = mass_int / size; mass_int = a_int * size; @@ -181,10 +153,9 @@ namespace multiverso { namespace lightlda int64_t mass_sum = 0; for (int32_t i = 0; i < size; ++i) { - (*q_w_proportion_)[i] /= mass; - (*q_w_proportion_int_)[i] = - static_cast((*q_w_proportion_)[i] * mass_int); - mass_sum += (*q_w_proportion_int_)[i]; + (*q_proportion_int_)[i] = + static_cast(q_proportion[i] / mass * mass_int); + mass_sum += (*q_proportion_int_)[i]; } if (mass_sum > mass_int) { @@ -192,9 +163,9 @@ namespace multiverso { namespace lightlda int32_t id = 0; for (int32_t i = 0; i < more;) { - if ((*q_w_proportion_int_)[id] >= 1) + if ((*q_proportion_int_)[id] >= 1) { - --(*q_w_proportion_int_)[id]; + --(*q_proportion_int_)[id]; ++i; } id = (id + 1) % size; @@ -207,7 +178,7 @@ namespace multiverso { namespace lightlda int32_t id = 0; for (int32_t i = 0; i < more; ++i) { - ++(*q_w_proportion_int_)[id]; + ++(*q_proportion_int_)[id]; id = (id + 1) % size; } } @@ -221,7 +192,7 @@ namespace multiverso { namespace lightlda int32_t L_head = 0, L_tail = 0, H_head = 0, H_tail = 0; for (int32_t k = 0; k < size; ++k) { - int32_t val = (*q_w_proportion_int_)[k]; + int32_t val = (*q_proportion_int_)[k]; if (val < height) { (*L_)[L_tail].first = k; @@ -276,5 +247,60 @@ namespace multiverso { namespace lightlda ++H_head; } } + + void AliasMultinomialRNGInt::Clear() + { + SAFE_DELETE(q_proportion_int_); + SAFE_DELETE(L_); + SAFE_DELETE(H_); + } + + int32_t AliasMultinomialRNGInt::Propose(xorshift_rng& rng, int32_t height, + int32_t* kv_vector) + { + auto sample = rng.rand(); + int32_t idx = sample / height; + if (size_ <= idx) idx = size_ - 1; + + int32_t* p = kv_vector + 2 * idx; + int32_t k = *p++; + int32_t v = *p; + int32_t m = -(sample < v); + return (idx & m) | (k & ~m); + } + + int32_t AliasMultinomialRNGInt::Propose(xorshift_rng& rng, + int32_t height, int32_t height_sum, + float mass, float mass_sum, + int32_t* kv_vector, int32_t vsize, + int32_t* kv_vector_sum) + { + auto sample = rng.rand_double() * (mass + mass_sum); + if (sample < mass) + { + int32_t* idx_vector = kv_vector + 2 * vsize; + auto n_sample = rng.rand(); + int32_t idx = n_sample / height; + if (vsize <= idx) idx = vsize - 1; + int32_t* p = kv_vector + 2 * idx; + int32_t k = *p++; + int32_t v = *p; + int32_t id = idx_vector[idx]; + int32_t m = -(n_sample < v); + return (id & m) | (idx_vector[k] & ~m); + } + else + { + auto n_sample = rng.rand(); + int32_t idx = n_sample / height_sum; + if (size_ <= idx) idx = size_ - 1; + int32_t* p = kv_vector_sum + 2 * idx; + int32_t k = *p++; + int32_t v = *p; + int32_t m = -(n_sample < v); + return (idx & m) | (k & ~m); + } + } + // -- AliasMultinomialRNGInt implement area --------------------------------- // } // namespace lightlda } // namespace multiverso diff --git a/src/alias_table.h b/src/alias_table.h index 4646e64..d3eece8 100644 --- a/src/alias_table.h +++ b/src/alias_table.h @@ -23,6 +23,31 @@ namespace multiverso { namespace lightlda class xorshift_rng; class AliasTableIndex; + class AliasMultinomialRNGInt + { + public: + AliasMultinomialRNGInt(int32_t size): size_(size) {} + void Build(const std::vector& q_proportion, int32_t size, + float mass, int32_t & height, int32_t* kv_vector); + void Clear(); + + //for dense sampling + int32_t Propose(xorshift_rng& rng, int32_t height, int32_t* kv_vector); + //for sparse sampling + int32_t Propose(xorshift_rng& rng, + int32_t height, int32_t height_sum, + float mass, float mass_sum, + int32_t* kv_vector, int32_t vsize, + int32_t* kv_vector_sum); + + private: + int32_t size_; + // thread local storage used for building alias + _THREAD_LOCAL static std::vector* q_proportion_int_; + _THREAD_LOCAL static std::vector>* L_; + _THREAD_LOCAL static std::vector>* H_; + }; + /*! * \brief AliasTable is the storage for alias tables used for fast sampling * from lightlda word proposal distribution. It optimize memory usage @@ -56,8 +81,7 @@ namespace multiverso { namespace lightlda /*! \brief Clear the alias table */ void Clear(); private: - void AliasMultinomialRNG(int32_t size, float mass, int32_t& height, - int32_t* kv_vector); + AliasMultinomialRNGInt * alias_rng_int_; int* memory_block_; int64_t memory_size_; AliasTableIndex* table_index_; @@ -71,9 +95,6 @@ namespace multiverso { namespace lightlda // thread local storage used for building alias _THREAD_LOCAL static std::vector* q_w_proportion_; - _THREAD_LOCAL static std::vector* q_w_proportion_int_; - _THREAD_LOCAL static std::vector>* L_; - _THREAD_LOCAL static std::vector>* H_; int num_vocabs_; int num_topics_; diff --git a/src/asym_alpha.cpp b/src/asym_alpha.cpp new file mode 100644 index 0000000..1eb040e --- /dev/null +++ b/src/asym_alpha.cpp @@ -0,0 +1,122 @@ +#include "asym_alpha.h" +#include "alias_table.h" +#include "model.h" +#include "common.h" +#include + +namespace multiverso +{ +namespace lightlda +{ + AsymAlpha::AsymAlpha() : dirichlet_scale_(1.0), dirichlet_shape_(1.00001) + { + num_topic_ = Config::num_topics; + max_doc_length_ = kMaxDocLength; + num_alpha_iterations_ = Config::num_alpha_iterations; + alpha_sum_ = num_topic_ * Config::alpha; + non_zero_limit_.resize(num_topic_); + alpha_base_measure_.resize(num_topic_, Config::alpha); + kv_vector_ = new int32_t[2 * num_topic_]; + alias_rng_int_ = new AliasMultinomialRNGInt(num_topic_); + } + + AsymAlpha::~AsymAlpha() + { + delete [] kv_vector_; + delete alias_rng_int_; + } + + void AsymAlpha::LearnDirichletPrior(ModelBase * model) + { + float oldParametersK; + float currentDigamma; + float denominator; + int nonZeroLimit; + float parametersSum; + + // get the initial non_zero_limit_ + for (int k = 0; k < num_topic_; ++k) + { + non_zero_limit_[k] = 0; + Row& row = model->GetTopicFrequencyRow(k); + for (int i = 0; i < max_doc_length_; ++i) + { + if (row.At(i) > 0) + { + non_zero_limit_[k] = i; + } + } + } + + // get the initial atomic_alpha_sum_ + parametersSum = 0; + for (int k = 0; k < num_topic_; k++) + { + parametersSum += alpha_base_measure_[k]; + } + + Row& doc_length_row = model->GetDocLengthRow(); + + for (int iteration = 0; + iteration < num_alpha_iterations_; ++iteration) + { + // Calculate the denominator + denominator = 0; + currentDigamma = 0; + + // Iterate over the histogram: + for (int i = 1; i < max_doc_length_; i++) + { + currentDigamma += 1 / (parametersSum + i - 1); + denominator += doc_length_row.At(i) * currentDigamma; + } + // Bayesian estimation Part I + denominator -= 1 / dirichlet_scale_; + + // Calculate the individual parameters + parametersSum = 0; + + for (int k = 0; k < num_topic_; k++) + { + // What's the largest non-zero element in the histogram? + nonZeroLimit = non_zero_limit_[k]; + + oldParametersK = alpha_base_measure_[k]; + alpha_base_measure_[k] = 0; + currentDigamma = 0; + + Row& row = model->GetTopicFrequencyRow(k); + + for (int i = 1; i <= nonZeroLimit; i++) + { + currentDigamma += 1 / (oldParametersK + i - 1); + alpha_base_measure_[k] += row.At(i) * currentDigamma; + } + + // Bayesian estimation part II + alpha_base_measure_[k] = oldParametersK + * (alpha_base_measure_[k] + dirichlet_shape_) + / denominator; + parametersSum += alpha_base_measure_[k]; + } + } + alpha_sum_ = parametersSum; + } + + void AsymAlpha::BuildAlias() + { + alias_rng_int_->Build(alpha_base_measure_, num_topic_, + alpha_sum_, alpha_height_, kv_vector_); + } + + void AsymAlpha::Clear() + { + alias_rng_int_->Clear(); + } + + int32_t AsymAlpha::Next() + { + return alias_rng_int_->Propose(rng_, alpha_height_, kv_vector_); + } +} // namespace lightlda +} // namespace multiverso \ No newline at end of file diff --git a/src/asym_alpha.h b/src/asym_alpha.h new file mode 100644 index 0000000..76ac109 --- /dev/null +++ b/src/asym_alpha.h @@ -0,0 +1,60 @@ +/*! + * \file asym_alpha.h + * \brief Defines asymmetric prior alpha + */ + +#ifndef LIGHTLDA_ASYM_ALPHA_H_ +#define LIGHTLDA_ASYM_ALPHA_H_ + +#include +#include +#include "util.h" + +namespace multiverso +{ +namespace lightlda +{ + class ModelBase; + class AliasMultinomialRNGInt; + + class AsymAlpha + { + public: + AsymAlpha(); + ~AsymAlpha(); + void LearnDirichletPrior(ModelBase * model); + void BuildAlias(); + void Clear(); + int32_t Next(); + float At(int32_t idx) const; + float AlphaSum() const; + private: + int32_t num_topic_; + int32_t max_doc_length_; + int32_t num_alpha_iterations_; + float dirichlet_scale_; + float dirichlet_shape_; + float alpha_sum_; + int32_t alpha_height_; + int32_t* kv_vector_; + std::vector non_zero_limit_; + std::vector alpha_base_measure_; + xorshift_rng rng_; + AliasMultinomialRNGInt * alias_rng_int_; + }; + + // -- inline functions definition area --------------------------------- // + inline float AsymAlpha::At(int32_t idx) const + { + return alpha_base_measure_[idx]; + } + + inline float AsymAlpha::AlphaSum() const + { + return alpha_sum_; + } + // -- inline functions definition area --------------------------------- // +} // namespace lightlda +} // namespace multiverso + +#endif //LIGHTLDA_ASYM_ALPHA_H_ diff --git a/src/common.cpp b/src/common.cpp index 3e18535..4fb38d6 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -10,6 +10,7 @@ namespace multiverso { namespace lightlda int32_t Config::num_vocabs = -1; int32_t Config::num_topics = 100; int32_t Config::num_iterations = 100; + int32_t Config::num_alpha_iterations = 0; int32_t Config::mh_steps = 2; int32_t Config::num_servers = 1; int32_t Config::num_local_workers = 1; @@ -22,6 +23,7 @@ namespace multiverso { namespace lightlda std::string Config::input_dir = ""; bool Config::warm_start = false; bool Config::inference = false; + bool Config::asymmetric_prior = false; bool Config::out_of_core = false; int64_t Config::data_capacity = 1024 * kMB; int64_t Config::model_capacity = 512 * kMB; @@ -44,6 +46,7 @@ namespace multiverso { namespace lightlda if (strcmp(argv[i], "-num_vocabs") == 0) num_vocabs = atoi(argv[i + 1]); if (strcmp(argv[i], "-num_topics") == 0) num_topics = atoi(argv[i + 1]); if (strcmp(argv[i], "-num_iterations") == 0) num_iterations = atoi(argv[i + 1]); + if (strcmp(argv[i], "-num_alpha_iterations") == 0) num_alpha_iterations = atoi(argv[i + 1]); if (strcmp(argv[i], "-mh_steps") == 0) mh_steps = atoi(argv[i + 1]); if (strcmp(argv[i], "-num_servers") == 0) num_servers = atoi(argv[i + 1]); if (strcmp(argv[i], "-num_local_workers") == 0) num_local_workers = atoi(argv[i + 1]); @@ -59,7 +62,8 @@ namespace multiverso { namespace lightlda if (strcmp(argv[i], "-data_capacity") == 0) data_capacity = atoi(argv[i + 1]) * kMB; if (strcmp(argv[i], "-model_capacity") == 0) model_capacity = atoi(argv[i + 1]) * kMB; if (strcmp(argv[i], "-alias_capacity") == 0) alias_capacity = atoi(argv[i + 1]) * kMB; - if (strcmp(argv[i], "-delta_capacity") == 0) delta_capacity = atoi(argv[i + 1]) * kMB; + if (strcmp(argv[i], "-delta_capacity") == 0) delta_capacity = atoi(argv[i + 1]) * kMB; + if(num_alpha_iterations > 0) asymmetric_prior = true; } Check(); } diff --git a/src/common.h b/src/common.h index bfef1ed..f000169 100644 --- a/src/common.h +++ b/src/common.h @@ -16,6 +16,10 @@ namespace multiverso { namespace lightlda const int32_t kWordTopicTable = 0; /*! \brief constant variable for table id */ const int32_t kSummaryRow = 1; + /*! \brief constant variable for table id */ + const int32_t kTopicFrequencyTable = 2; + /*! \brief constant variable for table id */ + const int32_t kDocLengthRow = 3; /*! \brief load factor for sparse hash table */ const int32_t kLoadFactor = 2; /*! \brief max length of a document */ @@ -36,8 +40,10 @@ namespace multiverso { namespace lightlda static int32_t num_vocabs; /*! \brief number of topics */ static int32_t num_topics; - /*! \brief number of iterations for trainning */ + /*! \brief number of iterations */ static int32_t num_iterations; + /*! \brief number of estimating alpha iterations */ + static int32_t num_alpha_iterations; /*! \brief number of metropolis-hastings steps */ static int32_t mh_steps; /*! \brief number of servers for Multiverso setting */ @@ -62,6 +68,8 @@ namespace multiverso { namespace lightlda static bool warm_start; /*! \brief inference mode */ static bool inference; + /*! \brief asymmetric prior */ + static bool asymmetric_prior; /*! \brief option specity whether use out of core computation */ static bool out_of_core; /*! \brief memory capacity settings, for memory pools */ diff --git a/src/data_stream.cpp b/src/data_stream.cpp index 21675e9..0d46560 100644 --- a/src/data_stream.cpp +++ b/src/data_stream.cpp @@ -31,8 +31,8 @@ namespace multiverso { namespace lightlda { typedef DoubleBuffer DataBuffer; public: - DiskDataStream(int32_t num_blocks, std::string data_path, - int32_t num_iterations); + DiskDataStream(std::string data_path, + int32_t num_blocks, int32_t num_iterations); virtual ~DiskDataStream(); virtual void BeforeDataAccess() override; virtual void EndDataAccess() override; @@ -46,12 +46,12 @@ namespace multiverso { namespace lightlda DataBuffer* data_buffer_; /*! \brief current block id to be accessed */ int32_t block_id_; + /*! \brief data path */ + std::string data_path_; /*! \brief number of data blocks in disk */ int32_t num_blocks_; /*! \brief number of training iterations */ int32_t num_iterations_; - /*! \brief data path */ - std::string data_path_; /*! \brief backend thread for data preload */ std::thread preload_thread_; bool working_; @@ -95,9 +95,9 @@ namespace multiverso { namespace lightlda return *data_buffer_[index_]; } - DiskDataStream::DiskDataStream(int32_t num_blocks, - std::string data_path, int32_t num_iterations) : - num_blocks_(num_blocks), data_path_(data_path), + DiskDataStream::DiskDataStream(std::string data_path, + int32_t num_blocks, int32_t num_iterations) : + data_path_(data_path), num_blocks_(num_blocks), num_iterations_(num_iterations), working_(false) { block_id_ = 0; @@ -178,7 +178,7 @@ namespace multiverso { namespace lightlda { if (Config::out_of_core && Config::num_blocks != 1) { - return new DiskDataStream(Config::num_blocks, Config::input_dir, + return new DiskDataStream(Config::input_dir, Config::num_blocks, Config::num_iterations); } else diff --git a/src/lightlda.cpp b/src/lightlda.cpp index 199634d..d364fb0 100644 --- a/src/lightlda.cpp +++ b/src/lightlda.cpp @@ -1,10 +1,12 @@ #include "common.h" #include "trainer.h" #include "alias_table.h" +#include "asym_alpha.h" #include "data_stream.h" #include "data_block.h" #include "document.h" #include "meta.h" +#include "model.h" #include "util.h" #include #include @@ -22,12 +24,17 @@ namespace multiverso { namespace lightlda Config::Init(argc, argv); AliasTable* alias_table = new AliasTable(); + AsymAlpha* asym_alpha = nullptr; + if(Config::asymmetric_prior) + { + asym_alpha = new AsymAlpha(); + } Barrier* barrier = new Barrier(Config::num_local_workers); meta.Init(); std::vector trainers; for (int32_t i = 0; i < Config::num_local_workers; ++i) { - Trainer* trainer = new Trainer(alias_table, barrier, &meta); + Trainer* trainer = new Trainer(alias_table, asym_alpha, barrier, &meta); trainers.push_back(trainer); } @@ -43,7 +50,8 @@ namespace multiverso { namespace lightlda + std::to_string(clock()) + ".log"); data_stream = CreateDataStream(); - InitMultiverso(); + InitDocTopic(); + PSModel::Init(&meta, data_stream); Train(); Multiverso::Close(); @@ -59,6 +67,7 @@ namespace multiverso { namespace lightlda delete data_stream; delete barrier; delete alias_table; + if(Config::asymmetric_prior) delete asym_alpha; } private: static void Train() @@ -93,16 +102,7 @@ namespace multiverso { namespace lightlda Multiverso::EndTrain(); } - static void InitMultiverso() - { - Multiverso::BeginConfig(); - CreateTable(); - ConfigTable(); - Initialize(); - Multiverso::EndConfig(); - } - - static void Initialize() + static void InitDocTopic() { xorshift_rng rng; for (int32_t block = 0; block < Config::num_blocks; ++block) @@ -124,11 +124,6 @@ namespace multiverso { namespace lightlda // Init the latent variable if (!Config::warm_start) doc->SetTopic(cursor, rng.rand_k(Config::num_topics)); - // Init the server table - Multiverso::AddToServer(kWordTopicTable, - doc->Word(cursor), doc->Topic(cursor), 1); - Multiverso::AddToServer(kSummaryRow, - 0, doc->Topic(cursor), 1); } } Multiverso::Flush(); @@ -163,60 +158,6 @@ namespace multiverso { namespace lightlda } } - static void CreateTable() - { - int32_t num_vocabs = Config::num_vocabs; - int32_t num_topics = Config::num_topics; - Type int_type = Type::Int; - Type longlong_type = Type::LongLong; - multiverso::Format dense_format = multiverso::Format::Dense; - multiverso::Format sparse_format = multiverso::Format::Sparse; - - Multiverso::AddServerTable(kWordTopicTable, num_vocabs, - num_topics, int_type, dense_format); - Multiverso::AddCacheTable(kWordTopicTable, num_vocabs, - num_topics, int_type, dense_format, Config::model_capacity); - Multiverso::AddAggregatorTable(kWordTopicTable, num_vocabs, - num_topics, int_type, dense_format, Config::delta_capacity); - - Multiverso::AddTable(kSummaryRow, 1, Config::num_topics, - longlong_type, dense_format); - } - - static void ConfigTable() - { - multiverso::Format dense_format = multiverso::Format::Dense; - multiverso::Format sparse_format = multiverso::Format::Sparse; - for (int32_t word = 0; word < Config::num_vocabs; ++word) - { - if (meta.tf(word) > 0) - { - if (meta.tf(word) * kLoadFactor > Config::num_topics) - { - Multiverso::SetServerRow(kWordTopicTable, - word, dense_format, Config::num_topics); - Multiverso::SetCacheRow(kWordTopicTable, - word, dense_format, Config::num_topics); - } - else - { - Multiverso::SetServerRow(kWordTopicTable, - word, sparse_format, meta.tf(word) * kLoadFactor); - Multiverso::SetCacheRow(kWordTopicTable, - word, sparse_format, meta.tf(word) * kLoadFactor); - } - } - if (meta.local_tf(word) > 0) - { - if (meta.local_tf(word) * 2 * kLoadFactor > Config::num_topics) - Multiverso::SetAggregatorRow(kWordTopicTable, - word, dense_format, Config::num_topics); - else - Multiverso::SetAggregatorRow(kWordTopicTable, word, - sparse_format, meta.local_tf(word) * 2 * kLoadFactor); - } - } - } private: /*! \brief training data access */ static IDataStream* data_stream; diff --git a/src/meta.cpp b/src/meta.cpp index ca867ec..a34b786 100644 --- a/src/meta.cpp +++ b/src/meta.cpp @@ -7,7 +7,7 @@ namespace multiverso { namespace lightlda { LocalVocab::LocalVocab() - : num_slices_(0), own_memory_(false), vocabs_(nullptr), size_(0) + : num_slices_(0), vocabs_(nullptr), size_(0), own_memory_(false) {} LocalVocab::~LocalVocab() @@ -161,6 +161,10 @@ namespace multiverso { namespace lightlda { Log::Info("Actual Model capacity: %d MB, Alias capacity: %d MB, Delta capacity: %dMB\n", model_offset/1024/1024, alias_offset/1024/1024, delta_offset/1024/1024); + Log::Info("Actual asymmetric alpha capacity: %d MB, Alias capacity: %dMB, Delta capacity: %d MB\n", + Config::num_topics * kMaxDocLength * sizeof(int32_t)/1024/1024, + 2 * Config::num_topics * sizeof(int32_t)/1024/1024, + Config::num_topics * kMaxDocLength * sizeof(int32_t)/1024/1024); local_vocab.slice_index_.push_back(j); ++local_vocab.num_slices_; model_offset = model_size; diff --git a/src/model.cpp b/src/model.cpp index 1878d53..792d4c5 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -13,14 +13,20 @@ #include "meta.h" #include "trainer.h" +#include "data_stream.h" +#include "data_block.h" +#include "document.h" #include #include namespace multiverso { namespace lightlda { - LocalModel::LocalModel(Meta * meta) : word_topic_table_(nullptr), - summary_table_(nullptr), meta_(meta) + // -- local model implement area --------------------------------- // + LocalModel::LocalModel(Meta * meta) : + word_topic_table_(nullptr), summary_table_(nullptr), + topic_frequency_table_(nullptr), doc_length_table_(nullptr), + meta_(meta) { CreateTable(); } @@ -35,7 +41,7 @@ namespace multiverso { namespace lightlda int32_t num_vocabs = Config::num_vocabs; int32_t num_topics = Config::num_topics; multiverso::Format dense_format = multiverso::Format::Dense; - multiverso::Format sparse_format = multiverso::Format::Sparse; + //multiverso::Format sparse_format = multiverso::Format::Sparse; Type int_type = Type::Int; Type longlong_type = Type::LongLong; @@ -178,13 +184,13 @@ namespace multiverso { namespace lightlda model_file.close(); } - void LocalModel::AddWordTopicRow( + void LocalModel::AddWordTopic( integer_t word_id, integer_t topic_id, int32_t delta) { Log::Fatal("Not implemented yet\n"); } - void LocalModel::AddSummaryRow(integer_t topic_id, int64_t delta) + void LocalModel::AddSummary(integer_t topic_id, int64_t delta) { Log::Fatal("Not implemented yet\n"); } @@ -198,7 +204,166 @@ namespace multiverso { namespace lightlda { return *(static_cast*>(summary_table_->GetRow(0))); } + + Row& LocalModel::GetTopicFrequencyRow(integer_t topic_id) + { + return *(static_cast*>(topic_frequency_table_->GetRow(topic_id))); + } + Row& LocalModel::GetDocLengthRow() + { + return *(static_cast*>(doc_length_table_->GetRow(0))); + } + void LocalModel::AddTopicFrequency(integer_t topic_id, integer_t freq, + int32_t delta) + { + Log::Fatal("Not implemented yet\n"); + } + void LocalModel::AddDocLength(integer_t doc_len, int32_t delta) + { + Log::Fatal("Not implemented yet\n"); + } + // -- local model implement area --------------------------------- // + + // -- PS model implement area --------------------------------- // + void PSModel::Init(Meta* meta, IDataStream * data_stream) + { + Multiverso::BeginConfig(); + CreateTable(); + ConfigTable(meta); + LoadTable(meta, data_stream); + Multiverso::EndConfig(); + } + + void PSModel::CreateTable() + { + int32_t num_vocabs = Config::num_vocabs; + int32_t num_topics = Config::num_topics; + Type int_type = Type::Int; + Type longlong_type = Type::LongLong; + multiverso::Format dense_format = multiverso::Format::Dense; + //multiverso::Format sparse_format = multiverso::Format::Sparse; + + Multiverso::AddServerTable(kWordTopicTable, num_vocabs, + num_topics, int_type, dense_format); + Multiverso::AddCacheTable(kWordTopicTable, num_vocabs, + num_topics, int_type, dense_format, Config::model_capacity); + Multiverso::AddAggregatorTable(kWordTopicTable, num_vocabs, + num_topics, int_type, dense_format, Config::delta_capacity); + Multiverso::AddTable(kSummaryRow, 1, Config::num_topics, + longlong_type, dense_format); + + if(Config::asymmetric_prior) + { + Multiverso::AddServerTable(kTopicFrequencyTable, num_topics, + kMaxDocLength, int_type, dense_format); + Multiverso::AddCacheTable(kTopicFrequencyTable, num_topics, + kMaxDocLength, int_type, dense_format, + num_topics * kMaxDocLength * sizeof(int32_t)); + Multiverso::AddAggregatorTable(kTopicFrequencyTable, num_vocabs, + num_topics, int_type, dense_format, + num_topics * kMaxDocLength * sizeof(int32_t)); + Multiverso::AddTable(kDocLengthRow, 1, kMaxDocLength, + int_type, dense_format); + } + } + + void PSModel::ConfigTable(Meta* meta) + { + multiverso::Format dense_format = multiverso::Format::Dense; + multiverso::Format sparse_format = multiverso::Format::Sparse; + for (int32_t word = 0; word < Config::num_vocabs; ++word) + { + if (meta->tf(word) > 0) + { + if (meta->tf(word) * kLoadFactor > Config::num_topics) + { + Multiverso::SetServerRow(kWordTopicTable, + word, dense_format, Config::num_topics); + Multiverso::SetCacheRow(kWordTopicTable, + word, dense_format, Config::num_topics); + } + else + { + Multiverso::SetServerRow(kWordTopicTable, + word, sparse_format, meta->tf(word) * kLoadFactor); + Multiverso::SetCacheRow(kWordTopicTable, + word, sparse_format, meta->tf(word) * kLoadFactor); + } + } + if (meta->local_tf(word) > 0) + { + if (meta->local_tf(word) * 2 * kLoadFactor > Config::num_topics) + Multiverso::SetAggregatorRow(kWordTopicTable, + word, dense_format, Config::num_topics); + else + Multiverso::SetAggregatorRow(kWordTopicTable, word, + sparse_format, meta->local_tf(word) * 2 * kLoadFactor); + } + } + if(Config::asymmetric_prior) + { + for(int32_t topic = 0; topic < Config::num_topics; topic++) + { + Multiverso::SetServerRow(kTopicFrequencyTable, + topic, dense_format, kMaxDocLength); + Multiverso::SetCacheRow(kTopicFrequencyTable, + topic, dense_format, kMaxDocLength); + Multiverso::SetAggregatorRow(kTopicFrequencyTable, + topic, dense_format, kMaxDocLength); + } + } + } + void PSModel::LoadTable(Meta* meta, IDataStream * data_stream) + { + int32_t t, c; + std::unique_ptr> doc_topic_counter; + doc_topic_counter.reset(new Row(0, + multiverso::Format::Sparse, kMaxDocLength)); + for (int32_t block = 0; block < Config::num_blocks; ++block) + { + data_stream->BeforeDataAccess(); + DataBlock& data_block = data_stream->CurrDataBlock(); + int32_t num_slice = meta->local_vocab(block).num_slice(); + for (int32_t slice = 0; slice < num_slice; ++slice) + { + for (int32_t i = 0; i < data_block.Size(); ++i) + { + Document* doc = data_block.GetOneDoc(i); + int32_t& cursor = doc->Cursor(); + if (slice == 0) cursor = 0; + int32_t last_word = meta->local_vocab(block).LastWord(slice); + // Init the server table + for (; cursor < doc->Size(); ++cursor) + { + if (doc->Word(cursor) > last_word) break; + Multiverso::AddToServer(kWordTopicTable, + doc->Word(cursor), doc->Topic(cursor), 1); + Multiverso::AddToServer(kSummaryRow, + 0, doc->Topic(cursor), 1); + } + if(Config::asymmetric_prior && 0 == slice) + { + doc_topic_counter->Clear(); + doc->GetDocTopicVector(*doc_topic_counter); + Row::iterator iter = doc_topic_counter->Iterator(); + while (iter.HasNext()) + { + t = iter.Key(); + c = iter.Value(); + Multiverso::AddToServer(kTopicFrequencyTable, + t, c, 1); + iter.Next(); + } + Multiverso::AddToServer(kDocLengthRow, 0, doc->Size(), 1); + } + } + Multiverso::Flush(); + } + data_stream->EndDataAccess(); + } + } + Row& PSModel::GetWordTopicRow(integer_t word_id) { return trainer_->GetRow(kWordTopicTable, word_id); @@ -209,16 +374,35 @@ namespace multiverso { namespace lightlda return trainer_->GetRow(kSummaryRow, 0); } - void PSModel::AddWordTopicRow( + void PSModel::AddWordTopic( integer_t word_id, integer_t topic_id, int32_t delta) { trainer_->Add(kWordTopicTable, word_id, topic_id, delta); } - void PSModel::AddSummaryRow(integer_t topic_id, int64_t delta) + void PSModel::AddSummary(integer_t topic_id, int64_t delta) { trainer_->Add(kSummaryRow, 0, topic_id, delta); } + Row& PSModel::GetTopicFrequencyRow(integer_t topic_id) + { + return trainer_->GetRow(kTopicFrequencyTable, topic_id); + } + Row& PSModel::GetDocLengthRow() + { + return trainer_->GetRow(kDocLengthRow, 0); + } + void PSModel::AddTopicFrequency(integer_t topic_id, integer_t freq, + int32_t delta) + { + trainer_->Add(kTopicFrequencyTable, topic_id, freq, delta); + } + void PSModel::AddDocLength(integer_t doc_len, int32_t delta) + { + trainer_->Add(kSummaryRow, 0, doc_len, delta); + } + // -- PS model implement area --------------------------------- // + } // namespace lightlda } // namespace multiverso diff --git a/src/model.h b/src/model.h index 2e71acf..bc678e6 100644 --- a/src/model.h +++ b/src/model.h @@ -21,6 +21,7 @@ namespace lightlda { class Meta; class Trainer; + class IDataStream; /*! \brief interface for acceess to model */ class ModelBase @@ -29,9 +30,15 @@ namespace lightlda virtual ~ModelBase() {} virtual Row& GetWordTopicRow(integer_t word_id) = 0; virtual Row& GetSummaryRow() = 0; - virtual void AddWordTopicRow(integer_t word_id, integer_t topic_id, + virtual void AddWordTopic(integer_t word_id, integer_t topic_id, int32_t delta) = 0; - virtual void AddSummaryRow(integer_t topic_id, int64_t delta) = 0; + virtual void AddSummary(integer_t topic_id, int64_t delta) = 0; + + virtual Row& GetTopicFrequencyRow(integer_t topic_id) = 0; + virtual Row& GetDocLengthRow() = 0; + virtual void AddTopicFrequency(integer_t topic_id, integer_t freq, + int32_t delta) = 0; + virtual void AddDocLength(integer_t doc_len, int32_t delta) = 0; }; /*! \brief model based on local buffer */ @@ -43,9 +50,15 @@ namespace lightlda Row& GetWordTopicRow(integer_t word_id) override; Row& GetSummaryRow() override; - void AddWordTopicRow(integer_t word_id, integer_t topic_id, + void AddWordTopic(integer_t word_id, integer_t topic_id, + int32_t delta) override; + void AddSummary(integer_t topic_id, int64_t delta) override; + + Row& GetTopicFrequencyRow(integer_t topic_id) override; + Row& GetDocLengthRow() override; + void AddTopicFrequency(integer_t topic_id, integer_t freq, int32_t delta) override; - void AddSummaryRow(integer_t topic_id, int64_t delta) override; + void AddDocLength(integer_t doc_len, int32_t delta) override; private: void CreateTable(); @@ -55,6 +68,8 @@ namespace lightlda std::unique_ptr word_topic_table_; std::unique_ptr
summary_table_; + std::unique_ptr
topic_frequency_table_; + std::unique_ptr
doc_length_table_; Meta* meta_; LocalModel(const LocalModel&) = delete; @@ -64,14 +79,27 @@ namespace lightlda /*! \brief model based on parameter server */ class PSModel : public ModelBase { + public: + static void Init(Meta* meta, IDataStream * data_stream); + private: + static void CreateTable(); + static void ConfigTable(Meta* meta); + static void LoadTable(Meta* meta, IDataStream * data_stream); + public: explicit PSModel(Trainer* trainer) : trainer_(trainer) {} Row& GetWordTopicRow(integer_t word_id) override; Row& GetSummaryRow() override; - void AddWordTopicRow(integer_t word_id, integer_t topic_id, + void AddWordTopic(integer_t word_id, integer_t topic_id, + int32_t delta) override; + void AddSummary(integer_t topic_id, int64_t delta) override; + + Row& GetTopicFrequencyRow(integer_t topic_id) override; + Row& GetDocLengthRow() override; + void AddTopicFrequency(integer_t topic_id, integer_t freq, int32_t delta) override; - void AddSummaryRow(integer_t topic_id, int64_t delta) override; + void AddDocLength(integer_t doc_len, int32_t delta) override; private: Trainer* trainer_; diff --git a/src/sampler.cpp b/src/sampler.cpp index 7af201f..07fde4a 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -1,12 +1,14 @@ #include "sampler.h" #include "alias_table.h" +#include "asym_alpha.h" #include "common.h" #include "document.h" #include "model.h" #include #include +#include namespace multiverso { namespace lightlda { @@ -28,7 +30,8 @@ namespace multiverso { namespace lightlda } int32_t LightDocSampler::SampleOneDoc(Document* doc, int32_t slice, - int32_t lastword, ModelBase* model, AliasTable* alias) + int32_t lastword, ModelBase* model, AliasTable* alias, + AsymAlpha* asym_alpha) { DocInit(doc); int32_t num_tokens = 0; @@ -40,7 +43,7 @@ namespace multiverso { namespace lightlda if (word > lastword) break; int32_t old_topic = doc->Topic(cursor); int32_t new_topic = Sample(doc, word, old_topic, old_topic, - model, alias); + model, alias, asym_alpha); if (old_topic != new_topic) { doc->SetTopic(cursor, new_topic); @@ -48,10 +51,21 @@ namespace multiverso { namespace lightlda doc_topic_counter_->Add(new_topic, 1); if(!Config::inference) { - model->AddWordTopicRow(word, old_topic, -1); - model->AddSummaryRow(old_topic, -1); - model->AddWordTopicRow(word, new_topic, 1); - model->AddSummaryRow(new_topic, 1); + model->AddWordTopic(word, old_topic, -1); + model->AddSummary(old_topic, -1); + model->AddWordTopic(word, new_topic, 1); + model->AddSummary(new_topic, 1); + if(Config::asymmetric_prior) + { + int32_t old_freq = doc_topic_counter_->At(old_topic) + 1; + int32_t new_freq = doc_topic_counter_->At(new_topic); + model->AddTopicFrequency(old_topic, old_freq, -1); + if(new_freq - 1 > 0) + { + model->AddTopicFrequency(new_topic, new_freq - 1, -1); + } + model->AddTopicFrequency(new_topic, new_freq, 1); + } } } ++num_tokens; @@ -67,12 +81,14 @@ namespace multiverso { namespace lightlda int32_t LightDocSampler::Sample(Document* doc, int32_t word, int32_t old_topic, int32_t s, - ModelBase* model, AliasTable* alias) + ModelBase* model, AliasTable* alias, + AsymAlpha* asym_alpha) { int32_t t, w_t_cnt, w_s_cnt; int64_t n_t, n_s; float n_td_alpha, n_sd_alpha; float n_tw_beta, n_sw_beta, n_t_beta_sum, n_s_beta_sum; + double n_td_or_alpha; float proposal_t, proposal_s; float nominator, denominator; double rejection, pi; @@ -98,8 +114,16 @@ namespace multiverso { namespace lightlda n_t = summary_row.At(t); n_s = summary_row.At(s); - n_td_alpha = doc_topic_counter_->At(t) + alpha_; - n_sd_alpha = doc_topic_counter_->At(s) + alpha_; + if(asym_alpha) + { + n_td_alpha = doc_topic_counter_->At(t) + asym_alpha->At(t); + n_sd_alpha = doc_topic_counter_->At(s) + asym_alpha->At(s); + } + else + { + n_td_alpha = doc_topic_counter_->At(t) + alpha_; + n_sd_alpha = doc_topic_counter_->At(s) + alpha_; + } n_tw_beta = w_t_cnt + beta_; n_t_beta_sum = n_t + beta_sum_; n_sw_beta = w_s_cnt + beta_; @@ -129,8 +153,14 @@ namespace multiverso { namespace lightlda s = (t & m) | (s & ~m); } // Doc proposal - double n_td_or_alpha = rng_.rand_double() * - (doc->Size() + alpha_sum_); + if(asym_alpha) + { + n_td_or_alpha = rng_.rand_double() * (doc->Size() + asym_alpha->AlphaSum()); + } + else + { + n_td_or_alpha = rng_.rand_double() * (doc->Size() + alpha_sum_); + } if (n_td_or_alpha < doc->Size()) { int32_t t_idx = static_cast(n_td_or_alpha); @@ -138,7 +168,14 @@ namespace multiverso { namespace lightlda } else { - t = rng_.rand_k(num_topic_); + if(asym_alpha) + { + t = asym_alpha->Next(); + } + else + { + t = rng_.rand_k(num_topic_); + } } if (t != s) { @@ -149,8 +186,19 @@ namespace multiverso { namespace lightlda n_t = summary_row.At(t); n_s = summary_row.At(s); - n_td_alpha = doc_topic_counter_->At(t) + alpha_; - n_sd_alpha = doc_topic_counter_->At(s) + alpha_; + if(asym_alpha) + { + n_td_alpha = doc_topic_counter_->At(t) + asym_alpha->At(t); + n_sd_alpha = doc_topic_counter_->At(s) + asym_alpha->At(s); + } + else + { + n_td_alpha = doc_topic_counter_->At(t) + alpha_; + n_sd_alpha = doc_topic_counter_->At(s) + alpha_; + } + proposal_t = n_td_alpha; + proposal_s = n_sd_alpha; + n_tw_beta = w_t_cnt + beta_; n_t_beta_sum = n_t + beta_sum_; n_sw_beta = w_s_cnt + beta_; @@ -169,9 +217,6 @@ namespace multiverso { namespace lightlda } - proposal_s = (doc_topic_counter_->At(s) + alpha_); - proposal_t = (doc_topic_counter_->At(t) + alpha_); - nominator = n_td_alpha * n_tw_beta * n_s_beta_sum * proposal_s; denominator = n_sd_alpha * n_sw_beta * n_t_beta_sum * proposal_t; @@ -186,7 +231,7 @@ namespace multiverso { namespace lightlda int32_t LightDocSampler::ApproxSample(Document* doc, int32_t word, int32_t old_topic, int32_t s, - ModelBase* model, AliasTable* alias) + ModelBase* model, AliasTable* alias, AsymAlpha* asym_alpha) { float n_tw_beta, n_sw_beta, n_t_beta_sum, n_s_beta_sum; float nominator, denominator; @@ -202,8 +247,16 @@ namespace multiverso { namespace lightlda t = alias->Propose(word, rng_); if (t != s) { - nominator = doc_topic_counter_->At(t) + alpha_; - denominator = doc_topic_counter_->At(s) + alpha_; + if(asym_alpha) + { + nominator = doc_topic_counter_->At(t) + asym_alpha->At(t); + denominator = doc_topic_counter_->At(s) + asym_alpha->At(s); + } + else + { + nominator = doc_topic_counter_->At(t) + alpha_; + denominator = doc_topic_counter_->At(s) + alpha_; + } if (t == old_topic) { nominator -= 1; @@ -227,7 +280,14 @@ namespace multiverso { namespace lightlda } else { - t = rng_.rand_k(num_topic_); + if(asym_alpha) + { + t = asym_alpha->Next(); + } + else + { + t = rng_.rand_k(num_topic_); + } } if (t != s) { diff --git a/src/sampler.h b/src/sampler.h index 404dd6e..789702a 100644 --- a/src/sampler.h +++ b/src/sampler.h @@ -18,6 +18,7 @@ namespace multiverso namespace multiverso { namespace lightlda { class AliasTable; + class AsymAlpha; class Document; class ModelBase; @@ -34,10 +35,11 @@ namespace multiverso { namespace lightlda * \param lastword last word of current slice * \param model pointer model, for access of model * \param alias pointer to alias table, for access of alias + * \param asym_alpha asym alpha prior provider * \return number of sampled token */ int32_t SampleOneDoc(Document* doc, int32_t slice, int32_t lastword, - ModelBase* model, AliasTable* alias); + ModelBase* model, AliasTable* alias, AsymAlpha* asym_alpha); /*! * \brief Get doc-topic-counter, for reusing this container * \return reference to light hash map @@ -57,9 +59,11 @@ namespace multiverso { namespace lightlda * \param old_topic old topic assignment of this token * \param model access * \param alias for alias table access + * \param asym_alpha asym alpha prior provider */ int32_t Sample(Document* doc, int32_t word, int32_t state, - int32_t old_topic, ModelBase* model, AliasTable* alias); + int32_t old_topic, ModelBase* model, AliasTable* alias, + AsymAlpha* asym_alpha); /*! * \brief Sample the latent topic assignment for a token. This function @@ -69,7 +73,8 @@ namespace multiverso { namespace lightlda * \param same with Sample */ int32_t ApproxSample(Document* doc, int32_t word, int32_t state, - int32_t old_topic, ModelBase* model, AliasTable* alias); + int32_t old_topic, ModelBase* model, AliasTable* alias, + AsymAlpha* asym_alpha); private: // lda hyper-parameter float alpha_; diff --git a/src/trainer.cpp b/src/trainer.cpp index eca707e..7d25d80 100644 --- a/src/trainer.cpp +++ b/src/trainer.cpp @@ -1,6 +1,7 @@ #include "trainer.h" #include "alias_table.h" +#include "asym_alpha.h" #include "common.h" #include "data_block.h" #include "eval.h" @@ -19,8 +20,10 @@ namespace multiverso { namespace lightlda double Trainer::word_llh_ = 0.0; Trainer::Trainer(AliasTable* alias_table, + AsymAlpha* asym_alpha, Barrier* barrier, Meta* meta) : - alias_(alias_table), barrier_(barrier), meta_(meta), + alias_(alias_table), asym_alpha_(asym_alpha), + barrier_(barrier), meta_(meta), model_(nullptr) { sampler_ = new LightDocSampler(); @@ -71,13 +74,27 @@ namespace multiverso { namespace lightlda Log::Info("Rank = %d, Alias Time used: %.2f s \n", Multiverso::ProcessRank(), watch.ElapsedSeconds()); } + // Learn alpha and build Alias table + if(asym_alpha_!= nullptr && + iter % 5 == 0 && + id == 0) + { + watch.Restart(); + asym_alpha_->LearnDirichletPrior(model_); + asym_alpha_->BuildAlias(); + Log::Info("Rank = %d, AsymAplha Time used: %.2f s \n", + Multiverso::ProcessRank(), watch.ElapsedSeconds()); + } + barrier_->Wait(); + int32_t num_token = 0; watch.Restart(); // Train with lightlda sampler for (int32_t doc_id = id; doc_id < data.Size(); doc_id += trainer_num) { Document* doc = data.GetOneDoc(doc_id); - num_token += sampler_->SampleOneDoc(doc, slice, lastword, model_, alias_); + num_token += sampler_->SampleOneDoc(doc, slice, lastword, model_, alias_, + asym_alpha_); } if (TrainerId() == 0) { @@ -100,7 +117,11 @@ namespace multiverso { namespace lightlda // if (iter != 0 && iter % 50 == 0) Dump(iter, lda_data_block); // Clear the thread information in alias table - if (iter == Config::num_iterations - 1) alias_->Clear(); + if (iter == Config::num_iterations - 1) + { + alias_->Clear(); + if(asym_alpha_) asym_alpha_->Clear(); + } } void Trainer::Evaluate(LDADataBlock* lda_data_block) @@ -203,6 +224,17 @@ namespace multiverso { namespace lightlda *local_vocab.begin(slice), *(local_vocab.end(slice) - 1)); // Request summary-row RequestTable(kSummaryRow); + + if(Config::asymmetric_prior) + { + // Request topic-frequency-table + for(int32_t topic = 0; topic < Config::num_topics; topic++) + { + RequestRow(kTopicFrequencyTable, topic); + } + // Request doc-length-row + RequestTable(kDocLengthRow); + } } } // namespace lightlda } // namespace multiverso diff --git a/src/trainer.h b/src/trainer.h index 2b483fd..cc2f1f5 100644 --- a/src/trainer.h +++ b/src/trainer.h @@ -14,6 +14,7 @@ namespace multiverso { namespace lightlda { class AliasTable; + class AsymAlpha; class LDADataBlock; class LightDocSampler; class Meta; @@ -23,7 +24,7 @@ namespace multiverso { namespace lightlda class Trainer : public TrainerBase { public: - Trainer(AliasTable* alias, Barrier* barrier, Meta* meta); + Trainer(AliasTable* alias, AsymAlpha* asym_alpha, Barrier* barrier, Meta* meta); ~Trainer(); /*! * \brief Defines Trainning method for a data_block in one iteration @@ -41,6 +42,8 @@ namespace multiverso { namespace lightlda private: /*! \brief alias table, for alias access */ AliasTable* alias_; + /*! \brief asym alpha */ + AsymAlpha* asym_alpha_; /*! \brief sampler for lightlda */ LightDocSampler* sampler_; /*! \brief barrier for thread-sync */ From 7515493680b643a052527be7436583719e6d2f24 Mon Sep 17 00:00:00 2001 From: hiyijian Date: Wed, 20 Jan 2016 14:18:56 +0800 Subject: [PATCH 2/3] fix bug for updating topic-frequency-table --- src/asym_alpha.cpp | 4 ++-- src/meta.cpp | 4 ++-- src/model.cpp | 20 ++++++++++---------- src/sampler.cpp | 10 +++++++--- src/trainer.cpp | 10 +++++++--- 5 files changed, 28 insertions(+), 20 deletions(-) diff --git a/src/asym_alpha.cpp b/src/asym_alpha.cpp index 1eb040e..c0ef946 100644 --- a/src/asym_alpha.cpp +++ b/src/asym_alpha.cpp @@ -39,7 +39,7 @@ namespace lightlda { non_zero_limit_[k] = 0; Row& row = model->GetTopicFrequencyRow(k); - for (int i = 0; i < max_doc_length_; ++i) + for (int i = 1; i <= max_doc_length_; ++i) { if (row.At(i) > 0) { @@ -65,7 +65,7 @@ namespace lightlda currentDigamma = 0; // Iterate over the histogram: - for (int i = 1; i < max_doc_length_; i++) + for (int i = 1; i <= max_doc_length_; i++) { currentDigamma += 1 / (parametersSum + i - 1); denominator += doc_length_row.At(i) * currentDigamma; diff --git a/src/meta.cpp b/src/meta.cpp index a34b786..46e593b 100644 --- a/src/meta.cpp +++ b/src/meta.cpp @@ -162,9 +162,9 @@ namespace multiverso { namespace lightlda Log::Info("Actual Model capacity: %d MB, Alias capacity: %d MB, Delta capacity: %dMB\n", model_offset/1024/1024, alias_offset/1024/1024, delta_offset/1024/1024); Log::Info("Actual asymmetric alpha capacity: %d MB, Alias capacity: %dMB, Delta capacity: %d MB\n", - Config::num_topics * kMaxDocLength * sizeof(int32_t)/1024/1024, + Config::num_topics * (kMaxDocLength + 1) * sizeof(int32_t)/1024/1024, 2 * Config::num_topics * sizeof(int32_t)/1024/1024, - Config::num_topics * kMaxDocLength * sizeof(int32_t)/1024/1024); + Config::num_topics * (kMaxDocLength + 1) * sizeof(int32_t)/1024/1024); local_vocab.slice_index_.push_back(j); ++local_vocab.num_slices_; model_offset = model_size; diff --git a/src/model.cpp b/src/model.cpp index 792d4c5..ceab6ee 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -255,14 +255,14 @@ namespace multiverso { namespace lightlda if(Config::asymmetric_prior) { Multiverso::AddServerTable(kTopicFrequencyTable, num_topics, - kMaxDocLength, int_type, dense_format); + kMaxDocLength + 1, int_type, dense_format); Multiverso::AddCacheTable(kTopicFrequencyTable, num_topics, - kMaxDocLength, int_type, dense_format, - num_topics * kMaxDocLength * sizeof(int32_t)); - Multiverso::AddAggregatorTable(kTopicFrequencyTable, num_vocabs, - num_topics, int_type, dense_format, - num_topics * kMaxDocLength * sizeof(int32_t)); - Multiverso::AddTable(kDocLengthRow, 1, kMaxDocLength, + kMaxDocLength + 1, int_type, dense_format, + num_topics * (kMaxDocLength + 1) * sizeof(int32_t)); + Multiverso::AddAggregatorTable(kTopicFrequencyTable, num_topics, + kMaxDocLength + 1, int_type, dense_format, + num_topics * (kMaxDocLength + 1) * sizeof(int32_t)); + Multiverso::AddTable(kDocLengthRow, 1, (kMaxDocLength + 1), int_type, dense_format); } } @@ -305,11 +305,11 @@ namespace multiverso { namespace lightlda for(int32_t topic = 0; topic < Config::num_topics; topic++) { Multiverso::SetServerRow(kTopicFrequencyTable, - topic, dense_format, kMaxDocLength); + topic, dense_format, kMaxDocLength + 1); Multiverso::SetCacheRow(kTopicFrequencyTable, - topic, dense_format, kMaxDocLength); + topic, dense_format, kMaxDocLength + 1); Multiverso::SetAggregatorRow(kTopicFrequencyTable, - topic, dense_format, kMaxDocLength); + topic, dense_format, kMaxDocLength + 1); } } } diff --git a/src/sampler.cpp b/src/sampler.cpp index 07fde4a..c1849aa 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -57,14 +57,18 @@ namespace multiverso { namespace lightlda model->AddSummary(new_topic, 1); if(Config::asymmetric_prior) { - int32_t old_freq = doc_topic_counter_->At(old_topic) + 1; + int32_t old_freq = doc_topic_counter_->At(old_topic); int32_t new_freq = doc_topic_counter_->At(new_topic); - model->AddTopicFrequency(old_topic, old_freq, -1); + model->AddTopicFrequency(old_topic, old_freq + 1, -1); + if(old_freq > 0) + { + model->AddTopicFrequency(old_topic, old_freq, 1); + } + model->AddTopicFrequency(new_topic, new_freq, 1); if(new_freq - 1 > 0) { model->AddTopicFrequency(new_topic, new_freq - 1, -1); } - model->AddTopicFrequency(new_topic, new_freq, 1); } } } diff --git a/src/trainer.cpp b/src/trainer.cpp index 7d25d80..6211bd6 100644 --- a/src/trainer.cpp +++ b/src/trainer.cpp @@ -220,20 +220,24 @@ namespace multiverso { namespace lightlda { RequestRow(kWordTopicTable, *p); } - Log::Debug("Request params. start = %d, end = %d\n", + Log::Debug("Request word-topic-table. start = %d, end = %d\n", *local_vocab.begin(slice), *(local_vocab.end(slice) - 1)); // Request summary-row RequestTable(kSummaryRow); + Log::Debug("Request summary-row\n"); if(Config::asymmetric_prior) { // Request topic-frequency-table - for(int32_t topic = 0; topic < Config::num_topics; topic++) + RequestTable(kTopicFrequencyTable); + Log::Debug("Request topic-frequency-table\n"); + /*for(int32_t topic = 0; topic < Config::num_topics; topic++) { RequestRow(kTopicFrequencyTable, topic); - } + }*/ // Request doc-length-row RequestTable(kDocLengthRow); + Log::Debug("Request doc-length-row\n"); } } } // namespace lightlda From e7e86db794e4e90bd1ae520ace06a7b0c25e846c Mon Sep 17 00:00:00 2001 From: hiyijian Date: Tue, 26 Jan 2016 14:29:26 +0800 Subject: [PATCH 3/3] some refactor --- src/alias_table.cpp | 33 ++++++++++++++++++++++++++------- src/alias_table.h | 2 +- src/asym_alpha.cpp | 7 +------ src/asym_alpha.h | 1 - src/common.cpp | 15 ++++++++++++--- src/common.h | 4 +++- src/lightlda.cpp | 1 - src/meta.cpp | 11 +++++++---- src/trainer.cpp | 11 ++++------- 9 files changed, 54 insertions(+), 31 deletions(-) diff --git a/src/alias_table.cpp b/src/alias_table.cpp index 738acc3..0f1dfe8 100644 --- a/src/alias_table.cpp +++ b/src/alias_table.cpp @@ -10,8 +10,6 @@ #include #include -#define SAFE_DELETE(p) if((p)) { delete (p); (p) = nullptr; } - namespace multiverso { namespace lightlda { _THREAD_LOCAL std::vector* AliasTable::q_w_proportion_; @@ -130,8 +128,8 @@ namespace multiverso { namespace lightlda void AliasTable::Clear() { - SAFE_DELETE(q_w_proportion_); - alias_rng_int_->Clear(); + delete q_w_proportion_; + q_w_proportion_ = nullptr; } // -- AliasTable implement area --------------------------------- // @@ -140,11 +138,29 @@ namespace multiverso { namespace lightlda float mass, int32_t & height, int32_t* kv_vector) { if (q_proportion_int_ == nullptr) + { q_proportion_int_ = new std::vector(size_); + } + else if(q_proportion_int_->size() != size_) + { + q_proportion_int_->resize(size_); + } if (L_ == nullptr) + { L_ = new std::vector>(size_); + } + else if(L_->size() != size_) + { + L_->resize(size_); + } if (H_ == nullptr) + { H_ = new std::vector>(size_); + } + else if(H_->size() != size_) + { + H_->resize(size_); + } int32_t mass_int = 0x7fffffff; int32_t a_int = mass_int / size; @@ -250,9 +266,12 @@ namespace multiverso { namespace lightlda void AliasMultinomialRNGInt::Clear() { - SAFE_DELETE(q_proportion_int_); - SAFE_DELETE(L_); - SAFE_DELETE(H_); + delete q_proportion_int_; + q_proportion_int_ = nullptr; + delete L_; + L_ = nullptr; + delete H_; + H_ = nullptr; } int32_t AliasMultinomialRNGInt::Propose(xorshift_rng& rng, int32_t height, diff --git a/src/alias_table.h b/src/alias_table.h index d3eece8..adf3b66 100644 --- a/src/alias_table.h +++ b/src/alias_table.h @@ -29,7 +29,7 @@ namespace multiverso { namespace lightlda AliasMultinomialRNGInt(int32_t size): size_(size) {} void Build(const std::vector& q_proportion, int32_t size, float mass, int32_t & height, int32_t* kv_vector); - void Clear(); + static void Clear(); //for dense sampling int32_t Propose(xorshift_rng& rng, int32_t height, int32_t* kv_vector); diff --git a/src/asym_alpha.cpp b/src/asym_alpha.cpp index c0ef946..d0a04aa 100644 --- a/src/asym_alpha.cpp +++ b/src/asym_alpha.cpp @@ -109,14 +109,9 @@ namespace lightlda alpha_sum_, alpha_height_, kv_vector_); } - void AsymAlpha::Clear() - { - alias_rng_int_->Clear(); - } - int32_t AsymAlpha::Next() { return alias_rng_int_->Propose(rng_, alpha_height_, kv_vector_); } } // namespace lightlda -} // namespace multiverso \ No newline at end of file +} // namespace multiverso diff --git a/src/asym_alpha.h b/src/asym_alpha.h index 76ac109..ea0c600 100644 --- a/src/asym_alpha.h +++ b/src/asym_alpha.h @@ -24,7 +24,6 @@ namespace lightlda ~AsymAlpha(); void LearnDirichletPrior(ModelBase * model); void BuildAlias(); - void Clear(); int32_t Next(); float At(int32_t idx) const; float AlphaSum() const; diff --git a/src/common.cpp b/src/common.cpp index 4fb38d6..245d2e1 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -10,7 +10,8 @@ namespace multiverso { namespace lightlda int32_t Config::num_vocabs = -1; int32_t Config::num_topics = 100; int32_t Config::num_iterations = 100; - int32_t Config::num_alpha_iterations = 0; + int32_t Config::num_alpha_iterations = 100; + int32_t Config::learn_alpha_every = 5; int32_t Config::mh_steps = 2; int32_t Config::num_servers = 1; int32_t Config::num_local_workers = 1; @@ -47,6 +48,7 @@ namespace multiverso { namespace lightlda if (strcmp(argv[i], "-num_topics") == 0) num_topics = atoi(argv[i + 1]); if (strcmp(argv[i], "-num_iterations") == 0) num_iterations = atoi(argv[i + 1]); if (strcmp(argv[i], "-num_alpha_iterations") == 0) num_alpha_iterations = atoi(argv[i + 1]); + if (strcmp(argv[i], "-learn_alpha_every") == 0) learn_alpha_every = atoi(argv[i + 1]); if (strcmp(argv[i], "-mh_steps") == 0) mh_steps = atoi(argv[i + 1]); if (strcmp(argv[i], "-num_servers") == 0) num_servers = atoi(argv[i + 1]); if (strcmp(argv[i], "-num_local_workers") == 0) num_local_workers = atoi(argv[i + 1]); @@ -59,11 +61,11 @@ namespace multiverso { namespace lightlda if (strcmp(argv[i], "-server_file") == 0) server_file = std::string(argv[i + 1]); if (strcmp(argv[i], "-warm_start") == 0) warm_start = true; if (strcmp(argv[i], "-out_of_core") == 0) out_of_core = true; + if (strcmp(argv[i], "-asymmetric_prior") == 0) asymmetric_prior = true; if (strcmp(argv[i], "-data_capacity") == 0) data_capacity = atoi(argv[i + 1]) * kMB; if (strcmp(argv[i], "-model_capacity") == 0) model_capacity = atoi(argv[i + 1]) * kMB; if (strcmp(argv[i], "-alias_capacity") == 0) alias_capacity = atoi(argv[i + 1]) * kMB; if (strcmp(argv[i], "-delta_capacity") == 0) delta_capacity = atoi(argv[i + 1]) * kMB; - if(num_alpha_iterations > 0) asymmetric_prior = true; } Check(); } @@ -74,6 +76,8 @@ namespace multiverso { namespace lightlda printf("-num_vocabs Size of dataset vocabulary \n"); printf("-num_topics Number of topics. Default: 100\n"); printf("-num_iterations Number of iteratioins. Default: 100\n"); + printf("-num_alhpa_iterations Number of learning alpha iteratioins. Default: 100\n"); + printf("-learn_alhpa_every Frequency of learning alpha. Default: 5\n"); printf("-mh_steps Metropolis-hasting steps. Default: 2\n"); printf("-alpha Dirichlet prior alpha. Default: 0.1\n"); printf("-beta Dirichlet prior beta. Default: 0.01\n\n"); @@ -86,6 +90,7 @@ namespace multiverso { namespace lightlda printf("-num_aggregator Number of local aggregation threads. Default: 1\n"); printf("-server_file Server endpoint file. Used by MPI-free version\n"); printf("-warm_start Warm start \n"); + printf("-asymmetric_prior Use asymmetric prior \n\n"); printf("-out_of_core Use out of core computing \n\n"); printf("-data_capacity Memory pool size(MB) for data storage, \n"); printf(" should larger than the any data block\n"); @@ -101,6 +106,8 @@ namespace multiverso { namespace lightlda printf("-num_vocabs Size of dataset vocabulary \n"); printf("-num_topics Number of topics. Default: 100\n"); printf("-num_iterations Number of iteratioins. Default: 100\n"); + printf("-num_alhpa_iterations Number of learning alpha iteratioins. Default: 100\n"); + printf("-learn_alhpa_every Frequency of learning alpha. Default: 5\n"); printf("-mh_steps Metropolis-hasting steps. Default: 2\n"); printf("-alpha Dirichlet prior alpha. Default: 0.1\n"); printf("-beta Dirichlet prior beta. Default: 0.01\n\n"); @@ -110,6 +117,7 @@ namespace multiverso { namespace lightlda printf(" files generated by dump_block \n\n"); printf("-num_local_workers Number of local training threads. Default: 4\n"); printf("-warm_start Warm start \n"); + printf("-asymmetric_prior Use asymmetric prior \n\n"); printf("-out_of_core Use out of core computing \n\n"); printf("-data_capacity Memory pool size(MB) for data storage, \n"); printf(" should larger than the any data block\n"); @@ -130,7 +138,8 @@ namespace multiverso { namespace lightlda void Config::Check() { - if (input_dir == "" || num_vocabs <= 0 || max_num_document == -1) + if (input_dir == "" || num_vocabs <= 0 || max_num_document == -1 || + (asymmetric_prior && learn_alpha_every > num_iterations)) { PrintUsage(); } diff --git a/src/common.h b/src/common.h index f000169..7fec997 100644 --- a/src/common.h +++ b/src/common.h @@ -42,8 +42,10 @@ namespace multiverso { namespace lightlda static int32_t num_topics; /*! \brief number of iterations */ static int32_t num_iterations; - /*! \brief number of estimating alpha iterations */ + /*! \brief number of learning alpha iterations */ static int32_t num_alpha_iterations; + /*! \brief frequency of learning alpha */ + static int32_t learn_alpha_every; /*! \brief number of metropolis-hastings steps */ static int32_t mh_steps; /*! \brief number of servers for Multiverso setting */ diff --git a/src/lightlda.cpp b/src/lightlda.cpp index d364fb0..3bff702 100644 --- a/src/lightlda.cpp +++ b/src/lightlda.cpp @@ -126,7 +126,6 @@ namespace multiverso { namespace lightlda doc->SetTopic(cursor, rng.rand_k(Config::num_topics)); } } - Multiverso::Flush(); } data_stream->EndDataAccess(); } diff --git a/src/meta.cpp b/src/meta.cpp index 46e593b..53cac17 100644 --- a/src/meta.cpp +++ b/src/meta.cpp @@ -161,10 +161,13 @@ namespace multiverso { namespace lightlda { Log::Info("Actual Model capacity: %d MB, Alias capacity: %d MB, Delta capacity: %dMB\n", model_offset/1024/1024, alias_offset/1024/1024, delta_offset/1024/1024); - Log::Info("Actual asymmetric alpha capacity: %d MB, Alias capacity: %dMB, Delta capacity: %d MB\n", - Config::num_topics * (kMaxDocLength + 1) * sizeof(int32_t)/1024/1024, - 2 * Config::num_topics * sizeof(int32_t)/1024/1024, - Config::num_topics * (kMaxDocLength + 1) * sizeof(int32_t)/1024/1024); + if(Config::asymmetric_prior) + { + Log::Info("Actual asymmetric alpha capacity: %d MB, Alias capacity: %dMB, Delta capacity: %d MB\n", + Config::num_topics * (kMaxDocLength + 1) * sizeof(int32_t)/1024/1024, + 2 * Config::num_topics * sizeof(int32_t)/1024/1024, + Config::num_topics * (kMaxDocLength + 1) * sizeof(int32_t)/1024/1024); + } local_vocab.slice_index_.push_back(j); ++local_vocab.num_slices_; model_offset = model_size; diff --git a/src/trainer.cpp b/src/trainer.cpp index 6211bd6..386773a 100644 --- a/src/trainer.cpp +++ b/src/trainer.cpp @@ -76,7 +76,7 @@ namespace multiverso { namespace lightlda } // Learn alpha and build Alias table if(asym_alpha_!= nullptr && - iter % 5 == 0 && + iter % Config::learn_alpha_every == 0 && id == 0) { watch.Restart(); @@ -119,8 +119,8 @@ namespace multiverso { namespace lightlda // Clear the thread information in alias table if (iter == Config::num_iterations - 1) { + AliasMultinomialRNGInt::Clear(); alias_->Clear(); - if(asym_alpha_) asym_alpha_->Clear(); } } @@ -212,6 +212,7 @@ namespace multiverso { namespace lightlda reinterpret_cast(data_block); // Request word-topic-table int32_t slice = lda_data_block->slice(); + int32_t iter = lda_data_block->iteration(); DataBlock& data = lda_data_block->data(); const LocalVocab& local_vocab = data.meta(); @@ -226,15 +227,11 @@ namespace multiverso { namespace lightlda RequestTable(kSummaryRow); Log::Debug("Request summary-row\n"); - if(Config::asymmetric_prior) + if(Config::asymmetric_prior && iter % Config::learn_alpha_every == 0) { // Request topic-frequency-table RequestTable(kTopicFrequencyTable); Log::Debug("Request topic-frequency-table\n"); - /*for(int32_t topic = 0; topic < Config::num_topics; topic++) - { - RequestRow(kTopicFrequencyTable, topic); - }*/ // Request doc-length-row RequestTable(kDocLengthRow); Log::Debug("Request doc-length-row\n");