From a1dd01db21d1589f617361836d39f9789d504f17 Mon Sep 17 00:00:00 2001 From: Brian Medeiros Date: Thu, 8 Jan 2026 13:35:32 -0700 Subject: [PATCH 1/4] refactor taylor diagram; first draft --- lib/adf_variable_defaults.yaml | 3 + scripts/plotting/cam_taylor_diagram.py | 383 ++++++++++++------------- 2 files changed, 191 insertions(+), 195 deletions(-) diff --git a/lib/adf_variable_defaults.yaml b/lib/adf_variable_defaults.yaml index ed5be6da0..2fd130e11 100644 --- a/lib/adf_variable_defaults.yaml +++ b/lib/adf_variable_defaults.yaml @@ -1825,6 +1825,9 @@ LANDFRAC: category: "Surface variables" pct_diff_contour_levels: [-100,-75,-50,-40,-30,-20,-10,-8,-6,-4,-2,0,2,4,6,8,10,20,30,40,50,75,100] pct_diff_colormap: "PuOr_r" + obs_file: "ERA5_LSM_1deg_conservativeregrid.nc" + obs_name: "ERA5" + obs_var_name: "LANDFRAC" #+++++++++++++++++ # Category: State diff --git a/scripts/plotting/cam_taylor_diagram.py b/scripts/plotting/cam_taylor_diagram.py index e05a33f72..6a029714e 100644 --- a/scripts/plotting/cam_taylor_diagram.py +++ b/scripts/plotting/cam_taylor_diagram.py @@ -7,6 +7,8 @@ It depends on an ADF instance to obtain the `climo` files. It is designed to have one "reference" case (could be observations) and arbitrary test cases. When multiple test cases are provided, they are plotted with different colors. + +NOTE: THIS IS A DRAFT REFACTORING TO ALLOW OBSERVATIONS (.b) ''' # # --- imports and configuration --- @@ -30,17 +32,10 @@ # --- Main Function Shares Name with Module: cam_taylor_diagram --- # def cam_taylor_diagram(adfobj): - - #Notify user that script has started: + """Create Taylor diagrams for specified configuration.""" msg = "\n Generating Taylor Diagrams..." print(f"{msg}\n {'-' * (len(msg)-3)}") - # Taylor diagrams currently don't work for model to obs comparison - # If compare_obs is set to True, then skip this script: - if adfobj.compare_obs: - print("\tTaylor diagrams don't work when doing model vs obs, so Taylor diagrams will be skipped.") - return - # Extract needed quantities from ADF object: # ----------------------------------------- # Case names: @@ -54,8 +49,6 @@ def cam_taylor_diagram(adfobj): syear_cases = adfobj.climo_yrs["syears"] eyear_cases = adfobj.climo_yrs["eyears"] - case_climo_loc = adfobj.get_cam_info('cam_climo_loc', required=True) - # ADF variable which contains the output path for plots and tables: plot_location = adfobj.plot_location if not plot_location: @@ -75,22 +68,11 @@ def cam_taylor_diagram(adfobj): else: plot_loc = Path(plot_location) - # CAUTION: - # "data" here refers to either obs or a baseline simulation, - # Until those are both treated the same (via intake-esm or similar) - # we will do a simple check and switch options as needed: - if adfobj.get_basic_info("compare_obs"): - data_name = "obs" # does not get used, is just here as a placemarker - data_list = adfobj.read_config_var("obs_type_list") # Double caution! - data_loc = adfobj.get_basic_info("obs_climo_loc", required=True) - else: - data_name = adfobj.get_baseline_info('cam_case_name', required=True) - data_list = data_name # should not be needed (?) - data_loc = adfobj.get_baseline_info("cam_climo_loc", required=True) - #Grab baseline case nickname - base_nickname = adfobj.case_nicknames["base_nickname"] - #End if + # reference data set(s) -- if comparing with obs, these are dicts. + data_name = adfobj.data.ref_case_label + data_loc = adfobj.data.ref_data_loc + base_nickname = adfobj.data.ref_nickname #Extract baseline years (which may be empty strings if using Obs): syear_baseline = adfobj.climo_yrs["syear_baseline"] @@ -110,16 +92,24 @@ def cam_taylor_diagram(adfobj): redo_plot = adfobj.get_basic_info('redo_plot') print(f"\t NOTE: redo_plot is set to {redo_plot}") - #Check if the variables needed for the Taylor diags are present, - #If not then skip this script: + # Check for required variables taylor_var_set = {'U', 'PSL', 'SWCF', 'LWCF', 'LANDFRAC', 'TREFHT', 'TAUX', 'RELHUM', 'T'} - if not taylor_var_set.issubset(adfobj.diag_var_list) or \ - (not ('PRECT' in adfobj.diag_var_list) and (not ('PRECL' in adfobj.diag_var_list) or not ('PRECC' in adfobj.diag_var_list))): - print("\tThe Taylor Diagrams require the variables: ") - print("\tU, PSL, SWCF, LWCF, PRECT (or PRECL and PRECC), LANDFRAC, TREFHT, TAUX, RELHUM,T") - print("\tSome variables are missing so Taylor diagrams will be skipped.") + available_vars = set(adfobj.diag_var_list) + missing_vars = taylor_var_set - available_vars + # Check for precipitation (Needs PRECT OR both PRECL and PRECC) + has_prect = 'PRECT' in available_vars + has_precl_precc = {'PRECL', 'PRECC'}.issubset(available_vars) + if missing_vars or not (has_prect or has_precl_precc): + print("\tTaylor Diagrams skipped due to missing variables:") + if missing_vars: + print(f"\t - Missing: {', '.join(sorted(missing_vars))}") + if not (has_prect or has_precl_precc): + if not has_prect: + print("\t - Missing: PRECT (Alternative PRECL + PRECC also incomplete)") + print("\n\tFull requirement: U, PSL, SWCF, LWCF, LANDFRAC, TREFHT, TAUX, RELHUM, T,") + print("\tAND (PRECT OR both PRECL & PRECC)") return - #End if + #Set seasonal ranges: seasons = {"ANN": np.arange(1,13,1), @@ -138,8 +128,7 @@ def cam_taylor_diagram(adfobj): # # LOOP OVER SEASON # - for s in seasons: - + for s in seasons.items(): plot_name = plot_loc / f"TaylorDiag_{s}_Special_Mean.{plot_type}" print(f"\t - Plotting Taylor Diagram, {s}") @@ -162,12 +151,24 @@ def cam_taylor_diagram(adfobj): # LOOP OVER VARIABLES # for v in var_list: - base_x = _retrieve(adfobj, v, data_name, data_loc) # get the baseline field - for casenumber, case in enumerate(case_names): # LOOP THROUGH CASES - case_x = _retrieve(adfobj, v, case, case_climo_loc[casenumber]) + # Load reference data (already regridded to target grid) + ref_data = _retrieve(adfobj, v, data_name) + if ref_data is None: + print(f"\t WARNING: No regridded reference data for {v} in {data_name}, skipping.") + continue + # ASSUMING `time` is 1-12, get the current season: + ref_data = ref_data.sel(time=s).mean(dim='time') + + for casenumber, case in enumerate(case_names): + # Load test case data regridded to match reference grid + case_data = _retrieve(adfobj, v, case) + if case_data is None: + print(f"\t WARNING: No regridded data for {v} in {case}, skipping.") + continue # ASSUMING `time` is 1-12, get the current season: - case_x = case_x.sel(time=seasons[s]).mean(dim='time') - result_by_case[case].loc[v] = taylor_stats_single(case_x, base_x) + case_data = case_data.sel(time=s).mean(dim='time') + # Now compute stats (grids are aligned) + result_by_case[case].loc[v] = taylor_stats_single(case_data, ref_data) # # -- PLOTTING (one per season) -- # @@ -208,65 +209,40 @@ def vertical_average(fld, ps, acoef, bcoef): dp_integrated = 0.5 * (pres.sel(lev=maxlev)**2 - pres.sel(lev=minlev)**2) levaxis = fld.dims.index('lev') # fld needs to be a dataarray assert isinstance(levaxis, int), f'the axis called lev is not an integer: {levaxis}' - fld_integrated = np.trapz(fld * pres, x=pres, axis=levaxis) + fld_integrated = np.trapezoid(fld * pres, x=pres, axis=levaxis) return fld_integrated / dp_integrated -def find_landmask(adf, casename, location): - # maybe it's in the climo files, but we might need to look in the history files: - landfrac_fils = list(Path(location).glob(f"{casename}*_LANDFRAC_*.nc")) - if landfrac_fils: - return xr.open_dataset(landfrac_fils[0])['LANDFRAC'] - else: - if casename in adf.get_cam_info('cam_case_name'): - # cases are in lists, so need to match them - caselist = adf.get_cam_info('cam_case_name') - hloc = adf.get_cam_info('cam_hist_loc') - hloc = hloc[caselist.index(casename)] - else: - hloc = adf.get_baseline_info('cam_hist_loc') - hfils = Path(hloc).glob((f"*{casename}*.nc")) - if not hfils: - raise IOError(f"No history files in expected location: {hloc}") - k = 0 - for h in hfils: - dstmp = xr.open_dataset(h) - if 'LANDFRAC' in dstmp: - print(f"\tGood news, found LANDFRAC in history file") - return dstmp['LANDFRAC'] - else: - k += 1 - else: - raise IOError(f"Checked {k} files, but did not find LANDFRAC in any of them.") - # should not reach past the `if` statement without returning landfrac or raising exception. - -def get_prect(casename, location, **kwargs): - # look for prect first: - fils = sorted(Path(location).glob(f"{casename}*_PRECT_*.nc")) - if len(fils) == 0: - print("\t Need to derive PRECT = PRECC + PRECL") - fils1 = sorted(Path(location).glob(f"{casename}*_PRECC_*.nc")) - fils2 = sorted(Path(location).glob(f"{casename}*_PRECL_*.nc")) - if (len(fils1) == 0) or (len(fils2) == 0): - raise IOError("Could not find PRECC or PRECL") - else: - if len(fils1) == 1: - precc = xr.open_dataset(fils1[0])['PRECC'] - precl = xr.open_dataset(fils2[0])['PRECL'] - prect = precc + precl - else: - raise NotImplementedError("Need to deal with mult-file case.") - elif len(fils) > 1: - prect = xr.open_mfdataset(fils)['PRECT'].load() # do we ever expect climo files split into pieces? - else: - prect = xr.open_dataset(fils[0])['PRECT'] - return prect +def find_landmask(adf, casename): + try: + return _retrieve(adf, 'LANDFRAC', casename) + except Exception as e: + print(f"\t WARNING: Could not find LANDFRAC for {casename}: {e}") + return None - -def get_tropical_land_precip(adf, casename, location, **kwargs): - landfrac = find_landmask(adf, casename, location) +def get_prect(adf, casename, **kwargs): + if casename == 'Obs': + return adf.data.load_reference_regrid_da('PRECT') + else: + # Try regridded PRECT first + prect = adf.data.load_regrid_da(casename, 'PRECT') + if prect is not None: + return prect + # Fallback: derive from PRECC + PRECL using regridded versions + print("\t Need to derive PRECT = PRECC + PRECL (using regridded data)") + precc = adf.data.load_regrid_da(casename, 'PRECC') + precl = adf.data.load_regrid_da(casename, 'PRECL') + if precc is None or precl is None: + print(f"\t WARNING: Could not derive PRECT for {casename} (missing PRECC or PRECL)") + return None + return precc + precl + +def get_tropical_land_precip(adf, casename, **kwargs): + landfrac = find_landmask(adf, casename) if landfrac is None: - raise ValueError("\t No landfrac returned") - prect = get_prect(casename, location) + return None + prect = get_prect(adf, casename) + if prect is None: + return None # mask to only keep land locations prect = xr.DataArray(np.where(landfrac >= .95, prect, np.nan), dims=prect.dims, @@ -275,11 +251,13 @@ def get_tropical_land_precip(adf, casename, location, **kwargs): return prect.sel(lat=slice(-30,30)) -def get_tropical_ocean_precip(adf, casename, location, **kwargs): - landfrac = find_landmask(adf, casename, location) +def get_tropical_ocean_precip(adf, casename, **kwargs): + landfrac = find_landmask(adf, casename) if landfrac is None: - raise ValueError("No landfrac returned") - prect = get_prect(casename, location) + return None + prect = get_prect(adf, casename) + if prect is None: + return None # mask to only keep ocean locations prect = xr.DataArray(np.where(landfrac <= 0.05, prect, np.nan), dims=prect.dims, @@ -287,102 +265,103 @@ def get_tropical_ocean_precip(adf, casename, location, **kwargs): attrs=prect.attrs) return prect.sel(lat=slice(-30,30)) -def get_surface_pressure(dset, casename, location): - #Find surface pressure (PS): +def get_surface_pressure(adf, dset, casename): if 'PS' in dset.variables: #Just use surface pressure in climo file: ps = dset['PS'] else: - #Check if surface pressure exists as a separate climo file: - fils = sorted(Path(location).glob(f"{casename}*_PS_*.nc")) - if (len(fils) == 0): - emsg = f"Could not find PS. This is needed as a separate variable if" - emsg += " reading time series files directly." - raise IOError(emsg) + if casename == 'Obs': + ps = adf.data.load_reference_regrid_da('PS') else: - if len(fils) == 1: - ps_ds = xr.open_dataset(fils[0]) - else: - raise NotImplementedError("Need to deal with mult-file case.") - #End if - ps = ps_ds['PS'] - #End if - - #Return values: + ps = adf.data.load_regrid_da(casename, 'PS') + if ps is None: + print(f"\t WARNING: Could not load PS for {casename}.") + return None return ps -def get_var_at_plev(adf, casename, location, variable, plev): - """ - Get `variable` from the data and then interpolate it to isobaric level `plev` (units of hPa). - """ - dset = _retrieve(adf, variable, casename, location, return_dataset=True) - - # Try and extract surface pressure: - ps = get_surface_pressure(dset, casename, location) - vplev = gc.interp_hybrid_to_pressure(dset['U'], ps, dset['hyam'], dset['hybm'], - new_levels=np.array([100. * plev]), lev_dim='lev') - vplev = vplev.squeeze(drop=True).load() - return vplev +def get_var_at_plev(adf, casename, variable, plev): + if casename == 'Obs': + dset = adf.data.load_reference_regrid_da(variable) + if dset is None or 'lev' not in dset.dims: + print(f"\t WARNING: Obs data for {variable} lacks lev dimension or is unavailable.") + return None + return dset.sel(lev=plev, method='nearest') if dset is not None else None + else: + dset = adf.data.load_regrid_da(casename, variable) + if dset is None: + return None + ps = get_surface_pressure(adf, dset, casename) + if ps is None: + print(f"\t WARNING: Could not load PS for {variable} interpolation in {casename}") + return None + # Proceed with gc.interp_hybrid_to_pressure using regridded data + # (Assumes hyam/hybm are available in dset or can be loaded similarly) + vplev = gc.interp_hybrid_to_pressure(dset, ps, dset['hyam'], dset['hybm'], + new_levels=np.array([100. * plev]), lev_dim='lev') + return vplev.squeeze(drop=True).load() -def get_u_at_plev(adf, casename, location): - return get_var_at_plev(adf, casename, location, "U", 300) +def get_u_at_plev(adf, casename): + return get_var_at_plev(adf, casename, "U", 300) -def get_vertical_average(adf, casename, location, varname): +def get_vertical_average(adf, casename, varname): '''Collect data from case and use `vertical_average` to get result.''' - fils = sorted(Path(location).glob(f"{casename}*_{varname}_*.nc")) - if (len(fils) == 0): - raise IOError(f"Could not find {varname}") + if casename == 'Obs': + ds = adf.data.load_reference_regrid_da(varname) + if ds is None or 'lev' not in ds.dims: + print(f"\t WARNING: Obs data for {varname} lacks lev dimension.") + return None + return ds.mean(dim='lev') else: - if len(fils) == 1: - ds = xr.open_dataset(fils[0]) - else: - raise NotImplementedError("Need to deal with mult-file case.") - # Try and extract surface pressure: - ps = get_surface_pressure(ds, casename, location) - # If the climo file is made by ADF, then hyam and hybm will be with VARIABLE: - return vertical_average(ds[varname], ps, ds['hyam'], ds['hybm']) - - -def get_virh(adf, casename, location, **kwargs): + ds = adf.data.load_regrid_da(casename, varname) + if ds is None: + return None + # Try and extract surface pressure: + ps = get_surface_pressure(adf, ds, casename) + if ps is None: + print(f"\t WARNING: Could not load PS for {varname} interpolation in {casename}") + return None + # If the climo file is made by ADF, then hyam and hybm will be with VARIABLE: + return vertical_average(ds[varname], ps, ds['hyam'], ds['hybm']) + + +def get_virh(adf, casename, **kwargs): '''Calculate vertically averaged relative humidity.''' - return get_vertical_average(adf, casename, location, "RELHUM") + return get_vertical_average(adf, casename, "RELHUM") -def get_vit(adf, casename, location, **kwargs): +def get_vit(adf, casename, **kwargs): '''Calculate vertically averaged temperature.''' - return get_vertical_average(adf, casename, location, "T") + return get_vertical_average(adf, casename, "T") - -def get_landt2m(adf, casename, location): - """Gets TREFHT (T_2m) and removes non-land points.""" - fils = sorted(Path(location).glob(f"{casename}*_TREFHT_*.nc")) - if len(fils) == 0: - raise IOError(f"TREFHT could not be found in the files.") - elif len(fils) > 1: - t = xr.open_mfdataset(fils)['TREFHT'].load() # do we ever expect climo files split into pieces? +def get_landt2m(adf, casename): + if casename == 'Obs': + t = adf.data.load_reference_regrid_da('TREFHT') else: - t = xr.open_dataset(fils[0])['TREFHT'] - landfrac = find_landmask(adf, casename, location) + t = adf.data.load_regrid_da(casename, 'TREFHT') + if t is None: + return None + landfrac = find_landmask(adf, casename) + if landfrac is None: + return None t = xr.DataArray(np.where(landfrac >= .95, t, np.nan), - dims=t.dims, - coords=t.coords, - attrs=t.attrs) # threshold could be 1 + dims=t.dims, coords=t.coords, attrs=t.attrs) return t -def get_eqpactaux(adf, casename, location): + +def get_eqpactaux(adf, casename): """Gets zonal surface wind stress 5S to 5N.""" - fils = sorted(Path(location).glob(f"{casename}*_TAUX_*.nc")) - if len(fils) == 0: - raise IOError(f"TAUX could not be found in the files.") - elif len(fils) > 1: - taux = xr.open_mfdataset(fils)['TAUX'].load() # do we ever expect climo files split into pieces? + if casename == 'Obs': + taux = adf.data.load_reference_regrid_da('TAUX') else: - taux = xr.open_dataset(fils[0])['TAUX'] + taux = adf.data.load_regrid_da(casename, 'TAUX') + if taux is None: + print(f"\t WARNING: Could not load TAUX for {casename}") + return None return taux.sel(lat=slice(-5, 5)) @@ -396,39 +375,51 @@ def get_derive_func(fld): 'EquatorialPacificStress': get_eqpactaux } if fld not in funcs: - raise ValueError(f"We do not have a method for variable: {fld}.") + print(f"We do not have a method for variable: {fld}.") + return None return funcs[fld] -def _retrieve(adfobj, variable, casename, location, return_dataset=False): - """Custom function that retrieves a variable. Returns the variable as a DataArray. - kwarg: - return_dataset -> if true, return the dataset object, otherwise return the DataArray - with `variable` - This option allows get_u_at_plev to use _retrieve. +def _retrieve(adfobj, variable, casename, return_dataset=False): + """Custom function that retrieves a variable using ADF loaders for grid consistency. + Returns the variable as a DataArray (or Dataset if return_dataset=True). """ - v_to_derive = ['TropicalLandPrecip', 'TropicalOceanPrecip', 'EquatorialPacificStress', - 'U300', 'ColumnRelativeHumidity', 'ColumnTemperature', 'Land2mTemperature'] - if variable not in v_to_derive: - fils = sorted(Path(location).glob(f"{casename}*_{variable}_*.nc")) - if len(fils) == 0: - raise ValueError(f"something went wrong for variable: {variable}") - elif len(fils) > 1: - ds = xr.open_mfdataset(fils) # do we ever expect climo files split into pieces? - else: - ds = xr.open_dataset(fils[0]) - if return_dataset: - da = ds - else: - da = ds[variable] - else: - func = get_derive_func(variable) - da = func(adfobj, casename, location) # these ONLY return DataArray + 'U300', 'ColumnRelativeHumidity', 'ColumnTemperature', 'Land2mTemperature'] + + try: + if casename == 'Obs': + if variable not in v_to_derive: + da = adfobj.data.load_reference_regrid_da(variable) + else: + func = get_derive_func(variable) + if func is None: + print(f"\t WARNING: No derivation function for {variable}.") + return None + da = func(adfobj, 'Obs') # No location needed + else: # Model cases + if variable not in v_to_derive: + da = adfobj.data.load_regrid_da(casename, variable) + else: + func = get_derive_func(variable) + if func is None: + print(f"\t WARNING: No derivation function for {variable}.") + return None + da = func(adfobj, casename) # No location needed + + if da is None: + print(f"\t WARNING: Could not load {variable} for {casename}.") + return None + if return_dataset: - da = da.to_dataset(name=variable) - return da - + if not isinstance(da, xr.Dataset): + da = da.to_dataset(name=variable) + return da + + except Exception as e: + print(f"\t WARNING: Error retrieving {variable} for {casename}: {e}") + return None + def weighted_correlation(x, y, weights): # TODO: since we expect masked fields (land/ocean), need to allow for missing values (maybe works already?) @@ -542,6 +533,8 @@ def plot_taylor_data(wks, df, **kwargs): k = 1 for ndx, row in df.iterrows(): # NOTE: ndx will be the DataFrame index, and we expect that to be the variable name + if np.isnan(row['corr']) or np.isnan(row['ratio']): + continue # Skip plotting if data is missing theta = np.pi/2 - np.arccos(row['corr']) # Transform DATA if use_bias: mk = marker_list[row['bias_digi']] From 1399a17bc3637eedc2aacc3f43a4c8a72ff30824 Mon Sep 17 00:00:00 2001 From: Brian Medeiros Date: Mon, 12 Jan 2026 20:21:10 -0700 Subject: [PATCH 2/4] Taylor for obs & multi-case; modified multicase support; modified regridding --- lib/adf_config.py | 53 ++- lib/adf_dataset.py | 112 ++++- lib/adf_info.py | 68 ++- lib/adf_utils.py | 68 ++- scripts/averaging/create_climo_files.py | 113 ++--- scripts/plotting/cam_taylor_diagram.py | 421 +++++++++++++----- .../regridding/regrid_and_vert_interp_2.py | 384 ++++++++++++++++ 7 files changed, 990 insertions(+), 229 deletions(-) create mode 100644 scripts/regridding/regrid_and_vert_interp_2.py diff --git a/lib/adf_config.py b/lib/adf_config.py index b26c55485..b5ff98b09 100644 --- a/lib/adf_config.py +++ b/lib/adf_config.py @@ -237,20 +237,55 @@ def expand_references(self, config_dict): """ #copy YAML config dictionary: - config_dict_copy = copy.copy(config_dict) + config_dict_copy = copy.deepcopy(config_dict) + + #Recursively expand references + self.__expand_dict_refs(config_dict_copy) + + #Update the original dict + config_dict.update(config_dict_copy) + + def __expand_dict_refs(self, config_dict): + + """ + Recursive helper to expand references in nested dicts and lists. + """ #Loop through dictionary: - for key, value in config_dict_copy.items(): + for key, value in config_dict.items(): + + if isinstance(value, str): + #expand any keywords to their full values: + config_dict[key] = self.__expand_yaml_var_ref(value) + + elif isinstance(value, list): + #Handle lists recursively + for i, item in enumerate(value): + if isinstance(item, str): + value[i] = self.__expand_yaml_var_ref(item) + elif isinstance(item, dict): + self.__expand_dict_refs(item) + elif isinstance(item, list): + #Nested list, recurse + self.__expand_list_refs(item) + + elif isinstance(value, dict): + #Recurse into nested dict + self.__expand_dict_refs(value) - #Skip non-strings (as they won't contain a keyword): - if not isinstance(value, str): - continue + def __expand_list_refs(self, config_list): - #expand any keywords to their full values: - new_value = self.__expand_yaml_var_ref(value) + """ + Recursive helper to expand references in lists. + """ - #Set config variable to new, expanded value: - config_dict[key] = new_value + for i, item in enumerate(config_list): + if isinstance(item, str): + config_list[i] = self.__expand_yaml_var_ref(item) + elif isinstance(item, dict): + self.__expand_dict_refs(item) + elif isinstance(item, list): + self.__expand_list_refs(item) ######### diff --git a/lib/adf_dataset.py b/lib/adf_dataset.py index 2cb326cb0..5f2259971 100644 --- a/lib/adf_dataset.py +++ b/lib/adf_dataset.py @@ -31,6 +31,17 @@ # Set AdfData.ref_nickname to that. # Could be altered from "Obs" to be the data source label. +# NOTE: Standard ADF workflow creates time series files with NCO. +# Climo files are then generated with create_climo_files.py +# Since neither of these apply units conversions (add_offset, scale_factor), +# the methods here default to applying them when loading +# time series and climo files, using the kwarg apply_scaling. +# Regridded files are made with regrid_and_vert_interp[_2].py, +# which uses this module for loading climo files, so will apply +# scaling. +# Therefore the default on loading regridded files is to NOT +# apply scaling. + class AdfData: """A class instantiated with an AdfDiag object. Methods provide means to load data. @@ -151,9 +162,12 @@ def load_timeseries_da(self, case, variablename): return None return self.load_da(fils, variablename, add_offset=add_offset, scale_factor=scale_factor) - def load_reference_timeseries_da(self, field): + def load_reference_timeseries_da(self, field, apply_scaling=True): """Return a DataArray time series to be used as reference (aka baseline) for variable field. + + apply_scaling: bool + If True, apply add_offset and scale_factor to data (if present). """ fils = self.get_ref_timeseries_file(field) if not fils: @@ -168,6 +182,10 @@ def load_reference_timeseries_da(self, field): else: add_offset, scale_factor = self.get_value_converters(self.ref_case_label, field) + if not apply_scaling: + add_offset = 0 + scale_factor = 1 + return self.load_da(fils, field, add_offset=add_offset, scale_factor=scale_factor) @@ -178,10 +196,22 @@ def load_reference_timeseries_da(self, field): #------------------ # Test case(s) - def load_climo_da(self, case, variablename): - """Return DataArray from climo file""" + def load_climo_ds(self, case, variablename): + """Return Dataset from climo file; applies scale factor and offset to `variablename`.""" add_offset, scale_factor = self.get_value_converters(case, variablename) fils = self.get_climo_file(case, variablename) + ds = self.load_dataset(fils) + ds[variablename] = ds[variablename] * scale_factor + add_offset + return ds + + def load_climo_da(self, case, variablename, apply_scaling=True): + """Return DataArray from climo file""" + if not apply_scaling: + add_offset = 0 + scale_factor = 1 + else: + add_offset, scale_factor = self.get_value_converters(case, variablename) + fils = self.get_climo_file(case, variablename) return self.load_da(fils, variablename, add_offset=add_offset, scale_factor=scale_factor) @@ -203,11 +233,29 @@ def get_climo_file(self, case, variablename): # Reference case (baseline/obs) - def load_reference_climo_da(self, case, variablename): - """Return DataArray from reference (aka baseline) climo file""" + def load_reference_climo_ds(self, case, variablename, apply_scaling=True): + """Return Dataset from reference climo file; applies scale factor and offset to `variablename`. + """ add_offset, scale_factor = self.get_value_converters(case, variablename) fils = self.get_reference_climo_file(variablename) - return self.load_da(fils, variablename, add_offset=add_offset, scale_factor=scale_factor) + ds = self.load_dataset(fils) + vname = self.ref_var_nam[variablename] # name of variable in the reference data + if not apply_scaling: + add_offset = 0 + scale_factor = 1 + ds[vname] = ds[vname] * scale_factor + add_offset + return ds + + def load_reference_climo_da(self, case, variablename, apply_scaling=True): + """Return DataArray from reference (aka baseline) climo file""" + fils = self.get_reference_climo_file(variablename) + vname = self.ref_var_nam[variablename] + if not apply_scaling: + add_offset = 0 + scale_factor = 1 + else: + add_offset, scale_factor = self.get_value_converters(case, variablename) + return self.load_da(fils, vname, add_offset=add_offset, scale_factor=scale_factor) def get_reference_climo_file(self, var): """Return a list of files to be used as reference (aka baseline) for variable var.""" @@ -230,7 +278,15 @@ def get_reference_climo_file(self, var): # Test case(s) def get_regrid_file(self, case, field): """Return list of test regridded files""" - model_rg_loc = Path(self.adf.get_basic_info("cam_regrid_loc", required=True)) + model_rg_list = self.model_rgrid_loc + if isinstance(model_rg_list, list): + caseindex = self.case_names.index(case) + model_rg_loc = Path(model_rg_list[caseindex]) + elif isinstance(model_rg_list, (str, Path)): + model_rg_loc = Path(model_rg_list) + else: + warnings.warn(f"\t ERROR: Did not find regrid location for case: {case}, variable {field}") + return None rlbl = self.ref_labels[field] # rlbl = "reference label" = the name of the reference data that defines target grid return sorted(model_rg_loc.glob(f"{rlbl}_{case}_{field}_regridded.nc")) @@ -244,9 +300,13 @@ def load_regrid_dataset(self, case, field): return self.load_dataset(fils) - def load_regrid_da(self, case, field): + def load_regrid_da(self, case, field, apply_scaling=False): """Return a data array to be used as reference (aka baseline) for variable field.""" - add_offset, scale_factor = self.get_value_converters(case, field) + if not apply_scaling: + add_offset = 0 + scale_factor = 1 + else: + add_offset, scale_factor = self.get_value_converters(case, field) fils = self.get_regrid_file(case, field) if not fils: warnings.warn(f"\t WARNING: Did not find regrid file(s) for case: {case}, variable: {field}") @@ -264,6 +324,16 @@ def get_ref_regrid_file(self, case, field): else: fils = [] else: + model_rg_list = self.model_rgrid_loc + if isinstance(model_rg_list, list): + caseindex = self.case_names.index(case) + model_rg_loc = Path(model_rg_list[caseindex]) + elif isinstance(model_rg_list, (str, Path)): + model_rg_loc = Path(model_rg_list) + else: + warnings.warn(f"\t ERROR: Did not find regrid location for case: {case}, variable {field}") + return None + model_rg_loc = Path(self.adf.get_basic_info("cam_regrid_loc", required=True)) fils = sorted(model_rg_loc.glob(f"{case}_{field}_baseline.nc")) return fils @@ -278,9 +348,13 @@ def load_reference_regrid_dataset(self, case, field): return self.load_dataset(fils) - def load_reference_regrid_da(self, case, field): + def load_reference_regrid_da(self, case, field, apply_scaling=False): """Return a data array to be used as reference (aka baseline) for variable field.""" - add_offset, scale_factor = self.get_value_converters(case, field) + if not apply_scaling: + add_offset = 0 + scale_factor = 1 + else: + add_offset, scale_factor = self.get_value_converters(case, field) fils = self.get_ref_regrid_file(case, field) if not fils: warnings.warn(f"\t WARNING: Did not find regridded file(s) for case: {case}, variable: {field}") @@ -291,13 +365,9 @@ def load_reference_regrid_da(self, case, field): field = self.ref_var_nam[field] return self.load_da(fils, field, add_offset=add_offset, scale_factor=scale_factor) - #------------------ - - + #--------------------------- # DataSet and DataArray load #--------------------------- - - # Load DataSet def load_dataset(self, fils): """Return xarray DataSet from file(s)""" if (len(fils) == 0): @@ -315,7 +385,6 @@ def load_dataset(self, fils): warnings.warn(f"\t WARNING: invalid data on load_dataset") return ds - # Load DataArray def load_da(self, fils, variablename, **kwargs): """Return xarray DataArray from files(s) w/ optional scale factor, offset, and/or new units""" ds = self.load_dataset(fils) @@ -323,8 +392,13 @@ def load_da(self, fils, variablename, **kwargs): warnings.warn(f"\t WARNING: Load failed for {variablename}") return None da = (ds[variablename]).squeeze() - scale_factor = kwargs.get('scale_factor', 1) - add_offset = kwargs.get('add_offset', 0) + apply_scaling = kwargs.get('apply_scaling', True) + if not apply_scaling: + add_offset = 0 + scale_factor = 1 + else: + scale_factor = kwargs.get('scale_factor', 1) + add_offset = kwargs.get('add_offset', 0) da = da * scale_factor + add_offset if variablename in self.adf.variable_defaults: vres = self.adf.variable_defaults[variablename] diff --git a/lib/adf_info.py b/lib/adf_info.py index fb4612a9b..78ea6d1de 100644 --- a/lib/adf_info.py +++ b/lib/adf_info.py @@ -125,10 +125,9 @@ def __init__(self, config_file, debug=False): emsg += f" {self.__num_cases} entries, instead it has {len(conf_val)}" self.end_diag_fail(emsg) else: - #If not a list, then convert it to one: - self.__cam_climo_info[conf_var] = [conf_val] - #End if - #End for + # If not a list, replicate the scalar value for each case + self.__cam_climo_info[conf_var] = [conf_val] * self.__num_cases + # End for #Initialize ADF variable list: self.__diag_var_list = self.read_config_var('diag_var_list', required=True) @@ -368,12 +367,14 @@ def __init__(self, config_file, debug=False): #Make lists of None to be iterated over for case_names if syears is None: - syears = [None]*len(case_names) - #End if - if eyears is None: - eyears = [None]*len(case_names) - #End if + syears = [None] * self.__num_cases + elif not isinstance(syears, list): + syears = [syears] * self.__num_cases + if eyears is None: + eyears = [None] * self.__num_cases + elif not isinstance(eyears, list): + eyears = [eyears] * self.__num_cases #Extract cam history files location: cam_hist_locs = self.get_cam_info('cam_hist_loc') @@ -388,7 +389,7 @@ def __init__(self, config_file, debug=False): #Initialize CAM history string nested list self.__hist_str = hist_str - + #Check if using pre-made ts files cam_ts_done = self.get_cam_info("cam_ts_done") @@ -445,13 +446,14 @@ def __init__(self, config_file, debug=False): hist_str_case = hist_str[case_idx] if any(cam_hist_locs): #Grab first possible hist string, just looking for years of run - hist_str = hist_str_case[0] + hist_str_use = hist_str_case[0] + print(f"HIST_STR_USE: {hist_str_use = }, as type: {type(hist_str_use)}") #Get climo years for verification or assignment if missing starting_location = Path(cam_hist_locs[case_idx]) print(f"\tChecking history files in '{starting_location}'") - file_list = sorted(starting_location.glob('*'+hist_str+'.*.nc')) + file_list = sorted(starting_location.glob('*'+hist_str_use+'.*.nc')) #Check if the history file location exists if not starting_location.is_dir(): @@ -464,7 +466,7 @@ def __init__(self, config_file, debug=False): self.end_diag_fail(emsg) #Check if there are any history files - file_list = sorted(starting_location.glob('*'+hist_str+'.*.nc')) + file_list = sorted(starting_location.glob('*'+hist_str_use+'.*.nc')) if len(file_list) == 0: msg = "Checking history files:\n" msg += f"\tThere are no history files in '{starting_location}'." @@ -483,7 +485,7 @@ def __init__(self, config_file, debug=False): #Since the last part always includes the time range, grab that with last index (2) #NOTE: this is based off the current CAM file name structure in the form: # $CASE.cam.h#.YYYY.nc - case_climo_yrs = [int(str(i).partition(f"{hist_str}.")[2][0:4]) for i in file_list] + case_climo_yrs = [int(str(i).partition(f"{hist_str_use}.")[2][0:4]) for i in file_list] if not case_climo_yrs: msg = f"\t ERROR: No climo years found in {cam_hist_locs[case_idx]}, " raise AdfError(msg) @@ -617,18 +619,36 @@ def __init__(self, config_file, debug=False): def hist_str_to_list(self, conf_var, conf_val): """ - Make hist_str a nested list [ncases,nfiles] of the given value(s) + Normalizes hist_str input into a nested list [ncases][nfiles]. """ - if isinstance(conf_val, list): - hist_str = conf_val - else: # one case, one hist str - hist_str = [ - conf_val - ] - self.__cam_climo_info[conf_var] = [hist_str] - ######### + n = self.__num_cases + result = None + + # 1. Handle Single String input: "h0" -> [["h0"], ["h0"], ...] + if isinstance(conf_val, str): + result = [[conf_val] for _ in range(n)] + + elif isinstance(conf_val, list): + # 2. Check if it's already a nested list: [["h0"], ["h0"]] + # We check the first element to see if it's a list. + if len(conf_val) == n and all(isinstance(i, list) for i in conf_val): + result = conf_val + + # 3. Check if it's a list of strings matching N cases: ["h0", "h1"] + elif len(conf_val) == n and all(isinstance(i, str) for i in conf_val): + result = [[i] for i in conf_val] + + # 4. Otherwise, treat it as a single set of files for ALL cases: ["h0", "h1"] + else: + # We wrap the list and multiply it + result = [conf_val for _ in range(n)] + + if result is None: + raise ValueError(f"Invalid format for {conf_var}: {conf_val}") + + self.__cam_climo_info[conf_var] = result # Create property needed to return "user" name to user: + - # Create property needed to return "user" name to user: @property def user(self): """Return the "user" name if requested.""" diff --git a/lib/adf_utils.py b/lib/adf_utils.py index 179051011..72c880911 100644 --- a/lib/adf_utils.py +++ b/lib/adf_utils.py @@ -27,6 +27,8 @@ Interpolate model hybrid levels to specified pressure levels. pmid_to_plev(data, pmid, new_levels=None, convert_to_mb=False) Interpolate `data` from hybrid-sigma levels to isobaric levels using provided mid-level pressures. +plev_to_plev(data, new_levels=None, convert_to_mb=False) + Interpolate `data` from isobaric levels to new isobaric levels. zonal_mean_xr(fld) Average over all dimensions except `lev` and `lat`. validate_dims(fld, list_of_dims) @@ -502,7 +504,7 @@ def lev_to_plev(data, ps, hyam, hybm, P0=100000., new_levels=None, Parameters ---------- - data : + data : xarray.DataArray ps : surface pressure hyam, hybm : @@ -616,7 +618,71 @@ def pmid_to_plev(data, pmid, new_levels=None, convert_to_mb=False): return output +def plev_to_plev(data, new_levels=None, convert_to_mb=False): + """Interpolate data from isobaric levels to new isobaric levels. + + Parameters + ---------- + data : xarray.DataArray + field with a vertical pressure coordinate (e.g., 'lev', 'plev', 'pressure') + new_levels : optional + the output pressure levels (Pa), defaults to standard levels + convert_to_mb : bool, optional + flag to convert output to mb (i.e., hPa), defaults to False + + Returns + ------- + output : xarray.DataArray + `data` interpolated onto `new_levels` + """ + + # Try to identify the vertical coordinate name + vert_coord_names = ['lev', 'plev', 'pressure'] + vert_coord = None + for name in vert_coord_names: + if name in data.dims: + vert_coord = name + break + + if vert_coord is None: + raise AdfError(f"plev_to_plev: Could not find a vertical coordinate in {vert_coord_names}.") + + # determine pressure levels to interpolate to: + if new_levels is None: + pnew = 100.0 * np.array([1000, 925, 850, 700, 500, 400, + 300, 250, 200, 150, 100, 70, 50, + 30, 20, 10, 7, 5, 3, 2, 1]) # mandatory levels, converted to Pa + else: + pnew = new_levels + #End if + + # save name of DataArray: + data_name = data.name + + # Create a pressure field that matches the data shape + p_mdl = data[vert_coord] * xr.ones_like(data) + + # reshape data and pressure + zdims = [i for i in data.dims if i != vert_coord] + dstack = data.stack(z=zdims) + pstack = p_mdl.stack(z=zdims) + + # Need to transpose to (vert_coord, z) + dstack = dstack.transpose(vert_coord, 'z') + pstack = pstack.transpose(vert_coord, 'z') + + output = vert_remap(dstack.values, pstack.values, pnew) + output = xr.DataArray(output, name=data_name, dims=("lev", "z"), + coords={"lev":pnew, "z":pstack['z']}) + output = output.unstack() + + # convert vertical dimension to mb/hPa, if requested: + if convert_to_mb: + output["lev"] = output["lev"] / 100.0 + #End if + #Return interpolated output: + return output def validate_dims(fld, list_of_dims): """Check if specified dimensions are in a DataArray. diff --git a/scripts/averaging/create_climo_files.py b/scripts/averaging/create_climo_files.py index 24d1f2098..f28d5a5ff 100644 --- a/scripts/averaging/create_climo_files.py +++ b/scripts/averaging/create_climo_files.py @@ -11,7 +11,6 @@ import numpy as np import xarray as xr # module-level import so all functions can get to it. -import multiprocessing as mp def get_time_slice_by_year(time, startyear, endyear): """ @@ -106,6 +105,10 @@ def create_climo_files(adf, clobber=False, search=None): output_locs = adf.get_cam_info("cam_climo_loc", required=True) calc_climos = adf.get_cam_info("calc_cam_climo") overwrite = adf.get_cam_info("cam_overwrite_climo") + print(f"CHECK ON INPUT") + print(case_names) + print(input_ts_locs) + #Extract simulation years: start_year = adf.climo_yrs["syears"] @@ -224,34 +227,35 @@ def create_climo_files(adf, clobber=False, search=None): # end_diag_script(errmsg) # Previously we would kill the run here. continue - list_of_arguments.append((adf, ts_files, syr, eyr, output_file)) - - - #End of var_list loop - #-------------------- + list_of_arguments.append((adf.user, ts_files, syr, eyr, output_file)) # Parallelize the computation using multiprocessing pool: - with mp.Pool(processes=number_of_cpu) as p: - result = p.starmap(process_variable, list_of_arguments) - - #End of model case loop - #---------------------- - - #Notify user that script has ended: + print(f" --> Starting Pool with {number_of_cpu} workers for {len(list_of_arguments)} variables.") + import multiprocessing as mp + # Use 'spawn' to ensure a fresh memory space for each process + # Safer on HPC systems than the default 'fork' + context = mp.get_context('spawn') + with context.Pool(processes=number_of_cpu) as p: + results = p.starmap(process_variable, list_of_arguments) + # Print results to see if any specific variable failed + for res in results: + if "Failed" in res: + print(f"\t {res}") + print(" ... multiprocessing pool closed.") print(" ...CAM climatologies have been calculated successfully.") # # Local functions # -def process_variable(adf, ts_files, syr, eyr, output_file): +def process_variable(adf_user, ts_files, syr, eyr, output_file): ''' Compute and save the monthly climatology file. Parameters ---------- - adf - The ADF object + adf_user + The user from the ADF object ts_files : list list of paths to time series files syr : str @@ -261,46 +265,43 @@ def process_variable(adf, ts_files, syr, eyr, output_file): output_file : str or Path file path for output climatology file ''' - #Read in files via xarray (xr): - if len(ts_files) == 1: - cam_ts_data = xr.open_dataset(ts_files[0], decode_times=True) - else: - cam_ts_data = xr.open_mfdataset(ts_files, decode_times=True, combine='by_coords') - #Average time dimension over time bounds, if bounds exist: - if 'time_bnds' in cam_ts_data: - time = cam_ts_data['time'] - # NOTE: force `load` here b/c if dask & time is cftime, throws a NotImplementedError: - time = xr.DataArray(cam_ts_data['time_bnds'].load().mean(dim='nbnd').values, dims=time.dims, attrs=time.attrs) - cam_ts_data['time'] = time - cam_ts_data.assign_coords(time=time) - cam_ts_data = xr.decode_cf(cam_ts_data) - #Extract data subset using provided year bounds: - tslice = get_time_slice_by_year(cam_ts_data.time, int(syr), int(eyr)) - cam_ts_data = cam_ts_data.isel(time=tslice) - #Group time series values by month, and average those months together: - cam_climo_data = cam_ts_data.groupby('time.month').mean(dim='time') - #Rename "months" to "time": - cam_climo_data = cam_climo_data.rename({'month':'time'}) - #Set netCDF encoding method (deal with getting non-nan fill values): - enc_dv = {xname: {'_FillValue': None, 'zlib': True, 'complevel': 4} for xname in cam_climo_data.data_vars} - enc_c = {xname: {'_FillValue': None} for xname in cam_climo_data.coords} - enc = {**enc_c, **enc_dv} - - # Create a dictionary of attributes - # Convert the list to a string (join with commas) - ts_files_str = [str(path) for path in ts_files] - ts_files_str = ', '.join(ts_files_str) - attrs_dict = { - "adf_user": adf.user, - "climo_yrs": f"{syr}-{eyr}", - "time_series_files": ts_files_str, - } - cam_climo_data = cam_climo_data.assign_attrs(attrs_dict) - - #Output variable climatology to NetCDF-4 file: - cam_climo_data.to_netcdf(output_file, format='NETCDF4', encoding=enc) - return 1 # All funcs return something. Could do error checking with this if needed. - + import xarray as xr + import numpy as np + import dask + import gc + dask.config.set(scheduler='synchronous') # Disable internal dask multi-threading + try: + # Using chunks={} forces xarray to use dask, which handles memory better + # than loading everything into RAM at once via open_dataset + with xr.open_mfdataset(ts_files, decode_times=True, combine='by_coords', chunks={'time': 12}) as ds: + if 'time_bnds' in ds: + new_time = ds['time_bnds'].load().mean(dim='nbnd') + ds = ds.assign_coords(time=new_time.values) + ds = xr.decode_cf(ds) + + tslice = get_time_slice_by_year(ds.time, int(syr), int(eyr)) + ds_subset = ds.isel(time=tslice) + + climo = ds_subset.groupby('time.month').mean(dim='time') + climo = climo.rename({'month': 'time'}) + + enc_dv = {xname: {'_FillValue': None, 'zlib': True, 'complevel': 4} for xname in climo.data_vars} + enc_c = {xname: {'_FillValue': None} for xname in climo.coords} + enc = {**enc_c, **enc_dv} + + climo.attrs.update({ + "adf_user": adf_user, + "climo_yrs": f"{syr}-{eyr}", + "time_series_files": ", ".join([str(f) for f in ts_files]) + }) + + climo.to_netcdf(output_file, format='NETCDF4', encoding=enc) + return f"Success: {output_file.name}" + except Exception as e: + return f"Failed: {output_file.name} with error: {str(e)}" + finally: + # Force cleanup of memory + gc.collect() def check_averaging_interval(syear_in, eyear_in): """ diff --git a/scripts/plotting/cam_taylor_diagram.py b/scripts/plotting/cam_taylor_diagram.py index 6a029714e..88d5f5d5a 100644 --- a/scripts/plotting/cam_taylor_diagram.py +++ b/scripts/plotting/cam_taylor_diagram.py @@ -4,29 +4,52 @@ Provides a Taylor diagram following the AMWG package. Uses spatial information only. This module, for better or worse, provides both the computation and plotting functionality. -It depends on an ADF instance to obtain the `climo` files. +It depends on an ADF instance to obtain the regridded `climo` files. It is designed to have one "reference" case (could be observations) and arbitrary test cases. When multiple test cases are provided, they are plotted with different colors. -NOTE: THIS IS A DRAFT REFACTORING TO ALLOW OBSERVATIONS (.b) ''' # # --- imports and configuration --- # +import sys +import logging from pathlib import Path import numpy as np import xarray as xr import pandas as pd import geocat.comp as gc # use geocat's interpolation import matplotlib as mpl +mpl.use('Agg') import matplotlib.pyplot as plt from matplotlib.lines import Line2D from matplotlib.legend_handler import HandlerTuple +try: + import xesmf as xe + XESMF_AVAILABLE = True +except ImportError: + XESMF_AVAILABLE = False + print("WARNING: xesmf not available, regridding in derived variables may fail") + import adf_utils as utils -import warnings # use to warn user about missing files. -warnings.formatwarning = utils.my_formatwarning +logger = logging.getLogger(__name__) +console_handler = logging.StreamHandler(sys.stdout) +console_handler.setLevel(logging.INFO) +logger.addHandler(console_handler) +logger.setLevel(logging.DEBUG) +logger.propagate = False + + +def get_level_dim(dset): + """Get the name of the level dimension in the dataset.""" + level_dims = ['lev', 'level', 'ilev'] + for dim in level_dims: + if dim in dset.dims: + return dim + return None + # # --- Main Function Shares Name with Module: cam_taylor_diagram --- @@ -34,20 +57,16 @@ def cam_taylor_diagram(adfobj): """Create Taylor diagrams for specified configuration.""" msg = "\n Generating Taylor Diagrams..." - print(f"{msg}\n {'-' * (len(msg)-3)}") + logger.info(f"{msg}\n {'-' * (len(msg)-3)}") # Extract needed quantities from ADF object: # ----------------------------------------- - # Case names: # NOTE: "baseline" == "reference" == "observations" will be called `base` # test case(s) == case(s) to be diagnosed will be called `case` (assumes a list) - case_names = adfobj.get_cam_info('cam_case_name', required=True) # Loop over these - - #Grab all case nickname(s) - test_nicknames = adfobj.case_nicknames["test_nicknames"] - - syear_cases = adfobj.climo_yrs["syears"] - eyear_cases = adfobj.climo_yrs["eyears"] + case_names: list = adfobj.get_cam_info('cam_case_name', required=True) + test_nicknames: list = adfobj.case_nicknames["test_nicknames"] + syear_cases: list = adfobj.climo_yrs["syears"] + eyear_cases: list = adfobj.climo_yrs["eyears"] # ADF variable which contains the output path for plots and tables: plot_location = adfobj.plot_location @@ -58,39 +77,32 @@ def cam_taylor_diagram(adfobj): plpth = Path(pl) #Check if plot output directory exists, and if not, then create it: if not plpth.is_dir(): - print(f"\t {pl} not found, making new directory") plpth.mkdir(parents=True) if len(plot_location) == 1: plot_loc = Path(plot_location[0]) else: - print(f"Ambiguous plotting location since all cases go on same plot. Will put them in first location: {plot_location[0]}") + logger.warning(f"Ambiguous plotting location since all cases go on same plot. Will put them in first location: {plot_location[0]}") plot_loc = Path(plot_location[0]) else: plot_loc = Path(plot_location) - # reference data set(s) -- if comparing with obs, these are dicts. data_name = adfobj.data.ref_case_label - data_loc = adfobj.data.ref_data_loc base_nickname = adfobj.data.ref_nickname #Extract baseline years (which may be empty strings if using Obs): syear_baseline = adfobj.climo_yrs["syear_baseline"] eyear_baseline = adfobj.climo_yrs["eyear_baseline"] - res = adfobj.variable_defaults # dict of variable-specific plot preferences - # or an empty dictionary if use_defaults was not specified in YAML. - #Set plot file type: # -- this should be set in basic_info_dict, but is not required # -- So check for it, and default to png basic_info_dict = adfobj.read_config_var("diag_basic_info") plot_type = basic_info_dict.get('plot_type', 'png') - print(f"\t NOTE: Plot type is set to {plot_type}") #Check if existing plots need to be redone redo_plot = adfobj.get_basic_info('redo_plot') - print(f"\t NOTE: redo_plot is set to {redo_plot}") + logger.info(f"\t redo_plot is set to {redo_plot}") # Check for required variables taylor_var_set = {'U', 'PSL', 'SWCF', 'LWCF', 'LANDFRAC', 'TREFHT', 'TAUX', 'RELHUM', 'T'} @@ -100,14 +112,14 @@ def cam_taylor_diagram(adfobj): has_prect = 'PRECT' in available_vars has_precl_precc = {'PRECL', 'PRECC'}.issubset(available_vars) if missing_vars or not (has_prect or has_precl_precc): - print("\tTaylor Diagrams skipped due to missing variables:") + logger.warning("\tTaylor Diagrams skipped due to missing variables:") if missing_vars: - print(f"\t - Missing: {', '.join(sorted(missing_vars))}") + logger.warning(f"\t - Missing: {', '.join(sorted(missing_vars))}") if not (has_prect or has_precl_precc): if not has_prect: - print("\t - Missing: PRECT (Alternative PRECL + PRECC also incomplete)") - print("\n\tFull requirement: U, PSL, SWCF, LWCF, LANDFRAC, TREFHT, TAUX, RELHUM, T,") - print("\tAND (PRECT OR both PRECL & PRECC)") + logger.warning("\t - Missing: PRECT (Alternative PRECL + PRECC also incomplete)") + logger.info("\n\tFull requirement: U, PSL, SWCF, LWCF, LANDFRAC, TREFHT, TAUX, RELHUM, T,") + logger.info("\tAND (PRECT OR both PRECL & PRECC)") return @@ -125,20 +137,19 @@ def cam_taylor_diagram(adfobj): 'U300', 'ColumnRelativeHumidity', 'ColumnTemperature'] case_colors = [mpl.cm.tab20(i) for i, case in enumerate(case_names)] # change color for each case + # # LOOP OVER SEASON # - for s in seasons.items(): - plot_name = plot_loc / f"TaylorDiag_{s}_Special_Mean.{plot_type}" - print(f"\t - Plotting Taylor Diagram, {s}") + for season, months in seasons.items(): + logger.debug(f"TAYLOR DIAGRAM SEASON: {season}") + plot_name = plot_loc / f"TaylorDiag_{season}_Special_Mean.{plot_type}" # Check redo_plot. If set to True: remove old plot, if it already exists: if (not redo_plot) and plot_name.is_file(): #Add already-existing plot to website (if enabled): adfobj.debug_log(f"'{plot_name}' exists and clobber is false.") - adfobj.add_website_data(plot_name, "TaylorDiag", None, season=s, multi_case=True) - - #Continue to next iteration: + adfobj.add_website_data(plot_name, "TaylorDiag", None, season=season, multi_case=True) continue elif (redo_plot) and plot_name.is_file(): plot_name.unlink() @@ -147,51 +158,49 @@ def cam_taylor_diagram(adfobj): # variable | correlation | stddev ratio | bias df_template = pd.DataFrame(index=var_list, columns=['corr', 'ratio', 'bias']) result_by_case = {cname: df_template.copy() for cname in case_names} - # - # LOOP OVER VARIABLES - # + for v in var_list: + logger.debug(f"TAYLOR DIAGRAM VARIABLE: {v}") # Load reference data (already regridded to target grid) ref_data = _retrieve(adfobj, v, data_name) if ref_data is None: - print(f"\t WARNING: No regridded reference data for {v} in {data_name}, skipping.") + logger.warning(f"\t WARNING: No regridded reference data for {v} in {data_name}, skipping.") continue - # ASSUMING `time` is 1-12, get the current season: - ref_data = ref_data.sel(time=s).mean(dim='time') + ref_data = ref_data.sel(time=months).mean(dim='time').compute() for casenumber, case in enumerate(case_names): # Load test case data regridded to match reference grid case_data = _retrieve(adfobj, v, case) if case_data is None: - print(f"\t WARNING: No regridded data for {v} in {case}, skipping.") + logger.warning(f"\t WARNING: No regridded data for {v} in {case}, skipping.") continue - # ASSUMING `time` is 1-12, get the current season: - case_data = case_data.sel(time=s).mean(dim='time') - # Now compute stats (grids are aligned) + case_data = case_data.sel(time=months).mean(dim='time').compute() result_by_case[case].loc[v] = taylor_stats_single(case_data, ref_data) - # + # -- PLOTTING (one per season) -- - # - fig, ax = taylor_plot_setup(title=f"Taylor Diagram - {s}", + logger.debug(f"TAYLOR DIAGRAM PLOTTING: {season}") + fig, ax = taylor_plot_setup(title=f"Taylor Diagram - {season}", baseline=f"Baseline: {base_nickname} yrs: {syear_baseline}-{eyear_baseline}") for i, case in enumerate(case_names): + logger.debug(f"\t TAYLOR DIAGRAM CASE: {case}") ax = plot_taylor_data(ax, result_by_case[case], case_color=case_colors[i], use_bias=True) ax = taylor_plot_finalize(ax, test_nicknames, case_colors, syear_cases, eyear_cases, needs_bias_labels=True) + logger.debug(f"TAYLOR DIAGRAM SAVING: {plot_name}") # add text with variable names: txtstrs = [f"{i+1} - {v}" for i, v in enumerate(var_list)] fig.text(0.9, 0.9, "\n".join(txtstrs), va='top') fig.savefig(plot_name, bbox_inches='tight') - adfobj.debug_log(f"\t Taylor Diagram: completed {s}. \n\t File: {plot_name}") + adfobj.debug_log(f"\t Taylor Diagram: completed {season}. \n\t File: {plot_name}") #Add plot to website (if enabled): - adfobj.add_website_data(plot_name, "TaylorDiag", None, season=s, multi_case=True) + adfobj.add_website_data(plot_name, "TaylorDiag", None, season=season, multi_case=True) + plt.close(fig) + logger.debug(f"TAYLOR DIAGRAM FINISHED WITH {season}") #Notify user that script has ended: - print(" ...Taylor Diagrams have been generated successfully.") - - return + logger.info("Taylor Diagrams have been generated successfully.") # # --- Local Functions --- @@ -199,7 +208,7 @@ def cam_taylor_diagram(adfobj): # --- DERIVED VARIABLES --- -def vertical_average(fld, ps, acoef, bcoef): +def vertical_average(fld, ps, acoef, bcoef, level_dim='lev'): """Calculate weighted vertical average using trapezoidal rule. Uses full column.""" pres = utils.pres_from_hybrid(ps, acoef, bcoef) # integral of del_pressure turns out to be just the average of the square of the boundaries: @@ -207,32 +216,150 @@ def vertical_average(fld, ps, acoef, bcoef): maxlev = pres['lev'].max().item() minlev = pres['lev'].min().item() dp_integrated = 0.5 * (pres.sel(lev=maxlev)**2 - pres.sel(lev=minlev)**2) - levaxis = fld.dims.index('lev') # fld needs to be a dataarray - assert isinstance(levaxis, int), f'the axis called lev is not an integer: {levaxis}' + levaxis = fld.dims.index(level_dim) # fld needs to be a dataarray + assert isinstance(levaxis, int), f'the axis called {level_dim} is not an integer: {levaxis}' fld_integrated = np.trapezoid(fld * pres, x=pres, axis=levaxis) return fld_integrated / dp_integrated def find_landmask(adf, casename): - try: - return _retrieve(adf, 'LANDFRAC', casename) - except Exception as e: - print(f"\t WARNING: Could not find LANDFRAC for {casename}: {e}") - return None + logger.debug(f"Finding landmask for {casename}") + return _retrieve(adf, 'LANDFRAC', casename) + + +def regrid_to_target(adf, casename, source_da, target_da, method='conservative'): + """ + Regrid source_da to match the grid of target_da using xesmf. + + Parameters: + - adf: ADF object for getting output locations + - casename: the casename of the source data - used to determine paths. + - source_da: xarray.DataArray to regrid + - target_da: xarray.DataArray with target grid + - method: regridding method ('conservative', 'bilinear', etc.) + + Returns: + - Regridded DataArray + """ + if not XESMF_AVAILABLE: + logger.error("xesmf not available, cannot regrid") + return source_da + + if source_da.lat.shape == target_da.lat.shape and source_da.lon.shape == target_da.lon.shape: + logger.debug("Grids already match, no regridding needed") + return source_da + + logger.debug(f"Regridding from {source_da.lat.shape} x {source_da.lon.shape} to {target_da.lat.shape} x {target_da.lon.shape}") + + # Create clean grids for xesmf + source_grid = _create_clean_grid(source_da) + target_grid = _create_clean_grid(target_da) + + # Manage weights files -- MULTI-CASE NEEDS TO KNOW CASENAME + regrid_loc = adf.get_basic_info("cam_regrid_loc", required=True) + case_list = adf.get_cam_info("cam_case_name", required=True) + if casename == "Obs": + first_case = Path(case_list[0]) + regrid_weights_dir = first_case.parent / "obs_regrid_weights" + else: + case_index = case_list.index(casename) + regrid_loc = regrid_loc[case_index] + regrid_loc = Path(regrid_loc) + regrid_weights_dir = regrid_loc / "regrid_weights" + regrid_weights_dir.mkdir(exist_ok=True) + + # Generate grid descriptions + source_grid_type = "unstructured" if "ncol" in source_da.dims else "structured" + target_grid_type = "unstructured" if "ncol" in target_da.dims else "structured" + + source_grid_desc = f"{source_grid_type}_{len(source_da.lat)}_{len(source_da.lon)}" if source_grid_type == "structured" else f"{source_grid_type}_{len(source_da.ncol)}" + target_grid_desc = f"{target_grid_type}_{len(target_da.lat)}_{len(target_da.lon)}" if target_grid_type == "structured" else f"{target_grid_type}_{len(target_da.ncol)}" + + weights_file = regrid_weights_dir / f"weights_{source_grid_desc}_to_{target_grid_desc}_{method}.nc" + + if weights_file.exists(): + logger.debug(f"Using existing regridding weights file: {weights_file}") + regridder = xe.Regridder(source_grid, target_grid, method, weights=str(weights_file)) + else: + logger.debug(f"Creating new regridding weights file: {weights_file}") + regridder = xe.Regridder(source_grid, target_grid, method) + regridder.to_netcdf(weights_file) + + regridded = regridder(source_da) + + return regridded + + +def _create_clean_grid(da): + """ + Creates a minimal, CF-compliant xarray Dataset for xesmf from a DataArray. + Adapted from regrid_and_vert_interp_2.py + """ + # Convert DataArray to Dataset if needed + if isinstance(da, xr.DataArray): + ds = da.to_dataset() + else: + ds = da + + # Extract raw values + lat_centers = ds.lat.values + lon_centers = ds.lon.values + + # Clip to avoid ESMF range errors + lat_centers = np.clip(lat_centers, -90, 90) + + # Build basic Dataset + clean_ds = xr.Dataset( + coords={ + "lat": (["lat"], lat_centers, {"units": "degrees_north", "standard_name": "latitude"}), + "lon": (["lon"], lon_centers, {"units": "degrees_east", "standard_name": "longitude"}), + } + ) + + # Add Bounds as vertices if they exist + # Check for various possible bounds names + lat_bnds_names = ['lat_bnds', 'lat_bounds', 'latitude_bnds', 'latitude_bounds'] + lon_bnds_names = ['lon_bnds', 'lon_bounds', 'longitude_bnds', 'longitude_bounds'] + + lat_bnds = None + lon_bnds = None + + for name in lat_bnds_names: + if name in ds: + lat_bnds = ds[name] + break + + for name in lon_bnds_names: + if name in ds: + lon_bnds = ds[name] + break + + if lat_bnds is not None and lon_bnds is not None: + lat_v = np.append(lat_bnds.values[:, 0], lat_bnds.values[-1, 1]) + lon_v = np.append(lon_bnds.values[:, 0], lon_bnds.values[-1, 1]) + + # Clip to avoid ESMF range errors + lat_v = np.clip(lat_v, -90, 90) + + # xesmf looks for 'lat_b' and 'lon_b' in the dataset for conservative regridding + clean_ds["lat_b"] = (["lat_f"], lat_v, {"units": "degrees_north"}) + clean_ds["lon_b"] = (["lon_f"], lon_v, {"units": "degrees_east"}) + + return clean_ds def get_prect(adf, casename, **kwargs): if casename == 'Obs': - return adf.data.load_reference_regrid_da('PRECT') + return adf.data.load_reference_regrid_da(adf.data.ref_labels["PRECT"], 'PRECT') else: # Try regridded PRECT first prect = adf.data.load_regrid_da(casename, 'PRECT') if prect is not None: return prect # Fallback: derive from PRECC + PRECL using regridded versions - print("\t Need to derive PRECT = PRECC + PRECL (using regridded data)") + logger.info("\t Need to derive PRECT = PRECC + PRECL (using regridded data)") precc = adf.data.load_regrid_da(casename, 'PRECC') precl = adf.data.load_regrid_da(casename, 'PRECL') if precc is None or precl is None: - print(f"\t WARNING: Could not derive PRECT for {casename} (missing PRECC or PRECL)") + logger.warning(f"\t WARNING: Could not derive PRECT for {casename} (missing PRECC or PRECL)") return None return precc + precl @@ -243,6 +370,9 @@ def get_tropical_land_precip(adf, casename, **kwargs): prect = get_prect(adf, casename) if prect is None: return None + + # Regrid prect to match landfrac grid if necessary + prect = regrid_to_target(adf, casename, prect, landfrac) # mask to only keep land locations prect = xr.DataArray(np.where(landfrac >= .95, prect, np.nan), dims=prect.dims, @@ -258,6 +388,10 @@ def get_tropical_ocean_precip(adf, casename, **kwargs): prect = get_prect(adf, casename) if prect is None: return None + + # Regrid prect to match landfrac grid if necessary + prect = regrid_to_target(adf, casename, prect, landfrac) + # mask to only keep ocean locations prect = xr.DataArray(np.where(landfrac <= 0.05, prect, np.nan), dims=prect.dims, @@ -267,39 +401,70 @@ def get_tropical_ocean_precip(adf, casename, **kwargs): def get_surface_pressure(adf, dset, casename): - if 'PS' in dset.variables: - #Just use surface pressure in climo file: + if isinstance(dset, xr.Dataset) and 'PS' in dset.data_vars: ps = dset['PS'] else: if casename == 'Obs': - ps = adf.data.load_reference_regrid_da('PS') + ps = adf.data.load_reference_regrid_da(adf.data.ref_labels['PS'], 'PS') else: ps = adf.data.load_regrid_da(casename, 'PS') if ps is None: - print(f"\t WARNING: Could not load PS for {casename}.") + logger.warning(f"\t WARNING: Could not load PS for {casename}.") return None return ps def get_var_at_plev(adf, casename, variable, plev): if casename == 'Obs': - dset = adf.data.load_reference_regrid_da(variable) - if dset is None or 'lev' not in dset.dims: - print(f"\t WARNING: Obs data for {variable} lacks lev dimension or is unavailable.") + dset = adf.data.load_reference_regrid_da(adf.data.ref_labels[variable], variable) + if dset is None: + logger.warning(f"\t WARNING: Obs data for {variable} is unavailable.") + return None + level_dim = get_level_dim(dset) + if level_dim is None: + logger.warning(f"\t WARNING: Obs data for {variable} lacks level dimension (lev/level/ilev).") return None - return dset.sel(lev=plev, method='nearest') if dset is not None else None + # For obs, assume already on pressure levels, just select + # Detect pressure units: if max(lev) > 2000, assume Pa, else hPa + lev_max = dset[level_dim].max().item() + if lev_max > 2000: + adjusted_plev = plev * 100 # Convert hPa to Pa + else: + adjusted_plev = plev + return dset.sel(**{level_dim: adjusted_plev}, method='nearest') else: dset = adf.data.load_regrid_da(casename, variable) if dset is None: return None + + # Check if data is already on pressure levels (no hybrid coords) + if 'hyam' not in dset and 'hybm' not in dset: + # Assume already on pressure levels + level_dim = get_level_dim(dset) + if level_dim is not None: + # Detect pressure units: if max(lev) > 2000, assume Pa, else hPa + lev_max = dset[level_dim].max().item() + if lev_max > 2000: + adjusted_plev = plev * 100 # Convert hPa to Pa + else: + adjusted_plev = plev + return dset.sel(**{level_dim: adjusted_plev}, method='nearest') + else: + logger.warning(f"\t WARNING: No level dimension in regridded {variable} for {casename}") + return None + + # Data is on hybrid levels, need to interpolate + level_dim = get_level_dim(dset) + if level_dim is None: + logger.warning(f"\t WARNING: No level dimension in regridded {variable} for {casename}") + return None ps = get_surface_pressure(adf, dset, casename) if ps is None: - print(f"\t WARNING: Could not load PS for {variable} interpolation in {casename}") + logger.warning(f"\t WARNING: Could not load PS for {variable} interpolation in {casename}") return None # Proceed with gc.interp_hybrid_to_pressure using regridded data - # (Assumes hyam/hybm are available in dset or can be loaded similarly) vplev = gc.interp_hybrid_to_pressure(dset, ps, dset['hyam'], dset['hybm'], - new_levels=np.array([100. * plev]), lev_dim='lev') + new_levels=np.array([100. * plev]), lev_dim=level_dim) return vplev.squeeze(drop=True).load() @@ -310,22 +475,42 @@ def get_u_at_plev(adf, casename): def get_vertical_average(adf, casename, varname): '''Collect data from case and use `vertical_average` to get result.''' if casename == 'Obs': - ds = adf.data.load_reference_regrid_da(varname) - if ds is None or 'lev' not in ds.dims: - print(f"\t WARNING: Obs data for {varname} lacks lev dimension.") + ds = adf.data.load_reference_regrid_da(adf.data.ref_labels[varname], varname) + if ds is None: + logger.warning(f"\t WARNING: Obs data for {varname} is unavailable.") + return None + level_dim = get_level_dim(ds) + if level_dim is None: + logger.warning(f"\t WARNING: Obs data for {varname} lacks level dimension.") return None - return ds.mean(dim='lev') + # For obs, assume already on pressure levels, just average + return ds.mean(dim=level_dim) else: ds = adf.data.load_regrid_da(casename, varname) if ds is None: return None - # Try and extract surface pressure: + + # Check if data is already on pressure levels + if 'hyam' not in ds and 'hybm' not in ds: + # Assume already on pressure levels + level_dim = get_level_dim(ds) + if level_dim is not None: + return ds.mean(dim=level_dim) + else: + logger.warning(f"\t WARNING: No level dimension in regridded {varname} for {casename}") + return None + + # Data is on hybrid levels, need vertical averaging + level_dim = get_level_dim(ds) + if level_dim is None: + logger.warning(f"\t WARNING: No level dimension in regridded {varname} for {casename}") + return None ps = get_surface_pressure(adf, ds, casename) if ps is None: - print(f"\t WARNING: Could not load PS for {varname} interpolation in {casename}") + logger.warning(f"\t WARNING: Could not load PS for {varname} interpolation in {casename}") return None # If the climo file is made by ADF, then hyam and hybm will be with VARIABLE: - return vertical_average(ds[varname], ps, ds['hyam'], ds['hybm']) + return vertical_average(ds[varname], ps, ds['hyam'], ds['hybm'], level_dim) def get_virh(adf, casename, **kwargs): @@ -339,14 +524,19 @@ def get_vit(adf, casename, **kwargs): def get_landt2m(adf, casename): if casename == 'Obs': - t = adf.data.load_reference_regrid_da('TREFHT') + t = adf.data.load_reference_regrid_da(adf.data.ref_labels["TREFHT"], 'TREFHT') else: t = adf.data.load_regrid_da(casename, 'TREFHT') if t is None: return None + landfrac = find_landmask(adf, casename) if landfrac is None: return None + + # Regrid t to match landfrac grid if necessary + t = regrid_to_target(adf, casename, t, landfrac) + t = xr.DataArray(np.where(landfrac >= .95, t, np.nan), dims=t.dims, coords=t.coords, attrs=t.attrs) return t @@ -356,16 +546,16 @@ def get_landt2m(adf, casename): def get_eqpactaux(adf, casename): """Gets zonal surface wind stress 5S to 5N.""" if casename == 'Obs': - taux = adf.data.load_reference_regrid_da('TAUX') + taux = adf.data.load_reference_regrid_da(adf.data.ref_labels["TAUX"], 'TAUX') else: taux = adf.data.load_regrid_da(casename, 'TAUX') if taux is None: - print(f"\t WARNING: Could not load TAUX for {casename}") + logger.warning(f"\t WARNING: Could not load TAUX for {casename}") return None return taux.sel(lat=slice(-5, 5)) -def get_derive_func(fld): +def get_derive_func(fld: str): funcs = {'TropicalLandPrecip': get_tropical_land_precip, 'TropicalOceanPrecip': get_tropical_ocean_precip, 'U300': get_u_at_plev, @@ -375,7 +565,7 @@ def get_derive_func(fld): 'EquatorialPacificStress': get_eqpactaux } if fld not in funcs: - print(f"We do not have a method for variable: {fld}.") + logger.warning(f"We do not have a method for variable: {fld}.") return None return funcs[fld] @@ -386,40 +576,30 @@ def _retrieve(adfobj, variable, casename, return_dataset=False): """ v_to_derive = ['TropicalLandPrecip', 'TropicalOceanPrecip', 'EquatorialPacificStress', 'U300', 'ColumnRelativeHumidity', 'ColumnTemperature', 'Land2mTemperature'] - - try: + if variable in v_to_derive: + func = get_derive_func(variable) + if func is None: + logger.error(f"No derivation function available for {variable}") + return None + da = func(adfobj, casename) + if da is None: + logger.warning(f"Derivation function for {variable} returned None for {casename}") + return None + else: if casename == 'Obs': - if variable not in v_to_derive: - da = adfobj.data.load_reference_regrid_da(variable) - else: - func = get_derive_func(variable) - if func is None: - print(f"\t WARNING: No derivation function for {variable}.") - return None - da = func(adfobj, 'Obs') # No location needed - else: # Model cases - if variable not in v_to_derive: - da = adfobj.data.load_regrid_da(casename, variable) - else: - func = get_derive_func(variable) - if func is None: - print(f"\t WARNING: No derivation function for {variable}.") - return None - da = func(adfobj, casename) # No location needed - + logger.debug(f"Loading reference data for {variable}") + da = adfobj.data.load_reference_regrid_da(adfobj.data.ref_labels[variable], variable) + else: + logger.debug(f"Loading regrid data for {variable} in {casename}") + da = adfobj.data.load_regrid_da(casename, variable) if da is None: - print(f"\t WARNING: Could not load {variable} for {casename}.") + logger.warning(f"Failed to load {variable} for {casename}") return None - - if return_dataset: - if not isinstance(da, xr.Dataset): - da = da.to_dataset(name=variable) - return da - - except Exception as e: - print(f"\t WARNING: Error retrieving {variable} for {casename}: {e}") - return None - + + if return_dataset and not isinstance(da, xr.Dataset): + da = da.to_dataset(name=variable) + return da + def weighted_correlation(x, y, weights): # TODO: since we expect masked fields (land/ocean), need to allow for missing values (maybe works already?) @@ -464,7 +644,7 @@ def taylor_stats_single(casedata, refdata, w=True): returns: pattern_correlation, ratio of standard deviation (case/ref), bias """ - lat = casedata['lat'] + lat = casedata.lat if w: wgt = np.cos(np.radians(lat)) else: @@ -478,7 +658,7 @@ def taylor_stats_single(casedata, refdata, w=True): return correlation, a_sigma/b_sigma, bias -def taylor_plot_setup(title,baseline): +def taylor_plot_setup(title, baseline): """Constructs Figure and Axes objects for basic Taylor Diagram.""" fig, ax = plt.subplots(figsize=(8,8), subplot_kw={'projection':'polar'}) corr_labels = np.array([0.0, .1, .2, .3, .4, .5, .6, .7, .8, .9, .95, .99, 1.]) @@ -561,14 +741,15 @@ def taylor_plot_finalize(wks, test_nicknames, casecolors, syear_cases, eyear_cas bottom_of_text = 0.05 height_of_lines = 0.03 - wks.text(0.052, 0.08, "Cases:", - color='k', ha='left', va='bottom', transform=wks.transAxes, fontsize=11) n = 0 for case_idx, (s, c) in enumerate(zip(test_nicknames, casecolors)): wks.text(0.052, bottom_of_text + n*height_of_lines, f"{s} yrs: {syear_cases[case_idx]}-{eyear_cases[case_idx]}", color=c, ha='left', va='bottom', transform=wks.transAxes, fontsize=10) n += 1 + wks.text(0.052, bottom_of_text + n*height_of_lines, "Cases:", + color='k', ha='left', va='bottom', transform=wks.transAxes, fontsize=11) + # BIAS LEGEND if needs_bias_labels: # produce an info-box showing the markers/sizes based on bias diff --git a/scripts/regridding/regrid_and_vert_interp_2.py b/scripts/regridding/regrid_and_vert_interp_2.py new file mode 100644 index 000000000..240f7f27a --- /dev/null +++ b/scripts/regridding/regrid_and_vert_interp_2.py @@ -0,0 +1,384 @@ +"""Driver for horizontal and vertical interpolation. +""" +from pathlib import Path + +import numpy as np +import xarray as xr +import xesmf as xe + +import adf_utils as utils + + +# Default pressure levels for vertical interpolation +DEFAULT_PLEVS = [ + 1000, 925, 850, 700, 500, 400, 300, 250, 200, 150, 100, 70, 50, + 30, 20, 10, 7, 5, 3, 2, 1 +] + +def regrid_and_vert_interp_2(adf): + """ + Regrids the test cases to the same horizontal + grid as the reference climatology and vertically + interpolates the test case (and reference if needed) + to match a default set of pressure levels (in hPa). + """ + msg = "\n Regridding CAM climatologies..." + print(f"{msg}\n {'-' * (len(msg)-3)}") + + overwrite_regrid = adf.get_basic_info("cam_overwrite_regrid", required=True) + output_loc = adf.get_basic_info("cam_regrid_loc", required=True) + output_loc = [Path(i) for i in output_loc] + var_list = adf.diag_var_list + var_defaults = adf.variable_defaults + + case_names = adf.get_cam_info("cam_case_name", required=True) + syear_cases = adf.climo_yrs["syears"] + eyear_cases = adf.climo_yrs["eyears"] + + # Move critical variables to the front of the list + for var in ["PMID", "OCNFRAC", "LANDFRAC", "PS"]: + if var in var_list: + var_list.insert(0, var_list.pop(var_list.index(var))) + + for case_idx, case_name in enumerate(case_names): + print(f"\t Regridding case '{case_name}':") + syear = syear_cases[case_idx] + eyear = eyear_cases[case_idx] + case_output_loc = output_loc[case_idx] + case_output_loc.mkdir(parents=True, exist_ok=True) + + for var in var_list: + print(f"Regridding variable: {var}") + # reset variables + model_ds = None + ref_ds = None + target_name = None + regridded_file_loc = None + model_da = None + ref_da = None + regridder = None + interp_da = None + + model_ds = adf.data.load_climo_ds(case_name, var) + if var in adf.data.ref_var_nam: + ref_ds = adf.data.load_reference_climo_ds(adf.data.ref_case_label, var) + target_name = adf.data.ref_labels[var] + else: + print(f"No reference data available for {var}.") + continue + if not ref_ds: + print(f"Missing reference data for {var}. Skipping.") + continue + if not model_ds: + print(f"Missing model data for {var}. Skipping.") + continue + + regridded_file_loc = case_output_loc / f'{target_name}_{case_name}_{var}_regridded.nc' + + if regridded_file_loc.is_file() and not overwrite_regrid: + print(f"\t INFO: Regridded file already exists, skipping: {regridded_file_loc}") + continue + + if regridded_file_loc.is_file() and overwrite_regrid: + regridded_file_loc.unlink() + + model_da = model_ds[var].squeeze() + ref_da = ref_ds[adf.data.ref_var_nam[var]].squeeze() + + # --- Horizontal Regridding --- + regridded_da = _handle_horizontal_regridding(model_da, ref_ds, adf, case_index=case_idx) + + # --- Vertical Interpolation --- + vert_type = _determine_vertical_coord_type(model_da) + + ps_da = None + if vert_type == 'hybrid': + # For hybrid, we need surface pressure on the target grid. + # It's assumed PS is processed first and is available. + ps_regridded_path = case_output_loc / f'{target_name}_{case_name}_PS_regridded.nc' + if ps_regridded_path.exists(): + ps_da = xr.open_dataset(ps_regridded_path)['PS'] + else: + # Regrid PS on the fly if not found + print("\t INFO: Regridding PS on the fly for hybrid interpolation.") + ps_da_source = adf.data.load_climo_da(case_name, 'PS')['PS'].squeeze() + ps_da = _handle_horizontal_regridding(ps_da_source, ref_da, adf, case_index=case_idx) + + interp_da = _handle_vertical_interpolation(regridded_da, vert_type, model_ds, ps_da=ps_da) + + # --- Masking --- + var_default_dict = var_defaults.get(var, {}) + if 'mask' in var_default_dict and var_default_dict['mask'].lower() == 'ocean': + ocn_frac_regridded_path = case_output_loc / f'{target_name}_{case_name}_OCNFRAC_regridded.nc' + if ocn_frac_regridded_path.exists(): + ocn_frac_da = xr.open_dataset(ocn_frac_regridded_path)['OCNFRAC'] + interp_da = _apply_ocean_mask(interp_da, ocn_frac_da) + else: + print(f"\t WARNING: OCNFRAC not found, unable to apply mask to '{var}'") + + # --- Save to file --- + final_ds = interp_da.to_dataset(name=var) + + # Add back other variables if they were in the original file (like PS, OCNFRAC) + if var == 'OCNFRAC': + final_ds = final_ds # it is already there + if var == 'PS': + final_ds = final_ds # it is already there + + + test_attrs_dict = { + "adf_user": adf.user, + "climo_yrs": f"{case_name}: {syear}-{eyear}", + "climatology_files": str(adf.data.get_climo_file(case_name, var)), + } + final_ds = final_ds.assign_attrs(test_attrs_dict) + + print(f"\t INFO: Saving regridded file: {regridded_file_loc}") + save_to_nc(final_ds, regridded_file_loc) + + print(" ...CAM climatologies have been regridded successfully.") + +def _handle_horizontal_regridding(source_da, target_grid, adf, method='conservative', case_index=None): + """ + Performs horizontal regridding using xesmf. + Manages and reuses regridding weight files. + + Parameters + ---------- + source_da : xarray.DataArray + The DataArray to regrid. + target_grid : xarray.Dataset + A dataset defining the target grid. + adf : adf_diag.AdfDiag + The ADF diagnostics object, used to get output locations. + method : str, optional + Regridding method. Defaults to 'conservative'. + case_index: str + For multi-case, need to provide the case name. + Returns + ------- + xarray.DataArray + The regridded DataArray. + """ + + # Generate a unique name for the weights file + source_grid_type = "unstructured" if "ncol" in source_da.dims else "structured" + target_grid_type = "unstructured" if "ncol" in target_grid.dims else "structured" + + # A simple naming convention for weight files. + source_grid_desc = f"{source_grid_type}_{len(source_da.lat)}_{len(source_da.lon)}" if source_grid_type == "structured" else f"{source_grid_type}_{len(source_da.ncol)}" + target_grid_desc = f"{target_grid_type}_{len(target_grid.lat)}_{len(target_grid.lon)}" if target_grid_type == "structured" else f"{target_grid_type}_{len(target_grid.ncol)}" + + if target_grid_type == "structured": + target_grid = _create_clean_grid(target_grid) + if source_grid_type == "structured": + source_grid = _create_clean_grid(source_da) + + regrid_loc = adf.get_basic_info("cam_regrid_loc", required=True) + if isinstance(regrid_loc, list) and len(regrid_loc)>1: + regrid_loc = regrid_loc[case_index] + else: + regrid_loc = regrid_loc[0] + regrid_loc = Path(regrid_loc) + regrid_weights_dir = regrid_loc / "regrid_weights" + regrid_weights_dir.mkdir(exist_ok=True) + weights_file = regrid_weights_dir / f"weights_{source_grid_desc}_to_{target_grid_desc}_{method}.nc" + if weights_file.exists(): + print(f"INFO: Using existing regridding weights file: {weights_file}") + # xesmf can accept a path to a weights file + regridder = xe.Regridder(source_da, target_grid, method, weights=str(weights_file)) + else: + print(f"INFO: Creating new regridding weights file: {weights_file}") + regridder = xe.Regridder(source_grid, target_grid, method) + regridder.to_netcdf(weights_file) + return regridder(source_da) + +def _create_clean_grid(ds): + """ + Creates a minimal, CF-compliant xarray Dataset for xesmf. + """ + + # Extract raw values + lat_centers = ds.lat.values + lon_centers = ds.lon.values + + # Build basic Dataset + clean_ds = xr.Dataset( + coords={ + "lat": (["lat"], lat_centers, {"units": "degrees_north", "standard_name": "latitude"}), + "lon": (["lon"], lon_centers, {"units": "degrees_east", "standard_name": "longitude"}), + } + ) + + # Add Bounds as vertices if they exist + if 'lat_bnds' in ds and 'lon_bnds' in ds: + lat_v = np.append(ds.lat_bnds.values[:, 0], ds.lat_bnds.values[-1, 1]) + lon_v = np.append(ds.lon_bnds.values[:, 0], ds.lon_bnds.values[-1, 1]) + + # Clip to avoid ESMF range errors + lat_v = np.clip(lat_v, -90, 90) + + # xesmf looks for 'lat_b' and 'lon_b' in the dataset for conservative regridding + clean_ds["lat_b"] = (["lat_f"], lat_v, {"units": "degrees_north"}) + clean_ds["lon_b"] = (["lon_f"], lon_v, {"units": "degrees_east"}) + + return clean_ds + +def _determine_vertical_coord_type(dset): + """ + Determines the type of vertical coordinate in a dataset. + + Parameters + ---------- + dset : xarray.Dataset + The dataset to inspect. + + Returns + ------- + str + The vertical coordinate type: 'hybrid', 'height', 'pressure', or 'none'. + """ + + if 'lev' in dset.dims or 'ilev' in dset.dims: + lev_coord_name = 'lev' if 'lev' in dset.dims else 'ilev' + lev_attrs = dset[lev_coord_name].attrs + + if 'vert_coord' in lev_attrs: + return lev_attrs['vert_coord'] + + if 'long_name' in lev_attrs: + lev_long_name = lev_attrs['long_name'] + if 'hybrid level' in lev_long_name: + return "hybrid" + if 'pressure level' in lev_long_name: + return "pressure" + if 'zeta level' in lev_long_name: + return "height" + + # If no specific metadata is found, make an educated guess. + # This part might need refinement based on expected data conventions. + if 'hyam' in dset or 'hyai' in dset: + return "hybrid" + + print(f"WARNING: Vertical coordinate type for '{lev_coord_name}' could not be determined. Assuming 'pressure'.") + return "pressure" + + return 'none' + +def _handle_vertical_interpolation(da, vert_type, source_ds, ps_da=None): + """ + Performs vertical interpolation to default pressure levels. + + Parameters + ---------- + da : xarray.DataArray + The DataArray to interpolate. + vert_type : str + The vertical coordinate type ('hybrid', 'height', 'pressure'). + source_ds : xarray.Dataset + The source dataset containing auxiliary variables (e.g., hyam, hybm). + ps_da : xarray.DataArray, optional + Surface pressure DataArray, required for hybrid coordinates. + + Returns + ------- + xarray.DataArray + The vertically interpolated DataArray. + """ + if vert_type == 'none': + return da + + if vert_type == "hybrid": + if ps_da is None: + raise ValueError("Surface pressure ('PS') is required for hybrid vertical interpolation.") + + lev_coord_name = 'lev' if 'lev' in source_ds.dims else 'ilev' + hyam_name = 'hyam' if lev_coord_name == 'lev' else 'hyai' + hybm_name = 'hybm' if lev_coord_name == 'lev' else 'hybi' + + if hyam_name not in source_ds or hybm_name not in source_ds: + raise ValueError(f"Hybrid coefficients ('{hyam_name}', '{hybm_name}') not found in dataset.") + + hyam = source_ds[hyam_name] + hybm = source_ds[hybm_name] + + if 'time' in hyam.dims: + hyam = hyam.isel(time=0).squeeze() + if 'time' in hybm.dims: + hybm = hybm.isel(time=0).squeeze() + + p0 = source_ds.get('P0', 100000.0) + if isinstance(p0, xr.DataArray): + p0 = p0.values[0] + + # hot fix for lev attributes + da[lev_coord_name].attrs["axis"] = "Z" + da[lev_coord_name].attrs["positive"] = "down" # standard for pressure/hybrid + da[lev_coord_name].attrs["standard_name"] = "atmosphere_hybrid_sigma_pressure_coordinate" + + return utils.lev_to_plev(da, ps_da, hyam, hybm, P0=p0, convert_to_mb=True, new_levels=DEFAULT_PLEVS) + + elif vert_type == "height": + pmid = source_ds.get('PMID') + if pmid is None: + raise ValueError("'PMID' is required for height vertical interpolation.") + return utils.pmid_to_plev(da, pmid, convert_to_mb=True, new_levels=DEFAULT_PLEVS) + + elif vert_type == "pressure": + return utils.plev_to_plev(da, new_levels=DEFAULT_PLEVS, convert_to_mb=True) + + else: + raise ValueError(f"Unknown vertical coordinate type: '{vert_type}'") + +def _apply_ocean_mask(da, ocn_frac_da): + """ + Applies an ocean mask to a DataArray. + + Parameters + ---------- + da : xarray.DataArray + The DataArray to mask. + ocn_frac_da : xarray.DataArray + The ocean fraction DataArray. + + Returns + ------- + xarray.DataArray + The masked DataArray. + """ + # Ensure ocean fraction is between 0 and 1 + ocn_frac_da = ocn_frac_da.clip(0, 1) + + # Apply the mask + return utils.mask_land_or_ocean(da, ocn_frac_da) + +def save_to_nc(tosave, outname, attrs=None, proc=None): + """Saves xarray variable to new netCDF file + + Parameters + ---------- + tosave : xarray.Dataset or xarray.DataArray + data to write to file + outname : str or Path + output netCDF file path + attrs : dict, optional + attributes dictionary for data + proc : str, optional + string to append to "Processing_info" attribute + """ + + xo = tosave + # deal with getting non-nan fill values. + if isinstance(xo, xr.Dataset): + enc_dv = {xname: {'_FillValue': None} for xname in xo.data_vars} + else: + enc_dv = {} + #End if + enc_c = {xname: {'_FillValue': None} for xname in xo.coords} + enc = {**enc_c, **enc_dv} + if attrs is not None: + xo.attrs = attrs + if proc is not None: + origname = tosave.attrs.get('climatology_files', 'unknown') + xo.attrs['Processing_info'] = f"Start from file {origname}. " + proc + xo.to_netcdf(outname, format='NETCDF4', encoding=enc) From 60c916b72b13cf9d037f33cd21234c6457bc7212 Mon Sep 17 00:00:00 2001 From: Brian Medeiros Date: Thu, 15 Jan 2026 14:00:37 -0700 Subject: [PATCH 3/4] refining taylor diagram and associated changes --- lib/adf_dataset.py | 32 +- lib/adf_utils.py | 19 +- scripts/plotting/cam_taylor_diagram.py | 243 +++-- scripts/regridding/regrid_and_vert_interp.py | 983 ++++++------------ .../regridding/regrid_and_vert_interp.py.OLD | 753 ++++++++++++++ .../regridding/regrid_and_vert_interp_2.py | 384 ------- 6 files changed, 1266 insertions(+), 1148 deletions(-) create mode 100644 scripts/regridding/regrid_and_vert_interp.py.OLD delete mode 100644 scripts/regridding/regrid_and_vert_interp_2.py diff --git a/lib/adf_dataset.py b/lib/adf_dataset.py index 5f2259971..a753d75a4 100644 --- a/lib/adf_dataset.py +++ b/lib/adf_dataset.py @@ -240,10 +240,20 @@ def load_reference_climo_ds(self, case, variablename, apply_scaling=True): fils = self.get_reference_climo_file(variablename) ds = self.load_dataset(fils) vname = self.ref_var_nam[variablename] # name of variable in the reference data + # Check if already transformed (via attribute or units) + # Check if already transformed (via attribute or units) + unit_match = ds[vname].attrs.get('units') == self.adf.variable_defaults[variablename].get('new_unit') + if ds[vname].attrs.get('transformed', False) or unit_match: + apply_scaling = False if not apply_scaling: add_offset = 0 scale_factor = 1 + + attrs = ds[vname].attrs.copy() ds[vname] = ds[vname] * scale_factor + add_offset + ds[vname].attrs = attrs + if scale_factor != 1 or add_offset != 0: + ds[vname].attrs['transformed'] = True return ds def load_reference_climo_da(self, case, variablename, apply_scaling=True): @@ -391,7 +401,7 @@ def load_da(self, fils, variablename, **kwargs): if ds is None: warnings.warn(f"\t WARNING: Load failed for {variablename}") return None - da = (ds[variablename]).squeeze() + da = ds[variablename].squeeze() apply_scaling = kwargs.get('apply_scaling', True) if not apply_scaling: add_offset = 0 @@ -399,12 +409,16 @@ def load_da(self, fils, variablename, **kwargs): else: scale_factor = kwargs.get('scale_factor', 1) add_offset = kwargs.get('add_offset', 0) + attrs = da.attrs.copy() da = da * scale_factor + add_offset - if variablename in self.adf.variable_defaults: - vres = self.adf.variable_defaults[variablename] - da.attrs['units'] = vres.get("new_unit", da.attrs.get('units', 'none')) - else: - da.attrs['units'] = 'none' + da.attrs = attrs + + if scale_factor != 1 or add_offset != 0: + if variablename in self.adf.variable_defaults: + new_unit = self.adf.variable_defaults[variablename].get("new_unit") + if new_unit: + da.attrs['units'] = new_unit + da.attrs['transformed'] = True return da # Get variable conversion defaults, if applicable @@ -433,8 +447,4 @@ def get_value_converters(self, case, variablename): add_offset = vres.get("add_offset", 0) return add_offset, scale_factor - #------------------ - - - - \ No newline at end of file + #------------------ \ No newline at end of file diff --git a/lib/adf_utils.py b/lib/adf_utils.py index 72c880911..b67ca1e26 100644 --- a/lib/adf_utils.py +++ b/lib/adf_utils.py @@ -126,7 +126,7 @@ def mask_land_or_ocean(arr, msk, use_nan=False): missing_value = -999. #End if - arr = xr.where(msk>=0.9,arr,missing_value) + arr = xr.where(msk>=0.9,arr,missing_value, keep_attrs=True) arr.attrs["missing_value"] = missing_value return(arr) @@ -535,7 +535,7 @@ def lev_to_plev(data, ps, hyam, hybm, P0=100000., new_levels=None, #Temporary print statement to notify users to ignore warning messages. #This should be replaced by a debug-log stdout filter at some point: print("Please ignore the interpolation warnings that follow!") - + #Apply GeoCAT hybrid->pressure interpolation: if new_levels is not None: data_interp = gcomp.interpolation.interp_hybrid_to_pressure(data, ps, @@ -559,10 +559,22 @@ def lev_to_plev(data, ps, hyam, hybm, P0=100000., new_levels=None, #Rename vertical dimension back to "lev" in order to work with #the ADF plotting functions: data_interp_rename = data_interp.rename({"plev": "lev"}) + attrs = data_interp_rename.attrs.copy() + lev_orig = data_interp_rename["lev"] + lev_orig_attrs = lev_orig.attrs.copy() #Convert vertical dimension to mb/hPa, if requested: if convert_to_mb: - data_interp_rename["lev"] = data_interp_rename["lev"] / 100.0 + lev_new = lev_orig / 100.0 + lev_new.attrs = lev_orig_attrs + lev_new.name = "lev" + lev_new.attrs["units"] = "hPa" + lev_new.attrs["history"] = f"converted to hPa by dividing by 100 in adf_utils.lev_to_plev" + data_interp_rename["lev"] = lev_new + data_interp_rename.attrs = attrs + else: + data_interp_rename.attrs['units'] = "Pa" + data_interp_rename.attrs['history'] = f"Interpolated using GeoCAT, assume units of Pa in adf_utils.lev_to_plev" return data_interp_rename @@ -678,6 +690,7 @@ def plev_to_plev(data, new_levels=None, convert_to_mb=False): # convert vertical dimension to mb/hPa, if requested: if convert_to_mb: + ## DEAL WITH METADATA BETTER HERE output["lev"] = output["lev"] / 100.0 #End if diff --git a/scripts/plotting/cam_taylor_diagram.py b/scripts/plotting/cam_taylor_diagram.py index 88d5f5d5a..6228074db 100644 --- a/scripts/plotting/cam_taylor_diagram.py +++ b/scripts/plotting/cam_taylor_diagram.py @@ -12,10 +12,13 @@ # # --- imports and configuration --- # +import os import sys import logging from pathlib import Path +import warnings import numpy as np +import numpy.typing as npt import xarray as xr import pandas as pd import geocat.comp as gc # use geocat's interpolation @@ -36,12 +39,30 @@ logger = logging.getLogger(__name__) console_handler = logging.StreamHandler(sys.stdout) -console_handler.setLevel(logging.INFO) +console_handler.setLevel(logging.DEBUG) logger.addHandler(console_handler) logger.setLevel(logging.DEBUG) logger.propagate = False +from contextlib import redirect_stdout, contextmanager + +@contextmanager +def silence_output(): + sys.stdout.flush() # Flush Python's buffer + sys.stderr.flush() + with open(os.devnull, 'w') as devnull: + with redirect_stdout(devnull): + # Also catch the C-level stderr/stdout if possible + # This is the most robust way to silence C extensions + old_stdout_fd = os.dup(sys.stdout.fileno()) + os.dup2(devnull.fileno(), sys.stdout.fileno()) + try: + yield + finally: + os.dup2(old_stdout_fd, sys.stdout.fileno()) + os.close(old_stdout_fd) + def get_level_dim(dset): """Get the name of the level dimension in the dataset.""" level_dims = ['lev', 'level', 'ilev'] @@ -114,21 +135,21 @@ def cam_taylor_diagram(adfobj): if missing_vars or not (has_prect or has_precl_precc): logger.warning("\tTaylor Diagrams skipped due to missing variables:") if missing_vars: - logger.warning(f"\t - Missing: {', '.join(sorted(missing_vars))}") + logger.warning(f"\t Missing: {', '.join(sorted(missing_vars))}") if not (has_prect or has_precl_precc): if not has_prect: - logger.warning("\t - Missing: PRECT (Alternative PRECL + PRECC also incomplete)") + logger.warning("\t Missing: PRECT (Alternative PRECL + PRECC also incomplete)") logger.info("\n\tFull requirement: U, PSL, SWCF, LWCF, LANDFRAC, TREFHT, TAUX, RELHUM, T,") logger.info("\tAND (PRECT OR both PRECL & PRECC)") return #Set seasonal ranges: - seasons = {"ANN": np.arange(1,13,1), - "DJF": [12, 1, 2], - "JJA": [6, 7, 8], - "MAM": [3, 4, 5], - "SON": [9, 10, 11]} + seasons = {"ANN": np.arange(1,13,1)} + # "DJF": [12, 1, 2], + # "JJA": [6, 7, 8], + # "MAM": [3, 4, 5], + # "SON": [9, 10, 11]} # TAYLOR PLOT VARIABLES: var_list = ['PSL', 'SWCF', 'LWCF', @@ -166,7 +187,8 @@ def cam_taylor_diagram(adfobj): if ref_data is None: logger.warning(f"\t WARNING: No regridded reference data for {v} in {data_name}, skipping.") continue - ref_data = ref_data.sel(time=months).mean(dim='time').compute() + with silence_output(): + ref_data = ref_data.sel(time=months).mean(dim='time').compute() for casenumber, case in enumerate(case_names): # Load test case data regridded to match reference grid @@ -208,21 +230,8 @@ def cam_taylor_diagram(adfobj): # --- DERIVED VARIABLES --- -def vertical_average(fld, ps, acoef, bcoef, level_dim='lev'): - """Calculate weighted vertical average using trapezoidal rule. Uses full column.""" - pres = utils.pres_from_hybrid(ps, acoef, bcoef) - # integral of del_pressure turns out to be just the average of the square of the boundaries: - # -- assume lev is a coordinate and is nominally in pressure units - maxlev = pres['lev'].max().item() - minlev = pres['lev'].min().item() - dp_integrated = 0.5 * (pres.sel(lev=maxlev)**2 - pres.sel(lev=minlev)**2) - levaxis = fld.dims.index(level_dim) # fld needs to be a dataarray - assert isinstance(levaxis, int), f'the axis called {level_dim} is not an integer: {levaxis}' - fld_integrated = np.trapezoid(fld * pres, x=pres, axis=levaxis) - return fld_integrated / dp_integrated def find_landmask(adf, casename): - logger.debug(f"Finding landmask for {casename}") return _retrieve(adf, 'LANDFRAC', casename) @@ -249,10 +258,12 @@ def regrid_to_target(adf, casename, source_da, target_da, method='conservative') return source_da logger.debug(f"Regridding from {source_da.lat.shape} x {source_da.lon.shape} to {target_da.lat.shape} x {target_da.lon.shape}") - + # Create clean grids for xesmf - source_grid = _create_clean_grid(source_da) - target_grid = _create_clean_grid(target_da) + # source_grid = _create_clean_grid(source_da) + # target_grid = _create_clean_grid(target_da) + source_grid = _create_clean_grid(source_da.reset_coords(drop=True)) + target_grid = _create_clean_grid(target_da.reset_coords(drop=True)) # Manage weights files -- MULTI-CASE NEEDS TO KNOW CASENAME regrid_loc = adf.get_basic_info("cam_regrid_loc", required=True) @@ -275,25 +286,34 @@ def regrid_to_target(adf, casename, source_da, target_da, method='conservative') target_grid_desc = f"{target_grid_type}_{len(target_da.lat)}_{len(target_da.lon)}" if target_grid_type == "structured" else f"{target_grid_type}_{len(target_da.ncol)}" weights_file = regrid_weights_dir / f"weights_{source_grid_desc}_to_{target_grid_desc}_{method}.nc" - + logger.debug(f">> Weights file: {weights_file}") if weights_file.exists(): logger.debug(f"Using existing regridding weights file: {weights_file}") - regridder = xe.Regridder(source_grid, target_grid, method, weights=str(weights_file)) + with silence_output(): + logger.debug(">>Set up regridder from existing regridding weights file.") + regridder = xe.Regridder(source_grid, target_grid, method, weights=str(weights_file)) + logger.debug("< xr.DataArray - weights -> array-like of weights, probably xr.DataArray - If weights is not the same shape as x, will use `broadcast_like` to - create weights array. - Returns the weighted standard deviation of the full x array. - """ +def weighted_std(x: xr.DataArray, weights: npt.ArrayLike): + """Calculate weighted standard deviation.""" xshape = x.shape wshape = weights.shape if xshape != wshape: @@ -637,12 +697,15 @@ def weighted_std(x, weights): def taylor_stats_single(casedata, refdata, w=True): """This replicates the basic functionality of 'taylor_stats' from NCL. - input: + PARAMTERS + --------- casedata : input data, DataArray refdata : reference case data, DataArray w : if true use cos(latitude) as spatial weight, if false assume uniform weight - returns: - pattern_correlation, ratio of standard deviation (case/ref), bias + RETURNS + ------- + tuple: + pattern correlation, ratio of standard deviation (case/ref), bias """ lat = casedata.lat if w: @@ -715,7 +778,7 @@ def plot_taylor_data(wks, df, **kwargs): # NOTE: ndx will be the DataFrame index, and we expect that to be the variable name if np.isnan(row['corr']) or np.isnan(row['ratio']): continue # Skip plotting if data is missing - theta = np.pi/2 - np.arccos(row['corr']) # Transform DATA + theta = np.pi/2 - np.arccos(np.clip(row['corr'], -1.0, 1.0)) # Transform DATA if use_bias: mk = marker_list[row['bias_digi']] mksz = marker_size[row['bias_digi']] diff --git a/scripts/regridding/regrid_and_vert_interp.py b/scripts/regridding/regrid_and_vert_interp.py index 04e3cd279..5b189b3b6 100644 --- a/scripts/regridding/regrid_and_vert_interp.py +++ b/scripts/regridding/regrid_and_vert_interp.py @@ -1,661 +1,388 @@ """Driver for horizontal and vertical interpolation. """ -import xarray as xr -import adf_utils as utils +from pathlib import Path -def regrid_and_vert_interp(adf): +import numpy as np +import xarray as xr +import xesmf as xe - """ - Regrids the test cases to the same horizontal - grid as the reference climatology and vertically - interpolates the test case (and reference if needed) - to match a default set of pressure levels (in hPa). +import adf_utils as utils - Parameters - ---------- - adf - The ADF object - - Notes - ----- - Default pressure levels: +# Default pressure levels for vertical interpolation +DEFAULT_PLEVS = [ 1000, 925, 850, 700, 500, 400, 300, 250, 200, 150, 100, 70, 50, 30, 20, 10, 7, 5, 3, 2, 1 +] +DEFAULT_PLEVS_Pa = [p*100.0 for p in DEFAULT_PLEVS] - Currently any 3-D observations file needs to have equivalent pressure - levels in order to work properly, although in the future it is hoped - to enable the vertical interpolation of observations as well. +def regrid_and_vert_interp(adf): + """ + Regrids the test cases to the same horizontal + grid as the reference climatology and vertically + interpolates the test case (and reference if needed) + to match a default set of pressure levels (in hPa). """ - - #Import necessary modules: - - from pathlib import Path - - # regridding - # Try just using the xarray method - # import xesmf as xe # This package is for regridding, and is just one potential solution. - - # Steps: - # - load climo files for model and obs - # - calculate all-time and seasonal fields (from individual months) - # - regrid one to the other (probably should be a choice) - - #Notify user that script has started: msg = "\n Regridding CAM climatologies..." print(f"{msg}\n {'-' * (len(msg)-3)}") - #Extract needed quantities from ADF object: - #----------------------------------------- overwrite_regrid = adf.get_basic_info("cam_overwrite_regrid", required=True) - output_loc = adf.get_basic_info("cam_regrid_loc", required=True) - var_list = adf.diag_var_list - var_defaults = adf.variable_defaults + output_loc = adf.get_basic_info("cam_regrid_loc", required=True) + output_loc = [Path(i) for i in output_loc] + var_list = adf.diag_var_list + var_defaults = adf.variable_defaults - #CAM simulation variables (these quantities are always lists): case_names = adf.get_cam_info("cam_case_name", required=True) - input_climo_locs = adf.get_cam_info("cam_climo_loc", required=True) - - #Grab case years syear_cases = adf.climo_yrs["syears"] eyear_cases = adf.climo_yrs["eyears"] - #Check if mid-level pressure, ocean fraction or land fraction exist - #in the variable list: - for var in ["PMID", "OCNFRAC", "LANDFRAC"]: + # Move critical variables to the front of the list + for var in ["PMID", "OCNFRAC", "LANDFRAC", "PS"]: if var in var_list: - #If so, then move them to the front of variable list so - #that they can be used to mask or vertically interpolate - #other model variables if need be: - var_idx = var_list.index(var) - var_list.pop(var_idx) - var_list.insert(0,var) - #End if - #End for - - #Create new variables that potentially stores the re-gridded - #ocean/land fraction dataset: - ocn_frc_ds = None - tgt_ocn_frc_ds = None - - #Check if surface pressure exists in variable list: - if "PS" in var_list: - #If so, then move it to front of variable list so that - #it can be used to vertically interpolate model variables - #if need be. This should be done after PMID so that the order - #is PS, PMID, other variables: - ps_idx = var_list.index("PS") - var_list.pop(ps_idx) - var_list.insert(0,"PS") - #End if - - #Regrid target variables (either obs or a baseline run): - if adf.compare_obs: + var_list.insert(0, var_list.pop(var_list.index(var))) - #Set obs name to match baseline (non-obs) - target_list = ["Obs"] + for case_idx, case_name in enumerate(case_names): + # print(f"\t Regridding case '{case_name}':") + syear = syear_cases[case_idx] + eyear = eyear_cases[case_idx] + case_output_loc = output_loc[case_idx] + case_output_loc.mkdir(parents=True, exist_ok=True) - #Extract variable-obs dictionary: - var_obs_dict = adf.var_obs_dict + for var in var_list: + # print(f"Regridding variable: {var}") + # reset variables + model_ds = None + ref_ds = None + target_name = None + regridded_file_loc = None + model_da = None + ref_da = None + regridder = None + interp_da = None + + if var in adf.data.ref_var_nam: + target_name = adf.data.ref_labels[var] + else: + print(f"\t ERROR: No reference data available for {var}.") + continue + + regridded_file_loc = case_output_loc / f'{target_name}_{case_name}_{var}_regridded.nc' + + if regridded_file_loc.is_file() and not overwrite_regrid: + print(f"\t INFO: Regridded file already exists, skipping: {regridded_file_loc}") + continue + + if regridded_file_loc.is_file() and overwrite_regrid: + regridded_file_loc.unlink() + + + model_ds = adf.data.load_climo_ds(case_name, var) + ref_ds = adf.data.load_reference_climo_ds(adf.data.ref_case_label, var) + if not ref_ds: + print(f"\t ERROR: Missing reference data for {var}. Skipping.") + continue + if not model_ds: + print(f"\t ERROR: Missing model data for {var}. Skipping.") + continue + + model_da = model_ds[var].squeeze() + ref_da = ref_ds[adf.data.ref_var_nam[var]].squeeze() + original_attrs = model_da.attrs.copy() + + # --- Horizontal Regridding --- + regridded_da = _handle_horizontal_regridding(model_da, ref_ds, adf, case_index=case_idx) + regridded_da.attrs.update(original_attrs) + # --- Vertical Interpolation --- + vert_type = _determine_vertical_coord_type(model_da) + ps_da = None + if vert_type == 'hybrid': + # For hybrid, we need surface pressure on the target grid. + # It's assumed PS is processed first and is available. + ps_regridded_path = case_output_loc / f'{target_name}_{case_name}_PS_regridded.nc' + if ps_regridded_path.exists(): + ps_da = xr.open_dataset(ps_regridded_path)['PS'] + else: + # Regrid PS on the fly if not found + ps_da_source = adf.data.load_climo_da(case_name, 'PS')['PS'].squeeze() + original_ps_attrs = ps_da_source.attrs.copy() + ps_da = _handle_horizontal_regridding(ps_da_source, ref_da, adf, case_index=case_idx) + ps_da.attrs.update(original_ps_attrs) + interp_da = _handle_vertical_interpolation(regridded_da, vert_type, model_ds, ps_da=ps_da) + interp_da.attrs.update(original_attrs) + # --- Masking --- + var_default_dict = var_defaults.get(var, {}) + if 'mask' in var_default_dict and var_default_dict['mask'].lower() == 'ocean': + ocn_frac_regridded_path = case_output_loc / f'{target_name}_{case_name}_OCNFRAC_regridded.nc' + if ocn_frac_regridded_path.exists(): + ocn_frac_da = xr.open_dataset(ocn_frac_regridded_path)['OCNFRAC'] + interp_da = _apply_ocean_mask(interp_da, ocn_frac_da) + else: + print(f"\t WARNING: OCNFRAC not found, unable to apply mask to '{var}'") + + # --- Save to file --- + final_ds = interp_da.to_dataset(name=var) + + # Add back other variables if they were in the original file (like PS, OCNFRAC) + if var == 'OCNFRAC': + final_ds = final_ds # it is already there + if var == 'PS': + final_ds = final_ds # it is already there + + + test_attrs_dict = { + "adf_user": adf.user, + "climo_yrs": f"{case_name}: {syear}-{eyear}", + "climatology_files": str(adf.data.get_climo_file(case_name, var)), + } + final_ds = final_ds.assign_attrs(test_attrs_dict) + save_to_nc(final_ds, regridded_file_loc) - #If dictionary is empty, then there are no observations to regrid to, - #so quit here: - if not var_obs_dict: - print("\t No observations found to regrid to, so no re-gridding will be done.") - return - #End if + print(" ...CAM climatologies have been regridded successfully.") - else: +def _handle_horizontal_regridding(source_da, target_grid, adf, method='conservative', case_index=None): + """ + Performs horizontal regridding using xesmf. + Manages and reuses regridding weight files. - #Extract model baseline variables: - target_loc = adf.get_baseline_info("cam_climo_loc", required=True) - target_list = [adf.get_baseline_info("cam_case_name", required=True)] - #End if + Parameters + ---------- + source_da : xarray.DataArray + The DataArray to regrid. + target_grid : xarray.Dataset + A dataset defining the target grid. + adf : adf_diag.AdfDiag + The ADF diagnostics object, used to get output locations. + method : str, optional + Regridding method. Defaults to 'conservative'. + case_index: str + For multi-case, need to provide the case name. + Returns + ------- + xarray.DataArray + The regridded DataArray. + """ - #Grab baseline years (which may be empty strings if using Obs): - syear_baseline = adf.climo_yrs["syear_baseline"] - eyear_baseline = adf.climo_yrs["eyear_baseline"] + # Generate a unique name for the weights file + source_grid_type = "unstructured" if "ncol" in source_da.dims else "structured" + target_grid_type = "unstructured" if "ncol" in target_grid.dims else "structured" - #Set attributes dictionary for climo years to save in the file attributes - base_climo_yrs_attr = f"{target_list[0]}: {syear_baseline}-{eyear_baseline}" + # A simple naming convention for weight files. + source_grid_desc = f"{source_grid_type}_{len(source_da.lat)}_{len(source_da.lon)}" if source_grid_type == "structured" else f"{source_grid_type}_{len(source_da.ncol)}" + target_grid_desc = f"{target_grid_type}_{len(target_grid.lat)}_{len(target_grid.lon)}" if target_grid_type == "structured" else f"{target_grid_type}_{len(target_grid.ncol)}" - #----------------------------------------- + if target_grid_type == "structured": + target_grid = _create_clean_grid(target_grid) + if source_grid_type == "structured": + source_grid = _create_clean_grid(source_da) - #Set output/target data path variables: - #------------------------------------ - rgclimo_loc = Path(output_loc) - if not adf.compare_obs: - tclimo_loc = Path(target_loc) - #------------------------------------ + regrid_loc = adf.get_basic_info("cam_regrid_loc", required=True) + if isinstance(regrid_loc, list) and len(regrid_loc)>1: + regrid_loc = regrid_loc[case_index] + else: + regrid_loc = regrid_loc[0] + regrid_loc = Path(regrid_loc) + regrid_weights_dir = regrid_loc / "regrid_weights" + regrid_weights_dir.mkdir(exist_ok=True) + weights_file = regrid_weights_dir / f"weights_{source_grid_desc}_to_{target_grid_desc}_{method}.nc" + if weights_file.exists(): + # print(f"INFO: Using existing regridding weights file: {weights_file}") + # xesmf can accept a path to a weights file + regridder = xe.Regridder(source_da, target_grid, method, weights=str(weights_file)) + else: + # print(f"INFO: Creating new regridding weights file: {weights_file}") + regridder = xe.Regridder(source_grid, target_grid, method) + regridder.to_netcdf(weights_file) + return regridder(source_da) - #Check if re-gridded directory exists, and if not, then create it: - if not rgclimo_loc.is_dir(): - print(f" {rgclimo_loc} not found, making new directory") - rgclimo_loc.mkdir(parents=True) - #End if - #Loop over CAM cases: - for case_idx, case_name in enumerate(case_names): +def _create_clean_grid(da): + """ + Creates a minimal, CF-compliant xarray Dataset for xesmf from a DataArray. + Adapted from regrid_and_vert_interp_2.py + """ + if isinstance(da, xr.DataArray): + ds = da.to_dataset() + else: + ds = da - #Notify user of model case being processed: - print(f"\t Regridding case '{case_name}' :") + # Extract raw values + lat_centers = ds.lat.values.astype(np.float64) + lon_centers = ds.lon.values.astype(np.float64) - #Set case climo data path: - mclimo_loc = Path(input_climo_locs[case_idx]) + if np.any(np.isnan(lat_centers)) or np.any(np.isinf(lat_centers)): + print("ERROR: Found NaNs or Infs in latitude centers!") + lat_centers = np.nan_to_num(lat_centers, nan=0.0, posinf=90.0, neginf=-90.0) - #Create empty dictionaries which store the locations of regridded surface - #pressure and mid-level pressure fields: - ps_loc_dict = {} - pmid_loc_dict = {} - #Get climo years for case - syear = syear_cases[case_idx] - eyear = eyear_cases[case_idx] + # Clip to avoid ESMF range errors + lat_centers = np.clip(lat_centers, -89.999999, 89.999999).astype(np.float64) - # probably want to do this one variable at a time: - for var in var_list: + # Build basic Dataset + clean_ds = xr.Dataset( + coords={ + "lat": (["lat"], lat_centers, {"units": "degrees_north", "standard_name": "latitude"}), + "lon": (["lon"], lon_centers, {"units": "degrees_east", "standard_name": "longitude"}), + } + ) - if adf.compare_obs: - #Check if obs exist for the variable: - if var in var_obs_dict: - #Note: In the future these may all be lists, but for - #now just convert the target_list. - #Extract target file: - tclimo_loc = var_obs_dict[var]["obs_file"] - #Extract target list (eventually will be a list, for now need to convert): - target_list = [var_obs_dict[var]["obs_name"]] - else: - dmsg = f"No obs found for variable `{var}`, regridding skipped." - adf.debug_log(dmsg) - continue - #End if - #End if - - #Notify user of variable being regridded: - print(f"\t - regridding {var} (known targets: {target_list})") - - #loop over regridding targets: - for target in target_list: - - #Write to debug log if enabled: - adf.debug_log(f"regrid_example: regrid target = {target}") - - #Determine regridded variable file name: - regridded_file_loc = rgclimo_loc / f'{target}_{case_name}_{var}_regridded.nc' - - #If surface or mid-level pressure, then save for potential use by other variables: - if var == "PS": - ps_loc_dict[target] = regridded_file_loc - elif var == "PMID": - pmid_loc_dict[target] = regridded_file_loc - #End if - - #Check if re-gridded file already exists and over-writing is allowed: - if regridded_file_loc.is_file() and overwrite_regrid: - #If so, then delete current file: - regridded_file_loc.unlink() - #End if - - #Check again if re-gridded file already exists: - if not regridded_file_loc.is_file(): - - #Create list of regridding target files (we should explore intake as an alternative to having this kind of repeated code) - # NOTE: This breaks if you have files from different cases in same directory! - if adf.compare_obs: - #For now, only grab one file (but convert to list for use below): - tclim_fils = [tclimo_loc] - else: - tclim_fils = sorted(tclimo_loc.glob(f"{target}*_{var}_climo.nc")) - #End if - - #Write to debug log if enabled: - adf.debug_log(f"regrid_example: tclim_fils (n={len(tclim_fils)}): {tclim_fils}") - - if len(tclim_fils) > 1: - #Combine all target files together into a single data set: - tclim_ds = xr.open_mfdataset(tclim_fils, combine='by_coords') - elif len(tclim_fils) == 0: - print(f"\t WARNING: regridding {var} failed, no climo file for case '{target}'. Continuing to next variable.") - continue - else: - #Open single file as new xarray dataset: - tclim_ds = xr.open_dataset(tclim_fils[0]) - #End if - - #Generate CAM climatology (climo) file list: - mclim_fils = sorted(mclimo_loc.glob(f"{case_name}_{var}_*.nc")) - - if len(mclim_fils) > 1: - #Combine all cam files together into a single data set: - mclim_ds = xr.open_mfdataset(mclim_fils, combine='by_coords') - elif len(mclim_fils) == 0: - #wmsg = f"\t WARNING: Unable to find climo file for '{var}'." - #wmsg += " Continuing to next variable." - wmsg= f"\t WARNING: regridding {var} failed, no climo file for case '{case_name}'. Continuing to next variable." - print(wmsg) - continue - else: - #Open single file as new xarray dataset: - mclim_ds = xr.open_dataset(mclim_fils[0]) - #End if - - #Create keyword arguments dictionary for regridding function: - regrid_kwargs = {} - - #Check if target in relevant pressure variable dictionaries: - if target in ps_loc_dict: - regrid_kwargs.update({'ps_file': ps_loc_dict[target]}) - #End if - if target in pmid_loc_dict: - regrid_kwargs.update({'pmid_file': pmid_loc_dict[target]}) - #End if - - #Perform regridding and interpolation of variable: - rgdata_interp = _regrid_and_interpolate_levs(mclim_ds, var, - regrid_dataset=tclim_ds, - **regrid_kwargs) - - #Extract defaults for variable: - var_default_dict = var_defaults.get(var, {}) - - if 'mask' in var_default_dict: - if var_default_dict['mask'].lower() == 'ocean': - #Check if the ocean fraction has already been regridded - #and saved: - if ocn_frc_ds: - ofrac = ocn_frc_ds['OCNFRAC'] - # set the bounds of regridded ocnfrac to 0 to 1 - ofrac = xr.where(ofrac>1,1,ofrac) - ofrac = xr.where(ofrac<0,0,ofrac) - - # apply ocean fraction mask to variable - rgdata_interp['OCNFRAC'] = ofrac - var_tmp = rgdata_interp[var] - var_tmp = utils.mask_land_or_ocean(var_tmp,ofrac) - rgdata_interp[var] = var_tmp - else: - print(f"\t WARNING: OCNFRAC not found, unable to apply mask to '{var}'") - #End if - else: - #Currently only an ocean mask is supported, so print warning here: - wmsg = "\t WARNING: Currently the only variable mask option is 'ocean'," - wmsg += f"not '{var_default_dict['mask'].lower()}'" - print(wmsg) - #End if - #End if - - #If the variable is ocean fraction, then save the dataset for use later: - if var == 'OCNFRAC': - ocn_frc_ds = rgdata_interp - #End if - - #Finally, write re-gridded data to output file: - #Convert the list of Path objects to a list of strings - climatology_files_str = [str(path) for path in mclim_fils] - climatology_files_str = ', '.join(climatology_files_str) - test_attrs_dict = { - "adf_user": adf.user, - "climo_yrs": f"{case_name}: {syear}-{eyear}", - "climatology_files": climatology_files_str, - } - rgdata_interp = rgdata_interp.assign_attrs(test_attrs_dict) - save_to_nc(rgdata_interp, regridded_file_loc) - rgdata_interp.close() # bpm: we are completely done with this data - - #Now vertically interpolate baseline (target) climatology, - #if applicable: - - #Set interpolated baseline file name: - interp_bl_file = rgclimo_loc / f'{target}_{var}_baseline.nc' - - if not adf.compare_obs and not interp_bl_file.is_file(): - - #Look for a baseline climo file for surface pressure (PS): - bl_ps_fil = tclimo_loc / f'{target}_PS_climo.nc' - - #Also look for a baseline climo file for mid-level pressure (PMID): - bl_pmid_fil = tclimo_loc / f'{target}_PMID_climo.nc' - - #Create new keyword arguments dictionary for regridding function: - regrid_kwargs = {} - - #Check if PS and PMID files exist: - if bl_ps_fil.is_file(): - regrid_kwargs.update({'ps_file': bl_ps_fil}) - #End if - if bl_pmid_fil.is_file(): - regrid_kwargs.update({'pmid_file': bl_pmid_fil}) - #End if - - #Generate vertically-interpolated baseline dataset: - tgdata_interp = _regrid_and_interpolate_levs(tclim_ds, var, - **regrid_kwargs) - - if tgdata_interp is None: - #Something went wrong during interpolation, so just cycle through - #for now: - continue - #End if - - #If the variable is ocean fraction, then save the dataset for use later: - if var == 'OCNFRAC': - tgt_ocn_frc_ds = tgdata_interp - #End if - - if 'mask' in var_default_dict: - if var_default_dict['mask'].lower() == 'ocean': - #Check if the ocean fraction has already been regridded - #and saved: - if tgt_ocn_frc_ds: - ofrac = tgt_ocn_frc_ds['OCNFRAC'] - # set the bounds of regridded ocnfrac to 0 to 1 - ofrac = xr.where(ofrac>1,1,ofrac) - ofrac = xr.where(ofrac<0,0,ofrac) - # mask the land in TS for global means - tgdata_interp['OCNFRAC'] = ofrac - ts_tmp = tgdata_interp[var] - ts_tmp = utils.mask_land_or_ocean(ts_tmp,ofrac) - tgdata_interp[var] = ts_tmp - else: - wmsg = "\t WARNING: OCNFRAC not found in target," - wmsg += f" unable to apply mask to '{var}'" - print(wmsg) - #End if - #End if - #End if - - # Convert the list to a string (join with commas or another separator) - climatology_files_str = [str(path) for path in tclim_fils] - climatology_files_str = ', '.join(climatology_files_str) - # Create a dictionary of attributes - base_attrs_dict = { - "adf_user": adf.user, - "climo_yrs": f"{case_name}: {syear}-{eyear}; {base_climo_yrs_attr}", - "climatology_files": climatology_files_str, - } - tgdata_interp = tgdata_interp.assign_attrs(base_attrs_dict) - - #Write interpolated baseline climatology to file: - save_to_nc(tgdata_interp, interp_bl_file) - #End if - else: - print("\t INFO: Regridded file already exists, so skipping...") - #End if (file check) - #End do (target list) - #End do (variable list) - #End do (case list) + # Add Bounds as vertices if they exist + # Check for various possible bounds names + lat_bnds_names = ['lat_bnds', 'lat_bounds', 'latitude_bnds', 'latitude_bounds'] + lon_bnds_names = ['lon_bnds', 'lon_bounds', 'longitude_bnds', 'longitude_bounds'] + + lat_bnds = None + lon_bnds = None + + for name in lat_bnds_names: + if name in ds: + lat_bnds = ds[name] + break + + for name in lon_bnds_names: + if name in ds: + lon_bnds = ds[name] + break + + if lat_bnds is not None and lon_bnds is not None: + lat_v = np.append(lat_bnds.values[:, 0], lat_bnds.values[-1, 1]) + lon_v = np.append(lon_bnds.values[:, 0], lon_bnds.values[-1, 1]) - #Notify user that script has ended: - print(" ...CAM climatologies have been regridded successfully.") + # Clip to avoid ESMF range errors + lat_v = np.clip(lat_v, -89.9999, 89.9999).astype(np.float64) -################# -#Helper functions -################# + # xesmf looks for 'lat_b' and 'lon_b' in the dataset for conservative regridding + clean_ds["lat_b"] = (["lat_f"], lat_v, {"units": "degrees_north"}) + clean_ds["lon_b"] = (["lon_f"], lon_v, {"units": "degrees_east"}) + return clean_ds -def _regrid_and_interpolate_levs(model_dataset, var_name, regrid_dataset=None, **kwargs): +def _determine_vertical_coord_type(dset): """ - Function that takes a variable from a model xarray - dataset, regrids it to another dataset's lat/lon - coordinates (if applicable), and then interpolates - it vertically to a set of pre-defined pressure levels. + Determines the type of vertical coordinate in a dataset. Parameters ---------- - model_dataset : xarray.Dataset - The xarray dataset which contains the model variable data - var_name : str - The name of the variable to be regridded/interpolated. - regrid_dataset : xr.Dataset or xr.DataArray, optional - The xarray object that contains the destination lat/lon grid - If not present then only vertical interpolation will be performed. - **kwargs - Additional optional arguments: - - `ps_file` : str or Path - specify surface pressure netCDF file - - `pmid_file` : str or Path - specify vertical layer midpoint pressure netCDF file + dset : xarray.Dataset + The dataset to inspect. Returns ------- - xarray.Dataset - This function returns a new xarray dataset that contains the regridded - and/or vertically-interpolated model variable. + str + The vertical coordinate type: 'hybrid', 'height', 'pressure', or 'none'. """ - #Import ADF-specific functions: - import plotting_functions as pf + if 'lev' in dset.dims or 'ilev' in dset.dims: + lev_coord_name = 'lev' if 'lev' in dset.dims else 'ilev' + lev_attrs = dset[lev_coord_name].attrs - #Extract keyword arguments: - if 'ps_file' in kwargs: - ps_file = kwargs['ps_file'] - else: - ps_file = None - #End if - if 'pmid_file' in kwargs: - pmid_file = kwargs['pmid_file'] - else: - pmid_file = None - #End if - - #Extract variable info from model data (and remove any degenerate dimensions): - mdata = model_dataset[var_name].squeeze() - mdat_ofrac = None - #if regrid_ofrac: - # if 'OCNFRAC' in model_dataset: - # mdat_ofrac = model_dataset['OCNFRAC'].squeeze() - - #Check if variable has a vertical component: - if 'lev' in mdata.dims or 'ilev' in mdata.dims: - has_lev = True - - #If lev exists, then determine what kind of vertical coordinate - #is being used: - if 'lev' in mdata.dims: - lev_attrs = model_dataset['lev'].attrs - elif 'ilev' in mdata.dims: - lev_attrs = model_dataset['ilev'].attrs - - #First check if there is a "vert_coord" attribute: if 'vert_coord' in lev_attrs: - vert_coord_type = lev_attrs['vert_coord'] - else: - #Next check that the "long_name" attribute exists: - if 'long_name' in lev_attrs: - #Extract long name: - lev_long_name = lev_attrs['long_name'] - - #Check for "keywords" in the long name: - if 'hybrid level' in lev_long_name: - #Set model to hybrid vertical levels: - vert_coord_type = "hybrid" - elif 'zeta level' in lev_long_name: - #Set model to height (z) vertical levels: - vert_coord_type = "height" - else: - #Print a warning, and skip variable re-gridding/interpolation: - wmsg = "WARNING! Unable to determine the vertical coordinate" - wmsg +=f" type from the 'lev' long name, which is:\n'{lev_long_name}'" - print(wmsg) - return None - #End if + return lev_attrs['vert_coord'] - else: - #Print a warning, and assume hybrid levels (for now): - wmsg = "WARNING! No long name found for the 'lev' dimension," - wmsg += f" so no re-gridding/interpolation will be done." - print(wmsg) - return None - #End if - #End if + if 'long_name' in lev_attrs: + lev_long_name = lev_attrs['long_name'] + if 'hybrid level' in lev_long_name: + return "hybrid" + if 'pressure level' in lev_long_name: + return "pressure" + if 'zeta level' in lev_long_name: + return "height" - else: - has_lev = False - #End if + # If no specific metadata is found, make an educated guess. + # This part might need refinement based on expected data conventions. + if 'hyam' in dset or 'hyai' in dset: + return "hybrid" - #Check if variable has a vertical levels dimension: - if has_lev: - - if vert_coord_type == "hybrid": - # Need hyam, hybm, and P0 for vertical interpolation of hybrid levels: - if 'lev' in mdata.dims: - if ('hyam' not in model_dataset) or ('hybm' not in model_dataset): - print(f"!! PROBLEM -- NO hyam or hybm for 3-D variable {var_name}, so it will not be re-gridded.") - return None #Return None to skip to next variable. - #End if - mhya = model_dataset['hyam'] - mhyb = model_dataset['hybm'] - elif 'ilev' in mdata.dims: - if ('hyai' not in model_dataset) or ('hybi' not in model_dataset): - print(f"!! PROBLEM -- NO hyai or hybi for 3-D variable {var_name}, so it will not be re-gridded.") - return None #Return None to skip to next variable. - #End if - mhya = model_dataset['hyai'] - mhyb = model_dataset['hybi'] - if 'time' in mhya.dims: - mhya = mhya.isel(time=0).squeeze() - if 'time' in mhyb.dims: - mhyb = mhyb.isel(time=0).squeeze() - if 'P0' in model_dataset: - P0_tmp = model_dataset['P0'] - if isinstance(P0_tmp, xr.DataArray): - #All of these value should be the same, - #so just grab the first one: - P0 = P0_tmp[0] - else: - #Just rename variable: - P0 = P0_tmp - #End if - else: - P0 = 100000.0 # Pa - #End if - - elif vert_coord_type == "height": - #Initialize already-regridded PMID logical: - regridded_pmid = False + print(f"\t WARNING: Vertical coordinate type for '{lev_coord_name}' could not be determined. Assuming 'pressure'.") + return "pressure" - #Need mid-level pressure for vertical interpolation of height levels: - if 'PMID' in model_dataset: - mpmid = model_dataset['PMID'] - else: - #Check if target has an associated surface pressure field: - if pmid_file: - mpmid_ds = xr.open_dataset(pmid_file) - mpmid = mpmid_ds['PMID'] - #This mid-level pressure field has already been regridded: - regridded_pmid = True - else: - print(f"!! PROBLEM -- NO PMID for 3-D variable {var_name}, so it will not be re-gridded.") - return None - #End if - #End if - #End if (vert_coord_type) - - #It is probably good to try and acquire PS for all vertical coordinate types, so try here: - regridded_ps = False - if 'PS' in model_dataset: - mps = model_dataset['PS'] - else: - #Check if target has an associated surface pressure field: - if ps_file: - mps_ds = xr.open_dataset(ps_file) - mps = mps_ds['PS'] - #This surface pressure field has already been regridded: - regridded_ps = True - else: - print(f"!! PROBLEM -- NO PS for 3-D variable {var_name}, so it will not be re-gridded.") - return None - #End if - #End if - #End if (has_lev) - - #Regrid variable to target dataset (if available): - if regrid_dataset: - - #Extract grid info from target data: - if 'time' in regrid_dataset.coords: - if 'lev' in regrid_dataset.coords: - tgrid = regrid_dataset.isel(time=0, lev=0).squeeze() - else: - tgrid = regrid_dataset.isel(time=0).squeeze() - #End if - #End if - - #Regrid model data to match target grid: - rgdata = regrid_data(mdata, tgrid, method=1) - if mdat_ofrac: - rgofrac = regrid_data(mdat_ofrac, tgrid, method=1) - #Regrid surface pressure if need be: - if has_lev: - if not regridded_ps: - rg_ps = regrid_data(mps, tgrid, method=1) - else: - rg_ps = mps - #End if + return 'none' - #Also regrid mid-level pressure if need be: - if vert_coord_type == "height": - if not regridded_pmid: - rg_pmid = regrid_data(mpmid, tgrid, method=1) - else: - rg_pmid = mpmid - #End if - #End if - #End if - else: - #Just rename variables: - rgdata = mdata - if has_lev: - rg_ps = mps - if vert_coord_type == "height": - rg_pmid = mpmid - #End if - #End if - #End if +def _handle_vertical_interpolation(da, vert_type, source_ds, ps_da=None): + """ + Performs vertical interpolation to default pressure levels. - #Vertical interpolation: - - #Interpolate variable to default pressure levels: - if has_lev: - - if vert_coord_type == "hybrid": - #Interpolate from hybrid sigma-pressure to the standard pressure levels: - rgdata_interp = utils.lev_to_plev(rgdata, rg_ps, mhya, mhyb, P0=P0, \ - convert_to_mb=True) - elif vert_coord_type == "height": - #Interpolate variable using mid-level pressure (PMID): - rgdata_interp = utils.pmid_to_plev(rgdata, rg_pmid, convert_to_mb=True) - else: - #The vertical coordinate type is un-recognized, so print warning and - #skip vertical interpolation: - wmsg = f"WARNING! Un-recognized vertical coordinate type: '{vert_coord_type}'," - wmsg += f" for variable '{var_name}'. Skipping vertical interpolation." - print(wmsg) - #Don't process variable: - return None - #End if - else: - #Just rename variable: - rgdata_interp = rgdata - #End if + Parameters + ---------- + da : xarray.DataArray + The DataArray to interpolate. + vert_type : str + The vertical coordinate type ('hybrid', 'height', 'pressure'). + source_ds : xarray.Dataset + The source dataset containing auxiliary variables (e.g., hyam, hybm). + ps_da : xarray.DataArray, optional + Surface pressure DataArray, required for hybrid coordinates. - #Convert to xarray dataset: - rgdata_interp = rgdata_interp.to_dataset() - if mdat_ofrac: - rgdata_interp['OCNFRAC'] = rgofrac + Returns + ------- + xarray.DataArray + The vertically interpolated DataArray. + """ + if vert_type == 'none': + return da + + if vert_type == "hybrid": + if ps_da is None: + raise ValueError("Surface pressure ('PS') is required for hybrid vertical interpolation.") + + lev_coord_name = 'lev' if 'lev' in source_ds.dims else 'ilev' + hyam_name = 'hyam' if lev_coord_name == 'lev' else 'hyai' + hybm_name = 'hybm' if lev_coord_name == 'lev' else 'hybi' + + if hyam_name not in source_ds or hybm_name not in source_ds: + raise ValueError(f"Hybrid coefficients ('{hyam_name}', '{hybm_name}') not found in dataset.") + + hyam = source_ds[hyam_name] + hybm = source_ds[hybm_name] + + if 'time' in hyam.dims: + hyam = hyam.isel(time=0).squeeze() + if 'time' in hybm.dims: + hybm = hybm.isel(time=0).squeeze() + + p0 = source_ds.get('P0', 100000.0) + if isinstance(p0, xr.DataArray): + p0 = p0.values[0] + + # hot fix for lev attributes + da[lev_coord_name].attrs["axis"] = "Z" + da[lev_coord_name].attrs["positive"] = "down" # standard for pressure/hybrid + da[lev_coord_name].attrs["standard_name"] = "atmosphere_hybrid_sigma_pressure_coordinate" + + return utils.lev_to_plev(da, ps_da, hyam, hybm, P0=p0, convert_to_mb=True, new_levels=DEFAULT_PLEVS_Pa) + + elif vert_type == "height": + pmid = source_ds.get('PMID') + if pmid is None: + raise ValueError("'PMID' is required for height vertical interpolation.") + return utils.pmid_to_plev(da, pmid, convert_to_mb=True, new_levels=DEFAULT_PLEVS_Pa) + + elif vert_type == "pressure": + return utils.plev_to_plev(da, new_levels=DEFAULT_PLEVS_Pa, convert_to_mb=True) - #Add surface pressure to variable if a hybrid (just in case): - if has_lev: - rgdata_interp['PS'] = rg_ps + else: + raise ValueError(f"Unknown vertical coordinate type: '{vert_type}'") - #Update "vert_coord" attribute for variable "lev": - rgdata_interp['lev'].attrs.update({"vert_coord": "pressure"}) - #End if +def _apply_ocean_mask(da, ocn_frac_da): + """ + Applies an ocean mask to a DataArray. - #Return dataset: - return rgdata_interp + Parameters + ---------- + da : xarray.DataArray + The DataArray to mask. + ocn_frac_da : xarray.DataArray + The ocean fraction DataArray. -##### + Returns + ------- + xarray.DataArray + The masked DataArray. + """ + # Ensure ocean fraction is between 0 and 1 + ocn_frac_da = ocn_frac_da.clip(0, 1) + + # Apply the mask + return utils.mask_land_or_ocean(da, ocn_frac_da) def save_to_nc(tosave, outname, attrs=None, proc=None): """Saves xarray variable to new netCDF file @@ -684,70 +411,6 @@ def save_to_nc(tosave, outname, attrs=None, proc=None): if attrs is not None: xo.attrs = attrs if proc is not None: + origname = tosave.attrs.get('climatology_files', 'unknown') xo.attrs['Processing_info'] = f"Start from file {origname}. " + proc xo.to_netcdf(outname, format='NETCDF4', encoding=enc) - -##### - -def regrid_data(fromthis, tothis, method=1): - """Regrid between lat-lon grids using various different methods - - Parameters - ---------- - fromthis : xarray.DataArray - original data - tothis : xarray.DataArray - provides destination grid information (regular lat-lon) - method : int, optional - method to use for regridding - 1 - xarray, `interp_like` - 2 - xarray, `interp` - 3 - xESMF, `Regridder()` - 4 - GeoCAT, `linint2` (may be deprecated) - - Returns - ------- - xarray.DataArray - Data interpolated to destination grid - - Notes - ----- - 1. xarray's interpolation does not respect longitude's periodicity - 2. xESMF can sometimes malfunction depending on dependencies - 3. GeoCAT `linint2` might be deprecated - - A more robust regridding solution is being explored. - - """ - - if method == 1: - # kludgy: spatial regridding only, seems like can't automatically deal with time - if 'time' in fromthis.coords: - result = [fromthis.isel(time=t).interp_like(tothis) for t,time in enumerate(fromthis['time'])] - result = xr.concat(result, 'time') - return result - else: - return fromthis.interp_like(tothis) - elif method == 2: - newlat = tothis['lat'] - newlon = tothis['lon'] - coords = dict(fromthis.coords) - coords['lat'] = newlat - coords['lon'] = newlon - return fromthis.interp(coords) - elif method == 3: - newlat = tothis['lat'] - newlon = tothis['lon'] - ds_out = xr.Dataset({'lat': newlat, 'lon': newlon}) - regridder = xe.Regridder(fromthis, ds_out, 'bilinear') - return regridder(fromthis) - elif method==4: - # geocat - newlat = tothis['lat'] - newlon = tothis['lon'] - result = geocat.comp.linint2(fromthis, newlon, newlat, False) - result.name = fromthis.name - return result - #End if - -##### \ No newline at end of file diff --git a/scripts/regridding/regrid_and_vert_interp.py.OLD b/scripts/regridding/regrid_and_vert_interp.py.OLD new file mode 100644 index 000000000..04e3cd279 --- /dev/null +++ b/scripts/regridding/regrid_and_vert_interp.py.OLD @@ -0,0 +1,753 @@ +"""Driver for horizontal and vertical interpolation. +""" +import xarray as xr +import adf_utils as utils + +def regrid_and_vert_interp(adf): + + """ + Regrids the test cases to the same horizontal + grid as the reference climatology and vertically + interpolates the test case (and reference if needed) + to match a default set of pressure levels (in hPa). + + Parameters + ---------- + adf + The ADF object + + + Notes + ----- + Default pressure levels: + 1000, 925, 850, 700, 500, 400, 300, 250, 200, 150, 100, 70, 50, + 30, 20, 10, 7, 5, 3, 2, 1 + + Currently any 3-D observations file needs to have equivalent pressure + levels in order to work properly, although in the future it is hoped + to enable the vertical interpolation of observations as well. + """ + + #Import necessary modules: + + from pathlib import Path + + # regridding + # Try just using the xarray method + # import xesmf as xe # This package is for regridding, and is just one potential solution. + + # Steps: + # - load climo files for model and obs + # - calculate all-time and seasonal fields (from individual months) + # - regrid one to the other (probably should be a choice) + + #Notify user that script has started: + msg = "\n Regridding CAM climatologies..." + print(f"{msg}\n {'-' * (len(msg)-3)}") + + #Extract needed quantities from ADF object: + #----------------------------------------- + overwrite_regrid = adf.get_basic_info("cam_overwrite_regrid", required=True) + output_loc = adf.get_basic_info("cam_regrid_loc", required=True) + var_list = adf.diag_var_list + var_defaults = adf.variable_defaults + + #CAM simulation variables (these quantities are always lists): + case_names = adf.get_cam_info("cam_case_name", required=True) + input_climo_locs = adf.get_cam_info("cam_climo_loc", required=True) + + #Grab case years + syear_cases = adf.climo_yrs["syears"] + eyear_cases = adf.climo_yrs["eyears"] + + #Check if mid-level pressure, ocean fraction or land fraction exist + #in the variable list: + for var in ["PMID", "OCNFRAC", "LANDFRAC"]: + if var in var_list: + #If so, then move them to the front of variable list so + #that they can be used to mask or vertically interpolate + #other model variables if need be: + var_idx = var_list.index(var) + var_list.pop(var_idx) + var_list.insert(0,var) + #End if + #End for + + #Create new variables that potentially stores the re-gridded + #ocean/land fraction dataset: + ocn_frc_ds = None + tgt_ocn_frc_ds = None + + #Check if surface pressure exists in variable list: + if "PS" in var_list: + #If so, then move it to front of variable list so that + #it can be used to vertically interpolate model variables + #if need be. This should be done after PMID so that the order + #is PS, PMID, other variables: + ps_idx = var_list.index("PS") + var_list.pop(ps_idx) + var_list.insert(0,"PS") + #End if + + #Regrid target variables (either obs or a baseline run): + if adf.compare_obs: + + #Set obs name to match baseline (non-obs) + target_list = ["Obs"] + + #Extract variable-obs dictionary: + var_obs_dict = adf.var_obs_dict + + #If dictionary is empty, then there are no observations to regrid to, + #so quit here: + if not var_obs_dict: + print("\t No observations found to regrid to, so no re-gridding will be done.") + return + #End if + + else: + + #Extract model baseline variables: + target_loc = adf.get_baseline_info("cam_climo_loc", required=True) + target_list = [adf.get_baseline_info("cam_case_name", required=True)] + #End if + + #Grab baseline years (which may be empty strings if using Obs): + syear_baseline = adf.climo_yrs["syear_baseline"] + eyear_baseline = adf.climo_yrs["eyear_baseline"] + + #Set attributes dictionary for climo years to save in the file attributes + base_climo_yrs_attr = f"{target_list[0]}: {syear_baseline}-{eyear_baseline}" + + #----------------------------------------- + + #Set output/target data path variables: + #------------------------------------ + rgclimo_loc = Path(output_loc) + if not adf.compare_obs: + tclimo_loc = Path(target_loc) + #------------------------------------ + + #Check if re-gridded directory exists, and if not, then create it: + if not rgclimo_loc.is_dir(): + print(f" {rgclimo_loc} not found, making new directory") + rgclimo_loc.mkdir(parents=True) + #End if + + #Loop over CAM cases: + for case_idx, case_name in enumerate(case_names): + + #Notify user of model case being processed: + print(f"\t Regridding case '{case_name}' :") + + #Set case climo data path: + mclimo_loc = Path(input_climo_locs[case_idx]) + + #Create empty dictionaries which store the locations of regridded surface + #pressure and mid-level pressure fields: + ps_loc_dict = {} + pmid_loc_dict = {} + + #Get climo years for case + syear = syear_cases[case_idx] + eyear = eyear_cases[case_idx] + + # probably want to do this one variable at a time: + for var in var_list: + + if adf.compare_obs: + #Check if obs exist for the variable: + if var in var_obs_dict: + #Note: In the future these may all be lists, but for + #now just convert the target_list. + #Extract target file: + tclimo_loc = var_obs_dict[var]["obs_file"] + #Extract target list (eventually will be a list, for now need to convert): + target_list = [var_obs_dict[var]["obs_name"]] + else: + dmsg = f"No obs found for variable `{var}`, regridding skipped." + adf.debug_log(dmsg) + continue + #End if + #End if + + #Notify user of variable being regridded: + print(f"\t - regridding {var} (known targets: {target_list})") + + #loop over regridding targets: + for target in target_list: + + #Write to debug log if enabled: + adf.debug_log(f"regrid_example: regrid target = {target}") + + #Determine regridded variable file name: + regridded_file_loc = rgclimo_loc / f'{target}_{case_name}_{var}_regridded.nc' + + #If surface or mid-level pressure, then save for potential use by other variables: + if var == "PS": + ps_loc_dict[target] = regridded_file_loc + elif var == "PMID": + pmid_loc_dict[target] = regridded_file_loc + #End if + + #Check if re-gridded file already exists and over-writing is allowed: + if regridded_file_loc.is_file() and overwrite_regrid: + #If so, then delete current file: + regridded_file_loc.unlink() + #End if + + #Check again if re-gridded file already exists: + if not regridded_file_loc.is_file(): + + #Create list of regridding target files (we should explore intake as an alternative to having this kind of repeated code) + # NOTE: This breaks if you have files from different cases in same directory! + if adf.compare_obs: + #For now, only grab one file (but convert to list for use below): + tclim_fils = [tclimo_loc] + else: + tclim_fils = sorted(tclimo_loc.glob(f"{target}*_{var}_climo.nc")) + #End if + + #Write to debug log if enabled: + adf.debug_log(f"regrid_example: tclim_fils (n={len(tclim_fils)}): {tclim_fils}") + + if len(tclim_fils) > 1: + #Combine all target files together into a single data set: + tclim_ds = xr.open_mfdataset(tclim_fils, combine='by_coords') + elif len(tclim_fils) == 0: + print(f"\t WARNING: regridding {var} failed, no climo file for case '{target}'. Continuing to next variable.") + continue + else: + #Open single file as new xarray dataset: + tclim_ds = xr.open_dataset(tclim_fils[0]) + #End if + + #Generate CAM climatology (climo) file list: + mclim_fils = sorted(mclimo_loc.glob(f"{case_name}_{var}_*.nc")) + + if len(mclim_fils) > 1: + #Combine all cam files together into a single data set: + mclim_ds = xr.open_mfdataset(mclim_fils, combine='by_coords') + elif len(mclim_fils) == 0: + #wmsg = f"\t WARNING: Unable to find climo file for '{var}'." + #wmsg += " Continuing to next variable." + wmsg= f"\t WARNING: regridding {var} failed, no climo file for case '{case_name}'. Continuing to next variable." + print(wmsg) + continue + else: + #Open single file as new xarray dataset: + mclim_ds = xr.open_dataset(mclim_fils[0]) + #End if + + #Create keyword arguments dictionary for regridding function: + regrid_kwargs = {} + + #Check if target in relevant pressure variable dictionaries: + if target in ps_loc_dict: + regrid_kwargs.update({'ps_file': ps_loc_dict[target]}) + #End if + if target in pmid_loc_dict: + regrid_kwargs.update({'pmid_file': pmid_loc_dict[target]}) + #End if + + #Perform regridding and interpolation of variable: + rgdata_interp = _regrid_and_interpolate_levs(mclim_ds, var, + regrid_dataset=tclim_ds, + **regrid_kwargs) + + #Extract defaults for variable: + var_default_dict = var_defaults.get(var, {}) + + if 'mask' in var_default_dict: + if var_default_dict['mask'].lower() == 'ocean': + #Check if the ocean fraction has already been regridded + #and saved: + if ocn_frc_ds: + ofrac = ocn_frc_ds['OCNFRAC'] + # set the bounds of regridded ocnfrac to 0 to 1 + ofrac = xr.where(ofrac>1,1,ofrac) + ofrac = xr.where(ofrac<0,0,ofrac) + + # apply ocean fraction mask to variable + rgdata_interp['OCNFRAC'] = ofrac + var_tmp = rgdata_interp[var] + var_tmp = utils.mask_land_or_ocean(var_tmp,ofrac) + rgdata_interp[var] = var_tmp + else: + print(f"\t WARNING: OCNFRAC not found, unable to apply mask to '{var}'") + #End if + else: + #Currently only an ocean mask is supported, so print warning here: + wmsg = "\t WARNING: Currently the only variable mask option is 'ocean'," + wmsg += f"not '{var_default_dict['mask'].lower()}'" + print(wmsg) + #End if + #End if + + #If the variable is ocean fraction, then save the dataset for use later: + if var == 'OCNFRAC': + ocn_frc_ds = rgdata_interp + #End if + + #Finally, write re-gridded data to output file: + #Convert the list of Path objects to a list of strings + climatology_files_str = [str(path) for path in mclim_fils] + climatology_files_str = ', '.join(climatology_files_str) + test_attrs_dict = { + "adf_user": adf.user, + "climo_yrs": f"{case_name}: {syear}-{eyear}", + "climatology_files": climatology_files_str, + } + rgdata_interp = rgdata_interp.assign_attrs(test_attrs_dict) + save_to_nc(rgdata_interp, regridded_file_loc) + rgdata_interp.close() # bpm: we are completely done with this data + + #Now vertically interpolate baseline (target) climatology, + #if applicable: + + #Set interpolated baseline file name: + interp_bl_file = rgclimo_loc / f'{target}_{var}_baseline.nc' + + if not adf.compare_obs and not interp_bl_file.is_file(): + + #Look for a baseline climo file for surface pressure (PS): + bl_ps_fil = tclimo_loc / f'{target}_PS_climo.nc' + + #Also look for a baseline climo file for mid-level pressure (PMID): + bl_pmid_fil = tclimo_loc / f'{target}_PMID_climo.nc' + + #Create new keyword arguments dictionary for regridding function: + regrid_kwargs = {} + + #Check if PS and PMID files exist: + if bl_ps_fil.is_file(): + regrid_kwargs.update({'ps_file': bl_ps_fil}) + #End if + if bl_pmid_fil.is_file(): + regrid_kwargs.update({'pmid_file': bl_pmid_fil}) + #End if + + #Generate vertically-interpolated baseline dataset: + tgdata_interp = _regrid_and_interpolate_levs(tclim_ds, var, + **regrid_kwargs) + + if tgdata_interp is None: + #Something went wrong during interpolation, so just cycle through + #for now: + continue + #End if + + #If the variable is ocean fraction, then save the dataset for use later: + if var == 'OCNFRAC': + tgt_ocn_frc_ds = tgdata_interp + #End if + + if 'mask' in var_default_dict: + if var_default_dict['mask'].lower() == 'ocean': + #Check if the ocean fraction has already been regridded + #and saved: + if tgt_ocn_frc_ds: + ofrac = tgt_ocn_frc_ds['OCNFRAC'] + # set the bounds of regridded ocnfrac to 0 to 1 + ofrac = xr.where(ofrac>1,1,ofrac) + ofrac = xr.where(ofrac<0,0,ofrac) + # mask the land in TS for global means + tgdata_interp['OCNFRAC'] = ofrac + ts_tmp = tgdata_interp[var] + ts_tmp = utils.mask_land_or_ocean(ts_tmp,ofrac) + tgdata_interp[var] = ts_tmp + else: + wmsg = "\t WARNING: OCNFRAC not found in target," + wmsg += f" unable to apply mask to '{var}'" + print(wmsg) + #End if + #End if + #End if + + # Convert the list to a string (join with commas or another separator) + climatology_files_str = [str(path) for path in tclim_fils] + climatology_files_str = ', '.join(climatology_files_str) + # Create a dictionary of attributes + base_attrs_dict = { + "adf_user": adf.user, + "climo_yrs": f"{case_name}: {syear}-{eyear}; {base_climo_yrs_attr}", + "climatology_files": climatology_files_str, + } + tgdata_interp = tgdata_interp.assign_attrs(base_attrs_dict) + + #Write interpolated baseline climatology to file: + save_to_nc(tgdata_interp, interp_bl_file) + #End if + else: + print("\t INFO: Regridded file already exists, so skipping...") + #End if (file check) + #End do (target list) + #End do (variable list) + #End do (case list) + + #Notify user that script has ended: + print(" ...CAM climatologies have been regridded successfully.") + +################# +#Helper functions +################# + +def _regrid_and_interpolate_levs(model_dataset, var_name, regrid_dataset=None, **kwargs): + + """ + Function that takes a variable from a model xarray + dataset, regrids it to another dataset's lat/lon + coordinates (if applicable), and then interpolates + it vertically to a set of pre-defined pressure levels. + + Parameters + ---------- + model_dataset : xarray.Dataset + The xarray dataset which contains the model variable data + var_name : str + The name of the variable to be regridded/interpolated. + regrid_dataset : xr.Dataset or xr.DataArray, optional + The xarray object that contains the destination lat/lon grid + If not present then only vertical interpolation will be performed. + **kwargs + Additional optional arguments: + - `ps_file` : str or Path + specify surface pressure netCDF file + - `pmid_file` : str or Path + specify vertical layer midpoint pressure netCDF file + + Returns + ------- + xarray.Dataset + This function returns a new xarray dataset that contains the regridded + and/or vertically-interpolated model variable. + """ + + #Import ADF-specific functions: + import plotting_functions as pf + + #Extract keyword arguments: + if 'ps_file' in kwargs: + ps_file = kwargs['ps_file'] + else: + ps_file = None + #End if + if 'pmid_file' in kwargs: + pmid_file = kwargs['pmid_file'] + else: + pmid_file = None + #End if + + #Extract variable info from model data (and remove any degenerate dimensions): + mdata = model_dataset[var_name].squeeze() + mdat_ofrac = None + #if regrid_ofrac: + # if 'OCNFRAC' in model_dataset: + # mdat_ofrac = model_dataset['OCNFRAC'].squeeze() + + #Check if variable has a vertical component: + if 'lev' in mdata.dims or 'ilev' in mdata.dims: + has_lev = True + + #If lev exists, then determine what kind of vertical coordinate + #is being used: + if 'lev' in mdata.dims: + lev_attrs = model_dataset['lev'].attrs + elif 'ilev' in mdata.dims: + lev_attrs = model_dataset['ilev'].attrs + + #First check if there is a "vert_coord" attribute: + if 'vert_coord' in lev_attrs: + vert_coord_type = lev_attrs['vert_coord'] + else: + #Next check that the "long_name" attribute exists: + if 'long_name' in lev_attrs: + #Extract long name: + lev_long_name = lev_attrs['long_name'] + + #Check for "keywords" in the long name: + if 'hybrid level' in lev_long_name: + #Set model to hybrid vertical levels: + vert_coord_type = "hybrid" + elif 'zeta level' in lev_long_name: + #Set model to height (z) vertical levels: + vert_coord_type = "height" + else: + #Print a warning, and skip variable re-gridding/interpolation: + wmsg = "WARNING! Unable to determine the vertical coordinate" + wmsg +=f" type from the 'lev' long name, which is:\n'{lev_long_name}'" + print(wmsg) + return None + #End if + + else: + #Print a warning, and assume hybrid levels (for now): + wmsg = "WARNING! No long name found for the 'lev' dimension," + wmsg += f" so no re-gridding/interpolation will be done." + print(wmsg) + return None + #End if + #End if + + else: + has_lev = False + #End if + + #Check if variable has a vertical levels dimension: + if has_lev: + + if vert_coord_type == "hybrid": + # Need hyam, hybm, and P0 for vertical interpolation of hybrid levels: + if 'lev' in mdata.dims: + if ('hyam' not in model_dataset) or ('hybm' not in model_dataset): + print(f"!! PROBLEM -- NO hyam or hybm for 3-D variable {var_name}, so it will not be re-gridded.") + return None #Return None to skip to next variable. + #End if + mhya = model_dataset['hyam'] + mhyb = model_dataset['hybm'] + elif 'ilev' in mdata.dims: + if ('hyai' not in model_dataset) or ('hybi' not in model_dataset): + print(f"!! PROBLEM -- NO hyai or hybi for 3-D variable {var_name}, so it will not be re-gridded.") + return None #Return None to skip to next variable. + #End if + mhya = model_dataset['hyai'] + mhyb = model_dataset['hybi'] + if 'time' in mhya.dims: + mhya = mhya.isel(time=0).squeeze() + if 'time' in mhyb.dims: + mhyb = mhyb.isel(time=0).squeeze() + if 'P0' in model_dataset: + P0_tmp = model_dataset['P0'] + if isinstance(P0_tmp, xr.DataArray): + #All of these value should be the same, + #so just grab the first one: + P0 = P0_tmp[0] + else: + #Just rename variable: + P0 = P0_tmp + #End if + else: + P0 = 100000.0 # Pa + #End if + + elif vert_coord_type == "height": + #Initialize already-regridded PMID logical: + regridded_pmid = False + + #Need mid-level pressure for vertical interpolation of height levels: + if 'PMID' in model_dataset: + mpmid = model_dataset['PMID'] + else: + #Check if target has an associated surface pressure field: + if pmid_file: + mpmid_ds = xr.open_dataset(pmid_file) + mpmid = mpmid_ds['PMID'] + #This mid-level pressure field has already been regridded: + regridded_pmid = True + else: + print(f"!! PROBLEM -- NO PMID for 3-D variable {var_name}, so it will not be re-gridded.") + return None + #End if + #End if + #End if (vert_coord_type) + + #It is probably good to try and acquire PS for all vertical coordinate types, so try here: + regridded_ps = False + if 'PS' in model_dataset: + mps = model_dataset['PS'] + else: + #Check if target has an associated surface pressure field: + if ps_file: + mps_ds = xr.open_dataset(ps_file) + mps = mps_ds['PS'] + #This surface pressure field has already been regridded: + regridded_ps = True + else: + print(f"!! PROBLEM -- NO PS for 3-D variable {var_name}, so it will not be re-gridded.") + return None + #End if + #End if + #End if (has_lev) + + #Regrid variable to target dataset (if available): + if regrid_dataset: + + #Extract grid info from target data: + if 'time' in regrid_dataset.coords: + if 'lev' in regrid_dataset.coords: + tgrid = regrid_dataset.isel(time=0, lev=0).squeeze() + else: + tgrid = regrid_dataset.isel(time=0).squeeze() + #End if + #End if + + #Regrid model data to match target grid: + rgdata = regrid_data(mdata, tgrid, method=1) + if mdat_ofrac: + rgofrac = regrid_data(mdat_ofrac, tgrid, method=1) + #Regrid surface pressure if need be: + if has_lev: + if not regridded_ps: + rg_ps = regrid_data(mps, tgrid, method=1) + else: + rg_ps = mps + #End if + + #Also regrid mid-level pressure if need be: + if vert_coord_type == "height": + if not regridded_pmid: + rg_pmid = regrid_data(mpmid, tgrid, method=1) + else: + rg_pmid = mpmid + #End if + #End if + #End if + else: + #Just rename variables: + rgdata = mdata + if has_lev: + rg_ps = mps + if vert_coord_type == "height": + rg_pmid = mpmid + #End if + #End if + #End if + + #Vertical interpolation: + + #Interpolate variable to default pressure levels: + if has_lev: + + if vert_coord_type == "hybrid": + #Interpolate from hybrid sigma-pressure to the standard pressure levels: + rgdata_interp = utils.lev_to_plev(rgdata, rg_ps, mhya, mhyb, P0=P0, \ + convert_to_mb=True) + elif vert_coord_type == "height": + #Interpolate variable using mid-level pressure (PMID): + rgdata_interp = utils.pmid_to_plev(rgdata, rg_pmid, convert_to_mb=True) + else: + #The vertical coordinate type is un-recognized, so print warning and + #skip vertical interpolation: + wmsg = f"WARNING! Un-recognized vertical coordinate type: '{vert_coord_type}'," + wmsg += f" for variable '{var_name}'. Skipping vertical interpolation." + print(wmsg) + #Don't process variable: + return None + #End if + else: + #Just rename variable: + rgdata_interp = rgdata + #End if + + #Convert to xarray dataset: + rgdata_interp = rgdata_interp.to_dataset() + if mdat_ofrac: + rgdata_interp['OCNFRAC'] = rgofrac + + #Add surface pressure to variable if a hybrid (just in case): + if has_lev: + rgdata_interp['PS'] = rg_ps + + #Update "vert_coord" attribute for variable "lev": + rgdata_interp['lev'].attrs.update({"vert_coord": "pressure"}) + #End if + + #Return dataset: + return rgdata_interp + +##### + +def save_to_nc(tosave, outname, attrs=None, proc=None): + """Saves xarray variable to new netCDF file + + Parameters + ---------- + tosave : xarray.Dataset or xarray.DataArray + data to write to file + outname : str or Path + output netCDF file path + attrs : dict, optional + attributes dictionary for data + proc : str, optional + string to append to "Processing_info" attribute + """ + + xo = tosave + # deal with getting non-nan fill values. + if isinstance(xo, xr.Dataset): + enc_dv = {xname: {'_FillValue': None} for xname in xo.data_vars} + else: + enc_dv = {} + #End if + enc_c = {xname: {'_FillValue': None} for xname in xo.coords} + enc = {**enc_c, **enc_dv} + if attrs is not None: + xo.attrs = attrs + if proc is not None: + xo.attrs['Processing_info'] = f"Start from file {origname}. " + proc + xo.to_netcdf(outname, format='NETCDF4', encoding=enc) + +##### + +def regrid_data(fromthis, tothis, method=1): + """Regrid between lat-lon grids using various different methods + + Parameters + ---------- + fromthis : xarray.DataArray + original data + tothis : xarray.DataArray + provides destination grid information (regular lat-lon) + method : int, optional + method to use for regridding + 1 - xarray, `interp_like` + 2 - xarray, `interp` + 3 - xESMF, `Regridder()` + 4 - GeoCAT, `linint2` (may be deprecated) + + Returns + ------- + xarray.DataArray + Data interpolated to destination grid + + Notes + ----- + 1. xarray's interpolation does not respect longitude's periodicity + 2. xESMF can sometimes malfunction depending on dependencies + 3. GeoCAT `linint2` might be deprecated + + A more robust regridding solution is being explored. + + """ + + if method == 1: + # kludgy: spatial regridding only, seems like can't automatically deal with time + if 'time' in fromthis.coords: + result = [fromthis.isel(time=t).interp_like(tothis) for t,time in enumerate(fromthis['time'])] + result = xr.concat(result, 'time') + return result + else: + return fromthis.interp_like(tothis) + elif method == 2: + newlat = tothis['lat'] + newlon = tothis['lon'] + coords = dict(fromthis.coords) + coords['lat'] = newlat + coords['lon'] = newlon + return fromthis.interp(coords) + elif method == 3: + newlat = tothis['lat'] + newlon = tothis['lon'] + ds_out = xr.Dataset({'lat': newlat, 'lon': newlon}) + regridder = xe.Regridder(fromthis, ds_out, 'bilinear') + return regridder(fromthis) + elif method==4: + # geocat + newlat = tothis['lat'] + newlon = tothis['lon'] + result = geocat.comp.linint2(fromthis, newlon, newlat, False) + result.name = fromthis.name + return result + #End if + +##### \ No newline at end of file diff --git a/scripts/regridding/regrid_and_vert_interp_2.py b/scripts/regridding/regrid_and_vert_interp_2.py deleted file mode 100644 index 240f7f27a..000000000 --- a/scripts/regridding/regrid_and_vert_interp_2.py +++ /dev/null @@ -1,384 +0,0 @@ -"""Driver for horizontal and vertical interpolation. -""" -from pathlib import Path - -import numpy as np -import xarray as xr -import xesmf as xe - -import adf_utils as utils - - -# Default pressure levels for vertical interpolation -DEFAULT_PLEVS = [ - 1000, 925, 850, 700, 500, 400, 300, 250, 200, 150, 100, 70, 50, - 30, 20, 10, 7, 5, 3, 2, 1 -] - -def regrid_and_vert_interp_2(adf): - """ - Regrids the test cases to the same horizontal - grid as the reference climatology and vertically - interpolates the test case (and reference if needed) - to match a default set of pressure levels (in hPa). - """ - msg = "\n Regridding CAM climatologies..." - print(f"{msg}\n {'-' * (len(msg)-3)}") - - overwrite_regrid = adf.get_basic_info("cam_overwrite_regrid", required=True) - output_loc = adf.get_basic_info("cam_regrid_loc", required=True) - output_loc = [Path(i) for i in output_loc] - var_list = adf.diag_var_list - var_defaults = adf.variable_defaults - - case_names = adf.get_cam_info("cam_case_name", required=True) - syear_cases = adf.climo_yrs["syears"] - eyear_cases = adf.climo_yrs["eyears"] - - # Move critical variables to the front of the list - for var in ["PMID", "OCNFRAC", "LANDFRAC", "PS"]: - if var in var_list: - var_list.insert(0, var_list.pop(var_list.index(var))) - - for case_idx, case_name in enumerate(case_names): - print(f"\t Regridding case '{case_name}':") - syear = syear_cases[case_idx] - eyear = eyear_cases[case_idx] - case_output_loc = output_loc[case_idx] - case_output_loc.mkdir(parents=True, exist_ok=True) - - for var in var_list: - print(f"Regridding variable: {var}") - # reset variables - model_ds = None - ref_ds = None - target_name = None - regridded_file_loc = None - model_da = None - ref_da = None - regridder = None - interp_da = None - - model_ds = adf.data.load_climo_ds(case_name, var) - if var in adf.data.ref_var_nam: - ref_ds = adf.data.load_reference_climo_ds(adf.data.ref_case_label, var) - target_name = adf.data.ref_labels[var] - else: - print(f"No reference data available for {var}.") - continue - if not ref_ds: - print(f"Missing reference data for {var}. Skipping.") - continue - if not model_ds: - print(f"Missing model data for {var}. Skipping.") - continue - - regridded_file_loc = case_output_loc / f'{target_name}_{case_name}_{var}_regridded.nc' - - if regridded_file_loc.is_file() and not overwrite_regrid: - print(f"\t INFO: Regridded file already exists, skipping: {regridded_file_loc}") - continue - - if regridded_file_loc.is_file() and overwrite_regrid: - regridded_file_loc.unlink() - - model_da = model_ds[var].squeeze() - ref_da = ref_ds[adf.data.ref_var_nam[var]].squeeze() - - # --- Horizontal Regridding --- - regridded_da = _handle_horizontal_regridding(model_da, ref_ds, adf, case_index=case_idx) - - # --- Vertical Interpolation --- - vert_type = _determine_vertical_coord_type(model_da) - - ps_da = None - if vert_type == 'hybrid': - # For hybrid, we need surface pressure on the target grid. - # It's assumed PS is processed first and is available. - ps_regridded_path = case_output_loc / f'{target_name}_{case_name}_PS_regridded.nc' - if ps_regridded_path.exists(): - ps_da = xr.open_dataset(ps_regridded_path)['PS'] - else: - # Regrid PS on the fly if not found - print("\t INFO: Regridding PS on the fly for hybrid interpolation.") - ps_da_source = adf.data.load_climo_da(case_name, 'PS')['PS'].squeeze() - ps_da = _handle_horizontal_regridding(ps_da_source, ref_da, adf, case_index=case_idx) - - interp_da = _handle_vertical_interpolation(regridded_da, vert_type, model_ds, ps_da=ps_da) - - # --- Masking --- - var_default_dict = var_defaults.get(var, {}) - if 'mask' in var_default_dict and var_default_dict['mask'].lower() == 'ocean': - ocn_frac_regridded_path = case_output_loc / f'{target_name}_{case_name}_OCNFRAC_regridded.nc' - if ocn_frac_regridded_path.exists(): - ocn_frac_da = xr.open_dataset(ocn_frac_regridded_path)['OCNFRAC'] - interp_da = _apply_ocean_mask(interp_da, ocn_frac_da) - else: - print(f"\t WARNING: OCNFRAC not found, unable to apply mask to '{var}'") - - # --- Save to file --- - final_ds = interp_da.to_dataset(name=var) - - # Add back other variables if they were in the original file (like PS, OCNFRAC) - if var == 'OCNFRAC': - final_ds = final_ds # it is already there - if var == 'PS': - final_ds = final_ds # it is already there - - - test_attrs_dict = { - "adf_user": adf.user, - "climo_yrs": f"{case_name}: {syear}-{eyear}", - "climatology_files": str(adf.data.get_climo_file(case_name, var)), - } - final_ds = final_ds.assign_attrs(test_attrs_dict) - - print(f"\t INFO: Saving regridded file: {regridded_file_loc}") - save_to_nc(final_ds, regridded_file_loc) - - print(" ...CAM climatologies have been regridded successfully.") - -def _handle_horizontal_regridding(source_da, target_grid, adf, method='conservative', case_index=None): - """ - Performs horizontal regridding using xesmf. - Manages and reuses regridding weight files. - - Parameters - ---------- - source_da : xarray.DataArray - The DataArray to regrid. - target_grid : xarray.Dataset - A dataset defining the target grid. - adf : adf_diag.AdfDiag - The ADF diagnostics object, used to get output locations. - method : str, optional - Regridding method. Defaults to 'conservative'. - case_index: str - For multi-case, need to provide the case name. - Returns - ------- - xarray.DataArray - The regridded DataArray. - """ - - # Generate a unique name for the weights file - source_grid_type = "unstructured" if "ncol" in source_da.dims else "structured" - target_grid_type = "unstructured" if "ncol" in target_grid.dims else "structured" - - # A simple naming convention for weight files. - source_grid_desc = f"{source_grid_type}_{len(source_da.lat)}_{len(source_da.lon)}" if source_grid_type == "structured" else f"{source_grid_type}_{len(source_da.ncol)}" - target_grid_desc = f"{target_grid_type}_{len(target_grid.lat)}_{len(target_grid.lon)}" if target_grid_type == "structured" else f"{target_grid_type}_{len(target_grid.ncol)}" - - if target_grid_type == "structured": - target_grid = _create_clean_grid(target_grid) - if source_grid_type == "structured": - source_grid = _create_clean_grid(source_da) - - regrid_loc = adf.get_basic_info("cam_regrid_loc", required=True) - if isinstance(regrid_loc, list) and len(regrid_loc)>1: - regrid_loc = regrid_loc[case_index] - else: - regrid_loc = regrid_loc[0] - regrid_loc = Path(regrid_loc) - regrid_weights_dir = regrid_loc / "regrid_weights" - regrid_weights_dir.mkdir(exist_ok=True) - weights_file = regrid_weights_dir / f"weights_{source_grid_desc}_to_{target_grid_desc}_{method}.nc" - if weights_file.exists(): - print(f"INFO: Using existing regridding weights file: {weights_file}") - # xesmf can accept a path to a weights file - regridder = xe.Regridder(source_da, target_grid, method, weights=str(weights_file)) - else: - print(f"INFO: Creating new regridding weights file: {weights_file}") - regridder = xe.Regridder(source_grid, target_grid, method) - regridder.to_netcdf(weights_file) - return regridder(source_da) - -def _create_clean_grid(ds): - """ - Creates a minimal, CF-compliant xarray Dataset for xesmf. - """ - - # Extract raw values - lat_centers = ds.lat.values - lon_centers = ds.lon.values - - # Build basic Dataset - clean_ds = xr.Dataset( - coords={ - "lat": (["lat"], lat_centers, {"units": "degrees_north", "standard_name": "latitude"}), - "lon": (["lon"], lon_centers, {"units": "degrees_east", "standard_name": "longitude"}), - } - ) - - # Add Bounds as vertices if they exist - if 'lat_bnds' in ds and 'lon_bnds' in ds: - lat_v = np.append(ds.lat_bnds.values[:, 0], ds.lat_bnds.values[-1, 1]) - lon_v = np.append(ds.lon_bnds.values[:, 0], ds.lon_bnds.values[-1, 1]) - - # Clip to avoid ESMF range errors - lat_v = np.clip(lat_v, -90, 90) - - # xesmf looks for 'lat_b' and 'lon_b' in the dataset for conservative regridding - clean_ds["lat_b"] = (["lat_f"], lat_v, {"units": "degrees_north"}) - clean_ds["lon_b"] = (["lon_f"], lon_v, {"units": "degrees_east"}) - - return clean_ds - -def _determine_vertical_coord_type(dset): - """ - Determines the type of vertical coordinate in a dataset. - - Parameters - ---------- - dset : xarray.Dataset - The dataset to inspect. - - Returns - ------- - str - The vertical coordinate type: 'hybrid', 'height', 'pressure', or 'none'. - """ - - if 'lev' in dset.dims or 'ilev' in dset.dims: - lev_coord_name = 'lev' if 'lev' in dset.dims else 'ilev' - lev_attrs = dset[lev_coord_name].attrs - - if 'vert_coord' in lev_attrs: - return lev_attrs['vert_coord'] - - if 'long_name' in lev_attrs: - lev_long_name = lev_attrs['long_name'] - if 'hybrid level' in lev_long_name: - return "hybrid" - if 'pressure level' in lev_long_name: - return "pressure" - if 'zeta level' in lev_long_name: - return "height" - - # If no specific metadata is found, make an educated guess. - # This part might need refinement based on expected data conventions. - if 'hyam' in dset or 'hyai' in dset: - return "hybrid" - - print(f"WARNING: Vertical coordinate type for '{lev_coord_name}' could not be determined. Assuming 'pressure'.") - return "pressure" - - return 'none' - -def _handle_vertical_interpolation(da, vert_type, source_ds, ps_da=None): - """ - Performs vertical interpolation to default pressure levels. - - Parameters - ---------- - da : xarray.DataArray - The DataArray to interpolate. - vert_type : str - The vertical coordinate type ('hybrid', 'height', 'pressure'). - source_ds : xarray.Dataset - The source dataset containing auxiliary variables (e.g., hyam, hybm). - ps_da : xarray.DataArray, optional - Surface pressure DataArray, required for hybrid coordinates. - - Returns - ------- - xarray.DataArray - The vertically interpolated DataArray. - """ - if vert_type == 'none': - return da - - if vert_type == "hybrid": - if ps_da is None: - raise ValueError("Surface pressure ('PS') is required for hybrid vertical interpolation.") - - lev_coord_name = 'lev' if 'lev' in source_ds.dims else 'ilev' - hyam_name = 'hyam' if lev_coord_name == 'lev' else 'hyai' - hybm_name = 'hybm' if lev_coord_name == 'lev' else 'hybi' - - if hyam_name not in source_ds or hybm_name not in source_ds: - raise ValueError(f"Hybrid coefficients ('{hyam_name}', '{hybm_name}') not found in dataset.") - - hyam = source_ds[hyam_name] - hybm = source_ds[hybm_name] - - if 'time' in hyam.dims: - hyam = hyam.isel(time=0).squeeze() - if 'time' in hybm.dims: - hybm = hybm.isel(time=0).squeeze() - - p0 = source_ds.get('P0', 100000.0) - if isinstance(p0, xr.DataArray): - p0 = p0.values[0] - - # hot fix for lev attributes - da[lev_coord_name].attrs["axis"] = "Z" - da[lev_coord_name].attrs["positive"] = "down" # standard for pressure/hybrid - da[lev_coord_name].attrs["standard_name"] = "atmosphere_hybrid_sigma_pressure_coordinate" - - return utils.lev_to_plev(da, ps_da, hyam, hybm, P0=p0, convert_to_mb=True, new_levels=DEFAULT_PLEVS) - - elif vert_type == "height": - pmid = source_ds.get('PMID') - if pmid is None: - raise ValueError("'PMID' is required for height vertical interpolation.") - return utils.pmid_to_plev(da, pmid, convert_to_mb=True, new_levels=DEFAULT_PLEVS) - - elif vert_type == "pressure": - return utils.plev_to_plev(da, new_levels=DEFAULT_PLEVS, convert_to_mb=True) - - else: - raise ValueError(f"Unknown vertical coordinate type: '{vert_type}'") - -def _apply_ocean_mask(da, ocn_frac_da): - """ - Applies an ocean mask to a DataArray. - - Parameters - ---------- - da : xarray.DataArray - The DataArray to mask. - ocn_frac_da : xarray.DataArray - The ocean fraction DataArray. - - Returns - ------- - xarray.DataArray - The masked DataArray. - """ - # Ensure ocean fraction is between 0 and 1 - ocn_frac_da = ocn_frac_da.clip(0, 1) - - # Apply the mask - return utils.mask_land_or_ocean(da, ocn_frac_da) - -def save_to_nc(tosave, outname, attrs=None, proc=None): - """Saves xarray variable to new netCDF file - - Parameters - ---------- - tosave : xarray.Dataset or xarray.DataArray - data to write to file - outname : str or Path - output netCDF file path - attrs : dict, optional - attributes dictionary for data - proc : str, optional - string to append to "Processing_info" attribute - """ - - xo = tosave - # deal with getting non-nan fill values. - if isinstance(xo, xr.Dataset): - enc_dv = {xname: {'_FillValue': None} for xname in xo.data_vars} - else: - enc_dv = {} - #End if - enc_c = {xname: {'_FillValue': None} for xname in xo.coords} - enc = {**enc_c, **enc_dv} - if attrs is not None: - xo.attrs = attrs - if proc is not None: - origname = tosave.attrs.get('climatology_files', 'unknown') - xo.attrs['Processing_info'] = f"Start from file {origname}. " + proc - xo.to_netcdf(outname, format='NETCDF4', encoding=enc) From bc838946da522fb81150e9ea73c26aeb0182ff55 Mon Sep 17 00:00:00 2001 From: Brian Medeiros Date: Thu, 15 Jan 2026 14:01:06 -0700 Subject: [PATCH 4/4] change logging level in taylor diagram --- scripts/plotting/cam_taylor_diagram.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/plotting/cam_taylor_diagram.py b/scripts/plotting/cam_taylor_diagram.py index 6228074db..d5d4b0234 100644 --- a/scripts/plotting/cam_taylor_diagram.py +++ b/scripts/plotting/cam_taylor_diagram.py @@ -39,9 +39,9 @@ logger = logging.getLogger(__name__) console_handler = logging.StreamHandler(sys.stdout) -console_handler.setLevel(logging.DEBUG) +console_handler.setLevel(logging.INFO) logger.addHandler(console_handler) -logger.setLevel(logging.DEBUG) +logger.setLevel(logging.INFO) logger.propagate = False