Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/probeinterface/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def plot_probe(
ylims: tuple | None = None,
zlims: tuple | None = None,
show_channel_on_click: bool = False,
side=None,
):
"""Plot a Probe object.
Generates a 2D or 3D axis, depending on Probe.ndim
Expand Down Expand Up @@ -138,6 +139,8 @@ def plot_probe(
Limits for z dimension
show_channel_on_click : bool, default: False
If True, the channel information is shown upon click
side : None | "front" | "back"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of think that I would not know what to expect between front and back or left or right? Both of them seem equally arbitrary, maybe we should type with literal so this surfaces to end users in docstrings more directly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my mind, the manufacturer define the official front and back.
Wre could also put A and B like in good old time for audio tapes.

If the probe is two side, then the side must be given otherwise this raises an error.

Returns
-------
Expand All @@ -148,6 +151,15 @@ def plot_probe(
"""
import matplotlib.pyplot as plt

if probe.contact_sides is not None:
if side is None or side not in ("front", "back"):
raise ValueError(
"The probe has two side, you must give which one to plot. plot_probe(probe, side='front'|'back')"
)
mask = probe.contact_sides == side
probe = probe.get_slice(mask)
probe._contact_sides = None

if ax is None:
if probe.ndim == 2:
fig, ax = plt.subplots()
Expand Down
94 changes: 77 additions & 17 deletions src/probeinterface/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(
self.probe_planar_contour = None

# This handles the shank id per contact
# If None then one shank only
self._shank_ids = None

# This handles the wiring to device : channel index on device side.
Expand All @@ -112,6 +113,10 @@ def __init__(
# This must be unique at Probe AND ProbeGroup level
self._contact_ids = None

# Handle contact side for double face probes
# If None then one face only
self._contact_sides = None

# annotation: a dict that contains all meta information about
# the probe (name, manufacturor, date of production, ...)
self.annotations = dict()
Expand Down Expand Up @@ -153,6 +158,10 @@ def contact_ids(self):
def shank_ids(self):
return self._shank_ids

@property
def contact_sides(self):
return self._contact_sides

@property
def name(self):
return self.annotations.get("name", None)
Expand Down Expand Up @@ -237,6 +246,8 @@ def get_title(self) -> str:
if self.shank_ids is not None:
num_shank = self.get_shank_count()
txt += f" - {num_shank}shanks"
if self._contact_sides is not None:
txt += f" - 2 sides"
return txt

def __repr__(self):
Expand Down Expand Up @@ -291,7 +302,14 @@ def get_shank_count(self) -> int:
return n

def set_contacts(
self, positions, shapes="circle", shape_params={"radius": 10}, plane_axes=None, contact_ids=None, shank_ids=None
self,
positions,
shapes="circle",
shape_params={"radius": 10},
plane_axes=None,
contact_ids=None,
shank_ids=None,
contact_sides=None,
):
"""Sets contacts to a Probe.

Expand Down Expand Up @@ -320,16 +338,28 @@ def set_contacts(
shank_ids : array[str] | None, default: None
Defines the shank ids for the contacts. If None, then
these are assigned to a unique Shank.
contact_sides : array[str] | None, default: None
If probe is double sided, defines sides by a vector of ['front' | 'back']
"""
positions = np.array(positions)
if positions.shape[1] != self.ndim:
raise ValueError(f"positions.shape[1]: {positions.shape[1]} and ndim: {self.ndim} do not match!")

# Check for duplicate positions
unique_positions = np.unique(positions, axis=0)
positions_are_not_unique = unique_positions.shape[0] != positions.shape[0]
if positions_are_not_unique:
_raise_non_unique_positions_error(positions)
if contact_sides is None:
# Check for duplicate positions
unique_positions = np.unique(positions, axis=0)
positions_are_not_unique = unique_positions.shape[0] != positions.shape[0]
if positions_are_not_unique:
_raise_non_unique_positions_error(positions)
else:
# Check for duplicate positions side by side
contact_sides = np.asarray(contact_sides).astype(str)
for side in ("front", "back"):
mask = contact_sides == "front"
unique_positions = np.unique(positions[mask], axis=0)
positions_are_not_unique = unique_positions.shape[0] != positions[mask].shape[0]
if positions_are_not_unique:
_raise_non_unique_positions_error(positions[mask])

self._contact_positions = positions
n = positions.shape[0]
Expand All @@ -356,6 +386,15 @@ def set_contacts(
if self.shank_ids.size != n:
raise ValueError(f"shank_ids have wrong size: {self.shanks.ids.size} != {n}")

if contact_sides is None:
self._contact_sides = contact_sides
else:
self._contact_sides = contact_sides
if self._contact_sides.size != n:
raise ValueError(f"contact_sides have wrong size: {self._contact_sides.ids.size} != {n}")
if not np.all(np.isin(self._contact_sides, ["front", "back"])):
raise ValueError(f"contact_sides must 'front' or 'back'")

# shape
if isinstance(shapes, str):
shapes = [shapes] * n
Expand Down Expand Up @@ -592,6 +631,13 @@ def __eq__(self, other):
):
return False

if self._contact_sides is None:
if other._contact_sides is not None:
return False
else:
if not np.array_equal(self._contact_sides, other._contact_sides):
return False

# Compare contact_annotations dictionaries
if self.contact_annotations.keys() != other.contact_annotations.keys():
return False
Expand Down Expand Up @@ -842,6 +888,7 @@ def rotate_contacts(self, thetas: float | np.array[float] | list[float]):
"device_channel_indices",
"_contact_ids",
"_shank_ids",
"_contact_sides",
]

def to_dict(self, array_as_list: bool = False) -> dict:
Expand Down Expand Up @@ -895,6 +942,9 @@ def from_dict(d: dict) -> "Probe":
plane_axes=d["contact_plane_axes"],
shapes=d["contact_shapes"],
shape_params=d["contact_shape_params"],
contact_ids=d.get("contact_ids", None),
shank_ids=d.get("shank_ids", None),
contact_sides=d.get("contact_sides", None),
)

