diff --git a/bluecellulab/circuit/circuit_access/sonata_circuit_access.py b/bluecellulab/circuit/circuit_access/sonata_circuit_access.py index 1f1746d..ade149d 100644 --- a/bluecellulab/circuit/circuit_access/sonata_circuit_access.py +++ b/bluecellulab/circuit/circuit_access/sonata_circuit_access.py @@ -52,6 +52,10 @@ def __init__(self, simulation_config: str | Path | SimulationConfig) -> None: self.config = SonataSimulationConfig(simulation_config) circuit_config = self.config.impl.config["network"] self._circuit = SnapCircuit(circuit_config) + self._inner_edge_pop_names = { + name for name, epop in self._circuit.edges.items() + if getattr(epop.source, "type", None) != "virtual" + } @property def available_cell_properties(self) -> set: @@ -105,6 +109,47 @@ def get_population_ids( source_population_name, target_population_name) return source_popid, target_popid + def _select_edge_pop_names(self, projections) -> list[str]: + edges = self._circuit.edges + all_names = list(edges.keys()) + + inner = [n for n in all_names if n in self._inner_edge_pop_names] + proj = [n for n in all_names if getattr(edges[n].source, "type", None) == "virtual"] + + if projections is False: + return inner + + elif projections is True: + # intrinsic + all projections + out, seen = [], set() + for n in inner + proj: + if n not in seen: + out.append(n); seen.add(n) + return out + else: # str / list[str]: intrinsic + requested + requested = [projections] if isinstance(projections, str) else list(projections or []) + + out, seen = [], set() + by_source = {} + for n in all_names: + by_source.setdefault(edges[n].source.name, []).append(n) + + for n in inner: + if n not in seen: + out.append(n); seen.add(n) + + for token in requested: + if token in edges: + if token not in seen: + out.append(token); seen.add(token) + else: + # legacy support: token as source node population name + for n in by_source.get(token, []): + if n not in seen: + out.append(n); seen.add(n) + + return out + def extract_synapses( self, cell_id: CellId, projections: Optional[list[str] | str] ) -> pd.DataFrame: @@ -114,13 +159,8 @@ def extract_synapses( """ snap_node_id = CircuitNodeId(cell_id.population_name, cell_id.id) edges = self._circuit.edges - # select edges that are in the projections, if there are projections - if projections is None or len(projections) == 0: - edge_population_names = [x for x in edges] - elif isinstance(projections, str): - edge_population_names = [x for x in edges if edges[x].source.name == projections] - else: - edge_population_names = [x for x in edges if edges[x].source.name in projections] + + edge_population_names = self._select_edge_pop_names(projections) all_synapses_dfs: list[pd.DataFrame] = [] for edge_population_name in edge_population_names: diff --git a/bluecellulab/circuit/config/sonata_simulation_config.py b/bluecellulab/circuit/config/sonata_simulation_config.py index f10f0c9..7f8691c 100644 --- a/bluecellulab/circuit/config/sonata_simulation_config.py +++ b/bluecellulab/circuit/config/sonata_simulation_config.py @@ -43,12 +43,11 @@ def __init__(self, config: str | Path | SnapSimulation) -> None: raise TypeError("Invalid config type.") def get_all_projection_names(self) -> list[str]: - unique_names = { - n - for n in self.impl.circuit.nodes - if self.impl.circuit.nodes[n].type == "virtual" - } - return list(unique_names) + return [ + edge_name + for edge_name, edge_pop in self.impl.circuit.edges.items() + if getattr(edge_pop.source, "type", None) == "virtual" + ] def get_all_stimuli_entries(self) -> list[Stimulus]: result: list[Stimulus] = [] diff --git a/bluecellulab/circuit_simulation.py b/bluecellulab/circuit_simulation.py index a37bc84..8c9280e 100644 --- a/bluecellulab/circuit_simulation.py +++ b/bluecellulab/circuit_simulation.py @@ -119,7 +119,7 @@ def __init__( self.spike_threshold = self.circuit_access.config.spike_threshold self.spike_location = self.circuit_access.config.spike_location - self.projections: list[str] = [] + self.projections: list[str] | None = None condition_parameters = self.circuit_access.config.condition_parameters() set_global_condition_parameters(condition_parameters) @@ -183,10 +183,12 @@ def instantiate_gids( Setting add_stimuli=True, will automatically set this option to True. - add_projections: - If True, adds all of the projection blocks of the - circuit config. If False, no projections are added. - If list, adds only the projections in the list. + add_projections: Control whether projection edge populations are considered when adding synapses. + - False (default): intrinsic connectivity only (no projection edge populations) + - True: intrinsic connectivity + all projection edge populations + - list[str]: intrinsic connectivity + the specified projection edge population names + Note: + Names refer to SONATA edge population names (SnapCircuit.edges keys). intersect_pre_gids : list of gids Only add synapses to the cells if their presynaptic gid is in this list @@ -255,10 +257,14 @@ def instantiate_gids( "if you want to specify use add_replay or " "pre_spike_trains") + # legacy for backward compatibility + if add_projections is None: + add_projections = False + if add_projections is True: self.projections = self.circuit_access.config.get_all_projection_names() elif add_projections is False: - self.projections = [] + self.projections = None else: self.projections = add_projections diff --git a/tests/test_cell/test_core.py b/tests/test_cell/test_core.py index 24a7d53..2cd1c5e 100644 --- a/tests/test_cell/test_core.py +++ b/tests/test_cell/test_core.py @@ -670,6 +670,7 @@ def test_add_synapse_replay(): cell_id = ("hippocampus_neurons", 0) circuit_sim.instantiate_gids(cell_id, add_stimuli=True, add_synapses=True, + add_projections=True, interconnect_cells=False) cell = circuit_sim.cells[cell_id] assert len(cell.connections) == 3 diff --git a/tests/test_circuit/test_circuit_access.py b/tests/test_circuit/test_circuit_access.py index c3adb4d..207281b 100644 --- a/tests/test_circuit/test_circuit_access.py +++ b/tests/test_circuit/test_circuit_access.py @@ -100,28 +100,33 @@ def test_get_cell_properties(self): def test_extract_synapses(self): cell_id = CellId("hippocampus_neurons", 1) - projections = None - res = self.circuit_access.extract_synapses(cell_id, projections) + + # intrinsic-only + res = self.circuit_access.extract_synapses(cell_id, False) + assert res.empty + + # intrinsic + all projections + res = self.circuit_access.extract_synapses(cell_id, True) assert res.shape == (1742, 16) assert all(res["source_popid"] == 2126) - assert all(res["source_population_name"] == "hippocampus_projections") assert all(res["target_popid"] == 378) assert all(res[SynapseProperty.POST_SEGMENT_ID] != -1) assert SynapseProperty.U_HILL_COEFFICIENT not in res.columns assert SynapseProperty.CONDUCTANCE_RATIO not in res.columns - # projection parameter + # specific projection selection (by edge pop name or legacy alias) projection = "hippocampus_projections" res = self.circuit_access.extract_synapses(cell_id, projection) assert res.shape == (1742, 16) - list_of_single_projection = [projection] - res = self.circuit_access.extract_synapses(cell_id, list_of_single_projection) - assert res.shape == (1742, 16) - empty_projection = [] - res = self.circuit_access.extract_synapses(cell_id, empty_projection) + + res = self.circuit_access.extract_synapses(cell_id, [projection]) assert res.shape == (1742, 16) + # empty list == no projections requested (intrinsic-only) + res = self.circuit_access.extract_synapses(cell_id, []) + assert res.empty + def test_target_contains_cell(self): target = "most_central_10_SP_PC" cell = CellId("hippocampus_neurons", 1)