Hi,
Thanks for your great jobs!
But I have some questions about the MMD/JMMD loss.
the unbiased estimate of JMMD in paper is as follow:

while the code is:
# Linear version
loss = 0
for i in range(batch_size):
s1, s2 = i, (i+1)%batch_size
t1, t2 = s1+batch_size, s2+batch_size
loss += kernels[s1, s2] + kernels[t1, t2]
loss -= kernels[s1, t2] + kernels[s2, t1]
return loss / float(batch_size)
It seems the samples are not matched. For example, there is a need for n/2 sample pairs to calculate the loss in the first term of the equation. But it uses n sample pairs in the code to calculate. why are they different?
Looking forward to your reply.