-
Notifications
You must be signed in to change notification settings - Fork 6
Added custom inputs and basic cluster analysis #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| ### Brian2 models for network construction. | ||
| import brian2 as b2 | ||
| import numpy as np | ||
| from automind.sim import b2_inputs | ||
|
|
||
|
|
||
| def adaptive_exp_net(all_param_dict): | ||
|
|
@@ -64,14 +65,14 @@ def adaptive_exp_net(all_param_dict): | |
| ) | ||
|
|
||
| ### TO DO: also randomly initialize w to either randint(?)*b or randn*(v-v_rest)*a | ||
|
|
||
| poisson_input_E = b2.PoissonInput( | ||
| ''' poisson_input_E = b2.PoissonInput( | ||
|
||
| target=E_pop, | ||
| target_var="ge", | ||
| N=param_dict_neuron_E["N_poisson"], | ||
| rate=param_dict_neuron_E["poisson_rate"], | ||
| weight=param_dict_neuron_E["Q_poisson"], | ||
| ) | ||
| )''' | ||
|
|
||
|
|
||
| if has_inh: | ||
| # make adlif if delta_t is 0, otherwise adex | ||
|
|
@@ -268,12 +269,21 @@ def make_clustered_network( | |
| return membership, shared_membership, conn_in, conn_out | ||
|
|
||
|
|
||
| def adaptive_exp_net_clustered(all_param_dict): | ||
| """Adaptive exponential integrate-and-fire network with clustered connections.""" | ||
| #Modified function incorporating inputs to specific clusters | ||
| def adaptive_exp_net_clustered_cog(all_param_dict, mode='default', custom_input=None, stim_cluster=None, custom_cluster_input=None): | ||
|
||
| ''' | ||
| Adaptive exponential integrate-and-fire network with clustered connections. | ||
|
|
||
| 3 modes | ||
| - Default mode - no input | ||
| - Single mode - Single input -> User can define any input sequence. DM_simple is used when no inputs are provided | ||
| - Cluster mode - Each cluster gets different input, can be defined by user. DM_simple with different mean is used for each cluster when no inputs are provided. | ||
|
||
| - Can also select number of clusters to stimulate | ||
| ''' | ||
|
|
||
| # separate parameter dictionaries | ||
| param_dict_net = all_param_dict["params_net"] | ||
| param_dict_settings = all_param_dict["params_settings"] | ||
|
|
||
| # set random seeds | ||
| b2.seed(param_dict_settings["random_seed"]) | ||
| np.random.seed(param_dict_settings["random_seed"]) | ||
|
|
@@ -287,25 +297,27 @@ def adaptive_exp_net_clustered(all_param_dict): | |
|
|
||
| #### NETWORK CONSTRUCTION ############ | ||
| ###################################### | ||
|
|
||
| ### get cell counts | ||
| N_pop, exc_prop = param_dict_net["N_pop"], param_dict_net["exc_prop"] | ||
| N_exc = int(N_pop * exc_prop) | ||
| N_inh = N_pop - N_exc | ||
rdgao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| ### define neuron equation | ||
| adex_coba_eq = """dv/dt = (-g_L * (v - v_rest) + g_L * delta_T * exp((v - v_thresh)/delta_T) - w + I)/C : volt (unless refractory)""" | ||
| adlif_coba_eq = ( | ||
| """dv/dt = (-g_L * (v - v_rest) - w + I)/C : volt (unless refractory)""" | ||
| ) | ||
|
|
||
| adlif_coba_eq = """dv/dt = (-g_L * (v - v_rest) - w + I)/C : volt (unless refractory)""" | ||
|
|
||
| network_eqs = """ | ||
| dw/dt = (-w + a * (v - v_rest))/tau_w : amp | ||
| dge/dt = -ge / tau_ge : siemens | ||
| dgi/dt = -gi / tau_gi : siemens | ||
| Ie = ge * (E_ge - v): amp | ||
| Ii = gi * (E_gi - v): amp | ||
| I = I_bias + Ie + Ii : amp | ||
| I_ext: amp | ||
rdgao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| I = I_bias + Ie + Ii + I_ext: amp | ||
| """ | ||
|
|
||
| ### get cell counts | ||
| N_pop, exc_prop = param_dict_net["N_pop"], param_dict_net["exc_prop"] | ||
| N_exc, N_inh = int(N_pop * exc_prop), int(N_pop * (1 - exc_prop)) | ||
|
|
||
| ### make neuron populations, set initial values and connect poisson inputs ### | ||
| # make adlif if delta_t is 0, otherwise adex | ||
| neuron_eq = ( | ||
|
|
@@ -422,6 +434,7 @@ def adaptive_exp_net_clustered(all_param_dict): | |
| p_out, | ||
| param_dict_net["order_clusters"], | ||
| ) | ||
| param_dict_net["membership"] = membership | ||
rdgao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # scale synaptic weight | ||
| Q_ge_out = param_dict_neuron_E["Q_ge"] | ||
|
|
@@ -491,6 +504,93 @@ def adaptive_exp_net_clustered(all_param_dict): | |
| ) | ||
| syn_i2i.connect("i!=j", p=param_dict_net["p_i2i"]) | ||
|
|
||
| ### Handle different input modes ### | ||
| if mode == 'default': #No input | ||
| stim_time_values = b2_inputs.DM_simple(all_param_dict,0,0) #Just change this to an input (pass in an stim array) | ||
|
||
| dt = param_dict_settings["dt"] | ||
| stim_timed_array = b2.TimedArray(stim_time_values * b2.amp, dt=dt) | ||
|
|
||
| # Define network operation to update I_ext | ||
| @b2.network_operation(dt=dt) | ||
| def update_test_input(t): | ||
| E_pop.I_ext = stim_timed_array(t) | ||
|
|
||
| elif mode == 'single': | ||
| # Check if custom input is provided in the parameter dictionary | ||
| custom_input = all_param_dict.get("custom_input", None) | ||
|
||
| if custom_input is not None: | ||
| stim_time_values = custom_input | ||
|
||
| else: | ||
| # Use default DM_simple if no custom input is provided | ||
| stim_time_values = b2_inputs.DM_simple(all_param_dict) | ||
|
|
||
| dt = param_dict_settings["dt"] | ||
| stim_timed_array = b2.TimedArray(stim_time_values * b2.amp, dt=dt) | ||
|
|
||
| # Define network operation to update I_ext | ||
| @b2.network_operation(dt=dt) | ||
| def update_test_input(t): | ||
| E_pop.I_ext = stim_timed_array(t) | ||
|
|
||
| elif mode == 'cluster': | ||
| # Determine if network has clusters | ||
| has_clusters = ( | ||
| "n_clusters" in param_dict_net.keys() | ||
| and param_dict_net["n_clusters"] >= 2 | ||
| and param_dict_net["R_pe2e"] != 1 | ||
| ) | ||
|
|
||
| if has_clusters: | ||
| n_clusters_original = int(param_dict_net["n_clusters"]) | ||
| stimulated_clusters_count = n_clusters_original | ||
|
|
||
| #Check if user defined number of clusters to stimulate | ||
| if stim_cluster is not None: | ||
|
||
| stimulated_clusters_count = stim_cluster | ||
|
||
| if stim_cluster > n_clusters_original: | ||
| stimulated_clusters_count = n_clusters_original | ||
| print(f"No. of clusters picked ({stim_cluster}) exceeds actual no. of clusters. Stimulating all {n_clusters_original} clusters instead.") | ||
| #Select number of clusters | ||
| selected_clusters = np.random.choice( | ||
| n_clusters_original, | ||
| stimulated_clusters_count, | ||
| replace=False | ||
| ) | ||
| cluster_lists = [[c] for c in selected_clusters] | ||
|
|
||
| # Generate cluster-specific inputs for selected clusters - see b2_inputs | ||
| if custom_cluster_input is not None: | ||
| stim_list = custom_cluster_input | ||
| _, weight_list = b2_inputs.cluster_specific_stim( | ||
|
||
| all_param_dict, | ||
| n_clusters=stimulated_clusters_count, | ||
| ) | ||
| else: | ||
| stim_list, weight_list = b2_inputs.cluster_specific_stim( | ||
| all_param_dict, | ||
| n_clusters=stimulated_clusters_count, | ||
| ) | ||
|
|
||
| # Create input configurations | ||
| input_configs = b2_inputs.get_input_configs( | ||
| cluster_lists, | ||
| stim_list, | ||
| weight_list, | ||
| ) | ||
| input_op = b2_inputs.create_input_operation(E_pop, input_configs, membership) | ||
| param_dict_net['input'] = stim_list | ||
|
||
| else: | ||
| # Fallback to test mode if network doesn't have clusters initially | ||
| print("Network does not have clusters. All neurons will receive the same DM_simple input ") | ||
| stim_time_values = b2_inputs.test_stim(all_param_dict) | ||
|
||
| dt = param_dict_settings["dt"] | ||
| stim_timed_array = b2.TimedArray(stim_time_values * b2.amp, dt=dt) | ||
|
|
||
| @b2.network_operation(dt=dt) | ||
| def update_cluster_fallback_input(t): | ||
| E_pop.I_ext = stim_timed_array(t) | ||
| param_dict_net['input'] = stim_timed_array | ||
|
|
||
| ### define monitors ### | ||
| rate_monitors, spike_monitors, trace_monitors = [], [], [] | ||
| rec_defs = param_dict_settings["record_defs"] | ||
|
|
@@ -510,12 +610,12 @@ def adaptive_exp_net_clustered(all_param_dict): | |
| # and later drop randomly before saving, otherwise | ||
| # recording only from first n neurons, which heavily overlap | ||
| # with those stimulated, and the first few clusters | ||
| rec_idx = np.arange(N_exc) | ||
| rec_idx = np.arange(N_exc) | ||
| else: | ||
| rec_idx = ( | ||
| np.arange(rec_defs[pop_name]["spikes"]) | ||
| np.arange(rec_defs[pop_name]["spikes"]) | ||
| if type(rec_defs[pop_name]["spikes"]) is int | ||
| else rec_defs[pop_name]["spikes"] | ||
| else rec_defs[pop_name]["spikes"] #Change param_settings.record_Defs to 2000 | ||
| ) | ||
| spike_monitors.append( | ||
| b2.SpikeMonitor(pop[rec_idx], name=pop_name + "_spikes") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
|
|
||
| def _filter_spikes_random(spike_trains, n_to_save): | ||
| """Filter a subset of spike trains randomly for saving.""" | ||
| np.random.seed(42) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I forgot if we talked about this but should this be reproduceably random, i.e., using the model
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess easiest is just to provide an optional seed with default value=42, otherwise can be provided, e.g., with the simulation random seed |
||
| record_subset = np.sort( | ||
| np.random.choice(len(spike_trains), n_to_save, replace=False) | ||
| ) | ||
|
|
@@ -33,13 +34,14 @@ def collect_spikes(net_collect, params_dict): | |
| sm.name.split("_")[0] | ||
| ]["spikes"] | ||
| n_to_save = pop_save_def if type(pop_save_def) == int else len(pop_save_def) | ||
| #n_to_save = len(spike_trains) | ||
| if n_to_save == len(spike_trains): | ||
| # recorded and to-be saved is the same length, go on a per usual | ||
| spike_dict[sm.name] = b2_interface._deunitize_spiketimes(spike_trains) | ||
| else: | ||
| # recorded more than necessary, subselect for saving | ||
| spike_dict[sm.name] = b2_interface._deunitize_spiketimes( | ||
| _filter_spikes_random(spike_trains, n_to_save) | ||
| _filter_spikes_random(spike_trains, n_to_save) # THIS IS WHERE THE NEURONS ARE RANDOMLY DROPPED BEFORE SAVING (AND PLOTTING??) | ||
|
||
| ) | ||
| return spike_dict | ||
|
|
||
|
|
@@ -674,3 +676,31 @@ def load_df_posteriors(path_dict): | |
| path_dict["root_path"] + path_dict["params_dict_analysis_updated"] | ||
| ) | ||
| return df_posterior_sims, posterior, params_dict_default | ||
|
|
||
| def sort_neurons(membership, sorting_method="cluster_identity"): | ||
| """ | ||
rdgao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Sort neurons based on the specified method. | ||
|
|
||
| Parameters: | ||
| membership (list/array of 2D arrays): Membership arrays for each simulation. | ||
| sorting_method (str): "cluster_identity" or "n_clusters". | ||
|
|
||
| Returns: | ||
| sorted_indices (list of arrays): Sorted indices for each simulation. | ||
| """ | ||
| sorted_indices = [] | ||
| #Sort by whether neurons are in one cluster or two clusters | ||
|
|
||
| if sorting_method == "cluster_identity": | ||
| # Sort by the first cluster identity | ||
| sorted_idx = np.argsort(membership[:, 0]) | ||
| sorted_indices.append(sorted_idx) | ||
| elif sorting_method == "n_clusters": | ||
| #Neurons in one cluster have the same values in both columns | ||
| single = np.where(membership[:,0] == membership[:,1]) | ||
| double = np.where(membership[:,0] != membership[:,1]) | ||
| sorted_indices.append(single) | ||
| sorted_indices.append(double) | ||
| else: | ||
| raise ValueError("Invalid sorting_method. Use 'cluster_identity' or 'n_clusters'.") | ||
| return sorted_indices | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -327,7 +327,6 @@ def _plot_raster_pretty( | |
| ax.set_ylabel("Raster", fontsize=fontsize) | ||
| return ax | ||
|
|
||
|
|
||
| def _plot_rates_pretty( | ||
| rates, | ||
| XL, | ||
|
|
@@ -584,3 +583,126 @@ def plot_corr_pv(pvals, ax, alpha_level=0.05, fmt="w*", ms=0.5): | |
| for j in range(pvals.shape[0]): | ||
| if pvals[i, j] < alpha_level: | ||
| ax.plot(j, i, fmt, ms=ms, alpha=1) | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so this is not being used anymore because you extracted the sorting function right? Then I think it can be removed |
||
|
|
||
| #Just use the default plotting function with the sorted spikes | ||
| ''' | ||
| def plot_raster( | ||
| spikes, | ||
| membership, | ||
| XL, | ||
| plotting_method="cluster_identity", | ||
| every_other=1, | ||
| ax=None, | ||
| fontsize=14, | ||
| plot_inh=False, | ||
| E_colors=None, | ||
| I_color="gray", | ||
| single_cluster_style="|", | ||
| double_cluster_style="x", | ||
| mew=0.5, | ||
| ms=1, | ||
| **plot_kwargs, | ||
| ): | ||
| """ | ||
| Plot raster plot with neurons sorted by cluster identity or number of clusters. | ||
|
|
||
| Parameters: | ||
| spikes (dict): Dictionary containing 'exc_spikes' and 'inh_spikes'. | ||
| membership: Array/list of 2D membership arrays (from params_net['membership']). | ||
| XL (list): X-axis limits. | ||
| plotting_method (str): "cluster_identity" or "n_clusters". | ||
| every_other (int): Plot every nth spike. | ||
| ax (matplotlib axis): Axis to plot on. | ||
| fontsize (int): Font size for labels. | ||
| plot_inh (bool): Whether to plot inhibitory spikes. | ||
| E_colors (list): Colors for excitatory clusters. | ||
| I_color (str): Color for inhibitory spikes. | ||
| single_cluster_style (str): Marker style for single-cluster neurons. | ||
| double_cluster_style (str): Marker style for two-cluster neurons. | ||
| mew (float): Marker edge width. | ||
| ms (float): Marker size. | ||
| """ | ||
| if ax is None: | ||
| ax = plt.axes() | ||
|
|
||
| exc_spikes = spikes["exc_spikes"] | ||
| inh_spikes = spikes.get("inh_spikes", {}) | ||
|
|
||
| if plotting_method == "cluster_identity": | ||
| # Sort by cluster identity | ||
| sorted_indices = data_utils.sort_neurons(membership, sorting_method='cluster_identity') | ||
| sorted_indices_list = sorted_indices[0].tolist() # Convert to list of Python integers | ||
| sorted_exc_spikes = {i: exc_spikes[idx] for i, idx in enumerate(sorted_indices_list)} | ||
| #exc_spikes_to_plot = sorted_exc_spikes.values() | ||
| elif plotting_method == "n_clusters": | ||
| # Sort by number of clusters | ||
| sorted_indices = data_utils.sort_neurons(membership, sorting_method='n_clusters') | ||
| sorted_exc_spikes_single = {i: exc_spikes[idx] for i, idx in enumerate(sorted_indices[0][0])} | ||
| sorted_exc_spikes_double = {i: exc_spikes[idx] for i, idx in enumerate(sorted_indices[1][0])} | ||
| #exc_spikes_to_plot.append(sorted_exc_spikes_single) | ||
| #exc_spikes_to_plot.append(sorted_exc_spikes_double) | ||
| else: | ||
| raise ValueError("Invalid plotting_method. Use 'cluster_identity' or 'n_clusters'.") | ||
|
|
||
| # Plot excitatory spikes, single cluster in blue and double cluster in red respectively | ||
| [ | ||
| ( | ||
| ax.plot( | ||
| v[::every_other], | ||
| i_v * np.ones_like(v[::every_other]), | ||
| single_cluster_style, | ||
| color='blue', | ||
| alpha=1, | ||
| ms=ms, | ||
| mew=mew, | ||
| ) | ||
| if len(v) > 0 | ||
| else None | ||
| ) | ||
| for i_v, (t,v) in enumerate(sorted_exc_spikes_single.items()) | ||
| ] | ||
| [ | ||
| ( | ||
| ax.plot( | ||
| v[::every_other], | ||
| (i_v+ len(sorted_indices[0][0])) * np.ones_like(v[::every_other]), | ||
| single_cluster_style, | ||
| color='red', | ||
| alpha=1, | ||
| ms=ms, | ||
| mew=mew, | ||
| ) | ||
| if len(v) > 0 | ||
| else None | ||
| ) | ||
| for i_v, (t,v) in enumerate(sorted_exc_spikes_double.items()) | ||
| ] | ||
|
|
||
| # Plot inhibitory spikes | ||
| if plot_inh: | ||
| [ | ||
| ( | ||
| ax.plot( | ||
| v[::every_other], | ||
| (i_v + len(sorted_indices[0][0]) + len(sorted_indices[1][0])) * np.ones_like(v[::every_other]), | ||
| "|", | ||
| color=I_color, | ||
| alpha=1, | ||
| ms=ms, | ||
| mew=mew, | ||
| ) | ||
| if len(v) > 0 | ||
| else None | ||
| ) | ||
| for i_v, v in enumerate(inh_spikes.values()) | ||
| ] | ||
|
|
||
| ax.set_xticks([]) | ||
| ax.set_yticks([]) | ||
| ax.spines.left.set_visible(False) | ||
| ax.spines.bottom.set_visible(False) | ||
| ax.set_xlim(XL) | ||
| ax.set_ylabel("Raster", fontsize=fontsize) | ||
| return ax | ||
| ''' | ||
Uh oh!
There was an error while loading. Please reload this page.