diff --git a/.vscode/settings.json b/.vscode/settings.json index fda6998..7a40653 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,6 +4,8 @@ "neighbormatrix.c": "cpp", "stdbool.h": "c", "returncodes.h": "c", - "string.h": "c" + "string.h": "c", + "__locale": "c", + "bitset": "c" } } \ No newline at end of file diff --git a/PyOCN/__init__.py b/PyOCN/__init__.py index 2c925c6..1051aa8 100644 --- a/PyOCN/__init__.py +++ b/PyOCN/__init__.py @@ -1,16 +1,14 @@ from .ocn import OCN -from .plotting import plot_ocn_as_dag, plot_positional_digraph, plot_ocn_raster -from .utils import net_type_to_dag, simulated_annealing_schedule, unwrap_digraph, get_subwatersheds, assign_subwatersheds +from ._version import __version__ + +from . import utils +from . import plotting + __all__ = [ "OCN", - "flowgrid", - "plot_ocn_as_dag", - "plot_positional_digraph", - "plot_ocn_raster", - "net_type_to_dag", - "simulated_annealing_schedule", - "unwrap_digraph", - "get_subwatersheds", - "assign_subwatersheds", -] \ No newline at end of file + "utils", + "plotting", + "__version__", +] + diff --git a/PyOCN/__pycache__/__init__.cpython-311.pyc b/PyOCN/__pycache__/__init__.cpython-311.pyc index aeb2167..b5f2afb 100644 Binary files a/PyOCN/__pycache__/__init__.cpython-311.pyc and b/PyOCN/__pycache__/__init__.cpython-311.pyc differ diff --git a/PyOCN/_flowgrid_convert.py b/PyOCN/_flowgrid_convert.py index 0f05233..3b2a9fb 100644 --- a/PyOCN/_flowgrid_convert.py +++ b/PyOCN/_flowgrid_convert.py @@ -1,6 +1,3 @@ -#TODO: allow users to provide multiple DAGs that partition a space. -#TODO: allow edge-wrapping (toroidal grids). -#TODO: implement "every edge vertex is a root" option """ Functions for converting between NetworkX graphs and FlowGrid_C structures. """ @@ -246,9 +243,6 @@ def direction_bit(pos1, pos2): p_c_graph.contents.resolution = float(resolution) p_c_graph.contents.nroots = len([n for n in G.nodes if G.out_degree(n) == 0]) p_c_graph.contents.wrap = wrap - - if p_c_graph.contents.nroots > 1: - warnings.warn(f"FlowGrid has {p_c_graph.contents.nroots} root nodes (nodes with no downstream). This will slow down certain operations.") # do not set energy @@ -291,3 +285,10 @@ def validate_flowgrid(c_graph:_bindings.FlowGrid_C, verbose:bool=False) -> bool| except Exception as e: # _digraph_to_flowgrid_c will destroy p_c_graph on failure return str(e) return "Graph is valid." + +__all__ = [ + "to_digraph", + "from_digraph", + "validate_digraph", + "validate_flowgrid", +] diff --git a/PyOCN/_libocn_bindings.py b/PyOCN/_libocn_bindings.py index ed60aa4..30897c5 100644 --- a/PyOCN/_libocn_bindings.py +++ b/PyOCN/_libocn_bindings.py @@ -115,6 +115,14 @@ class FlowGrid_C(Structure): libocn.fg_cart_to_lin.argtypes = [CartPair_C, CartPair_C] libocn.fg_cart_to_lin.restype = linidx_t +# CartPair fg_lin_to_cart(linidx_t a, CartPair dims); +libocn.fg_lin_to_cart.argtypes = [linidx_t, CartPair_C] +libocn.fg_lin_to_cart.restype = CartPair_C + +# Status fg_clockhand_to_lin(linidx_t *a_down, linidx_t a, clockhand_t down, CartPair dims, bool wrap) +libocn.fg_clockhand_to_lin.argtypes = [POINTER(linidx_t), linidx_t, clockhand_t, CartPair_C, c_bool] +libocn.fg_clockhand_to_lin.restype = Status + # Status fg_get_lin(Vertex *out, FlowGrid *G, linidx_t a); libocn.fg_get_lin.argtypes = [POINTER(Vertex_C), POINTER(FlowGrid_C), linidx_t] libocn.fg_get_lin.restype = Status @@ -135,6 +143,18 @@ class FlowGrid_C(Structure): libocn.fg_destroy.argtypes = [POINTER(FlowGrid_C)] libocn.fg_destroy.restype = Status +# Status fg_find_upstream_neighbors(linidx_t upstream_indices[8], linidx_t *nupstream, FlowGrid *G, linidx_t a); +libocn.fg_find_upstream_neighbors.argtypes = [linidx_t * 8, POINTER(linidx_t), POINTER(FlowGrid_C), linidx_t] +libocn.fg_find_upstream_neighbors.restype = Status + +# Status fg_dfs_iterative(linidx_t upstream_indices[], linidx_t *nupstream, linidx_t idx_stack[], FlowGrid *G, linidx_t a); +libocn.fg_dfs_iterative.argtypes = [POINTER(linidx_t), POINTER(linidx_t), POINTER(linidx_t), POINTER(FlowGrid_C), linidx_t] +libocn.fg_dfs_iterative.restype = Status + +# Status flow_downstream(linidx_t downstream_indices[], linidx_t *ndownstream, FlowGrid *G, linidx_t a) +libocn.flow_downstream.argtypes = [POINTER(linidx_t), POINTER(linidx_t), POINTER(FlowGrid_C), linidx_t] +libocn.flow_downstream.restype = Status + ############################## # OCN.H EQUIVALENTS # ############################## @@ -143,12 +163,12 @@ class FlowGrid_C(Structure): libocn.ocn_compute_energy.argtypes = [POINTER(FlowGrid_C), c_double] libocn.ocn_compute_energy.restype = c_double -# Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature); -libocn.ocn_single_erosion_event.argtypes = [POINTER(FlowGrid_C), c_double, c_double] +# Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature, bool calculate_full_energy); +libocn.ocn_single_erosion_event.argtypes = [POINTER(FlowGrid_C), c_double, c_double, c_bool] libocn.ocn_single_erosion_event.restype = Status -# Status ocn_outer_ocn_loop(FlowGrid *G, uint32_t niterations, double gamma, double *annealing_schedule); -libocn.ocn_outer_ocn_loop.argtypes = [POINTER(FlowGrid_C), c_uint32, c_double, POINTER(c_double)] +# Status ocn_outer_ocn_loop(FlowGrid *G, uint32_t niterations, double gamma, double *annealing_schedule, bool calculate_full_energy); +libocn.ocn_outer_ocn_loop.argtypes = [POINTER(FlowGrid_C), c_uint32, c_double, POINTER(c_double), c_bool] libocn.ocn_outer_ocn_loop.restype = Status diff --git a/PyOCN/_version.py b/PyOCN/_version.py new file mode 100644 index 0000000..89182ca --- /dev/null +++ b/PyOCN/_version.py @@ -0,0 +1 @@ +__version__ = "1.4.20251029" \ No newline at end of file diff --git a/PyOCN/c_src/flowgrid.c b/PyOCN/c_src/flowgrid.c index 627935e..ade1394 100644 --- a/PyOCN/c_src/flowgrid.c +++ b/PyOCN/c_src/flowgrid.c @@ -302,6 +302,82 @@ Status fg_check_for_cycles(FlowGrid *G, linidx_t a, uint8_t check_number){ return SUCCESS; // found root successfully, no cycles found } +Status fg_find_upstream_neighbors(linidx_t upstream_indices[8], linidx_t *nupstream, FlowGrid *G, linidx_t a){ + if (G == NULL || G->vertices == NULL || upstream_indices == NULL || nupstream == NULL) return NULL_POINTER_ERROR; + + Status code; + Vertex vert; + code = fg_get_lin(&vert, G, a); + if (code != SUCCESS) return code; + + // get the clockhand direction of all edges + *nupstream = 0; + for (clockhand_t dir = 0; dir < 8; dir++){ + if ((vert.edges & (1u << dir)) && (vert.downstream != dir)){ // if there's an edge in this direction + linidx_t a_up; + code = fg_clockhand_to_lin(&a_up, a, dir, G->dims, G->wrap); // get the upstream vertex index + if (code != SUCCESS) return code; + upstream_indices[*nupstream] = a_up; + (*nupstream)++; + } + } + return SUCCESS; +} + +Status fg_dfs_iterative(linidx_t upstream_indices[], linidx_t *nupstream, linidx_t idx_stack[], FlowGrid *G, linidx_t a){ + if (G == NULL || G->vertices == NULL || upstream_indices == NULL || nupstream == NULL || idx_stack == NULL) return NULL_POINTER_ERROR; + Status code; + + // Seed idx_stack with immediate upstream of a + *nupstream = 0; + linidx_t local_upstream[8]; + linidx_t nlocal_upstream = 0; + code = fg_find_upstream_neighbors(local_upstream, &nlocal_upstream, G, a); + if (code != SUCCESS) return code; + if (nlocal_upstream == 0) return SUCCESS; // no upstream vertices found + + linidx_t total = (linidx_t)G->dims.row * (linidx_t)G->dims.col; + if (nlocal_upstream >= total) return OOB_ERROR; + memcpy(idx_stack, local_upstream, nlocal_upstream * sizeof(linidx_t)); + + if (*nupstream >= total || total == 0) return OOB_ERROR; + linidx_t max_out = total - 1; + linidx_t top = nlocal_upstream; + while (top > 0){ + // pop from stack + linidx_t u = idx_stack[--top]; + if (u >= total) return OOB_ERROR; + + // Append to output buffer + if (*nupstream >= max_out) return OOB_ERROR; + upstream_indices[(*nupstream)++] = u; + + // Push u's upstream neighbors + nlocal_upstream = 0; + code = fg_find_upstream_neighbors(local_upstream, &nlocal_upstream, G, u); + if (code != SUCCESS) return code; + if (top + nlocal_upstream >= total) return OOB_ERROR; + memcpy(idx_stack + top, local_upstream, nlocal_upstream * sizeof(linidx_t)); + top += nlocal_upstream; + } + return SUCCESS; +} + +Status flow_downstream(linidx_t downstream_indices[], linidx_t *ndownstream, FlowGrid *G, linidx_t a){ + Vertex vert; + *ndownstream = 0; + do { + Status code = fg_get_lin(&vert, G, a); + if (code == OOB_ERROR) return OOB_ERROR; + downstream_indices[(*ndownstream)++] = a; + + a = vert.adown; + } while (vert.downstream != IS_ROOT); + + return SUCCESS; +} + + // vibe-coded display function const char E_ARROW = '-'; const char S_ARROW = '|'; diff --git a/PyOCN/c_src/flowgrid.h b/PyOCN/c_src/flowgrid.h index 41e7f78..b4f3d38 100644 --- a/PyOCN/c_src/flowgrid.h +++ b/PyOCN/c_src/flowgrid.h @@ -161,6 +161,36 @@ Status fg_change_vertex_outflow(FlowGrid *G, linidx_t a, clockhand_t down_new); */ Status fg_check_for_cycles(FlowGrid *G, linidx_t a, uint8_t check_number); +/** + * @brief Find all upstream neighbors of a given vertex. + * @param upstream_indices Array to store the linear indices of upstream neighbors. + * @param nupstream Pointer to store the number of upstream neighbors found. + * @param G Pointer to the FlowGrid. + * @param a The linear index of the vertex to find upstream neighbors for. + * @return Status code indicating success or failure + */ +Status fg_find_upstream_neighbors(linidx_t upstream_indices[8], linidx_t *nupstream, FlowGrid *G, linidx_t a); + +/** + * @brief Perform an iterative depth-first search to find all upstream vertices from a given starting vertex. + * @param upstream_indices Array to store the linear indices of all upstream vertices found. Must be preallocated to hold enough indices (ie G.dims.row * G.dims.col). + * @param nupstream Pointer to store the total number of upstream vertices found. + * @param idx_stack Preallocated stack array for DFS traversal. Must be large enough to hold all potential upstream vertices (ie G.dims.row * G.dims.col). + * @param G Pointer to the FlowGrid. + * @param a The linear index of the starting vertex. + * @return Status code indicating success or failure + */ +Status fg_dfs_iterative(linidx_t upstream_indices[], linidx_t *nupstream, linidx_t idx_stack[], FlowGrid *G, linidx_t a); + +/** + * @brief Follow the downstream path from a given vertex, recording each vertex along the path. Stops when a root node is reached. Includes root node. Does not include starting node. + * @param downstream_indices Array to store the linear indices of downstream vertices. Must be preallocated to hold enough indices (ie G.dims.row * G.dims.col). + * @param ndownstream Pointer to store the number of downstream vertices found. + * @param G Pointer to the FlowGrid. + * @param a The linear index of the starting vertex. + * @return Status code indicating success or failure + */ +Status flow_downstream(linidx_t downstream_indices[], linidx_t *ndownstream, FlowGrid *G, linidx_t a); /** * @brief Display the flowgrid in the terminal using ASCII or UTF-8 characters. * @param G Pointer to the FlowGrid to display. diff --git a/PyOCN/c_src/ocn.c b/PyOCN/c_src/ocn.c index 8e0d75b..45ab9a9 100644 --- a/PyOCN/c_src/ocn.c +++ b/PyOCN/c_src/ocn.c @@ -37,7 +37,7 @@ static inline bool simulate_annealing(double energy_new, double energy_old, doub * @param a The linear index of the starting vertex. * @return Status code indicating success or failure */ -static inline Status update_drained_area(FlowGrid *G, drainedarea_t da_inc, linidx_t a){ +static Status update_drained_area(FlowGrid *G, drainedarea_t da_inc, linidx_t a){ Vertex vert; do { Status code = fg_get_lin(&vert, G, a); @@ -51,7 +51,7 @@ static inline Status update_drained_area(FlowGrid *G, drainedarea_t da_inc, lini return SUCCESS; } -double ocn_compute_energy(FlowGrid *G, double gamma){ +static inline double compute_energy(FlowGrid *G, double gamma){ double energy = 0.0; for (linidx_t i = 0; i < (linidx_t)G->dims.row * (linidx_t)G->dims.col; i++){ energy += pow(G->vertices[i].drained_area, gamma); @@ -59,6 +59,10 @@ double ocn_compute_energy(FlowGrid *G, double gamma){ return energy; } +double ocn_compute_energy(FlowGrid *G, double gamma){ + return compute_energy(G, gamma); +} + /** * @brief Update the energy of the flowgrid along a single downstream path from a given vertex. Unsafe. * This function only works correctly if there is a single root in the flowgrid. @@ -69,7 +73,7 @@ double ocn_compute_energy(FlowGrid *G, double gamma){ * @param gamma The exponent used in the energy calculation. * @return Status code indicating success or failure */ -Status update_energy_single_root(FlowGrid *G, drainedarea_t da_inc, linidx_t a, double gamma){ +static Status update_energy_single_root(FlowGrid *G, drainedarea_t da_inc, linidx_t a, double gamma){ Vertex vert; double energy_old = 0.0; double energy_new = 0.0; @@ -86,7 +90,7 @@ Status update_energy_single_root(FlowGrid *G, drainedarea_t da_inc, linidx_t a, return SUCCESS; } -Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature){ +Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature, bool calculate_full_energy){ Status code; Vertex vert; @@ -154,52 +158,37 @@ Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature){ mh_eval: - /* - TODO: PERFORMANCE ISSUE: - This function is supposed to update the energy of the flowgrid G after a - change in drained area along the path starting at vertex a. - - Simple but inefficient fix (current): recompute the *entire* energy of the flowgrid from scratch - each time this function is called. - - More complex fix: find the set of all upstream vertices that flow into a and compute - their summed contribution to the energy. Pass this value (sum of (da^gamma) for all - upstream vertices) into this function, instead of just passing da_inc. - */ - if ((G->nroots > 1) && (gamma < 1.0)){ - // energy_old = ocn_compute_energy(G, gamma); // recompute energy from scratch + + if (calculate_full_energy){ + energy_old = compute_energy(G, gamma); // recompute energy from scratch update_drained_area(G, -da_inc, a_down_old); // remove drainage from old path update_drained_area(G, da_inc, a_down_new); // add drainage to new path - energy_new = ocn_compute_energy(G, gamma); // recompute energy from scratch + energy_new = compute_energy(G, gamma); // recompute energy from scratch if (simulate_annealing(energy_new, energy_old, temperature, &G->rng)){ G->energy = energy_new; return SUCCESS; } - // reject swap: undo everything and try again - update_drained_area(G, da_inc, a_down_old); // add removed drainage back to old path - update_drained_area(G, -da_inc, a_down_new); // remove added drainage from new path - fg_change_vertex_outflow(G, a, down_old); // undo the outflow change - } else { // if there's only one root, we can use a more efficient method + } else { update_energy_single_root(G, -da_inc, a_down_old, gamma); // remove drainage from old path and update energy update_energy_single_root(G, da_inc, a_down_new, gamma); // add drainage to new path and update energy energy_new = G->energy; if (simulate_annealing(energy_new, energy_old, temperature, &G->rng)){ return SUCCESS; } - // reject swap: undo everything and try again - update_energy_single_root(G, da_inc, a_down_old, gamma); // add removed drainage back to old path and update energy - update_energy_single_root(G, -da_inc, a_down_new, gamma); // remove added drainage from new path and update energy - fg_change_vertex_outflow(G, a, down_old); // undo the outflow change } - + // reject swap: undo everything and try again + update_drained_area(G, da_inc, a_down_old); // add removed drainage back to old path and update energy + update_drained_area(G, -da_inc, a_down_new); // remove added drainage from new path and update energy + fg_change_vertex_outflow(G, a, down_old); // undo the outflow change + G->energy = energy_old; return EROSION_FAILURE; // if we reach here, we failed to find a valid swap in many, many tries } -Status ocn_outer_ocn_loop(FlowGrid *G, uint32_t niterations, double gamma, double *annealing_schedule){ +Status ocn_outer_ocn_loop(FlowGrid *G, uint32_t niterations, double gamma, double *annealing_schedule, bool calculate_full_energy){ Status code; for (uint32_t i = 0; i < niterations; i++){ - code = ocn_single_erosion_event(G, gamma, annealing_schedule[i]); + code = ocn_single_erosion_event(G, gamma, annealing_schedule[i], calculate_full_energy); if ((code != SUCCESS) && (code != EROSION_FAILURE)) return code; } return SUCCESS; diff --git a/PyOCN/c_src/ocn.h b/PyOCN/c_src/ocn.h index 5c0aa2b..c04d277 100644 --- a/PyOCN/c_src/ocn.h +++ b/PyOCN/c_src/ocn.h @@ -4,8 +4,6 @@ * @brief Header file for OCN optimization. */ -//#TODO I'm worried that our method of choosing a new vertex if the last one is invalid introduces bias. Either choose a random direction to walk, or just try every vertex in random order. - #ifndef OCN_H #define OCN_H @@ -31,9 +29,10 @@ double ocn_compute_energy(FlowGrid *G, double gamma); * @param G Pointer to the FlowGrid. * @param gamma The exponent used in the energy calculation. * @param temperature The temperature parameter for the Metropolis-Hastings acceptance criterion. + * @param calculate_full_energy If true, the full energy of the graph is recalculated when considering the proposed change. If false, a more efficient incremental update is used. full_energy_recalc will be slower, but avoid accumulated numerical errors over many iterations. * @return Status code indicating success or failure. */ -Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature); +Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature, bool calculate_full_energy); /** * @brief Perform multiple erosion events on the flowgrid. @@ -43,8 +42,9 @@ Status ocn_single_erosion_event(FlowGrid *G, double gamma, double temperature); * @param gamma The exponent used in the energy calculation. * @param annealing_schedule An array of temperatures (ranging from 0-1) to use for each iteration. Length must be at least niterations. * @param wrap If true, allows wrapping around the edges of the grid (toroidal). If false, no wrapping is applied. + * @param calculate_full_energy If true, the full energy of the graph is recalculated after each proposed change. If false, a more efficient incremental update is used. full_energy_recalc will be slower, but avoid accumulated numerical errors over many iterations. * @return Status code indicating success or failure */ -Status ocn_outer_ocn_loop(FlowGrid *G, uint32_t niterations, double gamma, double *annealing_schedule); +Status ocn_outer_ocn_loop(FlowGrid *G, uint32_t niterations, double gamma, double *annealing_schedule, bool calculate_full_energy); #endif // OCN_H diff --git a/PyOCN/ocn.py b/PyOCN/ocn.py index bbbe506..feb5811 100644 --- a/PyOCN/ocn.py +++ b/PyOCN/ocn.py @@ -1,6 +1,7 @@ import warnings import ctypes from typing import Any, Callable, TYPE_CHECKING, Union +from collections.abc import Generator from os import PathLike from numbers import Number from pathlib import Path @@ -42,8 +43,6 @@ Helper functions for visualization and plotting """ - -# TODO: relax the even dims requirement # TODO: have to_rasterio use the option to set the root node to 0,0 by using to_xarray as the backend instead of numpy? @@ -122,6 +121,9 @@ class OCN: >>> plt.show() """ + ########################### + # CONSTRUCTORS # + ########################### def __init__(self, dag: nx.DiGraph, resolution: float=1.0, gamma: float = 0.5, random_state=None, verbosity: int = 0, validate:bool=True, wrap : bool = False): """ Construct an :class:`OCN` from a valid NetworkX ``DiGraph``. @@ -265,9 +267,11 @@ def from_digraph(cls, dag: nx.DiGraph, resolution:float=1, gamma=0.5, random_sta return cls(dag, resolution, gamma, random_state, verbosity=verbosity, validate=True, wrap=wrap) + ########################### + # DUNDER METHODS # + ########################### def __repr__(self): - #TODO: too verbose? - return f"" + return f"\n\n" def __str__(self): return f"OCN(gamma={self.gamma}, energy={self.energy}, dims={self.dims}, resolution={self.resolution}m, verbosity={self.verbosity})" def __del__(self): @@ -321,6 +325,9 @@ def copy(self) -> "OCN": """ return self.__copy__() + ########################### + # PROPERTIES # + ########################### def compute_energy(self) -> float: """ Compute the current energy of the network. @@ -369,6 +376,122 @@ def rng(self, random_state:int|None|np.random.Generator=None): def history(self) -> np.ndarray: return self.__history + ########################### + # GRAPH TRAVERSAL # + ########################### + # # Commenting out for now. May decide to re-introduce later + # def predecessors(self, pos:tuple[int, int]) -> Generator[tuple[int, int], None, None]: + # """Returns an iterator over predecessor nodes of the node at position `pos` in the OCN. + # A predecessor is defined as an immediate upstream neighbor. + + # A predecessor of a node n is a node m such that there exists a directed edge from m to n. + + # Mirrors the behavior of `networkx.DiGraph.predecessors`, but works directly on the OCN C graph structure. + + # Yields (row, col) of successors. + + # Parameters + # ---------- + # ocn : OCN + # The OCN instance. + # pos : tuple[int, int] + # The (row, col) position of the node whose predecessors are to be found. + + # Raises + # ------- + # TypeError + # If `pos` is not a tuple of two integers. + # IndexError + # If `pos` is out of bounds for the current OCN grid. + + # See Also + # -------- + # :meth:`OCN.successors` + # """ + + # pos = tuple(pos) + # if len(pos) != 2 or not all(isinstance(p, int) for p in pos): + # raise TypeError(f"Position must be a tuple of two integers. Got {pos}.") + # if (pos[0] < 0 or pos[0] >= self.dims[0]) or (pos[1] < 0 or pos[1] >= self.dims[1]): + # raise IndexError(f"Position {pos} is out of bounds for OCN with dimensions {self.dims}.") + + # # convert (row, col) to linear index + # a = _bindings.libocn.fg_cart_to_lin( + # _bindings.CartPair_C(row=pos[0], col=pos[1]), + # _bindings.CartPair_C(row=self.dims[0], col=self.dims[1]), + # ) + + # upstream_indices = (_bindings.linidx_t * 8)() + # nupstream = _bindings.linidx_t() + # check_status(_bindings.libocn.fg_find_upstream_neighbors( + # upstream_indices, + # ctypes.byref(nupstream), + # self.__p_c_graph, + # _bindings.linidx_t(a) + # )) + # count = int(nupstream.value) + + # # convert back to cartesian (row, col) + # for idx in upstream_indices[:count]: + # cart = _bindings.libocn.fg_lin_to_cart( + # idx, + # _bindings.CartPair_C(row=self.dims[0], col=self.dims[1]) + # ) + # yield (int(cart.row), int(cart.col)) + + + # def successors(self, pos: tuple[int, int]) -> Generator[tuple[int, int], None, None]: + # """Returns an iterator over successor nodes of the node at position `pos` in the OCN. + # A successor is the immediate downstream neighbor (at most one; none if root). + + # Mirrors the behavior of `networkx.DiGraph.successors`, but works directly on the OCN C graph structure. + # Yields (row, col) of successors. + + # Parameters + # ---------- + # pos : tuple[int, int] + # The (row, col) position of the node whose successors are to be found. + + # Raises + # ------- + # TypeError + # If `pos` is not a tuple of two integers. + # IndexError + # If `pos` is out of bounds for the current OCN grid. + # """ + # pos = tuple(pos) + # if len(pos) != 2 or not all(isinstance(p, int) for p in pos): + # raise TypeError(f"Position must be a tuple of two integers. Got {pos}.") + # if (pos[0] < 0 or pos[0] >= self.dims[0]) or (pos[1] < 0 or pos[1] >= self.dims[1]): + # raise IndexError(f"Position {pos} is out of bounds for OCN with dimensions {self.dims}.") + + # # (row, col) -> linear index + # a = _bindings.libocn.fg_cart_to_lin( + # _bindings.CartPair_C(row=pos[0], col=pos[1]), + # _bindings.CartPair_C(row=self.dims[0], col=self.dims[1]), + # ) + + # # Read vertex to get downstream direction + # vert = _bindings.Vertex_C() + # check_status(_bindings.libocn.fg_get_lin( + # ctypes.byref(vert), + # self.__p_c_graph, + # _bindings.linidx_t(a), + # )) + + # # If root, no successors + # if int(vert.downstream) == _bindings.IS_ROOT: + # return + + # cart = _bindings.libocn.fg_lin_to_cart( + # vert.adown, + # _bindings.CartPair_C(row=self.dims[0], col=self.dims[1]) + # ) + # yield (int(cart.row), int(cart.col)) + + ########################### + # EXPORT # + ########################### def to_digraph(self) -> nx.DiGraph: """ Create a NetworkX ``DiGraph`` view of the current grid. @@ -598,7 +721,10 @@ def to_xarray(self, unwrap:bool=True) -> "xr.Dataset": } ) - def single_iteration(self, temperature:float, array_report:bool=False, unwrap:bool=True) -> "xr.Dataset | None": + ########################### + # OPTIMIZATION # + ########################### + def single_iteration(self, temperature:float, array_report:bool=False, unwrap:bool=True, calculate_full_energy:bool=False) -> "xr.Dataset | None": """ Perform a single iteration of the optimization algorithm at a given temperature. Updates the internal history attribute. See :meth:`fit` for details on the algorithm. @@ -619,6 +745,18 @@ def single_iteration(self, temperature:float, array_report:bool=False, unwrap:bo with some nan values. If False or the current OCN does not have periodic boundaries, then no transformation is applied and the resulting raster will have the same dimensions as the current OCN grid. + calculate_full_energy : bool, default False + If True, the full energy of the graph is recalculated when considering + the proposed change. If False, a more efficient incremental update is used. + Small numerical differences may arise between the two methods due to floating point + precision. If precision is of the utmost importance, set this to True, but note that + this comes with a significant performance penalty. + + Returns + ------- + xr.Dataset | None + If ``array_report == True``, an xarray.Dataset containing the state of the FlowGrid + after the iteration. See :meth:`to_xarray` for details. If ``array_report == False``, returns None. Raises ------ @@ -630,6 +768,7 @@ def single_iteration(self, temperature:float, array_report:bool=False, unwrap:bo self.__p_c_graph, self.gamma, temperature, + calculate_full_energy )) # append to history @@ -654,7 +793,8 @@ def fit( array_reports:int=0, tol:float=None, max_iterations_per_loop=10_000, - unwrap:bool=True,) -> "xr.Dataset | None": + unwrap:bool=True, + calculate_full_energy:bool=False) -> "xr.Dataset | None": """ Convenience function to optimize the OCN using the simulated annealing algorithm from Carraro et al (2020). For finer control over the optimization process, use :meth:`fit_custom_cooling` or use :meth:`single_erosion_event` in a loop. @@ -706,6 +846,12 @@ def fit( with some nan values. If False or the current OCN does not have periodic boundaries, then no transformation is applied and the resulting raster will have the same dimensions as the current OCN grid. + calculate_full_energy : bool, default False + If True, the full energy of the graph is recalculated when considering + the proposed change. If False, a more efficient incremental update is used. + Small numerical differences may arise between the two methods due to floating point + precision. If precision is of the utmost importance, set this to True, but note that + this comes with a significant performance penalty. Returns ------- @@ -800,6 +946,7 @@ def fit( tol=tol, max_iterations_per_loop=max_iterations_per_loop, unwrap=unwrap, + calculate_full_energy=calculate_full_energy, ) def fit_custom_cooling( @@ -812,6 +959,7 @@ def fit_custom_cooling( tol:float=None, max_iterations_per_loop=10_000, unwrap:bool=True, + calculate_full_energy:bool=False, ) -> "xr.Dataset | None": """ Optimize the OCN using the a custom cooling schedule. This allows for @@ -838,6 +986,7 @@ def fit_custom_cooling( tol : float, optional max_iterations_per_loop: int, optional unwrap: bool, default True + calculate_full_energy: bool, default False Returns ------- @@ -924,6 +1073,7 @@ def fit_custom_cooling( iterations_this_loop, self.gamma, anneal_ptr, + calculate_full_energy, )) e_new = self.energy completed_iterations += iterations_this_loop diff --git a/PyOCN/plotting.py b/PyOCN/plotting.py index a149460..520480c 100644 --- a/PyOCN/plotting.py +++ b/PyOCN/plotting.py @@ -24,7 +24,7 @@ from .ocn import OCN from .utils import unwrap_digraph -def _pos_to_xy(dag: nx.DiGraph) -> dict[Any, tuple[float, float]]: +def _pos_to_xy(dag: nx.DiGraph, nrows=None) -> dict[Any, tuple[float, float]]: """ Convert node ``pos`` from (row, col) to plotting coordinates (x, y). @@ -32,6 +32,9 @@ def _pos_to_xy(dag: nx.DiGraph) -> dict[Any, tuple[float, float]]: ---------- dag : nx.DiGraph Graph whose nodes have ``pos=(row, col)`` attributes. + nrows : int, optional + Number of rows in the grid. If ``None``, it is inferred from the + maximum row index in the node positions. Returns ------- @@ -45,7 +48,7 @@ def _pos_to_xy(dag: nx.DiGraph) -> dict[Any, tuple[float, float]]: match typical plotting conventions. """ pos = nx.get_node_attributes(dag, 'pos') - nrows = max(r for r, _ in pos.values()) + 1 + nrows = nrows if nrows is not None else max(r for r, _ in pos.values()) + 1 for node, (r, c) in pos.items(): pos[node] = (c, nrows - r - 1) return pos @@ -82,7 +85,7 @@ def plot_ocn_as_dag(ocn: OCN, attribute: str | None = None, ax=None, norm=None, dag = ocn.to_digraph() if ocn.wrap: dag = unwrap_digraph(dag, ocn.dims) - pos = _pos_to_xy(dag) + pos = _pos_to_xy(dag, nrows=ocn.dims[0]) if ax is None: _, ax = plt.subplots() @@ -97,7 +100,7 @@ def plot_ocn_as_dag(ocn: OCN, attribute: str | None = None, ax=None, norm=None, kwargs["vmin"] = 0 kwargs["vmax"] = 1 node_color = norm(node_color) - + p = nx.draw_networkx(dag, node_color=node_color, pos=pos, ax=ax, **kwargs) return p, ax @@ -145,7 +148,7 @@ def plot_ocn_raster(ocn: OCN, attribute:str='energy', ax=None, **kwargs): return ax -def plot_positional_digraph(dag: nx.DiGraph, ax=None, **kwargs): +def plot_positional_digraph(dag: nx.DiGraph, ax=None, nrows:int|None=None, **kwargs): """ Plot a DAG with node positions taken from their ``pos`` attributes. @@ -155,6 +158,9 @@ def plot_positional_digraph(dag: nx.DiGraph, ax=None, **kwargs): Graph whose nodes have ``pos=(row, col)``. ax : matplotlib.axes.Axes, optional Target axes. If ``None``, a new figure and axes are created. + nrows : int | None, default None + Number of rows in the grid. If ``None``, infer from the maximum row index in + the node positions. If ``False``, the vertical coordinates are not **kwargs Additional keyword arguments forwarded to :func:`networkx.draw_networkx`. @@ -166,10 +172,16 @@ def plot_positional_digraph(dag: nx.DiGraph, ax=None, **kwargs): ``networkx.draw_networkx`` (often ``None``) and ``ax`` is the axes used for drawing. """ - pos = _pos_to_xy(dag) + pos = _pos_to_xy(dag, nrows) if ax is None: _, ax = plt.subplots() p = nx.draw_networkx(dag, pos=pos, ax=ax, **kwargs) - return p, ax \ No newline at end of file + return p, ax + +__all__ = [ + "plot_ocn_as_dag", + "plot_ocn_raster", + "plot_positional_digraph", +] \ No newline at end of file diff --git a/PyOCN/utils.py b/PyOCN/utils.py index 910dd62..843af0c 100644 --- a/PyOCN/utils.py +++ b/PyOCN/utils.py @@ -3,12 +3,20 @@ """ from __future__ import annotations +from collections.abc import Generator +from concurrent.futures import ThreadPoolExecutor +from ctypes import byref from itertools import product +import os from typing import Any, Literal, Callable, TYPE_CHECKING, Union from numbers import Number import networkx as nx import numpy as np from tqdm import tqdm +from functools import partial + +import PyOCN._libocn_bindings as _bindings +from PyOCN._statushandler import check_status if TYPE_CHECKING: from .ocn import OCN @@ -356,7 +364,272 @@ def get_subwatersheds(dag : nx.DiGraph, node : Any) -> set[nx.DiGraph]: in the subwatersheds will affect the original graph. """ subwatershed_outlets = [n for n in dag.predecessors(node)] - subwatersheds = set(set(nx.ancestors(dag, outlet)) | {outlet} for outlet in subwatershed_outlets) + subwatersheds = (nx.ancestors(dag, outlet) | {outlet} for outlet in subwatershed_outlets) subwatersheds = set(dag.subgraph(wshd) for wshd in subwatersheds) return subwatersheds - \ No newline at end of file + +# # deferring stream order calculations for now. +# def assign_strahler_orders(dag: nx.DiGraph) -> None: +# """Assign Strahler order to each node in the DAG as a 'strahler_order' attribute. +# Modifies the input graph in place. + +# Parameters +# ---------- +# dag : nx.DiGraph +# The input directed acyclic graph. +# """ +# # Initialize Strahler order for leaf nodes +# leaf_nodes = [n for n, d in dag.in_degree() if d == 0] +# nx.set_node_attributes(dag, {n: 1 for n in leaf_nodes}, 'strahler_order') + +# # Compute Strahler order for other nodes +# strahler_orders = dict(nx.get_node_attributes(dag, 'strahler_order')) +# for n in nx.topological_sort(dag): +# if dag.in_degree(n): +# in_orders = [strahler_orders[p] for p in dag.predecessors(n)] +# max_order = max(in_orders) +# if sum(o == max_order for o in in_orders) > 1: +# order = max_order + 1 +# else: +# order = max_order +# strahler_orders[n] = order +# nx.set_node_attributes(dag, strahler_orders, 'strahler_order') + +# def assign_shreve_orders(dag: nx.DiGraph) -> None: +# """Assign Shreve order to each node in the DAG as a 'shreve_order' attribute. +# Shreve order is defined as the total number of upstream sources that contribute to a node. + +# Parameters +# ---------- +# dag : nx.DiGraph +# The input directed acyclic graph. +# """ +# # Initialize Shreve order for leaf nodes +# leaf_nodes = [n for n, d in dag.in_degree() if d == 0] +# nx.set_node_attributes(dag, {n: 1 for n in leaf_nodes}, 'shreve_order') + +# # Compute Shreve order for other nodes +# shreve_orders = dict(nx.get_node_attributes(dag, 'shreve_order')) +# for n in nx.topological_sort(dag): +# if dag.in_degree(n): +# in_orders = [shreve_orders[p] for p in dag.predecessors(n)] +# order = sum(in_orders) +# shreve_orders[n] = order +# nx.set_node_attributes(dag, shreve_orders, 'shreve_order') + + +def parallel_fit( + ocn:OCN|list[OCN], + n_runs=5, + n_threads=None, + fit_method:Literal["fit", "fit_custom_cooling"]|list[Literal["fit", "fit_custom_cooling"]]|None=None, + increment_rng:bool=True, + pbar:bool=False, + fit_kwargs:dict|list[dict]=None +) -> tuple[list[Any], list[OCN]]: + """Convenience function to perform multiple OCN fitting operations in parallel using multithreading. + Useful for doing sensitivity analysis or ensemble fitting. + + Parameters + ---------- + ocn : OCN | list[OCN] + The OCN instance(s) to fit. If a list of OCNs is provided, each OCN will be fitted independently. + If a single OCN is provided, it will be copied `n_runs` times. + Not modified during fitting. + n_runs : int, default 5 + The number of fitting runs to perform. Must be equal to the length of `ocn` if `ocn` is a list. + n_threads : int, default None + The number of worker threads to use for parallel execution. If None, defaults to the number of CPU cores, times 2. + fit_method : {"fit", "fit_custom_cooling"} | list[{"fit", "fit_custom_cooling"} | None], default None + The fitting method to use. If None, defaults to "fit". If a list is provided, it must be of length `n_runs`, + and each element specifies the fitting method for the corresponding fit. + increment_rng : bool, default True + If True, each fit's random seed is set to `ocn.rng + i`, where `i` is the thread index. + pbar : bool, default False + If True, display a master progress bar that tracks the completion of each thread. + fit_kwargs : dict | list[dict], optional + Additional keyword arguments to pass to the `OCN.fit` method. If a list of dicts is provided, + it must be of length `n_runs`, and each dict will be used for the corresponding fit. + + Returns + ------- + tuple[list[Any], list[OCN]] + A tuple containing two lists: + - A list of results from each fitting operation. + - A list of fitted OCN instances corresponding to each fitting operation. + """ + + if isinstance(ocn, list): + if len(ocn) != n_runs: + raise ValueError(f"When ocn is a list, its length must equal n_runs. Got len(ocn)={len(ocn)} and n_runs={n_runs}.") + ocn = [o.copy() for o in ocn] + else: + ocn = [ocn.copy() for _ in range(n_runs)] + + if isinstance(fit_kwargs, list): + if len(fit_kwargs) != n_runs: + raise ValueError(f"When fit_kwargs is a list, its length must equal n_runs. Got len(fit_kwargs)={len(fit_kwargs)} and n_runs={n_runs}.") + if not all(isinstance(k, dict) for k in fit_kwargs): + raise ValueError("All elements of fit_kwargs list must be dictionaries.") + elif not isinstance(fit_kwargs, dict) and fit_kwargs is not None: + raise ValueError(f"fit_kwargs must be a dict or a list of dicts. Got {type(fit_kwargs)}.") + else: + fit_kwargs = [fit_kwargs or dict() for _ in range(n_runs)] + + if isinstance(fit_method, list): + if len(fit_method) != n_runs: + raise ValueError(f"When fit_method is a list, its length must equal n_runs. Got len(fit_method)={len(fit_method)} and n_runs={n_runs}.") + elif not all(m is None or m in {"fit", "fit_custom_cooling"} for m in fit_method): + raise ValueError("All elements of fit_method list must be one of 'fit', 'fit_custom_cooling', or None.") + else: + if fit_method is not None and fit_method not in {"fit", "fit_custom_cooling"}: + raise ValueError(f"Invalid fit_method {fit_method}. Must be one of 'fit', 'fit_custom_cooling', or None.") + fit_method = [fit_method for _ in range(n_runs)] + + def fit(ocn, i, fit_method, fit_kwargs): + if increment_rng: + ocn.rng = ocn.rng + i + if fit_method is None or fit_method == "fit": + res = ocn.fit(**fit_kwargs) + elif fit_method == "fit_custom_cooling": + res = ocn.fit_custom_cooling(**fit_kwargs) + else: + raise ValueError(f"Invalid fit_method {fit_method}. Must be one of 'fit', 'fit_custom_cooling', or None.") + return res, i # return index to place result correctly, in case of out-of-order completion. + + if n_threads is None: + n_threads = os.cpu_count()*2 + n_threads = min(os.cpu_count()*2, n_threads) + + with ThreadPoolExecutor(max_workers=n_threads) as executor: + futures = [] + for i in range(n_runs): + futures.append(executor.submit(fit, ocn[i], i, fit_method[i], fit_kwargs[i])) + results = [None] * n_runs + pbar = tqdm(futures, disable=not pbar, desc="Fitting OCNs") + for future in futures: + res, idx = future.result() + results[idx] = res + pbar.update(1) + return results, ocn + +# # Commenting out for now. May decide to re-introduce later +# def ancestors(ocn:OCN, pos:tuple[int, int]) -> set[tuple[int, int]]: +# """Returns all nodes that drain into the node at position `pos` in the OCN. + +# Parameters +# ---------- +# ocn : OCN +# The OCN instance. +# pos : tuple[int, int] +# The (row, col) position of the node whose predecessors are to be found. + +# Returns +# ------- +# set() +# A set of (row, col) tuples representing the positions of all ancestor nodes +# that eventually drain into the specified node, not including the node itself. + +# See Also +# -------- +# :meth:`OCN.predecessors` +# :meth:`OCN.successors` +# :meth:`utils.descendants` +# """ + +# pos = tuple(pos) +# if len(pos) != 2 or not all(isinstance(p, int) for p in pos): +# raise TypeError(f"Position must be a tuple of two integers. Got {pos}.") +# if (pos[0] < 0 or pos[0] >= ocn.dims[0]) or (pos[1] < 0 or pos[1] >= ocn.dims[1]): +# raise IndexError(f"Position {pos} is out of bounds for OCN with dimensions {ocn.dims}.") + +# a = _bindings.libocn.fg_cart_to_lin( +# _bindings.CartPair_C(row=pos[0], col=pos[1]), +# _bindings.CartPair_C(row=ocn.dims[0], col=ocn.dims[1]) +# ) +# upstream_indices = (_bindings.linidx_t * (ocn.dims[0]*ocn.dims[1]))() +# nupstream = _bindings.linidx_t(0) +# check_status(_bindings.libocn.fg_dfs_iterative( +# upstream_indices, +# byref(nupstream), +# (_bindings.linidx_t * (ocn.dims[0]*ocn.dims[1]))(), +# ocn._OCN__p_c_graph, +# a +# )) + +# return set( +# (int(c.row), int(c.col)) +# for c in ( +# _bindings.libocn.fg_lin_to_cart( +# idx, +# _bindings.CartPair_C(row=ocn.dims[0], col=ocn.dims[1]) +# ) +# for idx in upstream_indices[:nupstream.value] +# ) +# ) + +# # TODO write tests for the traversal functions +# def descendants(ocn:OCN, pos:tuple[int, int]) -> set[tuple[int, int]]: +# """Returns all nodes that are reachable from the node at position `pos` in the OCN. +# Mirrors the functionality of `networkx.descendants`. + +# Parameters +# ---------- +# ocn : OCN +# The OCN instance. +# pos : tuple[int, int] +# The (row, col) position of the node whose successors are to be found. + +# Returns +# ------- +# set() +# A set of (row, col) tuples representing the positions of all descendant nodes +# that the specified node eventually drains into, not including the node itself. + +# See Also +# -------- +# :meth:`OCN.predecessors` +# :meth:`OCN.successors` +# :meth:`utils.ancestors` +# """ + +# pos = tuple(pos) +# if len(pos) != 2 or not all(isinstance(p, int) for p in pos): +# raise TypeError(f"Position must be a tuple of two integers. Got {pos}.") +# if (pos[0] < 0 or pos[0] >= ocn.dims[0]) or (pos[1] < 0 or pos[1] >= ocn.dims[1]): +# raise IndexError(f"Position {pos} is out of bounds for OCN with dimensions {ocn.dims}.") + +# a = _bindings.libocn.fg_cart_to_lin( +# _bindings.CartPair_C(row=pos[0], col=pos[1]), +# _bindings.CartPair_C(row=ocn.dims[0], col=ocn.dims[1]) +# ) +# downstream_indices = (_bindings.linidx_t * (ocn.dims[0]*ocn.dims[1]))() +# ndownstream = _bindings.linidx_t(0) +# check_status(_bindings.libocn.flow_downstream( +# downstream_indices, +# byref(ndownstream), +# ocn._OCN__p_c_graph, +# a +# )) + +# return set( +# (int(c.row), int(c.col)) +# for c in ( +# _bindings.libocn.fg_lin_to_cart( +# idx, +# _bindings.CartPair_C(row=ocn.dims[0], col=ocn.dims[1]) +# ) +# for idx in downstream_indices[:ndownstream.value] +# ) +# ) + +__all__ = [ + "net_type_to_dag", + "simulated_annealing_schedule", + "unwrap_digraph", + "assign_subwatersheds", + "get_subwatersheds", + "parallel_fit", + "assign_strahler_orders", + "assign_shreve_orders", +] \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index 0457131..8366283 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,7 +14,7 @@ project = 'PyOCN' copyright = '2025, Alexander S. Fox' author = 'Alexander S. Fox' -release = '1.3.20251011' +version = "1.4.20251029" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/pyproject.toml b/pyproject.toml index 2983a9a..911ea05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "PyOCN" -version = "1.3.20251011" +version = "1.4.20251029" description = "Optimal Channel Networks (OCN) in Python with a C core" readme = "readme.md" requires-python = ">=3.10" diff --git a/sandbox.py b/sandbox.py index 92bb830..b8d3106 100644 --- a/sandbox.py +++ b/sandbox.py @@ -1,25 +1,15 @@ - -import matplotlib.pyplot as plt import PyOCN as po import numpy as np - -ocn = po.OCN.from_net_type( - net_type="E", - dims=(16, 16), - wrap=True, - random_state=84712, -) -ocn.fit_custom_cooling(lambda t: np.ones_like(t)*1e-1, pbar=True, n_iterations=16*16*100, max_iterations_per_loop=1) -ocn.fit(max_iterations_per_loop=1) - - - -# print(ocn.history.shape) -# print(np.max(np.diff(ocn.history[:, 1]))) -# print(np.quantile(np.diff(ocn.history[:, 1]), 0.999)) - -plt.plot(ocn.history[:, 0], ocn.history[:, 1]) -plt.plot(ocn.history[:, 0], ocn.history[:, 2]) -# plt.xscale("log") -# plt.yscale("log") -plt.show() \ No newline at end of file +import matplotlib.pyplot as plt +import networkx as nx + +ocn = po.OCN.from_net_type("E", dims=(32, 32), random_state=8472) +ocn.fit(pbar=True) +G = ocn.to_digraph() +for n in G.nodes: + if G.in_degree(n) > 2: + print(n) + break +n = 5 +po.utils.get_subwatersheds(G, node=n) +# po.plotting.plot_ocn_as_dag(ocn, ax=ax, node_size=5, with_labels="False") \ No newline at end of file diff --git a/tests/test_basic.py b/tests/test_basic.py index 603fbbc..9383336 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -4,6 +4,7 @@ These tests verify core functionality with deterministic outputs. """ import unittest +from time import perf_counter as timer import numpy as np import networkx as nx import PyOCN as po @@ -274,5 +275,150 @@ def test_greedy_optimization(self): ocn.fit_custom_cooling(lambda t: np.ones_like(t)*1e-7, pbar=False, n_iterations=16**2*100, max_iterations_per_loop=1) self.assertLessEqual(np.quantile(np.diff(ocn.history[:, 1]), 0.999) - 1e-7, 0, "Energy did not decrease monotonically.") + def test_parallel_fit(self): + """Test parallel fitting utility.""" + + ocn= po.OCN.from_net_type(net_type="I", dims=(42, 40), random_state=542390) + t0 = timer() + ocn.fit() + t1 = timer() + single_run_time = t1 - t0 + energy = ocn.energy + + ocn = po.OCN.from_net_type(net_type="I", dims=(42, 40), random_state=542390) + t0 = timer() + _, fitted_ocns = po.utils.parallel_fit( + ocn=ocn, + n_runs=5, + increment_rng=False, + ) + t1 = timer() + parallel_time = t1 - t0 + energies = [o.energy for o in fitted_ocns] + + self.assertEqual(len(fitted_ocns), 5, "Number of fitted OCNs does not match number of runs.") + + for e in energies: + self.assertAlmostEqual(e, energy, places=6, msg="Energies from parallel fits do not match.") + print() + print(f"\tSingle run time: {single_run_time:.2f} seconds") + print(f"\tParallel fit time for 5 runs: {parallel_time:.2f} seconds") + self.assertLess(parallel_time, single_run_time * 4, "Parallel fitting did not speed up the process as expected.") + + for o in fitted_ocns: + self.assertAlmostEqual(ocn.energy, o.history[0, 1], places=6, msg="Initial energy was not preserved.") + + def test_energy_update_method(self): + ocn = po.OCN.from_net_type("E", dims=(44, 44), random_state=238155) + ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=True) + energy = ocn.energy + ocn = po.OCN.from_net_type("E", dims=(44, 44), random_state=238155) + ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=False) + energy_2 = ocn.energy + self.assertAlmostEqual(energy/energy_2, 1.0, places=3, msg="Energies do not match between full_energy_calc and incremental update methods.") + + ocn = po.OCN.from_net_type("E", dims=(31, 31), random_state=12638, gamma=0.21) + ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=True) + energy = ocn.energy + ocn = po.OCN.from_net_type("E", dims=(31, 31), random_state=12638, gamma=0.21) + ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=False) + energy_2 = ocn.energy + self.assertAlmostEqual(energy/energy_2, 1.0, places=3, msg="Energies do not match between full_energy_calc and incremental update methods.") + + ocn = po.OCN.from_net_type("I", dims=(38, 38), random_state=19075, gamma=0.78) + ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=True) + energy = ocn.energy + ocn = po.OCN.from_net_type("I", dims=(38, 38), random_state=19075, gamma=0.78) + ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=False) + energy_2 = ocn.energy + self.assertAlmostEqual(energy/energy_2, 1.0, places=3, msg="Energies do not match between full_energy_calc and incremental update methods.") + + ocn = po.OCN.from_net_type("I", dims=(20, 41), random_state=910536) + ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=True) + energy = ocn.energy + ocn = po.OCN.from_net_type("I", dims=(20, 41), random_state=910536) + ocn.fit(pbar=False, max_iterations_per_loop=10_000, calculate_full_energy=False) + energy_2 = ocn.energy + + self.assertAlmostEqual(energy/energy_2, 1.0, places=3, msg="Energies do not match between full_energy_calc and incremental update methods.") + + def test_gamma(self): + """Test that different gamma values affect energy as expected.""" + ocn_gamma_1 = po.OCN.from_net_type( + net_type="V", + dims=(32, 32), + gamma=1.0, + random_state=1234, + ) + energy_gamma_1 = ocn_gamma_1.energy + + ocn_gamma_2 = po.OCN.from_net_type( + net_type="V", + dims=(32, 32), + gamma=0.5, + random_state=1234, + ) + energy_gamma_2 = ocn_gamma_2.energy + + ocn_gamma_3 = po.OCN.from_net_type( + net_type="V", + dims=(32, 32), + gamma=0.25, + random_state=1234, + ) + energy_gamma_3 = ocn_gamma_3.energy + + self.assertNotAlmostEqual(energy_gamma_1, energy_gamma_2, places=1, msg="Energies for gamma=1.0 and gamma=0.5 should differ.") + self.assertNotAlmostEqual(energy_gamma_1, energy_gamma_3, places=1, msg="Energies for gamma=1.0 and gamma=0.25 should differ.") + self.assertNotAlmostEqual(energy_gamma_2, energy_gamma_3, places=1, msg="Energies for gamma=0.5 and gamma=0.25 should differ.") + self.assertLess(energy_gamma_3, energy_gamma_2, msg="Energy should decrease with lower gamma.") + self.assertLess(energy_gamma_2, energy_gamma_1, msg="Energy should decrease with lower gamma.") + + ocn_gamma_1 = po.OCN.from_net_type( + net_type="E", + dims=(32, 32), + gamma=1.0, + random_state=1234, + ) + energy_gamma_1 = ocn_gamma_1.energy + + ocn_gamma_2 = po.OCN.from_net_type( + net_type="E", + dims=(32, 32), + gamma=0.5, + random_state=1234, + ) + energy_gamma_2 = ocn_gamma_2.energy + + ocn_gamma_3 = po.OCN.from_net_type( + net_type="E", + dims=(32, 32), + gamma=0.25, + random_state=1234, + ) + energy_gamma_3 = ocn_gamma_3.energy + + self.assertNotAlmostEqual(energy_gamma_1, energy_gamma_2, places=1, msg="Energies for gamma=1.0 and gamma=0.5 should differ.") + self.assertNotAlmostEqual(energy_gamma_1, energy_gamma_3, places=1, msg="Energies for gamma=1.0 and gamma=0.25 should differ.") + self.assertNotAlmostEqual(energy_gamma_2, energy_gamma_3, places=1, msg="Energies for gamma=0.5 and gamma=0.25 should differ.") + + self.assertLess(energy_gamma_3, energy_gamma_2, msg="Energy should decrease with lower gamma.") + self.assertLess(energy_gamma_2, energy_gamma_1, msg="Energy should decrease with lower gamma.") + + def test_watershed_partitioning(self): + ocn = po.OCN.from_net_type("E", dims=(32, 32), random_state=83) + ocn.fit(pbar=True) + G = ocn.to_digraph() + n = 997 + subgraphs = po.utils.get_subwatersheds(G, n) + + node_check = [ + set([932, 931, 964]), + set([965]), + set([900, 901, 902, 903, 904, 933, 934, 935, 936, 837, 838, 967, 840, 839, 966, 868, 869, 870, 871, 872]) + ] + for wshd in subgraphs: + self.assertIn(set(wshd.nodes), node_check) + if __name__ == "__main__": unittest.main() \ No newline at end of file