Skip to content

Commit e0edd2c

Browse files
committed
Run black
1 parent 29ae00c commit e0edd2c

File tree

3 files changed

+242
-165
lines changed

3 files changed

+242
-165
lines changed

shepherd_utils/TRAPI_to_NetworkX.py

Lines changed: 110 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import networkx as nx
55

6+
67
def trapi_kg_to_nx(
78
trapi: Union[Dict[str, Any], str, bytes],
89
*,
@@ -11,112 +12,127 @@ def trapi_kg_to_nx(
1112
default_weight: float = 1.0,
1213
edge_weight_attr: Optional[str] = None,
1314
edge_weight_transform: Optional[Callable[[Any], float]] = None,
14-
edge_payload: str = "full", # "full" | "weight_only"
15+
edge_payload: str = "full", # "full" | "weight_only"
1516
weight_agg: Union[str, Callable[[float, float], float]] = "sum", # collapse-only
1617
) -> nx.Graph:
1718
"""
18-
Convert a TRAPI ``knowledge_graph`` into a NetworkX graph that’s ready for downstream
19-
algorithms (e.g., Personalized PageRank). Supports directed/undirected graphs, multi-edge
20-
preservation, optional weight derivation, and (when collapsed) weight aggregation.
21-
22-
Overview
23-
--------
24-
- **Input forms:** Accepts a full TRAPI Response/Message/Result or a raw KG dict.
25-
The KG is discovered at:
26-
``response['message']['knowledge_graph']`` → ``message['knowledge_graph']`` → ``trapi['knowledge_graph']``.
27-
- **Graph type:** Controlled by ``directed`` and ``multigraph``:
28-
- ``multigraph=True`` → (Multi)DiGraph/(Multi)Graph with one edge per TRAPI edge.
29-
- ``multigraph=False`` → DiGraph/Graph where parallel TRAPI edges between the same (u, v)
30-
are either collapsed with metadata retained (“full”) or reduced to a single weighted edge
31-
(“weight_only”).
32-
- **Nodes:** Created for all TRAPI nodes and any edge endpoints not listed in the node map.
33-
Node metadata is preserved and an ``attributes_flat`` dict is added, where TRAPI attributes
34-
are keyed by ``original_attribute_name`` (fallback: ``attribute_type_id``); duplicate keys
35-
become lists.
36-
- **Edges & payload control (``edge_payload``):**
37-
- ``"full"``: keep TRAPI edge metadata (predicate, qualifiers, sources, attributes, etc.).
38-
In collapsed mode, the *last encountered* edge’s metadata wins for (u, v).
39-
- ``"weight_only"``: strip all edge metadata and keep only an optional ``weight``; in
40-
collapsed mode, multiple TRAPI edges between (u, v) are combined via ``weight_agg``.
41-
In multigraph mode, full metadata is always kept; ``edge_payload`` only affects collapsed graphs.
42-
- **Edge IDs (multigraph only):** Each multiedge gets an ``id`` attribute set to the TRAPI
43-
edge id (or a synthetic string if missing) and is used as the multiedge key.
44-
45-
Weights
46-
-------
47-
- **Enable weights** by setting ``edge_weight_attr`` to the TRAPI attribute name you want to use
48-
(e.g., ``"normalized_google_distance"``). The value is looked up in the edge’s flattened
49-
attributes. If the attribute is missing or cannot be parsed as ``float``, ``default_weight``
50-
is used.
51-
- **Transform weights** by providing ``edge_weight_transform(value) -> float``; if this transform
52-
raises, the untransformed numeric value is used instead of ``default_weight``.
53-
- **Disable weights entirely** by passing ``edge_weight_attr=None``. No ``weight`` attribute will be
54-
set on any edge (including in collapsed graphs).
55-
- **Aggregation (collapsed + weight_only only):** ``weight_agg`` combines weights for multiple TRAPI
56-
edges between (u, v). Built-ins: ``"sum"``, ``"max"``, ``"min"``, ``"mean"``, ``"first"``, or
57-
a callable ``(existing_weight, new_weight) -> combined_weight``.
58-
*Note:* the built-in ``"mean"`` is a simple pairwise average; supply a custom aggregator if you
59-
need the exact arithmetic mean across many edges.
60-
61-
Parameters
62-
----------
63-
trapi : dict | str | bytes
64-
TRAPI Response/Message dict, or a JSON string/bytes containing one. The function extracts
65-
the ``knowledge_graph`` as described above.
66-
multigraph : bool, default True
67-
Preserve parallel TRAPI edges (recommended to retain provenance/predicate distinctions).
68-
directed : bool, default True
69-
Build a directed graph (TRAPI ``subject`` → ``object``). Set ``False`` for undirected.
70-
default_weight : float, default 1.0
71-
Fallback edge weight when weights are enabled but the specified attribute is missing or
72-
non-numeric. Ignored when ``edge_weight_attr=None``.
73-
edge_weight_attr : str | None, default None
74-
Name of the TRAPI edge attribute to use as the weight source (flattened lookup via
75-
``original_attribute_name``/``attribute_type_id``). If ``None``, no weights are added.
76-
edge_weight_transform : callable(value) -> float | None, default None
77-
Optional transform applied to the numeric attribute value. On error, the raw numeric value
78-
is used (not ``default_weight``).
79-
edge_payload : {"full", "weight_only"}, default "full"
80-
Controls how much edge metadata is retained in collapsed graphs. Ignored for multigraphs.
81-
weight_agg : {"sum","max","min","mean","first"} | callable, default "sum"
82-
Only used when ``multigraph=False`` and ``edge_payload="weight_only"``. Aggregates multiple
83-
weights for the same (u, v). For a callable, provide ``f(existing, new) -> combined``.
84-
85-
Returns
86-
-------
87-
networkx.(Multi)DiGraph or (Multi)Graph
88-
Graph whose node ids are TRAPI node CURIEs. Node data includes original node metadata and
89-
``attributes_flat``. Edge data depends on the mode:
90-
• multigraph: full TRAPI edge metadata + optional ``weight``.
91-
• collapsed + "full": last-seen metadata for (u, v) + optional ``weight``.
92-
• collapsed + "weight_only": only ``weight`` (or no attributes if weights disabled).
19+
Convert a TRAPI ``knowledge_graph`` into a NetworkX graph that’s ready for downstream
20+
algorithms (e.g., Personalized PageRank). Supports directed/undirected graphs, multi-edge
21+
preservation, optional weight derivation, and (when collapsed) weight aggregation.
22+
23+
Overview
24+
--------
25+
- **Input forms:** Accepts a full TRAPI Response/Message/Result or a raw KG dict.
26+
The KG is discovered at:
27+
``response['message']['knowledge_graph']`` → ``message['knowledge_graph']`` → ``trapi['knowledge_graph']``.
28+
- **Graph type:** Controlled by ``directed`` and ``multigraph``:
29+
- ``multigraph=True`` → (Multi)DiGraph/(Multi)Graph with one edge per TRAPI edge.
30+
- ``multigraph=False`` → DiGraph/Graph where parallel TRAPI edges between the same (u, v)
31+
are either collapsed with metadata retained (“full”) or reduced to a single weighted edge
32+
(“weight_only”).
33+
- **Nodes:** Created for all TRAPI nodes and any edge endpoints not listed in the node map.
34+
Node metadata is preserved and an ``attributes_flat`` dict is added, where TRAPI attributes
35+
are keyed by ``original_attribute_name`` (fallback: ``attribute_type_id``); duplicate keys
36+
become lists.
37+
- **Edges & payload control (``edge_payload``):**
38+
- ``"full"``: keep TRAPI edge metadata (predicate, qualifiers, sources, attributes, etc.).
39+
In collapsed mode, the *last encountered* edge’s metadata wins for (u, v).
40+
- ``"weight_only"``: strip all edge metadata and keep only an optional ``weight``; in
41+
collapsed mode, multiple TRAPI edges between (u, v) are combined via ``weight_agg``.
42+
In multigraph mode, full metadata is always kept; ``edge_payload`` only affects collapsed graphs.
43+
- **Edge IDs (multigraph only):** Each multiedge gets an ``id`` attribute set to the TRAPI
44+
edge id (or a synthetic string if missing) and is used as the multiedge key.
45+
46+
Weights
47+
-------
48+
- **Enable weights** by setting ``edge_weight_attr`` to the TRAPI attribute name you want to use
49+
(e.g., ``"normalized_google_distance"``). The value is looked up in the edge’s flattened
50+
attributes. If the attribute is missing or cannot be parsed as ``float``, ``default_weight``
51+
is used.
52+
- **Transform weights** by providing ``edge_weight_transform(value) -> float``; if this transform
53+
raises, the untransformed numeric value is used instead of ``default_weight``.
54+
- **Disable weights entirely** by passing ``edge_weight_attr=None``. No ``weight`` attribute will be
55+
set on any edge (including in collapsed graphs).
56+
- **Aggregation (collapsed + weight_only only):** ``weight_agg`` combines weights for multiple TRAPI
57+
edges between (u, v). Built-ins: ``"sum"``, ``"max"``, ``"min"``, ``"mean"``, ``"first"``, or
58+
a callable ``(existing_weight, new_weight) -> combined_weight``.
59+
*Note:* the built-in ``"mean"`` is a simple pairwise average; supply a custom aggregator if you
60+
need the exact arithmetic mean across many edges.
61+
62+
Parameters
63+
----------
64+
trapi : dict | str | bytes
65+
TRAPI Response/Message dict, or a JSON string/bytes containing one. The function extracts
66+
the ``knowledge_graph`` as described above.
67+
multigraph : bool, default True
68+
Preserve parallel TRAPI edges (recommended to retain provenance/predicate distinctions).
69+
directed : bool, default True
70+
Build a directed graph (TRAPI ``subject`` → ``object``). Set ``False`` for undirected.
71+
default_weight : float, default 1.0
72+
Fallback edge weight when weights are enabled but the specified attribute is missing or
73+
non-numeric. Ignored when ``edge_weight_attr=None``.
74+
edge_weight_attr : str | None, default None
75+
Name of the TRAPI edge attribute to use as the weight source (flattened lookup via
76+
``original_attribute_name``/``attribute_type_id``). If ``None``, no weights are added.
77+
edge_weight_transform : callable(value) -> float | None, default None
78+
Optional transform applied to the numeric attribute value. On error, the raw numeric value
79+
is used (not ``default_weight``).
80+
edge_payload : {"full", "weight_only"}, default "full"
81+
Controls how much edge metadata is retained in collapsed graphs. Ignored for multigraphs.
82+
weight_agg : {"sum","max","min","mean","first"} | callable, default "sum"
83+
Only used when ``multigraph=False`` and ``edge_payload="weight_only"``. Aggregates multiple
84+
weights for the same (u, v). For a callable, provide ``f(existing, new) -> combined``.
85+
86+
Returns
87+
-------
88+
networkx.(Multi)DiGraph or (Multi)Graph
89+
Graph whose node ids are TRAPI node CURIEs. Node data includes original node metadata and
90+
``attributes_flat``. Edge data depends on the mode:
91+
• multigraph: full TRAPI edge metadata + optional ``weight``.
92+
• collapsed + "full": last-seen metadata for (u, v) + optional ``weight``.
93+
• collapsed + "weight_only": only ``weight`` (or no attributes if weights disabled).
9394
9495
"""
9596
# Check to see if `edge_payload` is `full`, if so, there should be no weight_agg
9697
if edge_payload not in ("full", "weight_only"):
9798
raise ValueError("edge_payload must be 'full' or 'weight_only'")
9899
if edge_payload == "full" and weight_agg != "sum":
99100
raise ValueError("weight_agg is only used when edge_payload is 'weight_only'")
101+
100102
# if multigraph is True, then none of the aggregation stuff is relevant
101103
def _as_dict(obj: Union[Dict[str, Any], str, bytes]) -> Dict[str, Any]:
102104
return json.loads(obj) if isinstance(obj, (str, bytes)) else obj
103105

104106
def _get_kg(root: Dict[str, Any]) -> Dict[str, Any]:
105-
if "message" in root and isinstance(root["message"], dict) and "knowledge_graph" in root["message"]:
107+
if (
108+
"message" in root
109+
and isinstance(root["message"], dict)
110+
and "knowledge_graph" in root["message"]
111+
):
106112
return root["message"]["knowledge_graph"]
107113
if "knowledge_graph" in root:
108114
return root["knowledge_graph"]
109115
raise KeyError("No TRAPI knowledge_graph found in input.")
110116

111-
def _flatten_attributes(attrs: Optional[Iterable[Dict[str, Any]]]) -> Dict[str, Any]:
117+
def _flatten_attributes(
118+
attrs: Optional[Iterable[Dict[str, Any]]],
119+
) -> Dict[str, Any]:
112120
flat: Dict[str, Any] = {}
113121
if not attrs:
114122
return flat
115123
for a in attrs:
116-
key = a.get("original_attribute_name") or a.get("attribute_type_id") or f"attr_{len(flat)}"
124+
key = (
125+
a.get("original_attribute_name")
126+
or a.get("attribute_type_id")
127+
or f"attr_{len(flat)}"
128+
)
117129
val = a.get("value")
118130
if key in flat:
119-
flat[key] = flat[key] + [val] if isinstance(flat[key], list) else [flat[key], val]
131+
flat[key] = (
132+
flat[key] + [val]
133+
if isinstance(flat[key], list)
134+
else [flat[key], val]
135+
)
120136
else:
121137
flat[key] = val
122138
return flat
@@ -144,11 +160,11 @@ def _agg_fn():
144160
if callable(weight_agg):
145161
return weight_agg
146162
return {
147-
"sum": lambda a, b: a + b,
148-
"max": lambda a, b: a if a >= b else b,
149-
"min": lambda a, b: a if a <= b else b,
163+
"sum": lambda a, b: a + b,
164+
"max": lambda a, b: a if a >= b else b,
165+
"min": lambda a, b: a if a <= b else b,
150166
"mean": lambda a, b: 0.5 * (a + b),
151-
"first":lambda a, b: a,
167+
"first": lambda a, b: a,
152168
}.get(weight_agg, lambda a, b: a + b)
153169

154170
data = _as_dict(trapi)
@@ -177,12 +193,15 @@ def _agg_fn():
177193
combine = _agg_fn()
178194

179195
# Edges
180-
for eid, eobj in (edges.items() if isinstance(edges, dict) else enumerate(edges)):
181-
u = eobj.get("subject"); v = eobj.get("object")
196+
for eid, eobj in edges.items() if isinstance(edges, dict) else enumerate(edges):
197+
u = eobj.get("subject")
198+
v = eobj.get("object")
182199
if u is None or v is None:
183200
continue
184-
if u not in G: G.add_node(u)
185-
if v not in G: G.add_node(v)
201+
if u not in G:
202+
G.add_node(u)
203+
if v not in G:
204+
G.add_node(v)
186205

187206
e_attr_flat = _flatten_attributes(eobj.get("attributes"))
188207
w = _derive_weight(e_attr_flat) # None if weights disabled
@@ -211,10 +230,11 @@ def _agg_fn():
211230
# edge_payload == "full": keep metadata but only add 'weight' if enabled
212231
e_attrs = dict(eobj)
213232
e_attrs["attributes_flat"] = e_attr_flat
214-
e_attrs["id"] = eid if isinstance(eid, str) else eobj.get("id", str(eid))
233+
e_attrs["id"] = (
234+
eid if isinstance(eid, str) else eobj.get("id", str(eid))
235+
)
215236
if weights_enabled and w is not None:
216237
e_attrs["weight"] = w
217238
G.add_edge(u, v, **e_attrs)
218239

219240
return G
220-

0 commit comments

Comments
 (0)