diff --git a/pykernels/regular.py b/pykernels/regular.py index c65f01e..ab48c93 100644 --- a/pykernels/regular.py +++ b/pykernels/regular.py @@ -465,16 +465,11 @@ def _compute(self, data_1, data_2): if np.any(data_1 < 0) or np.any(data_2 < 0): warnings.warn('MinMax kernel requires data to be strictly positive!') - minkernel = np.zeros((data_1.shape[0], data_2.shape[0])) - maxkernel = np.zeros((data_1.shape[0], data_2.shape[0])) - - for d in range(data_1.shape[1]): - column_1 = data_1[:, d].reshape(-1, 1) - column_2 = data_2[:, d].reshape(-1, 1) - minkernel += np.minimum(column_1, column_2.T) - maxkernel += np.maximum(column_1, column_2.T) - - return minkernel/maxkernel + data_1 = data_1[:, None, :] + data_2 = data_2[None, :, :] + minkernel = np.minimum(data_1, data_2).sum(axis=2) # Sum over features + maxkernel = np.maximum(data_1, data_2).sum(axis=2) + return minkernel / maxkernel def dim(self): return None