diff --git a/model.py b/model.py index ca375a3..1c16af6 100644 --- a/model.py +++ b/model.py @@ -21,11 +21,10 @@ def __init__(self, batch_size=16, seq_len=16, b_size=50, x_size=1*64*64, c_size= self.p_d = Decoder() # losses - self.kl = KullbackLeibler(self.q, self.p_b1) - self.b_ll = -NLL(self.p_b2) - self.t_nll = NLL(self.p_t) + self.kl_1 = KullbackLeibler(self.q, self.p_b1) + self.kl_2 = KullbackLeibler(self.p_b2, self.p_t) self.d_nll = NLL(self.p_d) - self.loss_cls = (self.kl+self.b_ll+self.t_nll+self.d_nll).mean() + self.loss_cls = (self.kl_1+self.kl_2+self.d_nll).mean() def forward(self, batch): batch_size, seq_len, *_ = batch.size()