diff --git a/pyest/gm/gm.py b/pyest/gm/gm.py index e485502..ad97d36 100755 --- a/pyest/gm/gm.py +++ b/pyest/gm/gm.py @@ -31,6 +31,7 @@ 'v_eval_mvnpdfchol', 'integral_gauss_product', 'integral_gauss_product_chol', + 'integral_squared_gm', 'marginal_2d', 'marginal_nd', 'comp_bounds', @@ -463,6 +464,24 @@ def integral_gauss_product(m1, P1, m2, P2, allow_singular=False): return multivariate_normal.pdf(m1, m2, P1 + P2, allow_singular=allow_singular) +def integral_squared_gm(p): + ''' compute integral of squared Gaussian mixture + + Parameters + ---------- + p : GaussianMixture + Gaussian mixture + + Returns + ------- + integral : float + integral of the squared Gaussian mixture + ''' + return np.sum([ + wi*wj*eval_mvnpdf(mi, mj, Pi + Pj) for (wi, mi, Pi) in p for (wj, mj, Pj) in p + ]) + + def marginal_2d(m, P, dimensions=[0, 1]): """ compute 2D marginal of GM""" return marginal_nd(m, P, dimensions) diff --git a/pyest/metrics.py b/pyest/metrics.py index 6b2e1a5..f2837ae 100644 --- a/pyest/metrics.py +++ b/pyest/metrics.py @@ -1,6 +1,7 @@ import numpy as np import pyest.gm as pygm from scipy.linalg import solve_triangular +from scipy.integrate import dblquad def l2_dist(p1, p2): @@ -87,4 +88,106 @@ def madem(m, S, m_ref): """ return np.linalg.norm( solve_triangular(S, m - m_ref, lower=True) - ) \ No newline at end of file + ) + + +def integral_squared_error_2d(p1, p2, a, b, c, d, epsabs=1.49e-2, epsrel=1.49e-2): + ''' compute integral squared error between two 2D densities + + Parameters + ---------- + p1: callable + first density, p1([x,y]) + p2: callable + second density, p2([x,y]) + a, b : float + The limits of integration in x: a