I have built a nest simulation for a binary dataset where we have two classes normal and anomalous . So i tried to built a model iun such a way so that the model learns to produce no spike in case of a normal data but when it sees a anomalous data it is producing the spikes. So this model was working fine. Now I am trying to move forward with it and try it on image dataset where I am using simple MNIST data and I have just taken image for one digit and trying to build a same kind of architecture where it should produce almost no spike when it sees a digit one and oin case of other digits it should produce spikes but I am facing issues in it . The model is not able to produce enough spikes and when I am creating more exc connections then spikes are not reducing with each epochs so can you please suggest me a better approach or some suggestion for my current approach.
# loading dataset from sklearn.datasets import fetch_openml import numpy as np
def preprocess_mnist_digit(digit=1): """ Load and preprocess the MNIST dataset to only include images of a specific digit. Normalize pixel values by dividing by 4 and flatten images.
Args: digit (int): The digit to filter (e.g., 1 for digit '1'). num_samples (int): Number of samples to retain.
Returns: flattened_images (array): Flattened MNIST images of the specified digit. labels (array): Corresponding labels. """ # Load MNIST data mnist = fetch_openml('mnist_784', version=1) data, labels = mnist['data'], mnist['target'].astype(int) # Convert labels to integers
# Filter images that correspond to the specified digit digit_indices = np.where(labels == digit)[0] # Get indices where label is equal to the specified digit
# Extract the corresponding images and labels filtered_images = data.iloc[digit_indices].to_numpy(dtype=float) / 4.0 # Normalize filtered_labels = labels[digit_indices]
return filtered_images, filtered_labels
# Load only images of the digit '1' images, labels = preprocess_mnist_digit(digit=1)
# building simulation
nest.ResetKernel() # Simulation parameters dt = 0.01 # Simulation time step in ms nest.SetKernelStatus({"resolution": dt})
N = 784 # Number of neurons images=images[:10] gmax = 3 neuron_params = { "tau_m": 140.0, # Membrane time constant in ms "E_L": -70.6, # Resting potential (in mV) "V_th": -40.4, # Threshold potential (in mV) "V_reset": -70.6, # Reset potential (in mV) "V_m":0.0, }
nest.CopyModel("stdp_synapse", "stdp_synapse_exc", {"weight":gmax,"alpha":1,"Wmax":gmax, "tau_plus": 20.0,"lambda": -0.0005, "mu_minus": 0, "mu_plus": 0}) nest.CopyModel("stdp_synapse", "stdp_synapse_inh", {"weight":-gmax/(N-1),"alpha":1,"Wmax":-gmax, "tau_plus": 20.0,"lambda": 0.005, "mu_minus": 0, "mu_plus": 0})
G_spike_gen = nest.Create("poisson_generator",N) G_input = nest.Create("parrot_neuron", N)
input_spike_recorder = nest.Create("spike_recorder",N)
nest.Connect(G_input, input_spike_recorder,"one_to_one")
G_hidden = nest.Create("iaf_psc_delta", N,params=neuron_params) G_output = nest.Create("iaf_psc_delta", 1,params=neuron_params)
hidden_spike_recorder = nest.Create("spike_recorder",N)
nest.Connect(G_hidden, hidden_spike_recorder,"one_to_one")
S_in = nest.Connect(G_spike_gen, G_input, "one_to_one")
# Connect input to hidden layer with excitatory STDP synapses S_exc = nest.Connect(G_input, G_hidden,"one_to_one",syn_spec = "stdp_synapse_exc")
# Retrieve the neuron IDs for G_input and G_hidden input_ids = G_input.tolist() hidden_ids = G_hidden.tolist()
# Get the existing excitatory connections and extract their source-target ID pairs exc_connections = nest.GetConnections(G_input, G_hidden) exc_connection_pairs = set(zip(exc_connections.get("source"), exc_connections.get("target")))
for i in input_ids: for h in hidden_ids: if (i, h) not in exc_connection_pairs: nc1 = nest.NodeCollection([i]) nc2 = nest.NodeCollection([h]) nest.Connect(nc1, nc2, syn_spec= "stdp_synapse_inh")
S_out = nest.Connect(G_hidden, G_output,"all_to_all",syn_spec = {"weight": 40.4/N})
# training function def train_net(images, spike_tol=0):
output_spike_recorder = nest.Create("spike_recorder") output_membrane_recorder = nest.Create("voltmeter", 1) nest.Connect(G_output, output_spike_recorder) # Record output layer spikes nest.Connect(output_membrane_recorder, G_output, "one_to_one") while True: for image in images: for i in range(784): G_spike_gen[i].rate=image[i] nest.Simulate(500) for i in range(784): spikes = nest.GetStatus(hidden_spike_recorder[i], "events")[0] # if length of spikes is not zero then print the number of spikes and the times of spikes also print the index of the input_spike_recorder if len(spikes["times"]) != 0: print("Number of spikes in hidden_spike_recorder",i,":", len(spikes["times"])) print("Times of spikes in hidden_spike_recorder",i,":", spikes["times"]) print("\n")
output_spikes_num = len(output_spike_recorder.get()["events"]["times"]) print("# of output spikes: " + str(output_spikes_num)) # del recorder_for_epoch if output_spikes_num <= spike_tol: print("Training finished") break
return output_membrane_recorder, output_spike_recorder
v, spk = train_net(images, spike_tol = 2) ___________________________________________