Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions bluecellulab/circuit/circuit_access/sonata_circuit_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def __init__(self, simulation_config: str | Path | SimulationConfig) -> None:
self.config = SonataSimulationConfig(simulation_config)
circuit_config = self.config.impl.config["network"]
self._circuit = SnapCircuit(circuit_config)
self._inner_edge_pop_names = {
name for name, epop in self._circuit.edges.items()
if getattr(epop.source, "type", None) != "virtual"
}

@property
def available_cell_properties(self) -> set:
Expand Down Expand Up @@ -105,6 +109,47 @@ def get_population_ids(
source_population_name, target_population_name)
return source_popid, target_popid

def _select_edge_pop_names(self, projections) -> list[str]:
edges = self._circuit.edges
all_names = list(edges.keys())

inner = [n for n in all_names if n in self._inner_edge_pop_names]
proj = [n for n in all_names if getattr(edges[n].source, "type", None) == "virtual"]

if projections is False:
return inner

elif projections is True:
# intrinsic + all projections
out, seen = [], set()
for n in inner + proj:
if n not in seen:
out.append(n); seen.add(n)
return out
else: # str / list[str]: intrinsic + requested
requested = [projections] if isinstance(projections, str) else list(projections or [])

out, seen = [], set()
by_source = {}
for n in all_names:
by_source.setdefault(edges[n].source.name, []).append(n)

for n in inner:
if n not in seen:
out.append(n); seen.add(n)

for token in requested:
if token in edges:
if token not in seen:
out.append(token); seen.add(token)
else:
# legacy support: token as source node population name
for n in by_source.get(token, []):
if n not in seen:
out.append(n); seen.add(n)

return out

def extract_synapses(
self, cell_id: CellId, projections: Optional[list[str] | str]
) -> pd.DataFrame:
Expand All @@ -114,13 +159,8 @@ def extract_synapses(
"""
snap_node_id = CircuitNodeId(cell_id.population_name, cell_id.id)
edges = self._circuit.edges
# select edges that are in the projections, if there are projections
if projections is None or len(projections) == 0:
edge_population_names = [x for x in edges]
elif isinstance(projections, str):
edge_population_names = [x for x in edges if edges[x].source.name == projections]
else:
edge_population_names = [x for x in edges if edges[x].source.name in projections]

edge_population_names = self._select_edge_pop_names(projections)

all_synapses_dfs: list[pd.DataFrame] = []
for edge_population_name in edge_population_names:
Expand Down
11 changes: 5 additions & 6 deletions bluecellulab/circuit/config/sonata_simulation_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,11 @@ def __init__(self, config: str | Path | SnapSimulation) -> None:
raise TypeError("Invalid config type.")

def get_all_projection_names(self) -> list[str]:
unique_names = {
n
for n in self.impl.circuit.nodes
if self.impl.circuit.nodes[n].type == "virtual"
}
return list(unique_names)
return [
edge_name
for edge_name, edge_pop in self.impl.circuit.edges.items()
if getattr(edge_pop.source, "type", None) == "virtual"
]

def get_all_stimuli_entries(self) -> list[Stimulus]:
result: list[Stimulus] = []
Expand Down
18 changes: 12 additions & 6 deletions bluecellulab/circuit_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(
self.spike_threshold = self.circuit_access.config.spike_threshold
self.spike_location = self.circuit_access.config.spike_location

self.projections: list[str] = []
self.projections: list[str] | None = None

condition_parameters = self.circuit_access.config.condition_parameters()
set_global_condition_parameters(condition_parameters)
Expand Down Expand Up @@ -183,10 +183,12 @@ def instantiate_gids(
Setting add_stimuli=True,
will automatically set this option to
True.
add_projections:
If True, adds all of the projection blocks of the
circuit config. If False, no projections are added.
If list, adds only the projections in the list.
add_projections: Control whether projection edge populations are considered when adding synapses.
- False (default): intrinsic connectivity only (no projection edge populations)
- True: intrinsic connectivity + all projection edge populations
- list[str]: intrinsic connectivity + the specified projection edge population names
Note:
Names refer to SONATA edge population names (SnapCircuit.edges keys).
intersect_pre_gids : list of gids
Only add synapses to the cells if their
presynaptic gid is in this list
Expand Down Expand Up @@ -255,10 +257,14 @@ def instantiate_gids(
"if you want to specify use add_replay or "
"pre_spike_trains")

# legacy for backward compatibility
if add_projections is None:
add_projections = False

if add_projections is True:
self.projections = self.circuit_access.config.get_all_projection_names()
elif add_projections is False:
self.projections = []
self.projections = None
else:
self.projections = add_projections

Expand Down
1 change: 1 addition & 0 deletions tests/test_cell/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,7 @@ def test_add_synapse_replay():
cell_id = ("hippocampus_neurons", 0)
circuit_sim.instantiate_gids(cell_id,
add_stimuli=True, add_synapses=True,
add_projections=True,
interconnect_cells=False)
cell = circuit_sim.cells[cell_id]
assert len(cell.connections) == 3
Expand Down
23 changes: 14 additions & 9 deletions tests/test_circuit/test_circuit_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,28 +100,33 @@ def test_get_cell_properties(self):

def test_extract_synapses(self):
cell_id = CellId("hippocampus_neurons", 1)
projections = None
res = self.circuit_access.extract_synapses(cell_id, projections)

# intrinsic-only
res = self.circuit_access.extract_synapses(cell_id, False)
assert res.empty

# intrinsic + all projections
res = self.circuit_access.extract_synapses(cell_id, True)
assert res.shape == (1742, 16)
assert all(res["source_popid"] == 2126)

assert all(res["source_population_name"] == "hippocampus_projections")
assert all(res["target_popid"] == 378)
assert all(res[SynapseProperty.POST_SEGMENT_ID] != -1)
assert SynapseProperty.U_HILL_COEFFICIENT not in res.columns
assert SynapseProperty.CONDUCTANCE_RATIO not in res.columns

# projection parameter
# specific projection selection (by edge pop name or legacy alias)
projection = "hippocampus_projections"
res = self.circuit_access.extract_synapses(cell_id, projection)
assert res.shape == (1742, 16)
list_of_single_projection = [projection]
res = self.circuit_access.extract_synapses(cell_id, list_of_single_projection)
assert res.shape == (1742, 16)
empty_projection = []
res = self.circuit_access.extract_synapses(cell_id, empty_projection)

res = self.circuit_access.extract_synapses(cell_id, [projection])
assert res.shape == (1742, 16)

# empty list == no projections requested (intrinsic-only)
res = self.circuit_access.extract_synapses(cell_id, [])
assert res.empty

def test_target_contains_cell(self):
target = "most_central_10_SP_PC"
cell = CellId("hippocampus_neurons", 1)
Expand Down
Loading