Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ pom.xml.asc
/out/
/.clj-kondo/.cache/
/cljs-test-runner-out/
.lsp
12 changes: 12 additions & 0 deletions src/gensql/inference/gpm/conditioned.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@
(let [merged-conditions (merge conditions simulate-conditions)]
(gpm.proto/simulate gpm targets merged-conditions)))

gpm.proto/LogProb
(logprob [_ event]
(let [
expression_list (map (fn [[variable value]] `(~'= ~variable ~value)) conditions)
conditions_event
(if (< 1 (count expression_list))
`(~'and ~@expression_list)
(first expression_list))
merged-event `(~'and ~conditions_event ~event)
]
(gpm.proto/logprob gpm merged-event)))

gpm.proto/Variables
(variables [_]
(gpm.proto/variables gpm))
Expand Down
12 changes: 6 additions & 6 deletions src/gensql/inference/gpm/crosscat.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
;; If targets are the same as constraints, the logpdf is 0.
(cond
(= targets constraints)
0.0
0.0 ;; Should be ##Inf for continuous functions
;; If the targets and constraints are not equal but the overlapping parts are,
;; just remove the overlapping keys and recur the scores.
(every? (fn [shared-key]
Expand All @@ -102,11 +102,11 @@
;; Catch overlap of targets and constraints and assure constraint is sampled.
(let [intersection (set/intersection (set targets) (set (keys constraints)))
unconstrained-targets (vec (remove intersection (set targets)))]
(->> views
(map (fn [[_ view]]
(gpm.proto/simulate view unconstrained-targets constraints)))
(filter not-empty)
(apply merge (select-keys constraints intersection)))))
(->> views
(map (fn [[_ view]]
(gpm.proto/simulate view unconstrained-targets constraints)))
(filter not-empty)
(apply merge (select-keys constraints intersection)))))
gpm.proto/Incorporate
(incorporate [this x]
(let [row-id (gensym)]
Expand Down
15 changes: 10 additions & 5 deletions src/gensql/inference/gpm/ensemble.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,27 @@
(.sample ed)))

(defrecord Ensemble [gpms]
gpm.proto/LogProb
(logprob [_ event]
(let [logprobs (map #(gpm.proto/logprob % event) gpms)]
(utils/logmeanexp logprobs)))

gpm.proto/GPM
(simulate [_ targets constraints]
(let [gpm (if-not (seq constraints)
(rand-nth gpms)
(weighted-sample
(zipmap gpms
(map #(gpm.proto/logpdf % constraints {})
gpms))))]
(zipmap gpms
(map #(gpm.proto/logpdf % constraints {})
gpms))))]
(gpm.proto/simulate gpm targets constraints)))

(logpdf [_ targets constraints]
(let [logpdfs (map #(gpm.proto/logpdf % targets constraints) gpms)]
(if (seq constraints)
(utils/logmeanexp-weighted (map #(gpm.proto/logpdf % constraints {}) gpms)
logpdfs)
(utils/logmeanexp logpdfs))))
logpdfs)
(utils/logmeanexp logpdfs))))

gpm.proto/Variables
(variables [_]
Expand Down
2 changes: 1 addition & 1 deletion src/gensql/inference/gpm/primitive_gpms/categorical.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
constrained? (if (= x x') 0 ##-Inf)
:else (let [counts (:counts suff-stats)
alpha (:alpha hyperparameters)
numer (math/log (+ alpha (get counts x)))
numer (math/log (+ alpha (get counts x 0)))
denom (math/log (+ (* alpha (count counts))
(reduce + (vals counts))))]
(- numer denom)))))
Expand Down
11 changes: 11 additions & 0 deletions test/gensql/inference/gpm/conditioned_test.cljc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@
{:x 0 :z 2} {:y 1} {:x 0 :y 1 :z 2}
{:x 0} {:y 1 :z 2} {:x 0 :y 1 :z 2}))

(deftest logprob-conditions
(are [condition-conditions event expected]
(let [model (reify gpm.proto/LogProb
(logprob [_ actual]
actual))
conditioned-model (conditioned/condition model condition-conditions)]
(= expected (gpm/logprob conditioned-model event)))

{:x 0} '(= :y 0) '(and (= :x 0) (= :y 0))
{:x 0 :y 0} '(= :z 1) '(and (and (= :x 0) (= :y 0)) (= :z 1))))

(deftest simulate-conditions
(are [c1 c2 expected]
(let [model (reify gpm.proto/GPM
Expand Down