diff --git a/inference/infer.cpp b/inference/infer.cpp index a0d3a60..6a8020d 100644 --- a/inference/infer.cpp +++ b/inference/infer.cpp @@ -28,8 +28,8 @@ namespace multiverso { namespace lightlda LocalModel* model = new LocalModel(&meta); model->Init(); //init document stream data_stream = CreateDataStream(); - //init documents - InitDocument(); + //init doc-topic + InitDocTopic(); //init alias table AliasTable* alias_table = new AliasTable(); //init inferers @@ -102,7 +102,7 @@ namespace multiverso { namespace lightlda return nullptr; } - static void InitDocument() + static void InitDocTopic() { xorshift_rng rng; for (int32_t block = 0; block < Config::num_blocks; ++block) diff --git a/inference/inferer.cpp b/inference/inferer.cpp index 2bd8f3b..64ff8d6 100644 --- a/inference/inferer.cpp +++ b/inference/inferer.cpp @@ -74,7 +74,8 @@ namespace multiverso { namespace lightlda for (int32_t doc_id = id_; doc_id < data.Size(); doc_id += thread_num_) { Document* doc = data.GetOneDoc(doc_id); - sampler_->SampleOneDoc(doc, 0, lastword, model_, alias_); + //TODO: Asymmeric prior + sampler_->SampleOneDoc(doc, 0, lastword, model_, alias_, nullptr); } } diff --git a/src/alias_table.cpp b/src/alias_table.cpp index 5b383bd..0f1dfe8 100644 --- a/src/alias_table.cpp +++ b/src/alias_table.cpp @@ -13,10 +13,11 @@ namespace multiverso { namespace lightlda { _THREAD_LOCAL std::vector* AliasTable::q_w_proportion_; - _THREAD_LOCAL std::vector* AliasTable::q_w_proportion_int_; - _THREAD_LOCAL std::vector>* AliasTable::L_; - _THREAD_LOCAL std::vector>* AliasTable::H_; + _THREAD_LOCAL std::vector* AliasMultinomialRNGInt::q_proportion_int_; + _THREAD_LOCAL std::vector>* AliasMultinomialRNGInt::L_; + _THREAD_LOCAL std::vector>* AliasMultinomialRNGInt::H_; + // -- AliasTable implement area --------------------------------- // AliasTable::AliasTable() { memory_size_ = Config::alias_capacity / sizeof(int32_t); @@ -25,6 +26,8 @@ namespace multiverso { namespace lightlda beta_ = Config::beta; beta_sum_ = beta_ * num_vocabs_; memory_block_ = new int32_t[memory_size_]; + + alias_rng_int_ = new AliasMultinomialRNGInt(num_topics_); beta_kv_vector_ = new int32_t[2 * num_topics_]; @@ -34,6 +37,7 @@ namespace multiverso { namespace lightlda AliasTable::~AliasTable() { + delete alias_rng_int_; delete[] memory_block_; delete[] beta_kv_vector_; } @@ -47,12 +51,6 @@ namespace multiverso { namespace lightlda { if (q_w_proportion_ == nullptr) q_w_proportion_ = new std::vector(num_topics_); - if (q_w_proportion_int_ == nullptr) - q_w_proportion_int_ = new std::vector(num_topics_); - if (L_ == nullptr) - L_ = new std::vector>(num_topics_); - if (H_ == nullptr) - H_ = new std::vector>(num_topics_); // Compute the proportion Row& summary_row = model->GetSummaryRow(); if (word == -1) // build alias row for beta @@ -63,8 +61,7 @@ namespace multiverso { namespace lightlda (*q_w_proportion_)[k] = beta_ / (summary_row.At(k) + beta_sum_); beta_mass_ += (*q_w_proportion_)[k]; } - AliasMultinomialRNG(num_topics_, beta_mass_, beta_height_, - beta_kv_vector_); + alias_rng_int_->Build(*q_w_proportion_, num_topics_, beta_mass_, beta_height_, beta_kv_vector_); } else // build alias row for word { @@ -105,7 +102,7 @@ namespace multiverso { namespace lightlda word_topic_row.NonzeroSize()); } } - AliasMultinomialRNG(size, mass_[word], height_[word], + alias_rng_int_->Build(*q_w_proportion_, size, mass_[word], height_[word], memory_block_ + word_entry.begin_offset); } return 0; @@ -118,62 +115,53 @@ namespace multiverso { namespace lightlda int32_t capacity = word_entry.capacity; if (word_entry.is_dense) { - auto sample = rng.rand(); - int32_t idx = sample / height_[word]; - if (capacity <= idx) idx = capacity - 1; - - int32_t* p = kv_vector + 2 * idx; - int32_t k = *p++; - int32_t v = *p; - int32_t m = -(sample < v); - return (idx & m) | (k & ~m); + return alias_rng_int_->Propose(rng, height_[word], kv_vector); } else { - auto sample = rng.rand_double() * (mass_[word] + beta_mass_); - if (sample < mass_[word]) - { - int32_t* idx_vector = kv_vector + 2 * word_entry.capacity; - auto n_kw_sample = rng.rand(); - int32_t idx = n_kw_sample / height_[word]; - if (capacity <= idx) idx = capacity - 1; - int32_t* p = kv_vector + 2 * idx; - int32_t k = *p++; - int32_t v = *p; - int32_t id = idx_vector[idx]; - int32_t m = -(n_kw_sample < v); - return (id & m) | (idx_vector[k] & ~m); - } - else - { - auto beta_sample = rng.rand(); - int32_t idx = beta_sample / beta_height_; - if (num_topics_ <= idx) idx = num_topics_ - 1; - int32_t* p = beta_kv_vector_ + 2 * idx; - int32_t k = *p++; - int32_t v = *p; - int32_t m = -(beta_sample < v); - return (idx & m) | (k & ~m); - } + return alias_rng_int_->Propose(rng, height_[word], beta_height_, + mass_[word], beta_mass_, + kv_vector, capacity, + beta_kv_vector_); } } void AliasTable::Clear() { - delete q_w_proportion_; - q_w_proportion_ = nullptr; - delete q_w_proportion_int_; - q_w_proportion_int_ = nullptr; - delete L_; - L_ = nullptr; - delete H_; - H_ = nullptr; + delete q_w_proportion_; + q_w_proportion_ = nullptr; } + // -- AliasTable implement area --------------------------------- // - - void AliasTable::AliasMultinomialRNG(int32_t size, float mass, int32_t& height, - int32_t* kv_vector) + // -- AliasMultinomialRNGInt implement area --------------------------------- // + void AliasMultinomialRNGInt::Build(const std::vector& q_proportion, int32_t size, + float mass, int32_t & height, int32_t* kv_vector) { + if (q_proportion_int_ == nullptr) + { + q_proportion_int_ = new std::vector(size_); + } + else if(q_proportion_int_->size() != size_) + { + q_proportion_int_->resize(size_); + } + if (L_ == nullptr) + { + L_ = new std::vector>(size_); + } + else if(L_->size() != size_) + { + L_->resize(size_); + } + if (H_ == nullptr) + { + H_ = new std::vector>(size_); + } + else if(H_->size() != size_) + { + H_->resize(size_); + } + int32_t mass_int = 0x7fffffff; int32_t a_int = mass_int / size; mass_int = a_int * size; @@ -181,10 +169,9 @@ namespace multiverso { namespace lightlda int64_t mass_sum = 0; for (int32_t i = 0; i < size; ++i) { - (*q_w_proportion_)[i] /= mass; - (*q_w_proportion_int_)[i] = - static_cast((*q_w_proportion_)[i] * mass_int); - mass_sum += (*q_w_proportion_int_)[i]; + (*q_proportion_int_)[i] = + static_cast(q_proportion[i] / mass * mass_int); + mass_sum += (*q_proportion_int_)[i]; } if (mass_sum > mass_int) { @@ -192,9 +179,9 @@ namespace multiverso { namespace lightlda int32_t id = 0; for (int32_t i = 0; i < more;) { - if ((*q_w_proportion_int_)[id] >= 1) + if ((*q_proportion_int_)[id] >= 1) { - --(*q_w_proportion_int_)[id]; + --(*q_proportion_int_)[id]; ++i; } id = (id + 1) % size; @@ -207,7 +194,7 @@ namespace multiverso { namespace lightlda int32_t id = 0; for (int32_t i = 0; i < more; ++i) { - ++(*q_w_proportion_int_)[id]; + ++(*q_proportion_int_)[id]; id = (id + 1) % size; } } @@ -221,7 +208,7 @@ namespace multiverso { namespace lightlda int32_t L_head = 0, L_tail = 0, H_head = 0, H_tail = 0; for (int32_t k = 0; k < size; ++k) { - int32_t val = (*q_w_proportion_int_)[k]; + int32_t val = (*q_proportion_int_)[k]; if (val < height) { (*L_)[L_tail].first = k; @@ -276,5 +263,63 @@ namespace multiverso { namespace lightlda ++H_head; } } + + void AliasMultinomialRNGInt::Clear() + { + delete q_proportion_int_; + q_proportion_int_ = nullptr; + delete L_; + L_ = nullptr; + delete H_; + H_ = nullptr; + } + + int32_t AliasMultinomialRNGInt::Propose(xorshift_rng& rng, int32_t height, + int32_t* kv_vector) + { + auto sample = rng.rand(); + int32_t idx = sample / height; + if (size_ <= idx) idx = size_ - 1; + + int32_t* p = kv_vector + 2 * idx; + int32_t k = *p++; + int32_t v = *p; + int32_t m = -(sample < v); + return (idx & m) | (k & ~m); + } + + int32_t AliasMultinomialRNGInt::Propose(xorshift_rng& rng, + int32_t height, int32_t height_sum, + float mass, float mass_sum, + int32_t* kv_vector, int32_t vsize, + int32_t* kv_vector_sum) + { + auto sample = rng.rand_double() * (mass + mass_sum); + if (sample < mass) + { + int32_t* idx_vector = kv_vector + 2 * vsize; + auto n_sample = rng.rand(); + int32_t idx = n_sample / height; + if (vsize <= idx) idx = vsize - 1; + int32_t* p = kv_vector + 2 * idx; + int32_t k = *p++; + int32_t v = *p; + int32_t id = idx_vector[idx]; + int32_t m = -(n_sample < v); + return (id & m) | (idx_vector[k] & ~m); + } + else + { + auto n_sample = rng.rand(); + int32_t idx = n_sample / height_sum; + if (size_ <= idx) idx = size_ - 1; + int32_t* p = kv_vector_sum + 2 * idx; + int32_t k = *p++; + int32_t v = *p; + int32_t m = -(n_sample < v); + return (idx & m) | (k & ~m); + } + } + // -- AliasMultinomialRNGInt implement area --------------------------------- // } // namespace lightlda } // namespace multiverso diff --git a/src/alias_table.h b/src/alias_table.h index 4646e64..adf3b66 100644 --- a/src/alias_table.h +++ b/src/alias_table.h @@ -23,6 +23,31 @@ namespace multiverso { namespace lightlda class xorshift_rng; class AliasTableIndex; + class AliasMultinomialRNGInt + { + public: + AliasMultinomialRNGInt(int32_t size): size_(size) {} + void Build(const std::vector& q_proportion, int32_t size, + float mass, int32_t & height, int32_t* kv_vector); + static void Clear(); + + //for dense sampling + int32_t Propose(xorshift_rng& rng, int32_t height, int32_t* kv_vector); + //for sparse sampling + int32_t Propose(xorshift_rng& rng, + int32_t height, int32_t height_sum, + float mass, float mass_sum, + int32_t* kv_vector, int32_t vsize, + int32_t* kv_vector_sum); + + private: + int32_t size_; + // thread local storage used for building alias + _THREAD_LOCAL static std::vector* q_proportion_int_; + _THREAD_LOCAL static std::vector>* L_; + _THREAD_LOCAL static std::vector>* H_; + }; + /*! * \brief AliasTable is the storage for alias tables used for fast sampling * from lightlda word proposal distribution. It optimize memory usage @@ -56,8 +81,7 @@ namespace multiverso { namespace lightlda /*! \brief Clear the alias table */ void Clear(); private: - void AliasMultinomialRNG(int32_t size, float mass, int32_t& height, - int32_t* kv_vector); + AliasMultinomialRNGInt * alias_rng_int_; int* memory_block_; int64_t memory_size_; AliasTableIndex* table_index_; @@ -71,9 +95,6 @@ namespace multiverso { namespace lightlda // thread local storage used for building alias _THREAD_LOCAL static std::vector* q_w_proportion_; - _THREAD_LOCAL static std::vector* q_w_proportion_int_; - _THREAD_LOCAL static std::vector>* L_; - _THREAD_LOCAL static std::vector>* H_; int num_vocabs_; int num_topics_; diff --git a/src/asym_alpha.cpp b/src/asym_alpha.cpp new file mode 100644 index 0000000..d0a04aa --- /dev/null +++ b/src/asym_alpha.cpp @@ -0,0 +1,117 @@ +#include "asym_alpha.h" +#include "alias_table.h" +#include "model.h" +#include "common.h" +#include + +namespace multiverso +{ +namespace lightlda +{ + AsymAlpha::AsymAlpha() : dirichlet_scale_(1.0), dirichlet_shape_(1.00001) + { + num_topic_ = Config::num_topics; + max_doc_length_ = kMaxDocLength; + num_alpha_iterations_ = Config::num_alpha_iterations; + alpha_sum_ = num_topic_ * Config::alpha; + non_zero_limit_.resize(num_topic_); + alpha_base_measure_.resize(num_topic_, Config::alpha); + kv_vector_ = new int32_t[2 * num_topic_]; + alias_rng_int_ = new AliasMultinomialRNGInt(num_topic_); + } + + AsymAlpha::~AsymAlpha() + { + delete [] kv_vector_; + delete alias_rng_int_; + } + + void AsymAlpha::LearnDirichletPrior(ModelBase * model) + { + float oldParametersK; + float currentDigamma; + float denominator; + int nonZeroLimit; + float parametersSum; + + // get the initial non_zero_limit_ + for (int k = 0; k < num_topic_; ++k) + { + non_zero_limit_[k] = 0; + Row& row = model->GetTopicFrequencyRow(k); + for (int i = 1; i <= max_doc_length_; ++i) + { + if (row.At(i) > 0) + { + non_zero_limit_[k] = i; + } + } + } + + // get the initial atomic_alpha_sum_ + parametersSum = 0; + for (int k = 0; k < num_topic_; k++) + { + parametersSum += alpha_base_measure_[k]; + } + + Row& doc_length_row = model->GetDocLengthRow(); + + for (int iteration = 0; + iteration < num_alpha_iterations_; ++iteration) + { + // Calculate the denominator + denominator = 0; + currentDigamma = 0; + + // Iterate over the histogram: + for (int i = 1; i <= max_doc_length_; i++) + { + currentDigamma += 1 / (parametersSum + i - 1); + denominator += doc_length_row.At(i) * currentDigamma; + } + // Bayesian estimation Part I + denominator -= 1 / dirichlet_scale_; + + // Calculate the individual parameters + parametersSum = 0; + + for (int k = 0; k < num_topic_; k++) + { + // What's the largest non-zero element in the histogram? + nonZeroLimit = non_zero_limit_[k]; + + oldParametersK = alpha_base_measure_[k]; + alpha_base_measure_[k] = 0; + currentDigamma = 0; + + Row& row = model->GetTopicFrequencyRow(k); + + for (int i = 1; i <= nonZeroLimit; i++) + { + currentDigamma += 1 / (oldParametersK + i - 1); + alpha_base_measure_[k] += row.At(i) * currentDigamma; + } + + // Bayesian estimation part II + alpha_base_measure_[k] = oldParametersK + * (alpha_base_measure_[k] + dirichlet_shape_) + / denominator; + parametersSum += alpha_base_measure_[k]; + } + } + alpha_sum_ = parametersSum; + } + + void AsymAlpha::BuildAlias() + { + alias_rng_int_->Build(alpha_base_measure_, num_topic_, + alpha_sum_, alpha_height_, kv_vector_); + } + + int32_t AsymAlpha::Next() + { + return alias_rng_int_->Propose(rng_, alpha_height_, kv_vector_); + } +} // namespace lightlda +} // namespace multiverso diff --git a/src/asym_alpha.h b/src/asym_alpha.h new file mode 100644 index 0000000..ea0c600 --- /dev/null +++ b/src/asym_alpha.h @@ -0,0 +1,59 @@ +/*! + * \file asym_alpha.h + * \brief Defines asymmetric prior alpha + */ + +#ifndef LIGHTLDA_ASYM_ALPHA_H_ +#define LIGHTLDA_ASYM_ALPHA_H_ + +#include +#include +#include "util.h" + +namespace multiverso +{ +namespace lightlda +{ + class ModelBase; + class AliasMultinomialRNGInt; + + class AsymAlpha + { + public: + AsymAlpha(); + ~AsymAlpha(); + void LearnDirichletPrior(ModelBase * model); + void BuildAlias(); + int32_t Next(); + float At(int32_t idx) const; + float AlphaSum() const; + private: + int32_t num_topic_; + int32_t max_doc_length_; + int32_t num_alpha_iterations_; + float dirichlet_scale_; + float dirichlet_shape_; + float alpha_sum_; + int32_t alpha_height_; + int32_t* kv_vector_; + std::vector non_zero_limit_; + std::vector alpha_base_measure_; + xorshift_rng rng_; + AliasMultinomialRNGInt * alias_rng_int_; + }; + + // -- inline functions definition area --------------------------------- // + inline float AsymAlpha::At(int32_t idx) const + { + return alpha_base_measure_[idx]; + } + + inline float AsymAlpha::AlphaSum() const + { + return alpha_sum_; + } + // -- inline functions definition area --------------------------------- // +} // namespace lightlda +} // namespace multiverso + +#endif //LIGHTLDA_ASYM_ALPHA_H_ diff --git a/src/common.cpp b/src/common.cpp index 3e18535..245d2e1 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -10,6 +10,8 @@ namespace multiverso { namespace lightlda int32_t Config::num_vocabs = -1; int32_t Config::num_topics = 100; int32_t Config::num_iterations = 100; + int32_t Config::num_alpha_iterations = 100; + int32_t Config::learn_alpha_every = 5; int32_t Config::mh_steps = 2; int32_t Config::num_servers = 1; int32_t Config::num_local_workers = 1; @@ -22,6 +24,7 @@ namespace multiverso { namespace lightlda std::string Config::input_dir = ""; bool Config::warm_start = false; bool Config::inference = false; + bool Config::asymmetric_prior = false; bool Config::out_of_core = false; int64_t Config::data_capacity = 1024 * kMB; int64_t Config::model_capacity = 512 * kMB; @@ -44,6 +47,8 @@ namespace multiverso { namespace lightlda if (strcmp(argv[i], "-num_vocabs") == 0) num_vocabs = atoi(argv[i + 1]); if (strcmp(argv[i], "-num_topics") == 0) num_topics = atoi(argv[i + 1]); if (strcmp(argv[i], "-num_iterations") == 0) num_iterations = atoi(argv[i + 1]); + if (strcmp(argv[i], "-num_alpha_iterations") == 0) num_alpha_iterations = atoi(argv[i + 1]); + if (strcmp(argv[i], "-learn_alpha_every") == 0) learn_alpha_every = atoi(argv[i + 1]); if (strcmp(argv[i], "-mh_steps") == 0) mh_steps = atoi(argv[i + 1]); if (strcmp(argv[i], "-num_servers") == 0) num_servers = atoi(argv[i + 1]); if (strcmp(argv[i], "-num_local_workers") == 0) num_local_workers = atoi(argv[i + 1]); @@ -56,10 +61,11 @@ namespace multiverso { namespace lightlda if (strcmp(argv[i], "-server_file") == 0) server_file = std::string(argv[i + 1]); if (strcmp(argv[i], "-warm_start") == 0) warm_start = true; if (strcmp(argv[i], "-out_of_core") == 0) out_of_core = true; + if (strcmp(argv[i], "-asymmetric_prior") == 0) asymmetric_prior = true; if (strcmp(argv[i], "-data_capacity") == 0) data_capacity = atoi(argv[i + 1]) * kMB; if (strcmp(argv[i], "-model_capacity") == 0) model_capacity = atoi(argv[i + 1]) * kMB; if (strcmp(argv[i], "-alias_capacity") == 0) alias_capacity = atoi(argv[i + 1]) * kMB; - if (strcmp(argv[i], "-delta_capacity") == 0) delta_capacity = atoi(argv[i + 1]) * kMB; + if (strcmp(argv[i], "-delta_capacity") == 0) delta_capacity = atoi(argv[i + 1]) * kMB; } Check(); } @@ -70,6 +76,8 @@ namespace multiverso { namespace lightlda printf("-num_vocabs Size of dataset vocabulary \n"); printf("-num_topics Number of topics. Default: 100\n"); printf("-num_iterations Number of iteratioins. Default: 100\n"); + printf("-num_alhpa_iterations Number of learning alpha iteratioins. Default: 100\n"); + printf("-learn_alhpa_every Frequency of learning alpha. Default: 5\n"); printf("-mh_steps Metropolis-hasting steps. Default: 2\n"); printf("-alpha Dirichlet prior alpha. Default: 0.1\n"); printf("-beta Dirichlet prior beta. Default: 0.01\n\n"); @@ -82,6 +90,7 @@ namespace multiverso { namespace lightlda printf("-num_aggregator Number of local aggregation threads. Default: 1\n"); printf("-server_file Server endpoint file. Used by MPI-free version\n"); printf("-warm_start Warm start \n"); + printf("-asymmetric_prior Use asymmetric prior \n\n"); printf("-out_of_core Use out of core computing \n\n"); printf("-data_capacity Memory pool size(MB) for data storage, \n"); printf(" should larger than the any data block\n"); @@ -97,6 +106,8 @@ namespace multiverso { namespace lightlda printf("-num_vocabs Size of dataset vocabulary \n"); printf("-num_topics Number of topics. Default: 100\n"); printf("-num_iterations Number of iteratioins. Default: 100\n"); + printf("-num_alhpa_iterations Number of learning alpha iteratioins. Default: 100\n"); + printf("-learn_alhpa_every Frequency of learning alpha. Default: 5\n"); printf("-mh_steps Metropolis-hasting steps. Default: 2\n"); printf("-alpha Dirichlet prior alpha. Default: 0.1\n"); printf("-beta Dirichlet prior beta. Default: 0.01\n\n"); @@ -106,6 +117,7 @@ namespace multiverso { namespace lightlda printf(" files generated by dump_block \n\n"); printf("-num_local_workers Number of local training threads. Default: 4\n"); printf("-warm_start Warm start \n"); + printf("-asymmetric_prior Use asymmetric prior \n\n"); printf("-out_of_core Use out of core computing \n\n"); printf("-data_capacity Memory pool size(MB) for data storage, \n"); printf(" should larger than the any data block\n"); @@ -126,7 +138,8 @@ namespace multiverso { namespace lightlda void Config::Check() { - if (input_dir == "" || num_vocabs <= 0 || max_num_document == -1) + if (input_dir == "" || num_vocabs <= 0 || max_num_document == -1 || + (asymmetric_prior && learn_alpha_every > num_iterations)) { PrintUsage(); } diff --git a/src/common.h b/src/common.h index bfef1ed..7fec997 100644 --- a/src/common.h +++ b/src/common.h @@ -16,6 +16,10 @@ namespace multiverso { namespace lightlda const int32_t kWordTopicTable = 0; /*! \brief constant variable for table id */ const int32_t kSummaryRow = 1; + /*! \brief constant variable for table id */ + const int32_t kTopicFrequencyTable = 2; + /*! \brief constant variable for table id */ + const int32_t kDocLengthRow = 3; /*! \brief load factor for sparse hash table */ const int32_t kLoadFactor = 2; /*! \brief max length of a document */ @@ -36,8 +40,12 @@ namespace multiverso { namespace lightlda static int32_t num_vocabs; /*! \brief number of topics */ static int32_t num_topics; - /*! \brief number of iterations for trainning */ + /*! \brief number of iterations */ static int32_t num_iterations; + /*! \brief number of learning alpha iterations */ + static int32_t num_alpha_iterations; + /*! \brief frequency of learning alpha */ + static int32_t learn_alpha_every; /*! \brief number of metropolis-hastings steps */ static int32_t mh_steps; /*! \brief number of servers for Multiverso setting */ @@ -62,6 +70,8 @@ namespace multiverso { namespace lightlda static bool warm_start; /*! \brief inference mode */ static bool inference; + /*! \brief asymmetric prior */ + static bool asymmetric_prior; /*! \brief option specity whether use out of core computation */ static bool out_of_core; /*! \brief memory capacity settings, for memory pools */ diff --git a/src/data_stream.cpp b/src/data_stream.cpp index 21675e9..0d46560 100644 --- a/src/data_stream.cpp +++ b/src/data_stream.cpp @@ -31,8 +31,8 @@ namespace multiverso { namespace lightlda { typedef DoubleBuffer DataBuffer; public: - DiskDataStream(int32_t num_blocks, std::string data_path, - int32_t num_iterations); + DiskDataStream(std::string data_path, + int32_t num_blocks, int32_t num_iterations); virtual ~DiskDataStream(); virtual void BeforeDataAccess() override; virtual void EndDataAccess() override; @@ -46,12 +46,12 @@ namespace multiverso { namespace lightlda DataBuffer* data_buffer_; /*! \brief current block id to be accessed */ int32_t block_id_; + /*! \brief data path */ + std::string data_path_; /*! \brief number of data blocks in disk */ int32_t num_blocks_; /*! \brief number of training iterations */ int32_t num_iterations_; - /*! \brief data path */ - std::string data_path_; /*! \brief backend thread for data preload */ std::thread preload_thread_; bool working_; @@ -95,9 +95,9 @@ namespace multiverso { namespace lightlda return *data_buffer_[index_]; } - DiskDataStream::DiskDataStream(int32_t num_blocks, - std::string data_path, int32_t num_iterations) : - num_blocks_(num_blocks), data_path_(data_path), + DiskDataStream::DiskDataStream(std::string data_path, + int32_t num_blocks, int32_t num_iterations) : + data_path_(data_path), num_blocks_(num_blocks), num_iterations_(num_iterations), working_(false) { block_id_ = 0; @@ -178,7 +178,7 @@ namespace multiverso { namespace lightlda { if (Config::out_of_core && Config::num_blocks != 1) { - return new DiskDataStream(Config::num_blocks, Config::input_dir, + return new DiskDataStream(Config::input_dir, Config::num_blocks, Config::num_iterations); } else diff --git a/src/lightlda.cpp b/src/lightlda.cpp index 199634d..3bff702 100644 --- a/src/lightlda.cpp +++ b/src/lightlda.cpp @@ -1,10 +1,12 @@ #include "common.h" #include "trainer.h" #include "alias_table.h" +#include "asym_alpha.h" #include "data_stream.h" #include "data_block.h" #include "document.h" #include "meta.h" +#include "model.h" #include "util.h" #include #include @@ -22,12 +24,17 @@ namespace multiverso { namespace lightlda Config::Init(argc, argv); AliasTable* alias_table = new AliasTable(); + AsymAlpha* asym_alpha = nullptr; + if(Config::asymmetric_prior) + { + asym_alpha = new AsymAlpha(); + } Barrier* barrier = new Barrier(Config::num_local_workers); meta.Init(); std::vector trainers; for (int32_t i = 0; i < Config::num_local_workers; ++i) { - Trainer* trainer = new Trainer(alias_table, barrier, &meta); + Trainer* trainer = new Trainer(alias_table, asym_alpha, barrier, &meta); trainers.push_back(trainer); } @@ -43,7 +50,8 @@ namespace multiverso { namespace lightlda + std::to_string(clock()) + ".log"); data_stream = CreateDataStream(); - InitMultiverso(); + InitDocTopic(); + PSModel::Init(&meta, data_stream); Train(); Multiverso::Close(); @@ -59,6 +67,7 @@ namespace multiverso { namespace lightlda delete data_stream; delete barrier; delete alias_table; + if(Config::asymmetric_prior) delete asym_alpha; } private: static void Train() @@ -93,16 +102,7 @@ namespace multiverso { namespace lightlda Multiverso::EndTrain(); } - static void InitMultiverso() - { - Multiverso::BeginConfig(); - CreateTable(); - ConfigTable(); - Initialize(); - Multiverso::EndConfig(); - } - - static void Initialize() + static void InitDocTopic() { xorshift_rng rng; for (int32_t block = 0; block < Config::num_blocks; ++block) @@ -124,14 +124,8 @@ namespace multiverso { namespace lightlda // Init the latent variable if (!Config::warm_start) doc->SetTopic(cursor, rng.rand_k(Config::num_topics)); - // Init the server table - Multiverso::AddToServer(kWordTopicTable, - doc->Word(cursor), doc->Topic(cursor), 1); - Multiverso::AddToServer(kSummaryRow, - 0, doc->Topic(cursor), 1); } } - Multiverso::Flush(); } data_stream->EndDataAccess(); } @@ -163,60 +157,6 @@ namespace multiverso { namespace lightlda } } - static void CreateTable() - { - int32_t num_vocabs = Config::num_vocabs; - int32_t num_topics = Config::num_topics; - Type int_type = Type::Int; - Type longlong_type = Type::LongLong; - multiverso::Format dense_format = multiverso::Format::Dense; - multiverso::Format sparse_format = multiverso::Format::Sparse; - - Multiverso::AddServerTable(kWordTopicTable, num_vocabs, - num_topics, int_type, dense_format); - Multiverso::AddCacheTable(kWordTopicTable, num_vocabs, - num_topics, int_type, dense_format, Config::model_capacity); - Multiverso::AddAggregatorTable(kWordTopicTable, num_vocabs, - num_topics, int_type, dense_format, Config::delta_capacity); - - Multiverso::AddTable(kSummaryRow, 1, Config::num_topics, - longlong_type, dense_format); - } - - static void ConfigTable() - { - multiverso::Format dense_format = multiverso::Format::Dense; - multiverso::Format sparse_format = multiverso::Format::Sparse; - for (int32_t word = 0; word < Config::num_vocabs; ++word) - { - if (meta.tf(word) > 0) - { - if (meta.tf(word) * kLoadFactor > Config::num_topics) - { - Multiverso::SetServerRow(kWordTopicTable, - word, dense_format, Config::num_topics); - Multiverso::SetCacheRow(kWordTopicTable, - word, dense_format, Config::num_topics); - } - else - { - Multiverso::SetServerRow(kWordTopicTable, - word, sparse_format, meta.tf(word) * kLoadFactor); - Multiverso::SetCacheRow(kWordTopicTable, - word, sparse_format, meta.tf(word) * kLoadFactor); - } - } - if (meta.local_tf(word) > 0) - { - if (meta.local_tf(word) * 2 * kLoadFactor > Config::num_topics) - Multiverso::SetAggregatorRow(kWordTopicTable, - word, dense_format, Config::num_topics); - else - Multiverso::SetAggregatorRow(kWordTopicTable, word, - sparse_format, meta.local_tf(word) * 2 * kLoadFactor); - } - } - } private: /*! \brief training data access */ static IDataStream* data_stream; diff --git a/src/meta.cpp b/src/meta.cpp index ca867ec..53cac17 100644 --- a/src/meta.cpp +++ b/src/meta.cpp @@ -7,7 +7,7 @@ namespace multiverso { namespace lightlda { LocalVocab::LocalVocab() - : num_slices_(0), own_memory_(false), vocabs_(nullptr), size_(0) + : num_slices_(0), vocabs_(nullptr), size_(0), own_memory_(false) {} LocalVocab::~LocalVocab() @@ -161,6 +161,13 @@ namespace multiverso { namespace lightlda { Log::Info("Actual Model capacity: %d MB, Alias capacity: %d MB, Delta capacity: %dMB\n", model_offset/1024/1024, alias_offset/1024/1024, delta_offset/1024/1024); + if(Config::asymmetric_prior) + { + Log::Info("Actual asymmetric alpha capacity: %d MB, Alias capacity: %dMB, Delta capacity: %d MB\n", + Config::num_topics * (kMaxDocLength + 1) * sizeof(int32_t)/1024/1024, + 2 * Config::num_topics * sizeof(int32_t)/1024/1024, + Config::num_topics * (kMaxDocLength + 1) * sizeof(int32_t)/1024/1024); + } local_vocab.slice_index_.push_back(j); ++local_vocab.num_slices_; model_offset = model_size; diff --git a/src/model.cpp b/src/model.cpp index 1878d53..ceab6ee 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -13,14 +13,20 @@ #include "meta.h" #include "trainer.h" +#include "data_stream.h" +#include "data_block.h" +#include "document.h" #include #include namespace multiverso { namespace lightlda { - LocalModel::LocalModel(Meta * meta) : word_topic_table_(nullptr), - summary_table_(nullptr), meta_(meta) + // -- local model implement area --------------------------------- // + LocalModel::LocalModel(Meta * meta) : + word_topic_table_(nullptr), summary_table_(nullptr), + topic_frequency_table_(nullptr), doc_length_table_(nullptr), + meta_(meta) { CreateTable(); } @@ -35,7 +41,7 @@ namespace multiverso { namespace lightlda int32_t num_vocabs = Config::num_vocabs; int32_t num_topics = Config::num_topics; multiverso::Format dense_format = multiverso::Format::Dense; - multiverso::Format sparse_format = multiverso::Format::Sparse; + //multiverso::Format sparse_format = multiverso::Format::Sparse; Type int_type = Type::Int; Type longlong_type = Type::LongLong; @@ -178,13 +184,13 @@ namespace multiverso { namespace lightlda model_file.close(); } - void LocalModel::AddWordTopicRow( + void LocalModel::AddWordTopic( integer_t word_id, integer_t topic_id, int32_t delta) { Log::Fatal("Not implemented yet\n"); } - void LocalModel::AddSummaryRow(integer_t topic_id, int64_t delta) + void LocalModel::AddSummary(integer_t topic_id, int64_t delta) { Log::Fatal("Not implemented yet\n"); } @@ -198,7 +204,166 @@ namespace multiverso { namespace lightlda { return *(static_cast*>(summary_table_->GetRow(0))); } + + Row& LocalModel::GetTopicFrequencyRow(integer_t topic_id) + { + return *(static_cast*>(topic_frequency_table_->GetRow(topic_id))); + } + Row& LocalModel::GetDocLengthRow() + { + return *(static_cast*>(doc_length_table_->GetRow(0))); + } + void LocalModel::AddTopicFrequency(integer_t topic_id, integer_t freq, + int32_t delta) + { + Log::Fatal("Not implemented yet\n"); + } + void LocalModel::AddDocLength(integer_t doc_len, int32_t delta) + { + Log::Fatal("Not implemented yet\n"); + } + // -- local model implement area --------------------------------- // + + // -- PS model implement area --------------------------------- // + void PSModel::Init(Meta* meta, IDataStream * data_stream) + { + Multiverso::BeginConfig(); + CreateTable(); + ConfigTable(meta); + LoadTable(meta, data_stream); + Multiverso::EndConfig(); + } + + void PSModel::CreateTable() + { + int32_t num_vocabs = Config::num_vocabs; + int32_t num_topics = Config::num_topics; + Type int_type = Type::Int; + Type longlong_type = Type::LongLong; + multiverso::Format dense_format = multiverso::Format::Dense; + //multiverso::Format sparse_format = multiverso::Format::Sparse; + + Multiverso::AddServerTable(kWordTopicTable, num_vocabs, + num_topics, int_type, dense_format); + Multiverso::AddCacheTable(kWordTopicTable, num_vocabs, + num_topics, int_type, dense_format, Config::model_capacity); + Multiverso::AddAggregatorTable(kWordTopicTable, num_vocabs, + num_topics, int_type, dense_format, Config::delta_capacity); + Multiverso::AddTable(kSummaryRow, 1, Config::num_topics, + longlong_type, dense_format); + + if(Config::asymmetric_prior) + { + Multiverso::AddServerTable(kTopicFrequencyTable, num_topics, + kMaxDocLength + 1, int_type, dense_format); + Multiverso::AddCacheTable(kTopicFrequencyTable, num_topics, + kMaxDocLength + 1, int_type, dense_format, + num_topics * (kMaxDocLength + 1) * sizeof(int32_t)); + Multiverso::AddAggregatorTable(kTopicFrequencyTable, num_topics, + kMaxDocLength + 1, int_type, dense_format, + num_topics * (kMaxDocLength + 1) * sizeof(int32_t)); + Multiverso::AddTable(kDocLengthRow, 1, (kMaxDocLength + 1), + int_type, dense_format); + } + } + + void PSModel::ConfigTable(Meta* meta) + { + multiverso::Format dense_format = multiverso::Format::Dense; + multiverso::Format sparse_format = multiverso::Format::Sparse; + for (int32_t word = 0; word < Config::num_vocabs; ++word) + { + if (meta->tf(word) > 0) + { + if (meta->tf(word) * kLoadFactor > Config::num_topics) + { + Multiverso::SetServerRow(kWordTopicTable, + word, dense_format, Config::num_topics); + Multiverso::SetCacheRow(kWordTopicTable, + word, dense_format, Config::num_topics); + } + else + { + Multiverso::SetServerRow(kWordTopicTable, + word, sparse_format, meta->tf(word) * kLoadFactor); + Multiverso::SetCacheRow(kWordTopicTable, + word, sparse_format, meta->tf(word) * kLoadFactor); + } + } + if (meta->local_tf(word) > 0) + { + if (meta->local_tf(word) * 2 * kLoadFactor > Config::num_topics) + Multiverso::SetAggregatorRow(kWordTopicTable, + word, dense_format, Config::num_topics); + else + Multiverso::SetAggregatorRow(kWordTopicTable, word, + sparse_format, meta->local_tf(word) * 2 * kLoadFactor); + } + } + if(Config::asymmetric_prior) + { + for(int32_t topic = 0; topic < Config::num_topics; topic++) + { + Multiverso::SetServerRow(kTopicFrequencyTable, + topic, dense_format, kMaxDocLength + 1); + Multiverso::SetCacheRow(kTopicFrequencyTable, + topic, dense_format, kMaxDocLength + 1); + Multiverso::SetAggregatorRow(kTopicFrequencyTable, + topic, dense_format, kMaxDocLength + 1); + } + } + } + void PSModel::LoadTable(Meta* meta, IDataStream * data_stream) + { + int32_t t, c; + std::unique_ptr> doc_topic_counter; + doc_topic_counter.reset(new Row(0, + multiverso::Format::Sparse, kMaxDocLength)); + for (int32_t block = 0; block < Config::num_blocks; ++block) + { + data_stream->BeforeDataAccess(); + DataBlock& data_block = data_stream->CurrDataBlock(); + int32_t num_slice = meta->local_vocab(block).num_slice(); + for (int32_t slice = 0; slice < num_slice; ++slice) + { + for (int32_t i = 0; i < data_block.Size(); ++i) + { + Document* doc = data_block.GetOneDoc(i); + int32_t& cursor = doc->Cursor(); + if (slice == 0) cursor = 0; + int32_t last_word = meta->local_vocab(block).LastWord(slice); + // Init the server table + for (; cursor < doc->Size(); ++cursor) + { + if (doc->Word(cursor) > last_word) break; + Multiverso::AddToServer(kWordTopicTable, + doc->Word(cursor), doc->Topic(cursor), 1); + Multiverso::AddToServer(kSummaryRow, + 0, doc->Topic(cursor), 1); + } + if(Config::asymmetric_prior && 0 == slice) + { + doc_topic_counter->Clear(); + doc->GetDocTopicVector(*doc_topic_counter); + Row::iterator iter = doc_topic_counter->Iterator(); + while (iter.HasNext()) + { + t = iter.Key(); + c = iter.Value(); + Multiverso::AddToServer(kTopicFrequencyTable, + t, c, 1); + iter.Next(); + } + Multiverso::AddToServer(kDocLengthRow, 0, doc->Size(), 1); + } + } + Multiverso::Flush(); + } + data_stream->EndDataAccess(); + } + } + Row& PSModel::GetWordTopicRow(integer_t word_id) { return trainer_->GetRow(kWordTopicTable, word_id); @@ -209,16 +374,35 @@ namespace multiverso { namespace lightlda return trainer_->GetRow(kSummaryRow, 0); } - void PSModel::AddWordTopicRow( + void PSModel::AddWordTopic( integer_t word_id, integer_t topic_id, int32_t delta) { trainer_->Add(kWordTopicTable, word_id, topic_id, delta); } - void PSModel::AddSummaryRow(integer_t topic_id, int64_t delta) + void PSModel::AddSummary(integer_t topic_id, int64_t delta) { trainer_->Add(kSummaryRow, 0, topic_id, delta); } + Row& PSModel::GetTopicFrequencyRow(integer_t topic_id) + { + return trainer_->GetRow(kTopicFrequencyTable, topic_id); + } + Row& PSModel::GetDocLengthRow() + { + return trainer_->GetRow(kDocLengthRow, 0); + } + void PSModel::AddTopicFrequency(integer_t topic_id, integer_t freq, + int32_t delta) + { + trainer_->Add(kTopicFrequencyTable, topic_id, freq, delta); + } + void PSModel::AddDocLength(integer_t doc_len, int32_t delta) + { + trainer_->Add(kSummaryRow, 0, doc_len, delta); + } + // -- PS model implement area --------------------------------- // + } // namespace lightlda } // namespace multiverso diff --git a/src/model.h b/src/model.h index 2e71acf..bc678e6 100644 --- a/src/model.h +++ b/src/model.h @@ -21,6 +21,7 @@ namespace lightlda { class Meta; class Trainer; + class IDataStream; /*! \brief interface for acceess to model */ class ModelBase @@ -29,9 +30,15 @@ namespace lightlda virtual ~ModelBase() {} virtual Row& GetWordTopicRow(integer_t word_id) = 0; virtual Row& GetSummaryRow() = 0; - virtual void AddWordTopicRow(integer_t word_id, integer_t topic_id, + virtual void AddWordTopic(integer_t word_id, integer_t topic_id, int32_t delta) = 0; - virtual void AddSummaryRow(integer_t topic_id, int64_t delta) = 0; + virtual void AddSummary(integer_t topic_id, int64_t delta) = 0; + + virtual Row& GetTopicFrequencyRow(integer_t topic_id) = 0; + virtual Row& GetDocLengthRow() = 0; + virtual void AddTopicFrequency(integer_t topic_id, integer_t freq, + int32_t delta) = 0; + virtual void AddDocLength(integer_t doc_len, int32_t delta) = 0; }; /*! \brief model based on local buffer */ @@ -43,9 +50,15 @@ namespace lightlda Row& GetWordTopicRow(integer_t word_id) override; Row& GetSummaryRow() override; - void AddWordTopicRow(integer_t word_id, integer_t topic_id, + void AddWordTopic(integer_t word_id, integer_t topic_id, + int32_t delta) override; + void AddSummary(integer_t topic_id, int64_t delta) override; + + Row& GetTopicFrequencyRow(integer_t topic_id) override; + Row& GetDocLengthRow() override; + void AddTopicFrequency(integer_t topic_id, integer_t freq, int32_t delta) override; - void AddSummaryRow(integer_t topic_id, int64_t delta) override; + void AddDocLength(integer_t doc_len, int32_t delta) override; private: void CreateTable(); @@ -55,6 +68,8 @@ namespace lightlda std::unique_ptr word_topic_table_; std::unique_ptr
summary_table_; + std::unique_ptr
topic_frequency_table_; + std::unique_ptr
doc_length_table_; Meta* meta_; LocalModel(const LocalModel&) = delete; @@ -64,14 +79,27 @@ namespace lightlda /*! \brief model based on parameter server */ class PSModel : public ModelBase { + public: + static void Init(Meta* meta, IDataStream * data_stream); + private: + static void CreateTable(); + static void ConfigTable(Meta* meta); + static void LoadTable(Meta* meta, IDataStream * data_stream); + public: explicit PSModel(Trainer* trainer) : trainer_(trainer) {} Row& GetWordTopicRow(integer_t word_id) override; Row& GetSummaryRow() override; - void AddWordTopicRow(integer_t word_id, integer_t topic_id, + void AddWordTopic(integer_t word_id, integer_t topic_id, + int32_t delta) override; + void AddSummary(integer_t topic_id, int64_t delta) override; + + Row& GetTopicFrequencyRow(integer_t topic_id) override; + Row& GetDocLengthRow() override; + void AddTopicFrequency(integer_t topic_id, integer_t freq, int32_t delta) override; - void AddSummaryRow(integer_t topic_id, int64_t delta) override; + void AddDocLength(integer_t doc_len, int32_t delta) override; private: Trainer* trainer_; diff --git a/src/sampler.cpp b/src/sampler.cpp index 7af201f..c1849aa 100644 --- a/src/sampler.cpp +++ b/src/sampler.cpp @@ -1,12 +1,14 @@ #include "sampler.h" #include "alias_table.h" +#include "asym_alpha.h" #include "common.h" #include "document.h" #include "model.h" #include #include +#include namespace multiverso { namespace lightlda { @@ -28,7 +30,8 @@ namespace multiverso { namespace lightlda } int32_t LightDocSampler::SampleOneDoc(Document* doc, int32_t slice, - int32_t lastword, ModelBase* model, AliasTable* alias) + int32_t lastword, ModelBase* model, AliasTable* alias, + AsymAlpha* asym_alpha) { DocInit(doc); int32_t num_tokens = 0; @@ -40,7 +43,7 @@ namespace multiverso { namespace lightlda if (word > lastword) break; int32_t old_topic = doc->Topic(cursor); int32_t new_topic = Sample(doc, word, old_topic, old_topic, - model, alias); + model, alias, asym_alpha); if (old_topic != new_topic) { doc->SetTopic(cursor, new_topic); @@ -48,10 +51,25 @@ namespace multiverso { namespace lightlda doc_topic_counter_->Add(new_topic, 1); if(!Config::inference) { - model->AddWordTopicRow(word, old_topic, -1); - model->AddSummaryRow(old_topic, -1); - model->AddWordTopicRow(word, new_topic, 1); - model->AddSummaryRow(new_topic, 1); + model->AddWordTopic(word, old_topic, -1); + model->AddSummary(old_topic, -1); + model->AddWordTopic(word, new_topic, 1); + model->AddSummary(new_topic, 1); + if(Config::asymmetric_prior) + { + int32_t old_freq = doc_topic_counter_->At(old_topic); + int32_t new_freq = doc_topic_counter_->At(new_topic); + model->AddTopicFrequency(old_topic, old_freq + 1, -1); + if(old_freq > 0) + { + model->AddTopicFrequency(old_topic, old_freq, 1); + } + model->AddTopicFrequency(new_topic, new_freq, 1); + if(new_freq - 1 > 0) + { + model->AddTopicFrequency(new_topic, new_freq - 1, -1); + } + } } } ++num_tokens; @@ -67,12 +85,14 @@ namespace multiverso { namespace lightlda int32_t LightDocSampler::Sample(Document* doc, int32_t word, int32_t old_topic, int32_t s, - ModelBase* model, AliasTable* alias) + ModelBase* model, AliasTable* alias, + AsymAlpha* asym_alpha) { int32_t t, w_t_cnt, w_s_cnt; int64_t n_t, n_s; float n_td_alpha, n_sd_alpha; float n_tw_beta, n_sw_beta, n_t_beta_sum, n_s_beta_sum; + double n_td_or_alpha; float proposal_t, proposal_s; float nominator, denominator; double rejection, pi; @@ -98,8 +118,16 @@ namespace multiverso { namespace lightlda n_t = summary_row.At(t); n_s = summary_row.At(s); - n_td_alpha = doc_topic_counter_->At(t) + alpha_; - n_sd_alpha = doc_topic_counter_->At(s) + alpha_; + if(asym_alpha) + { + n_td_alpha = doc_topic_counter_->At(t) + asym_alpha->At(t); + n_sd_alpha = doc_topic_counter_->At(s) + asym_alpha->At(s); + } + else + { + n_td_alpha = doc_topic_counter_->At(t) + alpha_; + n_sd_alpha = doc_topic_counter_->At(s) + alpha_; + } n_tw_beta = w_t_cnt + beta_; n_t_beta_sum = n_t + beta_sum_; n_sw_beta = w_s_cnt + beta_; @@ -129,8 +157,14 @@ namespace multiverso { namespace lightlda s = (t & m) | (s & ~m); } // Doc proposal - double n_td_or_alpha = rng_.rand_double() * - (doc->Size() + alpha_sum_); + if(asym_alpha) + { + n_td_or_alpha = rng_.rand_double() * (doc->Size() + asym_alpha->AlphaSum()); + } + else + { + n_td_or_alpha = rng_.rand_double() * (doc->Size() + alpha_sum_); + } if (n_td_or_alpha < doc->Size()) { int32_t t_idx = static_cast(n_td_or_alpha); @@ -138,7 +172,14 @@ namespace multiverso { namespace lightlda } else { - t = rng_.rand_k(num_topic_); + if(asym_alpha) + { + t = asym_alpha->Next(); + } + else + { + t = rng_.rand_k(num_topic_); + } } if (t != s) { @@ -149,8 +190,19 @@ namespace multiverso { namespace lightlda n_t = summary_row.At(t); n_s = summary_row.At(s); - n_td_alpha = doc_topic_counter_->At(t) + alpha_; - n_sd_alpha = doc_topic_counter_->At(s) + alpha_; + if(asym_alpha) + { + n_td_alpha = doc_topic_counter_->At(t) + asym_alpha->At(t); + n_sd_alpha = doc_topic_counter_->At(s) + asym_alpha->At(s); + } + else + { + n_td_alpha = doc_topic_counter_->At(t) + alpha_; + n_sd_alpha = doc_topic_counter_->At(s) + alpha_; + } + proposal_t = n_td_alpha; + proposal_s = n_sd_alpha; + n_tw_beta = w_t_cnt + beta_; n_t_beta_sum = n_t + beta_sum_; n_sw_beta = w_s_cnt + beta_; @@ -169,9 +221,6 @@ namespace multiverso { namespace lightlda } - proposal_s = (doc_topic_counter_->At(s) + alpha_); - proposal_t = (doc_topic_counter_->At(t) + alpha_); - nominator = n_td_alpha * n_tw_beta * n_s_beta_sum * proposal_s; denominator = n_sd_alpha * n_sw_beta * n_t_beta_sum * proposal_t; @@ -186,7 +235,7 @@ namespace multiverso { namespace lightlda int32_t LightDocSampler::ApproxSample(Document* doc, int32_t word, int32_t old_topic, int32_t s, - ModelBase* model, AliasTable* alias) + ModelBase* model, AliasTable* alias, AsymAlpha* asym_alpha) { float n_tw_beta, n_sw_beta, n_t_beta_sum, n_s_beta_sum; float nominator, denominator; @@ -202,8 +251,16 @@ namespace multiverso { namespace lightlda t = alias->Propose(word, rng_); if (t != s) { - nominator = doc_topic_counter_->At(t) + alpha_; - denominator = doc_topic_counter_->At(s) + alpha_; + if(asym_alpha) + { + nominator = doc_topic_counter_->At(t) + asym_alpha->At(t); + denominator = doc_topic_counter_->At(s) + asym_alpha->At(s); + } + else + { + nominator = doc_topic_counter_->At(t) + alpha_; + denominator = doc_topic_counter_->At(s) + alpha_; + } if (t == old_topic) { nominator -= 1; @@ -227,7 +284,14 @@ namespace multiverso { namespace lightlda } else { - t = rng_.rand_k(num_topic_); + if(asym_alpha) + { + t = asym_alpha->Next(); + } + else + { + t = rng_.rand_k(num_topic_); + } } if (t != s) { diff --git a/src/sampler.h b/src/sampler.h index 404dd6e..789702a 100644 --- a/src/sampler.h +++ b/src/sampler.h @@ -18,6 +18,7 @@ namespace multiverso namespace multiverso { namespace lightlda { class AliasTable; + class AsymAlpha; class Document; class ModelBase; @@ -34,10 +35,11 @@ namespace multiverso { namespace lightlda * \param lastword last word of current slice * \param model pointer model, for access of model * \param alias pointer to alias table, for access of alias + * \param asym_alpha asym alpha prior provider * \return number of sampled token */ int32_t SampleOneDoc(Document* doc, int32_t slice, int32_t lastword, - ModelBase* model, AliasTable* alias); + ModelBase* model, AliasTable* alias, AsymAlpha* asym_alpha); /*! * \brief Get doc-topic-counter, for reusing this container * \return reference to light hash map @@ -57,9 +59,11 @@ namespace multiverso { namespace lightlda * \param old_topic old topic assignment of this token * \param model access * \param alias for alias table access + * \param asym_alpha asym alpha prior provider */ int32_t Sample(Document* doc, int32_t word, int32_t state, - int32_t old_topic, ModelBase* model, AliasTable* alias); + int32_t old_topic, ModelBase* model, AliasTable* alias, + AsymAlpha* asym_alpha); /*! * \brief Sample the latent topic assignment for a token. This function @@ -69,7 +73,8 @@ namespace multiverso { namespace lightlda * \param same with Sample */ int32_t ApproxSample(Document* doc, int32_t word, int32_t state, - int32_t old_topic, ModelBase* model, AliasTable* alias); + int32_t old_topic, ModelBase* model, AliasTable* alias, + AsymAlpha* asym_alpha); private: // lda hyper-parameter float alpha_; diff --git a/src/trainer.cpp b/src/trainer.cpp index eca707e..386773a 100644 --- a/src/trainer.cpp +++ b/src/trainer.cpp @@ -1,6 +1,7 @@ #include "trainer.h" #include "alias_table.h" +#include "asym_alpha.h" #include "common.h" #include "data_block.h" #include "eval.h" @@ -19,8 +20,10 @@ namespace multiverso { namespace lightlda double Trainer::word_llh_ = 0.0; Trainer::Trainer(AliasTable* alias_table, + AsymAlpha* asym_alpha, Barrier* barrier, Meta* meta) : - alias_(alias_table), barrier_(barrier), meta_(meta), + alias_(alias_table), asym_alpha_(asym_alpha), + barrier_(barrier), meta_(meta), model_(nullptr) { sampler_ = new LightDocSampler(); @@ -71,13 +74,27 @@ namespace multiverso { namespace lightlda Log::Info("Rank = %d, Alias Time used: %.2f s \n", Multiverso::ProcessRank(), watch.ElapsedSeconds()); } + // Learn alpha and build Alias table + if(asym_alpha_!= nullptr && + iter % Config::learn_alpha_every == 0 && + id == 0) + { + watch.Restart(); + asym_alpha_->LearnDirichletPrior(model_); + asym_alpha_->BuildAlias(); + Log::Info("Rank = %d, AsymAplha Time used: %.2f s \n", + Multiverso::ProcessRank(), watch.ElapsedSeconds()); + } + barrier_->Wait(); + int32_t num_token = 0; watch.Restart(); // Train with lightlda sampler for (int32_t doc_id = id; doc_id < data.Size(); doc_id += trainer_num) { Document* doc = data.GetOneDoc(doc_id); - num_token += sampler_->SampleOneDoc(doc, slice, lastword, model_, alias_); + num_token += sampler_->SampleOneDoc(doc, slice, lastword, model_, alias_, + asym_alpha_); } if (TrainerId() == 0) { @@ -100,7 +117,11 @@ namespace multiverso { namespace lightlda // if (iter != 0 && iter % 50 == 0) Dump(iter, lda_data_block); // Clear the thread information in alias table - if (iter == Config::num_iterations - 1) alias_->Clear(); + if (iter == Config::num_iterations - 1) + { + AliasMultinomialRNGInt::Clear(); + alias_->Clear(); + } } void Trainer::Evaluate(LDADataBlock* lda_data_block) @@ -191,6 +212,7 @@ namespace multiverso { namespace lightlda reinterpret_cast(data_block); // Request word-topic-table int32_t slice = lda_data_block->slice(); + int32_t iter = lda_data_block->iteration(); DataBlock& data = lda_data_block->data(); const LocalVocab& local_vocab = data.meta(); @@ -199,10 +221,21 @@ namespace multiverso { namespace lightlda { RequestRow(kWordTopicTable, *p); } - Log::Debug("Request params. start = %d, end = %d\n", + Log::Debug("Request word-topic-table. start = %d, end = %d\n", *local_vocab.begin(slice), *(local_vocab.end(slice) - 1)); // Request summary-row RequestTable(kSummaryRow); + Log::Debug("Request summary-row\n"); + + if(Config::asymmetric_prior && iter % Config::learn_alpha_every == 0) + { + // Request topic-frequency-table + RequestTable(kTopicFrequencyTable); + Log::Debug("Request topic-frequency-table\n"); + // Request doc-length-row + RequestTable(kDocLengthRow); + Log::Debug("Request doc-length-row\n"); + } } } // namespace lightlda } // namespace multiverso diff --git a/src/trainer.h b/src/trainer.h index 2b483fd..cc2f1f5 100644 --- a/src/trainer.h +++ b/src/trainer.h @@ -14,6 +14,7 @@ namespace multiverso { namespace lightlda { class AliasTable; + class AsymAlpha; class LDADataBlock; class LightDocSampler; class Meta; @@ -23,7 +24,7 @@ namespace multiverso { namespace lightlda class Trainer : public TrainerBase { public: - Trainer(AliasTable* alias, Barrier* barrier, Meta* meta); + Trainer(AliasTable* alias, AsymAlpha* asym_alpha, Barrier* barrier, Meta* meta); ~Trainer(); /*! * \brief Defines Trainning method for a data_block in one iteration @@ -41,6 +42,8 @@ namespace multiverso { namespace lightlda private: /*! \brief alias table, for alias access */ AliasTable* alias_; + /*! \brief asym alpha */ + AsymAlpha* asym_alpha_; /*! \brief sampler for lightlda */ LightDocSampler* sampler_; /*! \brief barrier for thread-sync */