-
Notifications
You must be signed in to change notification settings - Fork 27
Description
Dear developer,
can you help me to find what is the issue behind this aspect?
bbknn should accept a custom distance metric but it provides a warning and returns to use its standard euclidean distance.
Here the code:
@numba.jit(nopython=True, fastmath=True)
def sp_distance(x: np.ndarray, y: np.ndarray) -> np.float32:
min_val = np.float32(0.1); max_val = np.float32(0.9); epsilon = np.float32(1e-7)
if x.shape[0] == 0 or y.shape[0] == 0: return max_val - min_val
global_min = np.inf; global_max = -np.inf
if x.shape[0] > 0: min_x = np.min(x); max_x = np.max(x); global_min = min(global_min, min_x); global_max = max(global_max, max_x)
if y.shape[0] > 0: min_y = np.min(y); max_y = np.max(y); global_min = min(global_min, min_y); global_max = max(global_max, max_y)
if global_min == np.inf or global_max == -np.inf: return max_val - min_val
diff = global_max - global_min
if diff <= epsilon: return np.float32(0.0)
scale_factor = (max_val - min_val) / diff
x_scaled = (x - global_min) * scale_factor + min_val; y_scaled = (y - global_min) * scale_factor + min_val
ranges = np.abs(x_scaled - y_scaled); avg_diff = np.mean(ranges); return avg_diff
bbknn.bbknn(adata, metric=sp_distance, pynndescent_n_neighbors=40, batch_key="Patient_ID", computation="pynndescent")
WARNING: unrecognised metric for type of neighbor calculation, switching to euclidean.`
So then I did a monkey patch to allow bbknn to run with my distance function. The code runs correctly but the clustering results are very strange, despite a very low resolution value provided to leiden, I end up with hundred of clusters instead of 5. This makes me doubting of my monkey patch and maybe I am missing something of the strategy applied to bbknn.
"""Batch balanced KNN"""
version = "1.6.0"
import bbknn
import bbknn.matrix
import pynndescent
import pynndescent.distances
from scanpy import logging as logg
import numpy as np # Keep for the example in main
import numba # Keep for the example in main
original_matrix_bbknn_func = bbknn.matrix.bbknn
def patched_matrix_bbknn_callable_metric(*args, **kwargs):
metric_arg = kwargs.get('metric', 'euclidean')
computation_arg = kwargs.get('computation', 'annoy')
if computation_arg == 'pynndescent' and callable(metric_arg):
user_callable_metric_func = metric_arg # This is the function object passed by the user
metric_func_name = getattr(user_callable_metric_func, '__name__', 'custom_callable_for_bbknn')
# Ensure the name is a simple string (e.g. some callables might not have __name__ or it might be complex)
if not isinstance(metric_func_name, str) or not metric_func_name.isidentifier():
metric_func_name = 'custom_callable_for_bbknn'
logg.info(f"PATCH: Applying monkey patch for bbknn to use callable metric '{metric_func_name}' with pynndescent.")
try:
_ = pynndescent.distances.named_distances
except AttributeError as e:
logg.error(f"PATCH: pynndescent.distances.named_distances not accessible ({e}). "
"Ensure pynndescent is correctly installed and imported. Patch may fail. "
"Proceeding with original metric argument by calling original bbknn.matrix.bbknn.")
return original_matrix_bbknn_func(*args, **kwargs)
original_function_if_any = pynndescent.distances.named_distances.get(metric_func_name)
# Temporarily register the user's callable function with pynndescent
pynndescent.distances.named_distances[metric_func_name] = user_callable_metric_func
# Create new kwargs, replacing the callable metric with its registered string name
patched_kwargs = kwargs.copy()
patched_kwargs['metric'] = metric_func_name # Pass the string name
try:
# Call the original bbknn.matrix.bbknn with the modified metric name
result = original_matrix_bbknn_func(*args, **patched_kwargs)
finally:
# Restore pynndescent.distances.named_distances to its original state
if original_function_if_any is not None:
pynndescent.distances.named_distances[metric_func_name] = original_function_if_any
else:
if metric_func_name in pynndescent.distances.named_distances:
del pynndescent.distances.named_distances[metric_func_name]
logg.info(f"PATCH: Restored pynndescent.distances for metric '{metric_func_name}'.")
return result
else:
# For any other case (e.g., different computation method or standard string metric),
# just call the original function with original args.
return original_matrix_bbknn_func(*args, **kwargs)
bbknn.matrix.bbknn = patched_matrix_bbknn_callable_metric # Use the new patched function name
logg.info("BBKNN monkey patch (for generic callables) applied.")
What do you think could be the issue?