From ffe1723f31cbb3cb08aab77301e76b50346c355f Mon Sep 17 00:00:00 2001 From: Isaac Davis Date: Tue, 25 Jul 2023 15:38:20 -0600 Subject: [PATCH 1/5] Add cloud_regime_analysis.py and add COSP variable to adf_variable_defaults.yaml --- lib/adf_variable_defaults.yaml | 39 +- scripts/plotting/cloud_regime_analysis.py | 1379 +++++++++++++++++++++ 2 files changed, 1416 insertions(+), 2 deletions(-) create mode 100644 scripts/plotting/cloud_regime_analysis.py diff --git a/lib/adf_variable_defaults.yaml b/lib/adf_variable_defaults.yaml index e6ff4ccfb..9892bf3e7 100644 --- a/lib/adf_variable_defaults.yaml +++ b/lib/adf_variable_defaults.yaml @@ -1,4 +1,3 @@ - #This file lists out variable-specific defaults #for plotting and observations. These defaults #are: @@ -854,12 +853,32 @@ CLIMODIS: CLWMODIS: category: "COSP" +CLD_MISR: + category: "Clouds" + obs_file: 'MISR_obs_data.nc' + obs_name: "MISR" + +CLMODIS: + category: "Clouds" + obs_file: 'MODIS_obs_data.nc' + obs_name: "MODIS" + FISCCP1_COSP: - category: "COSP" + category: "Clouds" + obs_file: 'ISCCP_obs_data.nc' + obs_name: "ISCCP" ICE_ICLD_VISTAU: category: "COSP" +ISCCP_emd_centers: + category: "Clouds" + obs_file: 'ISCCP_emd-means_n_init5_centers_1.npy' + +ISCCP_euclidean_centers: + category: "Clouds" + obs_file: 'CS_qualitative_clusters.npy' + IWPMODIS: category: "COSP" @@ -884,6 +903,22 @@ MEANTB_ISCCP: MEANTBCLR_ISCCP: category: "COSP" +MISR_euclidean_centers: + category: "Clouds" + obs_file: 'MISR_6C_weather_state_centers.npy' + +MISR_emd_centers: + category: "Clouds" + obs_file: 'MISR_emd-means_n_init5_centers_1.npy' + +MODIS_euclidean_centers: + category: "Clouds" + obs_file: 'MODIS_6C_weather_state_centers.npy' + +MODIS_emd_centers: + category: "Clouds" + obs_file: 'MODIS_emd-means_n_init5_centers_1.np' + PCTMODIS: category: "COSP" diff --git a/scripts/plotting/cloud_regime_analysis.py b/scripts/plotting/cloud_regime_analysis.py new file mode 100644 index 000000000..86e3bac63 --- /dev/null +++ b/scripts/plotting/cloud_regime_analysis.py @@ -0,0 +1,1379 @@ +#%% +print() +import numpy as np +try : import wasserstein +except: + print(' Wasserstein package is not installed so wasserstein distance cannot be used. Attempting to use wassertein distance will raise an error.') + print(' To use wasserstein distance please install the wasserstein package in your environment: https://pypi.org/project/Wasserstein/ ') +import matplotlib.pyplot as plt +import xarray as xr +import matplotlib as mpl +from mpl_toolkits.axes_grid1 import make_axes_locatable +from numba import njit +from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter +import cartopy.crs as ccrs +from shapely.geometry import Point +import cartopy +from shapely.prepared import prep +import glob +from math import ceil +import time +import dask +import os + + +#global num_iter, n_samples, data, ds, ht_var_name, tau_var_name, k, height_or_pressure + +def cloud_regime_analysis(adf, wasserstein_or_euclidean = "euclidean", data_product='all', premade_cloud_regimes=None, lat_range=None, lon_range=None, only_ocean_or_land=False): + """ + This script/function is designed to generate 2-D lat/lon maps of Cloud Regimes (CRs), as well as plots of the CR + centers themselves. It can fit data into CRs using either Wassertstein (AKA Earth Movers Distance) or the more conventional + Euclidean distance. To use this script, the user should add the appropriate COSP variables to the diag_var_list in the yaml file. + The appropriate variables are FISCCP1_COSP for ISCCP, CLD_MISR for MISR, and CLMODIS for MODIS. All three should be added to + diag_var_list if you wish to preform analysis on all three. The user can also specify to preform analysis for just one or for + all three of the data products (ISCCP, MODIS, and MISR) that there exists COSP output for. A user can also choose to use only + a specfic lat and lon range, or to use data only over water or over land. Lastly if a user has CRs that they have custom made, + these can be passed in and the script will fit data into them rather than the premade CRs that the script already points to. + There are a total of 6 sets of premade CRs, two for each data product. One set made with euclidean distance and one set made + with Wasserstein distance for ISCCP, MODIS, and MISR. Therefore when the wasserstein_or_euclidean variables is changes it is + important to undertand that not only the distance metric used to fit data into CRs is changing, but also the CRs themselves + unless the user is passing in a set of premade CRs with the premade_cloud_regimes variable. + + Description of kwargs: + wasserstein_or_euclidean -> Whether to use wasserstein or euclidean distance to fit CRs, enter "wassertein" for wasserstein or + "euclidean" for euclidean. This also changes the default CRs that data is fit into from ones created + with kmeans using euclidean distance to ones using kmeans with wassertein distance. Default is euclidean distance. + data_product -> Which data product to preform analysis for. Enter "ISCCP", "MODIS", "MISR" or "all". Default is "all" + premade_cloud_regimes -> If the user wishes to use custom CRs rather than the pre-loaded ones, enter them here as a path to a numpy + array of shape (k, n_tau_bins * n_pressure_bins) + lat_range -> Range of latitudes to use enetered as a list, Ex. [-30,30]. Default is use all available latitudes + lon_range -> Range of longitudes to use enetered as a list, Ex. [-90,90]. Default is use all available longitudes + only_ocean_or_land -> Set to "O" to preform analysis with only points over water, "L" for only points over land, or False + to use data over land and water. Default is False + """ + + global k, ht_var_name, tau_var_name, var, mat, mat_b, mat_o + dask.config.set({"array.slicing.split_large_chunks": False}) + + # Compute cluster labels from precomputed cluster centers with appropriate distance + def precomputed_clusters(mat, cl, wasserstein_or_euclidean, ds): + + if wasserstein_or_euclidean == 'euclidean': + cluster_dists = np.sum((mat[:,:,None] - cl.T[None,:,:])**2, axis = 1) + cluster_labels_temp = np.argmin(cluster_dists, axis = 1) + + if wasserstein_or_euclidean == 'wasserstein': + + # A function to convert mat into the form required for the EMD calculation + @njit() + def stacking(position_matrix, centroids): + centroid_list = [] + + for i in range(len(centroids)): + x = np.empty((3,len(mat[0]))).T + x[:,0] = centroids[i] + x[:,1] = position_matrix[0] + x[:,2] = position_matrix[1] + centroid_list.append(x) + + return centroid_list + + # setting shape + n1 = len(ds[tau_var_name]) + n2 = len(ds[ht_var_name]) + + # Calculating the max distance between two points to be used as hyperparameter in EMD + # This is not necesarily the only value for this variable that can be used, see Wasserstein documentation + # on R hyper-parameter for more information + R = (n1**2+n2**2)**0.5 + + # Creating a flattened position matrix to pass wasersstein.PairwiseEMD + position_matrix = np.zeros((2,n1,n2)) + position_matrix[0] = np.tile(np.arange(n2),(n1,1)) + position_matrix[1] = np.tile(np.arange(n1),(n2,1)).T + position_matrix = position_matrix.reshape(2,-1) + + # Initialising wasserstein.PairwiseEMD + emds = wasserstein.PairwiseEMD(R = R, norm=True, dtype=np.float32, verbose=1, num_threads=162) + + # Rearranging mat to be in the format necesary for wasserstein.PairwiseEMD + events = stacking(position_matrix, mat) + centroid_list = stacking(position_matrix, cl) + emds(events, centroid_list) + print(" -Calculating Wasserstein distances") + print(" -Warning: This can be slow, but scales very well with additional processors") + distances = emds.emds() + labels = np.argmin(distances, axis=1) + + cluster_labels_temp = np.argmin(distances, axis=1) + + return cluster_labels_temp + + # This function is no longer used, no need to check + # Plot the CR cluster centers + def plot_hists(cl, cluster_labels, ht_var_name, tau_var_name, adf): + #defining number of clusters + k = len(cl) + + # setting up plots + ylabels = ds[ht_var_name].values + xlabels = ds[tau_var_name].values + X2,Y2 = np.meshgrid(np.arange(len(xlabels)+1), np.arange(len(ylabels)+1)) + p = [0,0.2,1,2,3,4,6,8,10,15,99] + cmap = mpl.colors.ListedColormap(['white', (0.19215686274509805, 0.25098039215686274, 0.5607843137254902), (0.23529411764705882, 0.3333333333333333, 0.6313725490196078), (0.32941176470588235, 0.5098039215686274, 0.6980392156862745), (0.39215686274509803, 0.6, 0.43137254901960786), (0.44313725490196076, 0.6588235294117647, 0.21568627450980393), (0.4980392156862745, 0.6784313725490196, 0.1843137254901961), (0.5725490196078431, 0.7137254901960784, 0.16862745098039217), (0.7529411764705882, 0.8117647058823529, 0.2), (0.9568627450980393, 0.8980392156862745,0.1607843137254902)]) + norm = mpl.colors.BoundaryNorm(p,cmap.N) + plt.rcParams.update({'font.size': 12}) + fig_height = 1 + 10/3 * ceil(k/3) + fig, ax = plt.subplots(figsize = (17, fig_height), ncols=3, nrows=ceil(k/3), sharex='all', sharey = True) + + aa = ax.ravel() + boundaries = p + norm = mpl.colors.BoundaryNorm(boundaries, cmap.N, clip=True) + aa[1].invert_yaxis() + + # creating weights area for area weighted RFOs + weights = cluster_labels.stack(z=('time','lat','lon')).lat.values + weights = np.cos(np.deg2rad(weights)) + weights = weights[valid_indicies] + indicies = np.arange(len(mat)) + + # Plotting each cluster center + for i in range (k): + + # Area Weighted relative Frequency of occurence calculation + total_rfo_num = cluster_labels == i + total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat))) + total_rfo_denom = cluster_labels >= 0 + total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat))) + total_rfo = total_rfo_num / total_rfo_denom * 100 + total_rfo = total_rfo.values + + # Area weighting each histogram belonging to a cluster and taking the mean + # if clustering was preformed with wasserstein distance and area weighting on, mean of i = cl[i], however if clustering was preformed with + # conventional kmeans or wasseerstein without weighting, these two will not be equal + indicies_i = indicies[np.where(cluster_labels_temp == i)] + mean = mat[indicies_i] * weights[indicies_i][:,np.newaxis] + if len(indicies_i) > 0: mean = np.sum(mean, axis=0) / np.sum(weights[indicies_i]) + else: mean = np.zeros(len(xlabels)*len(ylabels)) + + mean = mean.reshape(len(xlabels),len(ylabels)).T # reshaping into original histogram shape + if np.max(mean) <= 1: # Converting fractional data to percent to plot properly + mean *= 100 + + im = aa[i].pcolormesh(X2,Y2,mean,norm=norm,cmap=cmap) + aa[i].set_title(f"CR {i+1}, RFO = {np.round(total_rfo,1)}%") + + # setting titles, labels, etc + if data == "MISR": height_or_pressure = 'h' + else: height_or_pressure = 'p' + if height_or_pressure == 'p': fig.supylabel(f'Cloud-top Pressure ({ds[ht_var_name].units})', fontsize = 12, x = 0.09 ) + if height_or_pressure == 'h': fig.supylabel(f'Cloud-top Height ({ds[ht_var_name].units})', fontsize = 12, x = 0.09 ) + # fig.supxlabel('Optical Depth', fontsize = 12, y=0.26 ) + cbar_ax = fig.add_axes([0.95, 0.38, 0.045, 0.45]) + cb = fig.colorbar(im, cax=cbar_ax, ticks=p) + cb.set_label(label='Cloud Cover (%)', size =10) + cb.ax.tick_params(labelsize=9) + #aa[6].set_position([0.399, 0.125, 0.228, 0.215]) + #aa[6].set_position([0.33, 0.117, 0.36, 0.16]) + #aa[-2].remove() + + bbox = aa[1].get_position() + p1 = bbox.p1 + p0 = bbox.p0 + fig.suptitle(f'{data} Cloud Regimes', x=0.5, y=p1[1]+(1/fig_height * 0.5), fontsize=15) + + bbox = aa[-2].get_position() + p1 = bbox.p1 + p0 = bbox.p0 + fig.supxlabel('Optical Depth', fontsize = 12, y=p0[1]-(1/fig_height * 0.5) ) + + + # Removing extra plots + for i in range(ceil(k/3)*3-k): + aa[-(i+1)].remove() + save_path = adf.plot_location[0] + f'/{data}_CR_centers' + plt.savefig(save_path) + + if adf.create_html: + adf.add_website_data(save_path + ".png", var, adf.get_baseline_info("cam_case_name")) + + + # Plot the CR centers of obs, baseline and test case + def plot_hists_baseline(cl, cluster_labels, cluster_labels_o, ht_var_name, tau_var_name, adf): + # #defining number of clusters + k = len(cl) + + # setting up plots + ylabels = ds[ht_var_name].values + xlabels = ds[tau_var_name].values + X2,Y2 = np.meshgrid(np.arange(len(xlabels)+1), np.arange(len(ylabels)+1)) + p = [0,0.2,1,2,3,4,6,8,10,15,99] + cmap = mpl.colors.ListedColormap(['white', (0.19215686274509805, 0.25098039215686274, 0.5607843137254902), (0.23529411764705882, 0.3333333333333333, 0.6313725490196078), (0.32941176470588235, 0.5098039215686274, 0.6980392156862745), (0.39215686274509803, 0.6, 0.43137254901960786), (0.44313725490196076, 0.6588235294117647, 0.21568627450980393), (0.4980392156862745, 0.6784313725490196, 0.1843137254901961), (0.5725490196078431, 0.7137254901960784, 0.16862745098039217), (0.7529411764705882, 0.8117647058823529, 0.2), (0.9568627450980393, 0.8980392156862745,0.1607843137254902)]) + norm = mpl.colors.BoundaryNorm(p,cmap.N) + plt.rcParams.update({'font.size': 14}) + fig_height = (1 + 10/3 * ceil(k/3))*3 + fig, ax = plt.subplots(figsize = (17, fig_height), ncols=3, nrows=k, sharex='all', sharey = True) + + aa = ax.ravel() + boundaries = p + norm = mpl.colors.BoundaryNorm(boundaries, cmap.N, clip=True) + if data != 'MISR': aa[1].invert_yaxis() + + # creating weights area for area weighted RFOs + weights = cluster_labels.stack(z=('time','lat','lon')).lat.values + weights = np.cos(np.deg2rad(weights)) + weights = weights[valid_indicies] + indicies = np.arange(len(mat)) + + for i in range(k): + + im = ax[i,0].pcolormesh(X2,Y2,cl[i].reshape(len(xlabels),len(ylabels)).T,norm=norm,cmap=cmap) + ax[i,0].set_title(f" Observation CR {i+1}") + + # Plotting each cluster center (baseline) + for i in range (k): + # Area Weighted relative Frequency of occurence calculation + total_rfo_num = cluster_labels_b == i + total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels_b.lat))) + total_rfo_denom = cluster_labels_b >= 0 + total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels_b.lat))) + total_rfo = total_rfo_num / total_rfo_denom * 100 + total_rfo = total_rfo.values + + # Area weighting each histogram belonging to a cluster and taking the mean + # if clustering was preformed with wasserstein distance and area weighting on, mean of i = cl[i], however if clustering was preformed with + # conventional kmeans or wasseerstein without weighting, these two will not be equal + indicies_i = indicies[np.where(cluster_labels_temp_b == i)] + mean = mat_b[indicies_i] * weights[indicies_i][:,np.newaxis] + if len(indicies_i) > 0: mean = np.sum(mean, axis=0) / np.sum(weights[indicies_i]) + else: mean = np.zeros(len(xlabels)*len(ylabels)) + + mean = mean.reshape(len(xlabels),len(ylabels)).T # reshaping into original histogram shape + if np.max(mean) <= 1: # Converting fractional data to percent to plot properly + mean *= 100 + + im = ax[i,1].pcolormesh(X2,Y2,mean,norm=norm,cmap=cmap) + ax[i,1].set_title(f"Baseline Case CR {i+1}, RFO = {np.round(total_rfo,1)}%") + + # Plotting each cluster center (test_case) + for i in range (k): + + # Area Weighted relative Frequency of occurence calculation + total_rfo_num = cluster_labels == i + total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat))) + total_rfo_denom = cluster_labels >= 0 + total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat))) + total_rfo = total_rfo_num / total_rfo_denom * 100 + total_rfo = total_rfo.values + + # Area weighting each histogram belonging to a cluster and taking the mean + # if clustering was preformed with wasserstein distance and area weighting on, mean of i = cl[i], however if clustering was preformed with + # conventional kmeans or wasseerstein without weighting, these two will not be equal + indicies_i = indicies[np.where(cluster_labels_temp == i)] + mean = mat[indicies_i] * weights[indicies_i][:,np.newaxis] + if len(indicies_i) > 0: mean = np.sum(mean, axis=0) / np.sum(weights[indicies_i]) + else: mean = np.zeros(len(xlabels)*len(ylabels)) + + mean = mean.reshape(len(xlabels),len(ylabels)).T # reshaping into original histogram shape + if np.max(mean) <= 1: # Converting fractional data to percent to plot properly + mean *= 100 + + im = ax[i,2].pcolormesh(X2,Y2,mean,norm=norm,cmap=cmap) + ax[i,2].set_title(f"Test Case CR {i+1}, RFO = {np.round(total_rfo,1)}%") + + # setting titles, labels, etc + if data == "MODIS": + ylabels = [0, 180, 310, 440, 560, 680, 800, 1000] + xlabels = [0, 0.3, 1.3, 3.6, 9.4, 23, 60, 150] + ax[0,0].set_yticks(np.arange(8)) + ax[0,0].set_xticks(np.arange(8)) + ax[0,0].set_yticklabels(ylabels) + ax[0,0].set_xticklabels(xlabels) + xticks = ax[0,0].xaxis.get_major_ticks() + xticks[0].set_visible(False) + xticks[-1].set_visible(False) + + if data == "MISR": + xlabels = [0.2, 0.8, 2.4, 6.5, 16.2, 41.5, 100] + ylabels = [ 0.25,0.75,1.25,1.75,2.25,2.75,3.5, 4.5, 6,8, 10, 12, 14, 16, 20 ] + ax[0,0].set_yticks(np.arange(0,16,2)+0.5) + ax[0,0].set_yticklabels(ylabels[0::2]) + ax[0,0].set_xticks(np.array([1,2,3,4,5,6,7]) -0.5) + ax[0,0].set_xticklabels(xlabels, fontsize = 16) + xticks = ax[0,0].xaxis.get_major_ticks() + xticks[0].set_visible(False) + xticks[-1].set_visible(False) + + if data == 'ISCCP': + xlabels = [ 0, 1.3, 3.6, 9.4, 22.6, 60.4, 450 ] + ylabels = [ 10, 180, 310, 440, 560, 680, 800, 1025] + yticks = aa[i].get_yticks().tolist() + xticks = aa[i].get_xticks().tolist() + aa[i].set_yticks(yticks) + aa[i].set_xticks(xticks) + aa[i].set_yticklabels(ylabels) + aa[i].set_xticklabels(xlabels) + xticks = aa[i].xaxis.get_major_ticks() + xticks[0].label1.set_visible(False) + xticks[-1].label1.set_visible(False) + + + if data == "MISR": height_or_pressure = 'h' + else: height_or_pressure = 'p' + if height_or_pressure == 'p': fig.supylabel(f'Cloud-top Pressure ({ds[ht_var_name].units})', x = 0.07 ) + if height_or_pressure == 'h': fig.supylabel(f'Cloud-top Height ({ds[ht_var_name].units})', x = 0.07 ) + + if data == "MODIS": + ylabels = [0, 180, 310, 440, 560, 680, 800, 1000] + xlabels = [0, 0.3, 1.3, 3.6, 9.4, 23, 60, 150] + if data == "MISR": + x=1 + + + + # fig.supxlabel('Optical Depth', fontsize = 12, y=0.26 ) + cbar_ax = fig.add_axes([0.95, 0.38, 0.045, 0.45]) + cb = fig.colorbar(im, cax=cbar_ax, ticks=p) + cb.set_label(label='Cloud Cover (%)', size =10) + cb.ax.tick_params(labelsize=9) + #aa[6].set_position([0.399, 0.125, 0.228, 0.215]) + #aa[6].set_position([0.33, 0.117, 0.36, 0.16]) + #aa[-2].remove() + + bbox = aa[1].get_position() + p1 = bbox.p1 + p0 = bbox.p0 + fig.suptitle(f'{data} Cloud Regimes', x=0.5, y=p1[1]+(1/fig_height * 0.5)+0.007, fontsize=18) + + bbox = aa[-2].get_position() + p1 = bbox.p1 + p0 = bbox.p0 + fig.supxlabel('Optical Depth', y=p0[1]-(1/fig_height * 0.5)-0.007 ) + + + save_path = adf.plot_location[0] + f'/{data}_CR_centers' + plt.savefig(save_path) + + if adf.create_html: + adf.add_website_data(save_path + ".png", var, adf.get_baseline_info("cam_case_name")) + + # Closing the figure + plt.close() + + + # Plot the CR centers of obs and test case + def plot_hists_obs(cl, cluster_labels, cluster_labels_o, ht_var_name, tau_var_name, adf): + #defining number of clusters + k = len(cl) + # setting up plots + ylabels = ds[ht_var_name].values + xlabels = ds[tau_var_name].values + X2,Y2 = np.meshgrid(np.arange(len(xlabels)+1), np.arange(len(ylabels)+1)) + p = [0,0.2,1,2,3,4,6,8,10,15,99] + cmap = mpl.colors.ListedColormap(['white', (0.19215686274509805, 0.25098039215686274, 0.5607843137254902), (0.23529411764705882, 0.3333333333333333, 0.6313725490196078), (0.32941176470588235, 0.5098039215686274, 0.6980392156862745), (0.39215686274509803, 0.6, 0.43137254901960786), (0.44313725490196076, 0.6588235294117647, 0.21568627450980393), (0.4980392156862745, 0.6784313725490196, 0.1843137254901961), (0.5725490196078431, 0.7137254901960784, 0.16862745098039217), (0.7529411764705882, 0.8117647058823529, 0.2), (0.9568627450980393, 0.8980392156862745,0.1607843137254902)]) + norm = mpl.colors.BoundaryNorm(p,cmap.N) + plt.rcParams.update({'font.size': 14}) + fig_height = (1 + 10/3 * ceil(k/3))*3 + fig, ax = plt.subplots(figsize = (12, fig_height), ncols=2, nrows=k, sharex='all', sharey = True) + + aa = ax.ravel() + boundaries = p + norm = mpl.colors.BoundaryNorm(boundaries, cmap.N, clip=True) + if data != 'MISR': aa[1].invert_yaxis() + + # creating weights area for area weighted RFOs + weights = cluster_labels.stack(z=('time','lat','lon')).lat.values + weights = np.cos(np.deg2rad(weights)) + weights = weights[valid_indicies] + indicies = np.arange(len(mat)) + + # ax[0,0].set_xticklabels(xlabels) + # ax[0,0].set_yticklabels(ylabels) + + for i in range(k): + # Area Weighted relative Frequency of occurence calculation + total_rfo_num = cluster_labels_o == i + total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels_o.lat))) + total_rfo_denom = cluster_labels_o >= 0 + total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels_o.lat))) + total_rfo = total_rfo_num / total_rfo_denom * 100 + total_rfo = total_rfo.values + + im = ax[i,0].pcolormesh(X2,Y2,cl[i].reshape(len(xlabels),len(ylabels)).T,norm=norm,cmap=cmap) + ax[i,0].set_title(f" Observation CR {i+1}, RFO = {np.round(total_rfo,1)}%") + + # Plotting each cluster center (test_case) + for i in range (k): + + # Area Weighted relative Frequency of occurence calculation + total_rfo_num = cluster_labels == i + total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat))) + total_rfo_denom = cluster_labels >= 0 + total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat))) + total_rfo = total_rfo_num / total_rfo_denom * 100 + total_rfo = total_rfo.values + + # Area weighting each histogram belonging to a cluster and taking the mean + # if clustering was preformed with wasserstein distance and area weighting on, mean of i = cl[i], however if clustering was preformed with + # conventional kmeans or wasseerstein without weighting, these two will not be equal + indicies_i = indicies[np.where(cluster_labels_temp == i)] + mean = mat[indicies_i] * weights[indicies_i][:,np.newaxis] + if len(indicies_i) > 0: mean = np.sum(mean, axis=0) / np.sum(weights[indicies_i]) + else: mean = np.zeros(len(xlabels)*len(ylabels)) + + mean = mean.reshape(len(xlabels),len(ylabels)).T # reshaping into original histogram shape + if np.max(mean) <= 1: # Converting fractional data to percent to plot properly + mean *= 100 + + im = ax[i,1].pcolormesh(X2,Y2,mean,norm=norm,cmap=cmap) + ax[i,1].set_title(f"Test Case CR {i+1}, RFO = {np.round(total_rfo,1)}%") + + if data == "MODIS": + ylabels = [0, 180, 310, 440, 560, 680, 800, 1000] + xlabels = [0, 0.3, 1.3, 3.6, 9.4, 23, 60, 150] + ax[0,0].set_yticks(np.arange(8)) + ax[0,0].set_xticks(np.arange(8)) + ax[0,0].set_yticklabels(ylabels) + ax[0,0].set_xticklabels(xlabels) + xticks = ax[0,0].xaxis.get_major_ticks() + xticks[0].set_visible(False) + xticks[-1].set_visible(False) + + if data == "MISR": + xlabels = [0.2, 0.8, 2.4, 6.5, 16.2, 41.5, 100] + ylabels = [ 0.25,0.75,1.25,1.75,2.25,2.75,3.5, 4.5, 6,8, 10, 12, 14, 16, 20 ] + ax[0,0].set_yticks(np.arange(0,16,2)+0.5) + ax[0,0].set_yticklabels(ylabels[0::2]) + ax[0,0].set_xticks(np.array([1,2,3,4,5,6,7]) -0.5) + ax[0,0].set_xticklabels(xlabels, fontsize = 16) + xticks = ax[0,0].xaxis.get_major_ticks() + xticks[0].set_visible(False) + xticks[-1].set_visible(False) + + if data == 'ISCCP': + xlabels = [ 0, 1.3, 3.6, 9.4, 22.6, 60.4, 450 ] + ylabels = [ 10, 180, 310, 440, 560, 680, 800, 1025] + yticks = aa[i].get_yticks().tolist() + xticks = aa[i].get_xticks().tolist() + aa[i].set_yticks(yticks) + aa[i].set_xticks(xticks) + aa[i].set_yticklabels(ylabels) + aa[i].set_xticklabels(xlabels) + xticks = aa[i].xaxis.get_major_ticks() + xticks[0].label1.set_visible(False) + xticks[-1].label1.set_visible(False) + + # setting titles, labels, etc + if data == "MISR": height_or_pressure = 'h' + else: height_or_pressure = 'p' + if height_or_pressure == 'p': fig.supylabel(f'Cloud-top Pressure ({ds[ht_var_name].units})', x = 0.05 ) + if height_or_pressure == 'h': fig.supylabel(f'Cloud-top Height ({ds[ht_var_name].units})', x = 0.05) + # fig.supxlabel('Optical Depth', fontsize = 12, y=0.26 ) + cbar_ax = fig.add_axes([0.95, 0.38, 0.045, 0.45]) + cb = fig.colorbar(im, cax=cbar_ax, ticks=p) + cb.set_label(label='Cloud Cover (%)') + # cb.ax.tick_params(labelsize=9) + + + bbox = aa[1].get_position() + p1 = bbox.p1 + p0 = bbox.p0 + fig.suptitle(f'{data} Cloud Regimes', x=0.5, y=p1[1]+(1/fig_height * 0.5)+0.007, fontsize=18) + + bbox = aa[-2].get_position() + p1 = bbox.p1 + p0 = bbox.p0 + fig.supxlabel('Optical Depth', y=p0[1]-(1/fig_height * 0.5)-0.007 ) + + save_path = adf.plot_location[0] + f'/{data}_CR_centers' + plt.savefig(save_path) + + if adf.create_html: + adf.add_website_data(save_path + ".png", var, case_name = None, multi_case=True) + + # Closing the figure + plt.close() + + # Plot LatLon plots of the frequency of occrence of the baseline/obs and test case + def plot_rfo_obs_base_diff(cluster_labels, cluster_labels_d, adf): + + COLOR = 'black' + mpl.rcParams['text.color'] = COLOR + mpl.rcParams['axes.labelcolor'] = COLOR + mpl.rcParams['xtick.color'] = COLOR + mpl.rcParams['ytick.color'] = COLOR + plt.rcParams.update({'font.size': 13}) + plt.rcParams['figure.dpi'] = 500 + fig_height = 7 + + # Comparing obs or baseline? + if adf.compare_obs == True: + obs_or_base = 'Observation' + else: + obs_or_base = 'Baseline' + + for cluster in range(k): + fig, ax = plt.subplots(ncols=2, nrows=2, subplot_kw={'projection': ccrs.PlateCarree()}, figsize = (12,fig_height))#, sharex='col', sharey='row') + plt.subplots_adjust(wspace=0.08, hspace=0.002) + aa = ax.ravel() + + # Calculating and plotting rfo of baseline/obs + X, Y = np. meshgrid(cluster_labels_d.lon,cluster_labels_d.lat) + rfo_d = np.sum(cluster_labels_d==cluster, axis=0) / np.sum(cluster_labels_d >= 0, axis=0) * 100 + aa[0].set_extent([-180, 180, -90, 90]) + aa[0].coastlines() + mesh = aa[0].pcolormesh(X, Y, rfo_d, transform=ccrs.PlateCarree(), rasterized = True, cmap="GnBu",vmin=0,vmax=100) + total_rfo_num = cluster_labels_d == cluster + total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels_d.lat))) + total_rfo_denom = cluster_labels_d >= 0 + total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels_d.lat))) + total_rfo_d = total_rfo_num / total_rfo_denom * 100 + aa[0].set_title(f"{obs_or_base}, RFO = {round(float(total_rfo_d),1)}", pad=4) + + # Calculating and plotting rfo of test_case + X, Y = np. meshgrid(cluster_labels.lon,cluster_labels.lat) + rfo = np.sum(cluster_labels==cluster, axis=0) / np.sum(cluster_labels >= 0, axis=0) * 100 + aa[1].set_extent([-180, 180, -90, 90]) + aa[1].coastlines() + mesh = aa[1].pcolormesh(X, Y, rfo, transform=ccrs.PlateCarree(), rasterized = True, cmap="GnBu",vmin=0,vmax=100) + total_rfo_num = cluster_labels == cluster + total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat))) + total_rfo_denom = cluster_labels >= 0 + total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat))) + total_rfo = total_rfo_num / total_rfo_denom * 100 + aa[1].set_title(f"Test Case, RFO = {round(float(total_rfo),1)}", pad=4) + + # Making colorbar + cax = fig.add_axes([aa[1].get_position().x1+0.01,aa[1].get_position().y0,0.02,aa[1].get_position().height]) + cb = plt.colorbar(mesh, cax=cax) + cb.set_label(label = 'RFO (%)') + + # Calculating and plotting difference + # If observation/baseline is a higher resolution interpolate from obs/baseline to CAM grid + if len(cluster_labels_d.lat) * len(cluster_labels_d.lon) > len(cluster_labels.lat) * len(cluster_labels.lon): + rfo_d = rfo_d.interp_like(rfo, method="nearest") + + # If CAM is a higher resolution interpolate from CAM to obs/baseline grid + if len(cluster_labels_d.lat) * len(cluster_labels_d.lon) <= len(cluster_labels.lat) * len(cluster_labels.lon): + rfo = rfo.interp_like(rfo_d, method="nearest") + X, Y = np. meshgrid(cluster_labels_d.lon,cluster_labels_d.lat) + + rfo_diff = rfo - rfo_d + + aa[2].set_extent([-180, 180, -90, 90]) + aa[2].coastlines() + mesh = aa[2].pcolormesh(X, Y, rfo_diff, transform=ccrs.PlateCarree(), rasterized = True, cmap="coolwarm",vmin=-100,vmax=100) + total_rfo_num = cluster_labels == cluster + total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat))) + total_rfo_denom = cluster_labels >= 0 + total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat))) + total_rfo = total_rfo_num / total_rfo_denom * 100 + aa[2].set_title(f"Test - {obs_or_base}, ΔRFO = {round(float(total_rfo-total_rfo_d),1)}", pad=4) + + + # Setting yticks + aa[0].set_yticks([-60,-30,0,30,60], crs=ccrs.PlateCarree()) + aa[2].set_yticks([-60,-30,0,30,60], crs=ccrs.PlateCarree()) + lat_formatter = LatitudeFormatter() + aa[0].yaxis.set_major_formatter(lat_formatter) + aa[2].yaxis.set_major_formatter(lat_formatter) + + + # making colorbar for diff plot + cax = fig.add_axes([aa[2].get_position().x1+0.01,aa[2].get_position().y0,0.02,aa[2].get_position().height]) + cb = plt.colorbar(mesh, cax=cax) + cb.set_label(label = 'ΔRFO (%)') + + # plotting x labels + aa[1].set_xticks([-120,-60,0,60,120,], crs=ccrs.PlateCarree()) + lon_formatter = LongitudeFormatter(zero_direction_label=True) + aa[1].xaxis.set_major_formatter(lon_formatter) + aa[2].set_xticks([-120,-60,0,60,120,], crs=ccrs.PlateCarree()) + lon_formatter = LongitudeFormatter(zero_direction_label=True) + aa[2].xaxis.set_major_formatter(lon_formatter) + + bbox = aa[1].get_position() + p1 = bbox.p1 + plt.suptitle(f"CR{cluster+1} Relative Frequency of Occurence", y= p1[1]+(1/fig_height * 0.5))#, {round(cl[cluster,23],4)}") + + aa[-1].remove() + + save_path = adf.plot_location[0] + f'/{data}_CR{cluster+1}_LatLon_mean' + plt.savefig(save_path) + + if adf.create_html: + adf.add_website_data(save_path + ".png", var, case_name = None, multi_case=True) + + # Closing the figure + plt.close() + + # This function is no longer used, no reason to check it + # Plot RFO maps of the CRss + def plot_rfo(cluster_labels, adf): + #defining number of clusters + + COLOR = 'black' + mpl.rcParams['text.color'] = COLOR + mpl.rcParams['axes.labelcolor'] = COLOR + mpl.rcParams['xtick.color'] = COLOR + mpl.rcParams['ytick.color'] = COLOR + plt.rcParams.update({'font.size': 10}) + fig_height = 2.2 * ceil(k/2) + plt.rcParams['figure.dpi'] = 500 + fig, ax = plt.subplots(ncols=2, nrows=int(k/2 + k%2), subplot_kw={'projection': ccrs.PlateCarree()}, figsize = (10,fig_height))#, sharex='col', sharey='row') + plt.subplots_adjust(wspace=0.13, hspace=0.05) + aa = ax.ravel() + + X, Y = np. meshgrid(ds.lon,ds.lat) + + # Plotting the rfo of each cluster + tot_rfo_sum = 0 + + for cluster in range(k): #range(0,k+1): + # Calculating rfo + rfo = np.sum(cluster_labels==cluster, axis=0) / np.sum(cluster_labels >= 0, axis=0) * 100 + # tca_explained = np.sum(cluster_labels == cluster) * np.sum(init_clusters[cluster]) / total_cloud_amnt * 100 + # tca_explained = round(float(tca_explained.values),1) + aa[cluster].set_extent([-180, 180, -90, 90]) + aa[cluster].coastlines() + mesh = aa[cluster].pcolormesh(X, Y, rfo, transform=ccrs.PlateCarree(), rasterized = True, cmap="GnBu",vmin=0,vmax=100) + #total_rfo = np.sum(cluster_labels==cluster) / np.sum(cluster_labels >= 0) * 100 + # total_rfo_num = np.sum(cluster_labels == cluster * np.cos(np.deg2rad(cluster_labels.lat))) + total_rfo_num = cluster_labels == cluster + total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat))) + total_rfo_denom = cluster_labels >= 0 + total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat))) + + total_rfo = total_rfo_num / total_rfo_denom * 100 + tot_rfo_sum += total_rfo + aa[cluster].set_title(f"CR {cluster+1}, RFO = {round(float(total_rfo),1)}", pad=4) + # aa[cluster].gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False) + # x_label_plot_list = [4,5,6] + # y_label_plot_list = [0,2,4,6] + # if cluster in x_label_plot_list: + + + if cluster % 2 == 0: + aa[cluster].set_yticks([-60,-30,0,30,60], crs=ccrs.PlateCarree()) + lat_formatter = LatitudeFormatter() + aa[cluster].yaxis.set_major_formatter(lat_formatter) + + #aa[7].set_title(f"Weathersdfasdfa State {i+1}, RFO = {round(float(total_rfo),1)}", pad=-40) + cb = plt.colorbar(mesh, ax = ax, anchor =(-0.28,0.83), shrink = 0.6) + cb.set_label(label = 'RFO (%)', labelpad=-3) + + x_ticks_indicies = np.array([-1,-2]) + + if k%2 == 1: + aa[-1].remove() + x_ticks_indicies -= 1 + + #aa[-2].set_position([0.27, 0.11, 0.31, 0.15]) + + # plotting x labels on final two plots + aa[x_ticks_indicies[0]].set_xticks([-120,-60,0,60,120,], crs=ccrs.PlateCarree()) + lon_formatter = LongitudeFormatter(zero_direction_label=True) + aa[x_ticks_indicies[0]].xaxis.set_major_formatter(lon_formatter) + aa[x_ticks_indicies[1]].set_xticks([-120,-60,0,60,120,], crs=ccrs.PlateCarree()) + lon_formatter = LongitudeFormatter(zero_direction_label=True) + aa[x_ticks_indicies[1]].xaxis.set_major_formatter(lon_formatter) + + bbox = aa[1].get_position() + p1 = bbox.p1 + plt.suptitle(f"CR Relative Frequency of Occurence", x= 0.43, y= p1[1]+(1/fig_height * 0.5))#, {round(cl[cluster,23],4)}") + + # Saving + save_path = adf.plot_location[0] + f'/{data}_RFO' + plt.savefig(save_path) + + if adf.create_html: + adf.add_website_data(save_path + ".png", var, adf.get_baseline_info("cam_case_name")) + + # This function is no longer used, no reason to check it + # Plot RFO maps of the CRs + def plot_rfo_diff(cluster_labels, cluster_labels_o, adf): + + # Setting plot parameters + COLOR = 'black' + mpl.rcParams['text.color'] = COLOR + mpl.rcParams['axes.labelcolor'] = COLOR + mpl.rcParams['xtick.color'] = COLOR + mpl.rcParams['ytick.color'] = COLOR + plt.rcParams.update({'font.size': 10}) + fig_height = 2.2 * ceil(k/2) + fig, ax = plt.subplots(ncols=2, nrows=int(k/2 + k%2), subplot_kw={'projection': ccrs.PlateCarree()}, figsize = (10,fig_height))#, sharex='col', sharey='row') + plt.subplots_adjust(wspace=0.13, hspace=0.05) + aa = ax.ravel() + plt.rcParams['figure.dpi'] = 500 + + # CReating lat-lon mesh + X, Y = np. meshgrid(ds.lon,ds.lat) + + # Plotting the difference in relative frequency of occurence (rfo) of each cluster + for cluster in range(k): + + # Calculating rfo + rfo = np.sum(cluster_labels==cluster, axis=0) / np.sum(cluster_labels >= 0, axis=0) * 100 + rfo_o = np.sum(cluster_labels_o==cluster, axis=0) / np.sum(cluster_labels_o >= 0, axis=0) * 100 + + # If observation/baseline is a higher resolution interpolate from obs/baseline to CAM grid + if len(cluster_labels_o.lat) * len(cluster_labels_o.lon) > len(cluster_labels.lat) * len(cluster_labels.lon): + rfo_o = rfo_o.interp_like(rfo, method="nearest") + + # If CAM is a higher resolution interpolate from CAM to obs/baseline grid + if len(cluster_labels_o.lat) * len(cluster_labels_o.lon) <= len(cluster_labels.lat) * len(cluster_labels.lon): + rfo = rfo.interp_like(rfo_o, method="nearest") + + # difference in RFO + rfo_diff = rfo - rfo_o + + # Setting up subplots and plotting + aa[cluster].set_extent([-180, 180, -90, 90]) + aa[cluster].coastlines() + mesh = aa[cluster].pcolormesh(X, Y, rfo_diff, transform=ccrs.PlateCarree(), rasterized = True, cmap="coolwarm",vmin=-100,vmax=100) + + + # Calucating area weighted rfo difference for the title of subplots + total_rfo_num = cluster_labels == cluster + total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat))) + total_rfo_denom = cluster_labels >= 0 + total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat))) + total_rfo = total_rfo_num / total_rfo_denom * 100 + + total_rfo_num_o = cluster_labels_o == cluster + total_rfo_num_o = np.sum(total_rfo_num_o * np.cos(np.deg2rad(cluster_labels_o.lat))) + total_rfo_denom_o = cluster_labels_o >= 0 + total_rfo_denom_o = np.sum(total_rfo_denom_o * np.cos(np.deg2rad(cluster_labels_o.lat))) + + total_rfo_o = total_rfo_num_o / total_rfo_denom_o * 100 + + # Setting title + aa[cluster].set_title(f"CR {cluster+1}, RFO Diff = {round(float(total_rfo-total_rfo_o),1)}", pad=4) + + # Put latitude labels on even numbered subplots + if cluster % 2 == 0: + aa[cluster].set_yticks([-60,-30,0,30,60], crs=ccrs.PlateCarree()) + lat_formatter = LatitudeFormatter() + aa[cluster].yaxis.set_major_formatter(lat_formatter) + + # Setting colorbar + cb = plt.colorbar(mesh, ax = ax, anchor =(-0.28,0.83), shrink = 0.6) + cb.set_label(label = 'Diff in RFO (%)', labelpad=-3) + + # Removing extra subplot if k is an odd number + x_ticks_indicies = np.array([-1,-2]) + if k%2 == 1: + aa[-1].remove() + x_ticks_indicies -= 1 + + # plotting x labels on final two plots + aa[x_ticks_indicies[0]].set_xticks([-120,-60,0,60,120,], crs=ccrs.PlateCarree()) + lon_formatter = LongitudeFormatter(zero_direction_label=True) + aa[x_ticks_indicies[0]].xaxis.set_major_formatter(lon_formatter) + aa[x_ticks_indicies[1]].set_xticks([-120,-60,0,60,120,], crs=ccrs.PlateCarree()) + lon_formatter = LongitudeFormatter(zero_direction_label=True) + aa[x_ticks_indicies[1]].xaxis.set_major_formatter(lon_formatter) + + # Setting suptitle + bbox = aa[1].get_position() + p1 = bbox.p1 + plt.suptitle(f"CR Relative Frequency of Occurence", x= 0.43, y= p1[1]+(1/fig_height * 0.5))#, {round(cl[cluster,23],4)}") + + # Saving + save_path = adf.plot_location[0] + f'/{data}_RFO' + plt.savefig(save_path) + + if adf.create_html: + adf.add_website_data(save_path + ".png", var, multi_case=True) + + # Create a one hot matrix where lat lon coordinates are over land using cartopy + def create_land_mask(ds): + + # Get land data and prep polygons + land_110m = cartopy.feature.NaturalEarthFeature('physical', 'land', '110m') + land_polygons = list(land_110m.geometries()) + land_polygons = [prep(land_polygon) for land_polygon in land_polygons] + + # Make lat-lon grid + lats = ds.lat.values + lons = ds.lon.values + lon_grid, lat_grid = np.meshgrid(lons, lats) + + points = [Point(point) for point in zip(lon_grid.ravel(), lat_grid.ravel())] + + # Creating list of cordinates that are over land + land = [] + for land_polygon in land_polygons: + land.extend([tuple(point.coords)[0] for point in filter(land_polygon.covers, points)]) + + + landar = np.asarray(land) + lat_lon = np.empty((len(lats)*len(lons),2)) + oh_land = np.zeros((len(lats)*len(lons))) + lat_lon[:,0] = lon_grid.flatten() + lat_lon[:,1] = lat_grid.flatten() + + # Function to (somewhat) quickly test if a lat-lon point is over land + @njit() + def test (oh_land, lat_lon, landar): + for i in range(len(oh_land)): + check = lat_lon[i] == landar + if np.max(np.sum(check,axis=1)) == 2: + oh_land[i] = 1 + return oh_land + + # Turn oh_land into a one hot matrix + oh_land = test (oh_land, lat_lon, landar) + + # Reshape into original shape(n_lat, n_lon) + oh_land=oh_land.reshape((len(lats),len(lons))) + + return oh_land + + # Checking if kwargs have been entered correctly + if wasserstein_or_euclidean not in ['euclidean', 'wasserstein']: + print(' WARNING: Invalid option for wasserstein_or_euclidean. Please enter "wasserstein" or "euclidean". Proceeding with default of euclidean distance') + wasserstein_or_euclidean = 'euclidean' + if data_product not in ['ISCCP', "MODIS", 'MISR', 'all']: + print(' WARNING: Invalid option for data_product. Please enter "ISCCP" or "MODIS", "MISR" or "all". Proceeding with default of "all"') + data_product = 'all' + if premade_cloud_regimes != None: + if type(premade_cloud_regimes) != str: + print(' WARNING: Invalid option for premade_cloud_regimes. Please enter a path to a numpy array of Cloud Regime centers of shape (n_clusters, n_dimensions_of_data). Proceeding with default clusters') + premade_cloud_regimes = None + if lat_range != None: + if type(lat_range) != list or len(lat_range) != 2: + print(' WARNING: Invalid option for lat_range. Please enter two values in square brackets sperated by a comma. Example: [-30,30]. Proceeding with entire latitude range') + lat_range = None + if lon_range != None: + if type(lon_range) != list or len(lon_range) != 2: + print(' WARNING: Invalid option for lon_range. Please enter two values in square brackets sperated by a comma. Example: [0,90]. Proceeding with entire longitude range') + lon_range = None + if only_ocean_or_land not in [False, 'L', 'O']: + print(' WARNING: Invalid option for only_ocean_or_land. Please enter "L" for land only, "O" for ocean only. Set to False or leave blank for both land and water. Proceeding with default of False') + only_ocean_or_land = False + + # Checking if the path we wish to save our plots to exists, and if it doesnt creating it + if not os.path.isdir(adf.plot_location[0]): + os.makedirs(adf.plot_location[0]) + + # path to h0 files + h0_data_path = adf.get_cam_info("cam_hist_loc", required=True)[0] + "/*h0*.nc" + + # Time Range min and max, or None for all time + time_range = [str(adf.get_cam_info("start_year")[0]), str(adf.get_cam_info("end_year")[0])] + + files = glob.glob(h0_data_path) + # Opening an initial dataset + init_ds = xr.open_mfdataset(files[0]) + + # defining dicts for variable names for each data set + data_var_dict = {'ISCCP':'FISCCP1_COSP', "MISR":'CLD_MISR', "MODIS":"CLMODIS" } + ht_var_dict = {'ISCCP':'cosp_prs', "MISR":'cosp_htmisr', "MODIS":"cosp_prs" } + tau_var_dict = {'ISCCP':'cosp_tau', "MISR":'cosp_tau', "MODIS":"cosp_tau_modis" } + + # geting names of cosp data variables for all data products that will get processed + if data_product == 'all': + var_name = list(data_var_dict.values()) + else: + var_name = [data_var_dict[data_product]] + + # looping through to do analysis on each data product selected + for var in var_name: + + # Getting data name corresponding to the variable being opened + key_list = list(data_var_dict.keys()) + val_list = list(data_var_dict.values()) + position = val_list.index(var) + data = key_list[position] + ht_var_name = ht_var_dict[data] + tau_var_name = tau_var_dict[data] + + print(f'\n Beginning {data} Cloud Regime analysis') #testing + + # variable that gets set to true if var is missing in the data file, and is used to skip that dataset + missing_var = False + + t = time.time() + # Trying to open time series files from cam)ts_loc + try: ds = xr.open_mfdataset(adf.get_cam_info("cam_ts_loc", required=True)[0] + f"/*{var}*") + + # If that doesnt work trying to open the variables from the h0 files + except: + print(f" -WARNING: {data} time series file does not exist, was {var} added to the diag_var_list?") + print(" -Attempting to use h0 files from cam_hist_loc, but this will be slower" ) + # Creating a list of all the variables in the dataset + remove = list(init_ds.keys()) + try: + # Deleting the variables we want to keep in our dataset, all remaining variables will be dropped upon opening the files, this allows for faster opening of large files + remove.remove(var) + # If there's a LANDFRAC variable keep it in the dataset + landfrac_present = True + try: remove.remove('LANDFRAC') + except: landfrac_present = False + + ds = xr.open_mfdataset(files, drop_variables = remove) + + # If variables are not present in h0 tell the user the variables do not exist, and that there is not COSP output for this data + except: + print(f' -{var} does not exist in h0 files, does this run have {data} COSP output? Skipping {data} for now') + missing_var = True # used to skip the code below and move onto the next var name + + # executing further analysis on this data + finally: + + # Skipping var if its not present in data files + if missing_var: + continue + + # Adjusting lon to run from -180 to 180 if it doesnt already + if np.max(ds.lon) > 180: + ds.coords['lon'] = (ds.coords['lon'] + 180) % 360 - 180 + ds = ds.sortby(ds.lon) + + + # Selecting only points over ocean or points over land if only_ocean_or_land has been used + if only_ocean_or_land != False: + # If LANDFRAC variable is present, use it to mask + if landfrac_present == True: + if only_ocean_or_land == 'L': ds = ds.where(ds.LANDFRAC == 1) + elif only_ocean_or_land == 'O': ds = ds.where(ds.LANDFRAC == 0) + # Otherwise use cartopy + else: + land = create_land_mask(ds) + dims = ds.dims + + # Inserting new axis to make land a broadcastable shape with ds + for n in range(len(dims)): + if dims[n] != 'lat' and dims[n] != 'lon': + land = np.expand_dims(land, n) + + # Masking out the land or water + if only_ocean_or_land == 'L': ds = ds.where(land == 1) + elif only_ocean_or_land == 'O': ds = ds.where(land == 0) + else: raise Exception('Invalid option for only_ocean_or_land: Please enter "O" for ocean only, "L" for land only, or set to False for both land and water') + + # Selecting lat range + if lat_range != None: + if ds.lat[0] > ds.lat[-1]: + ds = ds.sel(lat=slice(lat_range[1],lat_range[0])) + else: + ds = ds.sel(lat=slice(lat_range[0],lat_range[1])) + + # Selecting Lon range + if lon_range != None: + if ds.lon[0] > ds.lon[-1]: + ds = ds.sel(lon=slice(lon_range[1],lon_range[0])) + else: + ds = ds.sel(lon=slice(lon_range[0],lon_range[1])) + + # Selecting time range + if time_range != ["None","None"]: + # Need these if statements to be robust if the adf obj only has start_year or end_year + if time_range[0] == "None": + start = ds.time[0] + end = time_range[1] + elif time_range[1] == "None": + start = time_range[0] + end = ds.time[-1] + else: + start = time_range[0] + end = time_range[1] + + ds = ds.sel(time=slice(start,end)) + + # Turning dataset into a dataarray + ds = ds[var] + + # Selecting only valid tau and height/pressure range + # Many data products have a -1 bin for failed retreivals, we do not wish to include this + tau_selection = {tau_var_name:slice(0,9999999999999)} + # Making sure this works for pressure which is ordered largest to smallest and altitude which is ordered smallest to largest + if ds[ht_var_name][0] > ds[ht_var_name][-1]: ht_selection = {ht_var_name:slice(9999999999999,0)} + else: ht_selection = {ht_var_name:slice(0,9999999999999)} + ds = ds.sel(tau_selection) + ds = ds.sel(ht_selection) + + # Opening cluster centers + # Using premade clusters if they have been provided + if type(premade_cloud_regimes) == str: + cl = np.load(premade_cloud_regimes) + # Checking if the shape is what we'd expect + if cl.shape[1] != len(ds[tau_var_name]) * len(ds[ht_var_name]): + if data == 'ISCCP' and cl.shape[1] == 42: + None + elif data == 'MISR' and cl.shape[1] == 105: + None + else: + raise Exception (f'premade_cloud_regimes is the wrong shape. premade_cloud_regimes.shape = {cl.shape}, but must be shape (k, {len(ds[tau_var_name]) * len(ds[ht_var_name])}) to fit the loaded {data} data') + print(' -Using premade cloud regimes:') + + # If custom CRs havent been passed, use either the emd or euclidean premade ones + elif wasserstein_or_euclidean == "wasserstein": + cluster_centers_path = adf.variable_defaults[f"{data}_emd_centers"]['obs_file'] + cl = np.load(cluster_centers_path) + + elif wasserstein_or_euclidean == "euclidean": + cluster_centers_path = adf.variable_defaults[f"{data}_euclidean_centers"]['obs_file'] + cl = np.load(cluster_centers_path) + + # Defining k, the number of clusters + k = len(cl) + + print(f' -Preprocessing data') + + # COSP ISCCP data has one extra tau bin than the satellite data, and misr has an extra height bin. This checks roughly if we are comparing against the + # satellite data, and if so removes the extra tau or ht bin. If a user passes home made CRs from CESM data, no data will be removed + if data == 'ISCCP' and cl.shape[1] == 42: + # a slightly hacky way to drop the smallest tau bin, but is robust incase tau is flipped in a future version + ds = ds.sel(cosp_tau=slice(np.min(ds.cosp_tau)+1e-11,np.inf)) + print(" -Dropping smallest tau bin to be comparable with observational cloud regimes") + if data == 'MISR' and cl.shape[1] == 105: + # a slightly hacky way to drop the lowest height bin, but is robust incase height is flipped in a future version + ds = ds.sel(cosp_htmisr=slice(np.min(ds.cosp_htmisr)+1e-11,np.inf)) + print(" -Dropping lowest height bin to be comparable with observational cloud regimes") + + # Selcting only the relevant data and stacking it to shape n_histograms, n_tau * n_pc + dims = list(ds.dims) + dims.remove(tau_var_name) + dims.remove(ht_var_name) + histograms = ds.stack(spacetime=(dims), tau_ht=(tau_var_name, ht_var_name)) + weights = np.cos(np.deg2rad(histograms.lat.values)) # weights array to use with emd-kmeans + + # Turning into a numpy array for clustering + mat = histograms.values + + # Removing all histograms with 1 or more nans in them + indicies = np.arange(len(mat)) + is_valid = ~np.isnan(mat.mean(axis=1)) + is_valid = is_valid.astype(np.int32) + valid_indicies = indicies[is_valid==1] + mat=mat[valid_indicies] + weights=weights[valid_indicies] + + print(f' -Fitting data') + + # Compute cluster labels + cluster_labels_temp = precomputed_clusters(mat, cl, wasserstein_or_euclidean, ds) + + # taking the flattened cluster_labels_temp array, and turning it into a datarray the shape of ds.var_name, and reinserting NaNs in place of missing data + cluster_labels = np.full(len(indicies), np.nan, dtype=np.int32) + cluster_labels[valid_indicies]=cluster_labels_temp + cluster_labels = xr.DataArray(data=cluster_labels, coords={"spacetime":histograms.spacetime},dims=("spacetime") ) + cluster_labels = cluster_labels.unstack() + + print(' -Plotting') + + # Comparing to observation + if adf.compare_obs == True: + # defining dicts for variable names for each data set + obs_data_var_dict = {'ISCCP':'n_pctaudist', "MISR":'clMISR', "MODIS":"MODIS_CLD_HISTO" } + obs_ht_var_dict = {'ISCCP':'levtau', "MISR":'tau', "MODIS":"COT" } + obs_tau_var_dict = {'ISCCP':'levpc', "MISR":'cth', "MODIS":"PRES" } + + # Getting data name corresponding to the variable being opened + key_list = list(obs_data_var_dict.keys()) + val_list = list(obs_data_var_dict.values()) + obs_var = obs_data_var_dict[data] + position = val_list.index(obs_var) + data = key_list[position] + obs_ht_var_name = obs_ht_var_dict[data] + obs_tau_var_name = obs_tau_var_dict[data] + + print(f' -Starting {data} observation data') + + # Opening observation files. The obs files have three variables, precomputed euclidean cluster labels, precomputed emd cluster labels + # and then the raw data to use if custom CRs are passed in. + obs_data_path = adf.var_obs_dict[var]['obs_file'] + + # Opening the data + ds_o = xr.open_dataset(obs_data_path) + + # Selecting either the appropriate pre-computed cluster_labels or the raw data + if premade_cloud_regimes == None: + if wasserstein_or_euclidean == 'wasserstein': + ds_o = ds_o.emd_cluster_labels + else: + ds_o = ds_o.euclidean_cluster_labels + else: + ds_o = ds_o[obs_var] + + # Adjusting lon to run from -180 to 180 if it doesnt already + if np.max(ds_o.lon) > 180: + ds_o.coords['lon'] = (ds_o.coords['lon'] + 180) % 360 - 180 + ds_o = ds_o.sortby(ds_o.lon) + + # Selecting only points over ocean or points over land if only_ocean_or_land has been used + if only_ocean_or_land != False: + land = create_land_mask(ds_o) + dims = ds_o.dims + + # inserting new axis to make land a broadcastable shape with ds_o + for n in range(len(dims)): + if dims[n] != 'lat' and dims[n] != 'lon': + land = np.expand_dims(land, n) + + # Masking out the land or water + if only_ocean_or_land == 'L': ds_o = ds_o.where(land == 1) + elif only_ocean_or_land == 'O': ds_o = ds_o.where(land == 0) + else: raise Exception('Invalid option for only_ocean_or_land: Please enter "O" for ocean only, "L" for land only, or set to False for both land and water') + + # Selecting lat range + if lat_range != None: + if ds_o.lat[0] > ds_o.lat[-1]: + ds_o = ds_o.sel(lat=slice(lat_range[1],lat_range[0])) + else: + ds_o = ds_o.sel(lat=slice(lat_range[0],lat_range[1])) + + # Selecting Lon range + if lon_range != None: + if ds_o.lon[0] > ds_o.lon[-1]: + ds_o = ds_o.sel(lon=slice(lon_range[1],lon_range[0])) + else: + ds_o = ds_o.sel(lon=slice(lon_range[0],lon_range[1])) + + # Don't select time range for obsrvation, just compare to the full record + # # Selecting time range + # if time_range != ["None","None"]: + # if time_range[0] == "None": + # start = ds.time[0] + # end = time_range[1] + # elif time_range[1] == "None": + # start = time_range[0] + # end = ds.time[-1] + # else: + # start = time_range[0] + # end = time_range[1] + + # ds = ds.sel(time=slice(start,end)) + + if premade_cloud_regimes == None: + cluster_labels_o = ds_o + cluster_labels_o_temp = cluster_labels_o.stack(spacetime=("time", 'lat', 'lon')) + else: + # Selecting only valid tau and height/pressure range + # Many data products have a -1 bin for failed retreivals, we do not wish to include this + tau_selection = {obs_tau_var_name:slice(0,9999999999999)} + # Making sure this works for pressure which is ordered largest to smallest and altitude which is ordered smallest to largest + if ds_o[obs_ht_var_name][0] > ds_o[obs_ht_var_name][-1]: ht_selection = {obs_ht_var_name:slice(9999999999999,0)} + else: ht_selection = {obs_ht_var_name:slice(0,9999999999999)} + ds_o = ds_o.sel(tau_selection) + ds_o = ds_o.sel(ht_selection) + + # Selcting only the relevant data and stacking it to shape n_histograms, n_tau * n_pc + dims = list(ds_o.dims) + dims.remove(obs_tau_var_name) + dims.remove(obs_ht_var_name) + histograms_o = ds_o.stack(spacetime=(dims), tau_ht=(obs_ht_var_name, obs_tau_var_name)) + weights_o = np.cos(np.deg2rad(histograms_o.lat.values)) # weights_o array to use with emd-kmeans + + # Turning into a numpy array for clustering + mat_o = histograms_o.values + + # Removing all histograms with 1 or more nans in them + indicies = np.arange(len(mat_o)) + is_valid = ~np.isnan(mat_o.mean(axis=1)) + is_valid = is_valid.astype(np.int32) + valid_indicies_o = indicies[is_valid==1] + mat_o=mat_o[valid_indicies_o] + weights_o=weights_o[valid_indicies_o] + + if np.min(mat_o < 0): + raise Exception (f'Found negative value in ds_o.{var_name}, if this is a fill value for missing data, convert to nans and try again') + + print(f' -Fitting data') + + # Compute cluster labels + cluster_labels_temp_o = precomputed_clusters(mat_o, cl, wasserstein_or_euclidean, ds_o) + + # Taking the flattened cluster_labels_temp_o array, and turning it into a datarray the shape of obs_ds.var_name, and reinserting NaNs in place of missing data + cluster_labels_o = np.full(len(indicies), np.nan, dtype=np.int32) + cluster_labels_o[valid_indicies_o]=cluster_labels_temp_o + cluster_labels_o = xr.DataArray(data=cluster_labels_o, coords={"spacetime":histograms_o.spacetime},dims=("spacetime") ) + cluster_labels_o = cluster_labels_o.unstack() + + print(f' -Plotting') + + plot_hists_obs(cl, cluster_labels, cluster_labels_o, ht_var_name, tau_var_name, adf) + plot_rfo_obs_base_diff(cluster_labels, cluster_labels_o, adf) + + # Comparing to CAM baseline if not comparing to obs + else: + # path to h0 files + baseline_h0_data_path = adf.get_baseline_info("cam_hist_loc", required=True) + "/*h0*.nc" + # Time Range min and max, or None for all time + time_range_b = [str(adf.get_baseline_info("start_year")), str(adf.get_baseline_info("end_year"))] + # Creating a list of files + files = glob.glob(baseline_h0_data_path) + # Opening an initial dataset + init_ds_b = xr.open_dataset(files[0]) + + print(f' -Starting {data} CAM baseline data') #testing + + # Variable that gets set to true if var is missing in the data file, and is used to skip processing that dataset + missing_var = False + + # Trying to open time series files from cam)ts_loc + try: ds_b = xr.open_mfdataset(adf.get_baseline_info("cam_ts_loc", required=True) + f"/*{var}*") + + # If that doesnt work trying to open the variables from the h0 files + except: + print(f" -WARNING: {data} time series file does not exist, was {var} added to the diag_var_list?") + print(" Attempting to use h0 files from cam_hist_loc, but this will be slower" ) + # Creating a list of all the variables in the dataset + remove = list(init_ds_b.keys()) + try: + # Deleting the variables we want to keep in our dataset, all remaining variables will be dropped upon opening the files, this allows for faster opening of large files + remove.remove(var) + # If there's a LANDFRAC variable keep it in the dataset + landfrac_present = True + try: remove.remove('LANDFRAC') + except: landfrac_present = False + + # Opening dataset and dropping irrelevant data + ds_b = xr.open_mfdataset(files, drop_variables = remove) + + # If variables are not present in h0 tell the user the variables do not exist, and that there is not COSP output for this data + except: + print(f' {var} does not exist in h0 files, does this run have {data} COSP output? Skipping {data} for now') + missing_var = True # used to skip the code below and move onto the next var name + + # Executing further analysis on this data + finally: + + # Skipping var if its not present in data files + if missing_var: + continue + + # Adjusting lon to run from -180 to 180 if it doesnt already + if np.max(ds_b.lon) > 180: + ds_b.coords['lon'] = (ds_b.coords['lon'] + 180) % 360 - 180 + ds_b = ds_b.sortby(ds_b.lon) + + # Selecting only points over ocean or points over land if only_ocean_or_land has been used + if only_ocean_or_land != False: + # If LANDFRAC variable is present, use it to mask + if landfrac_present == True: + if only_ocean_or_land == 'L': ds_b = ds_b.where(ds_b.LANDFRAC == 1) + elif only_ocean_or_land == 'O': ds_b = ds_b.where(ds_b.LANDFRAC == 0) + # Otherwise use cartopy + else: + land = create_land_mask(ds_b) + dims = ds_b.dims + + # Inserting new axis to make land a broadcastable shape with ds_b + for n in range(len(dims)): + if dims[n] != 'lat' and dims[n] != 'lon': + land = np.expand_dims(land, n) + + # Masking out the land or water + if only_ocean_or_land == 'L': ds_b = ds_b.where(land == 1) + elif only_ocean_or_land == 'O': ds_b = ds_b.where(land == 0) + else: raise Exception('Invalid option for only_ocean_or_land: Please enter "O" for ocean only, "L" for land only, or set to False for both land and water') + + # Selecting lat range + if lat_range != None: + if ds_b.lat[0] > ds_b.lat[-1]: + ds_b = ds_b.sel(lat=slice(lat_range[1],lat_range[0])) + else: + ds_b = ds_b.sel(lat=slice(lat_range[0],lat_range[1])) + + # Selecting Lon range + if lon_range != None: + if ds_b.lon[0] > ds_b.lon[-1]: + ds_b = ds_b.sel(lon=slice(lon_range[1],lon_range[0])) + else: + ds_b = ds_b.sel(lon=slice(lon_range[0],lon_range[1])) + + # Selecting time range + if time_range_b != ["None","None"]: + # Need these if statements to be robust if the adf obj only has start_year or end_year + if time_range_b[0] == "None": + start = ds_b.time[0] + end = time_range_b[1] + elif time_range_b[1] == "None": + start = time_range_b[0] + end = ds_b.time[-1] + else: + start = time_range_b[0] + end = time_range_b[1] + + ds_b = ds_b.sel(time=slice(start,end)) + + # Turning dataset into a dataarray + ds_b = ds_b[var] + + # Selecting only valid tau and height/pressure range + # Many data products have a -1 bin for failed retreivals, we do not wish to include this + tau_selection = {tau_var_name:slice(0,9999999999999)} + # Making sure this works for pressure which is ordered largest to smallest and altitude which is ordered smallest to largest + if ds_b[ht_var_name][0] > ds_b[ht_var_name][-1]: ht_selection = {ht_var_name:slice(9999999999999,0)} + else: ht_selection = {ht_var_name:slice(0,9999999999999)} + ds_b = ds_b.sel(tau_selection) + ds_b = ds_b.sel(ht_selection) + + # Opening cluster centers + # Using premade clusters if they have been provided + if type(premade_cloud_regimes) == str: + cl = np.load(premade_cloud_regimes) + # Checking if the shape is what we'd expect + if premade_cloud_regimes.shape[1] != len(ds_b[tau_var_name]) * len(ds_b[ht_var_name]): + raise Exception (f'premade_cloud_regimes is the wrong shape. premade_cloud_regimes.shape = {premade_cloud_regimes.shape}, but must be shpae (k, {len(ds_b.tau_var_name) * len(ds_b.ht_var_name)}) to fit the loaded data') + print(' -Using premade cloud regimes:') + + + print(f' -Preprocessing data') + + # COSP ISCCP data has one extra tau bin than the satellite data, and misr has an extra height bin. This checks roughly if we are comparing against the + # satellite data, and if so removes the extra tau or ht bin. If a user passes home made CRs from CESM data, no data will be removed + if data == 'ISCCP' and cl.shape[1] == 42: + # A slightly hacky way to drop the smallest tau bin, but is robust incase tau is flipped in a future version + ds_b = ds_b.sel(cosp_tau=slice(np.min(ds_b.cosp_tau)+1e-11,np.inf)) + print(" -Dropping smallest tau bin to be comparable with observational cloud regimes") + if data == 'MISR' and cl.shape[1] == 105: + # A slightly hacky way to drop the lowest height bin, but is robust incase height is flipped in a future version + ds_b = ds_b.sel(cosp_htmisr=slice(np.min(ds_b.cosp_htmisr)+1e-11,np.inf)) + print(" -Dropping lowest height bin to be comparable with observational cloud regimes") + + # Selcting only the relevant data and stacking it to shape n_histograms, n_tau * n_pc + dims = list(ds_b.dims) + dims.remove(tau_var_name) + dims.remove(ht_var_name) + histograms_b = ds_b.stack(spacetime=(dims), tau_ht=(tau_var_name, ht_var_name)) + weights_b = np.cos(np.deg2rad(histograms_b.lat.values)) # weights_b array to use with emd-kmeans + + # Turning into a numpy array for clustering + mat_b = histograms_b.values + + # Removing all histograms with 1 or more nans in them + indicies = np.arange(len(mat_b)) + is_valid = ~np.isnan(mat_b.mean(axis=1)) + is_valid = is_valid.astype(np.int32) + valid_indicies_b = indicies[is_valid==1] + mat_b=mat_b[valid_indicies_b] + weights_b=weights_b[valid_indicies_b] + + if np.min(mat_b < 0): + raise Exception (f'Found negative value in ds_b.{var_name}, if this is a fill value for missing data, convert to nans and try again') + + print(f' -Fitting data') + + # Compute cluster labels + cluster_labels_temp_b = precomputed_clusters(mat_b, cl, wasserstein_or_euclidean, ds_b) + + # Taking the flattened cluster_labels_temp_b array, and turning it into a datarray the shape of ds.var_name, and reinserting NaNs in place of missing data + cluster_labels_b = np.full(len(indicies), np.nan, dtype=np.int32) + cluster_labels_b[valid_indicies_b]=cluster_labels_temp_b + cluster_labels_b = xr.DataArray(data=cluster_labels_b, coords={"spacetime":histograms_b.spacetime},dims=("spacetime") ) + cluster_labels_b = cluster_labels_b.unstack() + + print(f' -Plotting') + + # Plotting + plot_hists_baseline(cl, cluster_labels, cluster_labels_b, ht_var_name, tau_var_name, adf) + plot_rfo_obs_base_diff(cluster_labels, cluster_labels_b, adf) + + + + +# %% From 4de6ee961f0f9eb9d4028cd2438bfd1af5315d42 Mon Sep 17 00:00:00 2001 From: Isaac Davis Date: Tue, 25 Jul 2023 15:40:39 -0600 Subject: [PATCH 2/5] Minor comment changes --- scripts/plotting/cloud_regime_analysis.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/scripts/plotting/cloud_regime_analysis.py b/scripts/plotting/cloud_regime_analysis.py index 86e3bac63..9611aac42 100644 --- a/scripts/plotting/cloud_regime_analysis.py +++ b/scripts/plotting/cloud_regime_analysis.py @@ -894,7 +894,6 @@ def test (oh_land, lat_lon, landar): # variable that gets set to true if var is missing in the data file, and is used to skip that dataset missing_var = False - t = time.time() # Trying to open time series files from cam)ts_loc try: ds = xr.open_mfdataset(adf.get_cam_info("cam_ts_loc", required=True)[0] + f"/*{var}*") @@ -993,6 +992,7 @@ def test (oh_land, lat_lon, landar): else: ht_selection = {ht_var_name:slice(0,9999999999999)} ds = ds.sel(tau_selection) ds = ds.sel(ht_selection) + # Opening cluster centers # Using premade clusters if they have been provided @@ -1010,12 +1010,14 @@ def test (oh_land, lat_lon, landar): # If custom CRs havent been passed, use either the emd or euclidean premade ones elif wasserstein_or_euclidean == "wasserstein": + obs_data_loc = adf.get_basic_info('obs_data_loc') + '/' cluster_centers_path = adf.variable_defaults[f"{data}_emd_centers"]['obs_file'] - cl = np.load(cluster_centers_path) + cl = np.load(obs_data_loc + cluster_centers_path) elif wasserstein_or_euclidean == "euclidean": + obs_data_loc = adf.get_basic_info('obs_data_loc') + '/' cluster_centers_path = adf.variable_defaults[f"{data}_euclidean_centers"]['obs_file'] - cl = np.load(cluster_centers_path) + cl = np.load(obs_data_loc + cluster_centers_path) # Defining k, the number of clusters k = len(cl) @@ -1062,8 +1064,6 @@ def test (oh_land, lat_lon, landar): cluster_labels = xr.DataArray(data=cluster_labels, coords={"spacetime":histograms.spacetime},dims=("spacetime") ) cluster_labels = cluster_labels.unstack() - print(' -Plotting') - # Comparing to observation if adf.compare_obs == True: # defining dicts for variable names for each data set @@ -1312,16 +1312,6 @@ def test (oh_land, lat_lon, landar): ds_b = ds_b.sel(tau_selection) ds_b = ds_b.sel(ht_selection) - # Opening cluster centers - # Using premade clusters if they have been provided - if type(premade_cloud_regimes) == str: - cl = np.load(premade_cloud_regimes) - # Checking if the shape is what we'd expect - if premade_cloud_regimes.shape[1] != len(ds_b[tau_var_name]) * len(ds_b[ht_var_name]): - raise Exception (f'premade_cloud_regimes is the wrong shape. premade_cloud_regimes.shape = {premade_cloud_regimes.shape}, but must be shpae (k, {len(ds_b.tau_var_name) * len(ds_b.ht_var_name)}) to fit the loaded data') - print(' -Using premade cloud regimes:') - - print(f' -Preprocessing data') # COSP ISCCP data has one extra tau bin than the satellite data, and misr has an extra height bin. This checks roughly if we are comparing against the From ce1021f79388d3ba527d89060871a40082943649 Mon Sep 17 00:00:00 2001 From: Brian Medeiros Date: Wed, 20 Aug 2025 15:51:22 -0600 Subject: [PATCH 3/5] refactor to work with current ADF --- lib/adf_variable_defaults.yaml | 61 +- scripts/plotting/cloud_regime_analysis.py | 2693 +++++++++++---------- 2 files changed, 1457 insertions(+), 1297 deletions(-) diff --git a/lib/adf_variable_defaults.yaml b/lib/adf_variable_defaults.yaml index 8b4698780..f1b5b1482 100644 --- a/lib/adf_variable_defaults.yaml +++ b/lib/adf_variable_defaults.yaml @@ -297,10 +297,10 @@ BC: diff_colormap: "BrBG" scale_factor: 1000000000 add_offset: 0 - new_unit: '$\mu$g/m3' + new_unit: '$\\mu$g/m3' mpl: colorbar: - label : '$\mu$g/m3' + label : '$\\mu$g/m3' category: "Aerosols" derivable_from: ["bc_a1", "bc_a4"] pct_diff_contour_levels: [-100,-75,-50,-40,-30,-20,-10,-8,-6,-4,-2,0,2,4,6,8,10,20,30,40,50,75,100] @@ -311,10 +311,10 @@ POM: diff_colormap: "BrBG" scale_factor: 1000000000 add_offset: 0 - new_unit: '$\mu$g/m3' + new_unit: '$\\mu$g/m3' mpl: colorbar: - label : '$\mu$g/m3' + label : '$\\mu$g/m3' category: "Aerosols" derivable_from: ["pom_a1", "pom_a4"] pct_diff_contour_levels: [-100,-75,-50,-40,-30,-20,-10,-8,-6,-4,-2,0,2,4,6,8,10,20,30,40,50,75,100] @@ -325,10 +325,10 @@ SO4: diff_colormap: "BrBG" scale_factor: 1000000000 add_offset: 0 - new_unit: '$\mu$g/m3' + new_unit: '$\\mu$g/m3' mpl: colorbar: - label : '$\mu$g/m3' + label : '$\\mu$g/m3' category: "Aerosols" derivable_from: ["so4_a1", "so4_a2", "so4_a3"] derivable_from_cam_chem: ["so4_a1", "so4_a2", "so4_a3", "so4_a5"] @@ -340,10 +340,10 @@ SOA: diff_colormap: "BrBG" scale_factor: 1000000000 add_offset: 0 - new_unit: '$\mu$g/m3' + new_unit: '$\\mu$g/m3' mpl: colorbar: - label : '$\mu$g/m3' + label : '$\\mu$g/m3' category: "Aerosols" derivable_from: ["soa_a1", "soa_a2"] derivable_from_cam_chem: ["soa1_a1", "soa2_a1", "soa3_a1", "soa4_a1", "soa5_a1", "soa1_a2", "soa2_a2", "soa3_a2", "soa4_a2", "soa5_a2"] @@ -357,10 +357,10 @@ DUST: diff_colormap: "BrBG" scale_factor: 1000000000 add_offset: 0 - new_unit: '$\mu$g/m3' + new_unit: '$\\mu$g/m3' mpl: colorbar: - label : '$\mu$g/m3' + label : '$\\mu$g/m3' category: "Aerosols" derivable_from: ["dst_a1", "dst_a2", "dst_a3"] pct_diff_contour_levels: [-100,-75,-50,-40,-30,-20,-10,-8,-6,-4,-2,0,2,4,6,8,10,20,30,40,50,75,100] @@ -373,13 +373,13 @@ SeaSalt: diff_colormap: "BrBG" scale_factor: 1000000000 add_offset: 0 - new_unit: '$\mu$g/m3' + new_unit: '$\\mu$g/m3' mpl: colorbar: - label : '$\mu$g/m3' + label : '$\\mu$g/m3' ticks: [0.05,0.2,0.4,1,2,6,24,90] diff_colorbar: - label : '$\mu$g/m3' + label : '$\\mu$g/m3' ticks: [-10,8,6,4,2,0,-2,-4,-6,-8,-10] category: "Aerosols" derivable_from: ["ncl_a1", "ncl_a2", "ncl_a3"] @@ -2255,11 +2255,7 @@ FISCCP1_COSP: category: "Clouds" obs_file: 'ISCCP_obs_data.nc' obs_name: "ISCCP" - pct_diff_contour_levels: [-100,-75,-50,-40,-30,-20,-10,-8,-6,-4,-2,0,2,4,6,8,10,20,30,40,50,75,100] - pct_diff_colormap: "PuOr_r" - -ICE_ICLD_VISTAU: - category: "COSP" + obs_var_name: "n_pctaudist" pct_diff_contour_levels: [-100,-75,-50,-40,-30,-20,-10,-8,-6,-4,-2,0,2,4,6,8,10,20,30,40,50,75,100] pct_diff_colormap: "PuOr_r" @@ -2271,13 +2267,10 @@ ISCCP_euclidean_centers: category: "Clouds" obs_file: 'CS_qualitative_clusters.npy' -ISCCP_emd_centers: - category: "Clouds" - obs_file: 'ISCCP_emd-means_n_init5_centers_1.npy' - -ISCCP_euclidean_centers: - category: "Clouds" - obs_file: 'CS_qualitative_clusters.npy' +ICE_ICLD_VISTAU: + category: "COSP" + pct_diff_contour_levels: [-100,-75,-50,-40,-30,-20,-10,-8,-6,-4,-2,0,2,4,6,8,10,20,30,40,50,75,100] + pct_diff_colormap: "PuOr_r" IWPMODIS: category: "COSP" @@ -2333,23 +2326,7 @@ MODIS_euclidean_centers: MODIS_emd_centers: category: "Clouds" - obs_file: 'MODIS_emd-means_n_init5_centers_1.np' - -MISR_euclidean_centers: - category: "Clouds" - obs_file: 'MISR_6C_weather_state_centers.npy' - -MISR_emd_centers: - category: "Clouds" - obs_file: 'MISR_emd-means_n_init5_centers_1.npy' - -MODIS_euclidean_centers: - category: "Clouds" - obs_file: 'MODIS_6C_weather_state_centers.npy' - -MODIS_emd_centers: - category: "Clouds" - obs_file: 'MODIS_emd-means_n_init5_centers_1.np' + obs_file: 'MODIS_emd-means_n_init5_centers_1.npy' PCTMODIS: category: "COSP" diff --git a/scripts/plotting/cloud_regime_analysis.py b/scripts/plotting/cloud_regime_analysis.py index 9611aac42..e10af26f8 100644 --- a/scripts/plotting/cloud_regime_analysis.py +++ b/scripts/plotting/cloud_regime_analysis.py @@ -1,1369 +1,1552 @@ -#%% -print() + +from math import ceil +import warnings +from pathlib import Path + import numpy as np -try : import wasserstein -except: - print(' Wasserstein package is not installed so wasserstein distance cannot be used. Attempting to use wassertein distance will raise an error.') - print(' To use wasserstein distance please install the wasserstein package in your environment: https://pypi.org/project/Wasserstein/ ') +import xesmf + +try: + import wasserstein +except: + print( + " Wasserstein package is not installed so wasserstein distance cannot be used. Attempting to use wasserstein distance will raise an error." + ) + print( + " To use wasserstein distance please install the wasserstein package in your environment: https://pypi.org/project/Wasserstein/ " + ) import matplotlib.pyplot as plt -import xarray as xr -import matplotlib as mpl +from matplotlib.cm import ScalarMappable from mpl_toolkits.axes_grid1 import make_axes_locatable -from numba import njit +import cartopy from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter import cartopy.crs as ccrs + +import dask +import xarray as xr +import matplotlib as mpl from shapely.geometry import Point -import cartopy from shapely.prepared import prep -import glob -from math import ceil -import time -import dask -import os - - -#global num_iter, n_samples, data, ds, ht_var_name, tau_var_name, k, height_or_pressure -def cloud_regime_analysis(adf, wasserstein_or_euclidean = "euclidean", data_product='all', premade_cloud_regimes=None, lat_range=None, lon_range=None, only_ocean_or_land=False): +try: + from numba import njit +except ImportError: + + def njit(func=None): + if func is None: + return njit + # Issue a warning that numba is not available + warnings.warn( + "NumPy performance optimization using Numba is not available. " + "Fallback to standard Python execution.", + UserWarning, + ) + return func + + +# --- bpm refactor +# --- set up dataclass to get metadata per variable: +from dataclasses import dataclass + + +@dataclass(frozen=True) +class VariableNames: + product_name: str + data_var: str + ht_var: str + tau_var: str + obs_data_var: str + obs_ht_var: str + obs_tau_var: str + + +# Consolidate data into a single dictionary of dataclass objects +ALL_VARS = { + "FISCCP1_COSP": VariableNames( + product_name="ISCCP", + data_var="FISCCP1_COSP", + ht_var="cosp_prs", + tau_var="cosp_tau", + obs_data_var="n_pctaudist", + obs_ht_var="levtau", + obs_tau_var="levpc", + ), + "CLD_MISR": VariableNames( + product_name="MISR", + data_var="CLD_MISR", + ht_var="cosp_htmisr", + tau_var="cosp_tau", + obs_data_var="clMISR", + obs_ht_var="tau", + obs_tau_var="cth", + ), + "CLMODIS": VariableNames( + product_name="MODIS", + data_var="CLMODIS", + ht_var="cosp_prs", + tau_var="cosp_tau_modis", + obs_data_var="MODIS_CLD_HISTO", + obs_ht_var="COT", + obs_tau_var="PRES", + ), +} +# --- + + +def cloud_regime_analysis( + adf, + wasserstein_or_euclidean="euclidean", + premade_cloud_regimes=None, + lat_range=None, + lon_range=None, + only_ocean_or_land=False, +): """ - This script/function is designed to generate 2-D lat/lon maps of Cloud Regimes (CRs), as well as plots of the CR - centers themselves. It can fit data into CRs using either Wassertstein (AKA Earth Movers Distance) or the more conventional + This script/function is designed to generate 2-D lat/lon maps of Cloud Regimes (CRs), as well as plots of the CR + centers themselves. It can fit data into CRs using either Wasserstein (AKA Earth Movers Distance) or the more conventional Euclidean distance. To use this script, the user should add the appropriate COSP variables to the diag_var_list in the yaml file. - The appropriate variables are FISCCP1_COSP for ISCCP, CLD_MISR for MISR, and CLMODIS for MODIS. All three should be added to - diag_var_list if you wish to preform analysis on all three. The user can also specify to preform analysis for just one or for - all three of the data products (ISCCP, MODIS, and MISR) that there exists COSP output for. A user can also choose to use only - a specfic lat and lon range, or to use data only over water or over land. Lastly if a user has CRs that they have custom made, - these can be passed in and the script will fit data into them rather than the premade CRs that the script already points to. - There are a total of 6 sets of premade CRs, two for each data product. One set made with euclidean distance and one set made - with Wasserstein distance for ISCCP, MODIS, and MISR. Therefore when the wasserstein_or_euclidean variables is changes it is - important to undertand that not only the distance metric used to fit data into CRs is changing, but also the CRs themselves - unless the user is passing in a set of premade CRs with the premade_cloud_regimes variable. + The appropriate variables are FISCCP1_COSP for ISCCP, CLD_MISR for MISR, and CLMODIS for MODIS. All three should be added to + diag_var_list if you wish to perform analysis on all three. The user can also specify to perform analysis for just one or for + all three of the data products (ISCCP, MODIS, and MISR) that there exists COSP output for. A user can also choose to use only + a specfic lat and lon range, or to use data only over water or over land. Lastly if a user has CRs that they have custom made, + these can be passed in and the script will fit data into them rather than the premade CRs that the script already points to. + There are a total of 6 sets of premade CRs, two for each data product. One set made with euclidean distance and one set made + with Wasserstein distance for ISCCP, MODIS, and MISR. Therefore when the wasserstein_or_euclidean variables is changed it is + important to undertand that not only the distance metric used to fit data into CRs is changing, but also the CRs themselves + unless the user is passing in a set of premade CRs with the premade_cloud_regimes variable. Description of kwargs: - wasserstein_or_euclidean -> Whether to use wasserstein or euclidean distance to fit CRs, enter "wassertein" for wasserstein or + wasserstein_or_euclidean -> Whether to use wasserstein or euclidean distance to fit CRs, enter "wasserstein" for wasserstein or "euclidean" for euclidean. This also changes the default CRs that data is fit into from ones created - with kmeans using euclidean distance to ones using kmeans with wassertein distance. Default is euclidean distance. - data_product -> Which data product to preform analysis for. Enter "ISCCP", "MODIS", "MISR" or "all". Default is "all" - premade_cloud_regimes -> If the user wishes to use custom CRs rather than the pre-loaded ones, enter them here as a path to a numpy + with kmeans using euclidean distance to ones using kmeans with wasserstein distance. Default is euclidean distance. + premade_cloud_regimes -> If the user wishes to use custom CRs rather than the pre-loaded ones, enter them here as a path to a numpy array of shape (k, n_tau_bins * n_pressure_bins) lat_range -> Range of latitudes to use enetered as a list, Ex. [-30,30]. Default is use all available latitudes lon_range -> Range of longitudes to use enetered as a list, Ex. [-90,90]. Default is use all available longitudes - only_ocean_or_land -> Set to "O" to preform analysis with only points over water, "L" for only points over land, or False + only_ocean_or_land -> Set to "O" to perform analysis with only points over water, "L" for only points over land, or False to use data over land and water. Default is False """ - - global k, ht_var_name, tau_var_name, var, mat, mat_b, mat_o dask.config.set({"array.slicing.split_large_chunks": False}) - # Compute cluster labels from precomputed cluster centers with appropriate distance - def precomputed_clusters(mat, cl, wasserstein_or_euclidean, ds): - - if wasserstein_or_euclidean == 'euclidean': - cluster_dists = np.sum((mat[:,:,None] - cl.T[None,:,:])**2, axis = 1) - cluster_labels_temp = np.argmin(cluster_dists, axis = 1) - - if wasserstein_or_euclidean == 'wasserstein': - - # A function to convert mat into the form required for the EMD calculation - @njit() - def stacking(position_matrix, centroids): - centroid_list = [] - - for i in range(len(centroids)): - x = np.empty((3,len(mat[0]))).T - x[:,0] = centroids[i] - x[:,1] = position_matrix[0] - x[:,2] = position_matrix[1] - centroid_list.append(x) - - return centroid_list - - # setting shape - n1 = len(ds[tau_var_name]) - n2 = len(ds[ht_var_name]) - - # Calculating the max distance between two points to be used as hyperparameter in EMD - # This is not necesarily the only value for this variable that can be used, see Wasserstein documentation - # on R hyper-parameter for more information - R = (n1**2+n2**2)**0.5 - - # Creating a flattened position matrix to pass wasersstein.PairwiseEMD - position_matrix = np.zeros((2,n1,n2)) - position_matrix[0] = np.tile(np.arange(n2),(n1,1)) - position_matrix[1] = np.tile(np.arange(n1),(n2,1)).T - position_matrix = position_matrix.reshape(2,-1) - - # Initialising wasserstein.PairwiseEMD - emds = wasserstein.PairwiseEMD(R = R, norm=True, dtype=np.float32, verbose=1, num_threads=162) - - # Rearranging mat to be in the format necesary for wasserstein.PairwiseEMD - events = stacking(position_matrix, mat) - centroid_list = stacking(position_matrix, cl) - emds(events, centroid_list) - print(" -Calculating Wasserstein distances") - print(" -Warning: This can be slow, but scales very well with additional processors") - distances = emds.emds() - labels = np.argmin(distances, axis=1) - - cluster_labels_temp = np.argmin(distances, axis=1) - - return cluster_labels_temp - - # This function is no longer used, no need to check - # Plot the CR cluster centers - def plot_hists(cl, cluster_labels, ht_var_name, tau_var_name, adf): - #defining number of clusters - k = len(cl) - - # setting up plots - ylabels = ds[ht_var_name].values - xlabels = ds[tau_var_name].values - X2,Y2 = np.meshgrid(np.arange(len(xlabels)+1), np.arange(len(ylabels)+1)) - p = [0,0.2,1,2,3,4,6,8,10,15,99] - cmap = mpl.colors.ListedColormap(['white', (0.19215686274509805, 0.25098039215686274, 0.5607843137254902), (0.23529411764705882, 0.3333333333333333, 0.6313725490196078), (0.32941176470588235, 0.5098039215686274, 0.6980392156862745), (0.39215686274509803, 0.6, 0.43137254901960786), (0.44313725490196076, 0.6588235294117647, 0.21568627450980393), (0.4980392156862745, 0.6784313725490196, 0.1843137254901961), (0.5725490196078431, 0.7137254901960784, 0.16862745098039217), (0.7529411764705882, 0.8117647058823529, 0.2), (0.9568627450980393, 0.8980392156862745,0.1607843137254902)]) - norm = mpl.colors.BoundaryNorm(p,cmap.N) - plt.rcParams.update({'font.size': 12}) - fig_height = 1 + 10/3 * ceil(k/3) - fig, ax = plt.subplots(figsize = (17, fig_height), ncols=3, nrows=ceil(k/3), sharex='all', sharey = True) - - aa = ax.ravel() - boundaries = p - norm = mpl.colors.BoundaryNorm(boundaries, cmap.N, clip=True) - aa[1].invert_yaxis() - - # creating weights area for area weighted RFOs - weights = cluster_labels.stack(z=('time','lat','lon')).lat.values - weights = np.cos(np.deg2rad(weights)) - weights = weights[valid_indicies] - indicies = np.arange(len(mat)) - - # Plotting each cluster center - for i in range (k): - - # Area Weighted relative Frequency of occurence calculation - total_rfo_num = cluster_labels == i - total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat))) - total_rfo_denom = cluster_labels >= 0 - total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat))) - total_rfo = total_rfo_num / total_rfo_denom * 100 - total_rfo = total_rfo.values - - # Area weighting each histogram belonging to a cluster and taking the mean - # if clustering was preformed with wasserstein distance and area weighting on, mean of i = cl[i], however if clustering was preformed with - # conventional kmeans or wasseerstein without weighting, these two will not be equal - indicies_i = indicies[np.where(cluster_labels_temp == i)] - mean = mat[indicies_i] * weights[indicies_i][:,np.newaxis] - if len(indicies_i) > 0: mean = np.sum(mean, axis=0) / np.sum(weights[indicies_i]) - else: mean = np.zeros(len(xlabels)*len(ylabels)) - - mean = mean.reshape(len(xlabels),len(ylabels)).T # reshaping into original histogram shape - if np.max(mean) <= 1: # Converting fractional data to percent to plot properly - mean *= 100 - - im = aa[i].pcolormesh(X2,Y2,mean,norm=norm,cmap=cmap) - aa[i].set_title(f"CR {i+1}, RFO = {np.round(total_rfo,1)}%") - - # setting titles, labels, etc - if data == "MISR": height_or_pressure = 'h' - else: height_or_pressure = 'p' - if height_or_pressure == 'p': fig.supylabel(f'Cloud-top Pressure ({ds[ht_var_name].units})', fontsize = 12, x = 0.09 ) - if height_or_pressure == 'h': fig.supylabel(f'Cloud-top Height ({ds[ht_var_name].units})', fontsize = 12, x = 0.09 ) - # fig.supxlabel('Optical Depth', fontsize = 12, y=0.26 ) - cbar_ax = fig.add_axes([0.95, 0.38, 0.045, 0.45]) - cb = fig.colorbar(im, cax=cbar_ax, ticks=p) - cb.set_label(label='Cloud Cover (%)', size =10) - cb.ax.tick_params(labelsize=9) - #aa[6].set_position([0.399, 0.125, 0.228, 0.215]) - #aa[6].set_position([0.33, 0.117, 0.36, 0.16]) - #aa[-2].remove() - - bbox = aa[1].get_position() - p1 = bbox.p1 - p0 = bbox.p0 - fig.suptitle(f'{data} Cloud Regimes', x=0.5, y=p1[1]+(1/fig_height * 0.5), fontsize=15) - - bbox = aa[-2].get_position() - p1 = bbox.p1 - p0 = bbox.p0 - fig.supxlabel('Optical Depth', fontsize = 12, y=p0[1]-(1/fig_height * 0.5) ) - - - # Removing extra plots - for i in range(ceil(k/3)*3-k): - aa[-(i+1)].remove() - save_path = adf.plot_location[0] + f'/{data}_CR_centers' - plt.savefig(save_path) - - if adf.create_html: - adf.add_website_data(save_path + ".png", var, adf.get_baseline_info("cam_case_name")) - - - # Plot the CR centers of obs, baseline and test case - def plot_hists_baseline(cl, cluster_labels, cluster_labels_o, ht_var_name, tau_var_name, adf): - # #defining number of clusters - k = len(cl) - - # setting up plots - ylabels = ds[ht_var_name].values - xlabels = ds[tau_var_name].values - X2,Y2 = np.meshgrid(np.arange(len(xlabels)+1), np.arange(len(ylabels)+1)) - p = [0,0.2,1,2,3,4,6,8,10,15,99] - cmap = mpl.colors.ListedColormap(['white', (0.19215686274509805, 0.25098039215686274, 0.5607843137254902), (0.23529411764705882, 0.3333333333333333, 0.6313725490196078), (0.32941176470588235, 0.5098039215686274, 0.6980392156862745), (0.39215686274509803, 0.6, 0.43137254901960786), (0.44313725490196076, 0.6588235294117647, 0.21568627450980393), (0.4980392156862745, 0.6784313725490196, 0.1843137254901961), (0.5725490196078431, 0.7137254901960784, 0.16862745098039217), (0.7529411764705882, 0.8117647058823529, 0.2), (0.9568627450980393, 0.8980392156862745,0.1607843137254902)]) - norm = mpl.colors.BoundaryNorm(p,cmap.N) - plt.rcParams.update({'font.size': 14}) - fig_height = (1 + 10/3 * ceil(k/3))*3 - fig, ax = plt.subplots(figsize = (17, fig_height), ncols=3, nrows=k, sharex='all', sharey = True) - - aa = ax.ravel() - boundaries = p - norm = mpl.colors.BoundaryNorm(boundaries, cmap.N, clip=True) - if data != 'MISR': aa[1].invert_yaxis() - - # creating weights area for area weighted RFOs - weights = cluster_labels.stack(z=('time','lat','lon')).lat.values - weights = np.cos(np.deg2rad(weights)) - weights = weights[valid_indicies] - indicies = np.arange(len(mat)) - - for i in range(k): - - im = ax[i,0].pcolormesh(X2,Y2,cl[i].reshape(len(xlabels),len(ylabels)).T,norm=norm,cmap=cmap) - ax[i,0].set_title(f" Observation CR {i+1}") - - # Plotting each cluster center (baseline) - for i in range (k): - # Area Weighted relative Frequency of occurence calculation - total_rfo_num = cluster_labels_b == i - total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels_b.lat))) - total_rfo_denom = cluster_labels_b >= 0 - total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels_b.lat))) - total_rfo = total_rfo_num / total_rfo_denom * 100 - total_rfo = total_rfo.values - - # Area weighting each histogram belonging to a cluster and taking the mean - # if clustering was preformed with wasserstein distance and area weighting on, mean of i = cl[i], however if clustering was preformed with - # conventional kmeans or wasseerstein without weighting, these two will not be equal - indicies_i = indicies[np.where(cluster_labels_temp_b == i)] - mean = mat_b[indicies_i] * weights[indicies_i][:,np.newaxis] - if len(indicies_i) > 0: mean = np.sum(mean, axis=0) / np.sum(weights[indicies_i]) - else: mean = np.zeros(len(xlabels)*len(ylabels)) - - mean = mean.reshape(len(xlabels),len(ylabels)).T # reshaping into original histogram shape - if np.max(mean) <= 1: # Converting fractional data to percent to plot properly - mean *= 100 - - im = ax[i,1].pcolormesh(X2,Y2,mean,norm=norm,cmap=cmap) - ax[i,1].set_title(f"Baseline Case CR {i+1}, RFO = {np.round(total_rfo,1)}%") - - # Plotting each cluster center (test_case) - for i in range (k): - - # Area Weighted relative Frequency of occurence calculation - total_rfo_num = cluster_labels == i - total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat))) - total_rfo_denom = cluster_labels >= 0 - total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat))) - total_rfo = total_rfo_num / total_rfo_denom * 100 - total_rfo = total_rfo.values - - # Area weighting each histogram belonging to a cluster and taking the mean - # if clustering was preformed with wasserstein distance and area weighting on, mean of i = cl[i], however if clustering was preformed with - # conventional kmeans or wasseerstein without weighting, these two will not be equal - indicies_i = indicies[np.where(cluster_labels_temp == i)] - mean = mat[indicies_i] * weights[indicies_i][:,np.newaxis] - if len(indicies_i) > 0: mean = np.sum(mean, axis=0) / np.sum(weights[indicies_i]) - else: mean = np.zeros(len(xlabels)*len(ylabels)) - - mean = mean.reshape(len(xlabels),len(ylabels)).T # reshaping into original histogram shape - if np.max(mean) <= 1: # Converting fractional data to percent to plot properly - mean *= 100 - - im = ax[i,2].pcolormesh(X2,Y2,mean,norm=norm,cmap=cmap) - ax[i,2].set_title(f"Test Case CR {i+1}, RFO = {np.round(total_rfo,1)}%") - - # setting titles, labels, etc - if data == "MODIS": - ylabels = [0, 180, 310, 440, 560, 680, 800, 1000] - xlabels = [0, 0.3, 1.3, 3.6, 9.4, 23, 60, 150] - ax[0,0].set_yticks(np.arange(8)) - ax[0,0].set_xticks(np.arange(8)) - ax[0,0].set_yticklabels(ylabels) - ax[0,0].set_xticklabels(xlabels) - xticks = ax[0,0].xaxis.get_major_ticks() - xticks[0].set_visible(False) - xticks[-1].set_visible(False) - - if data == "MISR": - xlabels = [0.2, 0.8, 2.4, 6.5, 16.2, 41.5, 100] - ylabels = [ 0.25,0.75,1.25,1.75,2.25,2.75,3.5, 4.5, 6,8, 10, 12, 14, 16, 20 ] - ax[0,0].set_yticks(np.arange(0,16,2)+0.5) - ax[0,0].set_yticklabels(ylabels[0::2]) - ax[0,0].set_xticks(np.array([1,2,3,4,5,6,7]) -0.5) - ax[0,0].set_xticklabels(xlabels, fontsize = 16) - xticks = ax[0,0].xaxis.get_major_ticks() - xticks[0].set_visible(False) - xticks[-1].set_visible(False) - - if data == 'ISCCP': - xlabels = [ 0, 1.3, 3.6, 9.4, 22.6, 60.4, 450 ] - ylabels = [ 10, 180, 310, 440, 560, 680, 800, 1025] - yticks = aa[i].get_yticks().tolist() - xticks = aa[i].get_xticks().tolist() - aa[i].set_yticks(yticks) - aa[i].set_xticks(xticks) - aa[i].set_yticklabels(ylabels) - aa[i].set_xticklabels(xlabels) - xticks = aa[i].xaxis.get_major_ticks() - xticks[0].label1.set_visible(False) - xticks[-1].label1.set_visible(False) - - - if data == "MISR": height_or_pressure = 'h' - else: height_or_pressure = 'p' - if height_or_pressure == 'p': fig.supylabel(f'Cloud-top Pressure ({ds[ht_var_name].units})', x = 0.07 ) - if height_or_pressure == 'h': fig.supylabel(f'Cloud-top Height ({ds[ht_var_name].units})', x = 0.07 ) - - if data == "MODIS": - ylabels = [0, 180, 310, 440, 560, 680, 800, 1000] - xlabels = [0, 0.3, 1.3, 3.6, 9.4, 23, 60, 150] - if data == "MISR": - x=1 - - - - # fig.supxlabel('Optical Depth', fontsize = 12, y=0.26 ) - cbar_ax = fig.add_axes([0.95, 0.38, 0.045, 0.45]) - cb = fig.colorbar(im, cax=cbar_ax, ticks=p) - cb.set_label(label='Cloud Cover (%)', size =10) - cb.ax.tick_params(labelsize=9) - #aa[6].set_position([0.399, 0.125, 0.228, 0.215]) - #aa[6].set_position([0.33, 0.117, 0.36, 0.16]) - #aa[-2].remove() - - bbox = aa[1].get_position() - p1 = bbox.p1 - p0 = bbox.p0 - fig.suptitle(f'{data} Cloud Regimes', x=0.5, y=p1[1]+(1/fig_height * 0.5)+0.007, fontsize=18) - - bbox = aa[-2].get_position() - p1 = bbox.p1 - p0 = bbox.p0 - fig.supxlabel('Optical Depth', y=p0[1]-(1/fig_height * 0.5)-0.007 ) - - - save_path = adf.plot_location[0] + f'/{data}_CR_centers' - plt.savefig(save_path) - - if adf.create_html: - adf.add_website_data(save_path + ".png", var, adf.get_baseline_info("cam_case_name")) - - # Closing the figure - plt.close() - - - # Plot the CR centers of obs and test case - def plot_hists_obs(cl, cluster_labels, cluster_labels_o, ht_var_name, tau_var_name, adf): - #defining number of clusters - k = len(cl) - # setting up plots - ylabels = ds[ht_var_name].values - xlabels = ds[tau_var_name].values - X2,Y2 = np.meshgrid(np.arange(len(xlabels)+1), np.arange(len(ylabels)+1)) - p = [0,0.2,1,2,3,4,6,8,10,15,99] - cmap = mpl.colors.ListedColormap(['white', (0.19215686274509805, 0.25098039215686274, 0.5607843137254902), (0.23529411764705882, 0.3333333333333333, 0.6313725490196078), (0.32941176470588235, 0.5098039215686274, 0.6980392156862745), (0.39215686274509803, 0.6, 0.43137254901960786), (0.44313725490196076, 0.6588235294117647, 0.21568627450980393), (0.4980392156862745, 0.6784313725490196, 0.1843137254901961), (0.5725490196078431, 0.7137254901960784, 0.16862745098039217), (0.7529411764705882, 0.8117647058823529, 0.2), (0.9568627450980393, 0.8980392156862745,0.1607843137254902)]) - norm = mpl.colors.BoundaryNorm(p,cmap.N) - plt.rcParams.update({'font.size': 14}) - fig_height = (1 + 10/3 * ceil(k/3))*3 - fig, ax = plt.subplots(figsize = (12, fig_height), ncols=2, nrows=k, sharex='all', sharey = True) - - aa = ax.ravel() - boundaries = p - norm = mpl.colors.BoundaryNorm(boundaries, cmap.N, clip=True) - if data != 'MISR': aa[1].invert_yaxis() - - # creating weights area for area weighted RFOs - weights = cluster_labels.stack(z=('time','lat','lon')).lat.values - weights = np.cos(np.deg2rad(weights)) - weights = weights[valid_indicies] - indicies = np.arange(len(mat)) - - # ax[0,0].set_xticklabels(xlabels) - # ax[0,0].set_yticklabels(ylabels) - - for i in range(k): - # Area Weighted relative Frequency of occurence calculation - total_rfo_num = cluster_labels_o == i - total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels_o.lat))) - total_rfo_denom = cluster_labels_o >= 0 - total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels_o.lat))) - total_rfo = total_rfo_num / total_rfo_denom * 100 - total_rfo = total_rfo.values - - im = ax[i,0].pcolormesh(X2,Y2,cl[i].reshape(len(xlabels),len(ylabels)).T,norm=norm,cmap=cmap) - ax[i,0].set_title(f" Observation CR {i+1}, RFO = {np.round(total_rfo,1)}%") - - # Plotting each cluster center (test_case) - for i in range (k): - - # Area Weighted relative Frequency of occurence calculation - total_rfo_num = cluster_labels == i - total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat))) - total_rfo_denom = cluster_labels >= 0 - total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat))) - total_rfo = total_rfo_num / total_rfo_denom * 100 - total_rfo = total_rfo.values - - # Area weighting each histogram belonging to a cluster and taking the mean - # if clustering was preformed with wasserstein distance and area weighting on, mean of i = cl[i], however if clustering was preformed with - # conventional kmeans or wasseerstein without weighting, these two will not be equal - indicies_i = indicies[np.where(cluster_labels_temp == i)] - mean = mat[indicies_i] * weights[indicies_i][:,np.newaxis] - if len(indicies_i) > 0: mean = np.sum(mean, axis=0) / np.sum(weights[indicies_i]) - else: mean = np.zeros(len(xlabels)*len(ylabels)) - - mean = mean.reshape(len(xlabels),len(ylabels)).T # reshaping into original histogram shape - if np.max(mean) <= 1: # Converting fractional data to percent to plot properly - mean *= 100 - - im = ax[i,1].pcolormesh(X2,Y2,mean,norm=norm,cmap=cmap) - ax[i,1].set_title(f"Test Case CR {i+1}, RFO = {np.round(total_rfo,1)}%") - - if data == "MODIS": - ylabels = [0, 180, 310, 440, 560, 680, 800, 1000] - xlabels = [0, 0.3, 1.3, 3.6, 9.4, 23, 60, 150] - ax[0,0].set_yticks(np.arange(8)) - ax[0,0].set_xticks(np.arange(8)) - ax[0,0].set_yticklabels(ylabels) - ax[0,0].set_xticklabels(xlabels) - xticks = ax[0,0].xaxis.get_major_ticks() - xticks[0].set_visible(False) - xticks[-1].set_visible(False) - - if data == "MISR": - xlabels = [0.2, 0.8, 2.4, 6.5, 16.2, 41.5, 100] - ylabels = [ 0.25,0.75,1.25,1.75,2.25,2.75,3.5, 4.5, 6,8, 10, 12, 14, 16, 20 ] - ax[0,0].set_yticks(np.arange(0,16,2)+0.5) - ax[0,0].set_yticklabels(ylabels[0::2]) - ax[0,0].set_xticks(np.array([1,2,3,4,5,6,7]) -0.5) - ax[0,0].set_xticklabels(xlabels, fontsize = 16) - xticks = ax[0,0].xaxis.get_major_ticks() - xticks[0].set_visible(False) - xticks[-1].set_visible(False) - - if data == 'ISCCP': - xlabels = [ 0, 1.3, 3.6, 9.4, 22.6, 60.4, 450 ] - ylabels = [ 10, 180, 310, 440, 560, 680, 800, 1025] - yticks = aa[i].get_yticks().tolist() - xticks = aa[i].get_xticks().tolist() - aa[i].set_yticks(yticks) - aa[i].set_xticks(xticks) - aa[i].set_yticklabels(ylabels) - aa[i].set_xticklabels(xlabels) - xticks = aa[i].xaxis.get_major_ticks() - xticks[0].label1.set_visible(False) - xticks[-1].label1.set_visible(False) - - # setting titles, labels, etc - if data == "MISR": height_or_pressure = 'h' - else: height_or_pressure = 'p' - if height_or_pressure == 'p': fig.supylabel(f'Cloud-top Pressure ({ds[ht_var_name].units})', x = 0.05 ) - if height_or_pressure == 'h': fig.supylabel(f'Cloud-top Height ({ds[ht_var_name].units})', x = 0.05) - # fig.supxlabel('Optical Depth', fontsize = 12, y=0.26 ) - cbar_ax = fig.add_axes([0.95, 0.38, 0.045, 0.45]) - cb = fig.colorbar(im, cax=cbar_ax, ticks=p) - cb.set_label(label='Cloud Cover (%)') - # cb.ax.tick_params(labelsize=9) - - - bbox = aa[1].get_position() - p1 = bbox.p1 - p0 = bbox.p0 - fig.suptitle(f'{data} Cloud Regimes', x=0.5, y=p1[1]+(1/fig_height * 0.5)+0.007, fontsize=18) - - bbox = aa[-2].get_position() - p1 = bbox.p1 - p0 = bbox.p0 - fig.supxlabel('Optical Depth', y=p0[1]-(1/fig_height * 0.5)-0.007 ) - - save_path = adf.plot_location[0] + f'/{data}_CR_centers' - plt.savefig(save_path) - - if adf.create_html: - adf.add_website_data(save_path + ".png", var, case_name = None, multi_case=True) - - # Closing the figure - plt.close() - - # Plot LatLon plots of the frequency of occrence of the baseline/obs and test case - def plot_rfo_obs_base_diff(cluster_labels, cluster_labels_d, adf): - - COLOR = 'black' - mpl.rcParams['text.color'] = COLOR - mpl.rcParams['axes.labelcolor'] = COLOR - mpl.rcParams['xtick.color'] = COLOR - mpl.rcParams['ytick.color'] = COLOR - plt.rcParams.update({'font.size': 13}) - plt.rcParams['figure.dpi'] = 500 + # Plot LatLon plots of the frequency of occrence of the baseline/obs and test case + def plot_rfo_obs_base_diff(cluster_labels, cluster_labels_d, adf, field=None): + k = cluster_labels.attrs.get("k") + COLOR = "black" + mpl.rcParams["text.color"] = COLOR + mpl.rcParams["axes.labelcolor"] = COLOR + mpl.rcParams["xtick.color"] = COLOR + mpl.rcParams["ytick.color"] = COLOR + plt.rcParams.update({"font.size": 13}) + plt.rcParams["figure.dpi"] = 500 fig_height = 7 # Comparing obs or baseline? if adf.compare_obs == True: - obs_or_base = 'Observation' + obs_or_base = "Observation" else: - obs_or_base = 'Baseline' + obs_or_base = "Baseline" for cluster in range(k): - fig, ax = plt.subplots(ncols=2, nrows=2, subplot_kw={'projection': ccrs.PlateCarree()}, figsize = (12,fig_height))#, sharex='col', sharey='row') + fig, ax = plt.subplots( + ncols=2, + nrows=2, + subplot_kw={"projection": ccrs.PlateCarree()}, + figsize=(12, fig_height), + ) plt.subplots_adjust(wspace=0.08, hspace=0.002) aa = ax.ravel() - + # Calculating and plotting rfo of baseline/obs - X, Y = np. meshgrid(cluster_labels_d.lon,cluster_labels_d.lat) - rfo_d = np.sum(cluster_labels_d==cluster, axis=0) / np.sum(cluster_labels_d >= 0, axis=0) * 100 + X, Y = np.meshgrid(cluster_labels_d.lon, cluster_labels_d.lat) + rfo_d = ( + np.sum(cluster_labels_d == cluster, axis=0) + / np.sum(cluster_labels_d >= 0, axis=0) + * 100 + ) aa[0].set_extent([-180, 180, -90, 90]) aa[0].coastlines() - mesh = aa[0].pcolormesh(X, Y, rfo_d, transform=ccrs.PlateCarree(), rasterized = True, cmap="GnBu",vmin=0,vmax=100) - total_rfo_num = cluster_labels_d == cluster - total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels_d.lat))) + mesh = aa[0].pcolormesh( + X, + Y, + rfo_d, + transform=ccrs.PlateCarree(), + rasterized=True, + cmap="GnBu", + vmin=0, + vmax=100, + ) + total_rfo_num = cluster_labels_d == cluster + total_rfo_num = np.sum( + total_rfo_num * np.cos(np.deg2rad(cluster_labels_d.lat)) + ) total_rfo_denom = cluster_labels_d >= 0 - total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels_d.lat))) - total_rfo_d = total_rfo_num / total_rfo_denom * 100 - aa[0].set_title(f"{obs_or_base}, RFO = {round(float(total_rfo_d),1)}", pad=4) + total_rfo_denom = np.sum( + total_rfo_denom * np.cos(np.deg2rad(cluster_labels_d.lat)) + ) + total_rfo_d = total_rfo_num / total_rfo_denom * 100 + aa[0].set_title( + f"{obs_or_base}, RFO = {round(float(total_rfo_d),1)}", pad=4 + ) # Calculating and plotting rfo of test_case - X, Y = np. meshgrid(cluster_labels.lon,cluster_labels.lat) - rfo = np.sum(cluster_labels==cluster, axis=0) / np.sum(cluster_labels >= 0, axis=0) * 100 + X, Y = np.meshgrid(cluster_labels.lon, cluster_labels.lat) + rfo = ( + np.sum(cluster_labels == cluster, axis=0) + / np.sum(cluster_labels >= 0, axis=0) + * 100 + ) aa[1].set_extent([-180, 180, -90, 90]) aa[1].coastlines() - mesh = aa[1].pcolormesh(X, Y, rfo, transform=ccrs.PlateCarree(), rasterized = True, cmap="GnBu",vmin=0,vmax=100) - total_rfo_num = cluster_labels == cluster - total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat))) + mesh = aa[1].pcolormesh( + X, + Y, + rfo, + transform=ccrs.PlateCarree(), + rasterized=True, + cmap="GnBu", + vmin=0, + vmax=100, + ) + total_rfo_num = cluster_labels == cluster + total_rfo_num = np.sum( + total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat)) + ) total_rfo_denom = cluster_labels >= 0 - total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat))) - total_rfo = total_rfo_num / total_rfo_denom * 100 + total_rfo_denom = np.sum( + total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat)) + ) + total_rfo = total_rfo_num / total_rfo_denom * 100 aa[1].set_title(f"Test Case, RFO = {round(float(total_rfo),1)}", pad=4) # Making colorbar - cax = fig.add_axes([aa[1].get_position().x1+0.01,aa[1].get_position().y0,0.02,aa[1].get_position().height]) - cb = plt.colorbar(mesh, cax=cax) - cb.set_label(label = 'RFO (%)') + cax = fig.add_axes( + [ + aa[1].get_position().x1 + 0.01, + aa[1].get_position().y0, + 0.02, + aa[1].get_position().height, + ] + ) + cb = plt.colorbar(mesh, cax=cax) + cb.set_label(label="RFO (%)") # Calculating and plotting difference # If observation/baseline is a higher resolution interpolate from obs/baseline to CAM grid - if len(cluster_labels_d.lat) * len(cluster_labels_d.lon) > len(cluster_labels.lat) * len(cluster_labels.lon): + if len(cluster_labels_d.lat) * len(cluster_labels_d.lon) > len( + cluster_labels.lat + ) * len(cluster_labels.lon): rfo_d = rfo_d.interp_like(rfo, method="nearest") - + # If CAM is a higher resolution interpolate from CAM to obs/baseline grid - if len(cluster_labels_d.lat) * len(cluster_labels_d.lon) <= len(cluster_labels.lat) * len(cluster_labels.lon): + if len(cluster_labels_d.lat) * len(cluster_labels_d.lon) <= len( + cluster_labels.lat + ) * len(cluster_labels.lon): rfo = rfo.interp_like(rfo_d, method="nearest") - X, Y = np. meshgrid(cluster_labels_d.lon,cluster_labels_d.lat) + X, Y = np.meshgrid(cluster_labels_d.lon, cluster_labels_d.lat) rfo_diff = rfo - rfo_d aa[2].set_extent([-180, 180, -90, 90]) aa[2].coastlines() - mesh = aa[2].pcolormesh(X, Y, rfo_diff, transform=ccrs.PlateCarree(), rasterized = True, cmap="coolwarm",vmin=-100,vmax=100) - total_rfo_num = cluster_labels == cluster - total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat))) + mesh = aa[2].pcolormesh( + X, + Y, + rfo_diff, + transform=ccrs.PlateCarree(), + rasterized=True, + cmap="coolwarm", + vmin=-100, + vmax=100, + ) + total_rfo_num = cluster_labels == cluster + total_rfo_num = np.sum( + total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat)) + ) total_rfo_denom = cluster_labels >= 0 - total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat))) - total_rfo = total_rfo_num / total_rfo_denom * 100 - aa[2].set_title(f"Test - {obs_or_base}, ΔRFO = {round(float(total_rfo-total_rfo_d),1)}", pad=4) - + total_rfo_denom = np.sum( + total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat)) + ) + total_rfo = total_rfo_num / total_rfo_denom * 100 + aa[2].set_title( + f"Test - {obs_or_base}, ΔRFO = {round(float(total_rfo-total_rfo_d),1)}", + pad=4, + ) # Setting yticks - aa[0].set_yticks([-60,-30,0,30,60], crs=ccrs.PlateCarree()) - aa[2].set_yticks([-60,-30,0,30,60], crs=ccrs.PlateCarree()) + aa[0].set_yticks([-60, -30, 0, 30, 60], crs=ccrs.PlateCarree()) + aa[2].set_yticks([-60, -30, 0, 30, 60], crs=ccrs.PlateCarree()) lat_formatter = LatitudeFormatter() aa[0].yaxis.set_major_formatter(lat_formatter) aa[2].yaxis.set_major_formatter(lat_formatter) - # making colorbar for diff plot - cax = fig.add_axes([aa[2].get_position().x1+0.01,aa[2].get_position().y0,0.02,aa[2].get_position().height]) - cb = plt.colorbar(mesh, cax=cax) - cb.set_label(label = 'ΔRFO (%)') - - # plotting x labels - aa[1].set_xticks([-120,-60,0,60,120,], crs=ccrs.PlateCarree()) + cax = fig.add_axes( + [ + aa[2].get_position().x1 + 0.01, + aa[2].get_position().y0, + 0.02, + aa[2].get_position().height, + ] + ) + cb = plt.colorbar(mesh, cax=cax) + cb.set_label(label="ΔRFO (%)") + + # plotting x labels + aa[1].set_xticks( + [ + -120, + -60, + 0, + 60, + 120, + ], + crs=ccrs.PlateCarree(), + ) lon_formatter = LongitudeFormatter(zero_direction_label=True) aa[1].xaxis.set_major_formatter(lon_formatter) - aa[2].set_xticks([-120,-60,0,60,120,], crs=ccrs.PlateCarree()) + aa[2].set_xticks( + [ + -120, + -60, + 0, + 60, + 120, + ], + crs=ccrs.PlateCarree(), + ) lon_formatter = LongitudeFormatter(zero_direction_label=True) aa[2].xaxis.set_major_formatter(lon_formatter) bbox = aa[1].get_position() p1 = bbox.p1 - plt.suptitle(f"CR{cluster+1} Relative Frequency of Occurence", y= p1[1]+(1/fig_height * 0.5))#, {round(cl[cluster,23],4)}") + plt.suptitle( + f"CR{cluster+1} Relative Frequency of Occurence", + y=p1[1] + (1 / fig_height * 0.5), + ) aa[-1].remove() - save_path = adf.plot_location[0] + f'/{data}_CR{cluster+1}_LatLon_mean' + save_path = adf.plot_location[0] + f"/{field}_CR{cluster+1}_LatLon_mean" plt.savefig(save_path) if adf.create_html: - adf.add_website_data(save_path + ".png", var, case_name = None, multi_case=True) - + adf.add_website_data( + save_path + ".png", field, case_name=None, multi_case=True + ) + # Closing the figure plt.close() - # This function is no longer used, no reason to check it - # Plot RFO maps of the CRss - def plot_rfo(cluster_labels, adf): - #defining number of clusters - - COLOR = 'black' - mpl.rcParams['text.color'] = COLOR - mpl.rcParams['axes.labelcolor'] = COLOR - mpl.rcParams['xtick.color'] = COLOR - mpl.rcParams['ytick.color'] = COLOR - plt.rcParams.update({'font.size': 10}) - fig_height = 2.2 * ceil(k/2) - plt.rcParams['figure.dpi'] = 500 - fig, ax = plt.subplots(ncols=2, nrows=int(k/2 + k%2), subplot_kw={'projection': ccrs.PlateCarree()}, figsize = (10,fig_height))#, sharex='col', sharey='row') - plt.subplots_adjust(wspace=0.13, hspace=0.05) - aa = ax.ravel() - - X, Y = np. meshgrid(ds.lon,ds.lat) - - # Plotting the rfo of each cluster - tot_rfo_sum = 0 - - for cluster in range(k): #range(0,k+1): - # Calculating rfo - rfo = np.sum(cluster_labels==cluster, axis=0) / np.sum(cluster_labels >= 0, axis=0) * 100 - # tca_explained = np.sum(cluster_labels == cluster) * np.sum(init_clusters[cluster]) / total_cloud_amnt * 100 - # tca_explained = round(float(tca_explained.values),1) - aa[cluster].set_extent([-180, 180, -90, 90]) - aa[cluster].coastlines() - mesh = aa[cluster].pcolormesh(X, Y, rfo, transform=ccrs.PlateCarree(), rasterized = True, cmap="GnBu",vmin=0,vmax=100) - #total_rfo = np.sum(cluster_labels==cluster) / np.sum(cluster_labels >= 0) * 100 - # total_rfo_num = np.sum(cluster_labels == cluster * np.cos(np.deg2rad(cluster_labels.lat))) - total_rfo_num = cluster_labels == cluster - total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat))) - total_rfo_denom = cluster_labels >= 0 - total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat))) - - total_rfo = total_rfo_num / total_rfo_denom * 100 - tot_rfo_sum += total_rfo - aa[cluster].set_title(f"CR {cluster+1}, RFO = {round(float(total_rfo),1)}", pad=4) - # aa[cluster].gridlines(draw_labels=True, dms=True, x_inline=False, y_inline=False) - # x_label_plot_list = [4,5,6] - # y_label_plot_list = [0,2,4,6] - # if cluster in x_label_plot_list: + ################################################################### + # MAIN + ################################################################### + # Checking if kwargs have been entered correctly + if wasserstein_or_euclidean not in ["euclidean", "wasserstein"]: + print( + ' WARNING: Invalid option for wasserstein_or_euclidean. Please enter "wasserstein" or "euclidean". Proceeding with default of euclidean distance' + ) + wasserstein_or_euclidean = "euclidean" + if premade_cloud_regimes != None: + if type(premade_cloud_regimes) != str: + print( + " WARNING: Invalid option for premade_cloud_regimes. Please enter a path to a numpy array of Cloud Regime centers of shape (n_clusters, n_dimensions_of_data). Proceeding with default clusters" + ) + premade_cloud_regimes = None + if lat_range != None: + if type(lat_range) != list or len(lat_range) != 2: + print( + " WARNING: Invalid option for lat_range. Please enter two values in square brackets sperated by a comma. Example: [-30,30]. Proceeding with entire latitude range" + ) + lat_range = None + if lon_range != None: + if type(lon_range) != list or len(lon_range) != 2: + print( + " WARNING: Invalid option for lon_range. Please enter two values in square brackets sperated by a comma. Example: [0,90]. Proceeding with entire longitude range" + ) + lon_range = None + if only_ocean_or_land not in [False, "L", "O"]: + print( + ' WARNING: Invalid option for only_ocean_or_land. Please enter "L" for land only, "O" for ocean only. Set to False or leave blank for both land and water. Proceeding with default of False' + ) + only_ocean_or_land = False - if cluster % 2 == 0: - aa[cluster].set_yticks([-60,-30,0,30,60], crs=ccrs.PlateCarree()) - lat_formatter = LatitudeFormatter() - aa[cluster].yaxis.set_major_formatter(lat_formatter) + # NOTE: probably have to move into case loop + time_range = [ + str(adf.get_cam_info("start_year")[0]), + str(adf.get_cam_info("end_year")[0]), + ] + + # --- BPM refactor --- + # determine which variables to try + cr_vars = [] + landfrac_present = "LANDFRAC" in adf.diag_var_list + print(f"Did we find LANDFRAC in the variable list: {landfrac_present}") + for field in adf.diag_var_list: + if field in ["FISCCP1_COSP", "CLD_MISR", "CLMODIS"]: + cr_vars.append(field) + + # process each each COSP cloud variable + for field in cr_vars: + print(f"WORK ON {field}") + cluster_spec = premade_cloud_regimes if premade_cloud_regimes is not None else wasserstein_or_euclidean + + ht_var_name = ALL_VARS[field].ht_var + tau_var_name = ALL_VARS[field].tau_var + if adf.compare_obs: + ref_ht_var_name = ALL_VARS[field].obs_ht_var + ref_tau_var_name = ALL_VARS[field].obs_tau_var + else: + ref_ht_var_name = ALL_VARS[field].ht_var + ref_tau_var_name = ALL_VARS[field].tau_var + + # GET REFERENCE DATA, use for all cases + ref_data = load_reference_data(adf, field) + if adf.compare_obs: + # ref_data should be a dataset in this case + # reference regime labels or cloud data (to be labeled) + if premade_cloud_regimes is None: + if wasserstein_or_euclidean == "wasserstein": + ds_o = ref_data.emd_cluster_labels + else: + ds_o = ref_data.euclidean_cluster_labels + else: + ds_o = ref_data[adf.variable_defaults[field]['obs_var_name']] + else: + ds_o = ref_data # already a dataarray + + for case_name in adf.data.case_names: + c_ts_da = adf.data.load_timeseries_da(case_name, field) + if c_ts_da is None: + print( + f"\t WARNING: Variable {field} for case '{case_name}' provides None type. Skipping this variable" + ) + skip_var = True + continue + else: + print( + f"\t Loaded time series for {field} ==> {c_ts_da.shape = }, {c_ts_da.coords = }" + ) + if "ncol" in c_ts_da.dims: + # right now we are remapping to fv09 grid because that + # is the mapping available. + # TODO: generalize; would save time to remap to sat data grid + print("Trigger regrid (ne30-to-fv09 ONLY)") + regrid_weights_file = Path( + "/glade/work/brianpm/mapping_ne30pg3_to_fv09_esmfbilin.nc" + ) + rg = make_se_regridder( + regrid_weights_file, Method="bilinear" + ) # algorithm needs to match + ds = regrid_se_data_bilinear( + rg, c_ts_da, column_dim_name="ncol" + ) + else: + ds = c_ts_da # assumption: already on lat-lon grid - #aa[7].set_title(f"Weathersdfasdfa State {i+1}, RFO = {round(float(total_rfo),1)}", pad=-40) - cb = plt.colorbar(mesh, ax = ax, anchor =(-0.28,0.83), shrink = 0.6) - cb.set_label(label = 'RFO (%)', labelpad=-3) + ##### DATA PRE-PROCESSING + # Adjusting lon to run from -180 to 180 if it doesnt already + if np.max(ds.lon) > 180: + ds.coords["lon"] = (ds.coords["lon"] + 180) % 360 - 180 + ds = ds.sortby(ds.lon) - x_ticks_indicies = np.array([-1,-2]) + # Selecting only points over ocean or points over land if only_ocean_or_land has been used + ds = apply_land_ocean_mask(ds, only_ocean_or_land, landfrac_present) + if ds is None: + return # Error occurred + # Turning dataset into a dataarray + if isinstance(ds, xr.Dataset): + ds = ds[field] + ds = spatial_subset(ds, lat_range, lon_range) + ds = temporal_subset(ds, time_range) + ds = select_valid_tau_height(ds, tau_var_name, ht_var_name) + ##### + + # CLUSTER CENTERS + cl = load_cluster_centers(adf, cluster_spec, field) + if cl is None: + print(f"Skipping cloud regime analysis for {field} due to failed cluster center loading.") + continue # Skip to the next variable in cr_vars + - if k%2 == 1: - aa[-1].remove() - x_ticks_indicies -= 1 - - #aa[-2].set_position([0.27, 0.11, 0.31, 0.15]) - - # plotting x labels on final two plots - aa[x_ticks_indicies[0]].set_xticks([-120,-60,0,60,120,], crs=ccrs.PlateCarree()) - lon_formatter = LongitudeFormatter(zero_direction_label=True) - aa[x_ticks_indicies[0]].xaxis.set_major_formatter(lon_formatter) - aa[x_ticks_indicies[1]].set_xticks([-120,-60,0,60,120,], crs=ccrs.PlateCarree()) - lon_formatter = LongitudeFormatter(zero_direction_label=True) - aa[x_ticks_indicies[1]].xaxis.set_major_formatter(lon_formatter) - - bbox = aa[1].get_position() - p1 = bbox.p1 - plt.suptitle(f"CR Relative Frequency of Occurence", x= 0.43, y= p1[1]+(1/fig_height * 0.5))#, {round(cl[cluster,23],4)}") - - # Saving - save_path = adf.plot_location[0] + f'/{data}_RFO' - plt.savefig(save_path) - - if adf.create_html: - adf.add_website_data(save_path + ".png", var, adf.get_baseline_info("cam_case_name")) - - # This function is no longer used, no reason to check it - # Plot RFO maps of the CRs - def plot_rfo_diff(cluster_labels, cluster_labels_o, adf): - - # Setting plot parameters - COLOR = 'black' - mpl.rcParams['text.color'] = COLOR - mpl.rcParams['axes.labelcolor'] = COLOR - mpl.rcParams['xtick.color'] = COLOR - mpl.rcParams['ytick.color'] = COLOR - plt.rcParams.update({'font.size': 10}) - fig_height = 2.2 * ceil(k/2) - fig, ax = plt.subplots(ncols=2, nrows=int(k/2 + k%2), subplot_kw={'projection': ccrs.PlateCarree()}, figsize = (10,fig_height))#, sharex='col', sharey='row') - plt.subplots_adjust(wspace=0.13, hspace=0.05) - aa = ax.ravel() - plt.rcParams['figure.dpi'] = 500 - - # CReating lat-lon mesh - X, Y = np. meshgrid(ds.lon,ds.lat) - - # Plotting the difference in relative frequency of occurence (rfo) of each cluster - for cluster in range(k): - - # Calculating rfo - rfo = np.sum(cluster_labels==cluster, axis=0) / np.sum(cluster_labels >= 0, axis=0) * 100 - rfo_o = np.sum(cluster_labels_o==cluster, axis=0) / np.sum(cluster_labels_o >= 0, axis=0) * 100 - - # If observation/baseline is a higher resolution interpolate from obs/baseline to CAM grid - if len(cluster_labels_o.lat) * len(cluster_labels_o.lon) > len(cluster_labels.lat) * len(cluster_labels.lon): - rfo_o = rfo_o.interp_like(rfo, method="nearest") + # COSP ISCCP data has one extra tau bin than the satellite data, and misr has an extra height bin. + # This checks roughly if we are comparing against the + # satellite data, and if so removes the extra tau or ht bin. + # If a user passes home made CRs from CESM data, no data will be removed + if ALL_VARS[field].product_name == "ISCCP" and cl.shape[1] == 42: + sel_dict = {tau_var_name: slice(np.min(ds[tau_var_name]) + 1e-11, None)} + ds = ds.sel(sel_dict) + print(f"\t Dropping smallest tau bin ({tau_var_name}) to be comparable with observational cloud regimes") + if ALL_VARS[field].product_name == "MISR" and cl.shape[1] == 105: + sel_dict = {ht_var_name: slice(np.min(ds[ht_var_name]) + 1e-11, None)} + ds = ds.sel(sel_dict) + print(f"\t Dropping lowest height bin ({ht_var_name}) to be comparable with observational cloud regimes") - # If CAM is a higher resolution interpolate from CAM to obs/baseline grid - if len(cluster_labels_o.lat) * len(cluster_labels_o.lon) <= len(cluster_labels.lat) * len(cluster_labels.lon): - rfo = rfo.interp_like(rfo_o, method="nearest") + # CASE CLUSTER LABELING: + cluster_labels = compute_cluster_labels(ds, tau_var_name, ht_var_name, cl, wasserstein_or_euclidean) + print(f"{case_name} {field} cluster labels calculated.") + + ref_opts = {"premade_cloud_regimes":premade_cloud_regimes, + "distance": wasserstein_or_euclidean, + "landsea": only_ocean_or_land, + "landfrac": landfrac_present, # need to deal with this better + "lat_range": lat_range, + "lon_range": lon_range, + "time_range": time_range, + "tau_name": ref_tau_var_name, + "ht_name": ref_ht_var_name, + "data": ALL_VARS[field].product_name + } + cluster_labels_ref = compute_ref_cluster_labels(adf, ds_o, field, ref_opts) + + # PLOTS + taucoord = ds[tau_var_name] + htcoord = ds[ht_var_name] + # let cluster_labels know the number of clusters: + cluster_labels.attrs['k'] = cl.shape[0] + # `plot_rfo_obs_base_diff` expects `cluster_labels_ref` to be latxlon + if adf.compare_obs: + plot_hists_obs( + field, cl, cluster_labels, cluster_labels_ref, ds, ds_o, ht_var_name, tau_var_name, htcoord, taucoord, adf + ) + plot_rfo_obs_base_diff(cluster_labels, cluster_labels_ref, adf, field=field) + else: + plot_hists_baseline( + field, + cl, + cluster_labels, + cluster_labels_ref, + ds, + ds_o, # only is ref histograms for simulation, right + ht_var_name, + tau_var_name, + htcoord, + taucoord, + adf, + ) + plot_rfo_obs_base_diff(cluster_labels, cluster_labels_ref, adf, field=field) + # ^^^ BPM refactor ^^^ + + +def compute_ref_cluster_labels(adf, ds_ref, field, opts): + if adf.compare_obs == True: + ds_o = ds_ref + obs_var = adf.variable_defaults[field]['obs_var_name'] + # Adjusting lon to run from -180 to 180 if it doesnt already + if np.max(ds_o.lon) > 180: + ds_o.coords["lon"] = (ds_o.coords["lon"] + 180) % 360 - 180 + ds_o = ds_o.sortby(ds_o.lon) + + # this landfrac_present is probably not for ref dataset. + ds_o = apply_land_ocean_mask(ds_o, opts['landsea'], opts['landfrac']) + if ds_o is None: + print("[CRA compute_ref_cluster_labels] reference data is None.") + return # Error occurred + ds_o = spatial_subset(ds_o, opts['lat_range'], opts['lon_range']) # bpm + if ds_o is None: + print("[CRA compute_ref_cluster_labels] reference data is None.") + return # Error occurred + + if opts['premade_cloud_regimes'] is None: + print(f"[CRA compute_ref_cluster_labels] {opts['premade_cloud_regimes'] = }") + cluster_labels_o = ds_o + cluster_labels_ref = cluster_labels_o.stack( + spacetime=("time", "lat", "lon") + ).unstack() ## <- do we want to unstack here? + else: + print(f"[CRA compute_ref_cluster_labels] {opts['premade_cloud_regimes'] = }") + ds_o = select_valid_tau_height(ds_o, opts['tau_name'], opts['ht_name']) + cluster_labels_ref = finish_cluster_labels(ds_o, opts['tau_name'], opts['ht_name']) + else: + # Compare to simulation case. + ds_b = ds_ref + time_range_b = [ + str(adf.get_baseline_info("start_year")), + str(adf.get_baseline_info("end_year")), + ] + landfrac_present = opts['landfrac'] + # Adjusting lon to run from -180 to 180 if it doesnt already + if np.max(ds_b.lon) > 180: + ds_b.coords["lon"] = (ds_b.coords["lon"] + 180) % 360 - 180 + ds_b = ds_b.sortby(ds_b.lon) + + # this landfrac_present is porbably not for ds_b + ds_b = apply_land_ocean_mask(ds_b, opts['landsea'], opts['landfrac']) + if ds_b is None: + return # Error occurred + ds_b = spatial_subset(ds_b, opts['lat_range'], opts['lon_range']) + ds_b = temporal_subset(ds_b, time_range) + + # Turning dataset into a dataarray + if isinstance(ds_b, xr.Dataset): + ds_b = ds_b[field] + + ds_b = select_valid_tau_height(ds_b, opts['tau_name'], opts['ht_name']) + + # COSP ISCCP data has one extra tau bin than the satellite data, and misr has an extra height bin. This checks roughly if we are comparing against the + # satellite data, and if so removes the extra tau or ht bin. If a user passes home made CRs from CESM data, no data will be removed + if data == "ISCCP" and cl.shape[1] == 42: + # A slightly hacky way to drop the smallest tau bin, but is robust incase tau is flipped in a future version + ds_b = ds_b.sel( + cosp_tau=slice(np.min(ds_b.cosp_tau) + 1e-11, np.inf) + ) + print( + "\t Dropping smallest tau bin to be comparable with observational cloud regimes" + ) + if data == "MISR" and cl.shape[1] == 105: + # A slightly hacky way to drop the lowest height bin, but is robust incase height is flipped in a future version + ds_b = ds_b.sel( + cosp_htmisr=slice(np.min(ds_b.cosp_htmisr) + 1e-11, np.inf) + ) + print( + "\t Dropping lowest height bin to be comparable with observational cloud regimes" + ) + + cluster_labels_ref = finish_cluster_labels(ds_b, opts['tau_name'], opts['ht_name']) #bpm new func + return cluster_labels_ref + + +def precomputed_clusters(mat, cl, wasserstein_or_euclidean, ds): + """Compute cluster labels from precomputed cluster centers with appropriate distance""" + if wasserstein_or_euclidean == "euclidean": + cluster_dists = np.sum((mat[:, :, None] - cl.T[None, :, :]) ** 2, axis=1) + cluster_labels_temp = np.argmin(cluster_dists, axis=1) + elif wasserstein_or_euclidean == "wasserstein": + # A function to convert mat into the form required for the EMD calculation + @njit() + def stacking(position_matrix, centroids): + centroid_list = [] + + for i in range(len(centroids)): + x = np.empty((3, len(mat[0]))).T + x[:, 0] = centroids[i] + x[:, 1] = position_matrix[0] + x[:, 2] = position_matrix[1] + centroid_list.append(x) + + return centroid_list + + # setting shape + n1 = len(ds[tau_var_name]) + n2 = len(ds[ht_var_name]) + + # Calculating the max distance between two points to be used as hyperparameter in EMD + # This is not necesarily the only value for this variable that can be used, see Wasserstein documentation + # on R hyper-parameter for more information + R = (n1**2 + n2**2) ** 0.5 + + # Creating a flattened position matrix to pass wasersstein.PairwiseEMD + position_matrix = np.zeros((2, n1, n2)) + position_matrix[0] = np.tile(np.arange(n2), (n1, 1)) + position_matrix[1] = np.tile(np.arange(n1), (n2, 1)).T + position_matrix = position_matrix.reshape(2, -1) + + # Initialising wasserstein.PairwiseEMD + emds = wasserstein.PairwiseEMD( + R=R, norm=True, dtype=np.float32, verbose=1, num_threads=162 + ) + + # Rearranging mat to be in the format necesary for wasserstein.PairwiseEMD + events = stacking(position_matrix, mat) + centroid_list = stacking(position_matrix, cl) + emds(events, centroid_list) + print("\t Calculating Wasserstein distances") + print( + "\t Warning: This can be slow, but scales very well with additional processors" + ) + distances = emds.emds() + labels = np.argmin(distances, axis=1) + + cluster_labels_temp = np.argmin(distances, axis=1) + else: + print("[CRA: precomuted_clusters] ERROR -- must specify Wasserstein or Euclidean.") + return + return cluster_labels_temp - # difference in RFO - rfo_diff = rfo - rfo_o - # Setting up subplots and plotting - aa[cluster].set_extent([-180, 180, -90, 90]) - aa[cluster].coastlines() - mesh = aa[cluster].pcolormesh(X, Y, rfo_diff, transform=ccrs.PlateCarree(), rasterized = True, cmap="coolwarm",vmin=-100,vmax=100) - +def load_reference_data(adfobj, varname): + """Load and reference data. - # Calucating area weighted rfo difference for the title of subplots - total_rfo_num = cluster_labels == cluster - total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat))) - total_rfo_denom = cluster_labels >= 0 - total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat))) - total_rfo = total_rfo_num / total_rfo_denom * 100 + Make usual ADF assumption that reference case could be simulation or observation. - total_rfo_num_o = cluster_labels_o == cluster - total_rfo_num_o = np.sum(total_rfo_num_o * np.cos(np.deg2rad(cluster_labels_o.lat))) - total_rfo_denom_o = cluster_labels_o >= 0 - total_rfo_denom_o = np.sum(total_rfo_denom_o * np.cos(np.deg2rad(cluster_labels_o.lat))) + If compare_obs, returns a xr.Dataset, + otherwise returns time series xr.DataArray. - total_rfo_o = total_rfo_num_o / total_rfo_denom_o * 100 + """ + base_name = adfobj.data.ref_case_label + ref_var_nam = adfobj.data.ref_var_nam[varname] # shuld work for obs/sim + print(f"[CRA: load_reference_data] {base_name = }, {ref_var_nam = }") + + if adfobj.compare_obs: + ocase = adfobj.data.ref_case_label + fils = adfobj.data.ref_var_loc.get(varname, None) + if not isinstance(fils, list): + fils = [fils] + ds = adfobj.data.load_dataset(fils) + if ds is None: + warnings.warn(f"\t WARNING: Load failed reference data for {varname}") + return None + print(f"[CRA: load_reference_data] return observation dataset") + return ds + else: + print(f"[CRA: load_reference_data] returning simulation dataarray") + return adfobj.data.load_reference_timeseries_da(varname) - # Setting title - aa[cluster].set_title(f"CR {cluster+1}, RFO Diff = {round(float(total_rfo-total_rfo_o),1)}", pad=4) - # Put latitude labels on even numbered subplots - if cluster % 2 == 0: - aa[cluster].set_yticks([-60,-30,0,30,60], crs=ccrs.PlateCarree()) - lat_formatter = LatitudeFormatter() - aa[cluster].yaxis.set_major_formatter(lat_formatter) +def load_cluster_centers(adf, cluster_spec: str | Path, variablename: str) -> np.ndarray | None: + """ + Loads cluster center data from a specified source. - # Setting colorbar - cb = plt.colorbar(mesh, ax = ax, anchor =(-0.28,0.83), shrink = 0.6) - cb.set_label(label = 'Diff in RFO (%)', labelpad=-3) + Args: + cluster_spec: A string ('wasserstein', 'euclidean', or a file path) + or a Path object pointing to a .npy or .nc file. + variablename: The name of the variable to look up in ALL_VARS to + determine the data product name. - # Removing extra subplot if k is an odd number - x_ticks_indicies = np.array([-1,-2]) - if k%2 == 1: - aa[-1].remove() - x_ticks_indicies -= 1 - - # plotting x labels on final two plots - aa[x_ticks_indicies[0]].set_xticks([-120,-60,0,60,120,], crs=ccrs.PlateCarree()) - lon_formatter = LongitudeFormatter(zero_direction_label=True) - aa[x_ticks_indicies[0]].xaxis.set_major_formatter(lon_formatter) - aa[x_ticks_indicies[1]].set_xticks([-120,-60,0,60,120,], crs=ccrs.PlateCarree()) - lon_formatter = LongitudeFormatter(zero_direction_label=True) - aa[x_ticks_indicies[1]].xaxis.set_major_formatter(lon_formatter) - - # Setting suptitle - bbox = aa[1].get_position() - p1 = bbox.p1 - plt.suptitle(f"CR Relative Frequency of Occurence", x= 0.43, y= p1[1]+(1/fig_height * 0.5))#, {round(cl[cluster,23],4)}") - - # Saving - save_path = adf.plot_location[0] + f'/{data}_RFO' - plt.savefig(save_path) - - if adf.create_html: - adf.add_website_data(save_path + ".png", var, multi_case=True) - - # Create a one hot matrix where lat lon coordinates are over land using cartopy - def create_land_mask(ds): + Returns: + A NumPy array containing the cluster center data, or None if an error occurs. + """ + if isinstance(cluster_spec, str): + if cluster_spec in ('wasserstein', 'euclidean'): + try: + # Use variablename to find the data product name + data = ALL_VARS[variablename].product_name + obs_data_loc = Path(adf.get_basic_info("obs_data_loc")) + data_key = f"{data}_{cluster_spec}_centers" + cluster_centers_path = adf.variable_defaults[data_key]["obs_file"] + file_path = obs_data_loc / cluster_centers_path + except KeyError as e: + print( + f"[ERROR] Could not find '{variablename}' in ALL_VARS or default file path for '{cluster_spec}'. " + f"Original error: {e}" + ) + return None + else: + # Assume it's a direct file path + file_path = Path(cluster_spec) + + elif isinstance(cluster_spec, Path): + file_path = cluster_spec + else: + print(f"[ERROR] cluster_spec must be a string or a Path object, but got {type(cluster_spec)}") + return None + + # Check that the path exists before trying to load + if not file_path.exists(): + print(f"[ERROR] File not found at: {file_path}") + return None + + # Load the data based on the file extension + try: + if file_path.suffix == ".nc": + with xr.open_dataset(file_path) as ds: + if 'centers' not in ds: + print(f"[ERROR] NetCDF file {file_path.name} does not contain a 'centers' variable.") + return None + cl = ds['centers'].values + elif file_path.suffix == ".npy": + cl = np.load(file_path) + else: + print(f"[ERROR] Unsupported file type: {file_path.suffix}") + return None + except Exception as e: + print(f"[ERROR] An unexpected error occurred while loading {file_path.name}: {e}") + return None + + return cl + + +def compute_cluster_labels(ds, tau_var_name, ht_var_name, cl, wasserstein_or_euclidean): + # Selcting only the relevant data and stacking it to shape n_histograms, n_tau * n_pc + dims = list(ds.dims) + dims.remove(tau_var_name) + dims.remove(ht_var_name) + histograms = ds.stack(spacetime=(dims), tau_ht=(tau_var_name, ht_var_name)) + weights = np.cos( + np.deg2rad(histograms.lat.values) + ) # weights array to use with emd-kmeans + + # Turning into a numpy array for clustering + mat = histograms.values + + # Removing all histograms with 1 or more nans in them + indices = np.arange(len(mat)) + is_valid = ~np.isnan(mat.mean(axis=1)) + is_valid = is_valid.astype(np.int32) + valid_inds = indices[is_valid == 1] + mat = mat[valid_inds] + weights = weights[valid_inds] + + print(f"\t Fitting data") + + # Compute cluster labels + cluster_labels_temp = precomputed_clusters( + mat, cl, wasserstein_or_euclidean, ds + ) + + # taking the flattened cluster_labels_temp array, + # and turning it into a datarray the shape of ds.var_name, + # and reinserting NaNs in place of missing data + cluster_labels = np.full(len(indices), np.nan, dtype=np.int32) + cluster_labels[valid_inds] = cluster_labels_temp + cluster_labels = xr.DataArray( + data=cluster_labels, + coords={"spacetime": histograms.spacetime}, + dims=("spacetime"), + ) + cluster_labels = cluster_labels.unstack() + return cluster_labels + +def spatial_subset(ds_o, lat_range, lon_range): + # Selecting lat range + if lat_range: + if ds_o.lat[0] > ds_o.lat[-1]: + ds_o = ds_o.sel(lat=slice(lat_range[1], lat_range[0])) + else: + ds_o = ds_o.sel(lat=slice(lat_range[0], lat_range[1])) + + # Selecting Lon range + if lon_range: + if ds_o.lon[0] > ds_o.lon[-1]: + ds_o = ds_o.sel(lon=slice(lon_range[1], lon_range[0])) + else: + ds_o = ds_o.sel(lon=slice(lon_range[0], lon_range[1])) + return ds_o + + +def temporal_subset(ds, time_range): + """ + Subset dataset by time range, handling various None/empty cases. + + Parameters: + ----------- + ds : xarray.Dataset + Input dataset with time dimension + time_range : list, tuple, or None + Time range as [start, end]. Can contain None, "None", or be None/empty - # Get land data and prep polygons - land_110m = cartopy.feature.NaturalEarthFeature('physical', 'land', '110m') - land_polygons = list(land_110m.geometries()) - land_polygons = [prep(land_polygon) for land_polygon in land_polygons] + Returns: + -------- + ds : xarray.Dataset + Time-subsetted dataset, or original if no valid time range + """ + def is_valid_time(value): + """Check if a time value is valid (not None, "None", or empty string)""" + return value is not None and value != "None" and value != "" + + # Handle None, empty, or too short time_range + if not time_range or len(time_range) < 2: + return ds + + start, end = time_range[0], time_range[1] + + # Check if we have any valid time values + start_valid = is_valid_time(start) + end_valid = is_valid_time(end) + + if not start_valid and not end_valid: + return ds # No valid time range, return original + + # Set defaults for invalid values + if not start_valid: + start = ds.time[0] + if not end_valid: + end = ds.time[-1] + + return ds.sel(time=slice(start, end)) + +def select_valid_tau_height(ds, tau_var_name, ht_var_name, max_value=9999999999999): + """ + Select only valid tau and height/pressure range from dataset. + + Excludes failed retrievals (typically -1 values) by selecting from 0 to max_value. + Handles both pressure (decreasing) and altitude (increasing) coordinate ordering. + + Parameters: + ----------- + ds : xarray.Dataset + Input dataset containing tau and height/pressure variables + tau_var_name : str + Name of the tau variable + ht_var_name : str + Name of the height/pressure variable + max_value : int, optional + Maximum value for selection range (default: 9999999999999) + + Returns: + -------- + ds : xarray.Dataset + Dataset with valid tau and height range selected + """ + # Select valid tau range (exclude negative/failed retrievals) + tau_selection = {tau_var_name: slice(0, max_value)} + + # Handle height/pressure coordinate ordering + # Pressure: decreasing (high to low) -> slice(max, 0) + # Altitude: increasing (low to high) -> slice(0, max) + if ds[ht_var_name][0] > ds[ht_var_name][-1]: + # Decreasing coordinate (pressure) + ht_selection = {ht_var_name: slice(max_value, 0)} + else: + # Increasing coordinate (altitude) + ht_selection = {ht_var_name: slice(0, max_value)} + + # Apply selections + return ds.sel(tau_selection).sel(ht_selection) - # Make lat-lon grid - lats = ds.lat.values - lons = ds.lon.values - lon_grid, lat_grid = np.meshgrid(lons, lats) - points = [Point(point) for point in zip(lon_grid.ravel(), lat_grid.ravel())] +def finish_cluster_labels(ds_b, tau_var_name, ht_var_name): + """ + Compute cluster labels for cloud regime analysis. + + Parameters: + ----------- + ds_b : xarray.Dataset + Input dataset containing histogram data + tau_var_name : str + Name of tau variable + ht_var_name : str + Name of height variable + + Returns: + -------- + cluster_labels_b : xarray.DataArray + Cluster labels with same coordinates as input, NaN for invalid data + """ + # Selcting only the relevant data and + # stacking it to shape n_histograms, n_tau * n_pc + other_dims = [dim for dim in ds_b.dims if dim not in (tau_var_name, ht_var_name)] + histograms_b = ds_b.stack( + spacetime=other_dims, + tau_ht=(tau_var_name, ht_var_name) + ) + # convert to numpy array & compute weights + # TODO: weights abstraction + weights_b = np.cos(np.deg2rad(histograms_b.lat.values)) + mat_b = histograms_b.values + + # Find valid histograms (no NaNs) using boolean indexing + is_valid = ~np.isnan(mat_b).any(axis=1) + if not is_valid.any(): + print("[cloud_regime_analysis_error] No valid histograms found") + return None + # Check for negative values in valid data only + if (mat_b[is_valid] < 0).any(): + print(f"[cloud_regime_analysis_error] Found negative values in data. " + f"If these are fill values, convert to NaNs and try again") + return None + # Compute clusters only for valid data + valid_mat = mat_b[is_valid] + valid_weights = weights_b[is_valid] + cluster_labels_valid = precomputed_clusters( + valid_mat, cl, wasserstein_or_euclidean, ds_b + ) + + # Create output array with NaNs, then fill valid positions + cluster_labels_flat = np.full(len(mat_b), np.nan, dtype=np.float32) + cluster_labels_flat[is_valid] = cluster_labels_valid + + # Convert back to DataArray and unstack + cluster_labels_b = xr.DataArray( + data=cluster_labels_flat, + coords={"spacetime": histograms_b.spacetime}, + dims=("spacetime"), + name="cluster_labels" + ) + return cluster_labels_b.unstack() + + +################ +# REGRIDDING +################ + +def make_se_regridder(weight_file, Method="conservative"): + weights = xr.open_dataset(weight_file) + in_shape = weights.src_grid_dims.load().data + + # Since xESMF expects 2D vars, we'll insert a dummy dimension of size-1 + if len(in_shape) == 1: + in_shape = [1, in_shape.item()] + + # output variable shape + out_shape = weights.dst_grid_dims.load().data.tolist()[::-1] + + dummy_in = xr.Dataset( + { + "lat": ("lat", np.empty((in_shape[0],))), + "lon": ("lon", np.empty((in_shape[1],))), + } + ) + dummy_out = xr.Dataset( + { + "lat": ("lat", weights.yc_b.data.reshape(out_shape)[:, 0]), + "lon": ("lon", weights.xc_b.data.reshape(out_shape)[0, :]), + } + ) + regridder = xesmf.Regridder( + dummy_in, + dummy_out, + weights=weight_file, + # results seem insensitive to this method choice + # choices are coservative_normed, coservative, and bilinear + method=Method, + reuse_weights=True, + periodic=True, + ) + return regridder + + +def regrid_se_data_bilinear(regridder, data_to_regrid, column_dim_name="ncol"): + if isinstance(data_to_regrid, xr.Dataset): + vars_with_ncol = [ + name + for name in data_to_regrid.variables + if column_dim_name in data_to_regrid[name].dims + ] + updated = data_to_regrid.copy().update( + data_to_regrid[vars_with_ncol] + .transpose(..., "ncol") + .expand_dims("dummy", axis=-2) + ) + elif isinstance(data_to_regrid, xr.DataArray): + updated = data_to_regrid.transpose(..., column_dim_name).expand_dims( + "dummy", axis=-2 + ) + else: + raise ValueError( + f"Something is wrong because the data to regrid isn't xarray: {type(data_to_regrid)}" + ) + regridded = regridder(updated) + return regridded + +# +# LAND MASK CODE (probably need to simplify and move out of here) +# +def apply_land_ocean_mask(ds, only_ocean_or_land, landfrac_present=None): + """ + Apply land or ocean mask to dataset. + + Parameters: + ----------- + ds : xarray.Dataset + Input dataset with lat/lon coordinates + only_ocean_or_land : str or False + "L" for land only, "O" for ocean only, False for no masking + landfrac_present : bool, optional + Whether LANDFRAC variable is available. Auto-detected if None. + + Returns: + -------- + ds : xarray.Dataset + Masked dataset, or None if invalid option + """ + # No masking requested + if only_ocean_or_land is False: + return ds + + # Validate input + if only_ocean_or_land not in ["L", "O"]: + print('[cloud_regime_analysis ERROR] Invalid option for only_ocean_or_land: ' + 'Please enter "O" for ocean only, "L" for land only, or set to False for both') + return None + + # Auto-detect LANDFRAC if not specified + if landfrac_present is None: + landfrac_present = "LANDFRAC" in ds.data_vars or "LANDFRAC" in ds.coords + + # Use LANDFRAC if available + if landfrac_present: + land_mask_value = 1 if only_ocean_or_land == "L" else 0 + return ds.where(ds.LANDFRAC == land_mask_value) + + # Otherwise use cartopy-based land mask + land_mask = create_land_mask(ds) + + # Make land mask broadcastable with dataset + land_mask = _make_mask_broadcastable(land_mask, ds) + + # Apply mask + mask_value = 1 if only_ocean_or_land == "L" else 0 + return ds.where(land_mask == mask_value) - # Creating list of cordinates that are over land - land = [] - for land_polygon in land_polygons: - land.extend([tuple(point.coords)[0] for point in filter(land_polygon.covers, points)]) +def _make_mask_broadcastable(mask, ds): + """ + Make 2D land mask broadcastable with dataset by adding dimensions. + + Parameters: + ----------- + mask : numpy.ndarray + 2D mask array (lat, lon) + ds : xarray.Dataset + Target dataset + + Returns: + -------- + mask : numpy.ndarray + Broadcastable mask array + """ + # Add dimensions for any dims that aren't lat/lon + for i, dim in enumerate(ds.dims): + if dim not in ("lat", "lon"): + mask = np.expand_dims(mask, axis=i) + return mask - landar = np.asarray(land) - lat_lon = np.empty((len(lats)*len(lons),2)) - oh_land = np.zeros((len(lats)*len(lons))) - lat_lon[:,0] = lon_grid.flatten() - lat_lon[:,1] = lat_grid.flatten() - # Function to (somewhat) quickly test if a lat-lon point is over land - @njit() - def test (oh_land, lat_lon, landar): - for i in range(len(oh_land)): - check = lat_lon[i] == landar - if np.max(np.sum(check,axis=1)) == 2: - oh_land[i] = 1 - return oh_land +def create_land_mask(ds): + """ + Create land mask using cartopy Natural Earth data. + Improved version with better performance and cleaner code. + + Parameters: + ----------- + ds : xarray.Dataset + Dataset with lat/lon coordinates + + Returns: + -------- + land_mask : numpy.ndarray + 2D array (lat, lon) with 1 for land, 0 for ocean + """ + from cartopy import feature as cfeature + from shapely.geometry import Point + from shapely.prepared import prep + import numpy as np + from numba import njit + + # Get land polygons + land_110m = cfeature.NaturalEarthFeature("physical", "land", "110m") + land_polygons = [prep(geom) for geom in land_110m.geometries()] + # Create coordinate arrays + lats, lons = ds.lat.values, ds.lon.values + lon_grid, lat_grid = np.meshgrid(lons, lats) + # Flatten coordinates for easier processing + lon_flat, lat_flat = lon_grid.flatten(), lat_grid.flatten() + points = [Point(lon, lat) for lon, lat in zip(lon_flat, lat_flat)] + # Find land points + land_coords = [] + for polygon in land_polygons: + land_coords.extend([ + (point.x, point.y) for point in points if polygon.covers(point) + ]) + # Convert to numpy array for numba processing + land_array = np.array(land_coords) + coord_array = np.column_stack([lon_flat, lat_flat]) + # Use numba for fast coordinate matching + land_mask_flat = _find_land_points(coord_array, land_array) + # Reshape to original grid + return land_mask_flat.reshape(len(lats), len(lons)) + + +@njit() +def _find_land_points(coord_array, land_coords): + """ + Numba-compiled function to quickly identify land points. + + Parameters: + ----------- + coord_array : numpy.ndarray + Array of (lon, lat) coordinates + land_coords : numpy.ndarray + Array of known land coordinates - # Turn oh_land into a one hot matrix - oh_land = test (oh_land, lat_lon, landar) + Returns: + -------- + mask : numpy.ndarray + 1D mask array with 1 for land, 0 for ocean + """ + mask = np.zeros(len(coord_array), dtype=np.int32) + + for i in range(len(coord_array)): + coord = coord_array[i] + for j in range(len(land_coords)): + if np.allclose(coord, land_coords[j], atol=1e-10): + mask[i] = 1 + break + return mask + +import matplotlib.pyplot as plt +import matplotlib as mpl +import numpy as np +from math import ceil +import xarray as xr - # Reshape into original shape(n_lat, n_lon) - oh_land=oh_land.reshape((len(lats),len(lons))) +# ============================================================================= +# CORE PLOTTING FUNCTIONS +# ============================================================================= - return oh_land +def plot_hists_baseline(fld, cl, cluster_labels, cluster_labels_o, histograms, histograms_ref, ht_var_name, tau_var_name, htcoord, taucoord, adf): + """Plot cloud regime centers for observations, baseline, and test case.""" + plot_data = _prepare_plot_data(fld, cl, cluster_labels, cluster_labels_o, histograms, histograms_ref, ht_var_name, tau_var_name, htcoord, taucoord) + plot_data['columns'] = ['observation', 'baseline', 'test_case'] + plot_data['figsize'] = (17, plot_data['fig_height']) + plot_data['save_suffix'] = '_CR_centers' - # Checking if kwargs have been entered correctly - if wasserstein_or_euclidean not in ['euclidean', 'wasserstein']: - print(' WARNING: Invalid option for wasserstein_or_euclidean. Please enter "wasserstein" or "euclidean". Proceeding with default of euclidean distance') - wasserstein_or_euclidean = 'euclidean' - if data_product not in ['ISCCP', "MODIS", 'MISR', 'all']: - print(' WARNING: Invalid option for data_product. Please enter "ISCCP" or "MODIS", "MISR" or "all". Proceeding with default of "all"') - data_product = 'all' - if premade_cloud_regimes != None: - if type(premade_cloud_regimes) != str: - print(' WARNING: Invalid option for premade_cloud_regimes. Please enter a path to a numpy array of Cloud Regime centers of shape (n_clusters, n_dimensions_of_data). Proceeding with default clusters') - premade_cloud_regimes = None - if lat_range != None: - if type(lat_range) != list or len(lat_range) != 2: - print(' WARNING: Invalid option for lat_range. Please enter two values in square brackets sperated by a comma. Example: [-30,30]. Proceeding with entire latitude range') - lat_range = None - if lon_range != None: - if type(lon_range) != list or len(lon_range) != 2: - print(' WARNING: Invalid option for lon_range. Please enter two values in square brackets sperated by a comma. Example: [0,90]. Proceeding with entire longitude range') - lon_range = None - if only_ocean_or_land not in [False, 'L', 'O']: - print(' WARNING: Invalid option for only_ocean_or_land. Please enter "L" for land only, "O" for ocean only. Set to False or leave blank for both land and water. Proceeding with default of False') - only_ocean_or_land = False + _plot_cloud_regimes(plot_data, adf, baseline_mode=True) + + +def plot_hists_obs(fld, cl, cluster_labels, cluster_labels_o, histograms, histograms_ref, ht_var_name, tau_var_name, htcoord, taucoord, adf): + """Plot cloud regime centers for observations and test case.""" + plot_data = _prepare_plot_data(fld, cl, cluster_labels, cluster_labels_o, histograms, histograms_ref, ht_var_name, tau_var_name, htcoord, taucoord) + plot_data['columns'] = ['observation', 'test_case'] + plot_data['figsize'] = (12, plot_data['fig_height']) + plot_data['save_suffix'] = '_CR_centers' - # Checking if the path we wish to save our plots to exists, and if it doesnt creating it - if not os.path.isdir(adf.plot_location[0]): - os.makedirs(adf.plot_location[0]) + _plot_cloud_regimes(plot_data, adf, baseline_mode=False) + + +# ============================================================================= +# HELPER FUNCTIONS +# ============================================================================= + +def _prepare_plot_data(fld, cl, cluster_labels, cluster_labels_o, histograms, histograms_ref, ht_var_name, tau_var_name, htcoord, taucoord): + """Prepare all data needed for plotting. - # path to h0 files - h0_data_path = adf.get_cam_info("cam_hist_loc", required=True)[0] + "/*h0*.nc" + fld: str + name of variable - # Time Range min and max, or None for all time - time_range = [str(adf.get_cam_info("start_year")[0]), str(adf.get_cam_info("end_year")[0])] + cl + cluster centers + + cluster_labels: xr.DataArray + labels for case ([time], lat, lon) + cluster_labels_o: array-like + labels for reference data + histograms: xr.DataAray (?) + + """ - files = glob.glob(h0_data_path) - # Opening an initial dataset - init_ds = xr.open_mfdataset(files[0]) + print(f"[_prepare_plot_data] {histograms.coords = }") + + data = ALL_VARS[fld].product_name + k = len(cl) + ylabels = htcoord.values + xlabels = taucoord.values + + # Create meshgrid + X2, Y2 = np.meshgrid(np.arange(len(xlabels) + 1), np.arange(len(ylabels) + 1)) + + # Calculate figure height + fig_height = (1 + 10 / 3 * ceil(k / 3)) * 3 + + # Create weights for RFO calculations + weights = np.cos(np.deg2rad(cluster_labels.stack(z=("time", "lat", "lon")).lat.values)) + valid_inds = ~np.isnan(cluster_labels.stack(z=("time", "lat", "lon"))) + weights = weights[valid_inds] + + return { + 'field': fld, + 'data_product': data, + 'k': k, + 'cl': cl, + 'cluster_labels': cluster_labels, + 'cluster_labels_o': cluster_labels_o, + 'xlabels': xlabels, + 'ylabels': ylabels, + 'X2': X2, + 'Y2': Y2, + 'fig_height': fig_height, + 'weights': weights, + 'ht_var_name': ht_var_name, + 'tau_var_name': tau_var_name, + 'fld': fld, + 'histograms': histograms, + 'histograms_ref': histograms_ref + } + + +def _plot_cloud_regimes(plot_data, adf, baseline_mode=True): + """Main plotting function that handles both baseline and obs-only modes.""" + # Setup + plt.rcParams.update({"font.size": 14}) + cmap, norm = _create_colormap() + + # Create figure + ncols = 3 if baseline_mode else 2 + fig, ax = plt.subplots( + figsize=plot_data['figsize'], + ncols=ncols, + nrows=plot_data['k'], + sharex=True, # was "all" + sharey=True + ) + fig.subplots_adjust(right=0.88) # Leave space for colorbar + + # Handle y-axis inversion + if plot_data['data_product'] != "MISR": + ax.ravel()[1].invert_yaxis() + + # Plot columns + if baseline_mode: + _plot_observation_column(ax[:, 0], plot_data, cmap, norm, include_rfo=True) + _plot_baseline_column(ax[:, 1], plot_data, cmap, norm) + _plot_test_case_column(ax[:, 2], plot_data, cmap, norm) + else: + _plot_observation_column(ax[:, 0], plot_data, cmap, norm, include_rfo=True) + _plot_test_case_column(ax[:, 1], plot_data, cmap, norm) + + # Configure axes and labels + _configure_axes(ax, plot_data['data_product']) + _add_figure_labels(fig, ax, plot_data, baseline_mode) + _add_colorbar(fig, ax.ravel(), cmap, norm) + + # Save + _save_figure(fig, plot_data, adf) + plt.close() + + +def _plot_observation_column(ax_col, plot_data, cmap, norm, include_rfo=False): + """Plot the observation cluster centers.""" + for i in range(plot_data['k']): + # Plot cluster center + im = ax_col[i].pcolormesh( + plot_data['X2'], plot_data['Y2'], + plot_data['cl'][i].reshape(len(plot_data['xlabels']), len(plot_data['ylabels'])).T, + norm=norm, cmap=cmap + ) + + # Set title with optional RFO + if include_rfo: + rfo = _calculate_rfo(plot_data['cluster_labels_o'], i) + ax_col[i].set_title(f"Observation CR {i+1}, RFO = {np.round(float(rfo), 1)}%") + else: + ax_col[i].set_title(f"Observation CR {i+1}") + + +def _plot_baseline_column(ax_col, plot_data, cmap, norm): + """Plot the baseline cluster centers with weighted means and RFO.""" + for i in range(plot_data['k']): + # Calculate RFO + rfo = _calculate_rfo(cluster_labels_b, i) # Global variable + + # Calculate weighted mean + wmean = _calculate_weighted_mean_xr(i, plot_data['cluster_labels_o'], plot_data['histograms_ref']) + # Plot + im = ax_col[i].pcolormesh(plot_data['X2'], plot_data['Y2'], wmean, norm=norm, cmap=cmap) + ax_col[i].set_title(f"Baseline Case CR {i+1}, RFO = {np.round(rfo, 1)}%") + + +def _plot_test_case_column(ax_col, plot_data, cmap, norm): + """Plot the test case cluster centers with weighted means and RFO.""" + for i in range(plot_data['k']): + # Calculate RFO + rfo = _calculate_rfo(plot_data['cluster_labels'], i) + + # Calculate weighted mean + wmean = _calculate_weighted_mean_xr(i, plot_data['cluster_labels'], plot_data['histograms']) + + # Plot + im = ax_col[i].pcolormesh(plot_data['X2'], plot_data['Y2'], wmean, norm=norm, cmap=cmap) + ax_col[i].set_title(f"Test Case CR {i+1}, RFO = {np.round(rfo, 1)}%") - # defining dicts for variable names for each data set - data_var_dict = {'ISCCP':'FISCCP1_COSP', "MISR":'CLD_MISR', "MODIS":"CLMODIS" } - ht_var_dict = {'ISCCP':'cosp_prs', "MISR":'cosp_htmisr', "MODIS":"cosp_prs" } - tau_var_dict = {'ISCCP':'cosp_tau', "MISR":'cosp_tau', "MODIS":"cosp_tau_modis" } - # geting names of cosp data variables for all data products that will get processed - if data_product == 'all': - var_name = list(data_var_dict.values()) +def _calculate_rfo(cluster_labels, cluster_i): + """Calculate area-weighted relative frequency of occurrence for a cluster.""" + total_rfo_num = cluster_labels == cluster_i + total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat))) + + total_rfo_denom = cluster_labels >= 0 + total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat))) + + total_rfo = total_rfo_num / total_rfo_denom * 100 + return total_rfo.values + +def _calculate_weighted_mean_xr(cluster_i, cluster_labels, hists): + weights = np.cos(np.radians(hists['lat'])) + cluster_data = xr.where(cluster_labels==cluster_i, hists, np.nan) + dims = [dim for dim in hists.dims if dim in ["ncol","lat","lon","time"]] + return cluster_data.weighted(weights).mean(dim=dims) + +def _calculate_weighted_mean(cluster_i, cluster_labels, hists, weights, xlabels, ylabels): + """Calculate area-weighted mean histogram for a cluster. + + PARAMETERS + ---------- + cluster_i : int + cluster number (label) + cluster_labels : array-like + the ([time,] lat, lon) array of cluster labels for the data + hists: array-like + the ([time,] ht, tau, lat, lon) array of histograms + weights: array-like + area weights + xlabels: array + tau values + ylabels: array + height/pressure values + + RETURNS + ------- + wmean: array + Area-weigthed mean of the histograms in cluster (in percent) + """ + print(f"[_calculate_weighted_mean] cluster {cluster_i}, {cluster_labels.shape = }, {hists.shape = }, {weights.shape = }") + pts_i = np.where(cluster_labels == cluster_i) # identify points in cluster + n = pts_i.sum() # number of points in cluster + w = [pts_i] # cos(lat) + if n > 0: + weighted_hists = hists[indices_i] * weights[indices_i][:, np.newaxis] + wmean = np.sum(weighted_hists, axis=0) / np.sum(weights[indices_i]) else: - var_name = [data_var_dict[data_product]] - - # looping through to do analysis on each data product selected - for var in var_name: - - # Getting data name corresponding to the variable being opened - key_list = list(data_var_dict.keys()) - val_list = list(data_var_dict.values()) - position = val_list.index(var) - data = key_list[position] - ht_var_name = ht_var_dict[data] - tau_var_name = tau_var_dict[data] - - print(f'\n Beginning {data} Cloud Regime analysis') #testing - - # variable that gets set to true if var is missing in the data file, and is used to skip that dataset - missing_var = False - - # Trying to open time series files from cam)ts_loc - try: ds = xr.open_mfdataset(adf.get_cam_info("cam_ts_loc", required=True)[0] + f"/*{var}*") - - # If that doesnt work trying to open the variables from the h0 files - except: - print(f" -WARNING: {data} time series file does not exist, was {var} added to the diag_var_list?") - print(" -Attempting to use h0 files from cam_hist_loc, but this will be slower" ) - # Creating a list of all the variables in the dataset - remove = list(init_ds.keys()) - try: - # Deleting the variables we want to keep in our dataset, all remaining variables will be dropped upon opening the files, this allows for faster opening of large files - remove.remove(var) - # If there's a LANDFRAC variable keep it in the dataset - landfrac_present = True - try: remove.remove('LANDFRAC') - except: landfrac_present = False - - ds = xr.open_mfdataset(files, drop_variables = remove) - - # If variables are not present in h0 tell the user the variables do not exist, and that there is not COSP output for this data - except: - print(f' -{var} does not exist in h0 files, does this run have {data} COSP output? Skipping {data} for now') - missing_var = True # used to skip the code below and move onto the next var name - - # executing further analysis on this data - finally: - - # Skipping var if its not present in data files - if missing_var: - continue + wmean = np.zeros(len(xlabels) * len(ylabels)) + + # Reshape and convert to percentage if needed + wmean = wmean.reshape(len(xlabels), len(ylabels)).T + if np.max(wmean) <= 1: + wmean *= 100 + + return wmean + + +def _create_colormap(): + """Create standardized colormap and normalization.""" + p = [0, 0.2, 1, 2, 3, 4, 6, 8, 10, 15, 99] + colors = [ + "white", + (0.19215686274509805, 0.25098039215686274, 0.5607843137254902), + (0.23529411764705882, 0.3333333333333333, 0.6313725490196078), + (0.32941176470588235, 0.5098039215686274, 0.6980392156862745), + (0.39215686274509803, 0.6, 0.43137254901960786), + (0.44313725490196076, 0.6588235294117647, 0.21568627450980393), + (0.4980392156862745, 0.6784313725490196, 0.1843137254901961), + (0.5725490196078431, 0.7137254901960784, 0.16862745098039217), + (0.7529411764705882, 0.8117647058823529, 0.2), + (0.9568627450980393, 0.8980392156862745, 0.1607843137254902), + ] + cmap = mpl.colors.ListedColormap(colors) + norm = mpl.colors.BoundaryNorm(p, cmap.N, clip=True) + return cmap, norm + + +# ============================================================================= +# AXIS CONFIGURATION FUNCTIONS +# ============================================================================= + +def _configure_axes(ax, data_product): + """Configure axis ticks and labels based on data product.""" + config_functions = { + "MODIS": _configure_modis_axes, + "MISR": _configure_misr_axes, + "ISCCP": _configure_isccp_axes + } + + if data_product in config_functions: + config_functions[data_product](ax[0, 0]) - # Adjusting lon to run from -180 to 180 if it doesnt already - if np.max(ds.lon) > 180: - ds.coords['lon'] = (ds.coords['lon'] + 180) % 360 - 180 - ds = ds.sortby(ds.lon) - - # Selecting only points over ocean or points over land if only_ocean_or_land has been used - if only_ocean_or_land != False: - # If LANDFRAC variable is present, use it to mask - if landfrac_present == True: - if only_ocean_or_land == 'L': ds = ds.where(ds.LANDFRAC == 1) - elif only_ocean_or_land == 'O': ds = ds.where(ds.LANDFRAC == 0) - # Otherwise use cartopy - else: - land = create_land_mask(ds) - dims = ds.dims - - # Inserting new axis to make land a broadcastable shape with ds - for n in range(len(dims)): - if dims[n] != 'lat' and dims[n] != 'lon': - land = np.expand_dims(land, n) - - # Masking out the land or water - if only_ocean_or_land == 'L': ds = ds.where(land == 1) - elif only_ocean_or_land == 'O': ds = ds.where(land == 0) - else: raise Exception('Invalid option for only_ocean_or_land: Please enter "O" for ocean only, "L" for land only, or set to False for both land and water') - - # Selecting lat range - if lat_range != None: - if ds.lat[0] > ds.lat[-1]: - ds = ds.sel(lat=slice(lat_range[1],lat_range[0])) - else: - ds = ds.sel(lat=slice(lat_range[0],lat_range[1])) +def _configure_modis_axes(ax): + """Configure axes for MODIS data.""" + ylabels = [0, 180, 310, 440, 560, 680, 800, 1000] + xlabels = [0, 0.3, 1.3, 3.6, 9.4, 23, 60, 150] + + ax.set_yticks(np.arange(8)) + ax.set_xticks(np.arange(8)) + ax.set_yticklabels(ylabels) + ax.set_xticklabels(xlabels) + + # Hide first and last x-tick labels + xticks = ax.xaxis.get_major_ticks() + xticks[0].set_visible(False) + xticks[-1].set_visible(False) - # Selecting Lon range - if lon_range != None: - if ds.lon[0] > ds.lon[-1]: - ds = ds.sel(lon=slice(lon_range[1],lon_range[0])) - else: - ds = ds.sel(lon=slice(lon_range[0],lon_range[1])) - - # Selecting time range - if time_range != ["None","None"]: - # Need these if statements to be robust if the adf obj only has start_year or end_year - if time_range[0] == "None": - start = ds.time[0] - end = time_range[1] - elif time_range[1] == "None": - start = time_range[0] - end = ds.time[-1] - else: - start = time_range[0] - end = time_range[1] - ds = ds.sel(time=slice(start,end)) - - # Turning dataset into a dataarray - ds = ds[var] - - # Selecting only valid tau and height/pressure range - # Many data products have a -1 bin for failed retreivals, we do not wish to include this - tau_selection = {tau_var_name:slice(0,9999999999999)} - # Making sure this works for pressure which is ordered largest to smallest and altitude which is ordered smallest to largest - if ds[ht_var_name][0] > ds[ht_var_name][-1]: ht_selection = {ht_var_name:slice(9999999999999,0)} - else: ht_selection = {ht_var_name:slice(0,9999999999999)} - ds = ds.sel(tau_selection) - ds = ds.sel(ht_selection) - +def _configure_misr_axes(ax): + """Configure axes for MISR data.""" + xlabels = [0.2, 0.8, 2.4, 6.5, 16.2, 41.5, 100] + ylabels = [0.25, 0.75, 1.25, 1.75, 2.25, 2.75, 3.5, 4.5, 6, 8, 10, 12, 14, 16, 20] + + ax.set_yticks(np.arange(0, 16, 2) + 0.5) + ax.set_yticklabels(ylabels[0::2]) + ax.set_xticks(np.array([1, 2, 3, 4, 5, 6, 7]) - 0.5) + ax.set_xticklabels(xlabels, fontsize=16) + + # Hide first and last x-tick labels + xticks = ax.xaxis.get_major_ticks() + xticks[0].set_visible(False) + xticks[-1].set_visible(False) - # Opening cluster centers - # Using premade clusters if they have been provided - if type(premade_cloud_regimes) == str: - cl = np.load(premade_cloud_regimes) - # Checking if the shape is what we'd expect - if cl.shape[1] != len(ds[tau_var_name]) * len(ds[ht_var_name]): - if data == 'ISCCP' and cl.shape[1] == 42: - None - elif data == 'MISR' and cl.shape[1] == 105: - None - else: - raise Exception (f'premade_cloud_regimes is the wrong shape. premade_cloud_regimes.shape = {cl.shape}, but must be shape (k, {len(ds[tau_var_name]) * len(ds[ht_var_name])}) to fit the loaded {data} data') - print(' -Using premade cloud regimes:') - - # If custom CRs havent been passed, use either the emd or euclidean premade ones - elif wasserstein_or_euclidean == "wasserstein": - obs_data_loc = adf.get_basic_info('obs_data_loc') + '/' - cluster_centers_path = adf.variable_defaults[f"{data}_emd_centers"]['obs_file'] - cl = np.load(obs_data_loc + cluster_centers_path) - - elif wasserstein_or_euclidean == "euclidean": - obs_data_loc = adf.get_basic_info('obs_data_loc') + '/' - cluster_centers_path = adf.variable_defaults[f"{data}_euclidean_centers"]['obs_file'] - cl = np.load(obs_data_loc + cluster_centers_path) - - # Defining k, the number of clusters - k = len(cl) - - print(f' -Preprocessing data') - - # COSP ISCCP data has one extra tau bin than the satellite data, and misr has an extra height bin. This checks roughly if we are comparing against the - # satellite data, and if so removes the extra tau or ht bin. If a user passes home made CRs from CESM data, no data will be removed - if data == 'ISCCP' and cl.shape[1] == 42: - # a slightly hacky way to drop the smallest tau bin, but is robust incase tau is flipped in a future version - ds = ds.sel(cosp_tau=slice(np.min(ds.cosp_tau)+1e-11,np.inf)) - print(" -Dropping smallest tau bin to be comparable with observational cloud regimes") - if data == 'MISR' and cl.shape[1] == 105: - # a slightly hacky way to drop the lowest height bin, but is robust incase height is flipped in a future version - ds = ds.sel(cosp_htmisr=slice(np.min(ds.cosp_htmisr)+1e-11,np.inf)) - print(" -Dropping lowest height bin to be comparable with observational cloud regimes") - - # Selcting only the relevant data and stacking it to shape n_histograms, n_tau * n_pc - dims = list(ds.dims) - dims.remove(tau_var_name) - dims.remove(ht_var_name) - histograms = ds.stack(spacetime=(dims), tau_ht=(tau_var_name, ht_var_name)) - weights = np.cos(np.deg2rad(histograms.lat.values)) # weights array to use with emd-kmeans - - # Turning into a numpy array for clustering - mat = histograms.values - - # Removing all histograms with 1 or more nans in them - indicies = np.arange(len(mat)) - is_valid = ~np.isnan(mat.mean(axis=1)) - is_valid = is_valid.astype(np.int32) - valid_indicies = indicies[is_valid==1] - mat=mat[valid_indicies] - weights=weights[valid_indicies] - - print(f' -Fitting data') - - # Compute cluster labels - cluster_labels_temp = precomputed_clusters(mat, cl, wasserstein_or_euclidean, ds) - - # taking the flattened cluster_labels_temp array, and turning it into a datarray the shape of ds.var_name, and reinserting NaNs in place of missing data - cluster_labels = np.full(len(indicies), np.nan, dtype=np.int32) - cluster_labels[valid_indicies]=cluster_labels_temp - cluster_labels = xr.DataArray(data=cluster_labels, coords={"spacetime":histograms.spacetime},dims=("spacetime") ) - cluster_labels = cluster_labels.unstack() - - # Comparing to observation - if adf.compare_obs == True: - # defining dicts for variable names for each data set - obs_data_var_dict = {'ISCCP':'n_pctaudist', "MISR":'clMISR', "MODIS":"MODIS_CLD_HISTO" } - obs_ht_var_dict = {'ISCCP':'levtau', "MISR":'tau', "MODIS":"COT" } - obs_tau_var_dict = {'ISCCP':'levpc', "MISR":'cth', "MODIS":"PRES" } - - # Getting data name corresponding to the variable being opened - key_list = list(obs_data_var_dict.keys()) - val_list = list(obs_data_var_dict.values()) - obs_var = obs_data_var_dict[data] - position = val_list.index(obs_var) - data = key_list[position] - obs_ht_var_name = obs_ht_var_dict[data] - obs_tau_var_name = obs_tau_var_dict[data] - - print(f' -Starting {data} observation data') - - # Opening observation files. The obs files have three variables, precomputed euclidean cluster labels, precomputed emd cluster labels - # and then the raw data to use if custom CRs are passed in. - obs_data_path = adf.var_obs_dict[var]['obs_file'] - - # Opening the data - ds_o = xr.open_dataset(obs_data_path) - - # Selecting either the appropriate pre-computed cluster_labels or the raw data - if premade_cloud_regimes == None: - if wasserstein_or_euclidean == 'wasserstein': - ds_o = ds_o.emd_cluster_labels - else: - ds_o = ds_o.euclidean_cluster_labels - else: - ds_o = ds_o[obs_var] - - # Adjusting lon to run from -180 to 180 if it doesnt already - if np.max(ds_o.lon) > 180: - ds_o.coords['lon'] = (ds_o.coords['lon'] + 180) % 360 - 180 - ds_o = ds_o.sortby(ds_o.lon) - - # Selecting only points over ocean or points over land if only_ocean_or_land has been used - if only_ocean_or_land != False: - land = create_land_mask(ds_o) - dims = ds_o.dims - - # inserting new axis to make land a broadcastable shape with ds_o - for n in range(len(dims)): - if dims[n] != 'lat' and dims[n] != 'lon': - land = np.expand_dims(land, n) - - # Masking out the land or water - if only_ocean_or_land == 'L': ds_o = ds_o.where(land == 1) - elif only_ocean_or_land == 'O': ds_o = ds_o.where(land == 0) - else: raise Exception('Invalid option for only_ocean_or_land: Please enter "O" for ocean only, "L" for land only, or set to False for both land and water') - - # Selecting lat range - if lat_range != None: - if ds_o.lat[0] > ds_o.lat[-1]: - ds_o = ds_o.sel(lat=slice(lat_range[1],lat_range[0])) - else: - ds_o = ds_o.sel(lat=slice(lat_range[0],lat_range[1])) - - # Selecting Lon range - if lon_range != None: - if ds_o.lon[0] > ds_o.lon[-1]: - ds_o = ds_o.sel(lon=slice(lon_range[1],lon_range[0])) - else: - ds_o = ds_o.sel(lon=slice(lon_range[0],lon_range[1])) - - # Don't select time range for obsrvation, just compare to the full record - # # Selecting time range - # if time_range != ["None","None"]: - # if time_range[0] == "None": - # start = ds.time[0] - # end = time_range[1] - # elif time_range[1] == "None": - # start = time_range[0] - # end = ds.time[-1] - # else: - # start = time_range[0] - # end = time_range[1] - - # ds = ds.sel(time=slice(start,end)) - - if premade_cloud_regimes == None: - cluster_labels_o = ds_o - cluster_labels_o_temp = cluster_labels_o.stack(spacetime=("time", 'lat', 'lon')) - else: - # Selecting only valid tau and height/pressure range - # Many data products have a -1 bin for failed retreivals, we do not wish to include this - tau_selection = {obs_tau_var_name:slice(0,9999999999999)} - # Making sure this works for pressure which is ordered largest to smallest and altitude which is ordered smallest to largest - if ds_o[obs_ht_var_name][0] > ds_o[obs_ht_var_name][-1]: ht_selection = {obs_ht_var_name:slice(9999999999999,0)} - else: ht_selection = {obs_ht_var_name:slice(0,9999999999999)} - ds_o = ds_o.sel(tau_selection) - ds_o = ds_o.sel(ht_selection) - - # Selcting only the relevant data and stacking it to shape n_histograms, n_tau * n_pc - dims = list(ds_o.dims) - dims.remove(obs_tau_var_name) - dims.remove(obs_ht_var_name) - histograms_o = ds_o.stack(spacetime=(dims), tau_ht=(obs_ht_var_name, obs_tau_var_name)) - weights_o = np.cos(np.deg2rad(histograms_o.lat.values)) # weights_o array to use with emd-kmeans - - # Turning into a numpy array for clustering - mat_o = histograms_o.values - - # Removing all histograms with 1 or more nans in them - indicies = np.arange(len(mat_o)) - is_valid = ~np.isnan(mat_o.mean(axis=1)) - is_valid = is_valid.astype(np.int32) - valid_indicies_o = indicies[is_valid==1] - mat_o=mat_o[valid_indicies_o] - weights_o=weights_o[valid_indicies_o] - - if np.min(mat_o < 0): - raise Exception (f'Found negative value in ds_o.{var_name}, if this is a fill value for missing data, convert to nans and try again') - - print(f' -Fitting data') - - # Compute cluster labels - cluster_labels_temp_o = precomputed_clusters(mat_o, cl, wasserstein_or_euclidean, ds_o) - - # Taking the flattened cluster_labels_temp_o array, and turning it into a datarray the shape of obs_ds.var_name, and reinserting NaNs in place of missing data - cluster_labels_o = np.full(len(indicies), np.nan, dtype=np.int32) - cluster_labels_o[valid_indicies_o]=cluster_labels_temp_o - cluster_labels_o = xr.DataArray(data=cluster_labels_o, coords={"spacetime":histograms_o.spacetime},dims=("spacetime") ) - cluster_labels_o = cluster_labels_o.unstack() - - print(f' -Plotting') - - plot_hists_obs(cl, cluster_labels, cluster_labels_o, ht_var_name, tau_var_name, adf) - plot_rfo_obs_base_diff(cluster_labels, cluster_labels_o, adf) - - # Comparing to CAM baseline if not comparing to obs - else: - # path to h0 files - baseline_h0_data_path = adf.get_baseline_info("cam_hist_loc", required=True) + "/*h0*.nc" - # Time Range min and max, or None for all time - time_range_b = [str(adf.get_baseline_info("start_year")), str(adf.get_baseline_info("end_year"))] - # Creating a list of files - files = glob.glob(baseline_h0_data_path) - # Opening an initial dataset - init_ds_b = xr.open_dataset(files[0]) - - print(f' -Starting {data} CAM baseline data') #testing - - # Variable that gets set to true if var is missing in the data file, and is used to skip processing that dataset - missing_var = False - - # Trying to open time series files from cam)ts_loc - try: ds_b = xr.open_mfdataset(adf.get_baseline_info("cam_ts_loc", required=True) + f"/*{var}*") - - # If that doesnt work trying to open the variables from the h0 files - except: - print(f" -WARNING: {data} time series file does not exist, was {var} added to the diag_var_list?") - print(" Attempting to use h0 files from cam_hist_loc, but this will be slower" ) - # Creating a list of all the variables in the dataset - remove = list(init_ds_b.keys()) - try: - # Deleting the variables we want to keep in our dataset, all remaining variables will be dropped upon opening the files, this allows for faster opening of large files - remove.remove(var) - # If there's a LANDFRAC variable keep it in the dataset - landfrac_present = True - try: remove.remove('LANDFRAC') - except: landfrac_present = False - - # Opening dataset and dropping irrelevant data - ds_b = xr.open_mfdataset(files, drop_variables = remove) - - # If variables are not present in h0 tell the user the variables do not exist, and that there is not COSP output for this data - except: - print(f' {var} does not exist in h0 files, does this run have {data} COSP output? Skipping {data} for now') - missing_var = True # used to skip the code below and move onto the next var name - - # Executing further analysis on this data - finally: - - # Skipping var if its not present in data files - if missing_var: - continue - - # Adjusting lon to run from -180 to 180 if it doesnt already - if np.max(ds_b.lon) > 180: - ds_b.coords['lon'] = (ds_b.coords['lon'] + 180) % 360 - 180 - ds_b = ds_b.sortby(ds_b.lon) - - # Selecting only points over ocean or points over land if only_ocean_or_land has been used - if only_ocean_or_land != False: - # If LANDFRAC variable is present, use it to mask - if landfrac_present == True: - if only_ocean_or_land == 'L': ds_b = ds_b.where(ds_b.LANDFRAC == 1) - elif only_ocean_or_land == 'O': ds_b = ds_b.where(ds_b.LANDFRAC == 0) - # Otherwise use cartopy - else: - land = create_land_mask(ds_b) - dims = ds_b.dims - - # Inserting new axis to make land a broadcastable shape with ds_b - for n in range(len(dims)): - if dims[n] != 'lat' and dims[n] != 'lon': - land = np.expand_dims(land, n) - - # Masking out the land or water - if only_ocean_or_land == 'L': ds_b = ds_b.where(land == 1) - elif only_ocean_or_land == 'O': ds_b = ds_b.where(land == 0) - else: raise Exception('Invalid option for only_ocean_or_land: Please enter "O" for ocean only, "L" for land only, or set to False for both land and water') - - # Selecting lat range - if lat_range != None: - if ds_b.lat[0] > ds_b.lat[-1]: - ds_b = ds_b.sel(lat=slice(lat_range[1],lat_range[0])) - else: - ds_b = ds_b.sel(lat=slice(lat_range[0],lat_range[1])) - - # Selecting Lon range - if lon_range != None: - if ds_b.lon[0] > ds_b.lon[-1]: - ds_b = ds_b.sel(lon=slice(lon_range[1],lon_range[0])) - else: - ds_b = ds_b.sel(lon=slice(lon_range[0],lon_range[1])) - - # Selecting time range - if time_range_b != ["None","None"]: - # Need these if statements to be robust if the adf obj only has start_year or end_year - if time_range_b[0] == "None": - start = ds_b.time[0] - end = time_range_b[1] - elif time_range_b[1] == "None": - start = time_range_b[0] - end = ds_b.time[-1] - else: - start = time_range_b[0] - end = time_range_b[1] - - ds_b = ds_b.sel(time=slice(start,end)) - - # Turning dataset into a dataarray - ds_b = ds_b[var] - - # Selecting only valid tau and height/pressure range - # Many data products have a -1 bin for failed retreivals, we do not wish to include this - tau_selection = {tau_var_name:slice(0,9999999999999)} - # Making sure this works for pressure which is ordered largest to smallest and altitude which is ordered smallest to largest - if ds_b[ht_var_name][0] > ds_b[ht_var_name][-1]: ht_selection = {ht_var_name:slice(9999999999999,0)} - else: ht_selection = {ht_var_name:slice(0,9999999999999)} - ds_b = ds_b.sel(tau_selection) - ds_b = ds_b.sel(ht_selection) - - print(f' -Preprocessing data') - - # COSP ISCCP data has one extra tau bin than the satellite data, and misr has an extra height bin. This checks roughly if we are comparing against the - # satellite data, and if so removes the extra tau or ht bin. If a user passes home made CRs from CESM data, no data will be removed - if data == 'ISCCP' and cl.shape[1] == 42: - # A slightly hacky way to drop the smallest tau bin, but is robust incase tau is flipped in a future version - ds_b = ds_b.sel(cosp_tau=slice(np.min(ds_b.cosp_tau)+1e-11,np.inf)) - print(" -Dropping smallest tau bin to be comparable with observational cloud regimes") - if data == 'MISR' and cl.shape[1] == 105: - # A slightly hacky way to drop the lowest height bin, but is robust incase height is flipped in a future version - ds_b = ds_b.sel(cosp_htmisr=slice(np.min(ds_b.cosp_htmisr)+1e-11,np.inf)) - print(" -Dropping lowest height bin to be comparable with observational cloud regimes") - - # Selcting only the relevant data and stacking it to shape n_histograms, n_tau * n_pc - dims = list(ds_b.dims) - dims.remove(tau_var_name) - dims.remove(ht_var_name) - histograms_b = ds_b.stack(spacetime=(dims), tau_ht=(tau_var_name, ht_var_name)) - weights_b = np.cos(np.deg2rad(histograms_b.lat.values)) # weights_b array to use with emd-kmeans - - # Turning into a numpy array for clustering - mat_b = histograms_b.values - - # Removing all histograms with 1 or more nans in them - indicies = np.arange(len(mat_b)) - is_valid = ~np.isnan(mat_b.mean(axis=1)) - is_valid = is_valid.astype(np.int32) - valid_indicies_b = indicies[is_valid==1] - mat_b=mat_b[valid_indicies_b] - weights_b=weights_b[valid_indicies_b] - - if np.min(mat_b < 0): - raise Exception (f'Found negative value in ds_b.{var_name}, if this is a fill value for missing data, convert to nans and try again') - - print(f' -Fitting data') - - # Compute cluster labels - cluster_labels_temp_b = precomputed_clusters(mat_b, cl, wasserstein_or_euclidean, ds_b) - - # Taking the flattened cluster_labels_temp_b array, and turning it into a datarray the shape of ds.var_name, and reinserting NaNs in place of missing data - cluster_labels_b = np.full(len(indicies), np.nan, dtype=np.int32) - cluster_labels_b[valid_indicies_b]=cluster_labels_temp_b - cluster_labels_b = xr.DataArray(data=cluster_labels_b, coords={"spacetime":histograms_b.spacetime},dims=("spacetime") ) - cluster_labels_b = cluster_labels_b.unstack() - - print(f' -Plotting') - - # Plotting - plot_hists_baseline(cl, cluster_labels, cluster_labels_b, ht_var_name, tau_var_name, adf) - plot_rfo_obs_base_diff(cluster_labels, cluster_labels_b, adf) - - - - -# %% + +def _configure_isccp_axes(ax): + """Configure axes for ISCCP data.""" + xlabels = [0, 1.3, 3.6, 9.4, 22.6, 60.4, 450] + ylabels = [10, 180, 310, 440, 560, 680, 800, 1025] + + # Get current ticks + yticks = ax.get_yticks().tolist() + xticks = ax.get_xticks().tolist() + + ax.set_yticks(yticks) + ax.set_xticks(xticks) + ax.set_yticklabels(ylabels) + ax.set_xticklabels(xlabels) + + # Hide first and last x-tick labels + xticks = ax.xaxis.get_major_ticks() + xticks[0].label1.set_visible(False) + xticks[-1].label1.set_visible(False) + + +def _add_figure_labels(fig, ax, plot_data, baseline_mode): + """Add figure title and axis labels.""" + data = plot_data['data_product'] + ht_var_name = plot_data['ht_var_name'] + fig_height = plot_data['fig_height'] + # bpm: hacky attempt to get coordinate units: + if ("prs" in ht_var_name): + ht_label = "Pressure" + ht_unit = "hPa" + else: + ht_label = "Height" + ht_unit = "m" + + # Determine height or pressure + # height_or_pressure = "h" if data == "MISR" else "p" + + # Y-label + x_pos = 0.07 if baseline_mode else 0.05 + fig.supylabel(f"Cloud-top {ht_label} ({ht_unit})", x=x_pos) + + # Title positioning + bbox = ax[1, 0].get_position() # Use first column for positioning + fig.suptitle( + f"{data} Cloud Regimes", + # x=0.5, + # y=bbox.p1[1] + (1 / fig_height * 0.5) + 0.007, + fontsize=18, + ) + + # X-label positioning + bbox = ax[-1, -1].get_position() # Use last subplot for positioning + fig.supxlabel("Optical Depth", y=bbox.p0[1] - (1 / fig_height * 0.5) - 0.007) + + +def _add_colorbar(fig, ax, cmap, norm): + """Add colorbar to the figure.""" + p = [0, 0.2, 1, 2, 3, 4, 6, 8, 10, 15, 99] + # cbar_ax = fig.add_axes([1.01, 0.25, 0.04, 0.5]) + sm = ScalarMappable(norm=norm, cmap=cmap) + # sm.set_array([]) # Required for colorbar + cb = fig.colorbar(sm, + ax=ax, + orientation='vertical', + fraction=0.025, + pad=0.02, + aspect=40, + ticks=p) + cb.set_label(label="Cloud Cover (%)", size=16) + cb.ax.tick_params(labelsize=14) + + +def _save_figure(fig, plot_data, adf): + """Save figure and add to website if requested.""" + data = plot_data['data_product'] + save_path = adf.plot_location[0] + f"/{data}{plot_data['save_suffix']}" + plt.savefig(save_path) + + if adf.create_html: + if hasattr(adf, 'compare_obs') and adf.compare_obs: + # For obs comparison mode + adf.add_website_data(save_path + ".png", plot_data['field'], case_name=None, multi_case=True) + else: + # For baseline comparison mode + adf.add_website_data(save_path + ".png", plot_data['field'], adf.get_baseline_info("cam_case_name")) \ No newline at end of file From 2ce8c1eddc96b143397ca2c5ef8fb4c75c523ec9 Mon Sep 17 00:00:00 2001 From: Brian Medeiros Date: Mon, 15 Sep 2025 12:05:11 -0600 Subject: [PATCH 4/5] refactor Isaac's cloud regime analysis script --- scripts/plotting/adf_histogram.py | 2 +- scripts/plotting/cloud_regime_analysis.py | 1995 +++++++++------------ 2 files changed, 857 insertions(+), 1140 deletions(-) diff --git a/scripts/plotting/adf_histogram.py b/scripts/plotting/adf_histogram.py index bad3dc5a5..6346cf408 100644 --- a/scripts/plotting/adf_histogram.py +++ b/scripts/plotting/adf_histogram.py @@ -243,7 +243,7 @@ def make_histograms(data, land, vres): bins = np.arange(*vres['contour_levels_range']) else: print("WARNING: no sensible defaults found -- histogram will use 25 bins (bins may differ across cases)") - bins = np.linspace(data.min(), data.max(), 26) + bins = np.linspace(data.min().values, data.max().values, 26) # extend bins to catch all values hbins = np.insert(bins, 0, np.finfo(float).min) diff --git a/scripts/plotting/cloud_regime_analysis.py b/scripts/plotting/cloud_regime_analysis.py index e10af26f8..252a89022 100644 --- a/scripts/plotting/cloud_regime_analysis.py +++ b/scripts/plotting/cloud_regime_analysis.py @@ -1,4 +1,5 @@ - +import os +import joblib from math import ceil import warnings from pathlib import Path @@ -6,15 +7,6 @@ import numpy as np import xesmf -try: - import wasserstein -except: - print( - " Wasserstein package is not installed so wasserstein distance cannot be used. Attempting to use wasserstein distance will raise an error." - ) - print( - " To use wasserstein distance please install the wasserstein package in your environment: https://pypi.org/project/Wasserstein/ " - ) import matplotlib.pyplot as plt from matplotlib.cm import ScalarMappable from mpl_toolkits.axes_grid1 import make_axes_locatable @@ -44,11 +36,8 @@ def njit(func=None): return func -# --- bpm refactor -# --- set up dataclass to get metadata per variable: from dataclasses import dataclass - @dataclass(frozen=True) class VariableNames: product_name: str @@ -90,589 +79,469 @@ class VariableNames: obs_tau_var="PRES", ), } -# --- +# ============================================================================= +# MAIN ANALYSIS FUNCTION +# ============================================================================= def cloud_regime_analysis( adf, wasserstein_or_euclidean="euclidean", + ot_library="pot", + emd_method=None, premade_cloud_regimes=None, lat_range=None, lon_range=None, - only_ocean_or_land=False, + only_ocean_or_land=None ): """ - This script/function is designed to generate 2-D lat/lon maps of Cloud Regimes (CRs), as well as plots of the CR - centers themselves. It can fit data into CRs using either Wasserstein (AKA Earth Movers Distance) or the more conventional - Euclidean distance. To use this script, the user should add the appropriate COSP variables to the diag_var_list in the yaml file. - The appropriate variables are FISCCP1_COSP for ISCCP, CLD_MISR for MISR, and CLMODIS for MODIS. All three should be added to - diag_var_list if you wish to perform analysis on all three. The user can also specify to perform analysis for just one or for - all three of the data products (ISCCP, MODIS, and MISR) that there exists COSP output for. A user can also choose to use only - a specfic lat and lon range, or to use data only over water or over land. Lastly if a user has CRs that they have custom made, - these can be passed in and the script will fit data into them rather than the premade CRs that the script already points to. - There are a total of 6 sets of premade CRs, two for each data product. One set made with euclidean distance and one set made - with Wasserstein distance for ISCCP, MODIS, and MISR. Therefore when the wasserstein_or_euclidean variables is changed it is + Generates 2D maps and plots of Cloud Regime (CR) centers by comparing a + test case against observations or a baseline simulation. + + This function orchestrates the ADF workflow to to generate 2-D lat/lon maps of Cloud Regimes (CRs) and plots of the CR + centers themselves (CTP-tau histograms). It can fit data into CRs using either Wasserstein (AKA Earth Movers Distance) or + Euclidean distance. + Checks for COSP variables in diag_var_list: FISCCP1_COSP, CLD_MISR, and CLMODIS. + Whichever are found will be processed. + Optionally do analysis on subsets by masking ocean or land and specifying latitude and/or longitude bounds. + User-specified CRs can be provided, but default uses premade CRs from observational products from Davis & Medeiros (2024). + + There are 6 sets of premade CRs, two for each data product. One made with euclidean distance and one + with Wasserstein distance for ISCCP, MODIS, and MISR. + Therefore when the wasserstein_or_euclidean variables is changed it is important to undertand that not only the distance metric used to fit data into CRs is changing, but also the CRs themselves unless the user is passing in a set of premade CRs with the premade_cloud_regimes variable. - Description of kwargs: - wasserstein_or_euclidean -> Whether to use wasserstein or euclidean distance to fit CRs, enter "wasserstein" for wasserstein or - "euclidean" for euclidean. This also changes the default CRs that data is fit into from ones created - with kmeans using euclidean distance to ones using kmeans with wasserstein distance. Default is euclidean distance. - premade_cloud_regimes -> If the user wishes to use custom CRs rather than the pre-loaded ones, enter them here as a path to a numpy - array of shape (k, n_tau_bins * n_pressure_bins) - lat_range -> Range of latitudes to use enetered as a list, Ex. [-30,30]. Default is use all available latitudes - lon_range -> Range of longitudes to use enetered as a list, Ex. [-90,90]. Default is use all available longitudes - only_ocean_or_land -> Set to "O" to perform analysis with only points over water, "L" for only points over land, or False - to use data over land and water. Default is False + PARAMETERS + ---------- + adf + The ADF object + wasserstein_or_euclidean : str ("wasserstein" | "euclidean") + Whether to use wasserstein or euclidean distance to fit CRs. + This also selects the default CRs based on creation with kmeans the selected distance. + Default is euclidean distance *because it is much faster than wasserstein*. + ot_library : str ("pot" | "wasserstein") + When wasserstein distance is used, this chooses tha backend for calculation: + - "pot": python optimal transport is default, see: https://pythonot.github.io/index.html + - "wasserstein" : wasserstein package, see: https://github.com/thaler-lab/Wasserstein + NOTE: wasserstein was used originally (See Davis & Medeiros 2024), but as of ADF implementation, requires Numpy < 2. + emd_method : str ("exact" | "sinkhorn") + When wasserstein distances is used AND POT library is backend, specify the algorithm + - "exact" is uses the exact algorithm, is default, and is recommended. + - "sinkhorn" uses the Sinkhorn algorithm, which is faster, but is **highly experimental** and not recommeded. + premade_cloud_regimes : Path-like to numpy array file + Specify custom CRs to use rather than the those in ADF_variable_defaults + - enter as a path to a numpy array of shape (k, n_tau_bins * n_pressure_bins) + NOTE: specifying custom CRs has not been tested in ADF (caution!) + lat_range : like of floats + Range of latitudes to use, Example: [-30,30] + Default is use all available latitudes + lon_range : list of floats + Range of longitudes to use, Example [-90,90] + Default is use all available longitudes + only_ocean_or_land : str + Set to + - "O" to perform analysis with only points over water, + - "L" for only points over land, + - None or False to use data over land and water. + Default is None (land & water). """ dask.config.set({"array.slicing.split_large_chunks": False}) - # Plot LatLon plots of the frequency of occrence of the baseline/obs and test case - def plot_rfo_obs_base_diff(cluster_labels, cluster_labels_d, adf, field=None): - k = cluster_labels.attrs.get("k") - COLOR = "black" - mpl.rcParams["text.color"] = COLOR - mpl.rcParams["axes.labelcolor"] = COLOR - mpl.rcParams["xtick.color"] = COLOR - mpl.rcParams["ytick.color"] = COLOR - plt.rcParams.update({"font.size": 13}) - plt.rcParams["figure.dpi"] = 500 - fig_height = 7 - - # Comparing obs or baseline? - if adf.compare_obs == True: - obs_or_base = "Observation" - else: - obs_or_base = "Baseline" - - for cluster in range(k): - fig, ax = plt.subplots( - ncols=2, - nrows=2, - subplot_kw={"projection": ccrs.PlateCarree()}, - figsize=(12, fig_height), - ) - plt.subplots_adjust(wspace=0.08, hspace=0.002) - aa = ax.ravel() - - # Calculating and plotting rfo of baseline/obs - X, Y = np.meshgrid(cluster_labels_d.lon, cluster_labels_d.lat) - rfo_d = ( - np.sum(cluster_labels_d == cluster, axis=0) - / np.sum(cluster_labels_d >= 0, axis=0) - * 100 - ) - aa[0].set_extent([-180, 180, -90, 90]) - aa[0].coastlines() - mesh = aa[0].pcolormesh( - X, - Y, - rfo_d, - transform=ccrs.PlateCarree(), - rasterized=True, - cmap="GnBu", - vmin=0, - vmax=100, - ) - total_rfo_num = cluster_labels_d == cluster - total_rfo_num = np.sum( - total_rfo_num * np.cos(np.deg2rad(cluster_labels_d.lat)) - ) - total_rfo_denom = cluster_labels_d >= 0 - total_rfo_denom = np.sum( - total_rfo_denom * np.cos(np.deg2rad(cluster_labels_d.lat)) - ) - total_rfo_d = total_rfo_num / total_rfo_denom * 100 - aa[0].set_title( - f"{obs_or_base}, RFO = {round(float(total_rfo_d),1)}", pad=4 - ) - - # Calculating and plotting rfo of test_case - X, Y = np.meshgrid(cluster_labels.lon, cluster_labels.lat) - rfo = ( - np.sum(cluster_labels == cluster, axis=0) - / np.sum(cluster_labels >= 0, axis=0) - * 100 - ) - aa[1].set_extent([-180, 180, -90, 90]) - aa[1].coastlines() - mesh = aa[1].pcolormesh( - X, - Y, - rfo, - transform=ccrs.PlateCarree(), - rasterized=True, - cmap="GnBu", - vmin=0, - vmax=100, - ) - total_rfo_num = cluster_labels == cluster - total_rfo_num = np.sum( - total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat)) - ) - total_rfo_denom = cluster_labels >= 0 - total_rfo_denom = np.sum( - total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat)) - ) - total_rfo = total_rfo_num / total_rfo_denom * 100 - aa[1].set_title(f"Test Case, RFO = {round(float(total_rfo),1)}", pad=4) - - # Making colorbar - cax = fig.add_axes( - [ - aa[1].get_position().x1 + 0.01, - aa[1].get_position().y0, - 0.02, - aa[1].get_position().height, - ] - ) - cb = plt.colorbar(mesh, cax=cax) - cb.set_label(label="RFO (%)") - - # Calculating and plotting difference - # If observation/baseline is a higher resolution interpolate from obs/baseline to CAM grid - if len(cluster_labels_d.lat) * len(cluster_labels_d.lon) > len( - cluster_labels.lat - ) * len(cluster_labels.lon): - rfo_d = rfo_d.interp_like(rfo, method="nearest") - - # If CAM is a higher resolution interpolate from CAM to obs/baseline grid - if len(cluster_labels_d.lat) * len(cluster_labels_d.lon) <= len( - cluster_labels.lat - ) * len(cluster_labels.lon): - rfo = rfo.interp_like(rfo_d, method="nearest") - X, Y = np.meshgrid(cluster_labels_d.lon, cluster_labels_d.lat) - - rfo_diff = rfo - rfo_d - - aa[2].set_extent([-180, 180, -90, 90]) - aa[2].coastlines() - mesh = aa[2].pcolormesh( - X, - Y, - rfo_diff, - transform=ccrs.PlateCarree(), - rasterized=True, - cmap="coolwarm", - vmin=-100, - vmax=100, - ) - total_rfo_num = cluster_labels == cluster - total_rfo_num = np.sum( - total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat)) - ) - total_rfo_denom = cluster_labels >= 0 - total_rfo_denom = np.sum( - total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat)) - ) - total_rfo = total_rfo_num / total_rfo_denom * 100 - aa[2].set_title( - f"Test - {obs_or_base}, ΔRFO = {round(float(total_rfo-total_rfo_d),1)}", - pad=4, - ) - - # Setting yticks - aa[0].set_yticks([-60, -30, 0, 30, 60], crs=ccrs.PlateCarree()) - aa[2].set_yticks([-60, -30, 0, 30, 60], crs=ccrs.PlateCarree()) - lat_formatter = LatitudeFormatter() - aa[0].yaxis.set_major_formatter(lat_formatter) - aa[2].yaxis.set_major_formatter(lat_formatter) - - # making colorbar for diff plot - cax = fig.add_axes( - [ - aa[2].get_position().x1 + 0.01, - aa[2].get_position().y0, - 0.02, - aa[2].get_position().height, - ] - ) - cb = plt.colorbar(mesh, cax=cax) - cb.set_label(label="ΔRFO (%)") - - # plotting x labels - aa[1].set_xticks( - [ - -120, - -60, - 0, - 60, - 120, - ], - crs=ccrs.PlateCarree(), - ) - lon_formatter = LongitudeFormatter(zero_direction_label=True) - aa[1].xaxis.set_major_formatter(lon_formatter) - aa[2].set_xticks( - [ - -120, - -60, - 0, - 60, - 120, - ], - crs=ccrs.PlateCarree(), - ) - lon_formatter = LongitudeFormatter(zero_direction_label=True) - aa[2].xaxis.set_major_formatter(lon_formatter) - - bbox = aa[1].get_position() - p1 = bbox.p1 - plt.suptitle( - f"CR{cluster+1} Relative Frequency of Occurence", - y=p1[1] + (1 / fig_height * 0.5), - ) - - aa[-1].remove() - - save_path = adf.plot_location[0] + f"/{field}_CR{cluster+1}_LatLon_mean" - plt.savefig(save_path) - - if adf.create_html: - adf.add_website_data( - save_path + ".png", field, case_name=None, multi_case=True - ) - - # Closing the figure - plt.close() - - ################################################################### - # MAIN - ################################################################### - - # Checking if kwargs have been entered correctly - if wasserstein_or_euclidean not in ["euclidean", "wasserstein"]: - print( - ' WARNING: Invalid option for wasserstein_or_euclidean. Please enter "wasserstein" or "euclidean". Proceeding with default of euclidean distance' - ) - wasserstein_or_euclidean = "euclidean" - if premade_cloud_regimes != None: - if type(premade_cloud_regimes) != str: - print( - " WARNING: Invalid option for premade_cloud_regimes. Please enter a path to a numpy array of Cloud Regime centers of shape (n_clusters, n_dimensions_of_data). Proceeding with default clusters" - ) - premade_cloud_regimes = None - if lat_range != None: - if type(lat_range) != list or len(lat_range) != 2: - print( - " WARNING: Invalid option for lat_range. Please enter two values in square brackets sperated by a comma. Example: [-30,30]. Proceeding with entire latitude range" - ) - lat_range = None - if lon_range != None: - if type(lon_range) != list or len(lon_range) != 2: - print( - " WARNING: Invalid option for lon_range. Please enter two values in square brackets sperated by a comma. Example: [0,90]. Proceeding with entire longitude range" - ) - lon_range = None - if only_ocean_or_land not in [False, "L", "O"]: - print( - ' WARNING: Invalid option for only_ocean_or_land. Please enter "L" for land only, "O" for ocean only. Set to False or leave blank for both land and water. Proceeding with default of False' - ) - only_ocean_or_land = False - - # NOTE: probably have to move into case loop - time_range = [ - str(adf.get_cam_info("start_year")[0]), - str(adf.get_cam_info("end_year")[0]), - ] - - # --- BPM refactor --- - # determine which variables to try - cr_vars = [] + # 1. Validate user inputs & set configuration + opts = _validate_user_inputs( + wasserstein_or_euclidean, + premade_cloud_regimes, + lat_range, + lon_range, + only_ocean_or_land, + ) + + time_range = [str(adf.get_cam_info("start_year")[0]), str(adf.get_cam_info("end_year")[0])] + opts['time_range'] = time_range landfrac_present = "LANDFRAC" in adf.diag_var_list - print(f"Did we find LANDFRAC in the variable list: {landfrac_present}") - for field in adf.diag_var_list: - if field in ["FISCCP1_COSP", "CLD_MISR", "CLMODIS"]: - cr_vars.append(field) + opts['landfrac_present'] = landfrac_present + opts['emd_method'] = emd_method + opts['n_cpus'] = adf.get_basic_info('num_procs') - # process each each COSP cloud variable + # 2. Process each COSP cloud variable + cr_vars = [field for field in adf.diag_var_list if field in ALL_VARS] for field in cr_vars: - print(f"WORK ON {field}") - cluster_spec = premade_cloud_regimes if premade_cloud_regimes is not None else wasserstein_or_euclidean - - ht_var_name = ALL_VARS[field].ht_var - tau_var_name = ALL_VARS[field].tau_var - if adf.compare_obs: - ref_ht_var_name = ALL_VARS[field].obs_ht_var - ref_tau_var_name = ALL_VARS[field].obs_tau_var - else: - ref_ht_var_name = ALL_VARS[field].ht_var - ref_tau_var_name = ALL_VARS[field].tau_var - - # GET REFERENCE DATA, use for all cases + print(f"INFO: Processing variable: {field}") + var_info = ALL_VARS[field] + + # 3. Load cluster centers + cluster_spec = premade_cloud_regimes if premade_cloud_regimes is not None else opts['distance'] + cl = load_cluster_centers(adf, cluster_spec, field) + if cl is None: + warnings.warn(f"WARNING: Skipping {field} due to failed cluster center loading.") + continue + opts['cl_shape'] = cl.shape + + # 4. Load and process reference data to get reference labels ref_data = load_reference_data(adf, field) - if adf.compare_obs: - # ref_data should be a dataset in this case - # reference regime labels or cloud data (to be labeled) - if premade_cloud_regimes is None: - if wasserstein_or_euclidean == "wasserstein": - ds_o = ref_data.emd_cluster_labels - else: - ds_o = ref_data.euclidean_cluster_labels - else: - ds_o = ref_data[adf.variable_defaults[field]['obs_var_name']] - else: - ds_o = ref_data # already a dataarray + if ref_data is None: continue + + ref_labels = _get_ref_cluster_labels(adf, ref_data, field, var_info, cl, opts) + if ref_labels is None: + warnings.warn(f"WARNING: Could not generate reference labels for {field}. Skipping.") + continue + # 5. Process each test case against the reference for case_name in adf.data.case_names: + print(f"\nINFO: Analyzing case: {case_name}") + c_ts_da = adf.data.load_timeseries_da(case_name, field) if c_ts_da is None: - print( - f"\t WARNING: Variable {field} for case '{case_name}' provides None type. Skipping this variable" - ) - skip_var = True + warnings.warn(f"WARNING: Variable {field} for case '{case_name}' is None. Skipping.") continue - else: - print( - f"\t Loaded time series for {field} ==> {c_ts_da.shape = }, {c_ts_da.coords = }" - ) + + # Regrid if on unstructured grid (e.g., 'ncol' dimension) if "ncol" in c_ts_da.dims: - # right now we are remapping to fv09 grid because that - # is the mapping available. - # TODO: generalize; would save time to remap to sat data grid - print("Trigger regrid (ne30-to-fv09 ONLY)") - regrid_weights_file = Path( - "/glade/work/brianpm/mapping_ne30pg3_to_fv09_esmfbilin.nc" - ) - rg = make_se_regridder( - regrid_weights_file, Method="bilinear" - ) # algorithm needs to match - ds = regrid_se_data_bilinear( - rg, c_ts_da, column_dim_name="ncol" - ) + print("INFO: Regridding data from unstructured grid.") + regrid_weights_file = Path("/glade/work/brianpm/mapping_ne30pg3_to_fv09_esmfbilin.nc") + rg = make_se_regridder(regrid_weights_file, Method="bilinear") + ds = regrid_se_data_bilinear(rg, c_ts_da, column_dim_name="ncol") else: - ds = c_ts_da # assumption: already on lat-lon grid - - ##### DATA PRE-PROCESSING - # Adjusting lon to run from -180 to 180 if it doesnt already - if np.max(ds.lon) > 180: - ds.coords["lon"] = (ds.coords["lon"] + 180) % 360 - 180 - ds = ds.sortby(ds.lon) - - # Selecting only points over ocean or points over land if only_ocean_or_land has been used - ds = apply_land_ocean_mask(ds, only_ocean_or_land, landfrac_present) - if ds is None: - return # Error occurred - # Turning dataset into a dataarray - if isinstance(ds, xr.Dataset): - ds = ds[field] - ds = spatial_subset(ds, lat_range, lon_range) - ds = temporal_subset(ds, time_range) - ds = select_valid_tau_height(ds, tau_var_name, ht_var_name) - ##### - - # CLUSTER CENTERS - cl = load_cluster_centers(adf, cluster_spec, field) - if cl is None: - print(f"Skipping cloud regime analysis for {field} due to failed cluster center loading.") - continue # Skip to the next variable in cr_vars - + ds = c_ts_da + + # Preprocess test case data + processed_ds = _preprocess_data(ds, field, var_info, opts) + if processed_ds is None: continue + + # Compute cluster labels for the test case + test_labels = compute_cluster_labels(processed_ds, var_info.tau_var, var_info.ht_var, cl, opts['distance'], ot_library, method=opts['emd_method'], num_cpus=opts['n_cpus']) + test_labels.attrs['k'] = cl.shape[0] + + # 6. Generate all plots + print("INFO: Generating plots...") + tau_coord = processed_ds[var_info.tau_var] + ht_coord = processed_ds[var_info.ht_var] - # COSP ISCCP data has one extra tau bin than the satellite data, and misr has an extra height bin. - # This checks roughly if we are comparing against the - # satellite data, and if so removes the extra tau or ht bin. - # If a user passes home made CRs from CESM data, no data will be removed - if ALL_VARS[field].product_name == "ISCCP" and cl.shape[1] == 42: - sel_dict = {tau_var_name: slice(np.min(ds[tau_var_name]) + 1e-11, None)} - ds = ds.sel(sel_dict) - print(f"\t Dropping smallest tau bin ({tau_var_name}) to be comparable with observational cloud regimes") - if ALL_VARS[field].product_name == "MISR" and cl.shape[1] == 105: - sel_dict = {ht_var_name: slice(np.min(ds[ht_var_name]) + 1e-11, None)} - ds = ds.sel(sel_dict) - print(f"\t Dropping lowest height bin ({ht_var_name}) to be comparable with observational cloud regimes") - - # CASE CLUSTER LABELING: - cluster_labels = compute_cluster_labels(ds, tau_var_name, ht_var_name, cl, wasserstein_or_euclidean) - print(f"{case_name} {field} cluster labels calculated.") - - ref_opts = {"premade_cloud_regimes":premade_cloud_regimes, - "distance": wasserstein_or_euclidean, - "landsea": only_ocean_or_land, - "landfrac": landfrac_present, # need to deal with this better - "lat_range": lat_range, - "lon_range": lon_range, - "time_range": time_range, - "tau_name": ref_tau_var_name, - "ht_name": ref_ht_var_name, - "data": ALL_VARS[field].product_name - } - cluster_labels_ref = compute_ref_cluster_labels(adf, ds_o, field, ref_opts) - - # PLOTS - taucoord = ds[tau_var_name] - htcoord = ds[ht_var_name] - # let cluster_labels know the number of clusters: - cluster_labels.attrs['k'] = cl.shape[0] - # `plot_rfo_obs_base_diff` expects `cluster_labels_ref` to be latxlon if adf.compare_obs: - plot_hists_obs( - field, cl, cluster_labels, cluster_labels_ref, ds, ds_o, ht_var_name, tau_var_name, htcoord, taucoord, adf - ) - plot_rfo_obs_base_diff(cluster_labels, cluster_labels_ref, adf, field=field) + plot_hists_obs(field, cl, test_labels, ref_labels, processed_ds, ref_data, var_info.ht_var, var_info.tau_var, ht_coord, tau_coord, adf) else: - plot_hists_baseline( - field, - cl, - cluster_labels, - cluster_labels_ref, - ds, - ds_o, # only is ref histograms for simulation, right - ht_var_name, - tau_var_name, - htcoord, - taucoord, - adf, - ) - plot_rfo_obs_base_diff(cluster_labels, cluster_labels_ref, adf, field=field) - # ^^^ BPM refactor ^^^ - - -def compute_ref_cluster_labels(adf, ds_ref, field, opts): - if adf.compare_obs == True: - ds_o = ds_ref - obs_var = adf.variable_defaults[field]['obs_var_name'] - # Adjusting lon to run from -180 to 180 if it doesnt already - if np.max(ds_o.lon) > 180: - ds_o.coords["lon"] = (ds_o.coords["lon"] + 180) % 360 - 180 - ds_o = ds_o.sortby(ds_o.lon) - - # this landfrac_present is probably not for ref dataset. - ds_o = apply_land_ocean_mask(ds_o, opts['landsea'], opts['landfrac']) - if ds_o is None: - print("[CRA compute_ref_cluster_labels] reference data is None.") - return # Error occurred - ds_o = spatial_subset(ds_o, opts['lat_range'], opts['lon_range']) # bpm - if ds_o is None: - print("[CRA compute_ref_cluster_labels] reference data is None.") - return # Error occurred + plot_hists_baseline(field, cl, test_labels, ref_labels, processed_ds, ref_data, var_info.ht_var, var_info.tau_var, ht_coord, tau_coord, adf) - if opts['premade_cloud_regimes'] is None: - print(f"[CRA compute_ref_cluster_labels] {opts['premade_cloud_regimes'] = }") - cluster_labels_o = ds_o - cluster_labels_ref = cluster_labels_o.stack( - spacetime=("time", "lat", "lon") - ).unstack() ## <- do we want to unstack here? - else: - print(f"[CRA compute_ref_cluster_labels] {opts['premade_cloud_regimes'] = }") - ds_o = select_valid_tau_height(ds_o, opts['tau_name'], opts['ht_name']) - cluster_labels_ref = finish_cluster_labels(ds_o, opts['tau_name'], opts['ht_name']) + plot_rfo_maps(test_labels, ref_labels, adf, field) + + +# +# --- local functions --- +# +def _validate_user_inputs(distance, regimes, lat_r, lon_r, land_ocean): + """Validates inputs, returning an options dictionary.""" + opts = {} + if distance not in ["euclidean", "wasserstein"]: + warnings.warn('WARNING: Invalid distance metric. Defaulting to "euclidean".') + opts['distance'] = "euclidean" else: - # Compare to simulation case. - ds_b = ds_ref - time_range_b = [ - str(adf.get_baseline_info("start_year")), - str(adf.get_baseline_info("end_year")), - ] - landfrac_present = opts['landfrac'] - # Adjusting lon to run from -180 to 180 if it doesnt already - if np.max(ds_b.lon) > 180: - ds_b.coords["lon"] = (ds_b.coords["lon"] + 180) % 360 - 180 - ds_b = ds_b.sortby(ds_b.lon) - - # this landfrac_present is porbably not for ds_b - ds_b = apply_land_ocean_mask(ds_b, opts['landsea'], opts['landfrac']) - if ds_b is None: - return # Error occurred - ds_b = spatial_subset(ds_b, opts['lat_range'], opts['lon_range']) - ds_b = temporal_subset(ds_b, time_range) - - # Turning dataset into a dataarray - if isinstance(ds_b, xr.Dataset): - ds_b = ds_b[field] - - ds_b = select_valid_tau_height(ds_b, opts['tau_name'], opts['ht_name']) - - # COSP ISCCP data has one extra tau bin than the satellite data, and misr has an extra height bin. This checks roughly if we are comparing against the - # satellite data, and if so removes the extra tau or ht bin. If a user passes home made CRs from CESM data, no data will be removed - if data == "ISCCP" and cl.shape[1] == 42: - # A slightly hacky way to drop the smallest tau bin, but is robust incase tau is flipped in a future version - ds_b = ds_b.sel( - cosp_tau=slice(np.min(ds_b.cosp_tau) + 1e-11, np.inf) - ) - print( - "\t Dropping smallest tau bin to be comparable with observational cloud regimes" - ) - if data == "MISR" and cl.shape[1] == 105: - # A slightly hacky way to drop the lowest height bin, but is robust incase height is flipped in a future version - ds_b = ds_b.sel( - cosp_htmisr=slice(np.min(ds_b.cosp_htmisr) + 1e-11, np.inf) - ) - print( - "\t Dropping lowest height bin to be comparable with observational cloud regimes" - ) - - cluster_labels_ref = finish_cluster_labels(ds_b, opts['tau_name'], opts['ht_name']) #bpm new func - return cluster_labels_ref - - -def precomputed_clusters(mat, cl, wasserstein_or_euclidean, ds): - """Compute cluster labels from precomputed cluster centers with appropriate distance""" + opts['distance'] = distance + opts['premade_cloud_regimes'] = regimes + opts['lat_range'] = lat_r if isinstance(lat_r, list) and len(lat_r) == 2 else None + opts['lon_range'] = lon_r if isinstance(lon_r, list) and len(lon_r) == 2 else None + if not land_ocean: + print("INFO: Default to using both LAND and OCEAN points.") + opts['only_ocean_or_land'] = False + elif land_ocean not in ["L", "O"]: + warnings.warn('WARNING: Invalid land/ocean flag. Defaulting to False (land and ocean).') + opts['only_ocean_or_land'] = False + else: + opts['only_ocean_or_land'] = land_ocean + return opts + +def _get_ref_cluster_labels(adf, ref_data, field, var_info, cl, opts): + """Computes and returns cluster labels for the reference (obs or baseline).""" + if adf.compare_obs: + # If pre-computed labels are in the file, use them + if opts['premade_cloud_regimes'] is None: + label_var = "emd_cluster_labels" if opts['distance'] == "wasserstein" else "euclidean_cluster_labels" + if label_var in ref_data: + print(f"INFO: Using pre-computed reference labels: {label_var}") + # Get the labels and ensure their longitude is standardized to -180 to 180 + labels = ref_data[label_var] + if 'lon' in labels.coords and labels.lon.max() > 180: + print("INFO: Standardizing longitude for pre-computed reference labels.") + labels = labels.assign_coords(lon=(((labels.lon + 180) % 360) - 180)).sortby("lon") + return labels + # Otherwise, compute labels from histograms + ref_ht_var = var_info.obs_ht_var + ref_tau_var = var_info.obs_tau_var + data_var = var_info.obs_data_var + processed_ref = _preprocess_data(ref_data[data_var], data_var, VariableNames("", "", ref_ht_var, ref_tau_var, "", "", ""), opts) + else: # Comparing to baseline simulation + baseline_info = adf.get_baseline_info + time_range_b = [str(baseline_info("start_year")), str(baseline_info("end_year"))] + baseline_opts = {**opts, 'time_range': time_range_b} + processed_ref = _preprocess_data(ref_data, field, var_info, baseline_opts) + ref_ht_var, ref_tau_var = var_info.ht_var, var_info.tau_var + + if processed_ref is None: + return None + + return compute_cluster_labels(processed_ref, ref_tau_var, ref_ht_var, cl, opts['distance'], ot_library, method=opts['emd_method'], num_cpus=opts['n_cpus']) + +def _preprocess_data(ds, field_name, var_info, opts): + """Performs all preprocessing steps on a data array before clustering.""" + if isinstance(ds, xr.Dataset): + ds = ds[field_name] + + if 'lon' in ds.coords and ds.lon.max() > 180: + ds = ds.assign_coords(lon=(((ds.lon + 180) % 360) - 180)).sortby("lon") + + ds = apply_land_ocean_mask(ds, opts['only_ocean_or_land'], opts.get('landfrac_present')) + if ds is None: return None + + ds = spatial_subset(ds, opts['lat_range'], opts['lon_range']) + ds = temporal_subset(ds, opts.get('time_range')) + if ds is None or 'time' not in ds.dims or ds.time.size == 0: + warnings.warn(f"WARNING: No data remains for {field_name} after subsetting.") + return None + + ds = select_valid_tau_height(ds, var_info.tau_var, var_info.ht_var) + + # Special handling when comparing against observational regimes + if ALL_VARS[field_name].product_name == "ISCCP" and opts['cl_shape'][1] == 42: + ds = ds.sel({var_info.tau_var: slice(ds[var_info.tau_var].min().item() + 1e-11, None)}) + print(f"\t Dropping smallest tau bin ({var_info.tau_var}) for obs comparison.") + if ALL_VARS[field_name].product_name == "MISR" and opts['cl_shape'][1] == 105: + ds = ds.sel({var_info.ht_var: slice(ds[var_info.ht_var].min().item() + 1e-11, None)}) + print(f"\t Dropping lowest height bin ({var_info.ht_var}) for obs comparison.") + + return ds + + +def compute_cluster_labels(ds, tau_var_name, ht_var_name, cl, wasserstein_or_euclidean, ot_library='pot', method=None, num_cpus=None): + """ + Computes cluster labels for a given data array of histograms. + + PARAMETERS + ---------- + ds : xr.DataArray + input data with histograms + tau_var_name : str + tau dimension name + ht_var_name : str + CTH/CTP dimension name + cl + cluster centers + wassterstein_or_euclidean : str + distrance metric choice + ot_library : str + backend library for Wasserstein + methos : str + algorithm for wasserstein distance (default is exact) + num_cpus : int + number of CPU cores to assume for wasserstein calculation + + RETURNS + ------- + cluster_labels : xr.DataArray + """ + dims = [dim for dim in ds.dims if dim not in [tau_var_name, ht_var_name]] + histograms = ds.stack(spacetime=dims, tau_ht=(tau_var_name, ht_var_name)) + mat = histograms.values + is_valid = ~np.isnan(mat).any(axis=1) + if not np.any(is_valid): + warnings.warn("ERROR: No valid histograms found after removing NaNs.") + return None + mat_valid = mat[is_valid] + print(f"INFO: Fitting {len(mat_valid)} valid histograms to cluster centers.") + labels_valid = precomputed_clusters(mat_valid, cl, wasserstein_or_euclidean, ds, tau_var_name, ht_var_name, ot_library, method, num_cpus=num_cpus) + + cluster_labels_flat = np.full(len(mat), np.nan, dtype=np.float32) + cluster_labels_flat[is_valid] = labels_valid + + cluster_labels = xr.DataArray( + data=cluster_labels_flat, + coords={"spacetime": histograms.spacetime}, + dims=("spacetime"), + ) + return cluster_labels.unstack() + + +def precomputed_clusters(mat, cl, wasserstein_or_euclidean, ds, tau_var_name, ht_var_name, ot_library='pot', emd_method=None, num_cpus=None): + """ + Compute cluster labels from precomputed cluster centers. + + PARAMETERS + ---------- + mat + array of histograms, reshaped to be (time-lat-lon)x(2 histogram dimensions) + cl + cluster centers + wasserstein_or_euclidean : str + choice of distance metric + ds : xr.DataArray + the histogram DataArray, used for the dimension/coordinate information + tau_var_name, ht_var_name : str + names of the tau and vertical dimensions (in ds) + ot_library : str + The library to use for Optimal Transport. + Either 'pot' or 'wasserstein'. Defaults to 'pot'. + emd_method : str + When ot_library is 'pot', can use 'exact' or 'sinkhorn' for calculation + num_cpus : int + The number of CPU cores to specify (specified as number of threads for wasserstein) + RETURNS + ------- + array of cluster labels (integers) + """ if wasserstein_or_euclidean == "euclidean": - cluster_dists = np.sum((mat[:, :, None] - cl.T[None, :, :]) ** 2, axis=1) - cluster_labels_temp = np.argmin(cluster_dists, axis=1) + distances = np.sum((mat[:, :, None] - cl.T[None, :, :]) ** 2, axis=1) elif wasserstein_or_euclidean == "wasserstein": - # A function to convert mat into the form required for the EMD calculation - @njit() - def stacking(position_matrix, centroids): - centroid_list = [] - - for i in range(len(centroids)): - x = np.empty((3, len(mat[0]))).T - x[:, 0] = centroids[i] - x[:, 1] = position_matrix[0] - x[:, 2] = position_matrix[1] - centroid_list.append(x) - - return centroid_list - - # setting shape - n1 = len(ds[tau_var_name]) - n2 = len(ds[ht_var_name]) - - # Calculating the max distance between two points to be used as hyperparameter in EMD - # This is not necesarily the only value for this variable that can be used, see Wasserstein documentation - # on R hyper-parameter for more information - R = (n1**2 + n2**2) ** 0.5 - - # Creating a flattened position matrix to pass wasersstein.PairwiseEMD - position_matrix = np.zeros((2, n1, n2)) - position_matrix[0] = np.tile(np.arange(n2), (n1, 1)) - position_matrix[1] = np.tile(np.arange(n1), (n2, 1)).T - position_matrix = position_matrix.reshape(2, -1) - - # Initialising wasserstein.PairwiseEMD - emds = wasserstein.PairwiseEMD( - R=R, norm=True, dtype=np.float32, verbose=1, num_threads=162 - ) + distances = None + # Try preferred library first + if ot_library == 'pot': + distances = _compute_distances_pot(mat, cl, ds, tau_var_name, ht_var_name, method=emd_method) + elif (ot_library is None) or (ot_library == 'wasserstein'): + distances = _compute_distances_wasserstein(mat, cl, ds, tau_var_name, ht_var_name, num_cpus=num_cpus) + else: + warnings.warn(f"precomputed_clusters needs calculation backend in (`pot`,`wasserstein`), got {ot_library}") + return None + + if distances is None: + warnings.warn(f"ERROR: [precomputed_clusters] Calculation failed. Neither POT nor wasserstein library could be used successfully.") + return None - # Rearranging mat to be in the format necesary for wasserstein.PairwiseEMD - events = stacking(position_matrix, mat) - centroid_list = stacking(position_matrix, cl) - emds(events, centroid_list) - print("\t Calculating Wasserstein distances") - print( - "\t Warning: This can be slow, but scales very well with additional processors" - ) - distances = emds.emds() - labels = np.argmin(distances, axis=1) + # find the smallest distance for each - that is the cluster classification + return np.argmin(distances, axis=1) - cluster_labels_temp = np.argmin(distances, axis=1) - else: - print("[CRA: precomuted_clusters] ERROR -- must specify Wasserstein or Euclidean.") - return - return cluster_labels_temp +def _compute_distances_pot(mat, cl, ds, tau_var_name, ht_var_name, method=None): + """ + Computes pairwise Wasserstein distances using POT and parallelizes the + calculation with joblib for performance. + """ + try: + import ot + from ot.lp import emd2 + from ot import sinkhorn2 + except ImportError: + warnings.warn("Python Optimal Transport (POT) package not found or corrupt. Cannot use 'pot' library.") + return None + + print("\t INFO: Using Python Optimal Transport (POT) library with joblib for parallel execution.") + + # 1. Define the ground metric (cost matrix) ONCE. + n_tau = len(ds[tau_var_name]) + n_ht = len(ds[ht_var_name]) + x_coords, y_coords = np.meshgrid(np.arange(n_ht), np.arange(n_tau)) + coords = np.vstack([y_coords.ravel(), x_coords.ravel()]).T + + M = ot.dist(coords, coords, metric='euclidean') + M /= M.max() + + reg_val = 0.1 * M.mean() # used for regularization w/ Sinkhorn -- tunable parameter (bigger should be faster, but can be unstable). + + # 2. Normalize the cluster centers ONCE. + cl_sum = cl.sum(axis=1, keepdims=True) + cl_normalized = cl / (cl_sum + 1e-9) + + # 3. Define a helper function that INCLUDES normalization for the model data. + def compute_single_histogram_distances(histogram, centers_normalized, cost_matrix, method=None): + if method is None: + method = 'sinkhorn' + hist_sum = histogram.sum() + if hist_sum < 1e-9: + return np.full(centers_normalized.shape[0], np.inf) + hist_normalized = histogram / hist_sum + if method == 'exact': + return [emd2(hist_normalized, center, cost_matrix) for center in centers_normalized] + elif method == 'sinkhorn': + # Use stabilized version. Somewhat faster than exact calculation. + return [sinkhorn2(hist_normalized, center, cost_matrix, reg=reg_val, method='sinkhorn_stabilized', log=False) for center in centers_normalized] + else: + warnings.warn(f"ERROR: compute_single_histogram_distances method must be (None, sinkhorn, exact), got {method}") + return None + + # --- FIX: Use the helper function to get the *allocated* core count --- + n_jobs = 36 # _get_hpc_job_cores() + print(f"\t Distributing EMD calculation across {n_jobs} allocated cores...") + + # 4. Use joblib to run the calculations in parallel + distances_list = joblib.Parallel(n_jobs=n_jobs, verbose=10)( + joblib.delayed(compute_single_histogram_distances)(mat[i, :], cl_normalized, M, method) for i in range(mat.shape[0]) + ) + + distances = np.array(distances_list) + + return distances -def load_reference_data(adfobj, varname): - """Load and reference data. - Make usual ADF assumption that reference case could be simulation or observation. +def _compute_distances_wasserstein(mat, cl, ds, tau_var_name, ht_var_name, num_cpus=None): + """ + Computes pairwise Wasserstein distances using the 'wasserstein' library. + """ + try: + import wasserstein + except ImportError: + warnings.warn("'wasserstein' package not found. Cannot use 'wasserstein' library.") + return None - If compare_obs, returns a xr.Dataset, - otherwise returns time series xr.DataArray. + print("\t INFO: Using 'wasserstein' library. Will try to JIT compile `stacking` function.") + + # This function is defined locally as it's highly specific to this library's API + @njit() + def stacking(position_matrix, centroids): + centroid_list = [] + for i in range(len(centroids)): + x = np.empty((3, len(mat[0]))).T + x[:, 0] = centroids[i] + x[:, 1] = position_matrix[0] + x[:, 2] = position_matrix[1] + centroid_list.append(x) + return centroid_list + + n1 = len(ds[tau_var_name]) + n2 = len(ds[ht_var_name]) + R = (n1**2 + n2**2)**0.5 + + position_matrix = np.zeros((2, n1, n2)) + position_matrix[0] = np.tile(np.arange(n2), (n1, 1)) + position_matrix[1] = np.tile(np.arange(n1), (n2, 1)).T + position_matrix = position_matrix.reshape(2, -1) + + num_threads = num_cpus if num_cpus is not None else 1 + print(f"\t Using {num_threads} threads for calculation.") + emds = wasserstein.PairwiseEMD(R=R, norm=True, dtype=np.float32, verbose=0, num_threads=num_threads) + events = stacking(position_matrix, mat) + centroid_list = stacking(position_matrix, cl) + emds(events, centroid_list) + return emds.emds() + +def _calculate_rfo(labels, cluster_index): + """Calculates the spatial and total RFO for a given cluster using xarray. + + PARAMETERS + ---------- + labels : xr.DataArray + label data + cluster_index : int + cluster value to count + RETURNS + ------- + tuple of (RFO (array), total RFO (float)) + """ + if not isinstance(labels, xr.DataArray): + warnings.warn(f"ERROR: Input 'labels' must be an xarray.DataArray, got {type(labels)}") + return None + + # Spatial RFO map (% of time steps in the cluster) + rfo_map = (labels == cluster_index).mean(dim="time", skipna=True) * 100 + + # Total area-weighted RFO (scalar %) + weights = np.cos(np.deg2rad(labels.lat)) + total_rfo_num = (labels == cluster_index).weighted(weights).sum() + total_rfo_denom = (labels >= 0).weighted(weights).sum() + + total_rfo = (total_rfo_num / total_rfo_denom * 100).item() if total_rfo_denom > 0 else 0 + return rfo_map, total_rfo +def load_reference_data(adfobj, varname): + """Load reference data, which could be an observation or a baseline simulation.""" + # ... (function content is identical to your original) base_name = adfobj.data.ref_case_label ref_var_nam = adfobj.data.ref_var_nam[varname] # shuld work for obs/sim print(f"[CRA: load_reference_data] {base_name = }, {ref_var_nam = }") @@ -692,547 +561,216 @@ def load_reference_data(adfobj, varname): print(f"[CRA: load_reference_data] returning simulation dataarray") return adfobj.data.load_reference_timeseries_da(varname) - -def load_cluster_centers(adf, cluster_spec: str | Path, variablename: str) -> np.ndarray | None: - """ - Loads cluster center data from a specified source. - - Args: - cluster_spec: A string ('wasserstein', 'euclidean', or a file path) - or a Path object pointing to a .npy or .nc file. - variablename: The name of the variable to look up in ALL_VARS to - determine the data product name. - - Returns: - A NumPy array containing the cluster center data, or None if an error occurs. - """ +def load_cluster_centers(adf, cluster_spec, variablename): + """Loads cluster center data from a specified source.""" if isinstance(cluster_spec, str): if cluster_spec in ('wasserstein', 'euclidean'): + if cluster_spec == 'wasserstein': + algo = 'emd' + else: + algo = 'euclidean' try: # Use variablename to find the data product name data = ALL_VARS[variablename].product_name obs_data_loc = Path(adf.get_basic_info("obs_data_loc")) - data_key = f"{data}_{cluster_spec}_centers" + data_key = f"{data}_{algo}_centers" cluster_centers_path = adf.variable_defaults[data_key]["obs_file"] file_path = obs_data_loc / cluster_centers_path except KeyError as e: - print( - f"[ERROR] Could not find '{variablename}' in ALL_VARS or default file path for '{cluster_spec}'. " + warnings.warn( + f"[ERROR] Could not find '{variablename}' in ALL_VARS or default file path for '{cluster_spec} with {algo = }'. " f"Original error: {e}" ) return None else: - # Assume it's a direct file path - file_path = Path(cluster_spec) - + file_path = Path(cluster_spec) # Assume it's a direct file path elif isinstance(cluster_spec, Path): file_path = cluster_spec else: - print(f"[ERROR] cluster_spec must be a string or a Path object, but got {type(cluster_spec)}") + warnings.warn(f"ERROR: cluster_spec must be a string or Path, not {type(cluster_spec)}") return None - # Check that the path exists before trying to load if not file_path.exists(): - print(f"[ERROR] File not found at: {file_path}") + warnings.warn(f"[ERROR] Cluster center file not found at: {file_path}") return None - # Load the data based on the file extension try: if file_path.suffix == ".nc": - with xr.open_dataset(file_path) as ds: - if 'centers' not in ds: - print(f"[ERROR] NetCDF file {file_path.name} does not contain a 'centers' variable.") - return None - cl = ds['centers'].values + cl = xr.open_dataset(file_path)['centers'].values elif file_path.suffix == ".npy": cl = np.load(file_path) else: - print(f"[ERROR] Unsupported file type: {file_path.suffix}") + warnings.warn(f"[ERROR] Unsupported file type: {file_path.suffix}") return None except Exception as e: - print(f"[ERROR] An unexpected error occurred while loading {file_path.name}: {e}") + warnings.warn(f"[ERROR] Could not load {file_path.name}: {e}") return None return cl +# -------- +# PLOTTING +# -------- +def _plot_map(ax, lon, lat, data, title, cmap, vmin, vmax): + """Plots a single lat/lon map on a given axis.""" + ax.set_global() + ax.coastlines() + mesh = ax.pcolormesh( + lon, lat, data, + transform=ccrs.PlateCarree(), + rasterized=True, + cmap=cmap, + vmin=vmin, + vmax=vmax, + ) + ax.set_title(title, pad=4) + return mesh -def compute_cluster_labels(ds, tau_var_name, ht_var_name, cl, wasserstein_or_euclidean): - # Selcting only the relevant data and stacking it to shape n_histograms, n_tau * n_pc - dims = list(ds.dims) - dims.remove(tau_var_name) - dims.remove(ht_var_name) - histograms = ds.stack(spacetime=(dims), tau_ht=(tau_var_name, ht_var_name)) - weights = np.cos( - np.deg2rad(histograms.lat.values) - ) # weights array to use with emd-kmeans - - # Turning into a numpy array for clustering - mat = histograms.values +def _configure_map_axes(ax, is_left, is_bottom): + """Configures ticks and labels for a map subplot.""" + if is_left: + ax.set_yticks([-60, -30, 0, 30, 60], crs=ccrs.PlateCarree()) + ax.yaxis.set_major_formatter(LatitudeFormatter()) + if is_bottom: + ax.set_xticks([-120, -60, 0, 60, 120], crs=ccrs.PlateCarree()) + ax.xaxis.set_major_formatter(LongitudeFormatter(zero_direction_label=True)) - # Removing all histograms with 1 or more nans in them - indices = np.arange(len(mat)) - is_valid = ~np.isnan(mat.mean(axis=1)) - is_valid = is_valid.astype(np.int32) - valid_inds = indices[is_valid == 1] - mat = mat[valid_inds] - weights = weights[valid_inds] - print(f"\t Fitting data") +def _add_colorbar(fig, ax, cmap, norm): + """Add colorbar to the figure.""" + p = [0, 0.2, 1, 2, 3, 4, 6, 8, 10, 15, 99] + # cbar_ax = fig.add_axes([1.01, 0.25, 0.04, 0.5]) + sm = ScalarMappable(norm=norm, cmap=cmap) + # sm.set_array([]) # Required for colorbar + cb = fig.colorbar(sm, + ax=ax, + orientation='vertical', + fraction=0.025, + pad=0.02, + aspect=40, + ticks=p) + cb.set_label(label="Cloud Cover (%)", size=16) + cb.ax.tick_params(labelsize=14) - # Compute cluster labels - cluster_labels_temp = precomputed_clusters( - mat, cl, wasserstein_or_euclidean, ds - ) - # taking the flattened cluster_labels_temp array, - # and turning it into a datarray the shape of ds.var_name, - # and reinserting NaNs in place of missing data - cluster_labels = np.full(len(indices), np.nan, dtype=np.int32) - cluster_labels[valid_inds] = cluster_labels_temp - cluster_labels = xr.DataArray( - data=cluster_labels, - coords={"spacetime": histograms.spacetime}, - dims=("spacetime"), - ) - cluster_labels = cluster_labels.unstack() - return cluster_labels +def _add_colorbar2(fig, ax, mappable, label): + """Adds a colorbar next to a given axis.""" + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.1, axes_class=plt.Axes) + cb = fig.colorbar(mappable, cax=cax) + cb.set_label(label=label) -def spatial_subset(ds_o, lat_range, lon_range): - # Selecting lat range - if lat_range: - if ds_o.lat[0] > ds_o.lat[-1]: - ds_o = ds_o.sel(lat=slice(lat_range[1], lat_range[0])) - else: - ds_o = ds_o.sel(lat=slice(lat_range[0], lat_range[1])) +def plot_rfo_maps(test_labels, ref_labels, adf, field): + """ + Plots Relative Frequency of Occurrence (RFO) maps for test, reference, + and their difference for each cloud regime. + """ + k = test_labels.attrs.get("k", int(np.nanmax(test_labels.values)) + 1) + obs_or_base = "Observation" if adf.compare_obs else "Baseline" + plt.rcParams.update({"font.size": 13, "figure.dpi": 200}) + for cluster in range(k): + fig, axes = plt.subplots( + nrows=2, ncols=2, + subplot_kw={"projection": ccrs.PlateCarree()}, + figsize=(12, 7) + ) + fig.subplots_adjust(wspace=0.15, hspace=0.15) + ax = axes.ravel() + + # 1. Reference RFO + rfo_ref, total_rfo_ref = _calculate_rfo(ref_labels, cluster) + mesh1 = _plot_map(ax[0], rfo_ref.lon, rfo_ref.lat, rfo_ref, + f"{obs_or_base}, RFO = {total_rfo_ref:.1f}%", + "GnBu", 0, 100) + + # 2. Test Case RFO + rfo_test, total_rfo_test = _calculate_rfo(test_labels, cluster) + mesh2 = _plot_map(ax[1], rfo_test.lon, rfo_test.lat, rfo_test, + f"Test Case, RFO = {total_rfo_test:.1f}%", + "GnBu", 0, 100) + _add_colorbar2(fig, ax[1], mesh2, "RFO (%)") + + # 3. Difference RFO (regrid if necessary) + if rfo_ref.shape != rfo_test.shape: + rfo_ref = rfo_ref.interp_like(rfo_test, method="nearest") + + rfo_diff = rfo_test - rfo_ref + total_rfo_diff = total_rfo_test - total_rfo_ref + mesh3 = _plot_map(ax[2], rfo_diff.lon, rfo_diff.lat, rfo_diff, + f"Test - {obs_or_base}, ΔRFO = {total_rfo_diff:.1f}%", + "coolwarm", -100, 100) + _add_colorbar2(fig, ax[2], mesh3, "ΔRFO (%)") + + # Configure all axes + _configure_map_axes(ax[0], is_left=True, is_bottom=False) + _configure_map_axes(ax[1], is_left=False, is_bottom=False) + _configure_map_axes(ax[2], is_left=True, is_bottom=True) + # Manually set ticks for bottom right axis + ax[3].remove() + + fig.suptitle(f"CR{cluster+1} Relative Frequency of Occurrence", fontsize=16, y=0.95) + + # Save figure + save_path = Path(adf.plot_location[0]) / f"{field}_CR{cluster+1}_LatLon_mean.png" + plt.savefig(save_path, bbox_inches='tight') - # Selecting Lon range - if lon_range: - if ds_o.lon[0] > ds_o.lon[-1]: - ds_o = ds_o.sel(lon=slice(lon_range[1], lon_range[0])) - else: - ds_o = ds_o.sel(lon=slice(lon_range[0], lon_range[1])) - return ds_o + if adf.create_html: + adf.add_website_data(str(save_path), field, case_name=None, multi_case=True) + + plt.close(fig) -def temporal_subset(ds, time_range): - """ - Subset dataset by time range, handling various None/empty cases. - - Parameters: - ----------- - ds : xarray.Dataset - Input dataset with time dimension - time_range : list, tuple, or None - Time range as [start, end]. Can contain None, "None", or be None/empty - - Returns: - -------- - ds : xarray.Dataset - Time-subsetted dataset, or original if no valid time range - """ - def is_valid_time(value): - """Check if a time value is valid (not None, "None", or empty string)""" - return value is not None and value != "None" and value != "" - - # Handle None, empty, or too short time_range - if not time_range or len(time_range) < 2: - return ds - - start, end = time_range[0], time_range[1] - - # Check if we have any valid time values - start_valid = is_valid_time(start) - end_valid = is_valid_time(end) +def plot_hists_baseline(fld, cl, cluster_labels, cluster_labels_o, histograms, histograms_ref, ht_var_name, tau_var_name, htcoord, taucoord, adf): + """Plot cloud regime centers for observations, baseline, and test case.""" + plot_data = _prepare_plot_data(fld, cl, cluster_labels, cluster_labels_o, histograms, histograms_ref, ht_var_name, tau_var_name, htcoord, taucoord) + plot_data['columns'] = ['observation', 'baseline', 'test_case'] + plot_data['figsize'] = (17, plot_data['fig_height']) + plot_data['save_suffix'] = '_CR_centers' - if not start_valid and not end_valid: - return ds # No valid time range, return original + _plot_cloud_regimes(plot_data, adf, baseline_mode=True) + + +def plot_hists_obs(fld, cl, cluster_labels, cluster_labels_o, histograms, histograms_ref, ht_var_name, tau_var_name, htcoord, taucoord, adf): + """Plot cloud regime centers for observations and test case.""" + plot_data = _prepare_plot_data(fld, cl, cluster_labels, cluster_labels_o, histograms, histograms_ref, ht_var_name, tau_var_name, htcoord, taucoord) + plot_data['columns'] = ['observation', 'test_case'] + plot_data['figsize'] = (12, plot_data['fig_height']) + plot_data['save_suffix'] = '_CR_centers' - # Set defaults for invalid values - if not start_valid: - start = ds.time[0] - if not end_valid: - end = ds.time[-1] + _plot_cloud_regimes(plot_data, adf, baseline_mode=False) + + +def _prepare_plot_data(fld, cl, cluster_labels, cluster_labels_o, histograms, histograms_ref, ht_var_name, tau_var_name, htcoord, taucoord): + """Prepare all data needed for plotting. - return ds.sel(time=slice(start, end)) + fld: str + name of variable -def select_valid_tau_height(ds, tau_var_name, ht_var_name, max_value=9999999999999): - """ - Select only valid tau and height/pressure range from dataset. + cl + cluster centers - Excludes failed retrievals (typically -1 values) by selecting from 0 to max_value. - Handles both pressure (decreasing) and altitude (increasing) coordinate ordering. + cluster_labels: xr.DataArray + labels for case ([time], lat, lon) + cluster_labels_o: array-like + labels for reference data + histograms: xr.DataAray (?) - Parameters: - ----------- - ds : xarray.Dataset - Input dataset containing tau and height/pressure variables - tau_var_name : str - Name of the tau variable - ht_var_name : str - Name of the height/pressure variable - max_value : int, optional - Maximum value for selection range (default: 9999999999999) - - Returns: - -------- - ds : xarray.Dataset - Dataset with valid tau and height range selected """ - # Select valid tau range (exclude negative/failed retrievals) - tau_selection = {tau_var_name: slice(0, max_value)} + data = ALL_VARS[fld].product_name + k = len(cl) + ylabels = htcoord.values + xlabels = taucoord.values - # Handle height/pressure coordinate ordering - # Pressure: decreasing (high to low) -> slice(max, 0) - # Altitude: increasing (low to high) -> slice(0, max) - if ds[ht_var_name][0] > ds[ht_var_name][-1]: - # Decreasing coordinate (pressure) - ht_selection = {ht_var_name: slice(max_value, 0)} - else: - # Increasing coordinate (altitude) - ht_selection = {ht_var_name: slice(0, max_value)} + # Create meshgrid + X2, Y2 = np.meshgrid(np.arange(len(xlabels) + 1), np.arange(len(ylabels) + 1)) - # Apply selections - return ds.sel(tau_selection).sel(ht_selection) - - -def finish_cluster_labels(ds_b, tau_var_name, ht_var_name): - """ - Compute cluster labels for cloud regime analysis. + # Calculate figure height + fig_height = (1 + 10 / 3 * ceil(k / 3)) * 3 - Parameters: - ----------- - ds_b : xarray.Dataset - Input dataset containing histogram data - tau_var_name : str - Name of tau variable - ht_var_name : str - Name of height variable - - Returns: - -------- - cluster_labels_b : xarray.DataArray - Cluster labels with same coordinates as input, NaN for invalid data - """ - # Selcting only the relevant data and - # stacking it to shape n_histograms, n_tau * n_pc - other_dims = [dim for dim in ds_b.dims if dim not in (tau_var_name, ht_var_name)] - histograms_b = ds_b.stack( - spacetime=other_dims, - tau_ht=(tau_var_name, ht_var_name) - ) - # convert to numpy array & compute weights - # TODO: weights abstraction - weights_b = np.cos(np.deg2rad(histograms_b.lat.values)) - mat_b = histograms_b.values - - # Find valid histograms (no NaNs) using boolean indexing - is_valid = ~np.isnan(mat_b).any(axis=1) - if not is_valid.any(): - print("[cloud_regime_analysis_error] No valid histograms found") - return None - # Check for negative values in valid data only - if (mat_b[is_valid] < 0).any(): - print(f"[cloud_regime_analysis_error] Found negative values in data. " - f"If these are fill values, convert to NaNs and try again") - return None - # Compute clusters only for valid data - valid_mat = mat_b[is_valid] - valid_weights = weights_b[is_valid] - cluster_labels_valid = precomputed_clusters( - valid_mat, cl, wasserstein_or_euclidean, ds_b - ) - - # Create output array with NaNs, then fill valid positions - cluster_labels_flat = np.full(len(mat_b), np.nan, dtype=np.float32) - cluster_labels_flat[is_valid] = cluster_labels_valid - - # Convert back to DataArray and unstack - cluster_labels_b = xr.DataArray( - data=cluster_labels_flat, - coords={"spacetime": histograms_b.spacetime}, - dims=("spacetime"), - name="cluster_labels" - ) - return cluster_labels_b.unstack() - - -################ -# REGRIDDING -################ - -def make_se_regridder(weight_file, Method="conservative"): - weights = xr.open_dataset(weight_file) - in_shape = weights.src_grid_dims.load().data - - # Since xESMF expects 2D vars, we'll insert a dummy dimension of size-1 - if len(in_shape) == 1: - in_shape = [1, in_shape.item()] - - # output variable shape - out_shape = weights.dst_grid_dims.load().data.tolist()[::-1] - - dummy_in = xr.Dataset( - { - "lat": ("lat", np.empty((in_shape[0],))), - "lon": ("lon", np.empty((in_shape[1],))), - } - ) - dummy_out = xr.Dataset( - { - "lat": ("lat", weights.yc_b.data.reshape(out_shape)[:, 0]), - "lon": ("lon", weights.xc_b.data.reshape(out_shape)[0, :]), - } - ) - regridder = xesmf.Regridder( - dummy_in, - dummy_out, - weights=weight_file, - # results seem insensitive to this method choice - # choices are coservative_normed, coservative, and bilinear - method=Method, - reuse_weights=True, - periodic=True, - ) - return regridder - - -def regrid_se_data_bilinear(regridder, data_to_regrid, column_dim_name="ncol"): - if isinstance(data_to_regrid, xr.Dataset): - vars_with_ncol = [ - name - for name in data_to_regrid.variables - if column_dim_name in data_to_regrid[name].dims - ] - updated = data_to_regrid.copy().update( - data_to_regrid[vars_with_ncol] - .transpose(..., "ncol") - .expand_dims("dummy", axis=-2) - ) - elif isinstance(data_to_regrid, xr.DataArray): - updated = data_to_regrid.transpose(..., column_dim_name).expand_dims( - "dummy", axis=-2 - ) - else: - raise ValueError( - f"Something is wrong because the data to regrid isn't xarray: {type(data_to_regrid)}" - ) - regridded = regridder(updated) - return regridded - -# -# LAND MASK CODE (probably need to simplify and move out of here) -# -def apply_land_ocean_mask(ds, only_ocean_or_land, landfrac_present=None): - """ - Apply land or ocean mask to dataset. - - Parameters: - ----------- - ds : xarray.Dataset - Input dataset with lat/lon coordinates - only_ocean_or_land : str or False - "L" for land only, "O" for ocean only, False for no masking - landfrac_present : bool, optional - Whether LANDFRAC variable is available. Auto-detected if None. - - Returns: - -------- - ds : xarray.Dataset - Masked dataset, or None if invalid option - """ - # No masking requested - if only_ocean_or_land is False: - return ds - - # Validate input - if only_ocean_or_land not in ["L", "O"]: - print('[cloud_regime_analysis ERROR] Invalid option for only_ocean_or_land: ' - 'Please enter "O" for ocean only, "L" for land only, or set to False for both') - return None - - # Auto-detect LANDFRAC if not specified - if landfrac_present is None: - landfrac_present = "LANDFRAC" in ds.data_vars or "LANDFRAC" in ds.coords - - # Use LANDFRAC if available - if landfrac_present: - land_mask_value = 1 if only_ocean_or_land == "L" else 0 - return ds.where(ds.LANDFRAC == land_mask_value) - - # Otherwise use cartopy-based land mask - land_mask = create_land_mask(ds) - - # Make land mask broadcastable with dataset - land_mask = _make_mask_broadcastable(land_mask, ds) - - # Apply mask - mask_value = 1 if only_ocean_or_land == "L" else 0 - return ds.where(land_mask == mask_value) - - -def _make_mask_broadcastable(mask, ds): - """ - Make 2D land mask broadcastable with dataset by adding dimensions. - - Parameters: - ----------- - mask : numpy.ndarray - 2D mask array (lat, lon) - ds : xarray.Dataset - Target dataset - - Returns: - -------- - mask : numpy.ndarray - Broadcastable mask array - """ - # Add dimensions for any dims that aren't lat/lon - for i, dim in enumerate(ds.dims): - if dim not in ("lat", "lon"): - mask = np.expand_dims(mask, axis=i) - return mask - - -def create_land_mask(ds): - """ - Create land mask using cartopy Natural Earth data. - Improved version with better performance and cleaner code. - - Parameters: - ----------- - ds : xarray.Dataset - Dataset with lat/lon coordinates - - Returns: - -------- - land_mask : numpy.ndarray - 2D array (lat, lon) with 1 for land, 0 for ocean - """ - from cartopy import feature as cfeature - from shapely.geometry import Point - from shapely.prepared import prep - import numpy as np - from numba import njit - - # Get land polygons - land_110m = cfeature.NaturalEarthFeature("physical", "land", "110m") - land_polygons = [prep(geom) for geom in land_110m.geometries()] - # Create coordinate arrays - lats, lons = ds.lat.values, ds.lon.values - lon_grid, lat_grid = np.meshgrid(lons, lats) - # Flatten coordinates for easier processing - lon_flat, lat_flat = lon_grid.flatten(), lat_grid.flatten() - points = [Point(lon, lat) for lon, lat in zip(lon_flat, lat_flat)] - # Find land points - land_coords = [] - for polygon in land_polygons: - land_coords.extend([ - (point.x, point.y) for point in points if polygon.covers(point) - ]) - # Convert to numpy array for numba processing - land_array = np.array(land_coords) - coord_array = np.column_stack([lon_flat, lat_flat]) - # Use numba for fast coordinate matching - land_mask_flat = _find_land_points(coord_array, land_array) - # Reshape to original grid - return land_mask_flat.reshape(len(lats), len(lons)) - - -@njit() -def _find_land_points(coord_array, land_coords): - """ - Numba-compiled function to quickly identify land points. - - Parameters: - ----------- - coord_array : numpy.ndarray - Array of (lon, lat) coordinates - land_coords : numpy.ndarray - Array of known land coordinates - - Returns: - -------- - mask : numpy.ndarray - 1D mask array with 1 for land, 0 for ocean - """ - mask = np.zeros(len(coord_array), dtype=np.int32) - - for i in range(len(coord_array)): - coord = coord_array[i] - for j in range(len(land_coords)): - if np.allclose(coord, land_coords[j], atol=1e-10): - mask[i] = 1 - break - return mask - -import matplotlib.pyplot as plt -import matplotlib as mpl -import numpy as np -from math import ceil -import xarray as xr - -# ============================================================================= -# CORE PLOTTING FUNCTIONS -# ============================================================================= - -def plot_hists_baseline(fld, cl, cluster_labels, cluster_labels_o, histograms, histograms_ref, ht_var_name, tau_var_name, htcoord, taucoord, adf): - """Plot cloud regime centers for observations, baseline, and test case.""" - plot_data = _prepare_plot_data(fld, cl, cluster_labels, cluster_labels_o, histograms, histograms_ref, ht_var_name, tau_var_name, htcoord, taucoord) - plot_data['columns'] = ['observation', 'baseline', 'test_case'] - plot_data['figsize'] = (17, plot_data['fig_height']) - plot_data['save_suffix'] = '_CR_centers' - - _plot_cloud_regimes(plot_data, adf, baseline_mode=True) - - -def plot_hists_obs(fld, cl, cluster_labels, cluster_labels_o, histograms, histograms_ref, ht_var_name, tau_var_name, htcoord, taucoord, adf): - """Plot cloud regime centers for observations and test case.""" - plot_data = _prepare_plot_data(fld, cl, cluster_labels, cluster_labels_o, histograms, histograms_ref, ht_var_name, tau_var_name, htcoord, taucoord) - plot_data['columns'] = ['observation', 'test_case'] - plot_data['figsize'] = (12, plot_data['fig_height']) - plot_data['save_suffix'] = '_CR_centers' - - _plot_cloud_regimes(plot_data, adf, baseline_mode=False) - - -# ============================================================================= -# HELPER FUNCTIONS -# ============================================================================= - -def _prepare_plot_data(fld, cl, cluster_labels, cluster_labels_o, histograms, histograms_ref, ht_var_name, tau_var_name, htcoord, taucoord): - """Prepare all data needed for plotting. - - fld: str - name of variable - - cl - cluster centers - - cluster_labels: xr.DataArray - labels for case ([time], lat, lon) - cluster_labels_o: array-like - labels for reference data - histograms: xr.DataAray (?) - - """ - - print(f"[_prepare_plot_data] {histograms.coords = }") - - data = ALL_VARS[fld].product_name - k = len(cl) - ylabels = htcoord.values - xlabels = taucoord.values - - # Create meshgrid - X2, Y2 = np.meshgrid(np.arange(len(xlabels) + 1), np.arange(len(ylabels) + 1)) - - # Calculate figure height - fig_height = (1 + 10 / 3 * ceil(k / 3)) * 3 - - # Create weights for RFO calculations - weights = np.cos(np.deg2rad(cluster_labels.stack(z=("time", "lat", "lon")).lat.values)) - valid_inds = ~np.isnan(cluster_labels.stack(z=("time", "lat", "lon"))) - weights = weights[valid_inds] + # Create weights for RFO calculations + weights = np.cos(np.deg2rad(cluster_labels.stack(z=("time", "lat", "lon")).lat.values)) + valid_inds = ~np.isnan(cluster_labels.stack(z=("time", "lat", "lon"))) + weights = weights[valid_inds] return { 'field': fld, @@ -1307,7 +845,7 @@ def _plot_observation_column(ax_col, plot_data, cmap, norm, include_rfo=False): # Set title with optional RFO if include_rfo: - rfo = _calculate_rfo(plot_data['cluster_labels_o'], i) + rfo_map, rfo = _calculate_rfo(plot_data['cluster_labels_o'], i) ax_col[i].set_title(f"Observation CR {i+1}, RFO = {np.round(float(rfo), 1)}%") else: ax_col[i].set_title(f"Observation CR {i+1}") @@ -1317,7 +855,7 @@ def _plot_baseline_column(ax_col, plot_data, cmap, norm): """Plot the baseline cluster centers with weighted means and RFO.""" for i in range(plot_data['k']): # Calculate RFO - rfo = _calculate_rfo(cluster_labels_b, i) # Global variable + rfo_map, rfo = _calculate_rfo(cluster_labels_b, i) # Global variable # Calculate weighted mean wmean = _calculate_weighted_mean_xr(i, plot_data['cluster_labels_o'], plot_data['histograms_ref']) @@ -1329,74 +867,15 @@ def _plot_baseline_column(ax_col, plot_data, cmap, norm): def _plot_test_case_column(ax_col, plot_data, cmap, norm): """Plot the test case cluster centers with weighted means and RFO.""" for i in range(plot_data['k']): - # Calculate RFO - rfo = _calculate_rfo(plot_data['cluster_labels'], i) - - # Calculate weighted mean - wmean = _calculate_weighted_mean_xr(i, plot_data['cluster_labels'], plot_data['histograms']) - - # Plot - im = ax_col[i].pcolormesh(plot_data['X2'], plot_data['Y2'], wmean, norm=norm, cmap=cmap) - ax_col[i].set_title(f"Test Case CR {i+1}, RFO = {np.round(rfo, 1)}%") - - -def _calculate_rfo(cluster_labels, cluster_i): - """Calculate area-weighted relative frequency of occurrence for a cluster.""" - total_rfo_num = cluster_labels == cluster_i - total_rfo_num = np.sum(total_rfo_num * np.cos(np.deg2rad(cluster_labels.lat))) - - total_rfo_denom = cluster_labels >= 0 - total_rfo_denom = np.sum(total_rfo_denom * np.cos(np.deg2rad(cluster_labels.lat))) - - total_rfo = total_rfo_num / total_rfo_denom * 100 - return total_rfo.values - -def _calculate_weighted_mean_xr(cluster_i, cluster_labels, hists): - weights = np.cos(np.radians(hists['lat'])) - cluster_data = xr.where(cluster_labels==cluster_i, hists, np.nan) - dims = [dim for dim in hists.dims if dim in ["ncol","lat","lon","time"]] - return cluster_data.weighted(weights).mean(dim=dims) - -def _calculate_weighted_mean(cluster_i, cluster_labels, hists, weights, xlabels, ylabels): - """Calculate area-weighted mean histogram for a cluster. - - PARAMETERS - ---------- - cluster_i : int - cluster number (label) - cluster_labels : array-like - the ([time,] lat, lon) array of cluster labels for the data - hists: array-like - the ([time,] ht, tau, lat, lon) array of histograms - weights: array-like - area weights - xlabels: array - tau values - ylabels: array - height/pressure values - - RETURNS - ------- - wmean: array - Area-weigthed mean of the histograms in cluster (in percent) - """ - print(f"[_calculate_weighted_mean] cluster {cluster_i}, {cluster_labels.shape = }, {hists.shape = }, {weights.shape = }") - pts_i = np.where(cluster_labels == cluster_i) # identify points in cluster - n = pts_i.sum() # number of points in cluster - w = [pts_i] # cos(lat) - if n > 0: - weighted_hists = hists[indices_i] * weights[indices_i][:, np.newaxis] - wmean = np.sum(weighted_hists, axis=0) / np.sum(weights[indices_i]) - else: - wmean = np.zeros(len(xlabels) * len(ylabels)) - - # Reshape and convert to percentage if needed - wmean = wmean.reshape(len(xlabels), len(ylabels)).T - if np.max(wmean) <= 1: - wmean *= 100 + # Calculate RFO + rfo_map, rfo = _calculate_rfo(plot_data['cluster_labels'], i) - return wmean - + # Calculate weighted mean + wmean = _calculate_weighted_mean_xr(i, plot_data['cluster_labels'], plot_data['histograms']) + + # Plot + im = ax_col[i].pcolormesh(plot_data['X2'], plot_data['Y2'], wmean, norm=norm, cmap=cmap) + ax_col[i].set_title(f"Test Case CR {i+1}, RFO = {np.round(rfo, 1)}%") def _create_colormap(): """Create standardized colormap and normalization.""" @@ -1417,11 +896,6 @@ def _create_colormap(): norm = mpl.colors.BoundaryNorm(p, cmap.N, clip=True) return cmap, norm - -# ============================================================================= -# AXIS CONFIGURATION FUNCTIONS -# ============================================================================= - def _configure_axes(ax, data_product): """Configure axis ticks and labels based on data product.""" config_functions = { @@ -1507,12 +981,11 @@ def _add_figure_labels(fig, ax, plot_data, baseline_mode): fig.supylabel(f"Cloud-top {ht_label} ({ht_unit})", x=x_pos) # Title positioning - bbox = ax[1, 0].get_position() # Use first column for positioning fig.suptitle( f"{data} Cloud Regimes", - # x=0.5, - # y=bbox.p1[1] + (1 / fig_height * 0.5) + 0.007, + y=0.92, # Adjust this value slightly (e.g., 0.93) for perfect placement fontsize=18, + fontweight='bold' ) # X-label positioning @@ -1520,23 +993,6 @@ def _add_figure_labels(fig, ax, plot_data, baseline_mode): fig.supxlabel("Optical Depth", y=bbox.p0[1] - (1 / fig_height * 0.5) - 0.007) -def _add_colorbar(fig, ax, cmap, norm): - """Add colorbar to the figure.""" - p = [0, 0.2, 1, 2, 3, 4, 6, 8, 10, 15, 99] - # cbar_ax = fig.add_axes([1.01, 0.25, 0.04, 0.5]) - sm = ScalarMappable(norm=norm, cmap=cmap) - # sm.set_array([]) # Required for colorbar - cb = fig.colorbar(sm, - ax=ax, - orientation='vertical', - fraction=0.025, - pad=0.02, - aspect=40, - ticks=p) - cb.set_label(label="Cloud Cover (%)", size=16) - cb.ax.tick_params(labelsize=14) - - def _save_figure(fig, plot_data, adf): """Save figure and add to website if requested.""" data = plot_data['data_product'] @@ -1549,4 +1005,265 @@ def _save_figure(fig, plot_data, adf): adf.add_website_data(save_path + ".png", plot_data['field'], case_name=None, multi_case=True) else: # For baseline comparison mode - adf.add_website_data(save_path + ".png", plot_data['field'], adf.get_baseline_info("cam_case_name")) \ No newline at end of file + adf.add_website_data(save_path + ".png", plot_data['field'], adf.get_baseline_info("cam_case_name")) + +# --------------------- +# Data handling helpers +# --------------------- +def spatial_subset(ds, lat_range, lon_range): + """Subsets a DataArray or Dataset by latitude and longitude ranges.""" + if lat_range: + if ds.lat[0] > ds.lat[-1]: + ds = ds.sel(lat=slice(lat_range[1], lat_range[0])) + else: + ds = ds.sel(lat=slice(lat_range[0], lat_range[1])) + if lon_range: + if ds.lon[0] > ds.lon[-1]: + ds = ds.sel(lon=slice(lon_range[1], lon_range[0])) + else: + ds = ds.sel(lon=slice(lon_range[0], lon_range[1])) + return ds + +def temporal_subset(ds, time_range): + """Subsets a DataArray or Dataset by a time range.""" + def is_valid_time(value): + return value is not None and value != "None" and value != "" + + if not time_range or len(time_range) < 2: return ds + + start, end = time_range[0], time_range[1] + start_valid, end_valid = is_valid_time(start), is_valid_time(end) + + if not start_valid and not end_valid: return ds + + start = start if start_valid else ds.time.min().item() + end = end if end_valid else ds.time.max().item() + + return ds.sel(time=slice(start, end)) + +def select_valid_tau_height(ds, tau_var_name, ht_var_name): + """Selects valid (non-negative) tau and height/pressure ranges.""" + ds = ds.sel({tau_var_name: slice(0, None)}) + if ds[ht_var_name][0] > ds[ht_var_name][-1]: # Pressure (decreasing) + ds = ds.sel({ht_var_name: slice(None, 0)}) + else: # Altitude (increasing) + ds = ds.sel({ht_var_name: slice(0, None)}) + return ds + +def _calculate_weighted_mean_xr(cluster_i, cluster_labels, hists): + weights = np.cos(np.radians(hists['lat'])) + cluster_data = xr.where(cluster_labels==cluster_i, hists, np.nan) + dims = [dim for dim in hists.dims if dim in ["ncol","lat","lon","time"]] + return cluster_data.weighted(weights).mean(dim=dims) + + +# -------------- +# LAND MASK CODE (probably need to simplify and move out of here) +# -------------- +def apply_land_ocean_mask(ds, only_ocean_or_land, landfrac_present=None): + """ + Apply land or ocean mask to dataset. + + Parameters: + ----------- + ds : xarray.Dataset + Input dataset with lat/lon coordinates + only_ocean_or_land : str or False + "L" for land only, "O" for ocean only, False for no masking + landfrac_present : bool, optional + Whether LANDFRAC variable is available. Auto-detected if None. + + Returns: + -------- + ds : xarray.Dataset + Masked dataset, or None if invalid option + """ + # No masking requested + if only_ocean_or_land is False: + return ds + + # Validate input + if only_ocean_or_land not in ["L", "O"]: + warnings.warn(f'[ERROR] Invalid option for only_ocean_or_land: {only_ocean_or_land}' + 'Please enter "O" for ocean only, "L" for land only, or set to False for both') + return None + + # Auto-detect LANDFRAC if not specified + if landfrac_present is None: + landfrac_present = "LANDFRAC" in ds.data_vars or "LANDFRAC" in ds.coords + + # Use LANDFRAC if available + if landfrac_present: + land_mask_value = 1 if only_ocean_or_land == "L" else 0 + return ds.where(ds.LANDFRAC == land_mask_value) + + # Otherwise use cartopy-based land mask + land_mask = create_land_mask(ds) + + # Make land mask broadcastable with dataset + land_mask = _make_mask_broadcastable(land_mask, ds) + + # Apply mask + mask_value = 1 if only_ocean_or_land == "L" else 0 + return ds.where(land_mask == mask_value) + + +def _make_mask_broadcastable(mask, ds): + """ + Make 2D land mask broadcastable with dataset by adding dimensions. + + Parameters: + ----------- + mask : numpy.ndarray + 2D mask array (lat, lon) + ds : xarray.Dataset + Target dataset + + Returns: + -------- + mask : numpy.ndarray + Broadcastable mask array + """ + # Add dimensions for any dims that aren't lat/lon + for i, dim in enumerate(ds.dims): + if dim not in ("lat", "lon"): + mask = np.expand_dims(mask, axis=i) + return mask + + +def create_land_mask(ds): + """ + Create land mask using cartopy Natural Earth data. + Improved version with better performance and cleaner code. + + Parameters: + ----------- + ds : xarray.Dataset + Dataset with lat/lon coordinates + + Returns: + -------- + land_mask : numpy.ndarray + 2D array (lat, lon) with 1 for land, 0 for ocean + """ + from cartopy import feature as cfeature + from shapely.geometry import Point + from shapely.prepared import prep + import numpy as np + + #TODO: Replace this with with a regionmask approach (no numba needed) + # Get land polygons + land_110m = cfeature.NaturalEarthFeature("physical", "land", "110m") + land_polygons = [prep(geom) for geom in land_110m.geometries()] + # Create coordinate arrays + lats, lons = ds.lat.values, ds.lon.values + lon_grid, lat_grid = np.meshgrid(lons, lats) + # Flatten coordinates for easier processing + lon_flat, lat_flat = lon_grid.flatten(), lat_grid.flatten() + points = [Point(lon, lat) for lon, lat in zip(lon_flat, lat_flat)] + # Find land points + land_coords = [] + for polygon in land_polygons: + land_coords.extend([ + (point.x, point.y) for point in points if polygon.covers(point) + ]) + # Convert to numpy array for numba processing + land_array = np.array(land_coords) + coord_array = np.column_stack([lon_flat, lat_flat]) + # Use numba for fast coordinate matching + land_mask_flat = _find_land_points(coord_array, land_array) + # Reshape to original grid + return land_mask_flat.reshape(len(lats), len(lons)) + + +@njit() +def _find_land_points(coord_array, land_coords): + """ + Numba-compiled function to quickly identify land points. + + Parameters: + ----------- + coord_array : numpy.ndarray + Array of (lon, lat) coordinates + land_coords : numpy.ndarray + Array of known land coordinates + + Returns: + -------- + mask : numpy.ndarray + 1D mask array with 1 for land, 0 for ocean + """ + mask = np.zeros(len(coord_array), dtype=np.int32) + + for i in range(len(coord_array)): + coord = coord_array[i] + for j in range(len(land_coords)): + if np.allclose(coord, land_coords[j], atol=1e-10): + mask[i] = 1 + break + return mask + +#--------------------- +# Regridding functions +#--------------------- +def make_se_regridder(weight_file, Method="conservative"): + weights = xr.open_dataset(weight_file) + in_shape = weights.src_grid_dims.load().data + + # Since xESMF expects 2D vars, we'll insert a dummy dimension of size-1 + if len(in_shape) == 1: + in_shape = [1, in_shape.item()] + + # output variable shape + out_shape = weights.dst_grid_dims.load().data.tolist()[::-1] + + dummy_in = xr.Dataset( + { + "lat": ("lat", np.empty((in_shape[0],))), + "lon": ("lon", np.empty((in_shape[1],))), + } + ) + dummy_out = xr.Dataset( + { + "lat": ("lat", weights.yc_b.data.reshape(out_shape)[:, 0]), + "lon": ("lon", weights.xc_b.data.reshape(out_shape)[0, :]), + } + ) + regridder = xesmf.Regridder( + dummy_in, + dummy_out, + weights=weight_file, + # results seem insensitive to this method choice + # choices are coservative_normed, coservative, and bilinear + method=Method, + reuse_weights=True, + periodic=True, + ) + return regridder + + +def regrid_se_data_bilinear(regridder, data_to_regrid, column_dim_name="ncol"): + if isinstance(data_to_regrid, xr.Dataset): + vars_with_ncol = [ + name + for name in data_to_regrid.variables + if column_dim_name in data_to_regrid[name].dims + ] + updated = data_to_regrid.copy().update( + data_to_regrid[vars_with_ncol] + .transpose(..., "ncol") + .expand_dims("dummy", axis=-2) + ) + elif isinstance(data_to_regrid, xr.DataArray): + updated = data_to_regrid.transpose(..., column_dim_name).expand_dims( + "dummy", axis=-2 + ) + else: + warnings.warn( + f"[ERROR] Something is wrong because the data to regrid isn't xarray: {type(data_to_regrid)}" + ) + return None + regridded = regridder(updated) + return regridded + + From dcd377df022bf5cb2bf2b8990f83e3f3ce6a5a50 Mon Sep 17 00:00:00 2001 From: Brian Medeiros Date: Mon, 15 Sep 2025 12:33:52 -0600 Subject: [PATCH 5/5] remove dups from variable defaults --- lib/adf_variable_defaults.yaml | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/lib/adf_variable_defaults.yaml b/lib/adf_variable_defaults.yaml index f1b5b1482..437d28fbf 100644 --- a/lib/adf_variable_defaults.yaml +++ b/lib/adf_variable_defaults.yaml @@ -2231,11 +2231,6 @@ CLWMODIS: pct_diff_contour_levels: [-100,-75,-50,-40,-30,-20,-10,-8,-6,-4,-2,0,2,4,6,8,10,20,30,40,50,75,100] pct_diff_colormap: "PuOr_r" -CLD_MISR: - category: "Clouds" - obs_file: 'MISR_obs_data.nc' - obs_name: "MISR" - CLMODIS: category: "Clouds" obs_file: 'MODIS_obs_data.nc' @@ -2246,11 +2241,6 @@ CLD_MISR: obs_file: 'MISR_obs_data.nc' obs_name: "MISR" -CLMODIS: - category: "Clouds" - obs_file: 'MODIS_obs_data.nc' - obs_name: "MODIS" - FISCCP1_COSP: category: "Clouds" obs_file: 'ISCCP_obs_data.nc'