Skip to content
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
adjusttext>=1.0.0
cftime>=1.6.4
colorcet>=3.1.0
contourpy>=1.2.1
Expand Down Expand Up @@ -25,6 +26,7 @@ panel>=1.4.4
param>=2.1.1
Pillow>=10.4.0
playwright>=1.45.1
plotly>=6.2.3
pooch>=1.8.2
psutil>=5.9.0
pyparsing>=3.1.2
Expand Down
220 changes: 220 additions & 0 deletions src/temporalmapper/temporal_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from sklearn.neighbors import NearestNeighbors
from sklearn.base import ClusterMixin
from datamapplot.palette_handling import palette_from_datamap
import matplotlib as mpl
from copy import deepcopy
import plotly.graph_objects as go

"""TemporalMapper class
minimal usage example:
Expand Down Expand Up @@ -503,3 +506,220 @@ def vertex_subgraph(self, v, threshold=0.1):
def get_subgraph_data(self, vertices):
vals = [self.get_vertex_data(v) for v in vertices]
return np.concatenate(vals, axis=1)

def edge_thresholded_subgraph(self, threshold):
edges_to_remove = [
(u, v) for u, v, data in self.G.edges(data=True)
if data['weight'] < threshold
]
G_prime = deepcopy(self.G)
G_prime.remove_edges_from(edges_to_remove)
return G_prime

def temporal_plot(
self,
ax: mpl.axes = None,
title: str = None,
cluster_labels: dict = None,
cluster_label_kwargs: dict = None,
vertices: list[str] = None,
bundle: bool = False,
edge_labels: dict = None,
node_kwargs: dict = {},
edge_kwargs: dict = {},
edge_scaling: float = 1,
node_scaling: float = 1,
node_size_bounds: tuple[float] = (5,50),
edge_weight_bounds: float = 0.1,
node_size_scale: str = 'sigmoid',
layout_optimization: str = "barycenter",
layout_optimization_kwargs: dict = {},
):
"""
Generate a temporal plot of the Mapper graph on a specified matplotlib axis using sensible defaults.

Parameters
----------
ax : matplotlib.axes.Axes, optional
Matplotlib Axes to draw the plot on. If None, a new figure and axes
are created.
title : str, optional
Title of the plot.
cluster_labels : dict, optional
Mapping from node to label text. Defaults to string representations
of the node identifiers.
cluster_label_kwargs : dict, optional
Mapping from node to keyword arguments passed to `ax.text` when drawing
labels (e.g., fontsize, color).
vertices : list of str, optional
Subset of graph nodes to include in the plot. If None, all nodes in
`self.G` are used.
bundle : bool, default False
Whether to apply edge bundling in the visualization.
edge_labels : dict, optional
Mapping from edge to label text.
node_kwargs : dict, default {}
Keyword arguments controlling node appearance.
edge_kwargs : dict, default {}
Keyword arguments controlling edge appearance.
edge_scaling : float, default 1
Scaling factor applied to edge weights or widths.
node_scaling : float, default 1
Scaling factor applied to node sizes.
node_size_bounds : tuple[float], default (5,25)
Size bounds to clip the node sizes to.
edge_weight_bounds : tuple[float], default (0.1,1)
Minimum edge weight for rendering.
node_size_scale : {'linear', 'log', 'sigmoid'}, default 'sigmoid'
Scaling mode used for node sizes.
layout_optimization : str, default 'barycenter'
Layout optimization method passed to `time_semantic_plot`.
layout_optimization_kwargs : dict, optional
Additional keyword arguments for the layout optimization routine.

Returns
-------
matplotlib.axes.Axes
The Axes object containing the temporal plot.

