Skip to content

Custom function does not work #65

@LucaGiudice

Description

@LucaGiudice

Dear developer,
can you help me to find what is the issue behind this aspect?
bbknn should accept a custom distance metric but it provides a warning and returns to use its standard euclidean distance.
Here the code:

@numba.jit(nopython=True, fastmath=True)
def sp_distance(x: np.ndarray, y: np.ndarray) -> np.float32:
min_val = np.float32(0.1); max_val = np.float32(0.9); epsilon = np.float32(1e-7)
if x.shape[0] == 0 or y.shape[0] == 0: return max_val - min_val
global_min = np.inf; global_max = -np.inf
if x.shape[0] > 0: min_x = np.min(x); max_x = np.max(x); global_min = min(global_min, min_x); global_max = max(global_max, max_x)
if y.shape[0] > 0: min_y = np.min(y); max_y = np.max(y); global_min = min(global_min, min_y); global_max = max(global_max, max_y)
if global_min == np.inf or global_max == -np.inf: return max_val - min_val
diff = global_max - global_min
if diff <= epsilon: return np.float32(0.0)
scale_factor = (max_val - min_val) / diff
x_scaled = (x - global_min) * scale_factor + min_val; y_scaled = (y - global_min) * scale_factor + min_val
ranges = np.abs(x_scaled - y_scaled); avg_diff = np.mean(ranges); return avg_diff

bbknn.bbknn(adata, metric=sp_distance, pynndescent_n_neighbors=40, batch_key="Patient_ID", computation="pynndescent")
WARNING: unrecognised metric for type of neighbor calculation, switching to euclidean.`

So then I did a monkey patch to allow bbknn to run with my distance function. The code runs correctly but the clustering results are very strange, despite a very low resolution value provided to leiden, I end up with hundred of clusters instead of 5. This makes me doubting of my monkey patch and maybe I am missing something of the strategy applied to bbknn.

"""Batch balanced KNN"""
version = "1.6.0"

import bbknn
import bbknn.matrix
import pynndescent
import pynndescent.distances
from scanpy import logging as logg
import numpy as np # Keep for the example in main
import numba # Keep for the example in main

original_matrix_bbknn_func = bbknn.matrix.bbknn

def patched_matrix_bbknn_callable_metric(*args, **kwargs):

metric_arg = kwargs.get('metric', 'euclidean') 
computation_arg = kwargs.get('computation', 'annoy')

if computation_arg == 'pynndescent' and callable(metric_arg):
    
    user_callable_metric_func = metric_arg # This is the function object passed by the user
    metric_func_name = getattr(user_callable_metric_func, '__name__', 'custom_callable_for_bbknn')
    
    # Ensure the name is a simple string (e.g. some callables might not have __name__ or it might be complex)
    if not isinstance(metric_func_name, str) or not metric_func_name.isidentifier():
        metric_func_name = 'custom_callable_for_bbknn'

    logg.info(f"PATCH: Applying monkey patch for bbknn to use callable metric '{metric_func_name}' with pynndescent.")

    try:
        _ = pynndescent.distances.named_distances 
    except AttributeError as e:
         logg.error(f"PATCH: pynndescent.distances.named_distances not accessible ({e}). "
                    "Ensure pynndescent is correctly installed and imported. Patch may fail. "
                    "Proceeding with original metric argument by calling original bbknn.matrix.bbknn.")
         return original_matrix_bbknn_func(*args, **kwargs)

    original_function_if_any = pynndescent.distances.named_distances.get(metric_func_name)
    
    # Temporarily register the user's callable function with pynndescent
    pynndescent.distances.named_distances[metric_func_name] = user_callable_metric_func
    
    # Create new kwargs, replacing the callable metric with its registered string name
    patched_kwargs = kwargs.copy()
    patched_kwargs['metric'] = metric_func_name # Pass the string name
    
    try:
        # Call the original bbknn.matrix.bbknn with the modified metric name
        result = original_matrix_bbknn_func(*args, **patched_kwargs)
    finally:
        # Restore pynndescent.distances.named_distances to its original state
        if original_function_if_any is not None:
            pynndescent.distances.named_distances[metric_func_name] = original_function_if_any
        else:
            if metric_func_name in pynndescent.distances.named_distances:
                del pynndescent.distances.named_distances[metric_func_name]
        logg.info(f"PATCH: Restored pynndescent.distances for metric '{metric_func_name}'.")
    return result
else:
    # For any other case (e.g., different computation method or standard string metric),
    # just call the original function with original args.
    return original_matrix_bbknn_func(*args, **kwargs)

bbknn.matrix.bbknn = patched_matrix_bbknn_callable_metric # Use the new patched function name
logg.info("BBKNN monkey patch (for generic callables) applied.")

What do you think could be the issue?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions