From 0e53e22418b9f85ecf3cd2f5d6364fc96fac001b Mon Sep 17 00:00:00 2001 From: Kevin Kaichuang Yang Date: Thu, 17 Jul 2025 11:50:57 -0400 Subject: [PATCH] Fix more plots. --- analysis/fpd.py | 6 +++--- analysis/plot_metrics.py | 5 +++-- analysis/plot_valid.py | 9 +++++++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/analysis/fpd.py b/analysis/fpd.py index 5436d58..f07140f 100644 --- a/analysis/fpd.py +++ b/analysis/fpd.py @@ -104,10 +104,10 @@ def calculate_fid(act1, act2, eps=1e-6): for j, s2 in enumerate(natural_sets): if i > j: ej = pb_embedding_dict[s2] - mmd = mmd_rbf(ei[:], ej[:], gamma=pb_gamma) * mult + # mmd = mmd_rbf(ei[:], ej[:], gamma=pb_gamma) * mult fpd = calculate_fid(ei[:], ej[:], eps=1e-6) - print(s, s2, mmd, fpd) - pb_mmd_dict[s + ':' + s2] = mmd + print(s, s2, fpd) + # pb_mmd_dict[s + ':' + s2] = mmd pb_fpd_dict[s + ':' + s2] = fpd rfd_sets = [ diff --git a/analysis/plot_metrics.py b/analysis/plot_metrics.py index c7305c9..92b7c92 100644 --- a/analysis/plot_metrics.py +++ b/analysis/plot_metrics.py @@ -3,6 +3,7 @@ import matplotlib.pyplot as plt import pandas as pd import seaborn as sns +import numpy as np model_order = [ '170m-uniref50', @@ -446,7 +447,7 @@ def plot_cdf(x, color=sns.color_palette()[0], label=None, ax=None, **kwargs, ): fig, ax = plt.subplots(1, 1) pal = sns.color_palette() _ = sns.scatterplot(data=melted, x='value', y='Fraction expressed', hue='variable', style='variable', - ax=ax, palette=[pal[4], pal[7]], s=100) + ax=ax, palette=[pal[7], pal[4]], s=100) _ = ax.set_xlabel("Perplexity") handles, labels = ax.get_legend_handles_labels() ax.legend(handles=handles[:], labels=labels[:]) # This gets rid of the title @@ -467,7 +468,7 @@ def plot_cdf(x, color=sns.color_palette()[0], label=None, ax=None, **kwargs, ): fig, ax = plt.subplots(1, 1) pal = sns.color_palette() _ = sns.scatterplot(data=melted, x='value', y='Fraction expressed', hue='variable', style='variable', - ax=ax, palette=[pal[4], pal[7]], s=100) + ax=ax, palette=[pal[7], pal[4]], s=100) _ = ax.set_xlabel("FPD") handles, labels = ax.get_legend_handles_labels() ax.legend(handles=handles[:], labels=labels[:]) # This gets rid of the title diff --git a/analysis/plot_valid.py b/analysis/plot_valid.py index ea2a17a..9c2aa17 100644 --- a/analysis/plot_valid.py +++ b/analysis/plot_valid.py @@ -32,7 +32,11 @@ out_file = os.path.join(out_fpath, "valid_" + model + '_' + str(checkpoint) + "_" + "uniref" + "_" + direction + "_%d.pt" %rank) try: dat = torch.load(out_file) - ces.append(dat["ce"]) + if '3b' in model and direction == 'forward' and checkpoint == 43300: + ces.append(dat["ce"][::1000]) + else: + ces.append(dat["ce"]) + print(rank, torch.cat(ces).shape) except(EOFError): continue ces = torch.cat(ces) @@ -47,11 +51,12 @@ # _ = ax1.fill_between(x, ce_by_pos + se_by_pos, ce_by_pos - se_by_pos, alpha=0.3, color=pal[i]) _ = ax1.set_xlabel('position') _ = ax1.set_ylabel('cross-entropy') + _ = ax1.set_ylim(0.5, 3) _ = ax1.legend() ax2 = ax1.twinx() n = np.isfinite(ces).sum(axis=0) _ = ax2.plot(x, n, "-", color="gray") - _ = ax2.set_ylabel('n') + _ = ax2.set_ylabel('# of val sequences', rotation=270, labelpad=15) _ = fig1.savefig(os.path.join(out_fpath, model + "_" + "uniref" + "_" + direction + ".pdf"), dpi=300, bbox_inches="tight") # #