diff --git a/_CoqProject b/_CoqProject index 9521968f49..1b5e659c18 100644 --- a/_CoqProject +++ b/_CoqProject @@ -45,6 +45,12 @@ theories/itv.v theories/convex.v theories/charge.v theories/kernel.v +theories/prob_lang.v +theories/prob_lang_wip.v +theories/lang_syntax_util.v +theories/lang_syntax_toy.v +theories/lang_syntax.v +theories/lang_syntax_examples.v theories/altreals/xfinmap.v theories/altreals/discrete.v theories/altreals/realseq.v diff --git a/coq-mathcomp-analysis.opam b/coq-mathcomp-analysis.opam index cf729d78ed..af86774c4f 100644 --- a/coq-mathcomp-analysis.opam +++ b/coq-mathcomp-analysis.opam @@ -22,6 +22,7 @@ depends: [ "coq-mathcomp-solvable" { (>= "1.15.0" & < "1.18~") | (= "dev") } "coq-mathcomp-field" "coq-mathcomp-bigenough" { (>= "1.0.0") } + "coq-mathcomp-algebra-tactics" { (>= "1.1.1") } ] tags: [ diff --git a/theories/Make b/theories/Make index cd6285c45a..12d3f3af7b 100644 --- a/theories/Make +++ b/theories/Make @@ -36,6 +36,12 @@ itv.v convex.v charge.v kernel.v +prob_lang.v +prob_lang_wip.v +lang_syntax_util.v +lang_syntax_toy.v +lang_syntax.v +lang_syntax_examples.v altreals/xfinmap.v altreals/discrete.v altreals/realseq.v diff --git a/theories/kernel.v b/theories/kernel.v index c6975d0e41..dcde18301e 100644 --- a/theories/kernel.v +++ b/theories/kernel.v @@ -753,36 +753,6 @@ HB.instance Definition _ t := Kernel_isFinite.Build _ _ _ _ R (kadd k1 k2) kadd_finite_uub. End fkadd. -Lemma measurable_fun_mnormalize d d' (X : measurableType d) - (Y : measurableType d') (R : realType) (k : R.-ker X ~> Y) : - measurable_fun [set: X] (fun x => - [the probability _ _ of mnormalize (k x) point] : pprobability Y R). -Proof. -apply: (@measurability _ _ _ _ _ _ - (@pset _ _ _ : set (set (pprobability Y R)))) => //. -move=> _ -[_ [r r01] [Ys mYs <-]] <-. -rewrite /mnormalize /mset /preimage/=. -apply: emeasurable_fun_infty_o => //. -rewrite /mnormalize/=. -rewrite (_ : (fun x => _) = (fun x => if (k x setT == 0) || (k x setT == +oo) - then \d_point Ys else k x Ys * ((fine (k x setT))^-1)%:E)); last first. - by apply/funext => x/=; case: ifPn. -apply: measurable_fun_if => //. -- apply: (measurable_fun_bool true) => //. - rewrite (_ : _ @^-1` _ = [set t | k t setT == 0] `|` - [set t | k t setT == +oo]); last first. - by apply/seteqP; split=> x /= /orP//. - by apply: measurableU; exact: kernel_measurable_eq_cst. -- apply/emeasurable_funM; first exact/measurable_funTS/measurable_kernel. - apply/EFin_measurable_fun; rewrite setTI. - apply: (@measurable_comp _ _ _ _ _ _ [set r : R | r != 0%R]). - + exact: open_measurable. - + by move=> /= _ [x /norP[s0 soo]] <-; rewrite -eqe fineK ?ge0_fin_numE ?ltey. - + apply: open_continuous_measurable_fun => //; apply/in_setP => x /= x0. - exact: inv_continuous. - + by apply: measurableT_comp => //; exact/measurable_funS/measurable_kernel. -Qed. - Section knormalize. Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). Variable f : R.-ker X ~> Y. @@ -790,9 +760,7 @@ Variable f : R.-ker X ~> Y. Definition knormalize (P : probability Y R) : X -> {measure set Y -> \bar R} := fun x => [the measure _ _ of mnormalize (f x) P]. -Variable P : probability Y R. - -Let measurable_fun_knormalize U : +Let measurable_knormalize (P : probability Y R) U : measurable U -> measurable_fun [set: X] (knormalize P ^~ U). Proof. move=> mU; rewrite /knormalize/= /mnormalize /=. @@ -809,7 +777,7 @@ apply: measurable_fun_if => //. - apply: (@measurable_funS _ _ _ _ setT) => //. exact: kernel_measurable_fun_eq_cst. - apply: emeasurable_funM. - by have := measurable_kernel f U mU; exact: measurable_funS. + exact: measurable_funS (measurable_kernel f U mU). apply/EFin_measurable_fun. apply: (@measurable_comp _ _ _ _ _ _ [set r : R | r != 0%R]) => //. + exact: open_measurable. @@ -822,14 +790,14 @@ apply: measurable_fun_if => //. by have := measurable_kernel f _ measurableT; exact: measurable_funS. Qed. -HB.instance Definition _ := isKernel.Build _ _ _ _ R (knormalize P) - measurable_fun_knormalize. +HB.instance Definition _ (P : probability Y R) := + isKernel.Build _ _ _ _ R (knormalize P) (measurable_knormalize P). -Let knormalize1 x : knormalize P x [set: Y] = 1. +Let knormalize1 (P : probability Y R) x : knormalize P x [set: Y] = 1. Proof. by rewrite /knormalize/= probability_setT. Qed. -HB.instance Definition _ := - @Kernel_isProbability.Build _ _ _ _ _ (knormalize P) knormalize1. +HB.instance Definition _ (P : probability Y R):= + @Kernel_isProbability.Build _ _ _ _ _ (knormalize P) (knormalize1 P). End knormalize. diff --git a/theories/lang_syntax.v b/theories/lang_syntax.v new file mode 100644 index 0000000000..19368d333a --- /dev/null +++ b/theories/lang_syntax.v @@ -0,0 +1,1283 @@ +Require Import String. +From HB Require Import structures. +From mathcomp Require Import all_ssreflect ssralg ssrnum ssrint interval. +From mathcomp.classical Require Import mathcomp_extra boolp classical_sets. +From mathcomp.classical Require Import functions cardinality fsbigop. +Require Import signed reals ereal topology normedtype sequences esum measure. +Require Import lebesgue_measure numfun lebesgue_integral itv kernel prob_lang. +Require Import lang_syntax_util. +From mathcomp Require Import ring lra. + +(******************************************************************************) +(* Syntax and Evaluation for a Probabilistic Programming Language *) +(* *) +(* typ == syntax for types of data structures *) +(* measurable_of_typ t == the measurable type corresponding to type t *) +(* It is of type {d & measurableType d} *) +(* mtyp_disp t == the display corresponding to type t *) +(* mtyp t == the measurable type corresponding to type t *) +(* It is of type measurableType (mtyp_disp t) *) +(* measurable_of_seq s == the product space corresponding to the *) +(* list s : seq typ *) +(* It is of type {d & measurableType d} *) +(* acc_typ s n == function that access the nth element of s : seq typ *) +(* It is a function from projT2 (measurable_of_seq s) *) +(* to projT2 (measurable_of_typ (nth Unit s n)) *) +(* ctx == type of context *) +(* := seq (string * type) *) +(* mctx_disp g == the display corresponding to the context g *) +(* mctx g := the measurable type corresponding to the context g *) +(* It is formed of nested pairings of measurable *) +(* spaces. It is of type measurableType (mctx_disp g) *) +(* flag == a flag is either D (deterministic) or *) +(* P (probabilistic) *) +(* exp f g t == syntax of expressions with flag f of type t *) +(* context g *) +(* dval R g t == "deterministic value", i.e., *) +(* function from mctx g to mtyp t *) +(* pval R g t == "probabilistic value", i.e., *) +(* s-finite kernel, from mctx g to mtyp t *) +(* mkswap k == given a kernel k : (Y * X) ~> Z, *) +(* returns a kernel of type (X * Y) ~> Z *) +(* letin' := mkcomp \o mkswap *) +(* e -D> f ; mf == the evaluation of the deterministic expression e *) +(* leads to the deterministic value f *) +(* (mf is the proof that f is measurable) *) +(* e -P> k == the evaluation of the probabilistic function f *) +(* leads to the probabilistic value k *) +(* execD e == a dependent pair of a function corresponding to the *) +(* evaluation of e and a proof that this function is *) +(* measurable *) +(* execP e == a s-finite kernel corresponding to the evaluation *) +(* of the probabilistic expression e *) +(* *) +(******************************************************************************) + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +Import Order.TTheory GRing.Theory Num.Def Num.Theory. +Import numFieldTopology.Exports. + +Reserved Notation "f .; g" (at level 60, right associativity, + format "f .; '/ ' g"). +Reserved Notation "e -D> f ; mf" (at level 40). +Reserved Notation "e -P> k" (at level 40). + +Local Open Scope classical_set_scope. +Local Open Scope ring_scope. +Local Open Scope ereal_scope. + +(* TODO: mv *) +Arguments measurable_fst {d1 d2 T1 T2}. +Arguments measurable_snd {d1 d2 T1 T2}. + +Section mswap. +Context d d' d3 (X : measurableType d) (Y : measurableType d') + (Z : measurableType d3) (R : realType). +Variable k : R.-ker Y * X ~> Z. + +Definition mswap xy U := k (swap xy) U. + +Let mswap0 xy : mswap xy set0 = 0. +Proof. done. Qed. + +Let mswap_ge0 x U : 0 <= mswap x U. +Proof. done. Qed. + +Let mswap_sigma_additive x : semi_sigma_additive (mswap x). +Proof. exact: measure_semi_sigma_additive. Qed. + +HB.instance Definition _ x := isMeasure.Build _ R _ + (mswap x) (mswap0 x) (mswap_ge0 x) (@mswap_sigma_additive x). + +Definition mkswap : _ -> {measure set Z -> \bar R} := + fun x => mswap x. + +Let measurable_fun_kswap U : + measurable U -> measurable_fun setT (mkswap ^~ U). +Proof. +move=> mU. +rewrite [X in measurable_fun _ X](_ : _ = k ^~ U \o @swap _ _)//. +apply measurableT_comp => //=; first exact: measurable_kernel. +exact: measurable_swap. +Qed. + +HB.instance Definition _ := isKernel.Build _ _ + (X * Y)%type Z R mkswap measurable_fun_kswap. + +End mswap. + +Section mswap_sfinite_kernel. +Variables (d d' d3 : _) (X : measurableType d) (Y : measurableType d') + (Z : measurableType d3) (R : realType). +Variable k : R.-sfker Y * X ~> Z. + +Let mkswap_sfinite : + exists2 k_ : (R.-ker X * Y ~> Z)^nat, + forall n, measure_fam_uub (k_ n) & + forall x U, measurable U -> mkswap k x U = kseries k_ x U. +Proof. +have [k_ /= kE] := sfinite_kernel k. +exists (fun n => mkswap (k_ n)). + move=> n. + have /measure_fam_uubP[M hM] := measure_uub (k_ n). + by exists M%:num => x/=; exact: hM. +move=> xy U mU. +by rewrite /mswap/= kE. +Qed. + +HB.instance Definition _ := + Kernel_isSFinite_subdef.Build _ _ _ Z R (mkswap k) mkswap_sfinite. + +End mswap_sfinite_kernel. + +Section kswap_finite_kernel_finite. +Context d d' d3 (X : measurableType d) (Y : measurableType d') + (Z : measurableType d3) (R : realType) + (k : R.-fker Y * X ~> Z). + +Let mkswap_finite : measure_fam_uub (mkswap k). +Proof. +have /measure_fam_uubP[r hr] := measure_uub k. +apply/measure_fam_uubP; exists (PosNum [gt0 of r%:num%R]) => x /=. +exact: hr. +Qed. + +HB.instance Definition _ := + Kernel_isFinite.Build _ _ _ Z R (mkswap k) mkswap_finite. + +End kswap_finite_kernel_finite. + +Notation "l .; k" := (mkcomp l (mkswap k)) : ereal_scope. + +Section letin'. +Variables (d d' d3 : _) (X : measurableType d) (Y : measurableType d') + (Z : measurableType d3) (R : realType). + +Definition letin' (l : R.-sfker X ~> Y) (k : R.-sfker Y * X ~> Z) := + locked [the R.-sfker X ~> Z of l .; k]. + +Lemma letin'E (l : R.-sfker X ~> Y) (k : R.-sfker Y * X ~> Z) x U : + letin' l k x U = \int[l x]_y k (y, x) U. +Proof. by rewrite /letin'; unlock. Qed. + +Lemma letin'_letin (l : R.-sfker X ~> Y) (k : R.-sfker Y * X ~> Z) : + letin' l k = letin l (mkswap k). +Proof. by rewrite /letin'; unlock. Qed. + +End letin'. + +Section letin'C. +Import Notations. +Context d d1 d' (X : measurableType d) (Y : measurableType d1) + (Z : measurableType d') (R : realType). +Variables (t : R.-sfker Z ~> X) + (u' : R.-sfker X * Z ~> Y) + (u : R.-sfker Z ~> Y) + (t' : R.-sfker Y * Z ~> X) + (tt' : forall y, t =1 fun z => t' (y, z)) + (uu' : forall x, u =1 fun z => u' (x, z)). + +Definition T' z : set X -> \bar R := t z. +Let T0 z : (T' z) set0 = 0. Proof. by []. Qed. +Let T_ge0 z x : 0 <= (T' z) x. Proof. by []. Qed. +Let T_semi_sigma_additive z : semi_sigma_additive (T' z). +Proof. exact: measure_semi_sigma_additive. Qed. +HB.instance Definition _ z := @isMeasure.Build _ R X (T' z) (T0 z) (T_ge0 z) + (@T_semi_sigma_additive z). + +Let sfinT z : sfinite_measure (T' z). Proof. exact: sfinite_kernel_measure. Qed. +HB.instance Definition _ z := @Measure_isSFinite_subdef.Build _ X R + (T' z) (sfinT z). + +Definition U' z : set Y -> \bar R := u z. +Let U0 z : (U' z) set0 = 0. Proof. by []. Qed. +Let U_ge0 z x : 0 <= (U' z) x. Proof. by []. Qed. +Let U_semi_sigma_additive z : semi_sigma_additive (U' z). +Proof. exact: measure_semi_sigma_additive. Qed. +HB.instance Definition _ z := @isMeasure.Build _ R Y (U' z) (U0 z) (U_ge0 z) + (@U_semi_sigma_additive z). + +Let sfinU z : sfinite_measure (U' z). Proof. exact: sfinite_kernel_measure. Qed. +HB.instance Definition _ z := @Measure_isSFinite_subdef.Build _ Y R + (U' z) (sfinU z). + +Lemma letin'C z A : measurable A -> + letin' t + (letin' u' + (ret (measurable_fun_prod macc1of3' macc0of3'))) z A = + letin' u + (letin' t' + (ret (measurable_fun_prod macc0of3' macc1of3'))) z A. +Proof. +move=> mA. +rewrite !letin'E. +under eq_integral. + move=> x _. + rewrite letin'E -uu'. + under eq_integral do rewrite retE /=. + over. +rewrite (sfinite_Fubini (T' z) (U' z) (fun x => \d_(x.1, x.2) A ))//; last first. + apply/EFin_measurable_fun => /=; rewrite (_ : (fun x => _) = mindic R mA)//. + by apply/funext => -[]. +rewrite /=. +apply: eq_integral => y _. +by rewrite letin'E/= -tt'; apply: eq_integral => // x _; rewrite retE. +Qed. + +End letin'C. +Arguments letin'C {d d1 d' X Y Z R} _ _ _ _. + +Section letin'A. +Context d d' d1 d2 d3 (X : measurableType d) (Y : measurableType d') + (T1 : measurableType d1) (T2 : measurableType d2) (T3 : measurableType d3) + (R : realType). +Import Notations. +Variables (t : R.-sfker X ~> T1) + (u : R.-sfker T1 * X ~> T2) + (v : R.-sfker T2 * X ~> Y) + (v' : R.-sfker T2 * (T1 * X) ~> Y) + (vv' : forall y, v =1 fun xz => v' (xz.1, (y, xz.2))). + +Lemma letin'A x A : measurable A -> + letin' t (letin' u v') x A + = + (letin' (letin' t u) v) x A. +Proof. +move=> mA. +rewrite !letin'E. +under eq_integral do rewrite letin'E. +rewrite letin'_letin/=. +rewrite integral_kcomp; [|by []|]. + apply: eq_integral => z _. + apply: eq_integral => y _. + by rewrite (vv' z). +exact: measurableT_comp (@measurable_kernel _ _ _ _ _ v _ mA) _. +Qed. + +End letin'A. + +Declare Scope lang_scope. +Delimit Scope lang_scope with P. + +Section syntax_of_types. +Import Notations. +Context {R : realType}. + +Inductive typ := +| Unit | Bool | Real +| Pair : typ -> typ -> typ +| Prob : typ -> typ. + +Canonical stype_eqType := Equality.Pack (@gen_eqMixin typ). + +Fixpoint measurable_of_typ (t : typ) : {d & measurableType d} := + match t with + | Unit => existT _ _ munit + | Bool => existT _ _ mbool + | Real => existT _ _ (mR R) + | Pair A B => existT _ _ + [the measurableType (projT1 (measurable_of_typ A), + projT1 (measurable_of_typ B)).-prod%mdisp of + (projT2 (measurable_of_typ A) * + projT2 (measurable_of_typ B))%type] + | Prob A => existT _ _ (pprobability (projT2 (measurable_of_typ A)) R) + end. + +Definition mtyp_disp t : measure_display := projT1 (measurable_of_typ t). + +Definition mtyp t : measurableType (mtyp_disp t) := + projT2 (measurable_of_typ t). + +Definition measurable_of_seq (l : seq typ) : {d & measurableType d} := + iter_mprod (map measurable_of_typ l). + +End syntax_of_types. +Arguments measurable_of_typ {R}. +Arguments mtyp {R}. +Arguments measurable_of_seq {R}. + +Section accessor_functions. +Context {R : realType}. + +(* NB: almost the same as acc (map (@measurable_of_typ R) s) n l, + modulo commutativity of map and measurable_of_typ *) +Fixpoint acc_typ (s : seq typ) n : + projT2 (@measurable_of_seq R s) -> + projT2 (measurable_of_typ (nth Unit s n)) := + match s return + projT2 (measurable_of_seq s) -> projT2 (measurable_of_typ (nth Unit s n)) + with + | [::] => match n with | 0 => (fun=> tt) | m.+1 => (fun=> tt) end + | a :: l => match n with + | 0 => fst + | m.+1 => fun H => @acc_typ l m H.2 + end + end. + +(*Definition acc_typ : forall (s : seq typ) n, + projT2 (@measurable_of_seq R s) -> + projT2 (@measurable_of_typ R (nth Unit s n)). +fix H 1. +intros s n x. +destruct s as [|s]. + destruct n as [|n]. + exact tt. + exact tt. +destruct n as [|n]. + exact (fst x). +rewrite /=. +apply H. +exact: (snd x). +Show Proof. +Defined.*) + +Lemma measurable_acc_typ (s : seq typ) n : measurable_fun setT (@acc_typ s n). +Proof. +elim: s n => //= h t ih [|m]; first exact: measurable_fst. +by apply: (measurableT_comp (ih _)); exact: measurable_snd. +Qed. + +End accessor_functions. +Arguments acc_typ {R} s n. +Arguments measurable_acc_typ {R} s n. + +Section context. +Variables (R : realType). +Definition ctx := seq (string * typ). + +Definition mctx_disp (g : ctx) := projT1 (@measurable_of_seq R (map snd g)). + +Definition mctx (g : ctx) : measurableType (mctx_disp g) := + projT2 (@measurable_of_seq R (map snd g)). + +End context. +Arguments mctx {R}. + +Section syntax_of_expressions. +Context {R : realType}. + +Inductive flag := D | P. + +Section binop. + +Inductive binop := +| binop_and | binop_or +| binop_add | binop_minus | binop_mult. + +Definition type_of_binop (b : binop) : typ := +match b with +| binop_and => Bool +| binop_or => Bool +| binop_add => Real +| binop_minus => Real +| binop_mult => Real +end. + +(* Import Notations. *) + +Definition fun_of_binop g (b : binop) : (mctx g -> mtyp (type_of_binop b)) -> + (mctx g -> mtyp (type_of_binop b)) -> @mctx R g -> @mtyp R (type_of_binop b) := +match b with +| binop_and => (fun f1 f2 x => f1 x && f2 x : mtyp Bool) +| binop_or => (fun f1 f2 x => f1 x || f2 x : mtyp Bool) +| binop_add => (fun f1 f2 => (f1 \+ f2)%R) +| binop_minus => (fun f1 f2 => (f1 \- f2)%R) +| binop_mult => (fun f1 f2 => (f1 \* f2)%R) +end. + +Definition mfun_of_binop g b + (f1 : @mctx R g -> @mtyp R (type_of_binop b)) (mf1 : measurable_fun setT f1) + (f2 : @mctx R g -> @mtyp R (type_of_binop b)) (mf2 : measurable_fun setT f2) : + measurable_fun [set: @mctx R g] (fun_of_binop f1 f2). +destruct b. +exact: measurable_and mf1 mf2. +exact: measurable_or mf1 mf2. +exact: measurable_funD. +exact: measurable_funB. +exact: measurable_funM. +Defined. + +End binop. + +Section relop. +Inductive relop := +| relop_le | relop_lt | relop_eq . + +Definition fun_of_relop g (r : relop) : (@mctx R g -> @mtyp R Real) -> + (mctx g -> mtyp Real) -> @mctx R g -> @mtyp R Bool := +match r with +| relop_le => (fun f1 f2 x => (f1 x <= f2 x)%R) +| relop_lt => (fun f1 f2 x => (f1 x < f2 x)%R) +| relop_eq => (fun f1 f2 x => (f1 x == f2 x)%R) +end. + +Definition mfun_of_relop g r + (f1 : @mctx R g -> @mtyp R Real) (mf1 : measurable_fun setT f1) + (f2 : @mctx R g -> @mtyp R Real) (mf2 : measurable_fun setT f2) : + measurable_fun [set: @mctx R g] (fun_of_relop r f1 f2). +destruct r. +exact: measurable_fun_ler. +exact: measurable_fun_ltr. +exact: measurable_fun_eqr. +Defined. + +End relop. + +Inductive exp : flag -> ctx -> typ -> Type := +| exp_unit g : exp D g Unit +| exp_bool g : bool -> exp D g Bool +| exp_real g : R -> exp D g Real +| exp_bin g (b : binop) : exp D g (type_of_binop b) -> + exp D g (type_of_binop b) -> exp D g (type_of_binop b) +| exp_rel g (b : relop) : exp D g Real -> + exp D g Real -> exp D g Bool +| exp_pair g t1 t2 : exp D g t1 -> exp D g t2 -> exp D g (Pair t1 t2) +| exp_proj1 g t1 t2 : exp D g (Pair t1 t2) -> exp D g t1 +| exp_proj2 g t1 t2 : exp D g (Pair t1 t2) -> exp D g t2 +| exp_var g str t : t = lookup Unit g str -> exp D g t +| exp_bernoulli g (r : {nonneg R}) (r1 : (r%:num <= 1)%R) : + exp D g (Prob Bool) +| exp_binomial g (n : nat) (r : {nonneg R}) (r1 : (r%:num <= 1)%R) : + exp D g (Prob Real) +| exp_uniform g (a b : R) (ab0 : (0 < b - a)%R) : exp D g (Prob Real) +| exp_poisson g : nat -> exp D g Real -> exp D g Real +| exp_normalize g t : exp P g t -> exp D g (Prob t) +| exp_letin g t1 t2 str : exp P g t1 -> exp P ((str, t1) :: g) t2 -> + exp P g t2 +| exp_sample g t : exp D g (Prob t) -> exp P g t +| exp_score g : exp D g Real -> exp P g Unit +| exp_return g t : exp D g t -> exp P g t +| exp_if z g t : exp D g Bool -> exp z g t -> exp z g t -> exp z g t +| exp_weak z g h t x : exp z (g ++ h) t -> + x.1 \notin dom (g ++ h) -> exp z (g ++ x :: h) t. +Arguments exp_var {g} _ {t}. + +Definition exp_var' (str : string) (t : typ) (g : find str t) := + @exp_var (untag (ctx_of g)) str t (ctx_prf g). +Arguments exp_var' str {t} g. + +Lemma exp_var'E str t (f : find str t) H : + exp_var' str f = exp_var str H :> (@exp _ _ _). +Proof. by rewrite /exp_var'; congr exp_var. Qed. + +End syntax_of_expressions. +Arguments exp {R}. +Arguments exp_unit {R g}. +Arguments exp_bool {R g}. +Arguments exp_real {R g}. +Arguments exp_bin {R g} &. +Arguments exp_rel {R g} &. +Arguments exp_pair {R g} & {t1 t2}. +Arguments exp_var {R g} _ {t} H. +Arguments exp_bernoulli {R g}. +Arguments exp_binomial {R g}. +Arguments exp_uniform {R g}. +Arguments exp_poisson {R g}. +Arguments exp_normalize {R g _}. +Arguments exp_letin {R g} & {_ _}. +Arguments exp_sample {R g t}. +Arguments exp_score {R g}. +Arguments exp_return {R g} & {_}. +Arguments exp_if {R z g t}. +Arguments exp_weak {R} z g h {t} x. +Arguments exp_var' {R} str {t} g. + +Declare Custom Entry expr. +Notation "[ e ]" := e (e custom expr at level 5) : lang_scope. +Notation "'TT'" := (exp_unit) (in custom expr at level 1) : lang_scope. +Notation "b ':B'" := (@exp_bool _ _ b%bool) + (in custom expr at level 1) : lang_scope. +Notation "r ':R'" := (@exp_real _ _ r%R) + (in custom expr at level 1, format "r :R") : lang_scope. +Notation "e1 && e2" := (exp_bin binop_and e1 e2) + (in custom expr at level 1) : lang_scope. +Notation "e1 || e2" := (exp_bin binop_or e1 e2) + (in custom expr at level 1) : lang_scope. +Notation "e1 + e2" := (exp_bin binop_add e1 e2) + (in custom expr at level 1) : lang_scope. +Notation "e1 - e2" := (exp_bin binop_minus e1 e2) + (in custom expr at level 1) : lang_scope. +Notation "e1 * e2" := (exp_bin binop_mult e1 e2) + (in custom expr at level 1) : lang_scope. +Notation "e1 <= e2" := (exp_rel relop_le e1 e2) + (in custom expr at level 2) : lang_scope. +Notation "e1 == e2" := (exp_rel relop_eq e1 e2) + (in custom expr at level 2) : lang_scope. +Notation "'return' e" := (@exp_return _ _ _ e) + (in custom expr at level 3) : lang_scope. +(*Notation "% str" := (@exp_var _ _ str%string _ erefl) + (in custom expr at level 1, format "% str") : lang_scope.*) +(* Notation "% str H" := (@exp_var _ _ str%string _ H) + (in custom expr at level 1, format "% str H") : lang_scope. *) +Notation "# str" := (@exp_var' _ str%string _ _) + (in custom expr at level 1, format "# str"). +Notation "e :+ str" := (exp_weak _ [::] _ (str, _) e erefl) + (in custom expr at level 1) : lang_scope. +Notation "( e1 , e2 )" := (exp_pair e1 e2) + (in custom expr at level 1) : lang_scope. +Notation "\pi_1 e" := (exp_proj1 e) + (in custom expr at level 1) : lang_scope. +Notation "\pi_2 e" := (exp_proj2 e) + (in custom expr at level 1) : lang_scope. +Notation "'let' x ':=' e 'in' f" := (exp_letin x e f) + (in custom expr at level 3, + x constr, + f custom expr at level 3, + left associativity) : lang_scope. +Notation "{ c }" := c (in custom expr, c constr) : lang_scope. +Notation "x" := x + (in custom expr at level 0, x ident) : lang_scope. +Notation "'Sample' e" := (exp_sample e) + (in custom expr at level 2) : lang_scope. +Notation "'Score' e" := (exp_score e) + (in custom expr at level 2) : lang_scope. +Notation "'Normalize' e" := (exp_normalize e) + (in custom expr at level 0) : lang_scope. +Notation "'if' e1 'then' e2 'else' e3" := (exp_if e1 e2 e3) + (in custom expr at level 1) : lang_scope. + +Section free_vars. +Context {R : realType}. + +Fixpoint free_vars k g t (e : @exp R k g t) : seq string := + match e with + | exp_unit _ => [::] + | exp_bool _ _ => [::] + | exp_real _ _ => [::] + | exp_bin _ _ e1 e2 => free_vars e1 ++ free_vars e2 + | exp_rel _ _ e1 e2 => free_vars e1 ++ free_vars e2 + | exp_pair _ _ _ e1 e2 => free_vars e1 ++ free_vars e2 + | exp_proj1 _ _ _ e => free_vars e + | exp_proj2 _ _ _ e => free_vars e + | exp_var _ x _ _ => [:: x] + | exp_bernoulli _ _ _ => [::] + | exp_binomial _ _ _ _ => [::] + | exp_uniform _ _ _ _ => [::] + | exp_poisson _ _ e => free_vars e + | exp_normalize _ _ e => free_vars e + | exp_letin _ _ _ x e1 e2 => free_vars e1 ++ rem x (free_vars e2) + | exp_sample _ _ _ => [::] + | exp_score _ e => free_vars e + | exp_return _ _ e => free_vars e + | exp_if _ _ _ e1 e2 e3 => free_vars e1 ++ free_vars e2 ++ free_vars e3 + | exp_weak _ _ _ _ x e _ => rem x.1 (free_vars e) + end. + +End free_vars. + +Definition dval R g t := @mctx R g -> @mtyp R t. +Definition pval R g t := R.-sfker @mctx R g ~> @mtyp R t. + +Section weak. +Context {R : realType}. +Implicit Types (g h : ctx) (x : string * typ). + +Fixpoint mctx_strong g h x (f : @mctx R (g ++ x :: h)) : @mctx R (g ++ h) := + match g as g0 return mctx (g0 ++ x :: h) -> mctx (g0 ++ h) with + | [::] => fun f0 : mctx ([::] ++ x :: h) => let (a, b) := f0 in (fun=> id) a b + | a :: t => uncurry (fun a b => (a, @mctx_strong t h x b)) + end f. + +Definition weak g h x t (f : dval R (g ++ h) t) : dval R (g ++ x :: h) t := + f \o @mctx_strong g h x. + +Lemma measurable_fun_mctx_strong g h x : + measurable_fun setT (@mctx_strong g h x). +Proof. +elim: g h x => [h x|x g ih h x0]; first exact: measurable_snd. +apply/prod_measurable_funP; split. +- rewrite [X in measurable_fun _ X](_ : _ = fst)//. + by apply/funext => -[]. +- rewrite [X in measurable_fun _ X](_ : _ = @mctx_strong g h x0 \o snd). + apply: measurableT_comp; last exact: measurable_snd. + exact: ih. + by apply/funext => -[]. +Qed. + +Lemma measurable_weak g h x t (f : dval R (g ++ h) t) : + measurable_fun setT f -> measurable_fun setT (@weak g h x t f). +Proof. +move=> mf; apply: measurableT_comp; first exact: mf. +exact: measurable_fun_mctx_strong. +Qed. + +Definition kweak g h x t (f : pval R (g ++ h) t) + : @mctx R (g ++ x :: h) -> {measure set @mtyp R t -> \bar R} := + f \o @mctx_strong g h x. + +Section kernel_weak. +Context g h x t (f : pval R (g ++ h) t). + +Let mf U : measurable U -> measurable_fun setT (@kweak g h x t f ^~ U). +Proof. +move=> mU. +rewrite (_ : kweak _ ^~ U = f ^~ U \o @mctx_strong g h x)//. +apply: measurableT_comp => //; first exact: measurable_kernel. +exact: measurable_fun_mctx_strong. +Qed. + +HB.instance Definition _ := isKernel.Build _ _ _ _ _ (@kweak g h x t f) mf. +End kernel_weak. + +Section sfkernel_weak. +Context g h (x : string * typ) t (f : pval R (g ++ h) t). + +Let sf : exists2 s : (R.-ker @mctx R (g ++ x :: h) ~> @mtyp R t)^nat, + forall n, measure_fam_uub (s n) & + forall z U, measurable U -> (@kweak g h x t f) z U = kseries s z U . +Proof. +have [s hs] := sfinite_kernel f. +exists (fun n => @kweak g h x t (s n)). + by move=> n; have [M hM] := measure_uub (s n); exists M => x0; exact: hM. +by move=> z U mU; by rewrite /kweak/= hs. +Qed. + +HB.instance Definition _ := + Kernel_isSFinite_subdef.Build _ _ _ _ _ (@kweak g h x t f) sf. + +End sfkernel_weak. + +Section fkernel_weak. +Context g h x t (f : R.-fker @mctx R (g ++ h) ~> @mtyp R t). + +Let uub : measure_fam_uub (@kweak g h x t f). +Proof. by have [M hM] := measure_uub f; exists M => x0; exact: hM. Qed. + +HB.instance Definition _ := @Kernel_isFinite.Build _ _ _ _ _ + (@kweak g h x t f) uub. +End fkernel_weak. + +End weak. +Arguments weak {R} g h x {t}. +Arguments measurable_weak {R} g h x {t}. +Arguments kweak {R} g h x {t}. + +Section eval. +Context {R : realType}. +Implicit Type (g : ctx) (str : string). +Local Open Scope lang_scope. + +Context (a b : R) (ab0 : (0 < b - a)%R). + +Inductive evalD : forall g t, exp D g t -> + forall f : dval R g t, measurable_fun setT f -> Prop := +| eval_unit g : ([TT] : exp D g _) -D> cst tt ; ktt + +| eval_bool g b : ([b:B] : exp D g _) -D> cst b ; kb b + +| eval_real g r : ([r:R] : exp D g _) -D> cst r ; kr r + +| eval_bin g bop (e1 : exp D g _) f1 mf1 e2 f2 mf2 : + e1 -D> f1 ; mf1 -> e2 -D> f2 ; mf2 -> + exp_bin bop e1 e2 -D> fun_of_binop f1 f2 ; mfun_of_binop mf1 mf2 + +| eval_rel g rop (e1 : exp D g _) f1 mf1 e2 f2 mf2 : + e1 -D> f1 ; mf1 -> e2 -D> f2 ; mf2 -> + exp_rel rop e1 e2 -D> fun_of_relop rop f1 f2 ; mfun_of_relop rop mf1 mf2 + +| eval_pair g t1 (e1 : exp D g t1) f1 mf1 t2 (e2 : exp D g t2) f2 mf2 : + e1 -D> f1 ; mf1 -> e2 -D> f2 ; mf2 -> + [(e1, e2)] -D> fun x => (f1 x, f2 x) ; measurable_fun_prod mf1 mf2 + +| eval_proj1 g t1 t2 (e : exp D g (Pair t1 t2)) f mf : + e -D> f ; mf -> + [\pi_1 e] -D> fst \o f ; measurableT_comp measurable_fst mf + +| eval_proj2 g t1 t2 (e : exp D g (Pair t1 t2)) f mf : + e -D> f ; mf -> + [\pi_2 e] -D> snd \o f ; measurableT_comp measurable_snd mf + +(* | eval_var g str : let i := index str (dom g) in + [% str] -D> acc_typ (map snd g) i ; measurable_acc_typ (map snd g) i *) + +| eval_var g x H : let i := index x (dom g) in + exp_var x H -D> acc_typ (map snd g) i ; measurable_acc_typ (map snd g) i + +| eval_bernoulli g (r : {nonneg R}) (r1 : (r%:num <= 1)%R) : + (exp_bernoulli r r1 : exp D g _) -D> cst (bernoulli r1 : set bool -> \bar R) ; + measurable_cst _ + +| eval_binomial g n (p : {nonneg R}) (p1 : (p%:num <= 1)%R) : + (exp_binomial n p p1 : exp D g _) -D> cst (binomial_probability n p1) ; + measurable_cst _ + +| eval_uniform g (a b : R) (ab0 : (0 < b - a)%R) : + (exp_uniform a b ab0 : exp D g _) -D> cst (uniform_probability ab0) ; + measurable_cst _ + +| eval_poisson g n (e : exp D g _) f mf : + e -D> f ; mf -> + exp_poisson n e -D> poisson n \o f ; + measurableT_comp (measurable_poisson n) mf + +| eval_normalize g t (e : exp P g t) k : + e -P> k -> + [Normalize e] -D> normalize_pt k ; measurable_normalize_pt k + +| evalD_if g t e f mf (e1 : exp D g t) f1 mf1 e2 f2 mf2 : + e -D> f ; mf -> e1 -D> f1 ; mf1 -> e2 -D> f2 ; mf2 -> + [if e then e1 else e2] -D> fun x => if f x then f1 x else f2 x ; + measurable_fun_ifT mf mf1 mf2 + +| evalD_weak g h t e x (H : x.1 \notin dom (g ++ h)) f mf : + e -D> f ; mf -> + (exp_weak _ g h x e H : exp _ _ t) -D> weak g h x f ; + measurable_weak g h x f mf + +where "e -D> v ; mv" := (@evalD _ _ e v mv) + +with evalP : forall g t, exp P g t -> pval R g t -> Prop := + +| eval_letin g t1 t2 str (e1 : exp _ g t1) (e2 : exp _ _ t2) k1 k2 : + e1 -P> k1 -> e2 -P> k2 -> + [let str := e1 in e2] -P> letin' k1 k2 + +| eval_sample g t (e : exp _ _ (Prob t)) + (f : mctx g -> probability (mtyp t) R) mf : + e -D> f ; mf -> [Sample e] -P> sample f mf + +| eval_score g (e : exp _ g _) f mf : + e -D> f ; mf -> [Score e] -P> kscore mf + +| eval_return g t (e : exp D g t) f mf : + e -D> f ; mf -> [return e] -P> ret mf + +| evalP_if g t e f mf (e1 : exp P g t) k1 e2 k2 : + e -D> f ; mf -> e1 -P> k1 -> e2 -P> k2 -> + [if e then e1 else e2] -P> ite mf k1 k2 + +| evalP_weak g h t (e : exp P (g ++ h) t) x + (H : x.1 \notin dom (g ++ h)) f : + e -P> f -> + exp_weak _ g h x e H -P> kweak g h x f + +where "e -P> v" := (@evalP _ _ e v). + +End eval. + +Notation "e -D> v ; mv" := (@evalD _ _ _ e v mv) : lang_scope. +Notation "e -P> v" := (@evalP _ _ _ e v) : lang_scope. + +Scheme evalD_mut_ind := Induction for evalD Sort Prop +with evalP_mut_ind := Induction for evalP Sort Prop. + +(* properties of the evaluation relation *) +Section eval_prop. +Variables (R : realType). +Local Open Scope lang_scope. + +Lemma evalD_uniq g t (e : exp D g t) (u v : dval R g t) mu mv : + e -D> u ; mu -> e -D> v ; mv -> u = v. +Proof. +move=> hu. +apply: (@evalD_mut_ind R + (fun g t (e : exp D g t) f mf (h1 : e -D> f; mf) => + forall v mv, e -D> v; mv -> f = v) + (fun g t (e : exp P g t) u (h1 : e -P> u) => + forall v, e -P> v -> u = v)); last exact: hu. +all: (rewrite {g t e u v mu mv hu}). +- move=> g {}v {}mv. + inversion 1; subst g0. + by inj_ex H3. +- move=> g b {}v {}mv. + inversion 1; subst g0 b0. + by inj_ex H3. +- move=> g r {}v {}mv. + inversion 1; subst g0 r0. + by inj_ex H3. +- move=> g bop e1 f1 mf1 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. + inversion 1; subst g0 bop0. + inj_ex H10; subst v. + inj_ex H5; subst e1. + inj_ex H6; subst e5. + by move: H4 H11 => /IH1 <- /IH2 <-. +- move=> g rop e1 f1 mf1 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. + inversion 1; subst g0 rop0. + inj_ex H5; subst v. + inj_ex H1; subst e1. + inj_ex H3; subst e3. + by move: H6 H7 => /IH1 <- /IH2 <-. +- move=> g t1 e1 f1 mf1 t2 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. + simple inversion 1 => //; subst g0. + case: H3 => ? ?; subst t0 t3. + inj_ex H4; case: H4 => He1 He2. + inj_ex He1; subst e0. + inj_ex He2; subst e3. + inj_ex H5; subst v. + by move=> /IH1 <- /IH2 <-. +- move=> g t1 t2 e f mf H ih v mv. + inversion 1; subst g0 t3 t0. + inj_ex H11; subst v. + clear H9. + inj_ex H7; subst e1. + by rewrite (ih _ _ H4). +- move=> g t1 t2 e f mf H ih v mv. + inversion 1; subst g0 t3 t0. + inj_ex H11; subst v. + clear H9. + inj_ex H7; subst e1. + by rewrite (ih _ _ H4). +- move=> g str H n {}v {}mv. + inversion 1; subst g0. + inj_ex H9; rewrite -H9. + by inj_ex H10. +- move=> g r r1 {}v {}mv. + inversion 1; subst g0 r0. + inj_ex H3; subst v. + by have -> : r1 = r3 by []. +- move=> g n p p1 {}v {}mv. + inversion 1; subst g0 n0 p0. + inj_ex H2; subst v. + by have -> : p1 = p3 by []. +- move=> g a b ab0 {}v {}mv. + inversion 1; subst g0 a0 b0. + inj_ex H2; subst v. + by have -> : ab0 = ab2. +- move=> g n e0 f mf ev IH {}v {}mv. + inversion 1; subst g0 n0. + inj_ex H2; subst e0. + inj_ex H4; subst v. + by rewrite (IH _ _ H3). +- move=> g t e0 k ev IH {}v {}mv. + inversion 1; subst g0 t0. + inj_ex H2; subst e0. + inj_ex H4; subst v. + by rewrite (IH _ H3). +- move=> g t e f mf e1 f1 mf1 e2 f2 mf2 ev ih ev1 ih1 ev2 ih2 v m. + inversion 1; subst g0 t0. + inj_ex H2; subst e0. + inj_ex H6; subst e5. + inj_ex H7; subst e6. + inj_ex H9; subst v. + clear H11. + have ? := ih1 _ _ H12; subst f6. + have ? := ih2 _ _ H13; subst f7. + by rewrite (ih _ _ H5). +- move=> g h t e x H f mf ef ih {}v {}mv. + inversion 1; subst t0 g0 h0 x0. + inj_ex H12; subst e1. + inj_ex H14; subst v. + clear H16. + by rewrite (ih _ _ H5). +- move=> g t1 t2 x e1 e2 k1 k2 ev1 IH1 ev2 IH2 k. + inversion 1; subst g0 t0 t3 x. + inj_ex H7; subst k. + inj_ex H6; subst e5. + inj_ex H5; subst e4. + by rewrite (IH1 _ H4) (IH2 _ H8). +- move=> g t e f mf ev IH k. + inversion 1; subst g0. + inj_ex H5; subst t0. + inj_ex H5; subst e1. + inj_ex H7; subst k. + have ? := IH _ _ H3; subst f1. + by have -> : mf = mf1 by []. +- move=> g e f mf ev IH k. + inversion 1; subst g0. + inj_ex H0; subst e0. + inj_ex H4; subst k. + have ? := IH _ _ H2; subst f1. + by have -> : mf = mf0 by []. +- move=> g t e0 f mf ev IH k. + inversion 1; subst g0 t0. + inj_ex H5; subst e1. + inj_ex H7; subst k. + have ? := IH _ _ H3; subst f1. + by have -> : mf = mf1 by []. +- move=> g t e f mf e1 k1 e2 k2 ev ih ev1 ih1 ev2 ih2 k. + inversion 1; subst g0 t0. + inj_ex H0; subst e0. + inj_ex H1; subst e3. + inj_ex H5; subst k. + inj_ex H2; subst e4. + have ? := ih _ _ H6; subst f1. + have -> : mf = mf0 by []. + by rewrite (ih1 _ H7) (ih2 _ H8). +- move=> g h t e x xgh k ek ih. + inversion 1; subst x0 g0 h0 t0. + inj_ex H13; rewrite -H13. + inj_ex H11; subst e1. + by rewrite (ih _ H4). +Qed. + +Lemma evalP_uniq g t (e : exp P g t) (u v : pval R g t) : + e -P> u -> e -P> v -> u = v. +Proof. +move=> eu. +apply: (@evalP_mut_ind R + (fun g t (e : exp D g t) f mf (h : e -D> f; mf) => + forall v mv, e -D> v; mv -> f = v) + (fun g t (e : exp P g t) u (h : e -P> u) => + forall v, e -P> v -> u = v)); last exact: eu. +all: rewrite {g t e u v eu}. +- move=> g {}v {}mv. + inversion 1; subst g0. + by inj_ex H3. +- move=> g b {}v {}mv. + inversion 1; subst g0 b0. + by inj_ex H3. +- move=> g r {}v {}mv. + inversion 1; subst g0 r0. + by inj_ex H3. +- move=> g bop e1 f1 mf1 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. + inversion 1; subst g0 bop0. + inj_ex H10; subst v. + inj_ex H5; subst e1. + inj_ex H6; subst e5. + by move: H4 H11 => /IH1 <- /IH2 <-. +- move=> g rop e1 f1 mf1 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. + inversion 1; subst g0 rop0. + inj_ex H5; subst v. + inj_ex H1; subst e1. + inj_ex H3; subst e3. + by move: H6 H7 => /IH1 <- /IH2 <-. +- move=> g t1 e1 f1 mf1 t2 e2 f2 mf2 ev1 IH1 ev2 IH2 {}v {}mv. + simple inversion 1 => //; subst g0. + case: H3 => ? ?; subst t0 t3. + inj_ex H4; case: H4 => He1 He2. + inj_ex He1; subst e0. + inj_ex He2; subst e3. + inj_ex H5; subst v. + move=> e1f0 e2f3. + by rewrite (IH1 _ _ e1f0) (IH2 _ _ e2f3). +- move=> g t1 t2 e f mf H ih v mv. + inversion 1; subst g0 t3 t0. + inj_ex H11; subst v. + clear H9. + inj_ex H7; subst e1. + by rewrite (ih _ _ H4). +- move=> g t1 t2 e f mf H ih v mv. + inversion 1; subst g0 t3 t0. + inj_ex H11; subst v. + clear H9. + inj_ex H7; subst e1. + by rewrite (ih _ _ H4). +- move=> g str H n {}v {}mv. + inversion 1; subst g0. + inj_ex H9; rewrite -H9. + by inj_ex H10. +- move=> g r r1 {}v {}mv. + inversion 1; subst g0 r0. + inj_ex H3; subst v. + by have -> : r1 = r3 by []. +- move=> g n p p1 {}v {}mv. + inversion 1; subst g0 n0 p0. + inj_ex H2; subst v. + by have -> : p1 = p3 by []. +- move=> g a b ab0 {}v {}mv. + inversion 1; subst g0 a0 b0. + inj_ex H2; subst v. + by have -> : ab0 = ab2. +- move=> g n e f mf ev IH {}v {}mv. + inversion 1; subst g0 n0. + inj_ex H2; subst e0. + inj_ex H4; subst v. + inj_ex H5; subst mv. + by rewrite (IH _ _ H3). +- move=> g t e k ev IH {}v {}mv. + inversion 1; subst g0 t0. + inj_ex H2; subst e0. + inj_ex H4; subst v. + inj_ex H5; subst mv. + by rewrite (IH _ H3). +- move=> g t e f mf e1 f1 mf1 e2 f2 mf2 ef ih ef1 ih1 ef2 ih2 {}v {}mv. + inversion 1; subst g0 t0. + inj_ex H2; subst e0. + inj_ex H6; subst e5. + inj_ex H7; subst e6. + inj_ex H9; subst v. + clear H11. + have ? := ih1 _ _ H12; subst f6. + have ? := ih2 _ _ H13; subst f7. + by rewrite (ih _ _ H5). +- move=> g h t e x H f mf ef ih {}v {}mv. + inversion 1; subst x0 g0 h0 t0. + inj_ex H12; subst e1. + inj_ex H14; subst v. + clear H16. + by rewrite (ih _ _ H5). +- move=> g t1 t2 x e1 e2 k1 k2 ev1 IH1 ev2 IH2 k. + inversion 1; subst g0 x t3 t0. + inj_ex H7; subst k. + inj_ex H5; subst e4. + inj_ex H6; subst e5. + by rewrite (IH1 _ H4) (IH2 _ H8). +- move=> g t e f mf ep IH v. + inversion 1; subst g0 t0. + inj_ex H7; subst v. + inj_ex H5; subst e1. + have ? := IH _ _ H3; subst f1. + by have -> : mf = mf1 by []. +- move=> g e f mf ev IH k. + inversion 1; subst g0. + inj_ex H0; subst e0. + inj_ex H4; subst k. + have ? := IH _ _ H2; subst f1. + by have -> : mf = mf0 by []. +- move=> g t e f mf ev IH k. + inversion 1; subst g0 t0. + inj_ex H7; subst k. + inj_ex H5; subst e1. + have ? := IH _ _ H3; subst f1. + by have -> : mf = mf1 by []. +- move=> g t e f mf e1 k1 e2 k2 ev ih ev1 ih1 ev2 ih2 k. + inversion 1; subst g0 t0. + inj_ex H0; subst e0. + inj_ex H1; subst e3. + inj_ex H5; subst k. + inj_ex H2; subst e4. + have ? := ih _ _ H6; subst f1. + have -> : mf0 = mf by []. + by rewrite (ih1 _ H7) (ih2 _ H8). +- move=> g h t e x xgh k ek ih. + inversion 1; subst x0 g0 h0 t0. + inj_ex H13; rewrite -H13. + inj_ex H11; subst e1. + by rewrite (ih _ H4). +Qed. + +Lemma eval_total z g t (e : @exp R z g t) : + (match z with + | D => fun e => exists f mf, e -D> f ; mf + | P => fun e => exists k, e -P> k + end) e. +Proof. +elim: e. +all: rewrite {z g t}. +- by do 2 eexists; exact: eval_unit. +- by do 2 eexists; exact: eval_bool. +- by do 2 eexists; exact: eval_real. +- move=> g b e1 [f1 [mf1 H1]] e2 [f2 [mf2 H2]]. + by exists (fun_of_binop f1 f2); eexists; exact: eval_bin. +- move=> g r e1 [f1 [mf1 H1]] e2 [f2 [mf2 H2]]. + by exists (fun_of_relop r f1 f2); eexists; exact: eval_rel. +- move=> g t1 t2 e1 [f1 [mf1 H1]] e2 [f2 [mf2 H2]]. + by exists (fun x => (f1 x, f2 x)); eexists; exact: eval_pair. +- move=> g t1 t2 e [f [mf H]]. + by exists (fst \o f); eexists; exact: eval_proj1. +- move=> g t1 t2 e [f [mf H]]. + by exists (snd \o f); eexists; exact: eval_proj2. +- by move=> g x t tE; subst t; eexists; eexists; exact: eval_var. +- by eexists; eexists; exact: eval_bernoulli. +- by eexists; eexists; exact: eval_binomial. +- by eexists; eexists; exact: eval_uniform. +- move=> g h e [f [mf H]]. + by exists (poisson h \o f); eexists; exact: eval_poisson. +- move=> g t e [k ek]. + by exists (normalize_pt k); eexists; exact: eval_normalize. +- move=> g t1 t2 x e1 [k1 ev1] e2 [k2 ev2]. + by exists (letin' k1 k2); exact: eval_letin. +- move=> g t e [f [/= mf ef]]. + by eexists; exact: (@eval_sample _ _ _ _ _ mf). +- move=> g e [f [mf f_mf]]. + by exists (kscore mf); exact: eval_score. +- by move=> g t e [f [mf f_mf]]; exists (ret mf); exact: eval_return. +- case. + + move=> g t e1 [f [mf H1]] e2 [f2 [mf2 H2]] e3 [f3 [mf3 H3]]. + by exists (fun g => if f g then f2 g else f3 g), + (measurable_fun_ifT mf mf2 mf3); exact: evalD_if. + + move=> g t e1 [f [mf H1]] e2 [k2 H2] e3 [k3 H3]. + by exists (ite mf k2 k3); exact: evalP_if. +- case=> [g h t x e [f [mf ef]] xgh|g h st x e [k ek] xgh]. + + by exists (weak _ _ _ f), (measurable_weak _ _ _ _ mf); exact/evalD_weak. + + by exists (kweak _ _ _ k); exact: evalP_weak. +Qed. + +Lemma evalD_total g t (e : @exp R D g t) : exists f mf, e -D> f ; mf. +Proof. exact: (eval_total e). Qed. + +Lemma evalP_total g t (e : @exp R P g t) : exists k, e -P> k. +Proof. exact: (eval_total e). Qed. + +End eval_prop. + +Section execution_functions. +Local Open Scope lang_scope. +Context {R : realType}. +Implicit Type g : ctx. + +Definition execD g t (e : exp D g t) : + {f : dval R g t & measurable_fun setT f} := + let: exist _ H := cid (evalD_total e) in + existT _ _ (projT1 (cid H)). + +Lemma eq_execD g t (p1 p2 : @exp R D g t) : + projT1 (execD p1) = projT1 (execD p2) -> execD p1 = execD p2. +Proof. +rewrite /execD /=. +case: cid => /= f1 [mf1 ev1]. +case: cid => /= f2 [mf2 ev2] f12. +subst f2. +have ? : mf1 = mf2 by []. +subst mf2. +congr existT. +rewrite /sval. +case: cid => mf1' ev1'. +have ? : mf1 = mf1' by []. +subst mf1'. +case: cid => mf2' ev2'. +have ? : mf1 = mf2' by []. +by subst mf2'. +Qed. + +Definition execP g t (e : exp P g t) : pval R g t := + projT1 (cid (evalP_total e)). + +Lemma execD_evalD g t e x mx: + @execD g t e = existT _ x mx <-> e -D> x ; mx. +Proof. +rewrite /execD; split. + case: cid => x' [mx' H] [?]; subst x'. + have ? : mx = mx' by []. + by subst mx'. +case: cid => f' [mf' f'mf']/=. +move/evalD_uniq => /(_ _ _ f'mf') => ?; subst f'. +by case: cid => //= ? ?; congr existT. +Qed. + +Lemma evalD_execD g t (e : exp D g t) : + e -D> projT1 (execD e); projT2 (execD e). +Proof. +by rewrite /execD; case: cid => // x [mx xmx]/=; case: cid. +Qed. + +Lemma execP_evalP g t (e : exp P g t) x : + execP e = x <-> e -P> x. +Proof. +rewrite /execP; split; first by move=> <-; case: cid. +case: cid => // x0 Hx0. +by move/evalP_uniq => /(_ _ Hx0) ?; subst x. +Qed. + +Lemma evalP_execP g t (e : exp P g t) : e -P> execP e. +Proof. by rewrite /execP; case: cid. Qed. + +Lemma execD_unit g : @execD g _ [TT] = existT _ (cst tt) ktt. +Proof. exact/execD_evalD/eval_unit. Qed. + +Lemma execD_bool g b : @execD g _ [b:B] = existT _ (cst b) (kb b). +Proof. exact/execD_evalD/eval_bool. Qed. + +Lemma execD_real g r : @execD g _ [r:R] = existT _ (cst r) (kr r). +Proof. exact/execD_evalD/eval_real. Qed. + +Lemma execD_bin g bop (e1 : exp D g _) (e2 : exp D g _) : + let f1 := projT1 (execD e1) in let f2 := projT1 (execD e2) in + let mf1 := projT2 (execD e1) in let mf2 := projT2 (execD e2) in + execD (exp_bin bop e1 e2) = + @existT _ _ (fun_of_binop f1 f2) (mfun_of_binop mf1 mf2). +Proof. +by move=> f1 f2 mf1 mf2; apply/execD_evalD/eval_bin; exact: evalD_execD. +Qed. + +Lemma execD_rel g rop (e1 : exp D g _) (e2 : exp D g _) : + let f1 := projT1 (execD e1) in let f2 := projT1 (execD e2) in + let mf1 := projT2 (execD e1) in let mf2 := projT2 (execD e2) in + execD (exp_rel rop e1 e2) = + @existT _ _ (fun_of_relop rop f1 f2) (mfun_of_relop rop mf1 mf2). +Proof. +by move=> f1 f2 mf1 mf2; apply/execD_evalD/eval_rel; exact: evalD_execD. +Qed. + +Lemma execD_pair g t1 t2 (e1 : exp D g t1) (e2 : exp D g t2) : + let f1 := projT1 (execD e1) in let f2 := projT1 (execD e2) in + let mf1 := projT2 (execD e1) in let mf2 := projT2 (execD e2) in + execD [(e1, e2)] = + @existT _ _ (fun z => (f1 z, f2 z)) + (@measurable_fun_prod _ _ _ (mctx g) (mtyp t1) (mtyp t2) + f1 f2 mf1 mf2). +Proof. +by move=> f1 f2 mf1 mf2; apply/execD_evalD/eval_pair; exact: evalD_execD. +Qed. + +Lemma execD_proj1 g t1 t2 (e : exp D g (Pair t1 t2)) : + let f := projT1 (execD e) in + let mf := projT2 (execD e) in + execD [\pi_1 e] = @existT _ _ (fst \o f) + (measurableT_comp measurable_fst mf). +Proof. +by move=> f mf; apply/execD_evalD/eval_proj1; exact: evalD_execD. +Qed. + +Lemma execD_proj2 g t1 t2 (e : exp D g (Pair t1 t2)) : + let f := projT1 (execD e) in let mf := projT2 (execD e) in + execD [\pi_2 e] = @existT _ _ (snd \o f) + (measurableT_comp measurable_snd mf). +Proof. +by move=> f mf; apply/execD_evalD/eval_proj2; exact: evalD_execD. +Qed. + +Lemma execD_var_erefl g str : let i := index str (dom g) in + @execD g _ (exp_var str erefl) = existT _ (acc_typ (map snd g) i) + (measurable_acc_typ (map snd g) i). +Proof. by move=> i; apply/execD_evalD; exact: eval_var. Qed. + +Lemma execD_var g x (H : nth Unit (map snd g) (index x (dom g)) = lookup Unit g x) : + let i := index x (dom g) in + @execD g _ (exp_var x H) = existT _ (acc_typ (map snd g) i) + (measurable_acc_typ (map snd g) i). +Proof. by move=> i; apply/execD_evalD; exact: eval_var. Qed. + +Lemma execD_bernoulli g r (r1 : (r%:num <= 1)%R) : + @execD g _ (exp_bernoulli r r1) = + existT _ (cst [the probability _ _ of bernoulli r1]) (measurable_cst _). +Proof. exact/execD_evalD/eval_bernoulli. Qed. + +Lemma execD_binomial g n p (p1 : (p%:num <= 1)%R) : + @execD g _ (exp_binomial n p p1) = + existT _ (cst [the probability _ _ of binomial_probability n p1]) (measurable_cst _). +Proof. exact/execD_evalD/eval_binomial. Qed. + +Lemma execD_uniform g a b ab0 : + @execD g _ (exp_uniform a b ab0) = + existT _ (cst [the probability _ _ of uniform_probability ab0]) (measurable_cst _). +Proof. exact/execD_evalD/eval_uniform. Qed. + +Lemma execD_normalize_pt g t (e : exp P g t) : + @execD g _ [Normalize e] = + existT _ (normalize_pt (execP e) : _ -> pprobability _ _) + (measurable_normalize_pt (execP e)). +Proof. exact/execD_evalD/eval_normalize/evalP_execP. Qed. + +Lemma execD_poisson g n (e : exp D g Real) : + execD (exp_poisson n e) = + existT _ (poisson n \o (projT1 (execD e))) + (measurableT_comp (measurable_poisson n) (projT2 (execD e))). +Proof. exact/execD_evalD/eval_poisson/evalD_execD. Qed. + +Lemma execP_if g st e1 e2 e3 : + @execP g st [if e1 then e2 else e3] = + ite (projT2 (execD e1)) (execP e2) (execP e3). +Proof. +by apply/execP_evalP/evalP_if; [apply: evalD_execD| exact: evalP_execP..]. +Qed. + +Lemma execP_letin g x t1 t2 (e1 : exp P g t1) (e2 : exp P ((x, t1) :: g) t2) : + execP [let x := e1 in e2] = letin' (execP e1) (execP e2) :> (R.-sfker _ ~> _). +Proof. by apply/execP_evalP/eval_letin; exact: evalP_execP. Qed. + +Lemma execP_sample g t (e : @exp R D g (Prob t)) : + let x := execD e in + execP [Sample e] = sample (projT1 x) (projT2 x). +Proof. exact/execP_evalP/eval_sample/evalD_execD. Qed. + +Lemma execP_score g (e : exp D g Real) : + execP [Score e] = score (projT2 (execD e)). +Proof. exact/execP_evalP/eval_score/evalD_execD. Qed. + +Lemma execP_return g t (e : exp D g t) : + execP [return e] = ret (projT2 (execD e)). +Proof. exact/execP_evalP/eval_return/evalD_execD. Qed. + +Lemma execP_weak g h x t (e : exp P (g ++ h) t) + (xl : x.1 \notin dom (g ++ h)) : + execP (exp_weak P g h _ e xl) = kweak _ _ _ (execP e). +Proof. exact/execP_evalP/evalP_weak/evalP_execP. Qed. + +End execution_functions. +Arguments execD_var_erefl {R g} str. +Arguments execP_weak {R} g h x {t} e. +Arguments exp_var'E {R} str. diff --git a/theories/lang_syntax_examples.v b/theories/lang_syntax_examples.v new file mode 100644 index 0000000000..0a6148d0e6 --- /dev/null +++ b/theories/lang_syntax_examples.v @@ -0,0 +1,987 @@ +Require Import String. +From HB Require Import structures. +From mathcomp Require Import all_ssreflect ssralg ssrnum ssrint interval. +From mathcomp.classical Require Import mathcomp_extra boolp classical_sets. +From mathcomp.classical Require Import functions cardinality fsbigop. +Require Import signed reals ereal topology normedtype sequences esum measure. +Require Import lebesgue_measure numfun lebesgue_integral kernel prob_lang. +Require Import lang_syntax_util lang_syntax. +From mathcomp Require Import ring lra. + +(******************************************************************************) +(* Examples using the Probabilistic Programming Language of lang_syntax.v *) +(* *) +(* sample_pair_syntax := normalize ( *) +(* let x := sample (bernoulli 1/2) in *) +(* let y := sample (bernoulli 1/3) in *) +(* return (x, y)) *) +(* *) +(* bernoulli13_score := normalize ( *) +(* let x := sample (bernoulli 1/3) in *) +(* let _ := if x then score (1/3) else score (2/3) in *) +(* return x) *) +(* *) +(* bernoulli12_score := normalize ( *) +(* let x := sample (bernoulli 1/2) in *) +(* let _ := if x then score (1/3) else score (2/3) in *) +(* return x) *) +(* *) +(* hard_constraint := let x := Score {0}:R in return TT *) +(* *) +(* associativity of let-in expressions *) +(* *) +(* staton_bus_syntax == example from [Staton, ESOP 2017] *) +(* *) +(* staton_busA_syntax == same as staton_bus_syntax module associativity of *) +(* let-in expression *) +(* *) +(* commutativity of let-in expressions *) +(* *) +(******************************************************************************) + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +Import Order.TTheory GRing.Theory Num.Def Num.Theory. +Import numFieldTopology.Exports. + +Local Open Scope classical_set_scope. +Local Open Scope ring_scope. +Local Open Scope ereal_scope. + +(* letin' versions of rewriting laws *) +Lemma letin'_sample_bernoulli d d' (T : measurableType d) + (T' : measurableType d') (R : realType)(r : {nonneg R}) (r1 : (r%:num <= 1)%R) + (u : R.-sfker bool * T ~> T') x y : + letin' (sample_cst (bernoulli r1)) u x y = + r%:num%:E * u (true, x) y + (`1- (r%:num))%:E * u (false, x) y. +Proof. +rewrite letin'E/=. +rewrite ge0_integral_measure_sum// 2!big_ord_recl/= big_ord0 adde0/=. +by rewrite !ge0_integral_mscale//= !integral_dirac//= indicT 2!mul1e. +Qed. + +Section letin'_return. +Context d d' d3 (X : measurableType d) (Y : measurableType d') + (Z : measurableType d3) (R : realType). + +Lemma letin'_kret (k : R.-sfker X ~> Y) + (f : Y * X -> Z) (mf : measurable_fun setT f) x U : + measurable U -> + letin' k (ret mf) x U = k x (curry f ^~ x @^-1` U). +Proof. +move=> mU; rewrite letin'E. +under eq_integral do rewrite retE. +rewrite integral_indic ?setIT// -[X in measurable X]setTI. +exact: (measurableT_comp mf). +Qed. + +Lemma letin'_retk (f : X -> Y) (mf : measurable_fun setT f) + (k : R.-sfker Y * X ~> Z) x U : + measurable U -> letin' (ret mf) k x U = k (f x, x) U. +Proof. +move=> mU; rewrite letin'E retE integral_dirac ?indicT ?mul1e//. +exact: (measurableT_comp (measurable_kernel k _ mU)). +Qed. + +End letin'_return. + +Section letin'_ite. +Context d d2 d3 (T : measurableType d) (T2 : measurableType d2) + (Z : measurableType d3) (R : realType). +Variables (k1 k2 : R.-sfker T ~> Z) + (u : R.-sfker Z * T ~> T2) + (f : T -> bool) (mf : measurable_fun setT f) + (t : T) (U : set T2). + +Lemma letin'_iteT : f t -> letin' (ite mf k1 k2) u t U = letin' k1 u t U. +Proof. +move=> ftT. +rewrite !letin'E/=. +apply: eq_measure_integral => V mV _. +by rewrite iteE ftT. +Qed. + +Lemma letin'_iteF : ~~ f t -> letin' (ite mf k1 k2) u t U = letin' k2 u t U. +Proof. +move=> ftF. +rewrite !letin'E/=. +apply: eq_measure_integral => V mV _. +by rewrite iteE (negbTE ftF). +Qed. + +End letin'_ite. +(* /letin' versions of rewriting laws *) + +Local Open Scope lang_scope. + +Lemma execP_letinL {R : realType} g t1 t2 x (e1 : @exp R P g t1) (e1' : exp P g t1) + (e2 : exp P ((x, t1) :: g) t2) : + forall U, measurable U -> + execP e1 = execP e1' -> + execP [let x := e1 in e2] ^~ U = execP [let x := e1' in e2] ^~ U. +Proof. +by move=> U mU e1e1'; rewrite !execP_letin e1e1'. +Qed. + +Lemma execP_letinR {R : realType} g t1 t2 x (e1 : @exp R P g t1) + (e2 : exp P _ t2) (e2' : exp P ((x, t1) :: g) t2) : + forall U, measurable U -> + execP e2 = execP e2' -> + execP [let x := e1 in e2] ^~ U = execP [let x := e1 in e2'] ^~ U. +Proof. +by move=> U mU e1e1'; rewrite !execP_letin e1e1'. +Qed. + +Local Close Scope lang_scope. + +(* simple tests to check bidirectional hints *) +Module bidi_tests. +Local Open Scope lang_scope. +Import Notations. +Context (R : realType). + +Definition bidi_test1 x : @exp R P [::] _ := [ + let x := return {1}:R in + return #x]. + +Definition bidi_test2 (a b : string) + (a := "a") (b := "b") + (* (ba : infer (b != a)) *) + : @exp R P [::] _ := [ + let a := return {1}:R in + let b := return {true}:B in + (* let c := return {3}:R in + let d := return {4}:R in *) + return (#a, #b)]. + +Definition bidi_test3 (a b c d : string) + (ba : infer (b != a)) (ca : infer (c != a)) + (cb : infer (c != b)) (ab : infer (a != b)) + (ac : infer (a != c)) (bc : infer (b != c)) : @exp R P [::] _ := [ + let a := return {1}:R in + let b := return {2}:R in + let c := return {3}:R in + (* let d := return {4}:R in *) + return (#b, #a)]. + +Definition bidi_test4 (a b c d : string) + (ba : infer (b != a)) (ca : infer (c != a)) + (cb : infer (c != b)) (ab : infer (a != b)) + (ac : infer (a != c)) (bc : infer (b != c)) : @exp R P [::] _ := [ + let a := return {1}:R in + let b := return {2}:R in + let c := return {3}:R in + (* let d := return {4}:R in *) + return {exp_poisson O [#c(*{exp_var c erefl}*)]}]. + +End bidi_tests. + +Section trivial_example. +Local Open Scope lang_scope. +Import Notations. +Context {R : realType}. + +Lemma exec_normalize_return g x r : + projT1 (@execD _ g _ [Normalize return r:R]) x = \d_r :> probability _ R. +Proof. +rewrite execD_normalize_pt execP_return execD_real/=. +exact: normalize_kdirac. +Qed. + +End trivial_example. + +Section sample_pair. +Local Open Scope lang_scope. +Local Open Scope ring_scope. +Import Notations. +Context {R : realType}. + +Definition sample_pair_syntax0 : @exp R _ [::] _ := + [let "x" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in + let "y" := Sample {exp_bernoulli (1 / 3%:R)%:nng (p1S 2)} in + return (#{"x"}, #{"y"})]. + +Definition sample_pair_syntax : exp _ [::] _ := + [Normalize {sample_pair_syntax0}]. + +Lemma exec_sample_pair0 (A : set (bool * bool)) : + @execP R [::] _ sample_pair_syntax0 tt A = + ((1 / 2)%:E * + ((1 / 3)%:E * \d_(true, true) A + + (1 - 1 / 3)%:E * \d_(true, false) A) + + (1 - 1 / 2)%:E * + ((1 / 3)%:E * \d_(false, true) A + + (1 - 1 / 3)%:E * \d_(false, false) A))%E. +Proof. +rewrite !execP_letin !execP_sample !execD_bernoulli execP_return /=. +rewrite execD_pair !exp_var'E (execD_var_erefl "x") (execD_var_erefl "y") /=. +rewrite letin'E integral_measure_add//= !ge0_integral_mscale//= /onem. +rewrite !integral_dirac//= !indicE !in_setT/= !mul1e. +rewrite !letin'E !integral_measure_add//= !ge0_integral_mscale//= /onem. +by rewrite !integral_dirac//= !indicE !in_setT/= !mul1e !diracE. +Qed. + +Lemma exec_sample_pair0_TandT : + @execP R [::] _ sample_pair_syntax0 tt [set (true, true)] = (1 / 6)%:E. +Proof. +rewrite exec_sample_pair0 !diracE mem_set//; do 3 rewrite memNset//=. +by rewrite /= !mule0 mule1 !add0e mule0 adde0; congr (_%:E); lra. +Qed. + +Lemma exec_sample_pair0_TandF : + @execP R [::] _ sample_pair_syntax0 tt [set (true, false)] = (1 / 3)%:E. +Proof. +rewrite exec_sample_pair0 !diracE memNset// mem_set//; do 2 rewrite memNset//. +by rewrite /= !mule0 mule1 !add0e mule0 adde0; congr (_%:E); lra. +Qed. + +Lemma exec_sample_pair0_TandT' : + @execP R [::] _ sample_pair_syntax0 tt [set p | p.1 && p.2] = (1 / 6)%:E. +Proof. +rewrite exec_sample_pair0 !diracE mem_set//; do 3 rewrite memNset//=. +by rewrite /= !mule0 mule1 !add0e mule0 adde0; congr (_%:E); lra. +Qed. + +Lemma exec_sample_pair_TorT : + (projT1 (execD sample_pair_syntax)) tt [set p | p.1 || p.2] = (2 / 3)%:E. +Proof. +rewrite execD_normalize_pt normalizeE/= exec_sample_pair0. +rewrite !diracE; do 4 rewrite mem_set//=. +rewrite eqe ifF; last by apply/negbTE/negP => /orP[/eqP|//]; lra. +rewrite exec_sample_pair0 !diracE; do 3 rewrite mem_set//; rewrite memNset//=. +by rewrite !mule1; congr (_%:E); field. +Qed. + +Definition sample_and_syntax0 : @exp R _ [::] _ := + [let "x" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in + let "y" := Sample {exp_bernoulli (1 / 3%:R)%:nng (p1S 2)} in + return #{"x"} && #{"y"}]. + +Lemma exec_sample_and0 (A : set bool) : + @execP R [::] _ sample_and_syntax0 tt A = ((1 / 6)%:E * \d_true A + + (1 - 1 / 6)%:E * \d_false A)%E. +Proof. +rewrite !execP_letin !execP_sample !execD_bernoulli execP_return /=. +rewrite !(@execD_bin _ _ binop_and) !exp_var'E (execD_var_erefl "x") (execD_var_erefl "y") /=. +rewrite letin'E integral_measure_add//= !ge0_integral_mscale//= /onem. +rewrite !integral_dirac//= !indicE !in_setT/= !mul1e. +rewrite !letin'E !integral_measure_add//= !ge0_integral_mscale//= /onem. +rewrite !integral_dirac//= !indicE !in_setT/= !mul1e. +rewrite muleDr// -addeA; congr (_ + _)%E. + by rewrite !muleA; congr (_%:E); congr (_ * _); field. +rewrite -muleDl// !muleA -muleDl//. +by congr (_%:E); congr (_ * _); field. +Qed. + +Definition sample_bernoulli_and3 : @exp R _ [::] _ := + [let "x" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in + let "y" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in + let "z" := Sample {exp_bernoulli (1 / 2)%:nng (p1S 1)} in + return #{"x"} && #{"y"} && #{"z"}]. + +Lemma exec_sample_bernoulli_and3 t U : + execP sample_bernoulli_and3 t U = ((1 / 8)%:E * \d_true U + + (1 - 1 / 8)%:E * \d_false U)%E. +Proof. +rewrite !execP_letin !execP_sample !execD_bernoulli execP_return /=. +rewrite !(@execD_bin _ _ binop_and) !exp_var'E. +rewrite (execD_var_erefl "x") (execD_var_erefl "y") (execD_var_erefl "z") /=. +rewrite letin'E integral_measure_add//= !ge0_integral_mscale//= /onem. +rewrite !integral_dirac//= !indicE !in_setT/= !mul1e. +rewrite !letin'E !integral_measure_add//= !ge0_integral_mscale//= /onem. +rewrite !integral_dirac//= !indicE !in_setT/= !mul1e. +rewrite !letin'E !integral_measure_add//= !ge0_integral_mscale//= /onem. +rewrite !integral_dirac//= !indicE !in_setT/= !mul1e !diracE. +rewrite !muleDr// -!addeA. +by congr (_ + _)%E; rewrite ?addeA !muleA -?muleDl//; +congr (_ * _)%E; congr (_%:E); field. +Qed. + +End sample_pair. + +Section bernoulli_examples. +Local Open Scope ring_scope. +Local Open Scope lang_scope. +Import Notations. +Context {R : realType}. + +Definition bernoulli13_score := [Normalize + let "x" := Sample {@exp_bernoulli R [::] (1 / 3%:R)%:nng (p1S 2)} in + let "_" := if #{"x"} then Score {(1 / 3)}:R else Score {(2 / 3)}:R in + return #{"x"}]. + +Lemma exec_bernoulli13_score : + execD bernoulli13_score = execD (exp_bernoulli (1 / 5%:R)%:nng (p1S 4)). +Proof. +apply: eq_execD. +rewrite execD_bernoulli/= /bernoulli13_score execD_normalize_pt 2!execP_letin. +rewrite execP_sample/= execD_bernoulli/= execP_if /= exp_var'E. +rewrite (execD_var_erefl "x")/= !execP_return/= 2!execP_score 2!execD_real/=. +apply: funext=> g; apply: eq_probability => U. +rewrite normalizeE !letin'E/=. +under eq_integral. + move=> x _. + rewrite !letin'E. + under eq_integral do rewrite retE /=. + over. +rewrite !integral_measure_add //=; last by move=> b _; rewrite integral_ge0. +rewrite !ge0_integral_mscale //=; last 2 first. + by move=> b _; rewrite integral_ge0. + by move=> b _; rewrite integral_ge0. +rewrite !integral_dirac// !indicE !in_setT !mul1e. +rewrite iteE/= !ge0_integral_mscale//=. +rewrite ger0_norm//. +rewrite !integral_indic//= !iteE/= /mscale/=. +rewrite setTI diracE !in_setT !mule1. +rewrite ger0_norm//. +rewrite -EFinD/= eqe ifF; last first. + by apply/negbTE/negP => /orP[/eqP|//]; rewrite /onem; lra. +rewrite !letin'E/= !iteE/=. +rewrite !ge0_integral_mscale//=. +rewrite ger0_norm//. +rewrite !integral_dirac//= !indicE !in_setT /= !mul1e ger0_norm//. +rewrite exp_var'E (execD_var_erefl "x")/=. +rewrite /bernoulli/= measure_addE/= /mscale/= !mul1r. +rewrite muleDl//; congr (_ + _)%E; + rewrite -!EFinM; + congr (_%:E); + by rewrite indicE /onem; case: (_ \in _); field. +Qed. + +Definition bernoulli12_score := [Normalize + let "x" := Sample {@exp_bernoulli R [::] (1 / 2)%:nng (p1S 1)} in + let "r" := if #{"x"} then Score {(1 / 3)}:R else Score {(2 / 3)}:R in + return #{"x"}]. + +Lemma exec_bernoulli12_score : + execD bernoulli12_score = execD (exp_bernoulli (1 / 3%:R)%:nng (p1S 2)). +Proof. +apply: eq_execD. +rewrite execD_bernoulli/= /bernoulli12_score execD_normalize_pt 2!execP_letin. +rewrite execP_sample/= execD_bernoulli/= execP_if /= exp_var'E. +rewrite (execD_var_erefl "x")/= !execP_return/= 2!execP_score 2!execD_real/=. +apply: funext=> g; apply: eq_probability => U. +rewrite normalizeE !letin'E/=. +under eq_integral. + move=> x _. + rewrite !letin'E. + under eq_integral do rewrite retE /=. + over. +rewrite !integral_measure_add //=; last by move=> b _; rewrite integral_ge0. +rewrite !ge0_integral_mscale //=; last 2 first. + by move=> b _; rewrite integral_ge0. + by move=> b _; rewrite integral_ge0. +rewrite !integral_dirac// !indicE !in_setT !mul1e. +rewrite iteE/= !ge0_integral_mscale//=. +rewrite ger0_norm//. +rewrite !integral_indic//= !iteE/= /mscale/=. +rewrite setTI diracE !in_setT !mule1. +rewrite ger0_norm//. +rewrite -EFinD/= eqe ifF; last first. + apply/negbTE/negP => /orP[/eqP|//]. + by rewrite /onem; lra. +rewrite !letin'E/= !iteE/=. +rewrite !ge0_integral_mscale//=. +rewrite ger0_norm//. +rewrite !integral_dirac//= !indicE !in_setT /= !mul1e ger0_norm//. +rewrite exp_var'E (execD_var_erefl "x")/=. +rewrite /bernoulli/= measure_addE/= /mscale/= !mul1r. +rewrite muleDl//; congr (_ + _)%E; + rewrite -!EFinM; + congr (_%:E); + by rewrite indicE /onem; case: (_ \in _); field. +Qed. + +(* https://dl.acm.org/doi/pdf/10.1145/2933575.2935313 (Sect. 4) *) +Definition bernoulli14_score := [Normalize + let "x" := Sample {@exp_bernoulli R [::] (1 / 4%:R)%:nng (p1S 3)} in + let "r" := if #{"x"} then Score {5}:R else Score {2}:R in + return #{"x"}]. + +Let p511 : ((5%:R / 11%:R)%:nng%:num <= (1 : R)). +Proof. by rewrite /=; lra. Qed. + +Lemma exec_bernoulli14_score : + execD bernoulli14_score = execD (exp_bernoulli (5%:R / 11%:R)%:nng p511). +Proof. +apply: eq_execD. +rewrite execD_bernoulli/= execD_normalize_pt 2!execP_letin. +rewrite execP_sample/= execD_bernoulli/= execP_if /= !exp_var'E. +rewrite !execP_return/= 2!execP_score 2!execD_real/=. +rewrite !(execD_var_erefl "x")/=. +apply: funext=> g; apply: eq_probability => U. +rewrite normalizeE !letin'E/=. +under eq_integral. + move=> x _. + rewrite !letin'E. + under eq_integral do rewrite retE /=. + over. +rewrite !integral_measure_add //=; last by move=> b _; rewrite integral_ge0. +rewrite !ge0_integral_mscale //=; last 2 first. + by move=> b _; exact: integral_ge0. + by move=> b _; exact: integral_ge0. +rewrite !integral_dirac// !indicE !in_setT !mul1e. +rewrite iteE/= !ge0_integral_mscale//=. +rewrite ger0_norm//. +rewrite !integral_indic//= !iteE/= /mscale/=. +rewrite setTI diracE !in_setT !mule1. +rewrite ger0_norm//. +rewrite -EFinD/= eqe ifF; last first. + apply/negbTE/negP => /orP[/eqP|//]. + by rewrite /onem; lra. +rewrite !letin'E/= !iteE/=. +rewrite !ge0_integral_mscale//=. +rewrite ger0_norm//. +rewrite !integral_dirac//= !indicE !in_setT /= !mul1e ger0_norm//. +rewrite /bernoulli/= measure_addE/= /mscale/= !mul1r. +rewrite muleDl//; congr (_ + _)%E; + rewrite -!EFinM; + congr (_%:E); + by rewrite indicE /onem; case: (_ \in _); field. +Qed. + +End bernoulli_examples. + +Section binomial_examples. +Context {R : realType}. +Open Scope lang_scope. +Open Scope ring_scope. + +Definition sample_binomial3 : @exp R _ [::] _ := + [let "x" := Sample {exp_binomial 3 (1 / 2)%:nng (p1S 1)} in + return #{"x"}]. + +Open Scope real_scope. + +Lemma exec_sample_binomial3 t U : + execP sample_binomial3 t U = ((1 / 8)%:E * @dirac _ R 0%:R R U + + (3 / 8)%:E * @dirac _ R 1%:R R U + + (3 / 8)%:E * @dirac _ R 2%:R R U + + (1 / 8)%:E * @dirac _ R 3%:R R U)%E. +Proof. +rewrite /sample_binomial3 execP_letin execP_sample execP_return. +rewrite exp_var'E (execD_var_erefl "x") !execD_binomial/=. +rewrite letin'E ge0_integral_measure_sum//=. +rewrite !big_ord_recl big_ord0 !ge0_integral_mscale//=. +rewrite !integral_dirac// /bump. +rewrite indicT !binS/= !bin0 bin1 bin2 bin_small// addn0. +rewrite expr0 mulr1 mul1r subn0. +rewrite -2!addeA. +congr _%E. +congr (_ + _)%:E. +congr (_ * _). +by field. +by rewrite mul1r. +congr (_ + _). +congr (_ * _). +rewrite expr1 /onem. +by field. +by rewrite mul1r. +congr (_ + _). +congr (_ * _). +rewrite /onem/=. +by field. +by rewrite mul1r. +rewrite addr0. +congr (_ * _). +rewrite /onem/=. +by field. +by rewrite mul1r. +Admitted. + +End binomial_examples. + +Section hard_constraint'. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). + +Definition fail' := + letin' (score (@measurable_cst _ _ X _ setT (0%R : R))) + (ret (@measurable_cst _ _ _ Y setT point)). + +Lemma fail'E x U : fail' x U = 0. +Proof. by rewrite /fail' letin'E ge0_integral_mscale//= normr0 mul0e. Qed. + +End hard_constraint'. +Arguments fail' {d d' X Y R}. + +Section casino_example. +Open Scope lang_scope. +Open Scope ring_scope. +Context (R : realType). + +Lemma a01 : 0 < 1 - 0 :> R. Proof. by []. Qed. + +Definition binomial_le : @exp R _ [::] Bool := + [let "a2" := Sample {exp_binomial 3 (1 / 2)%:nng (p1S 1)} in + return {1}:R <= #{"a2"}]. + +Lemma exec_binomial_le t U : + execP binomial_le t U = ((7 / 8)%:E * @dirac _ _ true R U + + (1 / 8)%:E * @dirac _ _ false R U)%E. +Proof. +rewrite /binomial_le execP_letin execP_sample execP_return execD_rel execD_real. +rewrite exp_var'E (execD_var_erefl "a2") execD_binomial. +rewrite letin'E//= /binomial_probability ge0_integral_measure_sum//=. +rewrite !big_ord_recl big_ord0 !ge0_integral_mscale//=. +rewrite !integral_dirac// /bump. +rewrite indicT !binS/= !bin0 bin1 bin2 bin_small// addn0 !mul1e. +rewrite addeC adde0. +congr (_ + _)%:E. +have -> : (1 <= 1)%R. admit. +have -> : (1 <= 2)%R. admit. +have -> : (1 <= 3)%R. admit. +rewrite -!mulrDl. +congr (_ * _). +rewrite /onem addn0 add0n. +by field. +congr (_ * _). +by field. +by rewrite ler10. +Admitted. + +Definition binomial_guard : @exp R _ [::] Real := + [let "a1" := Sample {exp_binomial 3 (1 / 2)%:nng (p1S 1)} in + let "_" := if #{"a1"} == {1}:R then return TT else Score {0}:R in + return #{"a1"}]. + +Lemma exec_binomial_guard t U : + execP binomial_guard t U = ((7 / 8)%:E * @dirac _ R 1%R R U + + (1 / 8)%:E * @dirac _ R 0%R R U)%E. +Proof. +rewrite /binomial_guard !execP_letin execP_sample execP_return execP_if. +rewrite !exp_var'E execD_rel !(execD_var_erefl "a1") execP_return execD_unit execD_binomial execD_real execP_score execD_real. +rewrite !letin'E//= /binomial_probability ge0_integral_measure_sum//=. +rewrite !big_ord_recl big_ord0 !ge0_integral_mscale//=. +rewrite !integral_dirac// /bump. +rewrite indicT !binS/= !bin0 bin1 bin2 bin_small// addn0 !mul1e. +rewrite !letin'E//= !iteE/= !diracE/=. +have -> : (0 == 1)%R = false; first by admit. +have -> : (1 == 1)%R; first by admit. +have -> : (2 == 1)%R = false; first by admit. +have -> : (3 == 1)%R = false; first by admit. +rewrite addeC adde0. +Admitted. + +(* Definition casino : exp _ [::] _ := + [let "p" := Sample {exp_uniform 0 1 a01} in + let "a1" := Sample {exp_binomial 8 [#{"p"}]} in + return #{"p"}]. *) + +End casino_example. + +(* hard constraints to express score below 1 *) +Lemma score_fail' d (X : measurableType d) {R : realType} + (r : {nonneg R}) (r1 : (r%:num <= 1)%R) : + score (kr r%:num) = + letin' (sample_cst (bernoulli r1) : R.-pker X ~> _) + (ite macc0of2 (ret ktt) fail'). +Proof. +apply/eq_sfkernel => x U. +rewrite letin'E/= /sample; unlock. +rewrite integral_measure_add//= ge0_integral_mscale//= ge0_integral_mscale//=. +rewrite integral_dirac//= integral_dirac//= !indicT/= !mul1e. +by rewrite /mscale/= iteE//= iteE//= fail'E mule0 adde0 ger0_norm. +Qed. + +Section hard_constraint. +Local Open Scope ring_scope. +Local Open Scope lang_scope. +Import Notations. +Context {R : realType} {str : string}. + +Definition hard_constraint g : @exp R _ g _ := + [let str := Score {0}:R in return TT]. + +Lemma exec_hard_constraint g mg U : + execP (hard_constraint g) mg U = fail' (false, tt) U. +Proof. +rewrite execP_letin execP_score execD_real execP_return execD_unit/=. +rewrite letin'E integral_indic//= /mscale/= normr0 mul0e. +by rewrite /fail' letin'E/= ge0_integral_mscale//= normr0 mul0e. +Qed. + +Lemma exec_score_fail (r : {nonneg R}) (r1 : (r%:num <= 1)%R) : + execP (g := [::]) [Score {r%:num}:R] = + execP [let str := Sample {exp_bernoulli r r1} in + if #str then return TT else {hard_constraint _}]. +Proof. +rewrite execP_score execD_real /= score_fail'. +rewrite execP_letin execP_sample/= execD_bernoulli execP_if execP_return. +rewrite execD_unit/= exp_var'E /=. + apply/ctx_prf_head. +move=> h. +apply: eq_sfkernel=> /= -[] U. +rewrite 2!letin'E/=. +apply: eq_integral => b _. +rewrite 2!iteE//=. +case: b => //=. +- suff : projT1 (@execD R _ _ (exp_var str h)) (true, tt) = true by move=> ->. + set g := [:: (str, Bool)]. + have /= := @execD_var R [:: (str, Bool)] str. + by rewrite eqxx => /(_ h) ->. +- have -> : projT1 (@execD R _ _ (exp_var str h)) (false, tt) = false. + set g := [:: (str, Bool)]. + have /= := @execD_var R [:: (str, Bool)] str. + by rewrite eqxx /= => /(_ h) ->. + by rewrite (@exec_hard_constraint [:: (str, Bool)]). +Qed. + +End hard_constraint. + +Section letinA. +Local Open Scope lang_scope. +Variable R : realType. + +Lemma letinA g x y t1 t2 t3 (xyg : x \notin dom ((y, t2) :: g)) + (e1 : @exp R P g t1) + (e2 : exp P [:: (x, t1) & g] t2) + (e3 : exp P [:: (y, t2) & g] t3) : + forall U, measurable U -> + execP [let x := e1 in + let y := e2 in + {@exp_weak _ _ [:: (y, t2)] _ _ (x, t1) e3 xyg}] ^~ U = + execP [let y := + let x := e1 in e2 in + e3] ^~ U. +Proof. +move=> U mU; apply/funext=> z1. +rewrite !execP_letin. +rewrite (execP_weak [:: (y, t2)]). +apply: letin'A => //= z2 z3. +rewrite /kweak /mctx_strong /=. +by destruct z3. +Qed. + +Example letinA12 : forall U, measurable U -> + @execP R [::] _ [let "y" := return {1}:R in + let "x" := return {2}:R in + return #{"x"}] ^~ U = + @execP R [::] _ [let "x" := + let "y" := return {1}:R in return {2}:R in + return #{"x"}] ^~ U. +Proof. +move=> U mU. +rewrite !execP_letin !execP_return !execD_real. +apply: funext=> x. +rewrite !exp_var'E /= !(execD_var_erefl "x")/=. +exact: letin'A. +Qed. + +End letinA. + +Section staton_bus. +Local Open Scope ring_scope. +Local Open Scope lang_scope. +Import Notations. +Context {R : realType}. + +Section tests. + +Local Notation "$ str" := (@exp_var _ _ str%string _ erefl) + (in custom expr at level 1, format "$ str"). + +Definition staton_bus_syntax0_generic (x r u : string) + (rx : infer (r != x)) (Rx : infer (u != x)) + (ur : infer (u != r)) (xr : infer (x != r)) + (xu : infer (x != u)) (ru : infer (r != u)) : @exp R P [::] _ := + [let x := Sample {exp_bernoulli (2 / 7%:R)%:nng p27} in + let r := if #x then return {3}:R else return {10}:R in + let u := Score {exp_poisson 4 [#r]} in + return #x]. + +Fail Definition staton_bus_syntax0_generic' (x r u : string) + (rx : infer (r != x)) (Rx : infer (u != x)) + (ur : infer (u != r)) (xr : infer (x != r)) + (xu : infer (x != u)) (ru : infer (r != u)) : @exp R P [::] _ := + [let x := Sample {exp_bernoulli (2 / 7%:R)%:nng p27} in + let r := if $x then return {3}:R else return {10}:R in + let u := Score {exp_poisson 4 [$r]} in + return $x]. + +Fail Definition staton_bus_syntax0' : @exp R _ [::] _ := + [let "x" := Sample {exp_bernoulli (2 / 7%:R)%:nng p27} in + let "r" := if ${"x"} then return {3}:R else return {10}:R in + let "_" := Score {exp_poisson 4 [${"r"}]} in + return ${"x"}]. + +Definition staton_bus_syntax0 : @exp R _ [::] _ := + [let "x" := Sample {exp_bernoulli (2 / 7%:R)%:nng p27} in + let "r" := if #{"x"} then return {3}:R else return {10}:R in + let "_" := Score {exp_poisson 4 [#{"r"}]} in + return #{"x"}]. + +End tests. + +Definition staton_bus_syntax := [Normalize {staton_bus_syntax0}]. + +Let sample_bern : R.-sfker munit ~> mbool := sample_cst (bernoulli p27). + +Let ite_3_10 : R.-sfker mbool * munit ~> (mR R) := + ite macc0of2 (ret k3) (ret k10). + +Let score_poisson4 : R.-sfker mR R * (mbool * munit) ~> munit := + score (measurableT_comp (measurable_poisson 4) macc0of2). + +Let kstaton_bus' := + letin' sample_bern + (letin' ite_3_10 + (letin' score_poisson4 (ret macc2of4'))). + +Lemma eval_staton_bus0 : staton_bus_syntax0 -P> kstaton_bus'. +Proof. +apply: eval_letin; first by apply: eval_sample; exact: eval_bernoulli. +apply: eval_letin. + apply/evalP_if; [|exact/eval_return/eval_real..]. + rewrite exp_var'E. + by apply/execD_evalD; rewrite (execD_var_erefl "x")/=; congr existT. +apply: eval_letin. + apply/eval_score/eval_poisson. + rewrite exp_var'E. + by apply/execD_evalD; rewrite (execD_var_erefl "r")/=; congr existT. +apply/eval_return/execD_evalD. +by rewrite exp_var'E (execD_var_erefl "x")/=; congr existT. +Qed. + +Lemma exec_staton_bus0' : execP staton_bus_syntax0 = kstaton_bus'. +Proof. +rewrite 3!execP_letin execP_sample/= execD_bernoulli. +rewrite /kstaton_bus'; congr letin'. +rewrite !execP_if !execP_return !execD_real/=. +rewrite exp_var'E (execD_var_erefl "x")/=. +have -> : measurable_acc_typ [:: Bool] 0 = macc0of2 by []. +congr letin'. +rewrite execP_score execD_poisson/=. +rewrite exp_var'E (execD_var_erefl "r")/=. +have -> : measurable_acc_typ [:: Real; Bool] 0 = macc0of2 by []. +congr letin'. +by rewrite exp_var'E (execD_var_erefl "x") /=; congr ret. +Qed. + +Lemma exec_staton_bus : execD staton_bus_syntax = + existT _ (normalize_pt kstaton_bus') (measurable_normalize_pt _). +Proof. by rewrite execD_normalize_pt exec_staton_bus0'. Qed. + +Let poisson4 := @poisson R 4%N. + +Let staton_bus_probability U := + ((2 / 7)%:E * (poisson4 3)%:E * \d_true U + + (5 / 7)%:E * (poisson4 10)%:E * \d_false U)%E. + +Lemma exec_staton_bus0 (U : set bool) : + execP staton_bus_syntax0 tt U = staton_bus_probability U. +Proof. +rewrite exec_staton_bus0' /staton_bus_probability /kstaton_bus'. +rewrite letin'_sample_bernoulli. +rewrite -!muleA; congr (_ * _ + _ * _)%E. +- rewrite letin'_iteT//. + rewrite letin'_retk//. + rewrite letin'_kret//. + rewrite /score_poisson4. + by rewrite /score/= /mscale/= ger0_norm//= poisson_ge0. +- by rewrite onem27. +- rewrite letin'_iteF//. + rewrite letin'_retk//. + rewrite letin'_kret//. + rewrite /score_poisson4. + by rewrite /score/= /mscale/= ger0_norm//= poisson_ge0. +Qed. + +End staton_bus. + +(* same as staton_bus module associativity of letin *) +Section staton_busA. +Local Open Scope ring_scope. +Local Open Scope lang_scope. +Import Notations. +Context {R : realType}. + +Definition staton_busA_syntax0 : @exp R _ [::] _ := + [let "x" := Sample {exp_bernoulli (2 / 7%:R)%:nng p27} in + let "_" := + let "r" := if #{"x"} then return {3}:R else return {10}:R in + Score {exp_poisson 4 [#{"r"}]} in + return #{"x"}]. + +Definition staton_busA_syntax : exp _ [::] _ := + [Normalize {staton_busA_syntax0}]. + +Let sample_bern : R.-sfker munit ~> mbool := sample_cst (bernoulli p27). + +Let ite_3_10 : R.-sfker mbool * munit ~> (mR R) := + ite macc0of2 (ret k3) (ret k10). + +Let score_poisson4 : R.-sfker mR R * (mbool * munit) ~> munit := + score (measurableT_comp (measurable_poisson 4) macc0of3'). + +(* same as kstaton_bus _ (measurable_poisson 4) but expressed with letin' + instead of letin *) +Let kstaton_busA' := + letin' sample_bern + (letin' + (letin' ite_3_10 + score_poisson4) + (ret macc1of3')). + +(*Lemma kstaton_busA'E : kstaton_busA' = kstaton_bus _ (measurable_poisson 4). +Proof. +apply/eq_sfkernel => -[] U. +rewrite /kstaton_busA' /kstaton_bus. +rewrite letin'_letin. +rewrite /sample_bern. +congr (letin _ _ tt U). +rewrite 2!letin'_letin/=. +Abort.*) + +Lemma eval_staton_busA0 : staton_busA_syntax0 -P> kstaton_busA'. +Proof. +apply: eval_letin; first by apply: eval_sample; exact: eval_bernoulli. +apply: eval_letin. + apply: eval_letin. + apply/evalP_if; [|exact/eval_return/eval_real..]. + rewrite exp_var'E. + by apply/execD_evalD; rewrite (execD_var_erefl "x")/=; congr existT. + apply/eval_score/eval_poisson. + rewrite exp_var'E. + by apply/execD_evalD; rewrite (execD_var_erefl "r")/=; congr existT. +apply/eval_return. +by apply/execD_evalD; rewrite exp_var'E (execD_var_erefl "x")/=; congr existT. +Qed. + +Lemma exec_staton_busA0' : execP staton_busA_syntax0 = kstaton_busA'. +Proof. +rewrite 3!execP_letin execP_sample/= execD_bernoulli. +rewrite /kstaton_busA'; congr letin'. +rewrite !execP_if !execP_return !execD_real/=. +rewrite exp_var'E (execD_var_erefl "x")/=. +have -> : measurable_acc_typ [:: Bool] 0 = macc0of2 by []. +congr letin'. + rewrite execP_score execD_poisson/=. + rewrite exp_var'E (execD_var_erefl "r")/=. + by have -> : measurable_acc_typ [:: Real; Bool] 0 = macc0of3' by []. +by rewrite exp_var'E (execD_var_erefl "x") /=; congr ret. +Qed. + +Lemma exec_statonA_bus : execD staton_busA_syntax = + existT _ (normalize_pt kstaton_busA') (measurable_normalize_pt _). +Proof. by rewrite execD_normalize_pt exec_staton_busA0'. Qed. + +(* equivalence between staton_bus and staton_busA *) +Lemma staton_bus_staton_busA : + execP staton_bus_syntax0 = @execP R _ _ staton_busA_syntax0. +Proof. +rewrite /staton_bus_syntax0 /staton_busA_syntax0. +rewrite execP_letin. +rewrite [in RHS]execP_letin. +congr (letin' _). +set e1 := exp_if _ _ _. +set e2 := exp_score _. +set e3 := (exp_return _ in RHS). +pose f := @found _ Unit "x" Bool [::]. +have r_f : "r" \notin [seq i.1 | i <- ("_", Unit) :: untag (ctx_of f)] by []. +have H := @letinA _ _ _ _ _ _ + (lookup Unit (("_", Unit) :: untag (ctx_of f)) "x") + r_f e1 e2 e3. +apply/eq_sfkernel => /= x U. +have mU : + (@mtyp_disp R (lookup Unit (("_", Unit) :: untag (ctx_of f)) "x")).-measurable U. + by []. +move: H => /(_ U mU) /(congr1 (fun f => f x)) <-. +set e3' := exp_return _. +set e3_weak := exp_weak _ _ _ _. +rewrite !execP_letin. +suff: execP e3' = execP (e3_weak e3 r_f) by move=> <-. +rewrite execP_return/= exp_var'E (execD_var_erefl "x") /= /e3_weak. +rewrite (@execP_weak R [:: ("_", Unit)] _ ("r", Real) _ e3 r_f). +rewrite execP_return exp_var'E/= (execD_var_erefl "x") //=. +by apply/eq_sfkernel => /= -[[] [a [b []]]] U0. +Qed. + +Let poisson4 := @poisson R 4%N. + +Lemma exec_staton_busA0 U : execP staton_busA_syntax0 tt U = + ((2 / 7%:R)%:E * (poisson4 3%:R)%:E * \d_true U + + (5%:R / 7%:R)%:E * (poisson4 10%:R)%:E * \d_false U)%E. +Proof. by rewrite -staton_bus_staton_busA exec_staton_bus0. Qed. + +End staton_busA. + +Section letinC. +Local Open Scope lang_scope. +Variable (R : realType). + +Let weak_head g {t1 t2} x (e : @exp R P g t2) (xg : x \notin dom g) := + exp_weak P [::] _ (x, t1) e xg. + +Lemma letinC g t1 t2 (e1 : @exp R P g t1) (e2 : exp P g t2) + (x y : string) + (xy : infer (x != y)) (yx : infer (y != x)) + (xg : x \notin dom g) (yg : y \notin dom g) : + forall U, measurable U -> + execP [ + let x := e1 in + let y := {weak_head e2 xg} in + return (#x, #y)] ^~ U = + execP [ + let y := e2 in + let x := {weak_head e1 yg} in + return (#x, #y)] ^~ U. +Proof. +move=> U mU; apply/funext => z. +rewrite 4!execP_letin. +rewrite 2!(execP_weak [::] g). +rewrite 2!execP_return/=. +rewrite 2!execD_pair/=. +rewrite !exp_var'E. +- exact/(ctx_prf_tail _ yx)/ctx_prf_head. +- exact/ctx_prf_head. +- exact/ctx_prf_head. +- exact/(ctx_prf_tail _ xy)/ctx_prf_head. +- move=> h1 h2 h3 h4. + set g1 := [:: (y, t2), (x, t1) & g]. + set g2 := [:: (x, t1), (y, t2) & g]. + have /= := @execD_var R g1 x. + rewrite (negbTE yx) eqxx => /(_ h4) ->. + have /= := @execD_var R g2 x. + rewrite (negbTE yx) eqxx => /(_ h2) ->. + have /= := @execD_var R g1 y. + rewrite eqxx => /(_ h3) ->. + have /= := @execD_var R g2 y. + rewrite (negbTE xy) eqxx => /(_ h1) -> /=. + have -> : measurable_acc_typ [:: t2, t1 & map snd g] 0 = macc0of3' by []. + have -> : measurable_acc_typ [:: t2, t1 & map snd g] 1 = macc1of3' by []. + rewrite (letin'C _ _ (execP e2) + [the R.-sfker _ ~> _ of @kweak _ [::] _ (y, t2) _ (execP e1)]); + [ |by [] | by [] |by []]. + have -> : measurable_acc_typ [:: t1, t2 & map snd g] 0 = macc0of3' by []. + by have -> : measurable_acc_typ [:: t1, t2 & map snd g] 1 = macc1of3' by []. +Qed. + +Example letinC_ground_variables g t1 t2 (e1 : @exp R P g t1) (e2 : exp P g t2) + (x := "x") (y := "y") + (xg : x \notin dom g) (yg : y \notin dom g) : + forall U, measurable U -> + execP [ + let x := e1 in + let y := {exp_weak _ [::] _ (x, t1) e2 xg} in + return (#x, #y)] ^~ U = + execP [ + let y := e2 in + let x := {exp_weak _ [::] _ (y, t2) e1 yg} in + return (#x, #y)] ^~ U. +Proof. by move=> U mU; rewrite letinC. Qed. + +Example letinC_ground (g := [:: ("a", Unit); ("b", Bool)]) t1 t2 + (e1 : @exp R P g t1) + (e2 : exp P g t2) : + forall U, measurable U -> + execP [let "x" := e1 in + let "y" := e2 :+ {"x"} in + return (#{"x"}, #{"y"})] ^~ U = + execP [let "y" := e2 in + let "x" := e1 :+ {"y"} in + return (#{"x"}, #{"y"})] ^~ U. +Proof. move=> U mU; exact: letinC. Qed. + +End letinC. diff --git a/theories/lang_syntax_toy.v b/theories/lang_syntax_toy.v new file mode 100644 index 0000000000..a328ea1cd7 --- /dev/null +++ b/theories/lang_syntax_toy.v @@ -0,0 +1,550 @@ +Require Import String Classical. +From HB Require Import structures. +From mathcomp Require Import all_ssreflect ssralg. +From mathcomp.classical Require Import mathcomp_extra boolp. +Require Import signed reals topology normedtype. +Require Import lang_syntax_util. + +(******************************************************************************) +(* Intrinsically-typed concrete syntax for a toy language *) +(* *) +(* The main module provided by this file is "lang_intrinsic_tysc" which *) +(* provides an example of intrinsically-typed concrete syntax for a toy *) +(* language (a simplification of the syntax/evaluation formalized in *) +(* lang_syntax.v). Other modules provide even more simplified language for *) +(* pedagogical purposes. *) +(* *) +(* lang_extrinsic == non-intrinsic definition of expression *) +(* lang_intrinsic_ty == intrinsically-typed syntax *) +(* lang_intrinsic_sc == intrinsically-scoped syntax *) +(* lang_intrinsic_tysc == intrinsically-typed/scoped syntax *) +(* *) +(******************************************************************************) + +Set Implicit Arguments. +Unset Strict Implicit. +Set Printing Implicit Defensive. + +Import numFieldTopology.Exports. + +Local Open Scope classical_set_scope. +Local Open Scope ring_scope. +Local Open Scope ereal_scope. + +Section type. +Variables (R : realType). + +Inductive typ := Real | Unit. + +Canonical typ_eqType := Equality.Pack (@gen_eqMixin typ). + +Definition iter_pair (l : list Type) : Type := + foldr (fun x y => (x * y)%type) unit l. + +Definition Type_of_typ (t : typ) : Type := + match t with + | Real => R + | Unit => unit + end. + +Definition ctx := seq (string * typ). + +Definition Type_of_ctx (g : ctx) := iter_pair (map (Type_of_typ \o snd) g). + +Goal Type_of_ctx [:: ("x", Real); ("y", Real)] = (R * (R * unit))%type. +Proof. by []. Qed. + +End type. + +Module lang_extrinsic. +Section lang_extrinsic. +Variable R : realType. +Implicit Types str : string. + +Inductive exp : Type := +| exp_unit : exp +| exp_real : R -> exp +| exp_var (g : ctx) t str : t = lookup Unit g str -> exp +| exp_add : exp -> exp -> exp +| exp_letin str : exp -> exp -> exp. +Arguments exp_var {g t}. + +Fail Example letin_once : exp := + exp_letin "x" (exp_real 1) (exp_var "x" erefl). +Example letin_once : exp := + exp_letin "x" (exp_real 1) (@exp_var [:: ("x", Real)] Real "x" erefl). + +End lang_extrinsic. +End lang_extrinsic. + +Module lang_intrinsic_ty. +Section lang_intrinsic_ty. +Variable R : realType. +Implicit Types str : string. + +Inductive exp : typ -> Type := +| exp_unit : exp Unit +| exp_real : R -> exp Real +| exp_var g t str : t = lookup Unit g str -> exp t +| exp_add : exp Real -> exp Real -> exp Real +| exp_letin t u : string -> exp t -> exp u -> exp u. +Arguments exp_var {g t}. + +Fail Example letin_once : exp Real := + exp_letin "x" (exp_real 1) (exp_var "x" erefl). +Example letin_once : exp Real := + exp_letin "x" (exp_real 1) (@exp_var [:: ("x", Real)] _ "x" erefl). + +End lang_intrinsic_ty. +End lang_intrinsic_ty. + +Module lang_intrinsic_sc. +Section lang_intrinsic_sc. +Variable R : realType. +Implicit Types str : string. + +Inductive exp : ctx -> Type := +| exp_unit g : exp g +| exp_real g : R -> exp g +| exp_var g t str : t = lookup Unit g str -> exp g +| exp_add g : exp g -> exp g -> exp g +| exp_letin g t str : exp g -> exp ((str, t) :: g) -> exp g. +Arguments exp_real {g}. +Arguments exp_var {g t}. +Arguments exp_letin {g t}. + +Declare Custom Entry expr. + +Notation "[ e ]" := e (e custom expr at level 5). +Notation "{ x }" := x (in custom expr, x constr). +Notation "x ':R'" := (exp_real x) (in custom expr at level 1). +Notation "x" := x (in custom expr at level 0, x ident). +Notation "$ x" := (exp_var x erefl) (in custom expr at level 1). +Notation "x + y" := (exp_add x y) + (in custom expr at level 2, left associativity). +Notation "'let' x ':=' e1 'in' e2" := (exp_letin x e1 e2) + (in custom expr at level 3, x constr, + e1 custom expr at level 2, e2 custom expr at level 3, + left associativity). + +Fail Example letin_once : exp [::] := + [let "x" := {1%R}:R in ${"x"}]. +Example letin_once : exp [::] := + [let "x" := {1%R}:R in {@exp_var [:: ("x", Real)] _ "x" erefl}]. + +Fixpoint acc (g : ctx) (i : nat) : + Type_of_ctx R g -> @Type_of_typ R (nth Unit (map snd g) i) := + match g return Type_of_ctx R g -> Type_of_typ R (nth Unit (map snd g) i) with + | [::] => match i with | O => id | j.+1 => id end + | _ :: _ => match i with + | O => fst + | j.+1 => fun H => acc j H.2 + end + end. +Arguments acc : clear implicits. + +Inductive eval : forall g (t : typ), exp g -> (Type_of_ctx R g -> Type_of_typ R t) -> Prop := +| eval_real g c : @eval g Real [c:R] (fun=> c) +| eval_plus g (e1 e2 : exp g) (v1 v2 : R) : + @eval g Real e1 (fun=> v1) -> + @eval g Real e2 (fun=> v2) -> + @eval g Real [e1 + e2] (fun=> v1 + v2) +| eval_var (g : ctx) str i : + i = index str (map fst g) -> eval [$ str] (acc g i). + +Goal @eval [::] Real [{1}:R] (fun=> 1). +Proof. exact: eval_real. Qed. +Goal @eval [::] Real [{1}:R + {2}:R] (fun=> 3). +Proof. exact/eval_plus/eval_real/eval_real. Qed. +Goal @eval [:: ("x", Real)] _ [$ {"x"}] (acc [:: ("x", Real)] 0). +Proof. exact: eval_var. Qed. + +End lang_intrinsic_sc. +End lang_intrinsic_sc. + +Module lang_intrinsic_tysc. +Section lang_intrinsic_tysc. +Variable R : realType. +Implicit Types str : string. + +Inductive typ := Real | Unit | Pair : typ -> typ -> typ. + +Canonical typ_eqType := Equality.Pack (@gen_eqMixin typ). + +Fixpoint mtyp (t : typ) : Type := + match t with + | Real => R + | Unit => unit + | Pair t1 t2 => (mtyp t1 * mtyp t2) + end. + +Definition ctx := seq (string * typ). + +Definition Type_of_ctx (g : ctx) := iter_pair (map (mtyp \o snd) g). + +Goal Type_of_ctx [:: ("x", Real); ("y", Real)] = (R * (R * unit))%type. +Proof. by []. Qed. + +Inductive exp : ctx -> typ -> Type := +| exp_unit g : exp g Unit +| exp_real g : R -> exp g Real +| exp_var g t str : t = lookup Unit g str -> exp g t +| exp_add g : exp g Real -> exp g Real -> exp g Real +| exp_pair g t1 t2 : exp g t1 -> exp g t2 -> exp g (Pair t1 t2) +| exp_letin g t1 t2 x : exp g t1 -> exp ((x, t1) :: g) t2 -> exp g t2. + +Definition exp_var' str (t : typ) (g : find str t) := + @exp_var (untag (ctx_of g)) t str (ctx_prf g). + +Section no_bidirectional_hints. + +Arguments exp_unit {g}. +Arguments exp_real {g}. +Arguments exp_var {g t}. +Arguments exp_add {g}. +Arguments exp_pair {g t1 t2}. +Arguments exp_letin {g t1 t2}. +Arguments exp_var' str {t} g. + +Fail Example letin_add : exp [::] _ := + exp_letin "x" (exp_real 1) + (exp_letin "y" (exp_real 2) + (exp_add (exp_var "x" erefl) + (exp_var "y" erefl))). +Example letin_add : exp [::] _ := + exp_letin "x" (exp_real 1) + (exp_letin "y" (exp_real 2) + (exp_add (@exp_var [:: ("y", Real); ("x", Real)] _ "x" erefl) + (exp_var "y" erefl))). +Reset letin_add. + +Declare Custom Entry expr. + +Notation "[ e ]" := e (e custom expr at level 5). +Notation "{ x }" := x (in custom expr, x constr). +Notation "x ':R'" := (exp_real x) (in custom expr at level 1). +Notation "x" := x (in custom expr at level 0, x ident). +Notation "$ x" := (exp_var x erefl) (in custom expr at level 1). +Notation "# x" := (exp_var' x%string _) (in custom expr at level 1). +Notation "e1 + e2" := (exp_add e1 e2) + (in custom expr at level 2, + (* e1 custom expr at level 1, e2 custom expr at level 2, *) + left associativity). +Notation "( e1 , e2 )" := (exp_pair e1 e2) + (in custom expr at level 1). +Notation "'let' x ':=' e1 'in' e2" := (exp_letin x e1 e2) + (in custom expr at level 3, x constr, + e1 custom expr at level 2, e2 custom expr at level 3, + left associativity). + +Fail Definition let3_add_erefl (a b c : string) + (ba : infer (b != a)) (ca : infer (c != a)) (cb : infer (c != b)) + (ab : infer (a != b)) (ac : infer (a != c)) (bc : infer (b != c)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + let c := {3}:R in + $a + $b]. +(* The term "[$ a]" has type "exp ?g2 (lookup Unit ?g2 a)" while it is expected to have type "exp ?g2 Real". *) + +Definition let3_pair_erefl (a b c : string) + (ba : infer (b != a)) (ca : infer (c != a)) (cb : infer (c != b)) + (ab : infer (a != b)) (ac : infer (a != c)) (bc : infer (b != c)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + let c := {3}:R in + ($a, $b)]. + +Fail Definition let3_add (a b c : string) + (ba : infer (b != a)) (ca : infer (c != a)) (cb : infer (c != b)) + (ab : infer (a != b)) (ac : infer (a != c)) (bc : infer (b != c)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + let c := {3}:R in + #a + #b]. +(* The term "[# a + # b]" has type + "exp (untag (ctx_of (recurse (str':=b) Real ?f))) Real" +while it is expected to have type "exp ((c, Real) :: ?g1) ?u1" +(cannot unify "(b, Real)" and "(c, Real)"). *) + +Fail Definition let3_pair (a b c : string) + (ba : infer (b != a)) (ca : infer (c != a)) (cb : infer (c != b)) + (ab : infer (a != b)) (ac : infer (a != c)) (bc : infer (b != c)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + let c := {3}:R in + (#a, #b)]. +(* The term "[# a + # b]" has type "exp (untag (ctx_of (recurse (str':=b) Real ?f))) Real" while it is expected to have type + "exp ((c, Real) :: ?g1) ?u1" (cannot unify "(b, Real)" and "(c, Real)"). *) + +End no_bidirectional_hints. + +Section with_bidirectional_hints. + +Arguments exp_unit {g}. +Arguments exp_real {g}. +Arguments exp_var {g t}. +Arguments exp_add {g} &. +Arguments exp_pair {g} & {t1 t2}. +Arguments exp_letin {g} & {t1 t2}. +Arguments exp_var' str {t} g. + +Declare Custom Entry expr. + +Notation "[ e ]" := e (e custom expr at level 5). +Notation "{ x }" := x (in custom expr, x constr). +Notation "x ':R'" := (exp_real x) (in custom expr at level 1). +Notation "x" := x (in custom expr at level 0, x ident). +Notation "$ x" := (exp_var x%string erefl) (in custom expr at level 1). +Notation "# x" := (exp_var' x%string _) (in custom expr at level 1). +Notation "e1 + e2" := (exp_add e1 e2) + (in custom expr at level 2, + left associativity). +Notation "( e1 , e2 )" := (exp_pair e1 e2) + (in custom expr at level 1). +Notation "'let' x ':=' e1 'in' e2" := (exp_letin x e1 e2) + (in custom expr at level 3, x constr, + e1 custom expr at level 2, e2 custom expr at level 3, + left associativity). + +Fail Definition let2_add_erefl_bidi (a b : string) + (ba : infer (b != a)) (ab : infer (a != b)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + $a + $b]. + +Definition let2_add_erefl_bidi (a b : string) + (ba : infer (b != a)) (ab : infer (a != b)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + #a + #b]. + +Fail Definition let3_add_erefl_bidi (a b c d : string) + (ba : infer (b != a)) (ca : infer (c != a)) (cb : infer (c != b)) + (ab : infer (a != b)) (ac : infer (a != c)) (bc : infer (b != c)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + let c := {3}:R in + $a + $b]. +(* The term "[$ a]" has type "exp [:: (c, Real); (b, Real); (a, Real)] (lookup Unit [:: (c, Real); (b, Real); (a, Real)] a)" +while it is expected to have type "exp [:: (c, Real); (b, Real); (a, Real)] Real" +(cannot unify "lookup Unit [:: (c, Real); (b, Real); (a, Real)] a" and "Real"). *) + +Definition let3_pair_erefl_bidi (a b c : string) + (ba : infer (b != a)) (ca : infer (c != a)) (cb : infer (c != b)) + (ab : infer (a != b)) (ac : infer (a != c)) (bc : infer (b != c)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + let c := {3}:R in + ($a, $b)]. + +Check let3_pair_erefl_bidi. +(* exp [::] (Pair (lookup Unit [:: (c, Real); (b, Real); (a, Real)] a) (lookup Unit [:: (c, Real); (b, Real); (a, Real)] b)) *) + +Definition let3_add_bidi (a b c : string) + (ba : infer (b != a)) (ca : infer (c != a)) (cb : infer (c != b)) + (ab : infer (a != b)) (ac : infer (a != c)) (bc : infer (b != c)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + let c := {3}:R in + #a + #b]. + +Definition let3_pair_bidi (a b c : string) + (ba : infer (b != a)) (ca : infer (c != a)) (cb : infer (c != b)) + (ab : infer (a != b)) (ac : infer (a != c)) (bc : infer (b != c)) + : exp [::] _ := [ + let a := {1}:R in + let b := {2}:R in + let c := {3}:R in + (#a , #b)]. + +Check let3_pair_bidi. +(* exp [::] (Pair Real Real) *) + +Example e0 : exp [::] _ := exp_real 1. +Example letin1 : exp [::] _ := + exp_letin "x" (exp_real 1) (exp_var "x" erefl). +Example letin2 : exp [::] _ := + exp_letin "x" (exp_real 1) + (exp_letin "y" (exp_real 2) + (exp_var "x" erefl)). + +Example letin_add : exp [::] _ := + exp_letin "x" (exp_real 1) + (exp_letin "y" (exp_real 2) + (exp_add (exp_var "x" erefl) + (exp_var "y" erefl))). +Reset letin_add. +Fail Example letin_add (x y : string) + (xy : infer (x != y)) (yx : infer (y != x)) : exp [::] _ := + exp_letin x (exp_real 1) + (exp_letin y (exp_real 2) + (exp_add (exp_var x erefl) (exp_var y erefl))). +Example letin_add (x y : string) + (xy : infer (x != y)) (yx : infer (y != x)) : exp [::] _ := + exp_letin x (exp_real 1) + (exp_letin y (exp_real 2) + (exp_add (exp_var' x _) (exp_var' y _))). +Reset letin_add. + +Example letin_add_custom : exp [::] _ := + [let "x" := {1}:R in + let "y" := {2}:R in + #{"x"} + #{"y"}]. + +Section eval. + +Fixpoint acc (g : ctx) (i : nat) : + Type_of_ctx g -> mtyp (nth Unit (map snd g) i) := + match g return Type_of_ctx g -> mtyp (nth Unit (map snd g) i) with + | [::] => match i with | O => id | j.+1 => id end + | _ :: _ => match i with + | O => fst + | j.+1 => fun H => acc j H.2 + end + end. +Arguments acc : clear implicits. + +Reserved Notation "e '-e->' v" (at level 40). + +Inductive eval : forall g t, exp g t -> (Type_of_ctx g -> mtyp t) -> Prop := +| eval_tt g : (exp_unit : exp g _) -e-> (fun=> tt) +| eval_real g c : (exp_real c : exp g _) -e-> (fun=> c) +| eval_plus g (e1 e2 : exp g Real) v1 v2 : + e1 -e-> v1 -> + e2 -e-> v2 -> + [e1 + e2] -e-> fun x => v1 x + v2 x +| eval_var g str : + let i := index str (map fst g) in + exp_var str erefl -e-> acc g i +| eval_pair g t1 t2 e1 e2 v1 v2 : + e1 -e-> v1 -> + e2 -e-> v2 -> + @exp_pair g t1 t2 e1 e2 -e-> fun x => (v1 x, v2 x) +| eval_letin g t t' str (e1 : exp g t) (e2 : exp ((str, t) :: g) t') v1 v2 : + e1 -e-> v1 -> + e2 -e-> v2 -> + exp_letin str e1 e2 -e-> (fun a => v2 (v1 a, a)) +where "e '-e->' v" := (@eval _ _ e v). + +Lemma eval_uniq g t (e : exp g t) u v : + e -e-> u -> e -e-> v -> u = v. +Proof. +move=> hu. +apply: (@eval_ind + (fun g t (e : exp g t) (u : Type_of_ctx g -> mtyp t) => + forall v, e -e-> v -> u = v)); last exact: hu. +all: (rewrite {g t e u v hu}). +- move=> g v. + inversion 1. + by inj_ex H3. +- move=> g c v. + inversion 1. + by inj_ex H3. +- move=> g e1 e2 v1 v2 ev1 IH1 ev2 IH2 v. + inversion 1. + inj_ex H0; inj_ex H1; subst. + inj_ex H5; subst. + by rewrite (IH1 _ H3) (IH2 _ H4). +- move=> g x i v. + inversion 1. + by inj_ex H6; subst. +- move=> g t1 t2 e1 e2 v1 v2 ev1 IH1 ev2 IH2 v. + inversion 1. + inj_ex H3; inj_ex H4; subst. + inj_ex H5; subst. + by rewrite (IH1 _ H6) (IH2 _ H7). +- move=> g t t' x0 e0 e1 v1 v2 ev1 IH1 ev2 IH2 v. + inversion 1. + inj_ex H5; subst. + inj_ex H6; subst. + inj_ex H7; subst. + by rewrite (IH1 _ H4) (IH2 _ H8). +Qed. + +Lemma eval_total g t (e : exp g t) : exists v, e -e-> v. +Proof. +elim: e. +- by eexists; exact: eval_tt. +- by eexists; exact: eval_real. +- move=> {}g {}t x e; subst t. + by eexists; exact: eval_var. +- move=> {}g e1 [v1] IH1 e2 [v2] IH2. + by eexists; exact: (eval_plus IH1 IH2). +- move=> {}g t1 t2 e1 [v1] IH1 e2 [v2] IH2. + by eexists; exact: (eval_pair IH1 IH2). +- move=> {}g {}t u x e1 [v1] IH1 e2 [v2] IH2. + by eexists; exact: (eval_letin IH1 IH2). +Qed. + +Definition exec g t (e : exp g t) : Type_of_ctx g -> mtyp t := + proj1_sig (cid (@eval_total g t e)). + +Lemma exec_eval g t (e : exp g t) v : exec e = v <-> e -e-> v. +Proof. +split. + by move=> <-; rewrite /exec; case: cid. +move=> ev. +by rewrite /exec; case: cid => f H/=; apply: eval_uniq; eauto. +Qed. + +Lemma eval_exec g t (e : exp g t) : e -e-> exec e. +Proof. by rewrite /exec; case: cid. Qed. + +Lemma exec_real g r : @exec g Real (exp_real r) = (fun=> r). +Proof. exact/exec_eval/eval_real. Qed. + +Lemma exec_var g str t H : + exec (@exp_var _ t str H) = + eq_rect _ (fun a => Type_of_ctx g -> mtyp a) + (acc g (index str (map fst g))) + _ (esym H). +Proof. +subst t. +rewrite {1}/exec. +case: cid => f H. +inversion H; subst g0 str0. +by inj_ex H6; subst f. +Qed. + +Lemma exp_var'E str t (f : find str t) H : exp_var' str f = exp_var str H. +Proof. by rewrite /exp_var'; congr exp_var. Qed. + +Lemma exec_letin g x t1 t2 (e1 : exp g t1) (e2 : exp ((x, t1) :: g) t2) : + exec [let x := e1 in e2] = (fun a => (exec e2) ((exec e1) a, a)). +Proof. by apply/exec_eval/eval_letin; exact: eval_exec. Qed. + +Goal ([{1}:R] : exp [::] _) -e-> (fun=> 1). +Proof. exact: eval_real. Qed. +Goal @eval [::] _ [{1}:R + {2}:R] (fun=> 3). +Proof. exact/eval_plus/eval_real/eval_real. Qed. +Goal @eval [:: ("x", Real)] _ (exp_var "x" erefl) (@acc [:: ("x", Real)] 0). +Proof. exact: eval_var. Qed. +Goal @eval [::] _ [let "x" := {1}:R in #{"x"}] (fun=> 1). +Proof. +apply/exec_eval; rewrite exec_letin/=. +apply/funext => t/=. +by rewrite exp_var'E exec_real/= exec_var/=. +Qed. + +Goal exec (g := [::]) [let "x" := {1}:R in #{"x"}] = (fun=> 1). +Proof. +rewrite exec_letin//=. +apply/funext => x. +by rewrite exp_var'E exec_var/= exec_real. +Qed. + +End eval. + +End with_bidirectional_hints. + +End lang_intrinsic_tysc. +End lang_intrinsic_tysc. diff --git a/theories/lang_syntax_util.v b/theories/lang_syntax_util.v new file mode 100644 index 0000000000..ad84918c9b --- /dev/null +++ b/theories/lang_syntax_util.v @@ -0,0 +1,80 @@ +Require Import String. +From HB Require Import structures. +Require Import Classical_Prop. (* NB: to compile with Coq 8.17 *) +From mathcomp Require Import all_ssreflect. +Require Import signed. + +(******************************************************************************) +(* Shared by lang_syntax_*.v files *) +(******************************************************************************) + +Definition string_eqMixin := @EqMixin string String.eqb eqb_spec. +Canonical string_eqType := EqType string string_eqMixin. + +Ltac inj_ex H := revert H; + match goal with + | |- existT ?P ?l (existT ?Q ?t (existT ?R ?u (existT ?T ?v ?v1))) = + existT ?P ?l (existT ?Q ?t (existT ?R ?u (existT ?T ?v ?v2))) -> _ => + (intro H; do 4 apply Classical_Prop.EqdepTheory.inj_pair2 in H) + | |- existT ?P ?l (existT ?Q ?t (existT ?R ?u ?v1)) = + existT ?P ?l (existT ?Q ?t (existT ?R ?u ?v2)) -> _ => + (intro H; do 3 apply Classical_Prop.EqdepTheory.inj_pair2 in H) + | |- existT ?P ?l (existT ?Q ?t ?v1) = + existT ?P ?l (existT ?Q ?t ?v2) -> _ => + (intro H; do 2 apply Classical_Prop.EqdepTheory.inj_pair2 in H) + | |- existT ?P ?l (existT ?Q ?t ?v1) = + existT ?P ?l (existT ?Q ?t' ?v2) -> _ => + (intro H; do 2 apply Classical_Prop.EqdepTheory.inj_pair2 in H) + | |- existT ?P ?l ?v1 = + existT ?P ?l ?v2 -> _ => + (intro H; apply Classical_Prop.EqdepTheory.inj_pair2 in H) + | |- existT ?P ?l ?v1 = + existT ?P ?l' ?v2 -> _ => + (intro H; apply Classical_Prop.EqdepTheory.inj_pair2 in H) +end. + +Set Implicit Arguments. +Unset Strict Implicit. +Set Printing Implicit Defensive. + +Section tagged_context. +Context {T : eqType} {t0 : T}. +Let ctx := seq (string * T). +Implicit Types (str : string) (g : ctx) (t : T). + +Definition dom g := map fst g. + +Definition lookup g str := nth t0 (map snd g) (index str (dom g)). + +Structure tagged_ctx := Tag {untag : ctx}. + +Structure find str t := Find { + ctx_of : tagged_ctx ; + #[canonical=no] ctx_prf : t = lookup (untag ctx_of) str}. + +Lemma ctx_prf_head str t g : t = lookup ((str, t) :: g) str. +Proof. by rewrite /lookup /= !eqxx. Qed. + +Lemma ctx_prf_tail str t g str' t' : + str' != str -> + t = lookup g str -> + t = lookup ((str', t') :: g) str. +Proof. +move=> str'str tg /=; rewrite /lookup/=. +by case: ifPn => //=; rewrite (negbTE str'str). +Qed. + +Definition recurse_tag g := Tag g. +Canonical found_tag g := recurse_tag g. + +Canonical found str t g : find str t := + @Find str t (found_tag ((str, t) :: g)) + (@ctx_prf_head str t g). + +Canonical recurse str t str' t' {H : infer (str' != str)} + (g : find str t) : find str t := + @Find str t (recurse_tag ((str', t') :: untag (ctx_of g))) + (@ctx_prf_tail str t (untag (ctx_of g)) str' t' H (ctx_prf g)). + +End tagged_context. +Arguments lookup {T} t0 g str. diff --git a/theories/lebesgue_integral.v b/theories/lebesgue_integral.v index 308a2a2609..cea039714f 100644 --- a/theories/lebesgue_integral.v +++ b/theories/lebesgue_integral.v @@ -3491,7 +3491,7 @@ by rewrite mule0 -eq_le => /eqP. Qed. Lemma ae_eq_integral_abs (D : set T) (mD : measurable D) (f : T -> \bar R) : - measurable_fun D f -> \int[mu]_(x in D) `|f x| = 0 <-> ae_eq D f (cst 0). + measurable_fun D f -> \int[mu]_(x in D) `|f x| = 0 <-> ae_eq D f (cst 0). Proof. move=> mf; split=> [iDf0|Df0]. exists (D `&` [set x | f x != 0]); split; diff --git a/theories/lebesgue_measure.v b/theories/lebesgue_measure.v index 714117b0c8..7e03bac8da 100644 --- a/theories/lebesgue_measure.v +++ b/theories/lebesgue_measure.v @@ -1594,6 +1594,33 @@ move=> mf mg mD Y mY; have [| | |] := set_bool Y => /eqP ->. - by rewrite preimage_setT setIT. Qed. +Lemma measurable_fun_ler D f g : measurable_fun D f -> measurable_fun D g -> + measurable_fun D (fun x => f x <= g x). +Proof. +move=> mf mg mD Y mY; have [| | |] := set_bool Y => /eqP ->. +- under eq_fun do rewrite -subr_ge0. + rewrite preimage_true -preimage_itv_c_infty. + by apply: (measurable_funB mg mf) => //; exact: measurable_itv. +- under eq_fun do rewrite leNgt -subr_gt0. + rewrite preimage_false set_predC setCK -preimage_itv_o_infty. + by apply: (measurable_funB mf mg) => //; exact: measurable_itv. +- by rewrite preimage_set0 setI0. +- by rewrite preimage_setT setIT. +Qed. + +(* setT should be D? (derived from measurable_and) *) +Lemma measurable_fun_eqr D f g : measurable_fun D f -> measurable_fun D g -> + measurable_fun D (fun x => f x == g x). +Proof. +move=> mf mg. +rewrite (_ : (fun x : T => f x == g x) = (fun x : T => (f x <= g x) && (g x <= f x))). +apply: (@measurable_and _ _ _ (fun x => f x <= g x) (fun x => g x <= f x)); exact: measurable_fun_ler. +apply: funext => x. +apply/eqP/idP => [->|/andP[Hfg Hgf]]. +by apply/andP. +by apply/le_anti/andP. +Qed. + Lemma measurable_maxr D f g : measurable_fun D f -> measurable_fun D g -> measurable_fun D (f \max g). Proof. diff --git a/theories/measure.v b/theories/measure.v index 8b167b6ba0..2ff3da3eca 100644 --- a/theories/measure.v +++ b/theories/measure.v @@ -100,15 +100,16 @@ From HB Require Import structures. (* The HB class is FiniteMeasure. *) (* SigmaFinite_isFinite == mixin for finite measures *) (* Measure_isFinite == factory for finite measures *) -(* subprobability T R == subprobability measure over the measurableType *) -(* T with values in \bar R with R : realType *) +(* subprobability T R == subprobability measure over the *) +(* measurableType T with values in \bar R with *) +(* R : realType *) (* The HB class is SubProbability. *) -(* probability T R == probability measure over the measurableType T *) +(* probability T R == probability measure over the measurableType T *) (* with values in \bar with R : realType *) (* probability == type of probability measures *) (* The HB class is Probability. *) (* Measure_isProbability == factor for probability measures *) -(* mnormalize mu == normalization of a measure to a probability *) +(* mnormalize mu == normalization of a measure to a probability *) (* {outer_measure set T -> \bar R} == type of an outer measure over sets *) (* of elements of type T : Type where R is *) (* expected to be a numFieldType *) @@ -1137,8 +1138,8 @@ Lemma measurable_fun_if (g h : T1 -> T2) D (mD : measurable D) measurable_fun D (fun t => if f t then g t else h t). Proof. move=> mx my /= _ B mB; rewrite (_ : _ @^-1` B = - ((f @^-1` [set true]) `&` (g @^-1` B)) `|` - ((f @^-1` [set false]) `&` (h @^-1` B))). + ((f @^-1` [set true]) `&` (g @^-1` B)) `|` + ((f @^-1` [set false]) `&` (h @^-1` B))). rewrite setIUr; apply: measurableU. - by rewrite setIA; apply: mx => //; exact: mf. - by rewrite setIA; apply: my => //; exact: mf. @@ -1172,6 +1173,64 @@ have [-> _|-> _|-> _ |-> _] := subset_set2 YT. - by rewrite -setT_bool preimage_setT setIT. Qed. +Lemma measurable_fun_TF D (f : T1 -> bool) : + measurable (D `&` f @^-1` [set true]) -> + measurable (D `&` f @^-1` [set false]) -> measurable_fun D f. +Proof. +move=> mT mF mD /= Y mY. +have := @subsetT _ Y; rewrite setT_bool => YT. +move: mY; have [-> _|-> _|-> _ |-> _] := subset_set2 YT. +- by rewrite preimage0 ?setI0. +apply: mT. +apply: mF. +- by rewrite -setT_bool preimage_setT setIT. +Qed. + +Lemma measurable_and D (f : T1 -> bool) (g : T1 -> bool) : + measurable_fun D f -> measurable_fun D g -> + measurable_fun D (fun x => f x && g x). +Proof. +move=> mf mg mD. +apply: measurable_fun_TF => //. +rewrite [X in measurable X](_ : _ = D `&` f @^-1` [set true] `&` (D `&` g @^-1` [set true])); last first. +rewrite setICA !setIA setIid. +rewrite -setIA. +congr (_ `&` _). +apply/seteqP; split => x /andP //=. +apply: measurableI. +apply: mf => //. apply: mg => //. +rewrite [X in measurable X](_ : _ = D `&` f @^-1` [set false] `|` (D `&` g @^-1` [set false])); last first. +rewrite -setIUr. +congr (_ `&` _). +apply/seteqP; split => x /=. +by case: (f x); case: (g x); [|right|left|left]. +case: (f x); case: (g x) => //=; by case. +apply: measurableU. +exact: mf. +exact: mg. +Qed. + +Lemma measurable_or D (f g : T1 -> bool) : + measurable_fun D f -> measurable_fun D g -> + measurable_fun D (fun x => f x || g x). +Proof. +move=> mf mg mD; apply: measurable_fun_TF => //. +rewrite [X in measurable X](_ : _ = D `&` f @^-1` [set true] `|` D `&` g @^-1` [set true]). + apply: measurableU. + apply: mf => //. + apply: mg => //. + rewrite -setIUr. + congr (_ `&` _). + by apply/seteqP; split=> x /orP. +rewrite [X in measurable X](_ : _ = D `&` f @^-1` [set false] `&` (D `&` g @^-1` [set false])). + apply: measurableI. + apply: mf => //. + apply: mg => //. + rewrite setICA !setIA setIid -setIA. + congr (_ `&` _). + apply/seteqP; split => x //=; case: (f x); case: (g x) => //; by case. +Qed. + End measurable_fun. #[global] Hint Extern 0 (measurable_fun _ (fun=> _)) => solve [apply: measurable_cst] : core. diff --git a/theories/prob_lang.v b/theories/prob_lang.v new file mode 100644 index 0000000000..cd5f72ad69 --- /dev/null +++ b/theories/prob_lang.v @@ -0,0 +1,1978 @@ +(* mathcomp analysis (c) 2022 Inria and AIST. License: CeCILL-C. *) +From HB Require Import structures. +From mathcomp Require Import all_ssreflect ssralg ssrnum ssrint interval finmap. +From mathcomp Require Import rat. +From mathcomp.classical Require Import mathcomp_extra boolp classical_sets. +From mathcomp.classical Require Import functions cardinality fsbigop. +Require Import signed reals ereal signed itv topology normedtype sequences esum measure. +Require Import lebesgue_measure numfun lebesgue_integral exp kernel. +From mathcomp Require Import ring lra. + +(******************************************************************************) +(* Semantics of a probabilistic programming language using s-finite kernels *) +(* *) +(* bernoulli r1 == Bernoulli probability with r1 a proof that *) +(* r : {nonneg R} is smaller than 1 *) +(* uniform_probability a b ab0 == uniform probability over the interval [a,b] *) +(* sample mP == sample according to the probability P where mP is a *) +(* proof that P is a measurable function *) +(* letin l k == execute l, augment the context, and execute k *) +(* ret mf == access the context with f and return the result *) +(* score mf == observe t from d, where f is the density of d and *) +(* t occurs in f *) +(* e.g., score (r e^(-r * t)) = observe t from exp(r) *) +(* normalize k P == normalize the kernel k into a probability kernel, *) +(* P is a default probability in case normalization is *) +(* not possible *) +(* ite mf k1 k2 == access the context with the boolean function f and *) +(* behaves as k1 or k2 according to the result *) +(* *) +(* poisson == Poisson distribution function *) +(* exp_density == density function for exponential distribution *) +(* *) +(******************************************************************************) + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. +Import Order.TTheory GRing.Theory Num.Def Num.ExtraDef Num.Theory. +Import numFieldTopology.Exports. + +Local Open Scope classical_set_scope. +Local Open Scope ring_scope. +Local Open Scope ereal_scope. + +Module Notations. + +(*Notation var1of2 := (@measurable_fst _ _ _ _). +Notation var2of2 := (@measurable_snd _ _ _ _). +Notation var1of3 := (measurableT_comp (@measurable_fst _ _ _ _) + (@measurable_fst _ _ _ _)). +Notation var2of3 := (measurableT_comp (@measurable_snd _ _ _ _) + (@measurable_fst _ _ _ _)). +Notation var3of3 := (@measurable_snd _ _ _ _).*) + +Notation mR := Real_sort__canonical__measure_Measurable. +Notation munit := Datatypes_unit__canonical__measure_Measurable. +Notation mbool := Datatypes_bool__canonical__measure_Measurable. + +End Notations. + +(* TODO: PR *) +Lemma onem_nonneg_proof (R : numDomainType) (p : {nonneg R}) : + (p%:num <= 1 -> 0 <= `1-(p%:num))%R. +Proof. by rewrite /onem/= subr_ge0. Qed. + +Definition onem_nonneg (R : numDomainType) (p : {nonneg R}) + (p1 : (p%:num <= 1)%R) := + NngNum (onem_nonneg_proof p1). +(* /TODO: PR *) + +Lemma invr_nonneg_proof (R : numDomainType) (p : {nonneg R}) : + (0 <= (p%:num)^-1)%R. +Proof. by rewrite invr_ge0. Qed. + +Definition invr_nonneg (R : numDomainType) (p : {nonneg R}) := + NngNum (invr_nonneg_proof p). + +(* TODO: move *) +Lemma eq_probability R d (Y : measurableType d) (m1 m2 : probability Y R) : + (m1 =1 m2 :> (set Y -> \bar R)) -> m1 = m2. +Proof. +move: m1 m2 => [m1 +] [m2 +] /= m1m2. +move/funext : m1m2 => <- -[[c11 c12] [m01] [sf1] [sig1] [fin1] [sub1] [p1]] + [[c21 c22] [m02] [sf2] [sig2] [fin2] [sub2] [p2]]. +have ? : c11 = c21 by []. +subst c21. +have ? : c12 = c22 by []. +subst c22. +have ? : m01 = m02 by []. +subst m02. +have ? : sf1 = sf2 by []. +subst sf2. +have ? : sig1 = sig2 by []. +subst sig2. +have ? : fin1 = fin2 by []. +subst fin2. +have ? : sub1 = sub2 by []. +subst sub2. +have ? : p1 = p2 by []. +subst p2. +by f_equal. +Qed. + +Section constants. +Variable R : realType. +Local Open Scope ring_scope. + +Lemma onem1S n : `1- (1 / n.+1%:R) = (n%:R / n.+1%:R)%:nng%:num :> R. +Proof. +by rewrite /onem/= -{1}(@divrr _ n.+1%:R) ?unitfE// -mulrBl -natr1 addrK. +Qed. + +Lemma p1S n : (1 / n.+1%:R)%:nng%:num <= 1 :> R. +Proof. by rewrite ler_pdivr_mulr//= mul1r ler1n. Qed. + +Lemma p12 : (1 / 2%:R)%:nng%:num <= 1 :> R. Proof. by rewrite p1S. Qed. + +Lemma p14 : (1 / 4%:R)%:nng%:num <= 1 :> R. Proof. by rewrite p1S. Qed. + +Lemma onem27 : `1- (2 / 7%:R) = (5%:R / 7%:R)%:nng%:num :> R. +Proof. by apply/eqP; rewrite subr_eq/= -mulrDl -natrD divrr// unitfE. Qed. + +Lemma p27 : (2 / 7%:R)%:nng%:num <= 1 :> R. +Proof. by rewrite /= lter_pdivr_mulr// mul1r ler_nat. Qed. + +End constants. +Arguments p12 {R}. +Arguments p14 {R}. +Arguments p27 {R}. +Arguments p1S {R}. + +Section bernoulli. +Variables (R : realType) (p : {nonneg R}) (p1 : (p%:num <= 1)%R). +Local Open Scope ring_scope. + +Definition bernoulli : set _ -> \bar R := + measure_add + [the measure _ _ of mscale p [the measure _ _ of dirac true]] + [the measure _ _ of mscale (onem_nonneg p1) [the measure _ _ of dirac false]]. + +HB.instance Definition _ := Measure.on bernoulli. + +Local Close Scope ring_scope. + +Let bernoulli_setT : bernoulli [set: _] = 1. +Proof. +rewrite /bernoulli/= /measure_add/= /msum 2!big_ord_recr/= big_ord0 add0e/=. +by rewrite /mscale/= !diracT !mule1 -EFinD add_onemK. +Qed. + +HB.instance Definition _ := + @Measure_isProbability.Build _ _ R bernoulli bernoulli_setT. + +End bernoulli. + +Lemma integral_bernoulli {R : realType} + (p : {nonneg R}) (p1 : (p%:num <= 1)%R) (f : bool -> set bool -> _) U : + (forall x y, 0 <= f x y) -> + \int[bernoulli p1]_y (f y ) U = + p%:num%:E * f true U + (`1-(p%:num))%:E * f false U. +Proof. +move=> f0. +rewrite ge0_integral_measure_sum// 2!big_ord_recl/= big_ord0 adde0/=. +by rewrite !ge0_integral_mscale//= !integral_dirac//= indicT 2!mul1e. +Qed. + +Section binomial_probability. +Context {R : realType} (n : nat) (p : {nonneg R}) (p1 : (p%:num <= 1)%R). +Local Open Scope ring_scope. + +(* C(n, k) * p^(n-k) * (1-p)^k *) +Definition bino_term (k : nat) :{nonneg R} := + (p%:num^+(n-k)%N * (NngNum (onem_ge0 p1))%:num^+k *+ 'C(n, k))%:nng. + +Lemma bino_term0 : + bino_term 0 = (p%:num^+n)%:nng. +Proof. +rewrite /bino_term bin0 subn0/=. +apply/val_inj => /=. +by field. +Qed. + +Lemma bino_term1 : + bino_term 1 = (p%:num^+(n-1)%N * (NngNum (onem_ge0 p1))%:num *+ n)%:nng. +Proof. +rewrite /bino_term bin1/=. +apply/val_inj => /=. +by rewrite expr1. +Qed. + +(* \sum_(k < n.+1) (bino_coef p n k) * \d_k. *) +Definition binomial_probability := + @msum _ (_ R) R + (fun k => [the measure _ _ of mscale (bino_term k) + [the measure _ _ of @dirac _ R k%:R R]]) n.+1. + +HB.instance Definition _ := Measure.on binomial_probability. + +Let binomial_setT : binomial_probability [set: _] = 1%:E. +Proof. +rewrite /binomial_probability/msum/mscale/bino_term/=/mscale/=. +under eq_bigr do rewrite diracT mule1. +rewrite sumEFin. +rewrite -exprDn_comm; last by rewrite /GRing.comm mulrC. +by rewrite add_onemK; congr _%:E; rewrite expr1n. +Qed. + +HB.instance Definition _ := + @Measure_isProbability.Build _ _ R binomial_probability binomial_setT. + +End binomial_probability. + +Section binomial_example. +Context {R : realType}. +Open Scope ring_scope. + +Lemma binomial3_2 : @binomial_probability R 3 _ (p1S 1) [set 2%:R] = (3 / 8)%:E. +Proof. +rewrite /binomial_probability/msum !big_ord_recl/= big_ord0 adde0 bino_term0. +rewrite /mscale/= !diracE /bump/=. +repeat rewrite ?binS ?bin0 ?bin1 ?bin_small//. +rewrite memNset//=; last by move/eqP; rewrite eqr_nat. +rewrite memNset//=; last by move/eqP; rewrite eqr_nat. +rewrite mem_set//=. +rewrite memNset//=; last by move/eqP; rewrite eqr_nat. +congr _%:E. +rewrite expr0 !mulr1 !mulr0 !add0r !addn0 !add0n /onem. +by field. +Qed. + +End binomial_example. + +Section uniform_probability. +Context (R : realType) (a b : R) (ab0 : (0 < b - a)%R). + +Definition uniform_probability : set R -> \bar R + := mscale (invr_nonneg (NngNum (ltW ab0))) + (mrestr lebesgue_measure (measurable_itv `[a, b])). + +(** TODO: set R -> \bar R を書くとMeasure.onが通らない **) +(** **) +(* HB.instance Definition _ := Measure.on uniform_probability. *) + +Let uniform0 : uniform_probability set0 = 0. +Proof. exact: measure0. Qed. + +Let uniform_ge0 U : 0 <= uniform_probability U. +Proof. exact: measure_ge0. Qed. + +Let uniform_sigma_additive : semi_sigma_additive uniform_probability. +Proof. move=> /= F mF tF mUF; exact: measure_semi_sigma_additive. Qed. + +HB.instance Definition _ := isMeasure.Build _ _ _ uniform_probability + uniform0 uniform_ge0 uniform_sigma_additive. + +Let uniform_probability_setT : uniform_probability [set: _] = 1%:E. +Proof. +rewrite /uniform_probability /mscale/= /mrestr/=. +rewrite setTI lebesgue_measure_itv hlength_itv/= lte_fin. +by rewrite -subr_gt0 ab0 -EFinD -EFinM mulVf// gt_eqF// subr_gt0. +Qed. + +HB.instance Definition _ := @Measure_isProbability.Build _ _ R + uniform_probability uniform_probability_setT. + +End uniform_probability. + +Section mscore. +Context d (T : measurableType d) (R : realType). +Variable f : T -> R. + +Definition mscore t : {measure set _ -> \bar R} := + let p := NngNum (normr_ge0 (f t)) in + [the measure _ _ of mscale p [the measure _ _ of dirac tt]]. + +Lemma mscoreE t U : mscore t U = if U == set0 then 0 else `| (f t)%:E |. +Proof. +rewrite /mscore/= /mscale/=; have [->|->] := set_unit U. + by rewrite eqxx dirac0 mule0. +by rewrite diracT mule1 (negbTE setT0). +Qed. + +Lemma measurable_fun_mscore U : measurable_fun setT f -> + measurable_fun setT (mscore ^~ U). +Proof. +move=> mr; under eq_fun do rewrite mscoreE/=. +have [U0|U0] := eqVneq U set0; first exact: measurable_cst. +by apply: measurableT_comp => //; exact: measurableT_comp. +Qed. + +End mscore. + +(* decomposition of score into finite kernels *) +Module SCORE. +Section score. +Context d (T : measurableType d) (R : realType). +Variable f : T -> R. + +Definition k (mf : measurable_fun setT f) i t U := + if i%:R%:E <= mscore f t U < i.+1%:R%:E then + mscore f t U + else + 0. + +Hypothesis mf : measurable_fun setT f. + +Lemma k0 i t : k mf i t (set0 : set unit) = 0 :> \bar R. +Proof. by rewrite /k measure0; case: ifP. Qed. + +Lemma k_ge0 i t B : 0 <= k mf i t B. +Proof. by rewrite /k; case: ifP. Qed. + +Lemma k_sigma_additive i t : semi_sigma_additive (k mf i t). +Proof. +move=> /= F mF tF mUF; rewrite /k /=. +have [F0|UF0] := eqVneq (\bigcup_n F n) set0. + rewrite F0 measure0 (_ : (fun _ => _) = cst 0). + by case: ifPn => _; exact: cvg_cst. + apply/funext => k; rewrite big1// => n _. + by move: F0 => /bigcup0P -> //; rewrite measure0; case: ifPn. +move: (UF0) => /eqP/bigcup0P/existsNP[m /not_implyP[_ /eqP Fm0]]. +rewrite [in X in _ --> X]mscoreE (negbTE UF0). +rewrite -(cvg_shiftn m.+1)/=. +case: ifPn => ir. + rewrite (_ : (fun _ => _) = cst `|(f t)%:E|); first exact: cvg_cst. + apply/funext => n. + rewrite big_mkord (bigD1 (widen_ord (leq_addl n _) (Ordinal (ltnSn m))))//=. + rewrite [in X in X + _]mscoreE (negbTE Fm0) ir big1 ?adde0// => /= j jk. + rewrite mscoreE; have /eqP -> : F j == set0. + have [/eqP//|Fjtt] := set_unit (F j). + move/trivIsetP : tF => /(_ j m Logic.I Logic.I jk). + by rewrite Fjtt setTI => /eqP; rewrite (negbTE Fm0). + by rewrite eqxx; case: ifP. +rewrite (_ : (fun _ => _) = cst 0); first exact: cvg_cst. +apply/funext => n. +rewrite big_mkord (bigD1 (widen_ord (leq_addl n _) (Ordinal (ltnSn m))))//=. +rewrite [in X in if X then _ else _]mscoreE (negbTE Fm0) (negbTE ir) add0e. +rewrite big1//= => j jm; rewrite mscoreE; have /eqP -> : F j == set0. + have [/eqP//|Fjtt] := set_unit (F j). + move/trivIsetP : tF => /(_ j m Logic.I Logic.I jm). + by rewrite Fjtt setTI => /eqP; rewrite (negbTE Fm0). +by rewrite eqxx; case: ifP. +Qed. + +HB.instance Definition _ i t := isMeasure.Build _ _ _ + (k mf i t) (k0 i t) (k_ge0 i t) (@k_sigma_additive i t). + +Lemma measurable_fun_k i U : measurable U -> measurable_fun setT (k mf i ^~ U). +Proof. +move=> /= mU; rewrite /k /= (_ : (fun x => _) = + (fun x => if i%:R%:E <= x < i.+1%:R%:E then x else 0) \o (mscore f ^~ U)) //. +apply: measurableT_comp => /=; last exact/measurable_fun_mscore. +rewrite (_ : (fun x => _) = (fun x => x * + (\1_(`[i%:R%:E, i.+1%:R%:E [%classic : set _) x)%:E)); last first. + apply/funext => x; case: ifPn => ix; first by rewrite indicE/= mem_set ?mule1. + by rewrite indicE/= memNset ?mule0// /= in_itv/=; exact/negP. +apply: emeasurable_funM => //=; apply/EFin_measurable_fun. +by rewrite (_ : \1__ = mindic R (emeasurable_itv `[(i%:R)%:E, (i.+1%:R)%:E[)). +Qed. + +Definition mk i t := [the measure _ _ of k mf i t]. + +HB.instance Definition _ i := + isKernel.Build _ _ _ _ _ (mk i) (measurable_fun_k i). + +Lemma mk_uub i : measure_fam_uub (mk i). +Proof. +exists i.+1%:R => /= t; rewrite /k mscoreE setT_unit. +by case: ifPn => //; case: ifPn => // _ /andP[]. +Qed. + +HB.instance Definition _ i := + Kernel_isFinite.Build _ _ _ _ _ (mk i) (mk_uub i). + +End score. +End SCORE. + +Section kscore. +Context d (T : measurableType d) (R : realType). +Variable f : T -> R. + +Definition kscore (mf : measurable_fun setT f) + : T -> {measure set _ -> \bar R} := + mscore f. + +Variable mf : measurable_fun setT f. + +Let measurable_fun_kscore U : measurable U -> + measurable_fun setT (kscore mf ^~ U). +Proof. by move=> /= _; exact: measurable_fun_mscore. Qed. + +HB.instance Definition _ := isKernel.Build _ _ T _ R + (kscore mf) measurable_fun_kscore. + +Import SCORE. + +Let sfinite_kscore : exists k : (R.-fker T ~> _)^nat, + forall x U, measurable U -> + kscore mf x U = mseries (k ^~ x) 0 U. +Proof. +rewrite /=; exists (fun i => [the R.-fker _ ~> _ of mk mf i]) => /= t U mU. +rewrite /mseries /kscore/= mscoreE; case: ifPn => [/eqP U0|U0]. + by apply/esym/eseries0 => i _; rewrite U0 measure0. +rewrite /mk /= /k /= mscoreE (negbTE U0). +apply/esym/cvg_lim => //. +rewrite -(cvg_shiftn `|floor (fine `|(f t)%:E|)|%N.+1)/=. +rewrite (_ : (fun _ => _) = cst `|(f t)%:E|); first exact: cvg_cst. +apply/funext => n. +pose floor_f := widen_ord (leq_addl n `|floor `|f t| |.+1) + (Ordinal (ltnSn `|floor `|f t| |)). +rewrite big_mkord (bigD1 floor_f)//= ifT; last first. + rewrite lee_fin lte_fin; apply/andP; split. + by rewrite natr_absz (@ger0_norm _ (floor `|f t|)) ?floor_ge0 ?floor_le. + rewrite -addn1 natrD natr_absz. + by rewrite (@ger0_norm _ (floor `|f t|)) ?floor_ge0 ?lt_succ_floor. +rewrite big1 ?adde0//= => j jk. +rewrite ifF// lte_fin lee_fin. +move: jk; rewrite neq_ltn/= => /orP[|] jr. +- suff : (j.+1%:R <= `|f t|)%R by rewrite leNgt => /negbTE ->; rewrite andbF. + rewrite (_ : j.+1%:R = j.+1%:~R)// floor_ge_int. + move: jr; rewrite -lez_nat => /le_trans; apply. + by rewrite -[leRHS](@ger0_norm _ (floor `|f t|)) ?floor_ge0. +- suff : (`|f t| < j%:R)%R by rewrite ltNge => /negbTE ->. + move: jr; rewrite -ltz_nat -(@ltr_int R) (@gez0_abs (floor `|f t|)) ?floor_ge0//. + by rewrite ltr_int -floor_lt_int. +Qed. + +HB.instance Definition _ := + @Kernel_isSFinite.Build _ _ _ _ _ (kscore mf) sfinite_kscore. + +End kscore. + +(* decomposition of ite into s-finite kernels *) +Module ITE. +Section ite. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). + +Section kiteT. +Variable k : R.-ker X ~> Y. + +Definition kiteT : X * bool -> {measure set Y -> \bar R} := + fun xb => if xb.2 then k xb.1 else [the measure _ _ of mzero]. + +Let measurable_fun_kiteT U : measurable U -> measurable_fun setT (kiteT ^~ U). +Proof. +move=> /= mcU; rewrite /kiteT. +rewrite (_ : (fun _ => _) = + (fun x => if x.2 then k x.1 U else mzero U)); last first. + by apply/funext => -[t b]/=; case: ifPn. +apply: (@measurable_fun_if_pair _ _ _ _ (k ^~ U) (fun=> mzero U)) => //. +exact/measurable_kernel. +Qed. + +#[export] +HB.instance Definition _ := isKernel.Build _ _ _ _ _ + kiteT measurable_fun_kiteT. +End kiteT. + +Section sfkiteT. +Variable k : R.-sfker X ~> Y. + +Let sfinite_kiteT : exists2 k_ : (R.-ker _ ~> _)^nat, + forall n, measure_fam_uub (k_ n) & + forall x U, measurable U -> kiteT k x U = mseries (k_ ^~ x) 0 U. +Proof. +have [k_ hk /=] := sfinite_kernel k. +exists (fun n => [the _.-ker _ ~> _ of kiteT (k_ n)]) => /=. + move=> n; have /measure_fam_uubP[r k_r] := measure_uub (k_ n). + by exists r%:num => /= -[x []]; rewrite /kiteT//= /mzero//. +move=> [x b] U mU; rewrite /kiteT; case: ifPn => hb; first by rewrite hk. +by rewrite /mseries eseries0. +Qed. + +#[export] +HB.instance Definition _ := @Kernel_isSFinite_subdef.Build _ _ _ _ _ + (kiteT k) sfinite_kiteT. +End sfkiteT. + +Section fkiteT. +Variable k : R.-fker X ~> Y. + +Let kiteT_uub : measure_fam_uub (kiteT k). +Proof. +have /measure_fam_uubP[M hM] := measure_uub k. +exists M%:num => /= -[]; rewrite /kiteT => t [|]/=; first exact: hM. +by rewrite /= /mzero. +Qed. + +#[export] +HB.instance Definition _ := Kernel_isFinite.Build _ _ _ _ _ + (kiteT k) kiteT_uub. +End fkiteT. + +Section kiteF. +Variable k : R.-ker X ~> Y. + +Definition kiteF : X * bool -> {measure set Y -> \bar R} := + fun xb => if ~~ xb.2 then k xb.1 else [the measure _ _ of mzero]. + +Let measurable_fun_kiteF U : measurable U -> measurable_fun setT (kiteF ^~ U). +Proof. +move=> /= mcU; rewrite /kiteF. +rewrite (_ : (fun x => _) = + (fun x => if x.2 then mzero U else k x.1 U)); last first. + by apply/funext => -[t b]/=; rewrite if_neg//; case: ifPn. +apply: (@measurable_fun_if_pair _ _ _ _ (fun=> mzero U) (k ^~ U)) => //. +exact/measurable_kernel. +Qed. + +#[export] +HB.instance Definition _ := isKernel.Build _ _ _ _ _ + kiteF measurable_fun_kiteF. + +End kiteF. + +Section sfkiteF. +Variable k : R.-sfker X ~> Y. + +Let sfinite_kiteF : exists2 k_ : (R.-ker _ ~> _)^nat, + forall n, measure_fam_uub (k_ n) & + forall x U, measurable U -> kiteF k x U = mseries (k_ ^~ x) 0 U. +Proof. +have [k_ hk /=] := sfinite_kernel k. +exists (fun n => [the _.-ker _ ~> _ of kiteF (k_ n)]) => /=. + move=> n; have /measure_fam_uubP[r k_r] := measure_uub (k_ n). + by exists r%:num => /= -[x []]; rewrite /kiteF//= /mzero//. +move=> [x b] U mU; rewrite /kiteF; case: ifPn => hb; first by rewrite hk. +by rewrite /mseries eseries0. +Qed. + +#[export] +HB.instance Definition _ := @Kernel_isSFinite_subdef.Build _ _ _ _ _ + (kiteF k) sfinite_kiteF. + +End sfkiteF. + +Section fkiteF. +Variable k : R.-fker X ~> Y. + +Let kiteF_uub : measure_fam_uub (kiteF k). +Proof. +have /measure_fam_uubP[M hM] := measure_uub k. +by exists M%:num => /= -[]; rewrite /kiteF/= => t; case => //=; rewrite /mzero. +Qed. + +#[export] +HB.instance Definition _ := Kernel_isFinite.Build _ _ _ _ _ + (kiteF k) kiteF_uub. + +End fkiteF. +End ite. +End ITE. + +Section ite. +Context d d' (T : measurableType d) (T' : measurableType d') (R : realType). +Variables (f : T -> bool) (u1 u2 : R.-sfker T ~> T'). + +(* NB: not used? *) +Definition mite (mf : measurable_fun setT f) : T -> set T' -> \bar R := + fun t => if f t then u1 t else u2 t. + +Variables mf : measurable_fun setT f. + +Let mite0 t : mite mf t set0 = 0. +Proof. by rewrite /mite; case: ifPn. Qed. + +Let mite_ge0 t U : 0 <= mite mf t U. +Proof. by rewrite /mite; case: ifPn. Qed. + +Let mite_sigma_additive t : semi_sigma_additive (mite mf t). +Proof. +by rewrite /mite; case: ifPn => ft; exact: measure_semi_sigma_additive. +Qed. + +HB.instance Definition _ t := isMeasure.Build _ _ _ (mite mf t) + (mite0 t) (mite_ge0 t) (@mite_sigma_additive t). + +Import ITE. + +(* +Definition kite : R.-sfker T ~> T' := + kdirac mf \; kadd (kiteT u1) (kiteF u2). +*) +Definition kite := + [the R.-sfker _ ~> _ of kdirac mf] \; + [the R.-sfker _ ~> _ of kadd + [the R.-sfker _ ~> T' of kiteT u1] + [the R.-sfker _ ~> T' of kiteF u2] ]. + +End ite. + +Section insn2. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). + +Definition ret (f : X -> Y) (mf : measurable_fun setT f) + : R.-pker X ~> Y := [the R.-pker _ ~> _ of kdirac mf]. + +Definition sample (P : X -> pprobability Y R) (mP : measurable_fun setT P) + : R.-pker X ~> Y := + [the R.-pker _ ~> _ of kprobability mP]. + +Definition sample_cst (P : pprobability Y R) : R.-pker X ~> Y := + sample (measurable_cst P). + +Definition normalize (k : R.-ker X ~> Y) P : X -> probability Y R := + knormalize k P. + +Definition normalize_pt (k : R.-ker X ~> Y) : X -> probability Y R := + normalize k point. + +Lemma measurable_normalize_pt (f : R.-ker X ~> Y) : + measurable_fun [set: X] (normalize_pt f : X -> pprobability Y R). +Proof. +apply: (@measurability _ _ _ _ _ _ + (@pset _ _ _ : set (set (pprobability Y R)))) => //. +move=> _ -[_ [r r01] [Ys mYs <-]] <-. +apply: emeasurable_fun_infty_o => //. +exact: (measurable_kernel (knormalize f point) Ys). +Qed. + +Definition ite (f : X -> bool) (mf : measurable_fun setT f) + (k1 k2 : R.-sfker X ~> Y) : R.-sfker X ~> Y := + locked [the R.-sfker X ~> Y of kite k1 k2 mf]. + +End insn2. +Arguments ret {d d' X Y R f} mf. +Arguments sample_cst {d d' X Y R}. +Arguments sample {d d' X Y R}. + +Section insn2_lemmas. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). + +Lemma retE (f : X -> Y) (mf : measurable_fun setT f) x : + ret mf x = \d_(f x) :> (_ -> \bar R). +Proof. by []. Qed. + +Lemma sample_cstE (P : probability Y R) (x : X) : sample_cst P x = P. +Proof. by []. Qed. + +Lemma sampleE (P : X -> pprobability Y R) (mP : measurable_fun setT P) (x : X) : sample P mP x = P x. +Proof. by []. Qed. + +Lemma normalizeE (f : R.-sfker X ~> Y) P x U : + normalize f P x U = + if (f x [set: Y] == 0) || (f x [set: Y] == +oo) then P U + else f x U * ((fine (f x [set: Y]))^-1)%:E. +Proof. by rewrite /normalize /= /mnormalize; case: ifPn. Qed. + +Lemma iteE (f : X -> bool) (mf : measurable_fun setT f) + (k1 k2 : R.-sfker X ~> Y) x : + ite mf k1 k2 x = if f x then k1 x else k2 x. +Proof. +apply/eq_measure/funext => U. +rewrite /ite; unlock => /=. +rewrite /kcomp/= integral_dirac//=. +rewrite indicT mul1e. +rewrite -/(measure_add (ITE.kiteT k1 (x, f x)) (ITE.kiteF k2 (x, f x))). +rewrite measure_addE. +rewrite /ITE.kiteT /ITE.kiteF/=. +by case: ifPn => fx /=; rewrite /mzero ?(adde0,add0e). +Qed. + +End insn2_lemmas. + +Lemma normalize_kdirac (R : realType) + d (T : measurableType d) d' (T' : measurableType d') (x : T) (r : T') P : + normalize (kdirac (measurable_cst r)) P x = \d_r :> probability T' R. +Proof. +apply: eq_probability => U. +rewrite normalizeE /= diracE in_setT/=. +by rewrite onee_eq0/= indicE in_setT/= -div1r divr1 mule1. +Qed. + +Section insn3. +Context d d' d3 (X : measurableType d) (Y : measurableType d') + (Z : measurableType d3) (R : realType). + +Definition letin (l : R.-sfker X ~> Y) (k : R.-sfker [the measurableType _ of (X * Y)%type] ~> Z) + : R.-sfker X ~> Z := + [the R.-sfker X ~> Z of l \; k]. + +End insn3. + +Section insn3_lemmas. +Context d d' d3 (X : measurableType d) (Y : measurableType d') + (Z : measurableType d3) (R : realType). + +Lemma letinE (l : R.-sfker X ~> Y) (k : R.-sfker [the measurableType _ of (X * Y)%type] ~> Z) x U : + letin l k x U = \int[l x]_y k (x, y) U. +Proof. by []. Qed. + +End insn3_lemmas. + +(* rewriting laws *) +Section letin_return. +Context d d' d3 (X : measurableType d) (Y : measurableType d') + (Z : measurableType d3) (R : realType). + +Lemma letin_kret (k : R.-sfker X ~> Y) + (f : X * Y -> Z) (mf : measurable_fun setT f) x U : + measurable U -> + letin k (ret mf) x U = k x (curry f x @^-1` U). +Proof. +move=> mU; rewrite letinE. +under eq_integral do rewrite retE. +rewrite integral_indic ?setIT// -[X in measurable X]setTI. +exact: (measurableT_comp mf). +Qed. + +Lemma letin_retk + (f : X -> Y) (mf : measurable_fun setT f) + (k : R.-sfker [the measurableType _ of (X * Y)%type] ~> Z) + x U : measurable U -> + letin (ret mf) k x U = k (x, f x) U. +Proof. +move=> mU; rewrite letinE retE integral_dirac ?indicT ?mul1e//. +exact: (measurableT_comp (measurable_kernel k _ mU)). +Qed. + +End letin_return. + +Section insn1. +Context d (X : measurableType d) (R : realType). + +Definition score (f : X -> R) (mf : measurable_fun setT f) + : R.-sfker X ~> _ := + [the R.-sfker X ~> _ of kscore mf]. + +End insn1. + +Section hard_constraint. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). + +Definition fail := + letin (score (@measurable_cst _ _ X _ setT (0%R : R))) + (ret (@measurable_cst _ _ _ Y setT point)). + +Lemma failE x U : fail x U = 0. +Proof. by rewrite /fail letinE ge0_integral_mscale//= normr0 mul0e. Qed. + +End hard_constraint. +Arguments fail {d d' X Y R}. + +Section cst_fun. +Context d (T : measurableType d) (R : realType). + +Definition kr (r : R) := @measurable_cst _ _ T _ setT r. +Definition k3 : measurable_fun _ _ := kr 3%:R. +Definition k10 : measurable_fun _ _ := kr 10%:R. +Definition ktt := @measurable_cst _ _ T _ setT tt. +Definition kb (b : bool) := @measurable_cst _ _ T _ setT b. + +End cst_fun. +Arguments kr {d T R}. +Arguments k3 {d T R}. +Arguments k10 {d T R}. +Arguments ktt {d T}. +Arguments kb {d T}. + +Section iter_mprod. +Import Notations. + +Fixpoint iter_mprod (l : list {d & measurableType d}) + : {d & measurableType d} := + match l with + | [::] => existT measurableType _ munit + | h :: t => let t' := iter_mprod t in + existT _ _ [the measurableType (projT1 h, projT1 t').-prod of + (projT2 h * projT2 t')%type] + end. + +End iter_mprod. + +Section acc. +Import Notations. +Context {R : realType}. + +Fixpoint acc (l : seq {d & measurableType d}) n : + projT2 (iter_mprod l) -> projT2 (nth (existT _ _ munit) l n) := + match l return + projT2 (iter_mprod l) -> projT2 (nth (existT _ _ munit) l n) + with + | [::] => match n with | O => id | m.+1 => id end + | _ :: _ => match n with + | O => fst + | m.+1 => fun H => acc m H.2 + end + end. + +Lemma measurable_acc (l : seq {d & measurableType d}) n : + measurable_fun setT (@acc l n). +Proof. +by elim: l n => //= h t ih [|m] //; exact: (measurableT_comp (ih _)). +Qed. +End acc. +Arguments acc : clear implicits. +Arguments measurable_acc : clear implicits. + +Section rpair_pairA. +Context d0 d1 d2 (T0 : measurableType d0) (T1 : measurableType d1) + (T2 : measurableType d2). + +Definition rpair d (T : measurableType d) t : + ([the measurableType _ of T0] -> [the measurableType _ of T0 * T])%type := + fun x => (x, t). + +Lemma mrpair d (T : measurableType d) t : measurable_fun setT (@rpair _ T t). +Proof. exact: measurable_fun_prod. Qed. + +Definition pairA : ([the measurableType _ of T0 * T1 * T2] -> + [the measurableType _ of T0 * (T1 * T2)])%type := + fun x => (x.1.1, (x.1.2, x.2)). + +Definition mpairA : measurable_fun setT pairA. +Proof. +apply: measurable_fun_prod => /=; first exact: measurableT_comp. +by apply: measurable_fun_prod => //=; exact: measurableT_comp. +Qed. + +Definition pairAi : ([the measurableType _ of T0 * (T1 * T2)] -> + [the measurableType _ of T0 * T1 * T2])%type := + fun x => (x.1, x.2.1, x.2.2). + +Definition mpairAi : measurable_fun setT pairAi. +Proof. +apply: measurable_fun_prod => //=; last exact: measurableT_comp. +apply: measurable_fun_prod => //=; exact: measurableT_comp. +Qed. + +End rpair_pairA. +Arguments rpair {d0 T0 d} T. +#[global] Hint Extern 0 (measurable_fun _ (rpair _ _)) => + solve [apply: mrpair] : core. +Arguments pairA {d0 d1 d2 T0 T1 T2}. +#[global] Hint Extern 0 (measurable_fun _ pairA) => + solve [apply: mpairA] : core. +Arguments pairAi {d0 d1 d2 T0 T1 T2}. +#[global] Hint Extern 0 (measurable_fun _ pairAi) => + solve [apply: mpairAi] : core. + +Section rpair_pairA_comp. +Import Notations. +Context d0 d1 d2 d3 (T0 : measurableType d0) (T1 : measurableType d1) + (T2 : measurableType d2) (T3 : measurableType d3) (R : realType). + +Definition pairAr d (T : measurableType d) t : + ([the measurableType _ of T0 * T1] -> + [the measurableType _ of T0 * (T1 * T)])%type := + pairA \o rpair T t. +Arguments pairAr {d} T. + +Lemma mpairAr d (T : measurableType d) t : measurable_fun setT (pairAr T t). +Proof. exact: measurableT_comp. Qed. + +Definition pairAAr : ([the measurableType _ of T0 * T1 * T2] -> + [the measurableType _ of T0 * (T1 * (T2 * munit))])%type := + pairA \o pairA \o rpair munit tt. + +Lemma mpairAAr : measurable_fun setT pairAAr. +Proof. by do 2 apply: measurableT_comp => //. Qed. + +Definition pairAAAr : ([the measurableType _ of T0 * T1 * T2 * T3] -> + [the measurableType _ of T0 * (T1 * (T2 * (T3 * munit)))])%type := + pairA \o pairA \o pairA \o rpair munit tt. + +Lemma mpairAAAr : measurable_fun setT pairAAAr. +Proof. by do 3 apply: measurableT_comp => //. Qed. + +Definition pairAArAi : ([the measurableType _ of T0 * (T1 * T2)] -> + [the measurableType _ of T0 * (T1 * (T2 * munit))])%type := + pairAAr \o pairAi. + +Lemma mpairAArAi : measurable_fun setT pairAArAi. +Proof. by apply: measurableT_comp => //=; exact: mpairAAr. Qed. + +Definition pairAAArAAi : ([the measurableType _ of T3 * (T0 * (T1 * T2))] -> + [the measurableType _ of T3 * (T0 * (T1 * (T2 * munit)))])%type := + pairA \o pairA \o pairA \o rpair munit tt \o pairAi \o pairAi. + +Lemma mpairAAARAAAi : measurable_fun setT pairAAArAAi. +Proof. by do 5 apply: measurableT_comp => //=. Qed. + +End rpair_pairA_comp. +Arguments pairAr {d0 d1 T0 T1 d} T. +Arguments pairAAr {d0 d1 d2 T0 T1 T2}. +Arguments pairAAAr {d0 d1 d2 d3 T0 T1 T2 T3}. +Arguments pairAArAi {d0 d1 d2 T0 T1 T2}. +Arguments pairAAArAAi {d0 d1 d2 d3 T0 T1 T2 T3}. + +Section accessor_functions. +Import Notations. +Context d0 d1 d2 d3 (T0 : measurableType d0) (T1 : measurableType d1) + (T2 : measurableType d2) (T3 : measurableType d3) (R : realType). + +Definition Of2 := [:: existT _ _ T0; existT _ _ T1]. + +Definition acc0of2 : [the measurableType _ of (T0 * T1)%type] -> T0 := + @acc Of2 0 \o pairAr munit tt. + +Lemma macc0of2 : measurable_fun setT acc0of2. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc Of2 0)|exact: mpairAr]. +Qed. + +Definition acc1of2 : [the measurableType _ of (T0 * T1)%type] -> T1 := + acc Of2 1 \o pairAr munit tt. + +Lemma macc1of2 : measurable_fun setT acc1of2. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc Of2 1)|exact: mpairAr]. +Qed. + +Definition Of3 := [:: existT _ _ T0; existT _ _ T1; existT _ d2 T2]. + +Definition acc1of3 : [the measurableType _ of (T0 * T1 * T2)%type] -> T1 := + acc Of3 1 \o pairAAr. + +Lemma macc1of3 : measurable_fun setT acc1of3. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc Of3 1)|exact: mpairAAr]. +Qed. + +Definition acc2of3 : [the measurableType _ of (T0 * T1 * T2)%type] -> T2 := + acc Of3 2 \o pairAAr. + +Lemma macc2of3 : measurable_fun setT acc2of3. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc Of3 2)|exact: mpairAAr]. +Qed. + +Definition acc0of3' : [the measurableType _ of (T0 * (T1 * T2))%type] -> T0 := + acc Of3 0 \o pairAArAi. + +Lemma macc0of3' : measurable_fun setT acc0of3'. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc Of3 0)|exact: mpairAArAi]. +Qed. + +Definition acc1of3' : [the measurableType _ of (T0 * (T1 * T2))%type] -> T1 := + acc Of3 1 \o pairAArAi. + +Lemma macc1of3' : measurable_fun setT acc1of3'. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc Of3 1)|exact: mpairAArAi]. +Qed. + +Definition acc2of3' : [the measurableType _ of (T0 * (T1 * T2))%type] -> T2 := + acc Of3 2 \o pairAArAi. + +Lemma macc2of3' : measurable_fun setT acc2of3'. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc Of3 2)|exact: mpairAArAi]. +Qed. + +Definition Of4 := + [:: existT _ _ T0; existT _ _ T1; existT _ d2 T2; existT _ d3 T3]. + +Definition acc1of4 : [the measurableType _ of (T0 * T1 * T2 * T3)%type] -> T1 := + acc Of4 1 \o pairAAAr. + +Lemma macc1of4 : measurable_fun setT acc1of4. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc Of4 1)|exact: mpairAAAr]. +Qed. + +Definition acc2of4' : + [the measurableType _ of (T0 * (T1 * (T2 * T3)))%type] -> T2 := + acc Of4 2 \o pairAAArAAi. + +Lemma macc2of4' : measurable_fun setT acc2of4'. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc Of4 2)|exact: mpairAAARAAAi]. +Qed. + +Definition acc3of4 : [the measurableType _ of (T0 * T1 * T2 * T3)%type] -> T3 := + acc Of4 3 \o pairAAAr. + +Lemma macc3of4 : measurable_fun setT acc3of4. +Proof. +by apply: measurableT_comp; [exact: (measurable_acc Of4 3)|exact: mpairAAAr]. +Qed. + +End accessor_functions. +Arguments macc0of2 {d0 d1 _ _}. +Arguments macc1of2 {d0 d1 _ _}. +Arguments macc0of3' {d0 d1 d2 _ _ _}. +Arguments macc1of3 {d0 d1 d2 _ _ _}. +Arguments macc1of3' {d0 d1 d2 _ _ _}. +Arguments macc2of3 {d0 d1 d2 _ _ _}. +Arguments macc2of3' {d0 d1 d2 _ _ _}. +Arguments macc1of4 {d0 d1 d2 d3 _ _ _ _}. +Arguments macc2of4' {d0 d1 d2 d3 _ _ _ _}. +Arguments macc3of4 {d0 d1 d2 d3 _ _ _ _}. + +Section insn1_lemmas. +Import Notations. +Context d (T : measurableType d) (R : realType). + +Let kcomp_scoreE d1 d2 (T1 : measurableType d1) (T2 : measurableType d2) + (g : R.-sfker [the measurableType _ of (T1 * unit)%type] ~> T2) + f (mf : measurable_fun setT f) r U : + (score mf \; g) r U = `|f r|%:E * g (r, tt) U. +Proof. +rewrite /= /kcomp /kscore /= ge0_integral_mscale//=. +by rewrite integral_dirac// indicT mul1e. +Qed. + +Lemma scoreE d' (T' : measurableType d') (x : T * T') (U : set T') (f : R -> R) + (r : R) (r0 : (0 <= r)%R) + (f0 : (forall r, 0 <= r -> 0 <= f r)%R) (mf : measurable_fun setT f) : + score (measurableT_comp mf (@macc1of2 _ _ _ _)) + (x, r) (curry (snd \o fst) x @^-1` U) = + (f r)%:E * \d_x.2 U. +Proof. +by rewrite /score/= /mscale/= ger0_norm//= f0. +Qed. + +Lemma score_score (f : R -> R) (g : R * unit -> R) + (mf : measurable_fun setT f) + (mg : measurable_fun setT g) : + letin (score mf) (score mg) = + score (measurable_funM mf (measurableT_comp mg (measurable_pair2 tt))). +Proof. +apply/eq_sfkernel => x U. +rewrite {1}/letin; unlock. +by rewrite kcomp_scoreE/= /mscale/= diracE normrM muleA EFinM. +Qed. + +(* hard constraints to express score below 1 *) +Lemma score_fail (r : {nonneg R}) (r1 : (r%:num <= 1)%R) : + score (kr r%:num) = + letin (sample_cst (bernoulli r1) : R.-pker T ~> _) + (ite (@macc1of2 _ _ _ _) (ret ktt) fail). +Proof. +apply/eq_sfkernel => x U. +rewrite letinE/= /sample; unlock. +rewrite integral_measure_add//= ge0_integral_mscale//= ge0_integral_mscale//=. +rewrite integral_dirac//= integral_dirac//= !indicT/= !mul1e. +by rewrite /mscale/= iteE//= iteE//= failE mule0 adde0 ger0_norm. +Qed. + +End insn1_lemmas. + +Section letin_ite. +Context d d2 d3 (T : measurableType d) (T2 : measurableType d2) + (Z : measurableType d3) (R : realType). +Variables (k1 k2 : R.-sfker T ~> Z) + (u : R.-sfker [the measurableType _ of (T * Z)%type] ~> T2) + (f : T -> bool) (mf : measurable_fun setT f) + (t : T) (U : set T2). + +Lemma letin_iteT : f t -> letin (ite mf k1 k2) u t U = letin k1 u t U. +Proof. +move=> ftT. +rewrite !letinE/=. +apply: eq_measure_integral => V mV _. +by rewrite iteE ftT. +Qed. + +Lemma letin_iteF : ~~ f t -> letin (ite mf k1 k2) u t U = letin k2 u t U. +Proof. +move=> ftF. +rewrite !letinE/=. +apply: eq_measure_integral => V mV _. +by rewrite iteE (negbTE ftF). +Qed. + +End letin_ite. + +Section letinA. +Context d d' d1 d2 d3 (X : measurableType d) (Y : measurableType d') + (T1 : measurableType d1) (T2 : measurableType d2) (T3 : measurableType d3) + (R : realType). +Import Notations. +Variables (t : R.-sfker X ~> T1) + (u : R.-sfker [the measurableType _ of (X * T1)%type] ~> T2) + (v : R.-sfker [the measurableType _ of (X * T2)%type] ~> Y) + (v' : R.-sfker [the measurableType _ of (X * T1 * T2)%type] ~> Y) + (vv' : forall y, v =1 fun xz => v' (xz.1, y, xz.2)). + +Lemma letinA x A : measurable A -> + letin t (letin u v') x A + = + (letin (letin t u) v) x A. +Proof. +move=> mA. +rewrite !letinE. +under eq_integral do rewrite letinE. +rewrite integral_kcomp; [|by []|]. +- apply: eq_integral => y _. + apply: eq_integral => z _. + by rewrite (vv' y). +exact: (measurableT_comp (measurable_kernel v _ mA)). +Qed. + +End letinA. + +Section letinC. +Context d d1 d' (X : measurableType d) (Y : measurableType d1) + (Z : measurableType d') (R : realType). + +Import Notations. + +Variables (t : R.-sfker Z ~> X) + (t' : R.-sfker [the measurableType _ of (Z * Y)%type] ~> X) + (tt' : forall y, t =1 fun z => t' (z, y)) + (u : R.-sfker Z ~> Y) + (u' : R.-sfker [the measurableType _ of (Z * X)%type] ~> Y) + (uu' : forall x, u =1 fun z => u' (z, x)). + +Definition T z : set X -> \bar R := t z. +Let T0 z : (T z) set0 = 0. Proof. by []. Qed. +Let T_ge0 z x : 0 <= (T z) x. Proof. by []. Qed. +Let T_semi_sigma_additive z : semi_sigma_additive (T z). +Proof. exact: measure_semi_sigma_additive. Qed. +HB.instance Definition _ z := @isMeasure.Build _ R X (T z) (T0 z) (T_ge0 z) + (@T_semi_sigma_additive z). + +Let sfinT z : sfinite_measure (T z). Proof. exact: sfinite_kernel_measure. Qed. +HB.instance Definition _ z := @Measure_isSFinite_subdef.Build _ X R + (T z) (sfinT z). + +Definition U z : set Y -> \bar R := u z. +Let U0 z : (U z) set0 = 0. Proof. by []. Qed. +Let U_ge0 z x : 0 <= (U z) x. Proof. by []. Qed. +Let U_semi_sigma_additive z : semi_sigma_additive (U z). +Proof. exact: measure_semi_sigma_additive. Qed. +HB.instance Definition _ z := @isMeasure.Build _ R Y (U z) (U0 z) (U_ge0 z) + (@U_semi_sigma_additive z). + +Let sfinU z : sfinite_measure (U z). Proof. exact: sfinite_kernel_measure. Qed. +HB.instance Definition _ z := @Measure_isSFinite_subdef.Build _ Y R + (U z) (sfinU z). + +Lemma letinC z A : measurable A -> + letin t + (letin u' + (ret (measurable_fun_prod macc1of3 macc2of3))) z A = + letin u + (letin t' + (ret (measurable_fun_prod macc2of3 macc1of3))) z A. +Proof. +move=> mA. +rewrite !letinE. +under eq_integral. + move=> x _. + rewrite letinE -uu'. + under eq_integral do rewrite retE /=. + over. +rewrite (sfinite_Fubini + [the {sfinite_measure set X -> \bar R} of T z] + [the {sfinite_measure set Y -> \bar R} of U z] + (fun x => \d_(x.1, x.2) A ))//; last first. + apply/EFin_measurable_fun => /=; rewrite (_ : (fun x => _) = mindic R mA)//. + by apply/funext => -[]. +rewrite /=. +apply: eq_integral => y _. +by rewrite letinE/= -tt'; apply: eq_integral => // x _; rewrite retE. +Qed. + +End letinC. + +(* sample programs *) +Section poisson. +Variable R : realType. +Local Open Scope ring_scope. + +(* density function for Poisson *) +Definition poisson k r : R := + if r > 0 then r ^+ k / k`!%:R^-1 * expR (- r) else 1%:R. + +Lemma poisson_ge0 k r : 0 <= poisson k r. +Proof. +rewrite /poisson; case: ifPn => r0//. +by rewrite mulr_ge0 ?expR_ge0// mulr_ge0// exprn_ge0 ?ltW. +Qed. + +Lemma poisson_gt0 k r : 0 < r -> 0 < poisson k.+1 r. +Proof. +move=> r0; rewrite /poisson r0 mulr_gt0 ?expR_gt0//. +by rewrite divr_gt0// ?exprn_gt0// invr_gt0 ltr0n fact_gt0. +Qed. + +Lemma measurable_poisson k : measurable_fun setT (poisson k). +Proof. +rewrite /poisson; apply: measurable_fun_if => //. + apply: (measurable_fun_bool true). + rewrite (_ : _ @^-1` _ = `]0, +oo[%classic)//. + by apply/seteqP; split => x /=; rewrite in_itv/= andbT. +by apply: measurable_funM => /=; + [exact: measurable_funM|exact: measurableT_comp]. +Qed. + +Definition poisson3 := poisson 4 3%:R. (* 0.168 *) +Definition poisson10 := poisson 4 10%:R. (* 0.019 *) + +End poisson. + +Section exponential. +Variable R : realType. +Local Open Scope ring_scope. + +(* density function for exponential *) +Definition exp_density x r : R := r * expR (- r * x). + +Lemma exp_density_gt0 x r : 0 < r -> 0 < exp_density x r. +Proof. by move=> r0; rewrite /exp_density mulr_gt0// expR_gt0. Qed. + +Lemma exp_density_ge0 x r : 0 <= r -> 0 <= exp_density x r. +Proof. by move=> r0; rewrite /exp_density mulr_ge0// expR_ge0. Qed. + +Lemma mexp_density x : measurable_fun setT (exp_density x). +Proof. +apply: measurable_funM => //=; apply: measurableT_comp => //. +exact: measurable_funM. +Qed. + +End exponential. + +Lemma letin_sample_bernoulli d d' (T : measurableType d) + (T' : measurableType d') (R : realType)(r : {nonneg R}) (r1 : (r%:num <= 1)%R) + (u : R.-sfker [the measurableType _ of (T * bool)%type] ~> T') x y : + letin (sample_cst (bernoulli r1)) u x y = + r%:num%:E * u (x, true) y + (`1- (r%:num))%:E * u (x, false) y. +Proof. +rewrite letinE/=. +rewrite ge0_integral_measure_sum// 2!big_ord_recl/= big_ord0 adde0/=. +by rewrite !ge0_integral_mscale//= !integral_dirac//= indicT 2!mul1e. +Qed. + +Section sample_and_return. +Import Notations. +Context d (T : measurableType d) (R : realType). + +Definition sample_and_return : R.-sfker T ~> _ := + letin + (sample_cst [the probability _ _ of bernoulli p27]) (* T -> B *) + (ret macc1of2) (* T * B -> B *). + +Lemma sample_and_returnE t U : sample_and_return t U = + (2 / 7%:R)%:E * \d_true U + (5%:R / 7%:R)%:E * \d_false U. +Proof. +by rewrite /sample_and_return letin_sample_bernoulli !retE onem27. +Qed. + +End sample_and_return. + +(* trivial example *) +Section sample_and_branch. +Import Notations. +Context d (T : measurableType d) (R : realType). + +(* let x = sample (bernoulli (2/7)) in + let r = case x of {(1, _) => return (k3()), (2, _) => return (k10())} in + return r *) + +Definition sample_and_branch : R.-sfker T ~> mR R := + letin + (sample_cst [the probability _ _ of bernoulli p27]) (* T -> B *) + (ite macc1of2 (ret k3) (ret k10)). + +Lemma sample_and_branchE t U : sample_and_branch t U = + (2 / 7%:R)%:E * \d_(3%:R : R) U + + (5%:R / 7%:R)%:E * \d_(10%:R : R) U. +Proof. +by rewrite /sample_and_branch letin_sample_bernoulli/= !iteE !retE onem27. +Qed. + +End sample_and_branch. + +Section bernoulli_and. +Context d (T : measurableType d) (R : realType). +Import Notations. + +Definition bernoulli_and : R.-sfker T ~> mbool := + (letin (sample_cst [the probability _ _ of bernoulli p12]) + (letin (sample_cst [the probability _ _ of bernoulli p12]) + (ret (measurable_and macc1of3 macc2of3)))). + +Lemma bernoulli_andE t U : + bernoulli_and t U = + sample_cst (bernoulli p14) t U. +Proof. +rewrite /bernoulli_and 3!letin_sample_bernoulli/= muleDr//= -muleDl//. +rewrite !muleA -addeA -muleDl// -!EFinM !onem1S/= -splitr mulr1. +have -> : (1 / 2 * (1 / 2) = 1 / 4%:R :> R)%R by rewrite mulf_div mulr1// -natrM. +rewrite /bernoulli/= measure_addE/= /mscale/= -!EFinM; congr( _ + (_ * _)%:E). +have -> : (1 / 2 = 2 / 4%:R :> R)%R. + by apply/eqP; rewrite eqr_div// ?pnatr_eq0// mul1r -natrM. +by rewrite onem1S// -mulrDl. +Qed. + +End bernoulli_and. + +Section staton_bus. +Import Notations. +Context d (T : measurableType d) (R : realType) (h : R -> R). +Hypothesis mh : measurable_fun setT h. +Definition kstaton_bus : R.-sfker T ~> mbool := + letin (sample_cst [the probability _ _ of bernoulli p27]) + (letin + (letin (ite macc1of2 (ret k3) (ret k10)) + (score (measurableT_comp mh macc2of3))) + (ret macc1of3)). + +Definition staton_bus := normalize kstaton_bus. + +End staton_bus. + +(* let x = sample (bernoulli (2/7)) in + let r = case x of {(1, _) => return (k3()), (2, _) => return (k10())} in + let _ = score (1/4! r^4 e^-r) in + return x *) +Section staton_bus_poisson. +Import Notations. +Context d (T : measurableType d) (R : realType). +Let poisson4 := @poisson R 4%N. +Let mpoisson4 := @measurable_poisson R 4%N. + +Definition kstaton_bus_poisson : R.-sfker (mR R) ~> mbool := + kstaton_bus _ mpoisson4. + +Let kstaton_bus_poissonE t U : kstaton_bus_poisson t U = + (2 / 7%:R)%:E * (poisson4 3%:R)%:E * \d_true U + + (5%:R / 7%:R)%:E * (poisson4 10%:R)%:E * \d_false U. +Proof. +rewrite /kstaton_bus. +rewrite letin_sample_bernoulli. +rewrite -!muleA; congr (_ * _ + _ * _). +- rewrite letin_kret//. + rewrite letin_iteT//. + rewrite letin_retk//. + by rewrite scoreE//= => r r0; exact: poisson_ge0. +- by rewrite onem27. + rewrite letin_kret//. + rewrite letin_iteF//. + rewrite letin_retk//. + by rewrite scoreE//= => r r0; exact: poisson_ge0. +Qed. + +(* true -> 2/7 * 0.168 = 2/7 * 3^4 e^-3 / 4! *) +(* false -> 5/7 * 0.019 = 5/7 * 10^4 e^-10 / 4! *) + +Lemma staton_busE P (t : R) U : + let N := ((2 / 7%:R) * poisson4 3%:R + + (5%:R / 7%:R) * poisson4 10%:R)%R in + staton_bus mpoisson4 P t U = + ((2 / 7%:R)%:E * (poisson4 3%:R)%:E * \d_true U + + (5%:R / 7%:R)%:E * (poisson4 10%:R)%:E * \d_false U) * N^-1%:E. +Proof. +rewrite /staton_bus normalizeE /= !kstaton_bus_poissonE !diracT !mule1 ifF //. +apply/negbTE; rewrite gt_eqF// lte_fin. +by rewrite addr_gt0// mulr_gt0//= ?divr_gt0// ?ltr0n// poisson_gt0// ltr0n. +Qed. + +End staton_bus_poisson. + +(* let x = sample (bernoulli (2/7)) in + let r = case x of {(1, _) => return (k3()), (2, _) => return (k10())} in + let _ = score (r e^-(15/60 r)) in + return x *) +Section staton_bus_exponential. +Import Notations. +Context d (T : measurableType d) (R : realType). +Let exp1560 := @exp_density R (ratr (15%:Q / 60%:Q)). +Let mexp1560 := @mexp_density R (ratr (15%:Q / 60%:Q)). + +(* 15/60 = 0.25 *) + +Definition kstaton_bus_exponential : R.-sfker (mR R) ~> mbool := + kstaton_bus _ mexp1560. + +Let kstaton_bus_exponentialE t U : kstaton_bus_exponential t U = + (2 / 7%:R)%:E * (exp1560 3%:R)%:E * \d_true U + + (5%:R / 7%:R)%:E * (exp1560 10%:R)%:E * \d_false U. +Proof. +rewrite /kstaton_bus. +rewrite letin_sample_bernoulli. +rewrite -!muleA; congr (_ * _ + _ * _). +- rewrite letin_kret//. + rewrite letin_iteT//. + rewrite letin_retk//. + rewrite scoreE//= => r r0; exact: exp_density_ge0. +- by rewrite onem27. + rewrite letin_kret//. + rewrite letin_iteF//. + rewrite letin_retk//. + by rewrite scoreE//= => r r0; exact: exp_density_ge0. +Qed. + +(* true -> 5/7 * 0.019 = 5/7 * 10^4 e^-10 / 4! *) +(* false -> 2/7 * 0.168 = 2/7 * 3^4 e^-3 / 4! *) + +Lemma staton_bus_exponentialE P (t : R) U : + let N := ((2 / 7%:R) * exp1560 3%:R + + (5%:R / 7%:R) * exp1560 10%:R)%R in + staton_bus mexp1560 P t U = + ((2 / 7%:R)%:E * (exp1560 3%:R)%:E * \d_true U + + (5%:R / 7%:R)%:E * (exp1560 10%:R)%:E * \d_false U) * N^-1%:E. +Proof. +rewrite /staton_bus. +rewrite normalizeE /= !kstaton_bus_exponentialE !diracT !mule1 ifF //. +apply/negbTE; rewrite gt_eqF// lte_fin. +by rewrite addr_gt0// mulr_gt0//= ?divr_gt0// ?ltr0n// exp_density_gt0 ?ltr0n. +Qed. + +End staton_bus_exponential. + +(* X + Y is a measurableType if X and Y are *) +Canonical sum_pointedType (X Y : pointedType) := + PointedType (X + Y) (@inl X Y point). + +Section measurable_sum. +Context d d' (X : measurableType d) (Y : measurableType d'). + +Definition measurable_sum : set (set (X + Y)) := setT. + +Let sum0 : measurable_sum set0. Proof. by []. Qed. + +Let sumC A : measurable_sum A -> measurable_sum (~` A). Proof. by []. Qed. + +Let sumU (F : (set (X + Y))^nat) : (forall i, measurable_sum (F i)) -> + measurable_sum (\bigcup_i F i). +Proof. by []. Qed. + +HB.instance Definition _ := @isMeasurable.Build default_measure_display (X + Y)%type + (Pointed.class _) measurable_sum sum0 sumC sumU. + +End measurable_sum. + +Lemma measurable_fun_sum dA dB d' (A : measurableType dA) (B : measurableType dB) + (Y : measurableType d') (f : A -> Y) (g : B -> Y) : + measurable_fun setT f -> measurable_fun setT g -> + measurable_fun setT (fun tb : A + B => + match tb with inl a => f a | inr b => g b end). +Proof. +move=> mx my/= _ Z mZ /=; rewrite setTI /=. +rewrite (_ : _ @^-1` Z = inl @` (f @^-1` Z) `|` inr @` (g @^-1` Z)). + exact: measurableU. +apply/seteqP; split. + by move=> [a Zxa|b Zxb]/=; [left; exists a|right; exists b]. +by move=> z [/= [a Zxa <-//=]|]/= [b Zyb <-//=]. +Qed. + +(* TODO: measurable_fun_if_pair -> measurable_fun_if_pair_bool? *) +Lemma measurable_fun_if_pair_nat d d' (X : measurableType d) + (Y : measurableType d') (f g : X -> Y) (n : nat) : + measurable_fun setT f -> measurable_fun setT g -> + measurable_fun setT (fun xn => if xn.2 == n then f xn.1 else g xn.1). +Proof. +move=> mx my; apply: measurable_fun_ifT => //=. +- have h : measurable_fun [set: nat] (fun t => t == n) by []. + exact: (@measurableT_comp _ _ _ _ _ _ _ _ _ h). +- exact: measurableT_comp. +- exact: measurableT_comp. +Qed. + +Module CASE_NAT. +Section case_nat. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). + +Section case_nat_ker. +Variable k : R.-ker X ~> Y. + +Definition case_nat_ j : X * nat -> {measure set Y -> \bar R} := + fun xn => if xn.2 == j then k xn.1 else [the measure _ _ of mzero]. + +Let measurable_fun_case_nat_ m U : measurable U -> + measurable_fun setT (case_nat_ m ^~ U). +Proof. +move=> mU; rewrite /case_nat_ (_ : (fun _ => _) = + (fun x => if x.2 == m then k x.1 U else mzero U)) /=; last first. + by apply/funext => -[t b]/=; case: ifPn. +apply: (@measurable_fun_if_pair_nat _ _ _ _ (k ^~ U) (fun=> mzero U)) => //. +exact/measurable_kernel. +Qed. + +#[export] +HB.instance Definition _ j := isKernel.Build _ _ _ _ _ + (case_nat_ j) (measurable_fun_case_nat_ j). +End case_nat_ker. + +Section sfcase_nat. +Variable k : R.-sfker X ~> Y. + +Let sfcase_nat_ j : exists2 k_ : (R.-ker _ ~> _)^nat, + forall n, measure_fam_uub (k_ n) & + forall x U, measurable U -> case_nat_ k j x U = mseries (k_ ^~ x) 0 U. +Proof. +have [k_ hk /=] := sfinite_kernel k. +exists (fun n => [the _.-ker _ ~> _ of case_nat_ (k_ n) j]) => /=. + move=> n; have /measure_fam_uubP[r k_r] := measure_uub (k_ n). + exists r%:num => /= -[x [|n']]; rewrite /case_nat_//= /mzero//. + by case: ifPn => //= ?; rewrite /mzero. + by case: ifPn => // ?; rewrite /= /mzero. +move=> [x b] U mU; rewrite /case_nat_; case: ifPn => hb; first by rewrite hk. +by rewrite /mseries eseries0. +Qed. + +#[export] +HB.instance Definition _ j := @Kernel_isSFinite_subdef.Build _ _ _ _ _ + (case_nat_ k j) (sfcase_nat_ j). +End sfcase_nat. + +Section fkcase_nat. +Variable k : R.-fker X ~> Y. + +Let case_nat_uub (m : nat) : measure_fam_uub (case_nat_ k m). +Proof. +have /measure_fam_uubP[M hM] := measure_uub k. +exists M%:num => /= -[]; rewrite /case_nat_ => t [|m']/=. + by case: ifPn => //= ?; rewrite /mzero//=. +by case: ifPn => //= ?; rewrite /mzero//=. +Qed. + +#[export] +HB.instance Definition _ j := Kernel_isFinite.Build _ _ _ _ _ + (case_nat_ k j) (case_nat_uub j). +End fkcase_nat. + +End case_nat. +End CASE_NAT. + +Import CASE_NAT. + +Section case_nat. +Context d d' (T : measurableType d) (T' : measurableType d') (R : realType). + +Import CASE_NAT. + +(* case analysis on the nat datatype *) +Definition case_nat (t : R.-sfker T ~> nat) (u_ : (R.-sfker T ~> T')^nat) + : R.-sfker T ~> T' := + [the R.-sfker T ~> nat of t] \; + [the R.-sfker T * nat ~> T' of + kseries (fun n => [the R.-sfker T * nat ~> T' of case_nat_ (u_ n) n])]. + +End case_nat. + +Definition measure_sum_display : + (measure_display * measure_display) -> measure_display. +Proof. exact. Qed. + +Definition image_classes d1 d2 + (T1 : measurableType d1) (T2 : measurableType d2) (T : Type) + (f1 : T1 -> T) (f2 : T2 -> T) := + <>. + +Section sum_salgebra_instance. +Context d1 d2 (T1 : measurableType d1) (T2 : measurableType d2). +Let f1 : T1 -> T1 + T2 := @inl T1 T2. +Let f2 : T2 -> T1 + T2 := @inr T1 T2. + +Lemma sum_salgebra_set0 : image_classes f1 f2 (set0 : set (T1 + T2)). +Proof. exact: sigma_algebra0. Qed. + +Lemma sum_salgebra_setC A : image_classes f1 f2 A -> + image_classes f1 f2 (~` A). +Proof. exact: sigma_algebraC. Qed. + +Lemma sum_salgebra_bigcup (F : _^nat) : (forall i, image_classes f1 f2 (F i)) -> + image_classes f1 f2 (\bigcup_i (F i)). +Proof. exact: sigma_algebra_bigcup. Qed. + +HB.instance Definition sum_salgebra_mixin := + @isMeasurable.Build (measure_sum_display (d1, d2)) + (T1 + T2)%type (Pointed.class _) (image_classes f1 f2) + (sum_salgebra_set0) (sum_salgebra_setC) (sum_salgebra_bigcup). + +End sum_salgebra_instance. +Reserved Notation "p .-sum" (at level 1, format "p .-sum"). +Reserved Notation "p .-sum.-measurable" + (at level 2, format "p .-sum.-measurable"). +Notation "p .-sum" := (measure_sum_display p) : measure_display_scope. +Notation "p .-sum.-measurable" := + ((p.-sum).-measurable : set (set (_ + _))) : + classical_set_scope. + +Module CASE_SUM. + +Section case_suml. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). +Let A : measurableType _ := unit. + +Section kcase_suml. +Variable k : R.-ker X ~> Y. + +Definition case_suml (a : A) : X * A -> {measure set Y -> \bar R} := + fun xa => k xa.1. + +Let measurable_fun_case_suml a U : measurable U -> + measurable_fun setT (case_suml a ^~ U). +Proof. +move=> /= mU; rewrite /case_suml. +have h := measurable_kernel k _ mU. +rewrite (_ : (fun x : X * unit => k x.1 U) = (fun x : X => k x U) \o fst) //. +by apply: measurableT_comp => //. +Qed. + +#[export] +HB.instance Definition _ a := isKernel.Build _ _ _ _ _ + (case_suml a) (measurable_fun_case_suml a). +End kcase_suml. + +Section sfkcase_suml. +Variable k : R.-sfker X ~> Y. + +Let sfinite_case_suml (a : A) : exists2 k_ : (R.-ker _ ~> _)^nat, + forall n, measure_fam_uub (k_ n) & + forall x U, measurable U -> case_suml k a x U = mseries (k_ ^~ x) 0 U. +Proof. +have [k_ hk /=] := sfinite_kernel k. +exists (fun n => [the _.-ker _ ~> _ of case_suml (k_ n) a]) => /=. + move=> n; have /measure_fam_uubP[r k_r] := measure_uub (k_ n). + by exists r%:num => /= -[x []]; rewrite /case_suml//= /mzero//. +move=> [x b] U mU; rewrite /case_suml /=. +by rewrite /mseries hk. +Qed. + +#[export] +HB.instance Definition _ (a : A) := @Kernel_isSFinite_subdef.Build _ _ _ _ _ + (case_suml k a) (sfinite_case_suml a). +End sfkcase_suml. + +Section fkcase_suml. +Variable k : R.-fker X ~> Y. + +Let case_suml_uub (a : A) : measure_fam_uub (case_suml k a). +Proof. +have /measure_fam_uubP[M hM] := measure_uub k. +by exists M%:num => /= -[]; rewrite /case_suml. +Qed. + +#[export] +HB.instance Definition _ a := Kernel_isFinite.Build _ _ _ _ _ + (case_suml k a) (case_suml_uub a). +End fkcase_suml. + +End case_suml. + +Section case_sumr. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). +Let B : measurableType _ := bool. + +Section kcase_sumr. +Variable k : R.-ker X ~> Y. + +Definition case_sumr (b : B) : X * B -> {measure set Y -> \bar R} := + fun xa => if xa.2 == b then k xa.1 else [the measure _ _ of mzero]. + +Let measurable_fun_case_sumr b U : measurable U -> + measurable_fun setT (case_sumr b ^~ U). +Proof. +move=> /= mU; rewrite /case_suml. +have h := measurable_kernel k _ mU. +rewrite /case_sumr. +rewrite (_ : (fun x : X * bool => case_sumr b x U) = + (fun x : X * bool => (if x.2 == b then k x.1 U else [the {measure set Y -> \bar R} of mzero] U))); last first. + apply/funext => x. + rewrite /case_sumr. + by case: ifPn. +apply: measurable_fun_ifT => //=. + rewrite (_ : (fun t : X * bool => t.2 == b) = (fun t : bool => t == b) \o snd)//. + apply: measurableT_comp => //. +rewrite (_ : (fun t : X * bool => k t.1 U) = (fun t : X => k t U) \o fst)//. +by apply: measurableT_comp => //. +Qed. + +#[export] +HB.instance Definition _ b := isKernel.Build _ _ _ _ _ + (case_sumr b) (measurable_fun_case_sumr b). +End kcase_sumr. + +Section sfkcase_sumr. +Variable k : R.-sfker X ~> Y. + +Let sfinite_case_sumr b : exists2 k_ : (R.-ker _ ~> _)^nat, + forall n, measure_fam_uub (k_ n) & + forall x U, measurable U -> case_sumr k b x U = mseries (k_ ^~ x) 0 U. +Proof. +have [k_ hk /=] := sfinite_kernel k. +exists (fun n => [the _.-ker _ ~> _ of case_sumr (k_ n) b]) => /=. + move=> n; have /measure_fam_uubP[r k_r] := measure_uub (k_ n). + by exists r%:num => /= -[x []]; rewrite /case_sumr//=; case: ifPn => // _; + rewrite /= (le_lt_trans _ (k_r x))// /mzero//. +move=> [x [|]] U mU; rewrite /case_sumr /=; case: b => //=; rewrite ?hk//; +by rewrite /mseries/= eseries0. +Qed. + +#[export] +HB.instance Definition _ b := @Kernel_isSFinite_subdef.Build _ _ _ _ _ + (case_sumr k b) (sfinite_case_sumr b). +End sfkcase_sumr. + +Section fkcase_sumr. +Variable k : R.-fker X ~> Y. + +Let case_sumr_uub b : measure_fam_uub (case_sumr k b). +Proof. +have /measure_fam_uubP[M hM] := measure_uub k. +by exists M%:num => /= -[]; rewrite /case_sumr => x [|] /=; case: b => //=; + rewrite (le_lt_trans _ (hM x))// /mzero. +Qed. + +#[export] +HB.instance Definition _ b := Kernel_isFinite.Build _ _ _ _ _ + (case_sumr k b) (case_sumr_uub b). +End fkcase_sumr. + +End case_sumr. +End CASE_SUM. + +Section case_sum'. + +Section kcase_sum'. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). +Let A : measurableType _ := unit. +Let B : measurableType _ := bool. +Variables (k : (A + B)%type -> R.-sfker X ~> Y). + +Definition case_sum' : X * (A + B)%type -> {measure set Y -> \bar R} := + fun xab => match xab with + | (x, inl a) => CASE_SUM.case_suml (k xab.2) a (x, a) + | (x, inr b) => CASE_SUM.case_sumr (k xab.2) b (x, b) + end. + +Let measurable_fun_case_sum' U : measurable U -> + measurable_fun setT (case_sum' ^~ U). +Proof. +rewrite /= => mU. +apply: (measurability (ErealGenInftyO.measurableE R)) => //. +move=> /= _ [_ [x ->] <-]; apply: measurableI => //. +rewrite /case_sum' /CASE_SUM.case_suml /CASE_SUM.case_sumr /=. +rewrite (_ : + (fun x : X * (unit + bool) => (let (x0, s) := x in + match s with inl _ => k x.2 x0 | inr b => if b == b then k x.2 x0 else mzero + end) U) = + (fun x : X * (unit + bool) => k x.2 x.1 U)); last first. + by apply/funext => -[x1 [a|b]] //; rewrite eqxx. +rewrite (_ : _ @^-1` _ = + ([set x1 | k (inl tt) x1 U < x%:E] `*` inl @` [set tt]) `|` + ([set x1 | k (inr false) x1 U < x%:E] `*` inr @` [set false]) `|` + ([set x1 | k (inr true) x1 U < x%:E] `*` inr @` [set true])); last first. + apply/seteqP; split. + - move=> z /=; rewrite in_itv/=; move: z => [z [[]|[|]]]//= ?. + + by do 2 left; split => //; exists tt. + + by right; split => //; exists true. + + by left; right; split => //; exists false. + - move=> z /=; rewrite in_itv/=; move: z => [z [[]|[|]]]//=. + - move=> [[[]//|]|]. + + by move=> [_ []]. + + by move=> [_ []]. + - move=> [[|]|[]//]. + + by move=> [_ []]. + + by move=> [_ [] [|]]. + - move=> [[|[]//]|]. + + by move=> [_ []]. + + by move=> [_ [] [|]]. +pose h1 := [set xub : X * (unit + bool) | k (inl tt) xub.1 U < x%:E]. +have mh1 : measurable h1. + rewrite -[X in measurable X]setTI; apply: emeasurable_fun_infty_o => //=. + have H : measurable_fun [set: X] (fun x => k (inl tt) x U) by exact/measurable_kernel. + move=> _ /= C mC; rewrite setTI. + have := H measurableT _ mC; rewrite setTI => {}H. + rewrite [X in measurable X](_ : _ = ((fun x => k (inl tt) x U) @^-1` C) `*` setT)//. + exact: measurableM. + by apply/seteqP; split => [z//=| z/= []]. +set h2 := [set xub : X * (unit + bool)| k (inr false) xub.1 U < x%:E]. +have mh2 : measurable h2. + rewrite -[X in measurable X]setTI. + apply: emeasurable_fun_infty_o => //=. + have H : measurable_fun [set: X] (fun x => k (inr false) x U) by exact/measurable_kernel. + move=> _ /= C mC; rewrite setTI. + have := H measurableT _ mC; rewrite setTI => {}H. + rewrite [X in measurable X](_ : _ = ((fun x => k (inr false) x U) @^-1` C) `*` setT)//. + exact: measurableM. + by apply/seteqP; split => [z //=|z/= []]. +set h3 := [set xub : X * (unit + bool)| k (inr true) xub.1 U < x%:E]. +have mh3 : measurable h3. + rewrite -[X in measurable X]setTI. + apply: emeasurable_fun_infty_o => //=. + have H : measurable_fun [set: X] (fun x => k (inr true) x U) by exact/measurable_kernel. + move=> _ /= C mC; rewrite setTI. + have := H measurableT _ mC; rewrite setTI => {}H. + rewrite [X in measurable X](_ : _ = ((fun x => k (inr true) x U) @^-1` C) `*` setT)//. + exact: measurableM. + by apply/seteqP; split=> [z//=|z/= []]. +apply: measurableU. +- apply: measurableU. + + apply: measurableM => //. + rewrite [X in measurable X](_ : _ = ysection h1 (inl tt))//. + * by apply: measurable_ysection. + * by apply/seteqP; split => z /=; rewrite /ysection /= inE. + + apply: measurableM => //. + rewrite [X in measurable X](_ : _ = ysection h2 (inr false))//. + * by apply: measurable_ysection. + * by apply/seteqP; split => z /=; rewrite /ysection /= inE. +- apply: measurableM => //. + rewrite [X in measurable X](_ : _ = ysection h3 (inr true))//. + + by apply: measurable_ysection. + + by apply/seteqP; split => z /=; rewrite /ysection /= inE. +Qed. + +#[export] +HB.instance Definition _ := isKernel.Build _ _ _ _ _ + (case_sum') (measurable_fun_case_sum'). +End kcase_sum'. + +Section sfkcase_sum'. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). +Let A : measurableType _ := unit. +Let B : measurableType _ := bool. +Variables (k : (A + B)%type -> R.-sfker X ~> Y). + +Let sfinite_case_sum' : exists2 k_ : (R.-ker _ ~> _)^nat, + forall n, measure_fam_uub (k_ n) & + forall x U, measurable U -> case_sum' k x U = mseries (k_ ^~ x) 0 U. +Proof. +rewrite /=. +set f : A + B -> (R.-fker _ ~> _)^nat := + fun ab : A + B => sval (cid (sfinite_kernel (k ab))). +set Hf := fun ab : A + B => svalP (cid (sfinite_kernel (k ab))). +rewrite /= in Hf. +exists (fun n => [the R.-ker _ ~> _ of case_sum' (fun ab => [the R.-fker _ ~> _ of f ab n])]). + move=> n /=. + have [rtt Hrtt] := measure_uub (f (inl tt) n). + have [rfalse Hrfalse] := measure_uub (f (inr false) n). + have [rtrue Hrtrue] := measure_uub (f (inr true) n). + exists (maxr rtt (maxr rfalse rtrue)) => //= -[x [[]|[|]]] /=. + by rewrite 2!EFin_max lt_maxr Hrtt. + by rewrite /CASE_SUM.case_sumr /= 2!EFin_max 2!lt_maxr Hrtrue 2!orbT. + by rewrite /CASE_SUM.case_sumr /= 2!EFin_max 2!lt_maxr Hrfalse orbT. +move=> [x [[]|[|]]] U mU/=-. +by rewrite (Hf (inl tt) x _ mU). +by rewrite (Hf (inr true) x _ mU). +by rewrite (Hf (inr false) x _ mU). +Qed. + +#[export] +HB.instance Definition _ := @Kernel_isSFinite_subdef.Build _ _ _ _ _ + (case_sum' k) (sfinite_case_sum'). +End sfkcase_sum'. + +End case_sum'. + +Section case_sum. +Context d d' (X : measurableType d) (Y : measurableType d') (R : realType). +Let A : measurableType _ := unit. +Let B : measurableType _ := bool. + +Import CASE_SUM. + +(* case analysis on the datatype unit + bool *) +Definition case_sum (f : R.-sfker X ~> (A + B)%type) + (k : (A + B)%type-> R.-sfker X ~> Y) : R.-sfker X ~> Y := + [the R.-sfker X ~> (A + B)%type of f] \; + [the R.-sfker X * (A + B) ~> Y of case_sum' k]. + +End case_sum. + +(* counting measure as a kernel *) +Section kcounting. +Context d (G : measurableType d) (R : realType). + +Definition kcounting : G -> {measure set nat -> \bar R} := fun=> counting. + +Let mkcounting U : measurable U -> measurable_fun setT (kcounting ^~ U). +Proof. by []. Qed. + +HB.instance Definition _ := isKernel.Build _ _ _ _ _ kcounting mkcounting. + +Let sfkcounting : exists2 k_ : (R.-ker _ ~> _)^nat, + forall n, measure_fam_uub (k_ n) & + forall x U, measurable U -> kcounting x U = mseries (k_ ^~ x) 0 U. +Proof. +exists (fun n => [the R.-fker _ ~> _ of + @kdirac _ _ G nat R _ (@measurable_cst _ _ _ _ setT n)]). + by move=> n /=; exact: measure_uub. +by move=> g U mU; rewrite /kcounting/= counting_dirac. +Qed. + +HB.instance Definition _ := + Kernel_isSFinite_subdef.Build _ _ _ _ R kcounting sfkcounting. + +End kcounting. + +(* formalization of the iterate construct of Staton ESOP 2017, Sect. 4.2 *) +Section iterate. +Context d {G : measurableType d} {R : realType}. +Let A : measurableType _ := unit. +Let B : measurableType _ := bool. + +(* formalization of iterate^n + Gamma |-p iterate^n t from x = u : B *) +Variables (t : R.-sfker (G * A) ~> (A + B)%type) + (u : G -> A) (mu : measurable_fun setT u). + +Fixpoint iterate_ n : R.-sfker G ~> B := + match n with + | 0%N => case_sum (letin (ret mu) t) + (fun x => match x with + | inl a => fail + | inr b => ret (measurable_cst b) + end) + | m.+1 => case_sum (letin (ret mu) t) + (fun x => match x with + | inl a => iterate_ m + | inr b => fail + end) + end. + +(* formalization of iterate (A = unit, B = bool) + Gamma, x : A |-p t : A + B Gamma |-d u : A +----------------------------------------------- + Gamma |-p iterate t from x = u : B *) +Definition iterate : R.-sfker G ~> B := case_nat (kcounting R) iterate_. + +End iterate. + +(* an s-finite kernel to test that two expressions are different *) +Section lift_neq. +Context {R : realType} d (G : measurableType d). +Variables (f : G -> bool) (g : G -> bool). + +Definition flift_neq : G -> bool := fun x' => f x' != g x'. + +Hypotheses (mf : measurable_fun setT f) (mg : measurable_fun setT g). + +(* see also emeasurable_fun_neq *) +Lemma measurable_fun_flift_neq : measurable_fun setT flift_neq. +Proof. +apply: (measurable_fun_bool true). +rewrite /flift_neq /= (_ : _ @^-1` _ = ([set x | f x] `&` [set x | ~~ g x]) `|` + ([set x | ~~ f x] `&` [set x | g x])). + apply: measurableU; apply: measurableI. + - by rewrite -[X in measurable X]setTI; exact: mf. + - rewrite [X in measurable X](_ : _ = ~` [set x | g x]); last first. + by apply/seteqP; split => x /= /negP. + by apply: measurableC; rewrite -[X in measurable X]setTI; exact: mg. + - rewrite [X in measurable X](_ : _ = ~` [set x | f x]); last first. + by apply/seteqP; split => x /= /negP. + by apply: measurableC; rewrite -[X in measurable X]setTI; exact: mf. + - by rewrite -[X in measurable X]setTI; exact: mg. +by apply/seteqP; split => x /=; move: (f x) (g x) => [|] [|]//=; intuition. +Qed. + +Definition lift_neq : R.-sfker G ~> bool := ret measurable_fun_flift_neq. + +End lift_neq. + +Section von_neumann_trick. +Context d {T : measurableType d} {R : realType}. + +Definition kinrtt {d1 d2} {T1 : measurableType d1} {T2 : measurableType d2} := + @measurable_cst _ _ T1 _ setT (@inl unit T2 tt). + +Definition finlb d1 d2 (T1 : measurableType d1) (T2 : measurableType d2) + : T1 * bool -> T2 + bool := fun t1b => inr t1b.2. + +Lemma minlb {d1 d2} {T1 : measurableType d1} {T2 : measurableType d2} : + measurable_fun setT (@finlb _ _ T1 T2). +Proof. exact: measurableT_comp. Qed. + +Variable (D : pprobability bool R (* biased coin *)). + +Definition von_neumann_trick' : R.-sfker (T * unit) ~> (unit + bool)%type := + letin (sample_cst D) + (letin (sample_cst D) + (letin (lift_neq macc1of3 macc2of3) + (ite (macc3of4) + (letin (ret macc1of4) (ret minlb)) + (ret kinrtt)))). + +Definition von_neumann_trick : R.-sfker T ~> bool := + iterate von_neumann_trick' ktt. + +End von_neumann_trick. diff --git a/theories/prob_lang_wip.v b/theories/prob_lang_wip.v new file mode 100644 index 0000000000..9ab73014b7 --- /dev/null +++ b/theories/prob_lang_wip.v @@ -0,0 +1,251 @@ +From HB Require Import structures. +From mathcomp Require Import all_ssreflect ssralg ssrnum ssrint interval finmap. +From mathcomp Require Import rat. +From mathcomp.classical Require Import mathcomp_extra boolp classical_sets. +From mathcomp.classical Require Import functions cardinality fsbigop. +Require Import signed reals ereal topology normedtype sequences esum measure. +Require Import lebesgue_measure numfun lebesgue_integral exp kernel trigo. +Require Import prob_lang. + +(******************************************************************************) +(* Semantics of a probabilistic programming language using s-finite kernels *) +(* (wip about definition of Lebesgue and counting measures) *) +(******************************************************************************) + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. +Import Order.TTheory GRing.Theory Num.Def Num.ExtraDef Num.Theory. +Import numFieldTopology.Exports. + +Local Open Scope classical_set_scope. +Local Open Scope ring_scope. +Local Open Scope ereal_scope. + +Section gauss. +Variable R : realType. +Local Open Scope ring_scope. + +(* density function for gauss *) +Definition gauss_density m s x : R := + (s * sqrtr (pi *+ 2))^-1 * expR (- ((x - m) / s) ^+ 2 / 2%:R). + +Lemma gauss_density_ge0 m s x : 0 <= s -> 0 <= gauss_density m s x. +Proof. by move=> s0; rewrite mulr_ge0 ?expR_ge0// invr_ge0 mulr_ge0. Qed. + +Lemma gauss_density_gt0 m s x : 0 < s -> 0 < gauss_density m s x. +Proof. +move=> s0; rewrite mulr_gt0 ?expR_gt0// invr_gt0 mulr_gt0//. +by rewrite sqrtr_gt0 pmulrn_rgt0// pi_gt0. +Qed. + +Definition gauss01_density : R -> R := gauss_density 0 1. + +Hypothesis integral_gauss01_density : + (\int[lebesgue_measure]_x (gauss01_density x)%:E = 1%E)%E. + +Lemma gauss01_densityE x : + gauss01_density x = (sqrtr (pi *+ 2))^-1 * expR (- (x ^+ 2) / 2%:R). +Proof. by rewrite /gauss01_density /gauss_density mul1r subr0 divr1. Qed. + +Definition mgauss01 (V : set R) := + (\int[lebesgue_measure]_(x in V) (gauss01_density x)%:E)%E. + +Lemma measurable_fun_gauss_density m s : + measurable_fun setT (gauss_density m s). +Proof. +apply: measurable_funM => //=. +apply: measurableT_comp => //=. +apply: measurable_funM => //=. +apply: measurableT_comp => //=. +apply: measurableT_comp (measurable_exprn _) _ => /=. +apply: measurable_funM => //=. +exact: measurable_funD. +Qed. + +Let mgauss010 : mgauss01 set0 = 0%E. +Proof. by rewrite /mgauss01 integral_set0. Qed. + +Let mgauss01_ge0 A : (0 <= mgauss01 A)%E. +Proof. +by rewrite /mgauss01 integral_ge0//= => x _; rewrite lee_fin gauss_density_ge0. +Qed. + +Let mgauss01_sigma_additive : semi_sigma_additive mgauss01. +Proof. +move=> /= F mF tF mUF. +rewrite /mgauss01/= integral_bigcup//=; last first. + apply/integrableP; split. + apply/EFin_measurable_fun. + exact: measurable_funS (measurable_fun_gauss_density 0 1). + rewrite (_ : (fun x => _) = (EFin \o gauss01_density)); last first. + by apply/funext => x; rewrite gee0_abs// lee_fin gauss_density_ge0. + apply: le_lt_trans. + apply: (@subset_integral _ _ _ _ _ setT) => //=. + apply/EFin_measurable_fun. + exact: measurable_fun_gauss_density. + by move=> ? _; rewrite lee_fin gauss_density_ge0. + by rewrite integral_gauss01_density// ltey. +apply: is_cvg_ereal_nneg_natsum_cond => n _ _. +by apply: integral_ge0 => /= x ?; rewrite lee_fin gauss_density_ge0. +Qed. + +HB.instance Definition _ := isMeasure.Build _ _ _ + mgauss01 mgauss010 mgauss01_ge0 mgauss01_sigma_additive. + +Let mgauss01_setT : mgauss01 [set: _] = 1%E. +Proof. by rewrite /mgauss01 integral_gauss01_density. Qed. + +HB.instance Definition _ := @Measure_isProbability.Build _ _ R mgauss01 mgauss01_setT. + +Definition gauss01 := [the probability _ _ of mgauss01]. + +End gauss. + +Section gauss_lebesgue. +Import Notations. +Context d (T : measurableType d) (R : realType). +Hypothesis integral_gauss01_density : + (\int[@lebesgue_measure R]_x (gauss01_density x)%:E = 1%E)%E. + +Let f1 (x : R) := (gauss01_density x) ^-1. + +Hypothesis integral_mgauss01 : forall U, measurable U -> + \int[mgauss01 (R:=R)]_(y in U) (f1 y)%:E = + \int[lebesgue_measure]_(x0 in U) (gauss01_density x0 * f1 x0)%:E. + +Let mf1 : measurable_fun setT f1. +Proof. +apply: (measurable_comp (F := [set r : R | r != 0%R])) => //. +- exact: open_measurable. +- by move=> /= r [t _ <-]; rewrite gt_eqF// gauss_density_gt0. +- apply: open_continuous_measurable_fun => //. + by apply/in_setP => x /= x0; exact: inv_continuous. +- exact: measurable_fun_gauss_density. +Qed. + +Variable mu : {measure set mR R -> \bar R}. + +Definition staton_lebesgue : R.-sfker T ~> _ := + letin (sample_cst (gauss01 integral_gauss01_density : pprobability _ _)) + (letin + (score (measurableT_comp mf1 macc1of2)) + (ret macc1of3)). + +Lemma staton_lebesgueE x U : measurable U -> + staton_lebesgue x U = lebesgue_measure U. +Proof. +move=> mU; rewrite [in LHS]/staton_lebesgue/=. +rewrite [in LHS]letinE /=. +transitivity (\int[@mgauss01 R]_(y in U) (f1 y)%:E). + rewrite -[in RHS](setTI U) integral_setI_indic//=. + apply: eq_integral => //= r. + rewrite letinE/= ge0_integral_mscale//= ger0_norm//; last first. + by rewrite invr_ge0// gauss_density_ge0. + by rewrite integral_dirac// indicT mul1e diracE indicE. +rewrite integral_mgauss01//. +transitivity (\int[lebesgue_measure]_(x in U) (\1_U x)%:E). + apply: eq_integral => /= y yU. + by rewrite /f1 divrr ?indicE ?yU// unitfE gt_eqF// gauss_density_gt0. +by rewrite integral_indic//= setIid. +Qed. + +End gauss_lebesgue. + +(* TODO: move this elsewhere *) +(* assuming x > 0 *) +Definition Gamma {R : realType} (x : R) : \bar R := + \int[lebesgue_measure]_(t in `[0%R, +oo[%classic) (expR (- t) * powR t (x - 1))%:E. + +Definition Rfact {R : realType} (x : R) := Gamma (x + 1)%R. + +Section poisson. +Variable R : realType. +Local Open Scope ring_scope. +Hypothesis integral_poisson_density : forall k, + (\int[lebesgue_measure]_x (@poisson R k x)%:E = 1%E)%E. + +(* density function for poisson *) +Definition poisson1 := @poisson R 1%N. + +Lemma poisson1_ge0 (x : R) : 0 <= poisson1 x. +Proof. exact: poisson_ge0. Qed. + +Definition mpoisson1 (V : set R) : \bar R := + (\int[lebesgue_measure]_(x in V) (poisson1 x)%:E)%E. + +Lemma measurable_fun_poisson1 : measurable_fun setT poisson1. +Proof. exact: measurable_poisson. Qed. + +Let mpoisson10 : mpoisson1 set0 = 0%E. +Proof. by rewrite /mpoisson1 integral_set0. Qed. + +Lemma mpoisson1_ge0 A : (0 <= mpoisson1 A)%E. +Proof. +apply: integral_ge0 => x Ax. +by rewrite lee_fin poisson1_ge0. +Qed. + +Let mpoisson1_sigma_additive : semi_sigma_additive mpoisson1. +Proof. +move=> /= F mF tF mUF. +rewrite /mpoisson1/= integral_bigcup//=; last first. + apply/integrableP; split. + apply/EFin_measurable_fun. + exact: measurable_funS (measurable_poisson _). + rewrite (_ : (fun x => _) = (EFin \o poisson1)); last first. + by apply/funext => x; rewrite gee0_abs// lee_fin poisson1_ge0//. + apply: le_lt_trans. + apply: (@subset_integral _ _ _ _ _ setT) => //=. + by apply/EFin_measurable_fun; exact: measurable_poisson. + by move=> ? _; rewrite lee_fin poisson1_ge0//. + by rewrite /= integral_poisson_density// ltry. +apply: is_cvg_ereal_nneg_natsum_cond => n _ _. +by apply: integral_ge0 => /= x ?; rewrite lee_fin poisson1_ge0. +Qed. + +HB.instance Definition _ := isMeasure.Build _ _ _ + mpoisson1 mpoisson10 mpoisson1_ge0 mpoisson1_sigma_additive. + +Let mpoisson1_setT : mpoisson1 [set: _] = 1%E. +Proof. +rewrite /mpoisson1. +rewrite /poisson1. +by rewrite integral_poisson_density. +Qed. + +HB.instance Definition _ := @Measure_isProbability.Build _ _ R mpoisson1 mpoisson1_setT. + +Definition poisson' := [the probability _ _ of mpoisson1]. + +End poisson. + +(* Staton's definition of the counting measure + Staton ESOP 2017, Sect. 4.2, equation (13) *) +Section staton_counting. +Context d (X : measurableType d). +Variable R : realType. +Import Notations. +Hypothesis integral_poisson_density : forall k, + (\int[lebesgue_measure]_x (@poisson R k x)%:E = 1%E)%E. + +Let f1 x := (poisson1 (x : R)) ^-1. + +Let mf1 : measurable_fun setT f1. +rewrite /f1 /poisson1 /poisson. +apply: (measurable_comp (F := [set r : R | r != 0%R])) => //. +- exact: open_measurable. +- move=> /= r [t ? <-]. + by case: ifPn => // t0; rewrite gt_eqF ?mulr_gt0 ?expR_gt0//= invrK ltr0n. +- apply: open_continuous_measurable_fun => //. + by apply/in_setP => x /= x0; exact: inv_continuous. +- exact: measurable_poisson. +Qed. + +Definition staton_counting : R.-sfker X ~> _ := + letin (sample_cst (@poisson' R integral_poisson_density : pprobability _ _)) + (letin + (score (measurableT_comp mf1 macc1of2)) + (ret macc1of3)). + +End staton_counting.