Skip to content
3 changes: 2 additions & 1 deletion pychunkedgraph/graph/chunkedgraph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# pylint: disable=invalid-name, missing-docstring, too-many-lines, import-outside-toplevel, unsupported-binary-operation

import time
import typing
import datetime
Expand Down Expand Up @@ -810,6 +809,7 @@ def add_edges(
sink_coords: typing.Sequence[int] = None,
allow_same_segment_merge: typing.Optional[bool] = False,
do_sanity_check: typing.Optional[bool] = True,
stitch_mode: typing.Optional[bool] = False,
) -> operation.GraphEditOperation.Result:
"""
Adds an edge to the chunkedgraph
Expand All @@ -827,6 +827,7 @@ def add_edges(
sink_coords=sink_coords,
allow_same_segment_merge=allow_same_segment_merge,
do_sanity_check=do_sanity_check,
stitch_mode=stitch_mode,
).execute()

def remove_edges(
Expand Down
4 changes: 2 additions & 2 deletions pychunkedgraph/graph/cutting.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,8 @@ def _remap_cut_edge_set(self, cut_edge_set):

remapped_cutset = np.array(remapped_cutset, dtype=np.uint64)

remapped_cutset_flattened_view = remapped_cutset.view(dtype="u8,u8")
edges_flattened_view = self.cg_edges.view(dtype="u8,u8")
remapped_cutset_flattened_view = remapped_cutset.view(dtype="u8,u8").ravel()
edges_flattened_view = self.cg_edges.view(dtype="u8,u8").ravel()

cutset_mask = np.isin(remapped_cutset_flattened_view, edges_flattened_view).ravel()

Expand Down
4 changes: 3 additions & 1 deletion pychunkedgraph/graph/edges/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def merge_cross_edge_dicts(x_edges_d1: Dict, x_edges_d2: Dict) -> Dict:
Combines two cross chunk dictionaries of form
{node_id: {layer id : edge list}}.
"""
node_ids = np.unique(list(x_edges_d1.keys()) + list(x_edges_d2.keys()))
node_ids = np.unique(
np.array(list(x_edges_d1.keys()) + list(x_edges_d2.keys()), dtype=basetypes.NODE_ID)
)
result_d = {}
for node_id in node_ids:
cross_edge_ds = [x_edges_d1.get(node_id, {}), x_edges_d2.get(node_id, {})]
Expand Down
85 changes: 55 additions & 30 deletions pychunkedgraph/graph/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from . import attributes
from .edges import Edges
from .edges.utils import get_edges_status
from .edits import get_profiler
from .utils import basetypes
from .utils import serializers
from .cache import CacheService
Expand Down Expand Up @@ -420,6 +421,7 @@ def execute(
op_type = "merge" if is_merge else "split"
self.parent_ts = parent_ts
root_ids = self._update_root_ids()
self.privileged_mode = self.privileged_mode or (is_merge and self.stitch_mode)
with locks.RootLock(
self.cg,
root_ids,
Expand Down Expand Up @@ -563,6 +565,7 @@ class MergeOperation(GraphEditOperation):
"bbox_offset",
"allow_same_segment_merge",
"do_sanity_check",
"stitch_mode",
]

def __init__(
Expand All @@ -577,6 +580,7 @@ def __init__(
affinities: Optional[Sequence[np.float32]] = None,
allow_same_segment_merge: Optional[bool] = False,
do_sanity_check: Optional[bool] = True,
stitch_mode: bool = False,
) -> None:
super().__init__(
cg, user_id=user_id, source_coords=source_coords, sink_coords=sink_coords
Expand All @@ -585,6 +589,7 @@ def __init__(
self.bbox_offset = np.atleast_1d(bbox_offset).astype(basetypes.COORDINATES)
self.allow_same_segment_merge = allow_same_segment_merge
self.do_sanity_check = do_sanity_check
self.stitch_mode = stitch_mode

self.affinities = None
if affinities is not None:
Expand All @@ -609,40 +614,55 @@ def _update_root_ids(self) -> np.ndarray:
def _apply(
self, *, operation_id, timestamp
) -> Tuple[np.ndarray, np.ndarray, List["bigtable.row.Row"]]:
root_ids = set(
self.cg.get_roots(
self.added_edges.ravel(), assert_roots=True, time_stamp=self.parent_ts
profiler = get_profiler()

with profiler.profile("merge_apply_get_roots"):
root_ids = set(
self.cg.get_roots(
self.added_edges.ravel(), assert_roots=True, time_stamp=self.parent_ts
)
)
)
if len(root_ids) < 2 and not self.allow_same_segment_merge:
raise PreconditionError("Supervoxels must belong to different objects.")
bbox = get_bbox(self.source_coords, self.sink_coords, self.bbox_offset)
with TimeIt("subgraph", self.cg.graph_id, operation_id):
edges = self.cg.get_subgraph(
root_ids,
bbox=bbox,
bbox_is_coordinate=True,
edges_only=True,
raise PreconditionError(
"Supervoxels must belong to different objects."
f" Tried to merge {self.added_edges.ravel()},"
f" which all belong to {tuple(root_ids)[0]}."
)

if self.allow_same_segment_merge:
inactive_edges = types.empty_2d
else:
with TimeIt("preprocess", self.cg.graph_id, operation_id):
inactive_edges = edits.merge_preprocess(
atomic_edges = self.added_edges
fake_edge_rows = []
if not self.stitch_mode:
bbox = get_bbox(self.source_coords, self.sink_coords, self.bbox_offset)
with profiler.profile("get_subgraph"):
with TimeIt("subgraph", self.cg.graph_id, operation_id):
edges = self.cg.get_subgraph(
root_ids,
bbox=bbox,
bbox_is_coordinate=True,
edges_only=True,
)

if self.allow_same_segment_merge:
inactive_edges = types.empty_2d
else:
with profiler.profile("merge_preprocess"):
with TimeIt("preprocess", self.cg.graph_id, operation_id):
inactive_edges = edits.merge_preprocess(
self.cg,
subgraph_edges=edges,
supervoxels=self.added_edges.ravel(),
parent_ts=self.parent_ts,
)

with profiler.profile("check_fake_edges"):
atomic_edges, fake_edge_rows = edits.check_fake_edges(
self.cg,
subgraph_edges=edges,
supervoxels=self.added_edges.ravel(),
atomic_edges=self.added_edges,
inactive_edges=inactive_edges,
time_stamp=timestamp,
parent_ts=self.parent_ts,
)

atomic_edges, fake_edge_rows = edits.check_fake_edges(
self.cg,
atomic_edges=self.added_edges,
inactive_edges=inactive_edges,
time_stamp=timestamp,
parent_ts=self.parent_ts,
)
with TimeIt("add_edges", self.cg.graph_id, operation_id):
new_roots, new_l2_ids, new_entries = edits.add_edges(
self.cg,
Expand All @@ -652,6 +672,7 @@ def _apply(
parent_ts=self.parent_ts,
allow_same_segment_merge=self.allow_same_segment_merge,
do_sanity_check=self.do_sanity_check,
stitch_mode=self.stitch_mode,
)
return new_roots, new_l2_ids, fake_edge_rows + new_entries

Expand Down Expand Up @@ -874,12 +895,14 @@ def __init__(
"try placing the points further apart."
)

ids = np.concatenate([self.source_ids, self.sink_ids])
ids = np.concatenate([self.source_ids, self.sink_ids]).astype(basetypes.NODE_ID)
layers = self.cg.get_chunk_layers(ids)
assert np.sum(layers) == layers.size, "IDs must be supervoxels."

def _update_root_ids(self) -> np.ndarray:
sink_and_source_ids = np.concatenate((self.source_ids, self.sink_ids))
sink_and_source_ids = np.concatenate((self.source_ids, self.sink_ids)).astype(
basetypes.NODE_ID
)
root_ids = np.unique(
self.cg.get_roots(
sink_and_source_ids, assert_roots=True, time_stamp=self.parent_ts
Expand All @@ -895,7 +918,9 @@ def _apply(
# Verify that sink and source are from the same root object
root_ids = set(
self.cg.get_roots(
np.concatenate([self.source_ids, self.sink_ids]),
np.concatenate([self.source_ids, self.sink_ids]).astype(
basetypes.NODE_ID
),
assert_roots=True,
time_stamp=self.parent_ts,
)
Expand All @@ -916,7 +941,7 @@ def _apply(
edges = reduce(lambda x, y: x + y, edges_tuple, Edges([], []))
supervoxels = np.concatenate(
[agg.supervoxels for agg in l2id_agglomeration_d.values()]
)
).astype(basetypes.NODE_ID)
mask0 = np.isin(edges.node_ids1, supervoxels)
mask1 = np.isin(edges.node_ids2, supervoxels)
edges = edges[mask0 & mask1]
Expand Down
2 changes: 1 addition & 1 deletion pychunkedgraph/graph/segmenthistory.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def operation_id_root_id_dict(self):

@property
def operation_ids(self):
return np.array(list(self.operation_id_root_id_dict.keys()))
return np.array(list(self.operation_id_root_id_dict.keys()), dtype=basetypes.OPERATION_ID)

@property
def _log_rows(self):
Expand Down
16 changes: 10 additions & 6 deletions pychunkedgraph/ingest/create/atomic_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,20 @@ def _get_chunk_nodes_and_edges(chunk_edges_d: dict, isolated_ids: Sequence[int])
in-chunk edges and nodes_ids
"""
isolated_nodes_self_edges = np.vstack([isolated_ids, isolated_ids]).T
node_ids = [isolated_ids]
edge_ids = [isolated_nodes_self_edges]
node_ids = [isolated_ids] if len(isolated_ids) != 0 else []
edge_ids = (
[isolated_nodes_self_edges] if len(isolated_nodes_self_edges) != 0 else []
)
for edge_type in EDGE_TYPES:
edges = chunk_edges_d[edge_type]
node_ids.append(edges.node_ids1)
if edge_type == EDGE_TYPES.in_chunk:
node_ids.append(edges.node_ids2)
edge_ids.append(edges.get_pairs())

chunk_node_ids = np.unique(np.concatenate(node_ids))
chunk_node_ids = np.unique(np.concatenate(node_ids).astype(basetypes.NODE_ID))
edge_ids.append(np.vstack([chunk_node_ids, chunk_node_ids]).T)
return (chunk_node_ids, np.concatenate(edge_ids))
return (chunk_node_ids, np.concatenate(edge_ids).astype(basetypes.NODE_ID))


def _get_remapping(chunk_edges_d: dict):
Expand Down Expand Up @@ -116,7 +118,7 @@ def _process_component(
r_key = serializers.serialize_uint64(node_id)
nodes.append(cg.client.mutate_row(r_key, val_dict, time_stamp=time_stamp))

chunk_out_edges = np.concatenate(chunk_out_edges)
chunk_out_edges = np.concatenate(chunk_out_edges).astype(basetypes.NODE_ID)
cce_layers = cg.get_cross_chunk_edges_layer(chunk_out_edges)
u_cce_layers = np.unique(cce_layers)

Expand Down Expand Up @@ -147,5 +149,7 @@ def _get_outgoing_edges(node_id, chunk_edges_d, sparse_indices, remapping):
]
row_ids = row_ids[column_ids == 0]
# edges that this node is part of
chunk_out_edges = np.concatenate([chunk_out_edges, edges[row_ids]])
chunk_out_edges = np.concatenate([chunk_out_edges, edges[row_ids]]).astype(
basetypes.NODE_ID
)
return chunk_out_edges
4 changes: 2 additions & 2 deletions pychunkedgraph/ingest/create/parent_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _read_children_chunks(
children_ids = [types.empty_1d]
for child_coord in children_coords:
children_ids.append(_read_chunk([], cg, layer_id - 1, child_coord))
return np.concatenate(children_ids)
return np.concatenate(children_ids).astype(basetypes.NODE_ID)

with mp.Manager() as manager:
children_ids_shared = manager.list()
Expand All @@ -92,7 +92,7 @@ def _read_children_chunks(
multi_args,
n_threads=min(len(multi_args), mp.cpu_count()),
)
return np.concatenate(children_ids_shared)
return np.concatenate(children_ids_shared).astype(basetypes.NODE_ID)


def _read_chunk_helper(args):
Expand Down
6 changes: 4 additions & 2 deletions pychunkedgraph/ingest/ran_agglomeration.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ def get_active_edges(edges_d, mapping):
if edge_type == EDGE_TYPES.in_chunk:
pseudo_isolated_ids.append(edges.node_ids2)

return chunk_edges_active, np.unique(np.concatenate(pseudo_isolated_ids))
return chunk_edges_active, np.unique(
np.concatenate(pseudo_isolated_ids).astype(basetypes.NODE_ID)
)


def define_active_edges(edge_dict, mapping) -> Union[Dict, np.ndarray]:
Expand Down Expand Up @@ -380,7 +382,7 @@ def read_raw_agglomeration_data(imanager: IngestionManager, chunk_coord: np.ndar

edges_list = _read_agg_files(filenames, chunk_ids, path)
G = nx.Graph()
G.add_edges_from(np.concatenate(edges_list))
G.add_edges_from(np.concatenate(edges_list).astype(basetypes.NODE_ID))
mapping = {}
components = list(nx.connected_components(G))
for i_cc, cc in enumerate(components):
Expand Down
8 changes: 7 additions & 1 deletion pychunkedgraph/meshing/meshgen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,13 @@ def recursive_helper(cur_node_ids):
only_child_mask = np.array(
[len(children_for_node) == 1 for children_for_node in children_array]
)
only_children = children_array[only_child_mask].astype(np.uint64).ravel()
# Extract children from object array - each filtered element is a 1-element array
filtered_children = children_array[only_child_mask]
only_children = (
np.concatenate(filtered_children).astype(np.uint64)
if filtered_children.size
else np.array([], dtype=np.uint64)
)
if np.any(only_child_mask):
temp_array = cur_node_ids[stop_layer_mask]
temp_array[only_child_mask] = recursive_helper(only_children)
Expand Down
2 changes: 1 addition & 1 deletion requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ zmesh>=1.7.0
fastremap>=1.14.0
task-queue>=2.14.0
messagingclient
dracopy>=1.3.0
dracopy>=1.5.0
datastoreflex>=0.5.0
zstandard>=0.23.0

Expand Down