Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
project = "starbars"
copyright = "2024, Elide Brunelli"
author = "Elide Brunelli"
release = "3.0.0"
release = "3.1.0"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
190 changes: 123 additions & 67 deletions starbars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

import matplotlib.pyplot as plt

from ._utils import pvalue_to_asterisks, get_positions, get_starbars_logger
from ._utils import pvalue_to_asterisks, get_positions, get_starbars_logger, find_level

__version__ = "3.0.0"
__version__ = "3.1.0"


DEBUG = bool(os.environ.get("DEBUG_STARBARS", False))
Expand All @@ -31,6 +31,7 @@ def draw_annotation(
color="k",
text_args=None,
line_args=None,
h_gap=0.03,
):
"""
Draw statistical significance bars and p-value labels between chosen pairs of columns on existing plots.
Expand Down Expand Up @@ -75,15 +76,15 @@ def draw_annotation(
coords_to_px = create_coordinate_transformer(ax, mode)
px_to_coords = create_coordinate_transformer(ax, mode, True)

dpi = ax.figure.dpi
px_txt = (fontsize / 72) * dpi
# Transform the annot axis interval from (0, 0) to (0, 1) into pixels
px_ax = (
ax.transAxes.transform(unit_vector)[annot_axis]
- ax.transAxes.transform((0, 0))[annot_axis]
)

# Get annot axis limit maximum
annot = get_lim()[1]

# Take annot axis limit maximum and add the first bar gap in axis pixels as starting point
annot = px_to_coords((0, coords_to_px((0, annot))[annot_axis] + px_ax * bar_gap))[
annot_axis
Expand All @@ -92,45 +93,132 @@ def draw_annotation(
(0, ax.transData.transform((0, ax.get_ylim()[1]))[1] + px_ax * bar_gap)
)[1]

dpi = ax.figure.dpi
text_height = (fontsize / 72) * dpi / px_ax

# Find levels
leveled_annotations = find_level(ax, annotations, mode)

bars = []
text_positions = []
text_labels = []
max_annot_px = 0

# Get the positions of the values
for box1, box2, pvalue in annotations:
label = pvalue_to_asterisks(pvalue)
if label == "ns" and not ns_show:
continue
box1_position, box2_position = get_positions(ax, box1, box2, mode)
box1_px = coords_to_px((box1_position, annot))[+(not annot_axis)]
box2_px = coords_to_px((box2_position, annot))[+(not annot_axis)]
annot_px = coords_to_px((box1_position, annot))[annot_axis]

bar_box = [box1_px, box1_px, box2_px, box2_px]
bar_annot = [
annot_px,
px_ax * tip_length + annot_px,
px_ax * tip_length + annot_px,
annot_px,
]
text_pos = px_to_coords(
(
(box1_px + box2_px) / 2,
px_ax * tip_length + annot_px + px_ax * text_distance,
)
for annotation in leveled_annotations:
result = calculate_bar(
annotation,
ns_show,
annot,
annot_axis,
coords_to_px,
px_ax,
tip_length,
h_gap,
px_to_coords,
text_distance,
bar_gap,
text_height,
)
if result:
points, text_pos, label, bar_max = result
bars.append(points)
text_positions.append(text_pos)
text_labels.append(label)
max_annot_px = max(max_annot_px, bar_max)

draw_bars(
ax,
bars,
text_positions,
text_labels,
mode,
line_width,
color,
fontsize,
text_args,
line_args,
)


def create_coordinate_transformer(ax, mode, inverse=False):
if inverse:
transformation = lambda *args: ax.transData.inverted().transform(*args)
else:
transformation = lambda *args: ax.transData.transform(*args)

if mode == "vertical":

def coords_to_px(coords):
return transformation(coords)

else:

def coords_to_px(coords):
return transformation(tuple(reversed(coords)))

return coords_to_px


def calculate_bar(
annotation,
ns_show,
annot,
annot_axis,
coords_to_px,
px_ax,
tip_length,
h_gap,
px_to_coords,
text_distance,
bar_gap,
text_height,
):
box1_pos, box2_pos, level, pvalue = annotation

label = pvalue_to_asterisks(pvalue)
if label == "ns" and not ns_show:
return

box1_px = coords_to_px((box1_pos + h_gap / 2, annot))[+(not annot_axis)]
box2_px = coords_to_px((box2_pos - h_gap / 2, annot))[+(not annot_axis)]
annot_px = coords_to_px((box1_pos, annot))[annot_axis]

bar_box = [box1_px, box1_px, box2_px, box2_px]
level_offset = px_ax * (bar_gap + tip_length + text_distance + text_height)
offset = annot_px + level_offset * level

bar_annot = [
offset,
px_ax * tip_length + offset,
px_ax * tip_length + offset,
offset,
]

text_pos = px_to_coords(
(
(box1_px + box2_px) / 2,
px_ax * tip_length + px_ax * text_distance + offset,
)
)

points = [px_to_coords((_box, _annot)) for _box, _annot in zip(bar_box, bar_annot)]

return points, text_pos, label, annot_px + level_offset * (level + 1)

points = [
px_to_coords((_box, _annot)) for _box, _annot in zip(bar_box, bar_annot)
]
bars.append(points)
text_positions.append(text_pos)
text_labels.append(label)

# Move up the annot point to the next bar's start annot
annot = px_to_coords(
(box1_px, annot_px + px_ax * tip_length + px_ax * bar_gap + px_txt)
)[annot_axis]
def draw_bars(
ax,
bars,
text_positions,
text_labels,
mode,
line_width,
color,
fontsize,
text_args,
line_args,
):

# Draw the statistical annotation
for bar, text_pos, label in zip(bars, text_positions, text_labels):
Expand All @@ -152,35 +240,3 @@ def draw_annotation(
rotation=-90 * (mode == "horizontal"),
**text_args
)

if len(annotations) == 0:
return

# Adjust the annot axis limit of the current subplot to accommodate the top margin
annot_px = coords_to_px((box1_position, annot))[annot_axis]
annot_final = px_to_coords((0, annot_px + px_ax * top_margin))[annot_axis]
annot = px_to_coords((0, coords_to_px((0, annot_final))[annot_axis]))[annot_axis]

# If final annotation is out of bounds, adjust the annot limit to include it.
if annot_final > 1:
annot0, annot_max = get_lim()
set_lim(annot0, annot)


def create_coordinate_transformer(ax, mode, inverse=False):
if inverse:
transformation = lambda *args: ax.transData.inverted().transform(*args)
else:
transformation = lambda *args: ax.transData.transform(*args)

if mode == "vertical":

def coords_to_px(coords):
return transformation(coords)

else:

def coords_to_px(coords):
return transformation(tuple(reversed(coords)))

return coords_to_px
59 changes: 59 additions & 0 deletions starbars/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,62 @@ def get_starbars_logger(level):
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
return logger


def create_coordinate_transformer(ax, mode, inverse=False):
if inverse:
transformation = lambda *args: ax.transData.inverted().transform(*args)
else:
transformation = lambda *args: ax.transData.transform(*args)

if mode == "vertical":

def coords_to_px(coords):
return transformation(coords)

else:

def coords_to_px(coords):
return transformation(tuple(reversed(coords)))

return coords_to_px


def find_level(ax, annotations, mode):

# Sort annotations for optimized stacking
sorted_annotations = sorted(
annotations,
key=lambda x: (
min(get_positions(ax, x[0], x[1], mode)),
abs(
get_positions(ax, x[0], x[1], mode)[1]
- get_positions(ax, x[0], x[1], mode)[0]
),
),
)

levels = []

for box1, box2, pvalue in sorted_annotations:

# Retrieve positions
box_positions = get_positions(ax, box1, box2, mode)
box1_pos = min(box_positions)
box2_pos = max(box_positions)

# Find the first available level
for level_index, level in enumerate(levels):
if all(
box1_pos >= existing_end or box2_pos <= existing_start
for existing_start, existing_end, _, _ in level
):
levels[level_index].append((box1_pos, box2_pos, level_index, pvalue))
break
else:
# Create new level
level_index = len(levels)
levels.append([(box1_pos, box2_pos, level_index, pvalue)])

# Flatten list of lists
return [annotation for level in levels for annotation in level]
Loading