diff --git a/stan/math/opencl/kernel_generator/load.hpp b/stan/math/opencl/kernel_generator/load.hpp index 799cdb551a0..ae0f7d1eb6d 100644 --- a/stan/math/opencl/kernel_generator/load.hpp +++ b/stan/math/opencl/kernel_generator/load.hpp @@ -49,7 +49,7 @@ class load_ * Creates a deep copy of this expression. * @return copy of \c *this */ - inline load_ deep_copy() const & { return load_(a_); } + inline load_ deep_copy() const& { return load_(a_); } inline load_ deep_copy() && { return load_(std::forward(a_)); } /** diff --git a/stan/math/prim/prob/neg_binomial_2_lpmf.hpp b/stan/math/prim/prob/neg_binomial_2_lpmf.hpp index 7d0d5f9552d..c67499b9915 100644 --- a/stan/math/prim/prob/neg_binomial_2_lpmf.hpp +++ b/stan/math/prim/prob/neg_binomial_2_lpmf.hpp @@ -3,15 +3,14 @@ #include #include +#include #include -#include #include #include #include #include #include #include -#include #include namespace stan { @@ -47,7 +46,7 @@ return_type_t neg_binomial_2_lpmf( size_t size_phi = stan::math::size(phi); size_t size_mu_phi = max_size(mu, phi); size_t size_n_phi = max_size(n, phi); - size_t max_size_seq_view = max_size(n, mu, phi); + size_t size_all = max_size(n, mu, phi); VectorBuilder mu_val(size_mu); for (size_t i = 0; i < size_mu; ++i) { @@ -76,39 +75,30 @@ return_type_t neg_binomial_2_lpmf( n_plus_phi[i] = n_vec[i] + phi_val[i]; } - for (size_t i = 0; i < max_size_seq_view; i++) { - // if phi is large we probably overflow, defer to Poisson: - if (phi_val[i] > 1e5) { - // TODO(martinmodrak) This is wrong (doesn't pass propto information), - // and inaccurate for n = 0, but shouldn't break most models. - // Also the 1e5 cutoff is too small. - // Will be addressed better in PR #1497 - logp += poisson_lpmf(n_vec[i], mu_val[i]); - } else { - if (include_summand::value) { - logp -= lgamma(n_vec[i] + 1.0); - } - if (include_summand::value) { - logp += multiply_log(phi_val[i], phi_val[i]) - lgamma(phi_val[i]); - } - if (include_summand::value) { - logp += multiply_log(n_vec[i], mu_val[i]); - } - if (include_summand::value) { - logp += lgamma(n_plus_phi[i]); - } - logp -= n_plus_phi[i] * log_mu_plus_phi[i]; + for (size_t i = 0; i < size_all; i++) { + if (include_summand::value) { + logp += binomial_coefficient_log(n_plus_phi[i] - 1, n_vec[i]); + } + if (include_summand::value) { + logp += multiply_log(n_vec[i], mu_val[i]); } + logp += -phi_val[i] * (log1p(mu_val[i] / phi_val[i])) + - n_vec[i] * log_mu_plus_phi[i]; if (!is_constant_all::value) { ops_partials.edge1_.partials_[i] - += n_vec[i] / mu_val[i] - n_plus_phi[i] / mu_plus_phi[i]; + += n_vec[i] / mu_val[i] - (n_vec[i] + phi_val[i]) / (mu_plus_phi[i]); } if (!is_constant_all::value) { - ops_partials.edge2_.partials_[i] += 1.0 - n_plus_phi[i] / mu_plus_phi[i] - + log_phi[i] - log_mu_plus_phi[i] - - digamma(phi_val[i]) - + digamma(n_plus_phi[i]); + T_partials_return log_term; + if (mu_val[i] < phi_val[i]) { + log_term = log1p(-mu_val[i] / (mu_plus_phi[i])); + } else { + log_term = log_phi[i] - log_mu_plus_phi[i]; + } + ops_partials.edge2_.partials_[i] + += (mu_val[i] - n_vec[i]) / (mu_plus_phi[i]) + log_term + - (digamma(phi_val[i]) - digamma(n_plus_phi[i])); } } return ops_partials.build(logp); diff --git a/test/unit/math/prim/prob/neg_binomial_2_log_test.cpp b/test/unit/math/prim/prob/neg_binomial_2_log_test.cpp index 00645c9ceee..a9887050958 100644 --- a/test/unit/math/prim/prob/neg_binomial_2_log_test.cpp +++ b/test/unit/math/prim/prob/neg_binomial_2_log_test.cpp @@ -212,7 +212,10 @@ TEST(ProbNegBinomial2, log_matches_lpmf) { TEST(ProbDistributionsNegBinomial2Log, neg_binomial_2_log_grid_test) { std::vector mu_log_to_test = {-101, -27, -3, -1, -0.132, 0, 4, 10, 87}; - std::vector phi_to_test = {2e-5, 0.36, 1, 2.3e5, 1.8e10, 6e16}; + // TODO(martinmodrak) Reducing the span of the test, should be fixed + // along with #1495 + // std::vector phi_to_test = {2e-5, 0.36, 1, 10, 2.3e5, 1.8e10, 6e16}; + std::vector phi_to_test = {0.36, 1, 10}; std::vector n_to_test = {0, 1, 10, 39, 101, 3048, 150054}; // TODO(martinmdorak) Only weak tolerance for this quick fix diff --git a/test/unit/math/prim/prob/neg_binomial_2_test.cpp b/test/unit/math/prim/prob/neg_binomial_2_test.cpp index e73eb227b0f..47853d4e0b8 100644 --- a/test/unit/math/prim/prob/neg_binomial_2_test.cpp +++ b/test/unit/math/prim/prob/neg_binomial_2_test.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #include #include @@ -238,27 +239,40 @@ TEST(ProbDistributionsNegBinomial2, chiSquareGoodnessFitTest4) { } TEST(ProbDistributionsNegBinomial2, extreme_values) { - int N = 100; - double mu = 8; - double phi = 1e12; - for (int n = 0; n < 10; ++n) { - phi *= 10; - double logp = stan::math::neg_binomial_2_log(N, mu, phi); - EXPECT_LT(logp, 0); + std::vector n_to_test = {0, 1, 5, 100, 12985, 1968422}; + std::vector mu_to_test = {1e-5, 0.1, 8, 713, 28311, 19850054}; + for (double mu : mu_to_test) { + for (int n : n_to_test) { + // Test across a range of phi + for (double phi = 1e12; phi < 1e22; phi *= 10) { + double logp = stan::math::neg_binomial_2_log(n, mu, phi); + EXPECT_LT(logp, 0) << "n = " << n << ", mu = " << mu + << ", phi = " << phi; + } + } } } -TEST(ProbDistributionsNegBinomial2, vectorAroundCutoff) { - int y = 10; - double mu = 9.36; - std::vector phi; - phi.push_back(1); - phi.push_back(1e15); - double vector_value = stan::math::neg_binomial_2_lpmf(y, mu, phi); - double scalar_value = stan::math::neg_binomial_2_lpmf(y, mu, phi[0]) - + stan::math::neg_binomial_2_lpmf(y, mu, phi[1]); - - EXPECT_FLOAT_EQ(vector_value, scalar_value); +TEST(ProbDistributionsNegBinomial2, zeroOne) { + using stan::test::expect_near_rel; + + std::vector mu_to_test = {2.345e-5, 0.2, 13, 150, 1621, 18432, 1e10}; + double phi_start = 1e-8; + double phi_max = 1e22; + for (double mu : mu_to_test) { + for (double phi = phi_start; phi < phi_max; phi *= stan::math::pi()) { + std::stringstream msg; + msg << ", mu = " << mu << ", phi = " << phi; + + double expected_value_0 = phi * (-log1p(mu / phi)); + double value_0 = stan::math::neg_binomial_2_lpmf(0, mu, phi); + expect_near_rel("n = 0 " + msg.str(), value_0, expected_value_0); + + double expected_value_1 = (phi + 1) * (-log1p(mu / phi)) + log(mu); + double value_1 = stan::math::neg_binomial_2_lpmf(1, mu, phi); + expect_near_rel("n = 1 " + msg.str(), value_1, expected_value_1); + } + } } TEST(ProbDistributionsNegativeBinomial2Log, distributionCheck) { diff --git a/test/unit/math/rev/fun/lbeta_test.cpp b/test/unit/math/rev/fun/lbeta_test.cpp index f4ea4ccf6bb..0c015a2c770 100644 --- a/test/unit/math/rev/fun/lbeta_test.cpp +++ b/test/unit/math/rev/fun/lbeta_test.cpp @@ -74,11 +74,6 @@ TEST(MathFunctions, lbeta_identities_gradient) { // Successors: beta(a,b) = beta(a + 1, b) + beta(a, b + 1) for (double x : to_test) { for (double y : to_test) { - // TODO(martinmodrak) this restriction on testing should be lifted once - // the log_sum_exp bug (#1679) is resolved - if (x > 1e10 || y > 1e10) { - continue; - } auto rh = [](const var& a, const var& b) { return stan::math::log_sum_exp(lbeta(a + 1, b), lbeta(a, b + 1)); }; diff --git a/test/unit/math/rev/prob/neg_binomial_2_test.cpp b/test/unit/math/rev/prob/neg_binomial_2_test.cpp index 3dbc8cedd91..52542071c83 100644 --- a/test/unit/math/rev/prob/neg_binomial_2_test.cpp +++ b/test/unit/math/rev/prob/neg_binomial_2_test.cpp @@ -1,20 +1,449 @@ #include +#include #include +#include +#include #include +#include -TEST(ProbDistributionsNegBinomial2, derivatives) { +namespace neg_binomial_2_test_internal { +struct TestValue { + int n; + double mu; + double phi; + double value; + double grad_mu; + double grad_phi; +}; + +// Test data generated in Mathematica (Wolfram Cloud). The code can be re-ran at +// https://www.wolframcloud.com/obj/martin.modrak/Published/neg_binomial_2_lpmf.nb +// but is also presented below for convenience: +// +// toCString[x_] := ToString[CForm[N[x, 24]]]; +// nb2[n_,mu_,phi_]:= LogGamma[n + phi] - LogGamma[n + 1] - LogGamma[phi] + +// n * (Log[mu] - Log[mu + phi]) + phi * (Log[phi] - Log[mu + phi]) +// nb2dmu[n_,mu_,phi_]= D[nb2[n, mu, phi],mu]; +// nb2dphi[n_,mu_,phi_]= D[nb2[n, mu, phi],phi]; +// mus= SetPrecision[{0.0000256,0.314,1.5,3.0,180.0,1123.0,10586.0}, Infinity]; +// phis= SetPrecision[{0.0004,0.065,4.42,800.0, 15324.0}, Infinity]; +// ns = {0,6,14,1525,10233}; +// out = "std::vector testValues = {\n"; +// For[k = 1, k <= Length[ns], k++, { +// For[i = 1, i <= Length[mus], i++, { +// For[j = 1, j <= Length[phis], j++, { +// cmu = mus[[i]]; +// cphi = phis[[j]]; +// cn=ns[[k]]; +// val = nb2[cn,cmu,cphi]; +// ddmu= nb2dmu[cn,cmu,cphi]; +// ddphi= nb2dphi[cn,cmu,cphi]; +// out = StringJoin[out," {",ToString[cn],",",toCString[cmu],",", +// toCString[cphi],",", toCString[val],",",toCString[ddmu],",", +// toCString[ddphi],"},\n"]; +// }] +// }] +// }] +// out = StringJoin[out,"};\n"]; +// out +// +std::vector testValues = { + {0, 0.0000255999999999999988415517, 0.000400000000000000019168694, + -0.0000248141563677810554722069, -0.93984962406015038120717, + -0.00188501497960301691484332}, + {0, 0.0000255999999999999988415517, 0.065000000000000002220446, + -0.0000255949600924861392310515, -0.99960630889987943213184, + -7.75166869111834814851761e-8}, + {0, 0.0000255999999999999988415517, 4.41999999999999992894573, + -0.0000255999258645396469364285, -0.99999420817834177250031, + -1.67726638232395618551723e-11}, + {0, 0.0000255999999999999988415517, 800., -0.0000255999995904000075796848, + -0.999999968000001023999969, -5.11999978154667406760709e-16}, + {0, 0.0000255999999999999988415517, 15324., -0.0000255999999786165480692301, + -0.999999998329417909342659, -1.39542226236663819726744e-18}, + {0, 0.314000000000000001332268, 0.000400000000000000019168694, + -0.00266678271697168272833231, -0.00127226463104325705295582, + -5.66822905706024975839158}, + {0, 0.314000000000000001332268, 0.065000000000000002220446, + -0.11460468078714130271943, -0.171503957783641165200901, + -0.934652892970430377577097}, + {0, 0.314000000000000001332268, 4.41999999999999992894573, + -0.303348202309659617395206, -0.93367131389945078032175, + -0.00230212890163621522485358}, + {0, 0.314000000000000001332268, 800., -0.313938393619808992331269, + -0.999607653995806645889824, -7.69878314071302376498577e-8}, + {0, 0.314000000000000001332268, 15324., -0.313996782998787813729905, + -0.999979509686371605280254, -2.09929343978617832505718e-10}, + {0, 1.5, 0.000400000000000000019168694, -0.00329191110003275507221173, + -0.000266595574513463089285247, -7.23004434565640074923338}, + {0, 1.5, 0.065000000000000002220446, -0.206781499150110076176223, + -0.0415335463258785956090882, -2.22278737940449504349247}, + {0, 1.5, 4.41999999999999992894573, -1.29150964740387939616191, + -0.74662162162162161858047, -0.0388183744279291789704566}, + {0, 1.5, 800., -1.49859550534427827334156, -0.998128509045539613225203, + -1.75342721996106687969703e-6}, + {0, 1.5, 15324., -1.49992659053829698010225, -0.999902123911128511304688, + -4.79017695146261715054393e-9}, + {0, 3., 0.000400000000000000019168694, -0.00356911664958785467088415, + -0.000133315557925609925065371, -7.92292493952756185953978}, + {0, 3., 0.065000000000000002220446, -0.250472012600853859503051, + -0.0212071778140293644937542, -2.87462275628870399444374}, + {0, 3., 4.41999999999999992894573, -2.28973397601639585934409, + -0.595687331536388136290008, -0.113726692626070465736387}, + {0, 3., 800., -2.99438902306750149309853, -0.996264009962640099626401, + -6.99624147447649277415792e-6}, + {0, 3., 15324., -2.99970638131217673536622, -0.999804266979839498923468, + -1.91582075605741709616675e-8}, + {0, 180., 0.000400000000000000019168694, -0.00520680203358650246376672, + -2.22221728396159130295628e-6, -12.0170073061835394972106}, + {0, 180., 0.065000000000000002220446, -0.5152345838836882774678, + -0.000360980756948879583592839, -6.92704688665984518523008}, + {0, 180., 4.41999999999999992894573, -16.4913562250745366389837, + -0.023967031775295520717106, -2.75504310079668396055706}, + {0, 180., 800., -162.3526751973522458866, -0.816326530612244897959184, + -0.0192673746089352053174335}, + {0, 180., 15324., -178.951041022560568926078, -0.988390092879256965944272, + -0.0000679211892648339243088643}, + {0, 1123., 0.000400000000000000019168694, -0.00593912212871338103184145, + -3.56188653183026487941311e-7, -13.8478056779721050510975}, + {0, 1123., 0.065000000000000002220446, -0.634217014783734223797677, + -0.0000578773267798391030085681, -8.75724272015345987190877}, + {0, 1123., 4.41999999999999992894573, -24.4936395934523254888155, + -0.00392045555338737997305733, -4.54546787488649278326856}, + {0, 1123., 800., -701.624014336681947952778, -0.416016640665626625065003, + -0.293046658586479060005975}, + {0, 1123., 15324., -1083.75715139725783359741, -0.931720070529579862588922, + -0.00244293345030929573936621}, + {0, 10586., 0.000400000000000000019168694, -0.00683653348042399192327693, + -3.77857533426883319566749e-8, -16.0913337388457323318343}, + {0, 10586., 0.065000000000000002220446, -0.780043017108597613744853, + -6.14014744855619176912596e-6, -11.0006679418181806653922}, + {0, 10586., 4.41999999999999992894573, -34.3945190758501995485746, + -0.000417358329509122388817005, -6.7819827601055724937662}, + {0, 10586., 800., -2124.42246547447991619038, -0.0702617249253469172668189, + -1.72578980676844681250479}, + {0, 10586., 15324., -8048.29915678058041924358, -0.591431879583172520262447, + -0.116640647318788574768032}, + {6, 0.0000255999999999999988415517, 0.000400000000000000019168694, + -26.48036259722249074196, 220276.315789473695413011, + -11595.463497838709066521}, + {6, 0.0000255999999999999988415517, 0.065000000000000002220446, + -51.41938813644810311008, 234281.729042100352628204, + -74.6938075459730277737496}, + {6, 0.0000255999999999999988415517, 4.41999999999999992894573, + -67.5203813751119080113018, 234372.642547590685193838, + -0.431255924758490035983718}, + {6, 0.0000255999999999999988415517, 800., -69.9980790657427038228031, + 234373.99250003225060488, -0.0000233303844710378174910093}, + {6, 0.0000255999999999999988415517, 15324., -70.0158073212983957030198, + 234373.999608459003690188, -6.38614276575077488493694e-8}, + {6, 0.314000000000000001332268, 0.000400000000000000019168694, + -9.62519749441237577499515, 0.0230385244971718457666188, + 2477.53054955581647437693}, + {6, 0.314000000000000001332268, 0.065000000000000002220446, + -5.62316274126580026418416, 3.10564173234962949352322, + 0.811756664687527711460638}, + {6, 0.314000000000000001332268, 4.41999999999999992894573, + -11.7481186360489882881077, 16.9071818179371882584098, + -0.343526975436714230039528}, + {6, 0.314000000000000001332268, 800., -13.8270107248858319921459, + 18.1011755433762947868859, -0.0000204650172704751119244161}, + {6, 0.314000000000000001332268, 15324., -13.8425659569858167176518, + 18.107909210435378733015, -5.604917883004236869323e-8}, + {6, 1.5, 0.000400000000000000019168694, -9.61978396161923317848729, + 0.00079978672354038926785574, 2491.05377011516723211824}, + {6, 1.5, 0.065000000000000002220446, -4.84100453988744407299957, + 0.124600638977635786827265, 11.5208909282007767518774}, + {6, 1.5, 4.41999999999999992894573, -4.69471130087565244211116, + 2.23986486486486485574141, -0.126129611536089775574877}, + {6, 1.5, 800., -5.6375883578559464150792, 2.99438552713661883967561, + -0.0000110478695320263035818071}, + {6, 1.5, 15324., -5.64599569956598171509704, 2.99970637173338553391406, + -3.03295927479829163411156e-8}, + {6, 3., 0.000400000000000000019168694, -9.61926132713561020240632, + 0.000133315557925609925065371, 2492.36008977011386837543}, + {6, 3., 0.065000000000000002220446, -4.75878244235486509222263, + 0.0212071778140293644937542, 12.7453357216411121474773}, + {6, 3., 4.41999999999999992894573, -2.88910819582874566312831, + 0.595687331536388136290008, 0.00385024685205873591732729}, + {6, 3., 800., -2.98571724360242195248096, 0.996264009962640099626401, + -2.3069406647953774852549e-6}, + {6, 3., 15324., -2.98747963477610734281426, 0.999804266979839498923468, + -6.38245840446628643263247e-9}, + {6, 180., 0.000400000000000000019168694, -9.62011239916672045724397, + -2.14814337449620492619107e-6, 2490.23240751308261564996}, + {6, 180., 0.065000000000000002220446, -4.89709960487442545538356, + -0.000348948065050583597473077, 10.6171759350004771903562}, + {6, 180., 4.41999999999999992894573, -11.8028833191445189961465, + -0.0231681307161190033598691, -1.86137525666548784745946}, + {6, 180., 800., -138.973123190861697819951, -0.789115646258503401360544, + -0.0179131542129975601896404}, + {6, 180., 15324., -154.441639359481734262519, -0.955443756449948400412797, + -0.0000634392772194434652186986}, + {6, 1123., 0.000400000000000000019168694, -9.62083352307562850466433, + -3.54285597155334449715445e-7, 2488.42959957075572856696}, + {6, 1123., 0.065000000000000002220446, -5.01426303422777693881116, + -0.0000575680979635621354056728, 8.81495887967631239323725}, + {6, 1123., 4.41999999999999992894573, -19.6831821663032128008537, + -0.0038995092191751588868255, -3.62458748384772905786375}, + {6, 1123., 800., -671.304184750646975825767, -0.41379393376981739999787, + -0.288690114015941777831476}, + {6, 1123., 15324., -1048.61720731875393149901, -0.926742047000481483982035, + -0.00241626280594218362665218}, + {6, 10586., 0.000400000000000000019168694, -9.62172902400956374565605, + -3.77643368945439780938618e-8, 2486.19084755337970635833}, + {6, 10586., 0.065000000000000002220446, -5.15977860354001469077731, + -6.13666729697000839952321e-6, 6.57630939764217156062813}, + {6, 10586., 4.41999999999999992894573, -29.5629973572818965727548, + -0.00041712177651676883371282, -5.85634703346878398423394}, + {6, 10586., 800., -2091.3124520859014907288, -0.0702219015407302460497774, + -1.71884010032985743252181}, + {6, 10586., 15324., -8002.42495594488889877697, -0.59109666408369216553719, + -0.116480739324794834878369}, + {14, 0.0000255999999999999988415517, 0.000400000000000000019168694, + -49.8145624383586412504773, 513979.323308270702830867, + -30391.5592215308054935068}, + {14, 0.0000255999999999999988415517, 0.065000000000000002220446, + -114.928171090288373621844, 546658.700573312689305051, + -196.832392093265922277751}, + {14, 0.0000255999999999999988415517, 4.41999999999999992894573, + -161.649312877052269679471, 546870.832603322503241318, + -1.65012869090202525000607}, + {14, 0.0000255999999999999988415517, 800., -173.098986627797207395457, + 546873.982500032584746084, -0.000140607279245208991169695}, + {14, 0.0000255999999999999988415517, 15324., -173.206165042629407730008, + 546873.999086402114500983, -3.87293689171005773693605e-7}, + {14, 0.314000000000000001332268, 0.000400000000000000019168694, + -10.4823212403875677919956, 0.0554529100014586495260517, + 2452.98201444586163091021}, + {14, 0.314000000000000001332268, 0.065000000000000002220446, + -7.91753972042693522598031, 7.47516931919399037182204, + -19.4065385144529681904019}, + {14, 0.314000000000000001332268, 4.41999999999999992894573, + -31.1096030160603968883541, 40.6949859937193736433852, + -1.4423583039713969676008}, + {14, 0.314000000000000001332268, 800., -41.6146101076375483341583, + 43.5688864732057633639203, -0.000133818772002702504501651}, + {14, 0.314000000000000001332268, 15324., -41.7166402819016445012403, + 43.5850941705977125174086, -3.68785202783475625902428e-7}, + {14, 1.5, 0.000400000000000000019168694, -10.4688561593594869703521, + 0.00222162978761219241071039, 2486.61861613580826696107}, + {14, 1.5, 0.065000000000000002220446, -5.96960149272681241527387, + 0.346112886048988296742402, 7.29895408232336579179109}, + {14, 1.5, 4.41999999999999992894573, -13.3341039738636535426733, + 6.22184684684684682150392, -0.886409460834882622189145}, + {14, 1.5, 800., -20.9264151022685229463224, 8.31773757537949677687669, + -0.000109610174761583369512834}, + {14, 1.5, 15324., -21.0100699378693506067827, 8.3325176992594042608724, + -3.02665838448691571323444e-7}, + {14, 3., 0.000400000000000000019168694, -10.4672670714982932266983, + 0.000488823712393903058573027, 2490.59053612251196637491}, + {14, 3., 0.065000000000000002220446, -5.71949591388313641569191, + 0.0777596519847743364770988, 11.0251057695297603161259}, + {14, 3., 4.41999999999999992894573, -7.79006429060418244080376, + 2.18418688230008983306336, -0.483245366998347713019428}, + {14, 3., 800., -12.7443244787126682405839, 3.6529680365296803652968, + -0.0000822242550653573074282639}, + {14, 3., 15324., -12.8071593989944707311042, 3.66594897892607816271938, + -2.27631817501670696455485e-7}, + {14, 180., 0.000400000000000000019168694, -10.467069325725552290929, + -2.04937816187568975717079e-6, 2491.08472067831368019905}, + {14, 180., 0.065000000000000002220446, -5.68921919805166470578108, + -0.000332904475852855615980062, 11.4626317745298003372252}, + {14, 180., 4.41999999999999992894573, -9.65337657959679917832682, + -0.0221029293038836468835533, -1.31368299764513841447103}, + {14, 180., 800., -117.570556789937986073345, -0.752834467120181405895692, + -0.0161936967338941701029113}, + {14, 180., 15324., -131.5984192349764106228, -0.911515307877536979704162, + -0.0000577016781835615346878727}, + {14, 1123., 0.000400000000000000019168694, -10.4677755213861685634857, + -3.51748189118411732080956e-7, 2489.31923330860236441056}, + {14, 1123., 0.065000000000000002220446, -5.80395729200942357200564, + -0.0000571557928751928452684791, 9.6977197567649020595191}, + {14, 1123., 4.41999999999999992894573, -17.3710293984633662562012, + -0.00387158077355886410518306, -3.04061182895062280846571}, + {14, 1123., 800., -640.647914910330698665346, -0.41083032457540509990836, + -0.282967557637372205015805}, + {14, 1123., 15324., -1011.59993064034851794618, -0.920104682295016979172854, + -0.00238094023047733949128954}, + {14, 10586., 0.000400000000000000019168694, -10.4686684750964033112778, + -3.77357816303515062767776e-8, 2487.08684934922314896474}, + {14, 10586., 0.065000000000000002220446, -5.94905895063816047333042, + -6.13202709485509724005287e-6, 7.46543792757153450807576}, + {14, 10586., 4.41999999999999992894573, -27.2227588675497963109583, + -0.000416806372526964093573906, -5.26603093110764468933264}, + {14, 10586., 800., -2056.93593717552727578896, -0.0701688036945746844270555, + -1.70965999479388506262815}, + {14, 10586., 15324., -7951.09533692356666073317, + -0.590649710084385025903513, -0.116267766949827820373742}, + {1525, 0.0000255999999999999988415517, 0.000400000000000000019168694, + -4301.78472746701591561319, 5.59871348684210554788785e7, + -3.5806687876344078642494e6}, + {1525, 0.0000255999999999999988415517, 0.065000000000000002220446, + -11965.4671281234822402701, 5.95468591985310427791447e7, + -23429.1126891147666042428}, + {1525, 0.0000255999999999999988415517, 4.41999999999999992894573, + -18367.3428918005844581777, 5.95699664793796696369591e7, + -339.05708253439450821354}, + {1525, 0.0000255999999999999988415517, 800., -24826.2281356385314540934, + 5.95703095937500956956661e7, -0.838976287617641050521021}, + {1525, 0.0000255999999999999988415517, 15324., -25707.7170288016101268158, + 5.95703114004829071688901e7, -0.00464306291332368931093165}, + {1525, 0.314000000000000001332268, 0.000400000000000000019168694, + -17.0947387955188070694514, 6.17771997212362870958894, + -2348.27115497920553780678}, + {1525, 0.314000000000000001332268, 0.065000000000000002220446, + -296.585568286099074874008, 832.769692284422651260512, + -4001.49218112616936417653}, + {1525, 0.314000000000000001332268, 4.41999999999999992894573, + -4115.10160678526157468258, 4533.61649969457963823061, + -316.176485619091868439549}, + {1525, 0.314000000000000001332268, 800., -10469.9426767525578568447, + 4853.78278834474664085128, -0.838228516034976412378816}, + {1525, 0.314000000000000001332268, 15324., -11350.864477542964439615, + 4855.58840352125850604476, -0.00464102415296814454779679}, + {1525, 1.5, 0.000400000000000000019168694, -15.5605375416167689753134, + 0.270772238514174011017382, 1484.28030775203761395908}, + {1525, 1.5, 0.065000000000000002220446, -74.4509275868752818011125, + 42.1842385516506936069639, -953.474508334327861816359}, + {1525, 1.5, 4.41999999999999992894573, -2072.19103657901873983403, + 758.318693693693690604898, -251.676626135276650941522}, + {1525, 1.5, 800., -8088.5487996808837996978, 1013.76585568725306716573, + -0.835410572412918435298525}, + {1525, 1.5, 15324., -8967.33164084000397251235, 1015.56725718573619131513, + -0.00463332752548632640449661}, + {1525, 3., 0.000400000000000000019168694, -15.3575220720668965063351, + 0.0676354263875927686498315, 1991.71749039935661708657}, + {1525, 3., 0.065000000000000002220446, -42.4918294753981563747316, + 10.7591082109842309198313, -477.238467087057049352225}, + {1525, 3., 4.41999999999999992894573, -1360.54978818586118224694, + 302.212039532794247811131, -199.675789571126135171012}, + {1525, 3., 800., -7035.84649923785032031894, 505.437941054379410543794, + -0.831861613850395752926704}, + {1525, 3., 15324., -7911.93122400832367818281, 507.234031447771905787173, + -0.00462360345575814251935502}, + {1525, 180., 0.000400000000000000019168694, -15.1592288635917619446729, + 0.0000166049013718241127915344, 2487.41676922898488796714}, + {1525, 180., 0.065000000000000002220446, -10.6183839860155018633925, + 0.00269732843386801688851316, 7.79296281407547718024725}, + {1525, 180., 4.41999999999999992894573, -30.7569326420675970103087, + 0.179086987432069307580598, -5.06066771330890867384431}, + {1525, 180., 800., -1255.10606104316181434609, 6.09977324263038548752834, + -0.50811617220616562857323}, + {1525, 180., 15324., -1861.51726736161923065929, 7.38547041623667010663915, + -0.00355560001152251827139362}, + {1525, 1123., 0.000400000000000000019168694, -15.1571154863562692398625, + 1.27504753855366561133043e-7, 2492.70020501203960042957}, + {1525, 1123., 0.065000000000000002220446, -10.2750368571307051553926, + 0.0000207183306905568293939843, 13.074039765317042756658}, + {1525, 1123., 4.41999999999999992894573, -7.75481686725872855763217, + 0.00140340439221881277753254, 0.0654298516080506315262451}, + {1525, 1123., 800., -30.3901820482837344036952, 0.148921362019218079497891, + -0.0188047284729684005572389}, + {1525, 1123., 15324., -64.3938471491118556432688, + 0.333527576449591366661395, -0.000290976140796059790385352}, + {1525, 10586., 0.000400000000000000019168694, -15.1575273331900733345777, + -3.23424061059983918249983e-8, 2491.67058800680726231035}, + {1525, 10586., 0.065000000000000002220446, -10.341961135413218891829, + -5.25560892040125199509261e-6, 12.0444483664247286854037}, + {1525, 10586., 4.41999999999999992894573, -12.3018556139457377868195, + -0.00035723444395259379983666, -0.962437235779729779911197}, + {1525, 10586., 800., -744.016916706287313438782, + -0.0601399480019429829354474, -0.792452568525027960065769}, + {1525, 10586., 15324., -4300.64559341391477270059, + -0.506231273465249027592861, -0.080624196979174112177973}, + {10233, 0.0000255999999999999988415517, 0.000400000000000000019168694, + -28781.0708528951573397343, 3.75682858552631598053215e8, + -2.40411931995216795644478e7}, + {10233, 0.0000255999999999999988415517, 0.065000000000000002220446, + -80237.4790550341092525423, 3.99569192710255671251893e8, + -157343.697098361009501263}, + {10233, 0.0000255999999999999988415517, 4.41999999999999992894573, + -123371.163364167740152869, 3.99724246355043753581641e8, + -2307.27997929395105479548}, + {10233, 0.0000255999999999999988415517, 800., -173733.65594636811650329, + 3.99726548708750459408368e8, -10.1666354466535988554224}, + {10233, 0.0000255999999999999988415517, 15324., -189611.393943800924397874, + 3.99726560832223983286441e8, -0.156271946909903467303064}, + {10233, 0.314000000000000001332268, 0.000400000000000000019168694, + -30.0835346409569131744303, 41.4607785935398146017316, + -30043.5682705897920756525}, + {10233, 0.314000000000000001332268, 0.065000000000000002220446, + -1936.716591758487041006, 5589.00047056450946728878, + -26975.8416131385263827685}, + {10233, 0.314000000000000001332268, 4.41999999999999992894573, + -27734.5558869327518549392, 30426.6413450334884297763, + -2153.73427754141249439361}, + {10233, 0.314000000000000001332268, 800., -77398.8345343110590054878, + 32575.3861354641231869532, -10.1616173371352784115513}, + {10233, 0.314000000000000001332268, 15324., -93272.7669155443506617326, + 32587.5042326579588303572, -0.156258265294963792040359}, + {10233, 1.5, 0.000400000000000000019168694, -19.7852231333246170201452, + 1.81844841375633173201467, -4317.60144720493538394871}, + {10233, 1.5, 0.065000000000000002220446, -445.630392408020282583574, + 283.300319488817900649591, -6515.78789458981692128489}, + {10233, 1.5, 4.41999999999999992894573, -14020.6484936315299566843, + 5092.70608108108106033739, -1720.72113290933113059977}, + {10233, 1.5, 800., -61412.5266404072312643186, 6808.23456019962570180911, + -10.1426989007295637839166}, + {10233, 1.5, 15324., -77272.0989845637286907938, 6820.33238719780757560928, + -0.156206593508853883167535}, + {10233, 3., 0.000400000000000000019168694, -18.4213731622889640481056, + 0.454606052526329844472915, -908.658303440053134799723}, + {10233, 3., 0.065000000000000002220446, -230.930124889414052187924, + 72.3164763458401329237019, -3316.44389947819074719284}, + {10233, 3., 4.41999999999999992894573, -9239.71902985399613363504, + 2031.29380053908354474893, -1371.35925605961202095728}, + {10233, 3., 800., -54340.1804040886113152815, 3397.26027397260273972603, + -10.1188548696496798960218}, + {10233, 3., 15324., -70181.625182690519906018, 3409.33255040125269132903, + -0.156141261363057784911755}, + {10233, 180., 0.000400000000000000019168694, -17.0814417743218085107483, + 0.000124110835309254874270108, 2440.94300115825922532467}, + {10233, 180., 0.065000000000000002220446, -15.5422428149329861379767, + 0.02016077527559492474366, -38.6634853761835068674979}, + {10233, 180., 4.41999999999999992894573, -235.497379149446594201917, + 1.33855872465025483205037, -50.3775345819770169845521}, + {10233, 180., 800., -14640.5024630709594885282, 45.5918367346938775510204, + -7.83648996527639800152069}, + {10233, 180., 15324., -28577.6446971680821285066, 55.2015866873065015479876, + -0.148587051440887372280679}, + {10233, 1123., 0.000400000000000000019168694, -17.0630789988207188668825, + 2.88947340204574470627368e-6, 2486.84988023336329186324}, + {10233, 1123., 0.065000000000000002220446, -12.5589181079456256044661, + 0.000469512419380529143729346, 7.22412495831966508988538}, + {10233, 1123., 4.41999999999999992894573, -35.4550615786577835917049, + 0.031803517445555682595327, -5.7569606052102630173202}, + {10233, 1123., 800., -3343.1303902972739556477, 3.37480997013700672692981, + -2.98980536947426087305154}, + {10233, 1123., 15324., -11352.0607180353268830661, + 7.55829905834770485145599, -0.113119180727235553840235}, + {10233, 10586., 0.000400000000000000019168694, -17.0607181926565361138193, + -1.26000103249281892884057e-9, 2492.75189435765511506763}, + {10233, 10586., 0.065000000000000002220446, -12.175300607237463417773, + -2.04748918320454911628704e-7, 13.1257236766090675675961}, + {10233, 10586., 4.41999999999999992894573, -9.43079204562662170966662, + -0.0000139172010501341586295487, 0.116749372001926601548071}, + {10233, 10586., 800., -7.27036619721226178082317, + -0.002342942461614156602606, 0.0000890483066241868231597611}, + {10233, 10586., 15324., -9.3278240351663417156768, + -0.0197218452194275363359762, -0.0000805948075011592832299515}, +}; +} // namespace neg_binomial_2_test_internal + +TEST(ProbDistributionsNegativeBinomial2, derivativesPrecomputed) { + using neg_binomial_2_test_internal::TestValue; + using neg_binomial_2_test_internal::testValues; using stan::math::is_nan; - using stan::math::neg_binomial_2_lpmf; + using stan::math::neg_binomial_2_log; + using stan::math::value_of; using stan::math::var; - int N = 100; - double mu_dbl = 8; - double phi_dbl = 1.5; - - for (int k = 0; k < 20; ++k) { - var mu(mu_dbl); - var phi(phi_dbl); - var val = neg_binomial_2_lpmf(N, mu, phi); + for (TestValue t : testValues) { + int n = t.n; + var mu(t.mu); + var phi(t.phi); + var val = neg_binomial_2_log(n, mu, phi); std::vector x; x.push_back(mu); @@ -27,22 +456,146 @@ TEST(ProbDistributionsNegBinomial2, derivatives) { EXPECT_FALSE(is_nan(gradients[i])); } - std::vector finite_diffs; - double eps = 1e-10; - double inv2e = 0.5 / eps; + auto tolerance = [](double x) { return std::max(fabs(x * 1e-8), 1e-14); }; - double dmu = neg_binomial_2_lpmf(N, mu_dbl + eps, phi_dbl) - - neg_binomial_2_lpmf(N, mu_dbl - eps, phi_dbl); - double dphi = neg_binomial_2_lpmf(N, mu_dbl, phi_dbl + eps) - - neg_binomial_2_lpmf(N, mu_dbl, phi_dbl - eps); - finite_diffs.push_back(dmu * inv2e); - finite_diffs.push_back(dphi * inv2e); + EXPECT_NEAR(value_of(val), t.value, tolerance(t.value)) + << "value n = " << n << ", mu = " << t.mu << ", phi = " << t.phi; + EXPECT_NEAR(gradients[0], t.grad_mu, tolerance(t.grad_mu)) + << "grad_mu n = " << n << ", mu = " << t.mu << ", phi = " << t.phi; + EXPECT_NEAR(gradients[1], t.grad_phi, tolerance(t.grad_phi)) + << "grad_phi n = " << n << ", mu = " << t.mu << ", phi = " << t.phi; + } +} - for (int i = 0; i < 2; ++i) { - EXPECT_NEAR(gradients[i], finite_diffs[i], 1.0); +TEST(ProbDistributionsNegBinomial2, derivativesComplexStep) { + using boost::math::differentiation::complex_step_derivative; + using stan::math::is_nan; + using stan::math::neg_binomial_2_log; + using stan::math::var; + using stan::test::internal::expect_near_rel_finite; + + std::vector n_to_test = {0, 7, 100, 835, 14238, 385000, 1000000}; + std::vector mu_to_test = {0.8, 8, 24, 271, 2586, 33294}; + + auto nb2_log_for_test = [](int n, const std::complex& mu, + const std::complex& phi) { + // Using first-order Taylor expansion of lgamma(a + b*i) around b = 0 + // Which happens to work nice in this case, as b is always 0 or the very + // small complex step + auto lgamma_c_approx = [](const std::complex& x) { + return std::complex(lgamma(x.real()), + x.imag() * boost::math::digamma(x.real())); + }; + + const double n_(n); + return lgamma_c_approx(n_ + phi) - lgamma(n + 1) - lgamma_c_approx(phi) + + phi * (log(phi) - log(mu + phi)) - n_ * log(mu + phi) + + n_ * log(mu); + }; + + for (double mu_dbl : mu_to_test) { + for (int n : n_to_test) { + for (double phi_dbl = 1.5; phi_dbl < 1e22; phi_dbl *= 10) { + var mu(mu_dbl); + var phi(phi_dbl); + var val = neg_binomial_2_lpmf(n, mu, phi); + + std::vector x; + x.push_back(mu); + x.push_back(phi); + + std::vector gradients; + val.grad(x, gradients); + + EXPECT_TRUE(value_of(val) < 0) + << "for n = " << n << ", mu = " << mu_dbl << ", phi = " << phi_dbl; + + for (int i = 0; i < 2; ++i) { + EXPECT_FALSE(is_nan(gradients[i])); + } + + auto nb2_log_mu + = [n, phi_dbl, nb2_log_for_test](const std::complex& mu) { + return nb2_log_for_test(n, mu, phi_dbl); + }; + auto nb2_log_phi + = [n, mu_dbl, nb2_log_for_test](const std::complex& phi) { + return nb2_log_for_test(n, mu_dbl, phi); + }; + double complex_step_dmu = complex_step_derivative(nb2_log_mu, mu_dbl); + double complex_step_dphi + = complex_step_derivative(nb2_log_phi, phi_dbl); + + std::stringstream message; + message << ", n = " << n << ", mu = " << mu_dbl + << ", phi = " << phi_dbl; + + double tolerance_phi = std::max(1e-8, fabs(gradients[1]) * 1e-5); + double tolerance_mu = std::max(1e-10, fabs(gradients[0]) * 1e-8); + + EXPECT_NEAR(gradients[0], complex_step_dmu, tolerance_mu) + << "grad_mu" << message.str(); + + EXPECT_NEAR(gradients[1], complex_step_dphi, tolerance_phi) + << "grad_phi" << message.str(); + } } + } +} + +TEST(ProbDistributionsNegBinomial2, derivativesZeroOne) { + using stan::math::var; + using stan::test::expect_near_rel; + + std::vector mu_to_test = {2.345e-5, 0.2, 13, 150, 1621, 18432, 1e10}; + double phi_start = 1e-8; + double phi_max = 1e20; + for (double mu_dbl : mu_to_test) { + for (double phi_dbl = phi_start; phi_dbl < phi_max; + phi_dbl *= stan::math::pi()) { + std::stringstream msg; + msg << std::setprecision(20) << ", mu = " << mu_dbl + << ", phi = " << phi_dbl; - phi_dbl *= 10; + var mu0(mu_dbl); + var phi0(phi_dbl); + var val0 = neg_binomial_2_lpmf(0, mu0, phi0); + + std::vector x0; + x0.push_back(mu0); + x0.push_back(phi0); + + std::vector gradients0; + val0.grad(x0, gradients0); + + var mu1(mu_dbl); + var phi1(phi_dbl); + var val1 = neg_binomial_2_lpmf(1, mu1, phi1); + + std::vector x1; + x1.push_back(mu1); + x1.push_back(phi1); + + std::vector gradients1; + val1.grad(x1, gradients1); + + double expected_dmu_0 = -phi_dbl / (mu_dbl + phi_dbl); + double expected_dphi_0 + = mu_dbl / (mu_dbl + phi_dbl) - log1p(mu_dbl / phi_dbl); + expect_near_rel("dmu, n = 0 " + msg.str(), gradients0[0], expected_dmu_0); + expect_near_rel("dphi, n = 0 " + msg.str(), gradients0[1], + expected_dphi_0); + + double expected_dmu_1 + = (phi_dbl * (1 - mu_dbl)) / (mu_dbl * (mu_dbl + phi_dbl)); + expect_near_rel("dmu, n = 1 " + msg.str(), gradients1[0], expected_dmu_1); + + double expected_dphi_1 + = mu_dbl * (phi_dbl + 1) / (phi_dbl * (mu_dbl + phi_dbl)) + + log(phi_dbl) - log(mu_dbl + phi_dbl); + expect_near_rel("dphi, n = 1 " + msg.str(), gradients1[1], + expected_dphi_1); + } } }