diff --git a/draw.py b/draw.py index eefff2a..c8fdf2c 100644 --- a/draw.py +++ b/draw.py @@ -188,7 +188,7 @@ def binary_crossentropy(t,o): mu2=tf.square(mus[t]) sigma2=tf.square(sigmas[t]) logsigma=logsigmas[t] - kl_terms[t]=0.5*tf.reduce_sum(mu2+sigma2-2*logsigma,1)-T*.5 # each kl term is (1xminibatch) + kl_terms[t]=0.5*tf.reduce_sum(mu2+sigma2-2*logsigma-1,1) # each kl term is (1xminibatch) KL=tf.add_n(kl_terms) # this is 1xminibatch, corresponding to summing kl_terms from 1:T Lz=tf.reduce_mean(KL) # average over minibatches