From 9244d6a73d09cd898d959d49bc8b464949940d76 Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Tue, 28 Oct 2025 20:30:16 -0600 Subject: [PATCH 01/24] inlined compute_energy --- PyOCN/c_src/ocn.c | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/PyOCN/c_src/ocn.c b/PyOCN/c_src/ocn.c index 8e0d75b..84c448a 100644 --- a/PyOCN/c_src/ocn.c +++ b/PyOCN/c_src/ocn.c @@ -52,6 +52,10 @@ static inline Status update_drained_area(FlowGrid *G, drainedarea_t da_inc, lini } double ocn_compute_energy(FlowGrid *G, double gamma){ + return compute_energy(G, gamma); +} + +static inline double compute_energy(FlowGrid *G, double gamma){ double energy = 0.0; for (linidx_t i = 0; i < (linidx_t)G->dims.row * (linidx_t)G->dims.col; i++){ energy += pow(G->vertices[i].drained_area, gamma); @@ -167,10 +171,10 @@ Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature){ upstream vertices) into this function, instead of just passing da_inc. */ if ((G->nroots > 1) && (gamma < 1.0)){ - // energy_old = ocn_compute_energy(G, gamma); // recompute energy from scratch + // energy_old = compute_energy(G, gamma); // recompute energy from scratch update_drained_area(G, -da_inc, a_down_old); // remove drainage from old path update_drained_area(G, da_inc, a_down_new); // add drainage to new path - energy_new = ocn_compute_energy(G, gamma); // recompute energy from scratch + energy_new = compute_energy(G, gamma); // recompute energy from scratch if (simulate_annealing(energy_new, energy_old, temperature, &G->rng)){ G->energy = energy_new; return SUCCESS; From fe6fe4afc5f8213f289fa2675e040db3477bde3f Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Tue, 28 Oct 2025 22:20:13 -0600 Subject: [PATCH 02/24] bug fix --- PyOCN/c_src/ocn.c | 9 +++++---- sandbox.py | 39 ++++++++++++++++++++++----------------- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/PyOCN/c_src/ocn.c b/PyOCN/c_src/ocn.c index 84c448a..90122c7 100644 --- a/PyOCN/c_src/ocn.c +++ b/PyOCN/c_src/ocn.c @@ -51,10 +51,6 @@ static inline Status update_drained_area(FlowGrid *G, drainedarea_t da_inc, lini return SUCCESS; } -double ocn_compute_energy(FlowGrid *G, double gamma){ - return compute_energy(G, gamma); -} - static inline double compute_energy(FlowGrid *G, double gamma){ double energy = 0.0; for (linidx_t i = 0; i < (linidx_t)G->dims.row * (linidx_t)G->dims.col; i++){ @@ -63,6 +59,11 @@ static inline double compute_energy(FlowGrid *G, double gamma){ return energy; } +double ocn_compute_energy(FlowGrid *G, double gamma){ + return compute_energy(G, gamma); +} + + /** * @brief Update the energy of the flowgrid along a single downstream path from a given vertex. Unsafe. * This function only works correctly if there is a single root in the flowgrid. diff --git a/sandbox.py b/sandbox.py index 92bb830..e7d1dac 100644 --- a/sandbox.py +++ b/sandbox.py @@ -1,25 +1,30 @@ +from time import perf_counter as timer import matplotlib.pyplot as plt import PyOCN as po -import numpy as np -ocn = po.OCN.from_net_type( - net_type="E", - dims=(16, 16), - wrap=True, - random_state=84712, -) -ocn.fit_custom_cooling(lambda t: np.ones_like(t)*1e-1, pbar=True, n_iterations=16*16*100, max_iterations_per_loop=1) -ocn.fit(max_iterations_per_loop=1) - - - -# print(ocn.history.shape) -# print(np.max(np.diff(ocn.history[:, 1]))) -# print(np.quantile(np.diff(ocn.history[:, 1]), 0.999)) +rng = 8472 +times = [] +for _ in range(5): + ocn = po.OCN.from_net_type( + net_type="I", + dims=(100, 100), + wrap=True, + random_state=rng, + ) + timer_start = timer() + ocn.fit(max_iterations_per_loop=10_000, pbar=True) + timer_end = timer() + rng += 1 + times.append(timer_end - timer_start) +print(f"Average time over 5 runs: {sum(times)/len(times):.2f} seconds") +print(f"Min time over 5 runs: {min(times):.2f} seconds") +print(f"Max time over 5 runs: {max(times):.2f} seconds") +print(f"Mean time over 5 runs: {sum(times)/len(times):.2f} seconds") +print(f"Final energy: {ocn.energy:.6f}") plt.plot(ocn.history[:, 0], ocn.history[:, 1]) plt.plot(ocn.history[:, 0], ocn.history[:, 2]) -# plt.xscale("log") -# plt.yscale("log") +plt.xscale("log") +plt.yscale("log") plt.show() \ No newline at end of file From 38c9e0264d9e1992230b145afde935f184f318e9 Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Tue, 28 Oct 2025 22:29:25 -0600 Subject: [PATCH 03/24] undid some of the inlining, which were likely not helpful. --- PyOCN/c_src/ocn.c | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/PyOCN/c_src/ocn.c b/PyOCN/c_src/ocn.c index 90122c7..869faf1 100644 --- a/PyOCN/c_src/ocn.c +++ b/PyOCN/c_src/ocn.c @@ -37,7 +37,7 @@ static inline bool simulate_annealing(double energy_new, double energy_old, doub * @param a The linear index of the starting vertex. * @return Status code indicating success or failure */ -static inline Status update_drained_area(FlowGrid *G, drainedarea_t da_inc, linidx_t a){ +static Status update_drained_area(FlowGrid *G, drainedarea_t da_inc, linidx_t a){ Vertex vert; do { Status code = fg_get_lin(&vert, G, a); @@ -63,7 +63,6 @@ double ocn_compute_energy(FlowGrid *G, double gamma){ return compute_energy(G, gamma); } - /** * @brief Update the energy of the flowgrid along a single downstream path from a given vertex. Unsafe. * This function only works correctly if there is a single root in the flowgrid. @@ -74,7 +73,7 @@ double ocn_compute_energy(FlowGrid *G, double gamma){ * @param gamma The exponent used in the energy calculation. * @return Status code indicating success or failure */ -Status update_energy_single_root(FlowGrid *G, drainedarea_t da_inc, linidx_t a, double gamma){ +static Status update_energy_single_root(FlowGrid *G, drainedarea_t da_inc, linidx_t a, double gamma){ Vertex vert; double energy_old = 0.0; double energy_new = 0.0; From 9c8320c8cedb65c9fbb9f9dc7a96b7aa1ea910f5 Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Tue, 28 Oct 2025 23:18:00 -0600 Subject: [PATCH 04/24] multithreading support for parallel runs --- PyOCN/utils.py | 80 +++++++++++++++++++++++++++++++++++++++++++++++++- sandbox.py | 49 ++++++++++++++++++------------- 2 files changed, 108 insertions(+), 21 deletions(-) diff --git a/PyOCN/utils.py b/PyOCN/utils.py index 910dd62..dcaa741 100644 --- a/PyOCN/utils.py +++ b/PyOCN/utils.py @@ -3,12 +3,15 @@ """ from __future__ import annotations +from concurrent.futures import ThreadPoolExecutor from itertools import product +import os from typing import Any, Literal, Callable, TYPE_CHECKING, Union from numbers import Number import networkx as nx import numpy as np from tqdm import tqdm +from functools import partial if TYPE_CHECKING: from .ocn import OCN @@ -359,4 +362,79 @@ def get_subwatersheds(dag : nx.DiGraph, node : Any) -> set[nx.DiGraph]: subwatersheds = set(set(nx.ancestors(dag, outlet)) | {outlet} for outlet in subwatershed_outlets) subwatersheds = set(dag.subgraph(wshd) for wshd in subwatersheds) return subwatersheds - \ No newline at end of file + + +def multi_fit(ocn:OCN|list[OCN], n_runs=5, n_threads=None, fit_method:Literal["fit", "fit_custom_cooling"]|None=None, increment_rng:bool=True, pbar:bool=False, fit_kwargs:dict|list[dict]=None) -> tuple[list[Any], list[OCN]]: + """Perform multiple OCN fitting operations in parallel using multithreading. + Each fit uses the same set of parameters, but different random seeds. + Each fit's random seed is set to `ocn.rng + i`, where `i` is the thread index. + + Useful for doing sensitivity analysis or ensemble fitting. + + Parameters + ---------- + ocn : OCN | list[OCN] + The OCN instance(s) to fit. If a list of OCNs is provided, each OCN will be fitted independently. + n_runs : int, default 5 + The number of fitting runs to perform. Must be equal to the length of `ocn` if `ocn` is a list. + n_threads : int, default None + The number of worker threads to use for parallel execution. If None, defaults to the number of CPU cores, times 2. + fit_method : {"fit", "fit_custom_cooling"} | None, default None + The fitting method to use. If None, defaults to "fit". + increment_rng : bool, default True + If True, each fit's random seed is set to `ocn.rng + i`, where `i` is the thread index. + pbar : bool, default False + If True, display a master progress bar that tracks the completion of each thread. + fit_kwargs : dict | list[dict], optional + Additional keyword arguments to pass to the `OCN.fit` method. If a list of dicts is provided, + it must be of length `n_runs`, and each dict will be used for the corresponding fit. + + Returns + ------- + tuple[list[Any], list[OCN]] + A tuple containing two lists: + - A list of results from each fitting operation. + - A list of fitted OCN instances corresponding to each fitting operation. + """ + + if isinstance(ocn, list): + if len(ocn) != n_runs: + raise ValueError(f"When ocn is a list, its length must equal n_runs. Got len(ocn)={len(ocn)} and n_runs={n_runs}.") + else: + ocn = [ocn.copy() for _ in range(n_runs)] + + if isinstance(fit_kwargs, list): + if len(fit_kwargs) != n_runs: + raise ValueError(f"When fit_kwargs is a list, its length must equal n_runs. Got len(fit_kwargs)={len(fit_kwargs)} and n_runs={n_runs}.") + if not all(isinstance(k, dict) for k in fit_kwargs): + raise ValueError("All elements of fit_kwargs list must be dictionaries.") + elif not isinstance(fit_kwargs, dict) and fit_kwargs is not None: + raise ValueError(f"fit_kwargs must be a dict or a list of dicts. Got {type(fit_kwargs)}.") + else: + fit_kwargs = [fit_kwargs or dict() for _ in range(n_runs)] + + def fit(ocn, i, fit_kwargs): + if increment_rng: + ocn.rng = ocn.rng + i + if fit_method is None or fit_method == "fit": + res = ocn.fit(**fit_kwargs) + elif fit_method == "fit_custom_cooling": + res = ocn.fit_custom_cooling(**fit_kwargs) + else: + raise ValueError(f"Invalid fit_method {fit_method}. Must be one of 'fit', 'fit_custom_cooling', or None.") + return res + + if n_threads is None: + n_threads = os.cpu_count() + n_threads = min(os.cpu_count()*2, n_threads) + + with ThreadPoolExecutor(max_workers=n_threads) as executor: + futures = [] + for i in range(n_runs): + futures.append(executor.submit(fit, ocn[i], i, fit_kwargs[i])) + results = [] + pbar = tqdm(futures, disable=not pbar, desc="Fitting OCNs") + for i, future in enumerate(futures): + results.append(future.result()) + pbar.update(1) + return results, ocn \ No newline at end of file diff --git a/sandbox.py b/sandbox.py index e7d1dac..c126830 100644 --- a/sandbox.py +++ b/sandbox.py @@ -1,30 +1,39 @@ from time import perf_counter as timer +from concurrent.futures import ThreadPoolExecutor + import matplotlib.pyplot as plt +import numpy as np import PyOCN as po + rng = 8472 times = [] -for _ in range(5): - ocn = po.OCN.from_net_type( - net_type="I", - dims=(100, 100), - wrap=True, - random_state=rng, - ) - timer_start = timer() - ocn.fit(max_iterations_per_loop=10_000, pbar=True) - timer_end = timer() - rng += 1 - times.append(timer_end - timer_start) -print(f"Average time over 5 runs: {sum(times)/len(times):.2f} seconds") -print(f"Min time over 5 runs: {min(times):.2f} seconds") -print(f"Max time over 5 runs: {max(times):.2f} seconds") -print(f"Mean time over 5 runs: {sum(times)/len(times):.2f} seconds") -print(f"Final energy: {ocn.energy:.6f}") - -plt.plot(ocn.history[:, 0], ocn.history[:, 1]) -plt.plot(ocn.history[:, 0], ocn.history[:, 2]) + +gammas = np.linspace(0, 1, 51) +ocns = [po.OCN.from_net_type( + net_type="I", + gamma=gamma, + dims=(100, 100), + wrap=True, + random_state=rng, +) for gamma in gammas] + +fit_kwargs = { + "n_iterations": 64*64*40, + "pbar": False, + "max_iterations_per_loop": 2_000, +} +results, fitted_ocns = po.utils.multi_fit(ocns, len(gammas), pbar=True, fit_kwargs=fit_kwargs) + +for fitted_ocn in fitted_ocns: + plt.plot(fitted_ocn.history[:, 0], (fitted_ocn.history[:, 1] - fitted_ocn.history[-1, 1]) / (fitted_ocn.history[0, 1] - fitted_ocn.history[-1, 1]), alpha=0.5, color="C0") +# print(f"Average time over 5 runs: {sum(times)/len(times):.2f} seconds") +# print(f"Min time over 5 runs: {min(times):.2f} seconds") +# print(f"Max time over 5 runs: {max(times):.2f} seconds") +# print(f"Mean time over 5 runs: {sum(times)/len(times):.2f} seconds") +# print(f"Final energy: {ocn.energy:.6f}") + plt.xscale("log") plt.yscale("log") plt.show() \ No newline at end of file From eee994291178d5d20f4b34ce6fdd565e8410f7b2 Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Tue, 28 Oct 2025 23:29:12 -0600 Subject: [PATCH 05/24] update parallel_fit convenience function. --- PyOCN/utils.py | 46 +++++++++++++++++++++++++++++++++------------- sandbox.py | 2 +- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/PyOCN/utils.py b/PyOCN/utils.py index dcaa741..a107c36 100644 --- a/PyOCN/utils.py +++ b/PyOCN/utils.py @@ -364,23 +364,31 @@ def get_subwatersheds(dag : nx.DiGraph, node : Any) -> set[nx.DiGraph]: return subwatersheds -def multi_fit(ocn:OCN|list[OCN], n_runs=5, n_threads=None, fit_method:Literal["fit", "fit_custom_cooling"]|None=None, increment_rng:bool=True, pbar:bool=False, fit_kwargs:dict|list[dict]=None) -> tuple[list[Any], list[OCN]]: - """Perform multiple OCN fitting operations in parallel using multithreading. - Each fit uses the same set of parameters, but different random seeds. - Each fit's random seed is set to `ocn.rng + i`, where `i` is the thread index. - +def parallel_fit( + ocn:OCN|list[OCN], + n_runs=5, + n_threads=None, + fit_method:Literal["fit", "fit_custom_cooling"]|list[Literal["fit", "fit_custom_cooling"]]|None=None, + increment_rng:bool=True, + pbar:bool=False, + fit_kwargs:dict|list[dict]=None +) -> tuple[list[Any], list[OCN]]: + """Convenience function to perform multiple OCN fitting operations in parallel using multithreading. Useful for doing sensitivity analysis or ensemble fitting. Parameters ---------- ocn : OCN | list[OCN] The OCN instance(s) to fit. If a list of OCNs is provided, each OCN will be fitted independently. + If a single OCN is provided, it will be copied `n_runs` times. + Not modified during fitting. n_runs : int, default 5 The number of fitting runs to perform. Must be equal to the length of `ocn` if `ocn` is a list. n_threads : int, default None The number of worker threads to use for parallel execution. If None, defaults to the number of CPU cores, times 2. - fit_method : {"fit", "fit_custom_cooling"} | None, default None - The fitting method to use. If None, defaults to "fit". + fit_method : {"fit", "fit_custom_cooling"} | list[{"fit", "fit_custom_cooling"} | None], default None + The fitting method to use. If None, defaults to "fit". If a list is provided, it must be of length `n_runs`, + and each element specifies the fitting method for the corresponding fit. increment_rng : bool, default True If True, each fit's random seed is set to `ocn.rng + i`, where `i` is the thread index. pbar : bool, default False @@ -400,6 +408,7 @@ def multi_fit(ocn:OCN|list[OCN], n_runs=5, n_threads=None, fit_method:Literal["f if isinstance(ocn, list): if len(ocn) != n_runs: raise ValueError(f"When ocn is a list, its length must equal n_runs. Got len(ocn)={len(ocn)} and n_runs={n_runs}.") + ocn = [o.copy() for o in ocn] else: ocn = [ocn.copy() for _ in range(n_runs)] @@ -413,7 +422,17 @@ def multi_fit(ocn:OCN|list[OCN], n_runs=5, n_threads=None, fit_method:Literal["f else: fit_kwargs = [fit_kwargs or dict() for _ in range(n_runs)] - def fit(ocn, i, fit_kwargs): + if isinstance(fit_method, list): + if len(fit_method) != n_runs: + raise ValueError(f"When fit_method is a list, its length must equal n_runs. Got len(fit_method)={len(fit_method)} and n_runs={n_runs}.") + elif not all(m is None or m in {"fit", "fit_custom_cooling"} for m in fit_method): + raise ValueError("All elements of fit_method list must be one of 'fit', 'fit_custom_cooling', or None.") + else: + if fit_method is not None and fit_method not in {"fit", "fit_custom_cooling"}: + raise ValueError(f"Invalid fit_method {fit_method}. Must be one of 'fit', 'fit_custom_cooling', or None.") + fit_method = [fit_method for _ in range(n_runs)] + + def fit(ocn, i, fit_method, fit_kwargs): if increment_rng: ocn.rng = ocn.rng + i if fit_method is None or fit_method == "fit": @@ -422,7 +441,7 @@ def fit(ocn, i, fit_kwargs): res = ocn.fit_custom_cooling(**fit_kwargs) else: raise ValueError(f"Invalid fit_method {fit_method}. Must be one of 'fit', 'fit_custom_cooling', or None.") - return res + return res, i # return index to place result correctly, in case of out-of-order completion. if n_threads is None: n_threads = os.cpu_count() @@ -431,10 +450,11 @@ def fit(ocn, i, fit_kwargs): with ThreadPoolExecutor(max_workers=n_threads) as executor: futures = [] for i in range(n_runs): - futures.append(executor.submit(fit, ocn[i], i, fit_kwargs[i])) - results = [] + futures.append(executor.submit(fit, ocn[i], i, fit_method[i], fit_kwargs[i])) + results = [None] * n_runs pbar = tqdm(futures, disable=not pbar, desc="Fitting OCNs") - for i, future in enumerate(futures): - results.append(future.result()) + for future in futures: + res, idx = future.result() + results[idx] = res pbar.update(1) return results, ocn \ No newline at end of file diff --git a/sandbox.py b/sandbox.py index c126830..393e3ea 100644 --- a/sandbox.py +++ b/sandbox.py @@ -24,7 +24,7 @@ "pbar": False, "max_iterations_per_loop": 2_000, } -results, fitted_ocns = po.utils.multi_fit(ocns, len(gammas), pbar=True, fit_kwargs=fit_kwargs) +results, fitted_ocns = po.utils.parallel_fit(ocns, len(gammas), pbar=True, fit_kwargs=fit_kwargs) for fitted_ocn in fitted_ocns: plt.plot(fitted_ocn.history[:, 0], (fitted_ocn.history[:, 1] - fitted_ocn.history[-1, 1]) / (fitted_ocn.history[0, 1] - fitted_ocn.history[-1, 1]), alpha=0.5, color="C0") From c2c83bba7a48f0e85f8c034beefc838881e988b9 Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Tue, 28 Oct 2025 23:30:06 -0600 Subject: [PATCH 06/24] added TODO item: write test for parallel fit func --- tests/test_basic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_basic.py b/tests/test_basic.py index 603fbbc..1d33a2c 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -274,5 +274,6 @@ def test_greedy_optimization(self): ocn.fit_custom_cooling(lambda t: np.ones_like(t)*1e-7, pbar=False, n_iterations=16**2*100, max_iterations_per_loop=1) self.assertLessEqual(np.quantile(np.diff(ocn.history[:, 1]), 0.999) - 1e-7, 0, "Energy did not decrease monotonically.") + # TODO: write tests for the parallel fitting utility if __name__ == "__main__": unittest.main() \ No newline at end of file From 2e306f6d0755975ccbce5925e9738eb1a05f0780 Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Wed, 29 Oct 2025 09:42:41 -0600 Subject: [PATCH 07/24] wrote tests for parallel fit method --- PyOCN/utils.py | 2 +- sandbox.py | 68 ++++++++++++++++++++++++--------------------- tests/test_basic.py | 34 +++++++++++++++++++++++ 3 files changed, 72 insertions(+), 32 deletions(-) diff --git a/PyOCN/utils.py b/PyOCN/utils.py index a107c36..2ffcaa3 100644 --- a/PyOCN/utils.py +++ b/PyOCN/utils.py @@ -444,7 +444,7 @@ def fit(ocn, i, fit_method, fit_kwargs): return res, i # return index to place result correctly, in case of out-of-order completion. if n_threads is None: - n_threads = os.cpu_count() + n_threads = os.cpu_count()*2 n_threads = min(os.cpu_count()*2, n_threads) with ThreadPoolExecutor(max_workers=n_threads) as executor: diff --git a/sandbox.py b/sandbox.py index 393e3ea..3ec7990 100644 --- a/sandbox.py +++ b/sandbox.py @@ -6,34 +6,40 @@ import numpy as np import PyOCN as po - -rng = 8472 -times = [] - -gammas = np.linspace(0, 1, 51) -ocns = [po.OCN.from_net_type( - net_type="I", - gamma=gamma, - dims=(100, 100), - wrap=True, - random_state=rng, -) for gamma in gammas] - -fit_kwargs = { - "n_iterations": 64*64*40, - "pbar": False, - "max_iterations_per_loop": 2_000, -} -results, fitted_ocns = po.utils.parallel_fit(ocns, len(gammas), pbar=True, fit_kwargs=fit_kwargs) - -for fitted_ocn in fitted_ocns: - plt.plot(fitted_ocn.history[:, 0], (fitted_ocn.history[:, 1] - fitted_ocn.history[-1, 1]) / (fitted_ocn.history[0, 1] - fitted_ocn.history[-1, 1]), alpha=0.5, color="C0") -# print(f"Average time over 5 runs: {sum(times)/len(times):.2f} seconds") -# print(f"Min time over 5 runs: {min(times):.2f} seconds") -# print(f"Max time over 5 runs: {max(times):.2f} seconds") -# print(f"Mean time over 5 runs: {sum(times)/len(times):.2f} seconds") -# print(f"Final energy: {ocn.energy:.6f}") - -plt.xscale("log") -plt.yscale("log") -plt.show() \ No newline at end of file +results, fitted_ocns = po.utils.parallel_fit( + ocn= po.OCN.from_net_type(net_type="I", dims=(16, 16), random_state=542390), + n_runs=5, + increment_rng=False, +) +print(results, fitted_ocns) + +# rng = 8472 +# times = [] + +# gammas = np.linspace(0, 1, 51) +# ocns = [po.OCN.from_net_type( +# net_type="I", +# gamma=gamma, +# dims=(100, 100), +# wrap=True, +# random_state=rng, +# ) for gamma in gammas] + +# fit_kwargs = { +# "n_iterations": 64*64*40, +# "pbar": False, +# "max_iterations_per_loop": 2_000, +# } +# results, fitted_ocns = po.utils.parallel_fit(ocns, len(gammas), pbar=True, fit_kwargs=fit_kwargs) + +# for fitted_ocn in fitted_ocns: +# plt.plot(fitted_ocn.history[:, 0], (fitted_ocn.history[:, 1] - fitted_ocn.history[-1, 1]) / (fitted_ocn.history[0, 1] - fitted_ocn.history[-1, 1]), alpha=0.5, color="C0") +# # print(f"Average time over 5 runs: {sum(times)/len(times):.2f} seconds") +# # print(f"Min time over 5 runs: {min(times):.2f} seconds") +# # print(f"Max time over 5 runs: {max(times):.2f} seconds") +# # print(f"Mean time over 5 runs: {sum(times)/len(times):.2f} seconds") +# # print(f"Final energy: {ocn.energy:.6f}") + +# plt.xscale("log") +# plt.yscale("log") +# plt.show() \ No newline at end of file diff --git a/tests/test_basic.py b/tests/test_basic.py index 1d33a2c..50afb2b 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -4,6 +4,7 @@ These tests verify core functionality with deterministic outputs. """ import unittest +from time import perf_counter as timer import numpy as np import networkx as nx import PyOCN as po @@ -274,6 +275,39 @@ def test_greedy_optimization(self): ocn.fit_custom_cooling(lambda t: np.ones_like(t)*1e-7, pbar=False, n_iterations=16**2*100, max_iterations_per_loop=1) self.assertLessEqual(np.quantile(np.diff(ocn.history[:, 1]), 0.999) - 1e-7, 0, "Energy did not decrease monotonically.") + def test_parallel_fit(self): + """Test parallel fitting utility.""" + + ocn= po.OCN.from_net_type(net_type="I", dims=(42, 40), random_state=542390) + t0 = timer() + ocn.fit() + t1 = timer() + single_run_time = t1 - t0 + energy = ocn.energy + + ocn = po.OCN.from_net_type(net_type="I", dims=(42, 40), random_state=542390) + t0 = timer() + _, fitted_ocns = po.utils.parallel_fit( + ocn=ocn, + n_runs=5, + increment_rng=False, + ) + t1 = timer() + parallel_time = t1 - t0 + energies = [o.energy for o in fitted_ocns] + + self.assertEqual(len(fitted_ocns), 5, "Number of fitted OCNs does not match number of runs.") + + for e in energies: + self.assertAlmostEqual(e, energy, places=6, msg="Energies from parallel fits do not match.") + print() + print(f"\tSingle run time: {single_run_time:.2f} seconds") + print(f"\tParallel fit time for 5 runs: {parallel_time:.2f} seconds") + self.assertLess(parallel_time, single_run_time * 4, "Parallel fitting did not speed up the process as expected.") + + for o in fitted_ocns: + self.assertAlmostEqual(ocn.energy, o.history[0, 1], places=6, msg="Initial energy was not preserved.") + # TODO: write tests for the parallel fitting utility if __name__ == "__main__": unittest.main() \ No newline at end of file From 799200b9aa0352d3b7d23def1cde348e87bdb2ef Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Wed, 29 Oct 2025 15:22:17 -0600 Subject: [PATCH 08/24] uh --- .vscode/settings.json | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index fda6998..7a40653 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,6 +4,8 @@ "neighbormatrix.c": "cpp", "stdbool.h": "c", "returncodes.h": "c", - "string.h": "c" + "string.h": "c", + "__locale": "c", + "bitset": "c" } } \ No newline at end of file From d587ffff22a49cf8ab18490a6f639c218aff9e1c Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Wed, 29 Oct 2025 15:50:52 -0600 Subject: [PATCH 09/24] implemented new multi-outlet method --- PyOCN/_libocn_bindings.py | 8 ++++---- PyOCN/c_src/ocn.c | 39 ++++++++++++--------------------------- PyOCN/c_src/ocn.h | 6 ++++-- PyOCN/ocn.py | 26 ++++++++++++++++++++++++-- 4 files changed, 44 insertions(+), 35 deletions(-) diff --git a/PyOCN/_libocn_bindings.py b/PyOCN/_libocn_bindings.py index ed60aa4..00262a6 100644 --- a/PyOCN/_libocn_bindings.py +++ b/PyOCN/_libocn_bindings.py @@ -143,12 +143,12 @@ class FlowGrid_C(Structure): libocn.ocn_compute_energy.argtypes = [POINTER(FlowGrid_C), c_double] libocn.ocn_compute_energy.restype = c_double -# Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature); -libocn.ocn_single_erosion_event.argtypes = [POINTER(FlowGrid_C), c_double, c_double] +# Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature, bool calculate_full_energy); +libocn.ocn_single_erosion_event.argtypes = [POINTER(FlowGrid_C), c_double, c_double, c_bool] libocn.ocn_single_erosion_event.restype = Status -# Status ocn_outer_ocn_loop(FlowGrid *G, uint32_t niterations, double gamma, double *annealing_schedule); -libocn.ocn_outer_ocn_loop.argtypes = [POINTER(FlowGrid_C), c_uint32, c_double, POINTER(c_double)] +# Status ocn_outer_ocn_loop(FlowGrid *G, uint32_t niterations, double gamma, double *annealing_schedule, bool calculate_full_energy); +libocn.ocn_outer_ocn_loop.argtypes = [POINTER(FlowGrid_C), c_uint32, c_double, POINTER(c_double), c_bool] libocn.ocn_outer_ocn_loop.restype = Status diff --git a/PyOCN/c_src/ocn.c b/PyOCN/c_src/ocn.c index 869faf1..45ab9a9 100644 --- a/PyOCN/c_src/ocn.c +++ b/PyOCN/c_src/ocn.c @@ -90,7 +90,7 @@ static Status update_energy_single_root(FlowGrid *G, drainedarea_t da_inc, linid return SUCCESS; } -Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature){ +Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature, bool calculate_full_energy){ Status code; Vertex vert; @@ -158,20 +158,9 @@ Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature){ mh_eval: - /* - TODO: PERFORMANCE ISSUE: - This function is supposed to update the energy of the flowgrid G after a - change in drained area along the path starting at vertex a. - - Simple but inefficient fix (current): recompute the *entire* energy of the flowgrid from scratch - each time this function is called. - - More complex fix: find the set of all upstream vertices that flow into a and compute - their summed contribution to the energy. Pass this value (sum of (da^gamma) for all - upstream vertices) into this function, instead of just passing da_inc. - */ - if ((G->nroots > 1) && (gamma < 1.0)){ - // energy_old = compute_energy(G, gamma); // recompute energy from scratch + + if (calculate_full_energy){ + energy_old = compute_energy(G, gamma); // recompute energy from scratch update_drained_area(G, -da_inc, a_down_old); // remove drainage from old path update_drained_area(G, da_inc, a_down_new); // add drainage to new path energy_new = compute_energy(G, gamma); // recompute energy from scratch @@ -179,31 +168,27 @@ Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature){ G->energy = energy_new; return SUCCESS; } - // reject swap: undo everything and try again - update_drained_area(G, da_inc, a_down_old); // add removed drainage back to old path - update_drained_area(G, -da_inc, a_down_new); // remove added drainage from new path - fg_change_vertex_outflow(G, a, down_old); // undo the outflow change - } else { // if there's only one root, we can use a more efficient method + } else { update_energy_single_root(G, -da_inc, a_down_old, gamma); // remove drainage from old path and update energy update_energy_single_root(G, da_inc, a_down_new, gamma); // add drainage to new path and update energy energy_new = G->energy; if (simulate_annealing(energy_new, energy_old, temperature, &G->rng)){ return SUCCESS; } - // reject swap: undo everything and try again - update_energy_single_root(G, da_inc, a_down_old, gamma); // add removed drainage back to old path and update energy - update_energy_single_root(G, -da_inc, a_down_new, gamma); // remove added drainage from new path and update energy - fg_change_vertex_outflow(G, a, down_old); // undo the outflow change } - + // reject swap: undo everything and try again + update_drained_area(G, da_inc, a_down_old); // add removed drainage back to old path and update energy + update_drained_area(G, -da_inc, a_down_new); // remove added drainage from new path and update energy + fg_change_vertex_outflow(G, a, down_old); // undo the outflow change + G->energy = energy_old; return EROSION_FAILURE; // if we reach here, we failed to find a valid swap in many, many tries } -Status ocn_outer_ocn_loop(FlowGrid *G, uint32_t niterations, double gamma, double *annealing_schedule){ +Status ocn_outer_ocn_loop(FlowGrid *G, uint32_t niterations, double gamma, double *annealing_schedule, bool calculate_full_energy){ Status code; for (uint32_t i = 0; i < niterations; i++){ - code = ocn_single_erosion_event(G, gamma, annealing_schedule[i]); + code = ocn_single_erosion_event(G, gamma, annealing_schedule[i], calculate_full_energy); if ((code != SUCCESS) && (code != EROSION_FAILURE)) return code; } return SUCCESS; diff --git a/PyOCN/c_src/ocn.h b/PyOCN/c_src/ocn.h index 5c0aa2b..efa9f81 100644 --- a/PyOCN/c_src/ocn.h +++ b/PyOCN/c_src/ocn.h @@ -31,9 +31,10 @@ double ocn_compute_energy(FlowGrid *G, double gamma); * @param G Pointer to the FlowGrid. * @param gamma The exponent used in the energy calculation. * @param temperature The temperature parameter for the Metropolis-Hastings acceptance criterion. + * @param calculate_full_energy If true, the full energy of the graph is recalculated when considering the proposed change. If false, a more efficient incremental update is used. full_energy_recalc will be slower, but avoid accumulated numerical errors over many iterations. * @return Status code indicating success or failure. */ -Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature); +Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature, bool calculate_full_energy); /** * @brief Perform multiple erosion events on the flowgrid. @@ -43,8 +44,9 @@ Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature); * @param gamma The exponent used in the energy calculation. * @param annealing_schedule An array of temperatures (ranging from 0-1) to use for each iteration. Length must be at least niterations. * @param wrap If true, allows wrapping around the edges of the grid (toroidal). If false, no wrapping is applied. + * @param calculate_full_energy If true, the full energy of the graph is recalculated after each proposed change. If false, a more efficient incremental update is used. full_energy_recalc will be slower, but avoid accumulated numerical errors over many iterations. * @return Status code indicating success or failure */ -Status ocn_outer_ocn_loop(FlowGrid *G, uint32_t niterations, double gamma, double *annealing_schedule); +Status ocn_outer_ocn_loop(FlowGrid *G, uint32_t niterations, double gamma, double *annealing_schedule, bool calculate_full_energy); #endif // OCN_H diff --git a/PyOCN/ocn.py b/PyOCN/ocn.py index bbbe506..6ce249d 100644 --- a/PyOCN/ocn.py +++ b/PyOCN/ocn.py @@ -598,7 +598,7 @@ def to_xarray(self, unwrap:bool=True) -> "xr.Dataset": } ) - def single_iteration(self, temperature:float, array_report:bool=False, unwrap:bool=True) -> "xr.Dataset | None": + def single_iteration(self, temperature:float, array_report:bool=False, unwrap:bool=True, calculate_full_energy:bool=True) -> "xr.Dataset | None": """ Perform a single iteration of the optimization algorithm at a given temperature. Updates the internal history attribute. See :meth:`fit` for details on the algorithm. @@ -619,6 +619,17 @@ def single_iteration(self, temperature:float, array_report:bool=False, unwrap:bo with some nan values. If False or the current OCN does not have periodic boundaries, then no transformation is applied and the resulting raster will have the same dimensions as the current OCN grid. + calculate_full_energy : bool, default True + If True, the full energy of the graph is recalculated when considering + the proposed change. If False, a more efficient incremental update is used. + full_energy_recalc will be slower, but avoid accumulated numerical errors + over many iterations. + + Returns + ------- + xr.Dataset | None + If ``array_report == True``, an xarray.Dataset containing the state of the FlowGrid + after the iteration. See :meth:`to_xarray` for details. If ``array_report == False``, returns None. Raises ------ @@ -630,6 +641,7 @@ def single_iteration(self, temperature:float, array_report:bool=False, unwrap:bo self.__p_c_graph, self.gamma, temperature, + calculate_full_energy )) # append to history @@ -654,7 +666,8 @@ def fit( array_reports:int=0, tol:float=None, max_iterations_per_loop=10_000, - unwrap:bool=True,) -> "xr.Dataset | None": + unwrap:bool=True, + calculate_full_energy:bool=False) -> "xr.Dataset | None": """ Convenience function to optimize the OCN using the simulated annealing algorithm from Carraro et al (2020). For finer control over the optimization process, use :meth:`fit_custom_cooling` or use :meth:`single_erosion_event` in a loop. @@ -706,6 +719,11 @@ def fit( with some nan values. If False or the current OCN does not have periodic boundaries, then no transformation is applied and the resulting raster will have the same dimensions as the current OCN grid. + calculate_full_energy: bool, default False + If True, the full energy of the graph is recalculated when considering + the proposed change. If False, a more efficient incremental update is used. + full_energy_recalc will be slower, but avoid accumulated numerical errors + over many iterations. Returns ------- @@ -800,6 +818,7 @@ def fit( tol=tol, max_iterations_per_loop=max_iterations_per_loop, unwrap=unwrap, + calculate_full_energy=calculate_full_energy, ) def fit_custom_cooling( @@ -812,6 +831,7 @@ def fit_custom_cooling( tol:float=None, max_iterations_per_loop=10_000, unwrap:bool=True, + calculate_full_energy:bool=False, ) -> "xr.Dataset | None": """ Optimize the OCN using the a custom cooling schedule. This allows for @@ -838,6 +858,7 @@ def fit_custom_cooling( tol : float, optional max_iterations_per_loop: int, optional unwrap: bool, default True + calculate_full_energy: bool, default False Returns ------- @@ -924,6 +945,7 @@ def fit_custom_cooling( iterations_this_loop, self.gamma, anneal_ptr, + calculate_full_energy, )) e_new = self.energy completed_iterations += iterations_this_loop From 60ce7f9261225ef38aad06c009ffcbee9d12cff0 Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Wed, 29 Oct 2025 15:59:05 -0600 Subject: [PATCH 10/24] writing some new tests --- PyOCN/_flowgrid_convert.py | 6 --- PyOCN/c_src/ocn.h | 2 - sandbox.py | 105 ++++++++++++++++++++++--------------- 3 files changed, 64 insertions(+), 49 deletions(-) diff --git a/PyOCN/_flowgrid_convert.py b/PyOCN/_flowgrid_convert.py index 0f05233..9dc2fab 100644 --- a/PyOCN/_flowgrid_convert.py +++ b/PyOCN/_flowgrid_convert.py @@ -1,6 +1,3 @@ -#TODO: allow users to provide multiple DAGs that partition a space. -#TODO: allow edge-wrapping (toroidal grids). -#TODO: implement "every edge vertex is a root" option """ Functions for converting between NetworkX graphs and FlowGrid_C structures. """ @@ -246,9 +243,6 @@ def direction_bit(pos1, pos2): p_c_graph.contents.resolution = float(resolution) p_c_graph.contents.nroots = len([n for n in G.nodes if G.out_degree(n) == 0]) p_c_graph.contents.wrap = wrap - - if p_c_graph.contents.nroots > 1: - warnings.warn(f"FlowGrid has {p_c_graph.contents.nroots} root nodes (nodes with no downstream). This will slow down certain operations.") # do not set energy diff --git a/PyOCN/c_src/ocn.h b/PyOCN/c_src/ocn.h index efa9f81..c04d277 100644 --- a/PyOCN/c_src/ocn.h +++ b/PyOCN/c_src/ocn.h @@ -4,8 +4,6 @@ * @brief Header file for OCN optimization. */ -//#TODO I'm worried that our method of choosing a new vertex if the last one is invalid introduces bias. Either choose a random direction to walk, or just try every vertex in random order. - #ifndef OCN_H #define OCN_H diff --git a/sandbox.py b/sandbox.py index 3ec7990..8f16629 100644 --- a/sandbox.py +++ b/sandbox.py @@ -1,45 +1,68 @@ -from time import perf_counter as timer -from concurrent.futures import ThreadPoolExecutor - +import PyOCN as po +import ctypes import matplotlib.pyplot as plt +from time import perf_counter as timer +from tqdm import trange import numpy as np -import PyOCN as po -results, fitted_ocns = po.utils.parallel_fit( - ocn= po.OCN.from_net_type(net_type="I", dims=(16, 16), random_state=542390), - n_runs=5, - increment_rng=False, -) -print(results, fitted_ocns) - -# rng = 8472 -# times = [] - -# gammas = np.linspace(0, 1, 51) -# ocns = [po.OCN.from_net_type( -# net_type="I", -# gamma=gamma, -# dims=(100, 100), -# wrap=True, -# random_state=rng, -# ) for gamma in gammas] - -# fit_kwargs = { -# "n_iterations": 64*64*40, -# "pbar": False, -# "max_iterations_per_loop": 2_000, -# } -# results, fitted_ocns = po.utils.parallel_fit(ocns, len(gammas), pbar=True, fit_kwargs=fit_kwargs) - -# for fitted_ocn in fitted_ocns: -# plt.plot(fitted_ocn.history[:, 0], (fitted_ocn.history[:, 1] - fitted_ocn.history[-1, 1]) / (fitted_ocn.history[0, 1] - fitted_ocn.history[-1, 1]), alpha=0.5, color="C0") -# # print(f"Average time over 5 runs: {sum(times)/len(times):.2f} seconds") -# # print(f"Min time over 5 runs: {min(times):.2f} seconds") -# # print(f"Max time over 5 runs: {max(times):.2f} seconds") -# # print(f"Mean time over 5 runs: {sum(times)/len(times):.2f} seconds") -# # print(f"Final energy: {ocn.energy:.6f}") - -# plt.xscale("log") -# plt.yscale("log") -# plt.show() \ No newline at end of file +# def check_upstream_function_result(numpstream_iter, upstream_indices_iter, correct_answer, status): +# # print("Iterative:") +# # print("\tNumber of upstream vertices:", numpstream_iter) +# # print("Correct answer:") +# # print("\tNumber of upstream vertices:", len(correct_answer)) +# # print("Status code:", status) +# success = (numpstream_iter == len(correct_answer)) and (upstream_indices_iter == correct_answer) and (status == 0) +# # print(f"Success: {success}") +# return success + +# def run_upstream_functions(ocn, a): +# m, n = ocn.dims +# a = po._libocn_bindings.linidx_t(a) + +# c_graph = ocn._OCN__p_c_graph +# upstream_indices = (po._libocn_bindings.linidx_t * (m * n))() +# stack = (po._libocn_bindings.linidx_t * (m * n))() +# nupstream_val = po._libocn_bindings.linidx_t(0) + +# t0 = timer() +# status = po._libocn_bindings.libocn.fg_dfs_iterative( +# upstream_indices, +# ctypes.byref(nupstream_val), +# stack, +# c_graph, +# a +# ) +# t1 = timer() +# po._libocn_bindings.libocn.ocn_compute_energy(c_graph, ocn.gamma) +# t2 = timer() + +# numpstream_iter = nupstream_val.value +# upstream_indices_iter = sorted(list(upstream_indices)[:nupstream_val.value]) + +# return numpstream_iter, upstream_indices_iter, status, t1 - t0, t2 - t1 + +# ocn = po.OCN.from_net_type("I", dims=(10, 10)) +# a = 85 +# correct_answer = [80, 81, 82, 83, 84, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] +# numpstream_iter, upstream_indices_iter, status, time_dfs, time_energy = run_upstream_functions(ocn, a) +# check_upstream_function_result(numpstream_iter, upstream_indices_iter, correct_answer, status) + +# ocn = po.OCN.from_net_type("I", dims=(10, 10)) +# a = 33 +# correct_answer = [30, 31, 32] +# numpstream_iter, upstream_indices_iter, status, time_dfs, time_energy = run_upstream_functions(ocn, a) +# check_upstream_function_result(numpstream_iter, upstream_indices_iter, correct_answer, status) + + +ocn = po.OCN.from_net_type("E", dims=(64, 64), random_state=8471) +ocn.fit(pbar=True, max_iterations_per_loop=100, calculate_full_energy=True) +print("Final energy:", ocn.energy) + +ocn = po.OCN.from_net_type("E", dims=(64, 64), random_state=8471) +ocn.fit(pbar=True, max_iterations_per_loop=100, calculate_full_energy=False) +print("Final energy:", ocn.energy) +# plt.hist(np.diff(ocn.history[:, 1]), bins=np.linspace(-1000, 0)) + +plt.show() + From 454c5e401537f4c16440618af286ce9caa2ad74d Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Wed, 29 Oct 2025 16:00:38 -0600 Subject: [PATCH 11/24] Update ocn.py --- PyOCN/ocn.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/PyOCN/ocn.py b/PyOCN/ocn.py index 6ce249d..dce411d 100644 --- a/PyOCN/ocn.py +++ b/PyOCN/ocn.py @@ -598,7 +598,7 @@ def to_xarray(self, unwrap:bool=True) -> "xr.Dataset": } ) - def single_iteration(self, temperature:float, array_report:bool=False, unwrap:bool=True, calculate_full_energy:bool=True) -> "xr.Dataset | None": + def single_iteration(self, temperature:float, array_report:bool=False, unwrap:bool=True, calculate_full_energy:bool=False) -> "xr.Dataset | None": """ Perform a single iteration of the optimization algorithm at a given temperature. Updates the internal history attribute. See :meth:`fit` for details on the algorithm. @@ -619,11 +619,10 @@ def single_iteration(self, temperature:float, array_report:bool=False, unwrap:bo with some nan values. If False or the current OCN does not have periodic boundaries, then no transformation is applied and the resulting raster will have the same dimensions as the current OCN grid. - calculate_full_energy : bool, default True + calculate_full_energy : bool, default False If True, the full energy of the graph is recalculated when considering the proposed change. If False, a more efficient incremental update is used. - full_energy_recalc will be slower, but avoid accumulated numerical errors - over many iterations. + Used in debugging and testing. Returns ------- @@ -722,8 +721,7 @@ def fit( calculate_full_energy: bool, default False If True, the full energy of the graph is recalculated when considering the proposed change. If False, a more efficient incremental update is used. - full_energy_recalc will be slower, but avoid accumulated numerical errors - over many iterations. + Used in debugging and testing. Returns ------- From 374220a9f9746288dcaf938b631479a0052ae3b4 Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Wed, 29 Oct 2025 16:05:32 -0600 Subject: [PATCH 12/24] added a new test --- tests/test_basic.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/tests/test_basic.py b/tests/test_basic.py index 50afb2b..277a89b 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -308,6 +308,41 @@ def test_parallel_fit(self): for o in fitted_ocns: self.assertAlmostEqual(ocn.energy, o.history[0, 1], places=6, msg="Initial energy was not preserved.") - # TODO: write tests for the parallel fitting utility + def test_energy_update_method(self): + ocn = po.OCN.from_net_type("E", dims=(64, 64), random_state=238155) + ocn.fit(pbar=True, max_iterations_per_loop=10_000, calculate_full_energy=True) + energy = ocn.energy + ocn = po.OCN.from_net_type("E", dims=(64, 64), random_state=238155) + ocn.fit(pbar=True, max_iterations_per_loop=10_000, calculate_full_energy=False) + energy_2 = ocn.energy + self.assertAlmostEqual(energy, energy_2, places=5, msg="Energies do not match between full_energy_calc and incremental update methods.") + + ocn = po.OCN.from_net_type("E", dims=(61, 61), random_state=12638) + ocn.fit(pbar=True, max_iterations_per_loop=10_000, calculate_full_energy=True) + energy = ocn.energy + ocn = po.OCN.from_net_type("E", dims=(61, 61), random_state=12638) + ocn.fit(pbar=True, max_iterations_per_loop=10_000, calculate_full_energy=False) + energy_2 = ocn.energy + self.assertAlmostEqual(energy, energy_2, places=5, msg="Energies do not match between full_energy_calc and incremental update methods.") + + ocn = po.OCN.from_net_type("I", dims=(68, 68), random_state=19075) + ocn.fit(pbar=True, max_iterations_per_loop=10_000, calculate_full_energy=True) + energy = ocn.energy + ocn = po.OCN.from_net_type("I", dims=(68, 68), random_state=19075) + ocn.fit(pbar=True, max_iterations_per_loop=10_000, calculate_full_energy=False) + energy_2 = ocn.energy + self.assertAlmostEqual(energy, energy_2, places=5, msg="Energies do not match between full_energy_calc and incremental update methods.") + + ocn = po.OCN.from_net_type("I", dims=(50, 61), random_state=910536) + ocn.fit(pbar=True, max_iterations_per_loop=10_000, calculate_full_energy=True) + energy = ocn.energy + ocn = po.OCN.from_net_type("I", dims=(50, 61), random_state=910536) + ocn.fit(pbar=True, max_iterations_per_loop=10_000, calculate_full_energy=False) + energy_2 = ocn.energy + + self.assertAlmostEqual(energy, energy_2, places=5, msg="Energies do not match between full_energy_calc and incremental update methods.") + + def + if __name__ == "__main__": unittest.main() \ No newline at end of file From 5674170b8bc753d243a1f13618642dccebe4b6a4 Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Wed, 29 Oct 2025 16:06:09 -0600 Subject: [PATCH 13/24] Multi-outlet optimization works --- tests/test_basic.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_basic.py b/tests/test_basic.py index 277a89b..af0a49b 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -340,9 +340,7 @@ def test_energy_update_method(self): ocn.fit(pbar=True, max_iterations_per_loop=10_000, calculate_full_energy=False) energy_2 = ocn.energy - self.assertAlmostEqual(energy, energy_2, places=5, msg="Energies do not match between full_energy_calc and incremental update methods.") - - def + self.assertAlmostEqual(energy, energy_2, places=5, msg="Energies do not match between full_energy_calc and incremental update methods.") if __name__ == "__main__": unittest.main() \ No newline at end of file From 7e0f77d17ceffc3193bb81fcc43178671eb2bad3 Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Wed, 29 Oct 2025 16:13:29 -0600 Subject: [PATCH 14/24] updated a warning about the calculate_full_energy flag --- PyOCN/ocn.py | 10 +++++++--- tests/test_basic.py | 40 ++++++++++++++++++++-------------------- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/PyOCN/ocn.py b/PyOCN/ocn.py index dce411d..e6093a0 100644 --- a/PyOCN/ocn.py +++ b/PyOCN/ocn.py @@ -622,7 +622,9 @@ def single_iteration(self, temperature:float, array_report:bool=False, unwrap:bo calculate_full_energy : bool, default False If True, the full energy of the graph is recalculated when considering the proposed change. If False, a more efficient incremental update is used. - Used in debugging and testing. + Small numerical differences may arise between the two methods due to floating point + precision. If precision is of the utmost importance, set this to True, but note that + this comes with a significant performance penalty. Returns ------- @@ -718,10 +720,12 @@ def fit( with some nan values. If False or the current OCN does not have periodic boundaries, then no transformation is applied and the resulting raster will have the same dimensions as the current OCN grid. - calculate_full_energy: bool, default False + calculate_full_energy : bool, default False If True, the full energy of the graph is recalculated when considering the proposed change. If False, a more efficient incremental update is used. - Used in debugging and testing. + Small numerical differences may arise between the two methods due to floating point + precision. If precision is of the utmost importance, set this to True, but note that + this comes with a significant performance penalty. Returns ------- diff --git a/tests/test_basic.py b/tests/test_basic.py index af0a49b..52320df 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -309,38 +309,38 @@ def test_parallel_fit(self): self.assertAlmostEqual(ocn.energy, o.history[0, 1], places=6, msg="Initial energy was not preserved.") def test_energy_update_method(self): - ocn = po.OCN.from_net_type("E", dims=(64, 64), random_state=238155) - ocn.fit(pbar=True, max_iterations_per_loop=10_000, calculate_full_energy=True) + ocn = po.OCN.from_net_type("E", dims=(44, 44), random_state=238155) + ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=True) energy = ocn.energy - ocn = po.OCN.from_net_type("E", dims=(64, 64), random_state=238155) - ocn.fit(pbar=True, max_iterations_per_loop=10_000, calculate_full_energy=False) + ocn = po.OCN.from_net_type("E", dims=(44, 44), random_state=238155) + ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=False) energy_2 = ocn.energy - self.assertAlmostEqual(energy, energy_2, places=5, msg="Energies do not match between full_energy_calc and incremental update methods.") + self.assertAlmostEqual(energy/energy_2, 1.0, places=3, msg="Energies do not match between full_energy_calc and incremental update methods.") - ocn = po.OCN.from_net_type("E", dims=(61, 61), random_state=12638) - ocn.fit(pbar=True, max_iterations_per_loop=10_000, calculate_full_energy=True) + ocn = po.OCN.from_net_type("E", dims=(31, 31), random_state=12638) + ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=True) energy = ocn.energy - ocn = po.OCN.from_net_type("E", dims=(61, 61), random_state=12638) - ocn.fit(pbar=True, max_iterations_per_loop=10_000, calculate_full_energy=False) + ocn = po.OCN.from_net_type("E", dims=(31, 31), random_state=12638) + ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=False) energy_2 = ocn.energy - self.assertAlmostEqual(energy, energy_2, places=5, msg="Energies do not match between full_energy_calc and incremental update methods.") + self.assertAlmostEqual(energy/energy_2, 1.0, places=3, msg="Energies do not match between full_energy_calc and incremental update methods.") - ocn = po.OCN.from_net_type("I", dims=(68, 68), random_state=19075) - ocn.fit(pbar=True, max_iterations_per_loop=10_000, calculate_full_energy=True) + ocn = po.OCN.from_net_type("I", dims=(38, 38), random_state=19075) + ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=True) energy = ocn.energy - ocn = po.OCN.from_net_type("I", dims=(68, 68), random_state=19075) - ocn.fit(pbar=True, max_iterations_per_loop=10_000, calculate_full_energy=False) + ocn = po.OCN.from_net_type("I", dims=(38, 38), random_state=19075) + ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=False) energy_2 = ocn.energy - self.assertAlmostEqual(energy, energy_2, places=5, msg="Energies do not match between full_energy_calc and incremental update methods.") + self.assertAlmostEqual(energy/energy_2, 1.0, places=3, msg="Energies do not match between full_energy_calc and incremental update methods.") - ocn = po.OCN.from_net_type("I", dims=(50, 61), random_state=910536) - ocn.fit(pbar=True, max_iterations_per_loop=10_000, calculate_full_energy=True) + ocn = po.OCN.from_net_type("I", dims=(20, 41), random_state=910536) + ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=True) energy = ocn.energy - ocn = po.OCN.from_net_type("I", dims=(50, 61), random_state=910536) - ocn.fit(pbar=True, max_iterations_per_loop=10_000, calculate_full_energy=False) + ocn = po.OCN.from_net_type("I", dims=(20, 41), random_state=910536) + ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=False) energy_2 = ocn.energy - self.assertAlmostEqual(energy, energy_2, places=5, msg="Energies do not match between full_energy_calc and incremental update methods.") + self.assertAlmostEqual(energy/energy_2, 1.0, places=3, msg="Energies do not match between full_energy_calc and incremental update methods.") if __name__ == "__main__": unittest.main() \ No newline at end of file From aac812d9feec76bd76c04ee564f1247f75244e54 Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Wed, 29 Oct 2025 16:16:06 -0600 Subject: [PATCH 15/24] cleared up some todos --- PyOCN/ocn.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/PyOCN/ocn.py b/PyOCN/ocn.py index e6093a0..f911197 100644 --- a/PyOCN/ocn.py +++ b/PyOCN/ocn.py @@ -42,8 +42,6 @@ Helper functions for visualization and plotting """ - -# TODO: relax the even dims requirement # TODO: have to_rasterio use the option to set the root node to 0,0 by using to_xarray as the backend instead of numpy? @@ -266,8 +264,7 @@ def from_digraph(cls, dag: nx.DiGraph, resolution:float=1, gamma=0.5, random_sta return cls(dag, resolution, gamma, random_state, verbosity=verbosity, validate=True, wrap=wrap) def __repr__(self): - #TODO: too verbose? - return f"" + return f"\n\n" def __str__(self): return f"OCN(gamma={self.gamma}, energy={self.energy}, dims={self.dims}, resolution={self.resolution}m, verbosity={self.verbosity})" def __del__(self): From 49ab57c381d15f3c7e5386268999b8a255f2ceba Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Wed, 29 Oct 2025 16:21:46 -0600 Subject: [PATCH 16/24] added a test showing that energy increases with gamma --- tests/test_basic.py | 71 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 67 insertions(+), 4 deletions(-) diff --git a/tests/test_basic.py b/tests/test_basic.py index 52320df..9565c9f 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -317,18 +317,18 @@ def test_energy_update_method(self): energy_2 = ocn.energy self.assertAlmostEqual(energy/energy_2, 1.0, places=3, msg="Energies do not match between full_energy_calc and incremental update methods.") - ocn = po.OCN.from_net_type("E", dims=(31, 31), random_state=12638) + ocn = po.OCN.from_net_type("E", dims=(31, 31), random_state=12638, gamma=0.21) ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=True) energy = ocn.energy - ocn = po.OCN.from_net_type("E", dims=(31, 31), random_state=12638) + ocn = po.OCN.from_net_type("E", dims=(31, 31), random_state=12638, gamma=0.21) ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=False) energy_2 = ocn.energy self.assertAlmostEqual(energy/energy_2, 1.0, places=3, msg="Energies do not match between full_energy_calc and incremental update methods.") - ocn = po.OCN.from_net_type("I", dims=(38, 38), random_state=19075) + ocn = po.OCN.from_net_type("I", dims=(38, 38), random_state=19075, gamma=0.78) ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=True) energy = ocn.energy - ocn = po.OCN.from_net_type("I", dims=(38, 38), random_state=19075) + ocn = po.OCN.from_net_type("I", dims=(38, 38), random_state=19075, gamma=0.78) ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=False) energy_2 = ocn.energy self.assertAlmostEqual(energy/energy_2, 1.0, places=3, msg="Energies do not match between full_energy_calc and incremental update methods.") @@ -342,5 +342,68 @@ def test_energy_update_method(self): self.assertAlmostEqual(energy/energy_2, 1.0, places=3, msg="Energies do not match between full_energy_calc and incremental update methods.") + def test_gamma(self): + """Test that different gamma values affect energy as expected.""" + ocn_gamma_1 = po.OCN.from_net_type( + net_type="V", + dims=(32, 32), + gamma=1.0, + random_state=1234, + ) + energy_gamma_1 = ocn_gamma_1.energy + + ocn_gamma_2 = po.OCN.from_net_type( + net_type="V", + dims=(32, 32), + gamma=0.5, + random_state=1234, + ) + energy_gamma_2 = ocn_gamma_2.energy + + ocn_gamma_3 = po.OCN.from_net_type( + net_type="V", + dims=(32, 32), + gamma=0.25, + random_state=1234, + ) + energy_gamma_3 = ocn_gamma_3.energy + + self.assertNotAlmostEqual(energy_gamma_1, energy_gamma_2, places=1, msg="Energies for gamma=1.0 and gamma=0.5 should differ.") + self.assertNotAlmostEqual(energy_gamma_1, energy_gamma_3, places=1, msg="Energies for gamma=1.0 and gamma=0.25 should differ.") + self.assertNotAlmostEqual(energy_gamma_2, energy_gamma_3, places=1, msg="Energies for gamma=0.5 and gamma=0.25 should differ.") + self.assertLess(energy_gamma_3, energy_gamma_2, msg="Energy should decrease with lower gamma.") + self.assertLess(energy_gamma_2, energy_gamma_1, msg="Energy should decrease with lower gamma.") + + ocn_gamma_1 = po.OCN.from_net_type( + net_type="E", + dims=(32, 32), + gamma=1.0, + random_state=1234, + ) + energy_gamma_1 = ocn_gamma_1.energy + + ocn_gamma_2 = po.OCN.from_net_type( + net_type="E", + dims=(32, 32), + gamma=0.5, + random_state=1234, + ) + energy_gamma_2 = ocn_gamma_2.energy + + ocn_gamma_3 = po.OCN.from_net_type( + net_type="E", + dims=(32, 32), + gamma=0.25, + random_state=1234, + ) + energy_gamma_3 = ocn_gamma_3.energy + + self.assertNotAlmostEqual(energy_gamma_1, energy_gamma_2, places=1, msg="Energies for gamma=1.0 and gamma=0.5 should differ.") + self.assertNotAlmostEqual(energy_gamma_1, energy_gamma_3, places=1, msg="Energies for gamma=1.0 and gamma=0.25 should differ.") + self.assertNotAlmostEqual(energy_gamma_2, energy_gamma_3, places=1, msg="Energies for gamma=0.5 and gamma=0.25 should differ.") + + self.assertLess(energy_gamma_3, energy_gamma_2, msg="Energy should decrease with lower gamma.") + self.assertLess(energy_gamma_2, energy_gamma_1, msg="Energy should decrease with lower gamma.") + if __name__ == "__main__": unittest.main() \ No newline at end of file From d1751fc145401a6e85d48252081bf68038115e2c Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Wed, 29 Oct 2025 16:29:36 -0600 Subject: [PATCH 17/24] cleaned up imports, added __version__, update version --- PyOCN/__init__.py | 15 +++------------ PyOCN/__pycache__/__init__.cpython-311.pyc | Bin 656 -> 313 bytes PyOCN/_version.py | 1 + docs/conf.py | 2 +- pyproject.toml | 2 +- sandbox.py | 3 +-- 6 files changed, 7 insertions(+), 16 deletions(-) create mode 100644 PyOCN/_version.py diff --git a/PyOCN/__init__.py b/PyOCN/__init__.py index 2c925c6..99d9313 100644 --- a/PyOCN/__init__.py +++ b/PyOCN/__init__.py @@ -1,16 +1,7 @@ from .ocn import OCN -from .plotting import plot_ocn_as_dag, plot_positional_digraph, plot_ocn_raster -from .utils import net_type_to_dag, simulated_annealing_schedule, unwrap_digraph, get_subwatersheds, assign_subwatersheds +from ._version import __version__ __all__ = [ "OCN", - "flowgrid", - "plot_ocn_as_dag", - "plot_positional_digraph", - "plot_ocn_raster", - "net_type_to_dag", - "simulated_annealing_schedule", - "unwrap_digraph", - "get_subwatersheds", - "assign_subwatersheds", -] \ No newline at end of file +] + diff --git a/PyOCN/__pycache__/__init__.cpython-311.pyc b/PyOCN/__pycache__/__init__.cpython-311.pyc index aeb2167d232c6521d50e1e1ab7071484b63d7d0e..103dab24453beb8d3eb5f197adb794914682f17b 100644 GIT binary patch delta 202 zcmbQhx|4}-IWI340}%Y4!ju`zFp*CpSp~?O&XB^8!kEL5%NWJT2x2qkFy%5wF$3Am z>5NeR literal 656 zcma)3y>1jS5VrT{cJG!W0-cm5FR(?}7eF+Dh9W^iG;S<&cJ_??v$DN9uH!A}X(Ha{ z%99%ks&pr%OU3wrU6S%p%hg0zq|!C1^^{EYjLdXH8jO0xiyxRg)?K`yD{-5uQrpfl`9==uSxR>zoA8fC{O8@`> diff --git a/PyOCN/_version.py b/PyOCN/_version.py new file mode 100644 index 0000000..89182ca --- /dev/null +++ b/PyOCN/_version.py @@ -0,0 +1 @@ +__version__ = "1.4.20251029" \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 0457131..8366283 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,7 +14,7 @@ project = 'PyOCN' copyright = '2025, Alexander S. Fox' author = 'Alexander S. Fox' -release = '1.3.20251011' +version = "1.4.20251029" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index 2983a9a..911ea05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "PyOCN" -version = "1.3.20251011" +version = "1.4.20251029" description = "Optimal Channel Networks (OCN) in Python with a C core" readme = "readme.md" requires-python = ">=3.10" diff --git a/sandbox.py b/sandbox.py index 8f16629..aec0363 100644 --- a/sandbox.py +++ b/sandbox.py @@ -64,5 +64,4 @@ print("Final energy:", ocn.energy) # plt.hist(np.diff(ocn.history[:, 1]), bins=np.linspace(-1000, 0)) -plt.show() - +plt.show() \ No newline at end of file From 99b6ed6761fa3a109053ddb82366a4e8098187f4 Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Wed, 29 Oct 2025 21:30:02 -0600 Subject: [PATCH 18/24] graph traversal methods --- PyOCN/_libocn_bindings.py | 20 ++++++ PyOCN/c_src/flowgrid.c | 76 ++++++++++++++++++++++ PyOCN/c_src/flowgrid.h | 30 +++++++++ PyOCN/ocn.py | 128 ++++++++++++++++++++++++++++++++++++++ PyOCN/utils.py | 117 +++++++++++++++++++++++++++++++++- 5 files changed, 369 insertions(+), 2 deletions(-) diff --git a/PyOCN/_libocn_bindings.py b/PyOCN/_libocn_bindings.py index 00262a6..30897c5 100644 --- a/PyOCN/_libocn_bindings.py +++ b/PyOCN/_libocn_bindings.py @@ -115,6 +115,14 @@ class FlowGrid_C(Structure): libocn.fg_cart_to_lin.argtypes = [CartPair_C, CartPair_C] libocn.fg_cart_to_lin.restype = linidx_t +# CartPair fg_lin_to_cart(linidx_t a, CartPair dims); +libocn.fg_lin_to_cart.argtypes = [linidx_t, CartPair_C] +libocn.fg_lin_to_cart.restype = CartPair_C + +# Status fg_clockhand_to_lin(linidx_t *a_down, linidx_t a, clockhand_t down, CartPair dims, bool wrap) +libocn.fg_clockhand_to_lin.argtypes = [POINTER(linidx_t), linidx_t, clockhand_t, CartPair_C, c_bool] +libocn.fg_clockhand_to_lin.restype = Status + # Status fg_get_lin(Vertex *out, FlowGrid *G, linidx_t a); libocn.fg_get_lin.argtypes = [POINTER(Vertex_C), POINTER(FlowGrid_C), linidx_t] libocn.fg_get_lin.restype = Status @@ -135,6 +143,18 @@ class FlowGrid_C(Structure): libocn.fg_destroy.argtypes = [POINTER(FlowGrid_C)] libocn.fg_destroy.restype = Status +# Status fg_find_upstream_neighbors(linidx_t upstream_indices[8], linidx_t *nupstream, FlowGrid *G, linidx_t a); +libocn.fg_find_upstream_neighbors.argtypes = [linidx_t * 8, POINTER(linidx_t), POINTER(FlowGrid_C), linidx_t] +libocn.fg_find_upstream_neighbors.restype = Status + +# Status fg_dfs_iterative(linidx_t upstream_indices[], linidx_t *nupstream, linidx_t idx_stack[], FlowGrid *G, linidx_t a); +libocn.fg_dfs_iterative.argtypes = [POINTER(linidx_t), POINTER(linidx_t), POINTER(linidx_t), POINTER(FlowGrid_C), linidx_t] +libocn.fg_dfs_iterative.restype = Status + +# Status flow_downstream(linidx_t downstream_indices[], linidx_t *ndownstream, FlowGrid *G, linidx_t a) +libocn.flow_downstream.argtypes = [POINTER(linidx_t), POINTER(linidx_t), POINTER(FlowGrid_C), linidx_t] +libocn.flow_downstream.restype = Status + ############################## # OCN.H EQUIVALENTS # ############################## diff --git a/PyOCN/c_src/flowgrid.c b/PyOCN/c_src/flowgrid.c index 627935e..ade1394 100644 --- a/PyOCN/c_src/flowgrid.c +++ b/PyOCN/c_src/flowgrid.c @@ -302,6 +302,82 @@ Status fg_check_for_cycles(FlowGrid *G, linidx_t a, uint8_t check_number){ return SUCCESS; // found root successfully, no cycles found } +Status fg_find_upstream_neighbors(linidx_t upstream_indices[8], linidx_t *nupstream, FlowGrid *G, linidx_t a){ + if (G == NULL || G->vertices == NULL || upstream_indices == NULL || nupstream == NULL) return NULL_POINTER_ERROR; + + Status code; + Vertex vert; + code = fg_get_lin(&vert, G, a); + if (code != SUCCESS) return code; + + // get the clockhand direction of all edges + *nupstream = 0; + for (clockhand_t dir = 0; dir < 8; dir++){ + if ((vert.edges & (1u << dir)) && (vert.downstream != dir)){ // if there's an edge in this direction + linidx_t a_up; + code = fg_clockhand_to_lin(&a_up, a, dir, G->dims, G->wrap); // get the upstream vertex index + if (code != SUCCESS) return code; + upstream_indices[*nupstream] = a_up; + (*nupstream)++; + } + } + return SUCCESS; +} + +Status fg_dfs_iterative(linidx_t upstream_indices[], linidx_t *nupstream, linidx_t idx_stack[], FlowGrid *G, linidx_t a){ + if (G == NULL || G->vertices == NULL || upstream_indices == NULL || nupstream == NULL || idx_stack == NULL) return NULL_POINTER_ERROR; + Status code; + + // Seed idx_stack with immediate upstream of a + *nupstream = 0; + linidx_t local_upstream[8]; + linidx_t nlocal_upstream = 0; + code = fg_find_upstream_neighbors(local_upstream, &nlocal_upstream, G, a); + if (code != SUCCESS) return code; + if (nlocal_upstream == 0) return SUCCESS; // no upstream vertices found + + linidx_t total = (linidx_t)G->dims.row * (linidx_t)G->dims.col; + if (nlocal_upstream >= total) return OOB_ERROR; + memcpy(idx_stack, local_upstream, nlocal_upstream * sizeof(linidx_t)); + + if (*nupstream >= total || total == 0) return OOB_ERROR; + linidx_t max_out = total - 1; + linidx_t top = nlocal_upstream; + while (top > 0){ + // pop from stack + linidx_t u = idx_stack[--top]; + if (u >= total) return OOB_ERROR; + + // Append to output buffer + if (*nupstream >= max_out) return OOB_ERROR; + upstream_indices[(*nupstream)++] = u; + + // Push u's upstream neighbors + nlocal_upstream = 0; + code = fg_find_upstream_neighbors(local_upstream, &nlocal_upstream, G, u); + if (code != SUCCESS) return code; + if (top + nlocal_upstream >= total) return OOB_ERROR; + memcpy(idx_stack + top, local_upstream, nlocal_upstream * sizeof(linidx_t)); + top += nlocal_upstream; + } + return SUCCESS; +} + +Status flow_downstream(linidx_t downstream_indices[], linidx_t *ndownstream, FlowGrid *G, linidx_t a){ + Vertex vert; + *ndownstream = 0; + do { + Status code = fg_get_lin(&vert, G, a); + if (code == OOB_ERROR) return OOB_ERROR; + downstream_indices[(*ndownstream)++] = a; + + a = vert.adown; + } while (vert.downstream != IS_ROOT); + + return SUCCESS; +} + + // vibe-coded display function const char E_ARROW = '-'; const char S_ARROW = '|'; diff --git a/PyOCN/c_src/flowgrid.h b/PyOCN/c_src/flowgrid.h index 41e7f78..b4f3d38 100644 --- a/PyOCN/c_src/flowgrid.h +++ b/PyOCN/c_src/flowgrid.h @@ -161,6 +161,36 @@ Status fg_change_vertex_outflow(FlowGrid *G, linidx_t a, clockhand_t down_new); */ Status fg_check_for_cycles(FlowGrid *G, linidx_t a, uint8_t check_number); +/** + * @brief Find all upstream neighbors of a given vertex. + * @param upstream_indices Array to store the linear indices of upstream neighbors. + * @param nupstream Pointer to store the number of upstream neighbors found. + * @param G Pointer to the FlowGrid. + * @param a The linear index of the vertex to find upstream neighbors for. + * @return Status code indicating success or failure + */ +Status fg_find_upstream_neighbors(linidx_t upstream_indices[8], linidx_t *nupstream, FlowGrid *G, linidx_t a); + +/** + * @brief Perform an iterative depth-first search to find all upstream vertices from a given starting vertex. + * @param upstream_indices Array to store the linear indices of all upstream vertices found. Must be preallocated to hold enough indices (ie G.dims.row * G.dims.col). + * @param nupstream Pointer to store the total number of upstream vertices found. + * @param idx_stack Preallocated stack array for DFS traversal. Must be large enough to hold all potential upstream vertices (ie G.dims.row * G.dims.col). + * @param G Pointer to the FlowGrid. + * @param a The linear index of the starting vertex. + * @return Status code indicating success or failure + */ +Status fg_dfs_iterative(linidx_t upstream_indices[], linidx_t *nupstream, linidx_t idx_stack[], FlowGrid *G, linidx_t a); + +/** + * @brief Follow the downstream path from a given vertex, recording each vertex along the path. Stops when a root node is reached. Includes root node. Does not include starting node. + * @param downstream_indices Array to store the linear indices of downstream vertices. Must be preallocated to hold enough indices (ie G.dims.row * G.dims.col). + * @param ndownstream Pointer to store the number of downstream vertices found. + * @param G Pointer to the FlowGrid. + * @param a The linear index of the starting vertex. + * @return Status code indicating success or failure + */ +Status flow_downstream(linidx_t downstream_indices[], linidx_t *ndownstream, FlowGrid *G, linidx_t a); /** * @brief Display the flowgrid in the terminal using ASCII or UTF-8 characters. * @param G Pointer to the FlowGrid to display. diff --git a/PyOCN/ocn.py b/PyOCN/ocn.py index f911197..431f12a 100644 --- a/PyOCN/ocn.py +++ b/PyOCN/ocn.py @@ -1,6 +1,7 @@ import warnings import ctypes from typing import Any, Callable, TYPE_CHECKING, Union +from collections.abc import Generator from os import PathLike from numbers import Number from pathlib import Path @@ -120,6 +121,9 @@ class OCN: >>> plt.show() """ + ########################### + # CONSTRUCTORS # + ########################### def __init__(self, dag: nx.DiGraph, resolution: float=1.0, gamma: float = 0.5, random_state=None, verbosity: int = 0, validate:bool=True, wrap : bool = False): """ Construct an :class:`OCN` from a valid NetworkX ``DiGraph``. @@ -263,6 +267,9 @@ def from_digraph(cls, dag: nx.DiGraph, resolution:float=1, gamma=0.5, random_sta return cls(dag, resolution, gamma, random_state, verbosity=verbosity, validate=True, wrap=wrap) + ########################### + # DUNDER METHODS # + ########################### def __repr__(self): return f"\n\n" def __str__(self): @@ -318,6 +325,9 @@ def copy(self) -> "OCN": """ return self.__copy__() + ########################### + # PROPERTIES # + ########################### def compute_energy(self) -> float: """ Compute the current energy of the network. @@ -366,6 +376,121 @@ def rng(self, random_state:int|None|np.random.Generator=None): def history(self) -> np.ndarray: return self.__history + ########################### + # GRAPH TRAVERSAL # + ########################### + def predecessors(self, pos:tuple[int, int]) -> Generator[tuple[int, int], None, None]: + """Returns an iterator over predecessor nodes of the node at position `pos` in the OCN. + A predecessor is defined as an immediate upstream neighbor. + + A predecessor of a node n is a node m such that there exists a directed edge from m to n. + + Mirrors the behavior of `networkx.DiGraph.predecessors`, but works directly on the OCN C graph structure. + + Yields (row, col) of successors. + + Parameters + ---------- + ocn : OCN + The OCN instance. + pos : tuple[int, int] + The (row, col) position of the node whose predecessors are to be found. + + Raises + ------- + TypeError + If `pos` is not a tuple of two integers. + IndexError + If `pos` is out of bounds for the current OCN grid. + + See Also + -------- + :meth:`OCN.successors` + """ + + pos = tuple(pos) + if len(pos) != 2 or not all(isinstance(p, int) for p in pos): + raise TypeError(f"Position must be a tuple of two integers. Got {pos}.") + if (pos[0] < 0 or pos[0] >= self.dims[0]) or (pos[1] < 0 or pos[1] >= self.dims[1]): + raise IndexError(f"Position {pos} is out of bounds for OCN with dimensions {self.dims}.") + + # convert (row, col) to linear index + a = _bindings.libocn.fg_cart_to_lin( + _bindings.CartPair_C(row=pos[0], col=pos[1]), + _bindings.CartPair_C(row=self.dims[0], col=self.dims[1]), + ) + + upstream_indices = (_bindings.linidx_t * 8)() + nupstream = _bindings.linidx_t() + check_status(_bindings.libocn.fg_find_upstream_neighbors( + upstream_indices, + ctypes.byref(nupstream), + self.__p_c_graph, + _bindings.linidx_t(a) + )) + count = int(nupstream.value) + + # convert back to cartesian (row, col) + for idx in upstream_indices[:count]: + cart = _bindings.libocn.fg_lin_to_cart( + idx, + _bindings.CartPair_C(row=self.dims[0], col=self.dims[1]) + ) + yield (int(cart.row), int(cart.col)) + + + def successors(self, pos: tuple[int, int]) -> Generator[tuple[int, int], None, None]: + """Returns an iterator over successor nodes of the node at position `pos` in the OCN. + A successor is the immediate downstream neighbor (at most one; none if root). + + Mirrors the behavior of `networkx.DiGraph.successors`, but works directly on the OCN C graph structure. + Yields (row, col) of successors. + + Parameters + ---------- + pos : tuple[int, int] + The (row, col) position of the node whose successors are to be found. + + Raises + ------- + TypeError + If `pos` is not a tuple of two integers. + IndexError + If `pos` is out of bounds for the current OCN grid. + """ + pos = tuple(pos) + if len(pos) != 2 or not all(isinstance(p, int) for p in pos): + raise TypeError(f"Position must be a tuple of two integers. Got {pos}.") + if (pos[0] < 0 or pos[0] >= self.dims[0]) or (pos[1] < 0 or pos[1] >= self.dims[1]): + raise IndexError(f"Position {pos} is out of bounds for OCN with dimensions {self.dims}.") + + # (row, col) -> linear index + a = _bindings.libocn.fg_cart_to_lin( + _bindings.CartPair_C(row=pos[0], col=pos[1]), + _bindings.CartPair_C(row=self.dims[0], col=self.dims[1]), + ) + + # Read vertex to get downstream direction + vert = _bindings.Vertex_C() + check_status(_bindings.libocn.fg_get_lin( + ctypes.byref(vert), + self.__p_c_graph, + _bindings.linidx_t(a), + )) + + # If root, no successors + if int(vert.downstream) == _bindings.IS_ROOT: + return + + cart = _bindings.libocn.fg_lin_to_cart( + vert.adown, + _bindings.CartPair_C(row=self.dims[0], col=self.dims[1]) + ) + yield (int(cart.row), int(cart.col)) + + ########################### + # EXPORT # + ########################### def to_digraph(self) -> nx.DiGraph: """ Create a NetworkX ``DiGraph`` view of the current grid. @@ -595,6 +720,9 @@ def to_xarray(self, unwrap:bool=True) -> "xr.Dataset": } ) + ########################### + # OPTIMIZATION # + ########################### def single_iteration(self, temperature:float, array_report:bool=False, unwrap:bool=True, calculate_full_energy:bool=False) -> "xr.Dataset | None": """ Perform a single iteration of the optimization algorithm at a given temperature. Updates the internal history attribute. diff --git a/PyOCN/utils.py b/PyOCN/utils.py index 2ffcaa3..a312042 100644 --- a/PyOCN/utils.py +++ b/PyOCN/utils.py @@ -3,7 +3,9 @@ """ from __future__ import annotations +from collections.abc import Generator from concurrent.futures import ThreadPoolExecutor +from ctypes import byref from itertools import product import os from typing import Any, Literal, Callable, TYPE_CHECKING, Union @@ -13,6 +15,9 @@ from tqdm import tqdm from functools import partial +import PyOCN._libocn_bindings as _bindings +from PyOCN._statushandler import check_status + if TYPE_CHECKING: from .ocn import OCN @@ -363,7 +368,6 @@ def get_subwatersheds(dag : nx.DiGraph, node : Any) -> set[nx.DiGraph]: subwatersheds = set(dag.subgraph(wshd) for wshd in subwatersheds) return subwatersheds - def parallel_fit( ocn:OCN|list[OCN], n_runs=5, @@ -457,4 +461,113 @@ def fit(ocn, i, fit_method, fit_kwargs): res, idx = future.result() results[idx] = res pbar.update(1) - return results, ocn \ No newline at end of file + return results, ocn + +def ancestors(ocn:OCN, pos:tuple[int, int]) -> set[tuple[int, int]]: + """Returns all nodes that drain into the node at position `pos` in the OCN. + + Parameters + ---------- + ocn : OCN + The OCN instance. + pos : tuple[int, int] + The (row, col) position of the node whose predecessors are to be found. + + Returns + ------- + set() + A set of (row, col) tuples representing the positions of all ancestor nodes + that eventually drain into the specified node, not including the node itself. + + See Also + -------- + :meth:`OCN.predecessors` + :meth:`OCN.successors` + :meth:`utils.descendants` + """ + + pos = tuple(pos) + if len(pos) != 2 or not all(isinstance(p, int) for p in pos): + raise TypeError(f"Position must be a tuple of two integers. Got {pos}.") + if (pos[0] < 0 or pos[0] >= ocn.dims[0]) or (pos[1] < 0 or pos[1] >= ocn.dims[1]): + raise IndexError(f"Position {pos} is out of bounds for OCN with dimensions {ocn.dims}.") + + a = _bindings.libocn.fg_cart_to_lin( + _bindings.CartPair_C(row=pos[0], col=pos[1]), + _bindings.CartPair_C(row=ocn.dims[0], col=ocn.dims[1]) + ) + upstream_indices = (_bindings.linidx_t * (ocn.dims[0]*ocn.dims[1]))() + nupstream = _bindings.linidx_t(0) + check_status(_bindings.libocn.fg_dfs_iterative( + upstream_indices, + byref(nupstream), + (_bindings.linidx_t * (ocn.dims[0]*ocn.dims[1]))(), + ocn._OCN__p_c_graph, + a + )) + + return set( + (int(c.row), int(c.col)) + for c in ( + _bindings.libocn.fg_lin_to_cart( + idx, + _bindings.CartPair_C(row=ocn.dims[0], col=ocn.dims[1]) + ) + for idx in upstream_indices[:nupstream.value] + ) + ) + + +def descendants(ocn:OCN, pos:tuple[int, int]) -> set[tuple[int, int]]: + """Returns all nodes that are reachable from the node at position `pos` in the OCN. + Mirrors the functionality of `networkx.descendants`. + + Parameters + ---------- + ocn : OCN + The OCN instance. + pos : tuple[int, int] + The (row, col) position of the node whose successors are to be found. + + Returns + ------- + set() + A set of (row, col) tuples representing the positions of all descendant nodes + that the specified node eventually drains into, not including the node itself. + + See Also + -------- + :meth:`OCN.predecessors` + :meth:`OCN.successors` + :meth:`utils.ancestors` + """ + + pos = tuple(pos) + if len(pos) != 2 or not all(isinstance(p, int) for p in pos): + raise TypeError(f"Position must be a tuple of two integers. Got {pos}.") + if (pos[0] < 0 or pos[0] >= ocn.dims[0]) or (pos[1] < 0 or pos[1] >= ocn.dims[1]): + raise IndexError(f"Position {pos} is out of bounds for OCN with dimensions {ocn.dims}.") + + a = _bindings.libocn.fg_cart_to_lin( + _bindings.CartPair_C(row=pos[0], col=pos[1]), + _bindings.CartPair_C(row=ocn.dims[0], col=ocn.dims[1]) + ) + downstream_indices = (_bindings.linidx_t * (ocn.dims[0]*ocn.dims[1]))() + ndownstream = _bindings.linidx_t(0) + check_status(_bindings.libocn.flow_downstream( + downstream_indices, + byref(ndownstream), + ocn._OCN__p_c_graph, + a + )) + + return set( + (int(c.row), int(c.col)) + for c in ( + _bindings.libocn.fg_lin_to_cart( + idx, + _bindings.CartPair_C(row=ocn.dims[0], col=ocn.dims[1]) + ) + for idx in downstream_indices[:ndownstream.value] + ) + ) \ No newline at end of file From c64e8f7f8c3ce99d24724d42513ebe8c04cc016d Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Wed, 29 Oct 2025 22:54:45 -0600 Subject: [PATCH 19/24] traversal functions I guess --- PyOCN/__init__.py | 7 +++ PyOCN/__pycache__/__init__.cpython-311.pyc | Bin 313 -> 430 bytes PyOCN/plotting.py | 1 + PyOCN/utils.py | 2 +- sandbox.py | 64 +++------------------ 5 files changed, 17 insertions(+), 57 deletions(-) diff --git a/PyOCN/__init__.py b/PyOCN/__init__.py index 99d9313..1051aa8 100644 --- a/PyOCN/__init__.py +++ b/PyOCN/__init__.py @@ -1,7 +1,14 @@ from .ocn import OCN from ._version import __version__ +from . import utils +from . import plotting + + __all__ = [ "OCN", + "utils", + "plotting", + "__version__", ] diff --git a/PyOCN/__pycache__/__init__.cpython-311.pyc b/PyOCN/__pycache__/__init__.cpython-311.pyc index 103dab24453beb8d3eb5f197adb794914682f17b..b5f2afb85b5e8006e198e06723600f0c9136c6e9 100644 GIT binary patch delta 228 zcmdnVw2ql~IWI340}w2`!<4yUBCjN)-9&XcRfZJi9F|6owQQAe$|UjgcXp zA%%4jV-$M|TQGwr`@{h008PeQtfeKHImJK*M?p@0Nl9j2x+Y5zGtkH)RuI7kB3OWg zpC-pG=KSP5u-q+<__EZZ;>`R!u-q+};)$K9Tzo*qj6htRH1UUy-2(>U3#jM@gV+UB V^ns0)o9P1sjA&p7!6F`@836M{Hah?S delta 112 zcmZ3-ypxG{IWI340}%Y4!ju_2kynyYWum$qH*-2;6iW(AFoP!R#4KsKB4(fpKTXzK z%=yWAV9qU$__EZZ;>`TKB9@6)Q@OZ-3K)U7*mAN9qmJkYHU?I%56n#5ObzTHSi}jG F0|2|_88QF> diff --git a/PyOCN/plotting.py b/PyOCN/plotting.py index a149460..545e376 100644 --- a/PyOCN/plotting.py +++ b/PyOCN/plotting.py @@ -98,6 +98,7 @@ def plot_ocn_as_dag(ocn: OCN, attribute: str | None = None, ax=None, norm=None, kwargs["vmax"] = 1 node_color = norm(node_color) + p = nx.draw_networkx(dag, node_color=node_color, pos=pos, ax=ax, **kwargs) return p, ax diff --git a/PyOCN/utils.py b/PyOCN/utils.py index a312042..af5fb10 100644 --- a/PyOCN/utils.py +++ b/PyOCN/utils.py @@ -517,7 +517,7 @@ def ancestors(ocn:OCN, pos:tuple[int, int]) -> set[tuple[int, int]]: ) ) - +# TODO write tests for the traversal functions def descendants(ocn:OCN, pos:tuple[int, int]) -> set[tuple[int, int]]: """Returns all nodes that are reachable from the node at position `pos` in the OCN. Mirrors the functionality of `networkx.descendants`. diff --git a/sandbox.py b/sandbox.py index aec0363..994cb96 100644 --- a/sandbox.py +++ b/sandbox.py @@ -6,62 +6,14 @@ from tqdm import trange import numpy as np -# def check_upstream_function_result(numpstream_iter, upstream_indices_iter, correct_answer, status): -# # print("Iterative:") -# # print("\tNumber of upstream vertices:", numpstream_iter) -# # print("Correct answer:") -# # print("\tNumber of upstream vertices:", len(correct_answer)) -# # print("Status code:", status) -# success = (numpstream_iter == len(correct_answer)) and (upstream_indices_iter == correct_answer) and (status == 0) -# # print(f"Success: {success}") -# return success +ocn = po.OCN.from_net_type("H", dims=(10, 10)) -# def run_upstream_functions(ocn, a): -# m, n = ocn.dims -# a = po._libocn_bindings.linidx_t(a) +pos = (7, 7) +print(list(ocn.predecessors(pos))) +print(list(ocn.successors(pos))) +print(po.utils.ancestors(ocn, pos)) +print(po.utils.descendants(ocn, pos)) -# c_graph = ocn._OCN__p_c_graph -# upstream_indices = (po._libocn_bindings.linidx_t * (m * n))() -# stack = (po._libocn_bindings.linidx_t * (m * n))() -# nupstream_val = po._libocn_bindings.linidx_t(0) +po.plotting.plot_ocn_as_dag(ocn) +plt.show() -# t0 = timer() -# status = po._libocn_bindings.libocn.fg_dfs_iterative( -# upstream_indices, -# ctypes.byref(nupstream_val), -# stack, -# c_graph, -# a -# ) -# t1 = timer() -# po._libocn_bindings.libocn.ocn_compute_energy(c_graph, ocn.gamma) -# t2 = timer() - -# numpstream_iter = nupstream_val.value -# upstream_indices_iter = sorted(list(upstream_indices)[:nupstream_val.value]) - -# return numpstream_iter, upstream_indices_iter, status, t1 - t0, t2 - t1 - -# ocn = po.OCN.from_net_type("I", dims=(10, 10)) -# a = 85 -# correct_answer = [80, 81, 82, 83, 84, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] -# numpstream_iter, upstream_indices_iter, status, time_dfs, time_energy = run_upstream_functions(ocn, a) -# check_upstream_function_result(numpstream_iter, upstream_indices_iter, correct_answer, status) - -# ocn = po.OCN.from_net_type("I", dims=(10, 10)) -# a = 33 -# correct_answer = [30, 31, 32] -# numpstream_iter, upstream_indices_iter, status, time_dfs, time_energy = run_upstream_functions(ocn, a) -# check_upstream_function_result(numpstream_iter, upstream_indices_iter, correct_answer, status) - - -ocn = po.OCN.from_net_type("E", dims=(64, 64), random_state=8471) -ocn.fit(pbar=True, max_iterations_per_loop=100, calculate_full_energy=True) -print("Final energy:", ocn.energy) - -ocn = po.OCN.from_net_type("E", dims=(64, 64), random_state=8471) -ocn.fit(pbar=True, max_iterations_per_loop=100, calculate_full_energy=False) -print("Final energy:", ocn.energy) -# plt.hist(np.diff(ocn.history[:, 1]), bins=np.linspace(-1000, 0)) - -plt.show() \ No newline at end of file From efdaa5be4508957f1e48828405cc1327654d708f Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Thu, 30 Oct 2025 08:29:06 -0600 Subject: [PATCH 20/24] temporarily commenting out graph traversal methods --- PyOCN/ocn.py | 203 ++++++++++++++++++++++----------------------- PyOCN/utils.py | 217 +++++++++++++++++++++++++------------------------ 2 files changed, 211 insertions(+), 209 deletions(-) diff --git a/PyOCN/ocn.py b/PyOCN/ocn.py index 431f12a..feb5811 100644 --- a/PyOCN/ocn.py +++ b/PyOCN/ocn.py @@ -379,114 +379,115 @@ def history(self) -> np.ndarray: ########################### # GRAPH TRAVERSAL # ########################### - def predecessors(self, pos:tuple[int, int]) -> Generator[tuple[int, int], None, None]: - """Returns an iterator over predecessor nodes of the node at position `pos` in the OCN. - A predecessor is defined as an immediate upstream neighbor. + # # Commenting out for now. May decide to re-introduce later + # def predecessors(self, pos:tuple[int, int]) -> Generator[tuple[int, int], None, None]: + # """Returns an iterator over predecessor nodes of the node at position `pos` in the OCN. + # A predecessor is defined as an immediate upstream neighbor. - A predecessor of a node n is a node m such that there exists a directed edge from m to n. + # A predecessor of a node n is a node m such that there exists a directed edge from m to n. - Mirrors the behavior of `networkx.DiGraph.predecessors`, but works directly on the OCN C graph structure. + # Mirrors the behavior of `networkx.DiGraph.predecessors`, but works directly on the OCN C graph structure. - Yields (row, col) of successors. + # Yields (row, col) of successors. - Parameters - ---------- - ocn : OCN - The OCN instance. - pos : tuple[int, int] - The (row, col) position of the node whose predecessors are to be found. + # Parameters + # ---------- + # ocn : OCN + # The OCN instance. + # pos : tuple[int, int] + # The (row, col) position of the node whose predecessors are to be found. - Raises - ------- - TypeError - If `pos` is not a tuple of two integers. - IndexError - If `pos` is out of bounds for the current OCN grid. - - See Also - -------- - :meth:`OCN.successors` - """ - - pos = tuple(pos) - if len(pos) != 2 or not all(isinstance(p, int) for p in pos): - raise TypeError(f"Position must be a tuple of two integers. Got {pos}.") - if (pos[0] < 0 or pos[0] >= self.dims[0]) or (pos[1] < 0 or pos[1] >= self.dims[1]): - raise IndexError(f"Position {pos} is out of bounds for OCN with dimensions {self.dims}.") - - # convert (row, col) to linear index - a = _bindings.libocn.fg_cart_to_lin( - _bindings.CartPair_C(row=pos[0], col=pos[1]), - _bindings.CartPair_C(row=self.dims[0], col=self.dims[1]), - ) - - upstream_indices = (_bindings.linidx_t * 8)() - nupstream = _bindings.linidx_t() - check_status(_bindings.libocn.fg_find_upstream_neighbors( - upstream_indices, - ctypes.byref(nupstream), - self.__p_c_graph, - _bindings.linidx_t(a) - )) - count = int(nupstream.value) - - # convert back to cartesian (row, col) - for idx in upstream_indices[:count]: - cart = _bindings.libocn.fg_lin_to_cart( - idx, - _bindings.CartPair_C(row=self.dims[0], col=self.dims[1]) - ) - yield (int(cart.row), int(cart.col)) - - - def successors(self, pos: tuple[int, int]) -> Generator[tuple[int, int], None, None]: - """Returns an iterator over successor nodes of the node at position `pos` in the OCN. - A successor is the immediate downstream neighbor (at most one; none if root). - - Mirrors the behavior of `networkx.DiGraph.successors`, but works directly on the OCN C graph structure. - Yields (row, col) of successors. - - Parameters - ---------- - pos : tuple[int, int] - The (row, col) position of the node whose successors are to be found. + # Raises + # ------- + # TypeError + # If `pos` is not a tuple of two integers. + # IndexError + # If `pos` is out of bounds for the current OCN grid. + + # See Also + # -------- + # :meth:`OCN.successors` + # """ + + # pos = tuple(pos) + # if len(pos) != 2 or not all(isinstance(p, int) for p in pos): + # raise TypeError(f"Position must be a tuple of two integers. Got {pos}.") + # if (pos[0] < 0 or pos[0] >= self.dims[0]) or (pos[1] < 0 or pos[1] >= self.dims[1]): + # raise IndexError(f"Position {pos} is out of bounds for OCN with dimensions {self.dims}.") + + # # convert (row, col) to linear index + # a = _bindings.libocn.fg_cart_to_lin( + # _bindings.CartPair_C(row=pos[0], col=pos[1]), + # _bindings.CartPair_C(row=self.dims[0], col=self.dims[1]), + # ) + + # upstream_indices = (_bindings.linidx_t * 8)() + # nupstream = _bindings.linidx_t() + # check_status(_bindings.libocn.fg_find_upstream_neighbors( + # upstream_indices, + # ctypes.byref(nupstream), + # self.__p_c_graph, + # _bindings.linidx_t(a) + # )) + # count = int(nupstream.value) + + # # convert back to cartesian (row, col) + # for idx in upstream_indices[:count]: + # cart = _bindings.libocn.fg_lin_to_cart( + # idx, + # _bindings.CartPair_C(row=self.dims[0], col=self.dims[1]) + # ) + # yield (int(cart.row), int(cart.col)) + + + # def successors(self, pos: tuple[int, int]) -> Generator[tuple[int, int], None, None]: + # """Returns an iterator over successor nodes of the node at position `pos` in the OCN. + # A successor is the immediate downstream neighbor (at most one; none if root). + + # Mirrors the behavior of `networkx.DiGraph.successors`, but works directly on the OCN C graph structure. + # Yields (row, col) of successors. + + # Parameters + # ---------- + # pos : tuple[int, int] + # The (row, col) position of the node whose successors are to be found. - Raises - ------- - TypeError - If `pos` is not a tuple of two integers. - IndexError - If `pos` is out of bounds for the current OCN grid. - """ - pos = tuple(pos) - if len(pos) != 2 or not all(isinstance(p, int) for p in pos): - raise TypeError(f"Position must be a tuple of two integers. Got {pos}.") - if (pos[0] < 0 or pos[0] >= self.dims[0]) or (pos[1] < 0 or pos[1] >= self.dims[1]): - raise IndexError(f"Position {pos} is out of bounds for OCN with dimensions {self.dims}.") - - # (row, col) -> linear index - a = _bindings.libocn.fg_cart_to_lin( - _bindings.CartPair_C(row=pos[0], col=pos[1]), - _bindings.CartPair_C(row=self.dims[0], col=self.dims[1]), - ) - - # Read vertex to get downstream direction - vert = _bindings.Vertex_C() - check_status(_bindings.libocn.fg_get_lin( - ctypes.byref(vert), - self.__p_c_graph, - _bindings.linidx_t(a), - )) - - # If root, no successors - if int(vert.downstream) == _bindings.IS_ROOT: - return + # Raises + # ------- + # TypeError + # If `pos` is not a tuple of two integers. + # IndexError + # If `pos` is out of bounds for the current OCN grid. + # """ + # pos = tuple(pos) + # if len(pos) != 2 or not all(isinstance(p, int) for p in pos): + # raise TypeError(f"Position must be a tuple of two integers. Got {pos}.") + # if (pos[0] < 0 or pos[0] >= self.dims[0]) or (pos[1] < 0 or pos[1] >= self.dims[1]): + # raise IndexError(f"Position {pos} is out of bounds for OCN with dimensions {self.dims}.") + + # # (row, col) -> linear index + # a = _bindings.libocn.fg_cart_to_lin( + # _bindings.CartPair_C(row=pos[0], col=pos[1]), + # _bindings.CartPair_C(row=self.dims[0], col=self.dims[1]), + # ) + + # # Read vertex to get downstream direction + # vert = _bindings.Vertex_C() + # check_status(_bindings.libocn.fg_get_lin( + # ctypes.byref(vert), + # self.__p_c_graph, + # _bindings.linidx_t(a), + # )) + + # # If root, no successors + # if int(vert.downstream) == _bindings.IS_ROOT: + # return - cart = _bindings.libocn.fg_lin_to_cart( - vert.adown, - _bindings.CartPair_C(row=self.dims[0], col=self.dims[1]) - ) - yield (int(cart.row), int(cart.col)) + # cart = _bindings.libocn.fg_lin_to_cart( + # vert.adown, + # _bindings.CartPair_C(row=self.dims[0], col=self.dims[1]) + # ) + # yield (int(cart.row), int(cart.col)) ########################### # EXPORT # diff --git a/PyOCN/utils.py b/PyOCN/utils.py index af5fb10..186d63e 100644 --- a/PyOCN/utils.py +++ b/PyOCN/utils.py @@ -463,111 +463,112 @@ def fit(ocn, i, fit_method, fit_kwargs): pbar.update(1) return results, ocn -def ancestors(ocn:OCN, pos:tuple[int, int]) -> set[tuple[int, int]]: - """Returns all nodes that drain into the node at position `pos` in the OCN. - - Parameters - ---------- - ocn : OCN - The OCN instance. - pos : tuple[int, int] - The (row, col) position of the node whose predecessors are to be found. - - Returns - ------- - set() - A set of (row, col) tuples representing the positions of all ancestor nodes - that eventually drain into the specified node, not including the node itself. - - See Also - -------- - :meth:`OCN.predecessors` - :meth:`OCN.successors` - :meth:`utils.descendants` - """ - - pos = tuple(pos) - if len(pos) != 2 or not all(isinstance(p, int) for p in pos): - raise TypeError(f"Position must be a tuple of two integers. Got {pos}.") - if (pos[0] < 0 or pos[0] >= ocn.dims[0]) or (pos[1] < 0 or pos[1] >= ocn.dims[1]): - raise IndexError(f"Position {pos} is out of bounds for OCN with dimensions {ocn.dims}.") - - a = _bindings.libocn.fg_cart_to_lin( - _bindings.CartPair_C(row=pos[0], col=pos[1]), - _bindings.CartPair_C(row=ocn.dims[0], col=ocn.dims[1]) - ) - upstream_indices = (_bindings.linidx_t * (ocn.dims[0]*ocn.dims[1]))() - nupstream = _bindings.linidx_t(0) - check_status(_bindings.libocn.fg_dfs_iterative( - upstream_indices, - byref(nupstream), - (_bindings.linidx_t * (ocn.dims[0]*ocn.dims[1]))(), - ocn._OCN__p_c_graph, - a - )) - - return set( - (int(c.row), int(c.col)) - for c in ( - _bindings.libocn.fg_lin_to_cart( - idx, - _bindings.CartPair_C(row=ocn.dims[0], col=ocn.dims[1]) - ) - for idx in upstream_indices[:nupstream.value] - ) - ) - -# TODO write tests for the traversal functions -def descendants(ocn:OCN, pos:tuple[int, int]) -> set[tuple[int, int]]: - """Returns all nodes that are reachable from the node at position `pos` in the OCN. - Mirrors the functionality of `networkx.descendants`. - - Parameters - ---------- - ocn : OCN - The OCN instance. - pos : tuple[int, int] - The (row, col) position of the node whose successors are to be found. - - Returns - ------- - set() - A set of (row, col) tuples representing the positions of all descendant nodes - that the specified node eventually drains into, not including the node itself. - - See Also - -------- - :meth:`OCN.predecessors` - :meth:`OCN.successors` - :meth:`utils.ancestors` - """ - - pos = tuple(pos) - if len(pos) != 2 or not all(isinstance(p, int) for p in pos): - raise TypeError(f"Position must be a tuple of two integers. Got {pos}.") - if (pos[0] < 0 or pos[0] >= ocn.dims[0]) or (pos[1] < 0 or pos[1] >= ocn.dims[1]): - raise IndexError(f"Position {pos} is out of bounds for OCN with dimensions {ocn.dims}.") - - a = _bindings.libocn.fg_cart_to_lin( - _bindings.CartPair_C(row=pos[0], col=pos[1]), - _bindings.CartPair_C(row=ocn.dims[0], col=ocn.dims[1]) - ) - downstream_indices = (_bindings.linidx_t * (ocn.dims[0]*ocn.dims[1]))() - ndownstream = _bindings.linidx_t(0) - check_status(_bindings.libocn.flow_downstream( - downstream_indices, - byref(ndownstream), - ocn._OCN__p_c_graph, - a - )) - - return set( - (int(c.row), int(c.col)) - for c in ( - _bindings.libocn.fg_lin_to_cart( - idx, - _bindings.CartPair_C(row=ocn.dims[0], col=ocn.dims[1]) - ) - for idx in downstream_indices[:ndownstream.value] - ) - ) \ No newline at end of file +# # Commenting out for now. May decide to re-introduce later +# def ancestors(ocn:OCN, pos:tuple[int, int]) -> set[tuple[int, int]]: +# """Returns all nodes that drain into the node at position `pos` in the OCN. + +# Parameters +# ---------- +# ocn : OCN +# The OCN instance. +# pos : tuple[int, int] +# The (row, col) position of the node whose predecessors are to be found. + +# Returns +# ------- +# set() +# A set of (row, col) tuples representing the positions of all ancestor nodes +# that eventually drain into the specified node, not including the node itself. + +# See Also +# -------- +# :meth:`OCN.predecessors` +# :meth:`OCN.successors` +# :meth:`utils.descendants` +# """ + +# pos = tuple(pos) +# if len(pos) != 2 or not all(isinstance(p, int) for p in pos): +# raise TypeError(f"Position must be a tuple of two integers. Got {pos}.") +# if (pos[0] < 0 or pos[0] >= ocn.dims[0]) or (pos[1] < 0 or pos[1] >= ocn.dims[1]): +# raise IndexError(f"Position {pos} is out of bounds for OCN with dimensions {ocn.dims}.") + +# a = _bindings.libocn.fg_cart_to_lin( +# _bindings.CartPair_C(row=pos[0], col=pos[1]), +# _bindings.CartPair_C(row=ocn.dims[0], col=ocn.dims[1]) +# ) +# upstream_indices = (_bindings.linidx_t * (ocn.dims[0]*ocn.dims[1]))() +# nupstream = _bindings.linidx_t(0) +# check_status(_bindings.libocn.fg_dfs_iterative( +# upstream_indices, +# byref(nupstream), +# (_bindings.linidx_t * (ocn.dims[0]*ocn.dims[1]))(), +# ocn._OCN__p_c_graph, +# a +# )) + +# return set( +# (int(c.row), int(c.col)) +# for c in ( +# _bindings.libocn.fg_lin_to_cart( +# idx, +# _bindings.CartPair_C(row=ocn.dims[0], col=ocn.dims[1]) +# ) +# for idx in upstream_indices[:nupstream.value] +# ) +# ) + +# # TODO write tests for the traversal functions +# def descendants(ocn:OCN, pos:tuple[int, int]) -> set[tuple[int, int]]: +# """Returns all nodes that are reachable from the node at position `pos` in the OCN. +# Mirrors the functionality of `networkx.descendants`. + +# Parameters +# ---------- +# ocn : OCN +# The OCN instance. +# pos : tuple[int, int] +# The (row, col) position of the node whose successors are to be found. + +# Returns +# ------- +# set() +# A set of (row, col) tuples representing the positions of all descendant nodes +# that the specified node eventually drains into, not including the node itself. + +# See Also +# -------- +# :meth:`OCN.predecessors` +# :meth:`OCN.successors` +# :meth:`utils.ancestors` +# """ + +# pos = tuple(pos) +# if len(pos) != 2 or not all(isinstance(p, int) for p in pos): +# raise TypeError(f"Position must be a tuple of two integers. Got {pos}.") +# if (pos[0] < 0 or pos[0] >= ocn.dims[0]) or (pos[1] < 0 or pos[1] >= ocn.dims[1]): +# raise IndexError(f"Position {pos} is out of bounds for OCN with dimensions {ocn.dims}.") + +# a = _bindings.libocn.fg_cart_to_lin( +# _bindings.CartPair_C(row=pos[0], col=pos[1]), +# _bindings.CartPair_C(row=ocn.dims[0], col=ocn.dims[1]) +# ) +# downstream_indices = (_bindings.linidx_t * (ocn.dims[0]*ocn.dims[1]))() +# ndownstream = _bindings.linidx_t(0) +# check_status(_bindings.libocn.flow_downstream( +# downstream_indices, +# byref(ndownstream), +# ocn._OCN__p_c_graph, +# a +# )) + +# return set( +# (int(c.row), int(c.col)) +# for c in ( +# _bindings.libocn.fg_lin_to_cart( +# idx, +# _bindings.CartPair_C(row=ocn.dims[0], col=ocn.dims[1]) +# ) +# for idx in downstream_indices[:ndownstream.value] +# ) +# ) \ No newline at end of file From 7bf516e4c1b12acaa2509ba31d7ee8fd9a790926 Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Thu, 30 Oct 2025 17:48:21 -0600 Subject: [PATCH 21/24] fixed a bug in get_subwatersheds --- PyOCN/_flowgrid_convert.py | 7 +++++++ PyOCN/plotting.py | 8 +++++++- PyOCN/utils.py | 13 +++++++++++-- sandbox.py | 30 +++++++++++++----------------- 4 files changed, 38 insertions(+), 20 deletions(-) diff --git a/PyOCN/_flowgrid_convert.py b/PyOCN/_flowgrid_convert.py index 9dc2fab..3b2a9fb 100644 --- a/PyOCN/_flowgrid_convert.py +++ b/PyOCN/_flowgrid_convert.py @@ -285,3 +285,10 @@ def validate_flowgrid(c_graph:_bindings.FlowGrid_C, verbose:bool=False) -> bool| except Exception as e: # _digraph_to_flowgrid_c will destroy p_c_graph on failure return str(e) return "Graph is valid." + +__all__ = [ + "to_digraph", + "from_digraph", + "validate_digraph", + "validate_flowgrid", +] diff --git a/PyOCN/plotting.py b/PyOCN/plotting.py index 545e376..f79ab76 100644 --- a/PyOCN/plotting.py +++ b/PyOCN/plotting.py @@ -173,4 +173,10 @@ def plot_positional_digraph(dag: nx.DiGraph, ax=None, **kwargs): _, ax = plt.subplots() p = nx.draw_networkx(dag, pos=pos, ax=ax, **kwargs) - return p, ax \ No newline at end of file + return p, ax + +__all__ = [ + "plot_ocn_as_dag", + "plot_ocn_raster", + "plot_positional_digraph", +] \ No newline at end of file diff --git a/PyOCN/utils.py b/PyOCN/utils.py index 186d63e..38ba486 100644 --- a/PyOCN/utils.py +++ b/PyOCN/utils.py @@ -364,7 +364,7 @@ def get_subwatersheds(dag : nx.DiGraph, node : Any) -> set[nx.DiGraph]: in the subwatersheds will affect the original graph. """ subwatershed_outlets = [n for n in dag.predecessors(node)] - subwatersheds = set(set(nx.ancestors(dag, outlet)) | {outlet} for outlet in subwatershed_outlets) + subwatersheds = (nx.ancestors(dag, outlet) | {outlet} for outlet in subwatershed_outlets) subwatersheds = set(dag.subgraph(wshd) for wshd in subwatersheds) return subwatersheds @@ -571,4 +571,13 @@ def fit(ocn, i, fit_method, fit_kwargs): # ) # for idx in downstream_indices[:ndownstream.value] # ) -# ) \ No newline at end of file +# ) + +__all__ = [ + "net_type_to_dag", + "simulated_annealing_schedule", + "unwrap_digraph", + "assign_subwatersheds", + "get_subwatersheds", + "parallel_fit", +] \ No newline at end of file diff --git a/sandbox.py b/sandbox.py index 994cb96..b8d3106 100644 --- a/sandbox.py +++ b/sandbox.py @@ -1,19 +1,15 @@ - import PyOCN as po -import ctypes -import matplotlib.pyplot as plt -from time import perf_counter as timer -from tqdm import trange import numpy as np - -ocn = po.OCN.from_net_type("H", dims=(10, 10)) - -pos = (7, 7) -print(list(ocn.predecessors(pos))) -print(list(ocn.successors(pos))) -print(po.utils.ancestors(ocn, pos)) -print(po.utils.descendants(ocn, pos)) - -po.plotting.plot_ocn_as_dag(ocn) -plt.show() - +import matplotlib.pyplot as plt +import networkx as nx + +ocn = po.OCN.from_net_type("E", dims=(32, 32), random_state=8472) +ocn.fit(pbar=True) +G = ocn.to_digraph() +for n in G.nodes: + if G.in_degree(n) > 2: + print(n) + break +n = 5 +po.utils.get_subwatersheds(G, node=n) +# po.plotting.plot_ocn_as_dag(ocn, ax=ax, node_size=5, with_labels="False") \ No newline at end of file From 1f2afa909b9546ed561e34c279c81ec207353731 Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Thu, 30 Oct 2025 19:56:02 -0600 Subject: [PATCH 22/24] stream orders, fixed some plotting bugs --- PyOCN/plotting.py | 19 +++++++++++------- PyOCN/utils.py | 51 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 7 deletions(-) diff --git a/PyOCN/plotting.py b/PyOCN/plotting.py index f79ab76..520480c 100644 --- a/PyOCN/plotting.py +++ b/PyOCN/plotting.py @@ -24,7 +24,7 @@ from .ocn import OCN from .utils import unwrap_digraph -def _pos_to_xy(dag: nx.DiGraph) -> dict[Any, tuple[float, float]]: +def _pos_to_xy(dag: nx.DiGraph, nrows=None) -> dict[Any, tuple[float, float]]: """ Convert node ``pos`` from (row, col) to plotting coordinates (x, y). @@ -32,6 +32,9 @@ def _pos_to_xy(dag: nx.DiGraph) -> dict[Any, tuple[float, float]]: ---------- dag : nx.DiGraph Graph whose nodes have ``pos=(row, col)`` attributes. + nrows : int, optional + Number of rows in the grid. If ``None``, it is inferred from the + maximum row index in the node positions. Returns ------- @@ -45,7 +48,7 @@ def _pos_to_xy(dag: nx.DiGraph) -> dict[Any, tuple[float, float]]: match typical plotting conventions. """ pos = nx.get_node_attributes(dag, 'pos') - nrows = max(r for r, _ in pos.values()) + 1 + nrows = nrows if nrows is not None else max(r for r, _ in pos.values()) + 1 for node, (r, c) in pos.items(): pos[node] = (c, nrows - r - 1) return pos @@ -82,7 +85,7 @@ def plot_ocn_as_dag(ocn: OCN, attribute: str | None = None, ax=None, norm=None, dag = ocn.to_digraph() if ocn.wrap: dag = unwrap_digraph(dag, ocn.dims) - pos = _pos_to_xy(dag) + pos = _pos_to_xy(dag, nrows=ocn.dims[0]) if ax is None: _, ax = plt.subplots() @@ -97,8 +100,7 @@ def plot_ocn_as_dag(ocn: OCN, attribute: str | None = None, ax=None, norm=None, kwargs["vmin"] = 0 kwargs["vmax"] = 1 node_color = norm(node_color) - - + p = nx.draw_networkx(dag, node_color=node_color, pos=pos, ax=ax, **kwargs) return p, ax @@ -146,7 +148,7 @@ def plot_ocn_raster(ocn: OCN, attribute:str='energy', ax=None, **kwargs): return ax -def plot_positional_digraph(dag: nx.DiGraph, ax=None, **kwargs): +def plot_positional_digraph(dag: nx.DiGraph, ax=None, nrows:int|None=None, **kwargs): """ Plot a DAG with node positions taken from their ``pos`` attributes. @@ -156,6 +158,9 @@ def plot_positional_digraph(dag: nx.DiGraph, ax=None, **kwargs): Graph whose nodes have ``pos=(row, col)``. ax : matplotlib.axes.Axes, optional Target axes. If ``None``, a new figure and axes are created. + nrows : int | None, default None + Number of rows in the grid. If ``None``, infer from the maximum row index in + the node positions. If ``False``, the vertical coordinates are not **kwargs Additional keyword arguments forwarded to :func:`networkx.draw_networkx`. @@ -167,7 +172,7 @@ def plot_positional_digraph(dag: nx.DiGraph, ax=None, **kwargs): ``networkx.draw_networkx`` (often ``None``) and ``ax`` is the axes used for drawing. """ - pos = _pos_to_xy(dag) + pos = _pos_to_xy(dag, nrows) if ax is None: _, ax = plt.subplots() diff --git a/PyOCN/utils.py b/PyOCN/utils.py index 38ba486..2ac5159 100644 --- a/PyOCN/utils.py +++ b/PyOCN/utils.py @@ -368,6 +368,55 @@ def get_subwatersheds(dag : nx.DiGraph, node : Any) -> set[nx.DiGraph]: subwatersheds = set(dag.subgraph(wshd) for wshd in subwatersheds) return subwatersheds +def assign_strahler_orders(dag: nx.DiGraph) -> None: + """Assign Strahler order to each node in the DAG as a 'strahler_order' attribute. + Modifies the input graph in place. + + Parameters + ---------- + dag : nx.DiGraph + The input directed acyclic graph. + """ + # Initialize Strahler order for leaf nodes + leaf_nodes = [n for n, d in dag.in_degree() if d == 0] + nx.set_node_attributes(dag, {n: 1 for n in leaf_nodes}, 'strahler_order') + + # Compute Strahler order for other nodes + strahler_orders = dict(nx.get_node_attributes(dag, 'strahler_order')) + for n in nx.topological_sort(dag): + if dag.in_degree(n): + in_orders = [strahler_orders[p] for p in dag.predecessors(n)] + max_order = max(in_orders) + if sum(o == max_order for o in in_orders) > 1: + order = max_order + 1 + else: + order = max_order + strahler_orders[n] = order + nx.set_node_attributes(dag, strahler_orders, 'strahler_order') + +def assign_shreve_orders(dag: nx.DiGraph) -> None: + """Assign Shreve order to each node in the DAG as a 'shreve_order' attribute. + Shreve order is defined as the total number of upstream sources that contribute to a node. + + Parameters + ---------- + dag : nx.DiGraph + The input directed acyclic graph. + """ + # Initialize Shreve order for leaf nodes + leaf_nodes = [n for n, d in dag.in_degree() if d == 0] + nx.set_node_attributes(dag, {n: 1 for n in leaf_nodes}, 'shreve_order') + + # Compute Shreve order for other nodes + shreve_orders = dict(nx.get_node_attributes(dag, 'shreve_order')) + for n in nx.topological_sort(dag): + if dag.in_degree(n): + in_orders = [shreve_orders[p] for p in dag.predecessors(n)] + order = sum(in_orders) + shreve_orders[n] = order + nx.set_node_attributes(dag, shreve_orders, 'shreve_order') + + def parallel_fit( ocn:OCN|list[OCN], n_runs=5, @@ -580,4 +629,6 @@ def fit(ocn, i, fit_method, fit_kwargs): "assign_subwatersheds", "get_subwatersheds", "parallel_fit", + "assign_strahler_orders", + "assign_shreve_orders", ] \ No newline at end of file From d06a8ad282778d23f1e53eba06ab717e44b881a9 Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Thu, 30 Oct 2025 20:08:17 -0600 Subject: [PATCH 23/24] added tests for subwatersheds, deferring stream order calculations for now --- PyOCN/utils.py | 91 +++++++++++++++++++++++---------------------- tests/test_basic.py | 18 +++++++++ 2 files changed, 64 insertions(+), 45 deletions(-) diff --git a/PyOCN/utils.py b/PyOCN/utils.py index 2ac5159..843af0c 100644 --- a/PyOCN/utils.py +++ b/PyOCN/utils.py @@ -368,53 +368,54 @@ def get_subwatersheds(dag : nx.DiGraph, node : Any) -> set[nx.DiGraph]: subwatersheds = set(dag.subgraph(wshd) for wshd in subwatersheds) return subwatersheds -def assign_strahler_orders(dag: nx.DiGraph) -> None: - """Assign Strahler order to each node in the DAG as a 'strahler_order' attribute. - Modifies the input graph in place. +# # deferring stream order calculations for now. +# def assign_strahler_orders(dag: nx.DiGraph) -> None: +# """Assign Strahler order to each node in the DAG as a 'strahler_order' attribute. +# Modifies the input graph in place. - Parameters - ---------- - dag : nx.DiGraph - The input directed acyclic graph. - """ - # Initialize Strahler order for leaf nodes - leaf_nodes = [n for n, d in dag.in_degree() if d == 0] - nx.set_node_attributes(dag, {n: 1 for n in leaf_nodes}, 'strahler_order') - - # Compute Strahler order for other nodes - strahler_orders = dict(nx.get_node_attributes(dag, 'strahler_order')) - for n in nx.topological_sort(dag): - if dag.in_degree(n): - in_orders = [strahler_orders[p] for p in dag.predecessors(n)] - max_order = max(in_orders) - if sum(o == max_order for o in in_orders) > 1: - order = max_order + 1 - else: - order = max_order - strahler_orders[n] = order - nx.set_node_attributes(dag, strahler_orders, 'strahler_order') - -def assign_shreve_orders(dag: nx.DiGraph) -> None: - """Assign Shreve order to each node in the DAG as a 'shreve_order' attribute. - Shreve order is defined as the total number of upstream sources that contribute to a node. +# Parameters +# ---------- +# dag : nx.DiGraph +# The input directed acyclic graph. +# """ +# # Initialize Strahler order for leaf nodes +# leaf_nodes = [n for n, d in dag.in_degree() if d == 0] +# nx.set_node_attributes(dag, {n: 1 for n in leaf_nodes}, 'strahler_order') + +# # Compute Strahler order for other nodes +# strahler_orders = dict(nx.get_node_attributes(dag, 'strahler_order')) +# for n in nx.topological_sort(dag): +# if dag.in_degree(n): +# in_orders = [strahler_orders[p] for p in dag.predecessors(n)] +# max_order = max(in_orders) +# if sum(o == max_order for o in in_orders) > 1: +# order = max_order + 1 +# else: +# order = max_order +# strahler_orders[n] = order +# nx.set_node_attributes(dag, strahler_orders, 'strahler_order') + +# def assign_shreve_orders(dag: nx.DiGraph) -> None: +# """Assign Shreve order to each node in the DAG as a 'shreve_order' attribute. +# Shreve order is defined as the total number of upstream sources that contribute to a node. - Parameters - ---------- - dag : nx.DiGraph - The input directed acyclic graph. - """ - # Initialize Shreve order for leaf nodes - leaf_nodes = [n for n, d in dag.in_degree() if d == 0] - nx.set_node_attributes(dag, {n: 1 for n in leaf_nodes}, 'shreve_order') - - # Compute Shreve order for other nodes - shreve_orders = dict(nx.get_node_attributes(dag, 'shreve_order')) - for n in nx.topological_sort(dag): - if dag.in_degree(n): - in_orders = [shreve_orders[p] for p in dag.predecessors(n)] - order = sum(in_orders) - shreve_orders[n] = order - nx.set_node_attributes(dag, shreve_orders, 'shreve_order') +# Parameters +# ---------- +# dag : nx.DiGraph +# The input directed acyclic graph. +# """ +# # Initialize Shreve order for leaf nodes +# leaf_nodes = [n for n, d in dag.in_degree() if d == 0] +# nx.set_node_attributes(dag, {n: 1 for n in leaf_nodes}, 'shreve_order') + +# # Compute Shreve order for other nodes +# shreve_orders = dict(nx.get_node_attributes(dag, 'shreve_order')) +# for n in nx.topological_sort(dag): +# if dag.in_degree(n): +# in_orders = [shreve_orders[p] for p in dag.predecessors(n)] +# order = sum(in_orders) +# shreve_orders[n] = order +# nx.set_node_attributes(dag, shreve_orders, 'shreve_order') def parallel_fit( diff --git a/tests/test_basic.py b/tests/test_basic.py index 9565c9f..f42ae70 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -405,5 +405,23 @@ def test_gamma(self): self.assertLess(energy_gamma_3, energy_gamma_2, msg="Energy should decrease with lower gamma.") self.assertLess(energy_gamma_2, energy_gamma_1, msg="Energy should decrease with lower gamma.") + def test_watershed_partitioning(self): + ocn = po.OCN.from_net_type("E", dims=(32, 32), random_state=83) + ocn.fit(pbar=True) + G = ocn.to_digraph() + n = 997 + subgraphs = po.utils.get_subwatersheds(G, n) + + node_check = [ + set([932, 931, 964]), + set([965]), + set([900, 901, 902, 903, 904, 933, 934, 935, 936, 837, 838, 967, 840, 839, 966, 868, 869, 870, 871, 872]) + ] + for wshd in subgraphs: + self.assertIn(set(wshd.nodes), node_check) + + def test_stream_ordering(self): + + if __name__ == "__main__": unittest.main() \ No newline at end of file From e85782ae309df5c221f0faaa9cfaeb12383c9085 Mon Sep 17 00:00:00 2001 From: Alex Fox Date: Thu, 30 Oct 2025 20:10:52 -0600 Subject: [PATCH 24/24] minor bug fix --- tests/test_basic.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_basic.py b/tests/test_basic.py index f42ae70..9383336 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -420,8 +420,5 @@ def test_watershed_partitioning(self): for wshd in subgraphs: self.assertIn(set(wshd.nodes), node_check) - def test_stream_ordering(self): - - if __name__ == "__main__": unittest.main() \ No newline at end of file