"""
y_initial_pos = np.arctan2(self.data[:,1], self.data[:,0])

if ax is None:
fig, ax = mpl.pyplot.subplots(figsize=(12,8))
if vertices is None:
vertices = self.G.nodes()
G = self.G.subgraph(vertices)

if cluster_labels is None:
cluster_labels = {node:str(node) for node in vertices}
if cluster_label_kwargs is None:
cluster_label_kwargs = {node:{} for node in vertices}

clr_dict = nx.get_node_attributes(G, "colour")
edge_color_list = [
clr_dict[u]
for u, v in G.edges()
]
edge_kwargs = {'edge_color':edge_color_list}

ax = time_semantic_plot(
self,
y_initial_pos,
ax = ax,
vertices = vertices,
bundle = bundle,
edge_labels = edge_labels,
cluster_labels = cluster_labels,
cluster_label_kwargs = cluster_label_kwargs,
layout_optimization = layout_optimization,
node_kwargs = node_kwargs,
edge_kwargs = edge_kwargs,
edge_scaling = edge_scaling,
node_scaling = node_scaling,
node_size_bounds = node_size_bounds,
edge_weight_bounds = edge_weight_bounds,
node_size_scale = node_size_scale
)
if title is not None:
ax.set_title(title)
return ax

def interactive_temporal_plot(
self,
cluster_labels: dict = {},
vertices = None,
hover_text = {},
graph_layout: go.Layout = None,
layout_optimization: str = "barycenter",
layout_optimization_kwargs: dict = {},
edge_scaling: float = 1,
node_scaling: float = 1,
node_size_bounds: tuple[float] = (5,50),
edge_weight_bounds: tuple[float] = (0.1,1),
node_size_scale: str = 'sigmoid',
):
"""
Generate an interactive (plotly) temporal plot of the Mapper graph on a specified matplotlib axis using sensible defaults.

Parameters
----------
cluster_labels : dict, optional
Mapping from node to label text. Defaults to string representations
of the node identifiers.
vertices : list of str, optional
Subset of graph nodes to include in the plot. If None, all nodes in
`self.G` are used.
hover_text : dict, default {}
A dictionary with `hover_text[node]` containing a string with the text
to display when hovering over vertex `node`.
edge_scaling : float, default 1
Scaling factor applied to edge weights or widths.
node_scaling : float, default 1
Scaling factor applied to node sizes.
node_size_bounds : tuple[float], default (5,25)
Size bounds to clip the node sizes to.
edge_weight_bounds : tuple[float], default (0.1,1)
Minimum edge weight for rendering.
node_size_scale : {'linear', 'log', 'sigmoid'}, default 'sigmoid'
Scaling mode used for node sizes.
layout_optimization : str, default 'barycenter'
Layout optimization method passed to `time_semantic_plot`.
layout_optimization_kwargs : dict, optional
Additional keyword arguments for the layout optimization routine.

Returns
-------
matplotlib.axes.Axes
The Axes object containing the temporal plot.

"""
if vertices is None:
vertices = self.G.nodes()
G = self.G.subgraph(vertices)

if len(hover_text.keys())==0:
# construct some default hover text.
for node in vertices:
idx = self.get_vertex_data(node)
median_time = np.median(self.time[idx])
if cluster_labels.get(node,'') != '':
label_str = cluster_labels[node]+"<br>"
else:
label_str = ''
label_str += f'Node {node}<br>Time: {median_time}'
hover_text[node] = label_str

y_initial_pos = np.arctan2(self.data[:,1], self.data[:,0])
compute_time_semantic_positions(
self,
y_initial_pos,
layout_optimization = layout_optimization,
layout_optimization_kwargs = layout_optimization_kwargs
)
positions = nx.get_node_attributes(self.G,'ts_pos')
edge_traces, node_trace = prepare_plotly_graph_objects(
self,
positions,
hover_text = hover_text,
edge_scaling = edge_scaling,
node_scaling = node_scaling,
node_size_bounds = node_size_bounds,
edge_weight_bounds = edge_weight_bounds,
node_size_scale = node_size_scale,
)
if graph_layout is None:
graph_layout = go.Layout(
hovermode = 'closest',
showlegend = False,
margin=dict(b=20,l=5,r=5,t=40),
xaxis=dict(showgrid=False, zeroline=False),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
)

traces = edge_traces+[node_trace]
fig = go.Figure(
data=traces,
layout = graph_layout,
)
return fig
Loading