v = d.get("probe_planar_contour", None)
Expand All @@ -905,14 +955,6 @@ def from_dict(d: dict) -> "Probe":
if v is not None:
probe.set_device_channel_indices(v)

v = d.get("shank_ids", None)
if v is not None:
probe.set_shank_ids(v)

v = d.get("contact_ids", None)
if v is not None:
probe.set_contact_ids(v)

if "annotations" in d:
probe.annotate(**d["annotations"])
if "contact_annotations" in d:
Expand Down Expand Up @@ -955,6 +997,7 @@ def to_numpy(self, complete: bool = False) -> np.array:
...
('shank_ids', 'U64'),
('contact_ids', 'U64'),
('contact_sides', 'U8'),

# The rest is added only if `complete=True`
('device_channel_indices', 'int64', optional),
Expand Down Expand Up @@ -991,6 +1034,11 @@ def to_numpy(self, complete: bool = False) -> np.array:
dtype += [(k, "float64")]
dtype += [("shank_ids", "U64"), ("contact_ids", "U64")]

if self._contact_sides is not None:
dtype += [
("contact_sides", "U8"),
]

if complete:
dtype += [("device_channel_indices", "int64")]
dtype += [("si_units", "U64")]
Expand All @@ -1014,6 +1062,9 @@ def to_numpy(self, complete: bool = False) -> np.array:

arr["shank_ids"] = self.shank_ids

if self._contact_sides is not None:
arr["contact_sides"] = self.contact_sides

if self.contact_ids is None:
arr["contact_ids"] = [""] * self.get_contact_count()
else:
Expand Down Expand Up @@ -1062,6 +1113,7 @@ def from_numpy(arr: np.ndarray) -> "Probe":
"contact_shapes",
"shank_ids",
"contact_ids",
"contact_sides",
"device_channel_indices",
"radius",
"width",
Expand Down Expand Up @@ -1118,14 +1170,22 @@ def from_numpy(arr: np.ndarray) -> "Probe":
else:
plane_axes = None

probe.set_contacts(positions=positions, plane_axes=plane_axes, shapes=shapes, shape_params=shape_params)
shank_ids = arr["shank_ids"] if "shank_ids" in fields else None
contact_sides = arr["contact_sides"] if "contact_sides" in fields else None

probe.set_contacts(
positions=positions,
plane_axes=plane_axes,
shapes=shapes,
shape_params=shape_params,
shank_ids=shank_ids,
contact_sides=contact_sides,
)

if "device_channel_indices" in fields:
dev_channel_indices = arr["device_channel_indices"]
if not np.all(dev_channel_indices == -1):
probe.set_device_channel_indices(dev_channel_indices)
if "shank_ids" in fields:
probe.set_shank_ids(arr["shank_ids"])
if "contact_ids" in fields:
probe.set_contact_ids(arr["contact_ids"])

Expand Down
25 changes: 24 additions & 1 deletion tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,30 @@ def test_plot_probegroup():
plot_probegroup(probegroup_3d, same_axes=True)


def test_plot_probe_two_side():
probe = Probe()
probe.set_contacts(
positions=np.array(
[
[0, 0],
[0, 10],
[0, 20],
[0, 0],
[0, 10],
[0, 20],
]
),
shapes="circle",
contact_ids=["F1", "F2", "F3", "B1", "B2", "B3"],
contact_sides=["front", "front", "front", "back", "back", "back"],
)

plot_probe(probe, with_contact_id=True, side="front")
plot_probe(probe, with_contact_id=True, side="back")


if __name__ == "__main__":
test_plot_probe()
# test_plot_probe()
# test_plot_probe_group()
test_plot_probe_two_side()
plt.show()
33 changes: 33 additions & 0 deletions tests/test_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,42 @@ def test_position_uniqueness():
probe.set_contacts(positions=positions_with_dups, shapes="circle", shape_params={"radius": 5})


def test_double_side_probe():

probe = Probe()
probe.set_contacts(
positions=np.array(
[
[0, 0],
[0, 10],
[0, 20],
[0, 0],
[0, 10],
[0, 20],
]
),
shapes="circle",
contact_sides=["front", "front", "front", "back", "back", "back"],
)
print(probe)

assert "contact_sides" in probe.to_dict()

probe2 = Probe.from_dict(probe.to_dict())
assert probe2 == probe

probe3 = Probe.from_numpy(probe.to_numpy())
assert probe3 == probe

probe4 = Probe.from_dataframe(probe.to_dataframe())
assert probe4 == probe


if __name__ == "__main__":
test_probe()

tmp_path = Path("tmp")
tmp_path.mkdir(exist_ok=True)
test_save_to_zarr(tmp_path)

test_double_side_probe()
Loading