diff --git a/.gitignore b/.gitignore index 4ceb2bb..65210fe 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ pom.xml.asc /out/ /.clj-kondo/.cache/ /cljs-test-runner-out/ +.lsp diff --git a/src/gensql/inference/gpm/conditioned.cljc b/src/gensql/inference/gpm/conditioned.cljc index e63d326..1929335 100644 --- a/src/gensql/inference/gpm/conditioned.cljc +++ b/src/gensql/inference/gpm/conditioned.cljc @@ -11,6 +11,18 @@ (let [merged-conditions (merge conditions simulate-conditions)] (gpm.proto/simulate gpm targets merged-conditions))) + gpm.proto/LogProb + (logprob [_ event] + (let [ + expression_list (map (fn [[variable value]] `(~'= ~variable ~value)) conditions) + conditions_event + (if (< 1 (count expression_list)) + `(~'and ~@expression_list) + (first expression_list)) + merged-event `(~'and ~conditions_event ~event) + ] + (gpm.proto/logprob gpm merged-event))) + gpm.proto/Variables (variables [_] (gpm.proto/variables gpm)) diff --git a/src/gensql/inference/gpm/crosscat.cljc b/src/gensql/inference/gpm/crosscat.cljc index a6e5dd3..e184fba 100644 --- a/src/gensql/inference/gpm/crosscat.cljc +++ b/src/gensql/inference/gpm/crosscat.cljc @@ -81,7 +81,7 @@ ;; If targets are the same as constraints, the logpdf is 0. (cond (= targets constraints) - 0.0 + 0.0 ;; Should be ##Inf for continuous functions ;; If the targets and constraints are not equal but the overlapping parts are, ;; just remove the overlapping keys and recur the scores. (every? (fn [shared-key] @@ -102,11 +102,11 @@ ;; Catch overlap of targets and constraints and assure constraint is sampled. (let [intersection (set/intersection (set targets) (set (keys constraints))) unconstrained-targets (vec (remove intersection (set targets)))] - (->> views - (map (fn [[_ view]] - (gpm.proto/simulate view unconstrained-targets constraints))) - (filter not-empty) - (apply merge (select-keys constraints intersection))))) + (->> views + (map (fn [[_ view]] + (gpm.proto/simulate view unconstrained-targets constraints))) + (filter not-empty) + (apply merge (select-keys constraints intersection))))) gpm.proto/Incorporate (incorporate [this x] (let [row-id (gensym)] diff --git a/src/gensql/inference/gpm/ensemble.cljc b/src/gensql/inference/gpm/ensemble.cljc index b45c20d..05ba026 100644 --- a/src/gensql/inference/gpm/ensemble.cljc +++ b/src/gensql/inference/gpm/ensemble.cljc @@ -21,22 +21,27 @@ (.sample ed))) (defrecord Ensemble [gpms] + gpm.proto/LogProb + (logprob [_ event] + (let [logprobs (map #(gpm.proto/logprob % event) gpms)] + (utils/logmeanexp logprobs))) + gpm.proto/GPM (simulate [_ targets constraints] (let [gpm (if-not (seq constraints) (rand-nth gpms) (weighted-sample - (zipmap gpms - (map #(gpm.proto/logpdf % constraints {}) - gpms))))] + (zipmap gpms + (map #(gpm.proto/logpdf % constraints {}) + gpms))))] (gpm.proto/simulate gpm targets constraints))) (logpdf [_ targets constraints] (let [logpdfs (map #(gpm.proto/logpdf % targets constraints) gpms)] (if (seq constraints) (utils/logmeanexp-weighted (map #(gpm.proto/logpdf % constraints {}) gpms) - logpdfs) - (utils/logmeanexp logpdfs)))) + logpdfs) + (utils/logmeanexp logpdfs)))) gpm.proto/Variables (variables [_] diff --git a/src/gensql/inference/gpm/primitive_gpms/categorical.cljc b/src/gensql/inference/gpm/primitive_gpms/categorical.cljc index c635808..5a69ae7 100644 --- a/src/gensql/inference/gpm/primitive_gpms/categorical.cljc +++ b/src/gensql/inference/gpm/primitive_gpms/categorical.cljc @@ -16,7 +16,7 @@ constrained? (if (= x x') 0 ##-Inf) :else (let [counts (:counts suff-stats) alpha (:alpha hyperparameters) - numer (math/log (+ alpha (get counts x))) + numer (math/log (+ alpha (get counts x 0))) denom (math/log (+ (* alpha (count counts)) (reduce + (vals counts))))] (- numer denom))))) diff --git a/test/gensql/inference/gpm/conditioned_test.cljc b/test/gensql/inference/gpm/conditioned_test.cljc index f334cfb..748e587 100644 --- a/test/gensql/inference/gpm/conditioned_test.cljc +++ b/test/gensql/inference/gpm/conditioned_test.cljc @@ -55,6 +55,17 @@ {:x 0 :z 2} {:y 1} {:x 0 :y 1 :z 2} {:x 0} {:y 1 :z 2} {:x 0 :y 1 :z 2})) +(deftest logprob-conditions + (are [condition-conditions event expected] + (let [model (reify gpm.proto/LogProb + (logprob [_ actual] + actual)) + conditioned-model (conditioned/condition model condition-conditions)] + (= expected (gpm/logprob conditioned-model event))) + + {:x 0} '(= :y 0) '(and (= :x 0) (= :y 0)) + {:x 0 :y 0} '(= :z 1) '(and (and (= :x 0) (= :y 0)) (= :z 1)))) + (deftest simulate-conditions (are [c1 c2 expected] (let [model (reify gpm.proto/GPM