Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 35 additions & 23 deletions src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,17 @@ def is_filtered(self):
# the is_filtered is handle with annotation
return self._annotations.get("is_filtered", False)

def set_probe(self, probe, group_mode="by_probe", in_place=False):
def set_probe(self, probe, group_mode="auto", in_place=False):
"""
Attach a list of Probe object to a recording.

Parameters
----------
probe_or_probegroup: Probe, list of Probe, or ProbeGroup
The probe(s) to be attached to the recording
group_mode: "by_probe" | "by_shank", default: "by_probe
"by_probe" or "by_shank". Adds grouping property to the recording based on the probes ("by_probe")
or shanks ("by_shanks")
group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto"
How to add the "group" property.
"auto" is the best splitting possible that can be all at once when multiple, probe with multiple shanks and 2 sides.
in_place: bool
False by default.
Useful internally when extractor do self.set_probegroup(probe)
Expand All @@ -86,10 +86,10 @@ def set_probe(self, probe, group_mode="by_probe", in_place=False):
probegroup.add_probe(probe)
return self._set_probes(probegroup, group_mode=group_mode, in_place=in_place)

def set_probegroup(self, probegroup, group_mode="by_probe", in_place=False):
def set_probegroup(self, probegroup, group_mode="auto", in_place=False):
return self._set_probes(probegroup, group_mode=group_mode, in_place=in_place)

def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False):
def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False):
"""
Attach a list of Probe objects to a recording.
For this Probe.device_channel_indices is used to link contacts to recording channels.
Expand All @@ -103,9 +103,9 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False
----------
probe_or_probegroup: Probe, list of Probe, or ProbeGroup
The probe(s) to be attached to the recording
group_mode: "by_probe" | "by_shank", default: "by_probe"
"by_probe" or "by_shank". Adds grouping property to the recording based on the probes ("by_probe")
or shanks ("by_shank")
group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto"
How to add the "group" property.
"auto" is the best splitting possible that can be all at once when multiple, probe with multiple shanks and 2 sides.
in_place: bool
False by default.
Useful internally when extractor do self.set_probegroup(probe)
Expand All @@ -115,7 +115,7 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False
sub_recording: BaseRecording
A view of the recording (ChannelSlice or clone or itself)
"""
assert group_mode in ("by_probe", "by_shank"), "'group_mode' can be 'by_probe' or 'by_shank'"
assert group_mode in ("auto", "by_probe", "by_shank", "by_side"), "'group_mode' can be 'auto' 'by_probe' 'by_shank' or 'by_side'"

# handle several input possibilities
if isinstance(probe_or_probegroup, Probe):
Expand Down Expand Up @@ -199,20 +199,32 @@ def _set_probes(self, probe_or_probegroup, group_mode="by_probe", in_place=False
sub_recording.set_property("location", locations, ids=None)

# handle groups
groups = np.zeros(probe_as_numpy_array.size, dtype="int64")
if group_mode == "by_probe":
for group, probe_index in enumerate(np.unique(probe_as_numpy_array["probe_index"])):
mask = probe_as_numpy_array["probe_index"] == probe_index
groups[mask] = group
all_has_shank_id = all(probe.shank_ids is not None for probe in probegroup.probes)
all_has_contact_side = all(probe.contact_sides is not None for probe in probegroup.probes)
if group_mode == "auto":
group_keys = ["probe_index"]
if all_has_shank_id:
group_keys += ["shank_ids"]
if all_has_contact_side:
group_keys += ["contact_sides"]
elif group_mode == "by_probe":
group_keys = ["probe_index"]
elif group_mode == "by_shank":
assert all(
probe.shank_ids is not None for probe in probegroup.probes
), "shank_ids is None in probe, you cannot group by shank"
for group, a in enumerate(np.unique(probe_as_numpy_array[["probe_index", "shank_ids"]])):
mask = (probe_as_numpy_array["probe_index"] == a["probe_index"]) & (
probe_as_numpy_array["shank_ids"] == a["shank_ids"]
)
groups[mask] = group
assert all_has_shank_id, "shank_ids is None in probe, you cannot group by shank"
group_keys = ["probe_index", "shank_ids"]
elif group_mode == "by_side":
assert all_has_contact_side, "contact_sides is None in probe, you cannot group by side"
if all_has_shank_id:
group_keys = ["probe_index", "shank_ids", "contact_sides"]
else:
group_keys = ["probe_index", "contact_sides"]
groups = np.zeros(probe_as_numpy_array.size, dtype="int64")
unique_keys = np.unique(probe_as_numpy_array[group_keys])
for group, a in enumerate(unique_keys):
mask = np.ones(probe_as_numpy_array.size, dtype=bool)
for k in group_keys:
mask &= (probe_as_numpy_array[k] == a[k])
groups[mask] = group
sub_recording.set_property("group", groups, ids=None)

# add probe annotations to recording
Expand Down
12 changes: 8 additions & 4 deletions src/spikeinterface/core/tests/test_baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,15 @@ def test_BaseRecording(create_cache_folder):

# set/get Probe only 2 channels
probe = Probe(ndim=2)
positions = [[0.0, 0.0], [0.0, 15.0], [0, 30.0]]
probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5})
probe.set_device_channel_indices([2, -1, 0])
positions = [[0.0, 0.0], [0.0, 15.0], [0, 30.0],
[100.0, 0.0], [100.0, 15.0], [100.0, 30.0],
]
probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5}, shank_ids=["a"]*3 + ["b"]*3)
probe.set_device_channel_indices([2, -1, 0, -1, -1, -1 ], )
probe.create_auto_shape()

print("ici", probe.shank_ids)

rec_p = rec.set_probe(probe, group_mode="by_shank")
rec_p = rec.set_probe(probe, group_mode="by_probe")
positions2 = rec_p.get_channel_locations()
Expand Down Expand Up @@ -216,7 +220,7 @@ def test_BaseRecording(create_cache_folder):
# set unconnected probe
probe = Probe(ndim=2)
positions = [[0.0, 0.0], [0.0, 15.0], [0, 30.0]]
probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5})
probe.set_contacts(positions=positions, shapes="circle", shape_params={"radius": 5}, shank_ids=["a", "a", "a"])
probe.set_device_channel_indices([-1, -1, -1])
probe.create_auto_shape()

Expand Down