diff --git a/torch_cka/cka.py b/torch_cka/cka.py index 9e7c95f..dd71e67 100644 --- a/torch_cka/cka.py +++ b/torch_cka/cka.py @@ -172,7 +172,7 @@ def compare(self, Y = feat2.flatten(1) L = Y @ Y.t() L.fill_diagonal_(0) - assert K.shape == L.shape, f"Feature shape mistach! {K.shape}, {L.shape}" + assert K.shape == L.shape, f"Feature shape mismatch! {K.shape}, {L.shape}" self.hsic_matrix[i, j, 1] += self._HSIC(K, L) / num_batches self.hsic_matrix[i, j, 2] += self._HSIC(L, L) / num_batches @@ -217,4 +217,4 @@ def plot_results(self, if save_path is not None: plt.savefig(save_path, dpi=300) - plt.show() \ No newline at end of file + plt.show()