diff --git a/docs/source/conf.py b/docs/source/conf.py index bca57d9..5e5bc68 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,7 +9,7 @@ project = "starbars" copyright = "2024, Elide Brunelli" author = "Elide Brunelli" -release = "3.0.0" +release = "3.1.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/starbars/__init__.py b/starbars/__init__.py index c035572..c05bab8 100644 --- a/starbars/__init__.py +++ b/starbars/__init__.py @@ -8,9 +8,9 @@ import matplotlib.pyplot as plt -from ._utils import pvalue_to_asterisks, get_positions, get_starbars_logger +from ._utils import pvalue_to_asterisks, get_positions, get_starbars_logger, find_level -__version__ = "3.0.0" +__version__ = "3.1.0" DEBUG = bool(os.environ.get("DEBUG_STARBARS", False)) @@ -31,6 +31,7 @@ def draw_annotation( color="k", text_args=None, line_args=None, + h_gap=0.03, ): """ Draw statistical significance bars and p-value labels between chosen pairs of columns on existing plots. @@ -75,15 +76,15 @@ def draw_annotation( coords_to_px = create_coordinate_transformer(ax, mode) px_to_coords = create_coordinate_transformer(ax, mode, True) - dpi = ax.figure.dpi - px_txt = (fontsize / 72) * dpi # Transform the annot axis interval from (0, 0) to (0, 1) into pixels px_ax = ( ax.transAxes.transform(unit_vector)[annot_axis] - ax.transAxes.transform((0, 0))[annot_axis] ) + # Get annot axis limit maximum annot = get_lim()[1] + # Take annot axis limit maximum and add the first bar gap in axis pixels as starting point annot = px_to_coords((0, coords_to_px((0, annot))[annot_axis] + px_ax * bar_gap))[ annot_axis @@ -92,45 +93,132 @@ def draw_annotation( (0, ax.transData.transform((0, ax.get_ylim()[1]))[1] + px_ax * bar_gap) )[1] + dpi = ax.figure.dpi + text_height = (fontsize / 72) * dpi / px_ax + + # Find levels + leveled_annotations = find_level(ax, annotations, mode) + bars = [] text_positions = [] text_labels = [] + max_annot_px = 0 # Get the positions of the values - for box1, box2, pvalue in annotations: - label = pvalue_to_asterisks(pvalue) - if label == "ns" and not ns_show: - continue - box1_position, box2_position = get_positions(ax, box1, box2, mode) - box1_px = coords_to_px((box1_position, annot))[+(not annot_axis)] - box2_px = coords_to_px((box2_position, annot))[+(not annot_axis)] - annot_px = coords_to_px((box1_position, annot))[annot_axis] - - bar_box = [box1_px, box1_px, box2_px, box2_px] - bar_annot = [ - annot_px, - px_ax * tip_length + annot_px, - px_ax * tip_length + annot_px, - annot_px, - ] - text_pos = px_to_coords( - ( - (box1_px + box2_px) / 2, - px_ax * tip_length + annot_px + px_ax * text_distance, - ) + for annotation in leveled_annotations: + result = calculate_bar( + annotation, + ns_show, + annot, + annot_axis, + coords_to_px, + px_ax, + tip_length, + h_gap, + px_to_coords, + text_distance, + bar_gap, + text_height, + ) + if result: + points, text_pos, label, bar_max = result + bars.append(points) + text_positions.append(text_pos) + text_labels.append(label) + max_annot_px = max(max_annot_px, bar_max) + + draw_bars( + ax, + bars, + text_positions, + text_labels, + mode, + line_width, + color, + fontsize, + text_args, + line_args, + ) + + +def create_coordinate_transformer(ax, mode, inverse=False): + if inverse: + transformation = lambda *args: ax.transData.inverted().transform(*args) + else: + transformation = lambda *args: ax.transData.transform(*args) + + if mode == "vertical": + + def coords_to_px(coords): + return transformation(coords) + + else: + + def coords_to_px(coords): + return transformation(tuple(reversed(coords))) + + return coords_to_px + + +def calculate_bar( + annotation, + ns_show, + annot, + annot_axis, + coords_to_px, + px_ax, + tip_length, + h_gap, + px_to_coords, + text_distance, + bar_gap, + text_height, +): + box1_pos, box2_pos, level, pvalue = annotation + + label = pvalue_to_asterisks(pvalue) + if label == "ns" and not ns_show: + return + + box1_px = coords_to_px((box1_pos + h_gap / 2, annot))[+(not annot_axis)] + box2_px = coords_to_px((box2_pos - h_gap / 2, annot))[+(not annot_axis)] + annot_px = coords_to_px((box1_pos, annot))[annot_axis] + + bar_box = [box1_px, box1_px, box2_px, box2_px] + level_offset = px_ax * (bar_gap + tip_length + text_distance + text_height) + offset = annot_px + level_offset * level + + bar_annot = [ + offset, + px_ax * tip_length + offset, + px_ax * tip_length + offset, + offset, + ] + + text_pos = px_to_coords( + ( + (box1_px + box2_px) / 2, + px_ax * tip_length + px_ax * text_distance + offset, ) + ) + + points = [px_to_coords((_box, _annot)) for _box, _annot in zip(bar_box, bar_annot)] + + return points, text_pos, label, annot_px + level_offset * (level + 1) - points = [ - px_to_coords((_box, _annot)) for _box, _annot in zip(bar_box, bar_annot) - ] - bars.append(points) - text_positions.append(text_pos) - text_labels.append(label) - # Move up the annot point to the next bar's start annot - annot = px_to_coords( - (box1_px, annot_px + px_ax * tip_length + px_ax * bar_gap + px_txt) - )[annot_axis] +def draw_bars( + ax, + bars, + text_positions, + text_labels, + mode, + line_width, + color, + fontsize, + text_args, + line_args, +): # Draw the statistical annotation for bar, text_pos, label in zip(bars, text_positions, text_labels): @@ -152,35 +240,3 @@ def draw_annotation( rotation=-90 * (mode == "horizontal"), **text_args ) - - if len(annotations) == 0: - return - - # Adjust the annot axis limit of the current subplot to accommodate the top margin - annot_px = coords_to_px((box1_position, annot))[annot_axis] - annot_final = px_to_coords((0, annot_px + px_ax * top_margin))[annot_axis] - annot = px_to_coords((0, coords_to_px((0, annot_final))[annot_axis]))[annot_axis] - - # If final annotation is out of bounds, adjust the annot limit to include it. - if annot_final > 1: - annot0, annot_max = get_lim() - set_lim(annot0, annot) - - -def create_coordinate_transformer(ax, mode, inverse=False): - if inverse: - transformation = lambda *args: ax.transData.inverted().transform(*args) - else: - transformation = lambda *args: ax.transData.transform(*args) - - if mode == "vertical": - - def coords_to_px(coords): - return transformation(coords) - - else: - - def coords_to_px(coords): - return transformation(tuple(reversed(coords))) - - return coords_to_px diff --git a/starbars/_utils.py b/starbars/_utils.py index fcd64e8..1a457ef 100644 --- a/starbars/_utils.py +++ b/starbars/_utils.py @@ -120,3 +120,62 @@ def get_starbars_logger(level): console_handler.setFormatter(formatter) logger.addHandler(console_handler) return logger + + +def create_coordinate_transformer(ax, mode, inverse=False): + if inverse: + transformation = lambda *args: ax.transData.inverted().transform(*args) + else: + transformation = lambda *args: ax.transData.transform(*args) + + if mode == "vertical": + + def coords_to_px(coords): + return transformation(coords) + + else: + + def coords_to_px(coords): + return transformation(tuple(reversed(coords))) + + return coords_to_px + + +def find_level(ax, annotations, mode): + + # Sort annotations for optimized stacking + sorted_annotations = sorted( + annotations, + key=lambda x: ( + min(get_positions(ax, x[0], x[1], mode)), + abs( + get_positions(ax, x[0], x[1], mode)[1] + - get_positions(ax, x[0], x[1], mode)[0] + ), + ), + ) + + levels = [] + + for box1, box2, pvalue in sorted_annotations: + + # Retrieve positions + box_positions = get_positions(ax, box1, box2, mode) + box1_pos = min(box_positions) + box2_pos = max(box_positions) + + # Find the first available level + for level_index, level in enumerate(levels): + if all( + box1_pos >= existing_end or box2_pos <= existing_start + for existing_start, existing_end, _, _ in level + ): + levels[level_index].append((box1_pos, box2_pos, level_index, pvalue)) + break + else: + # Create new level + level_index = len(levels) + levels.append([(box1_pos, box2_pos, level_index, pvalue)]) + + # Flatten list of lists + return [annotation for level in levels for annotation in level]