I have built a nest simulation for a binary dataset where we have two classes normal and anomalous . So i tried to built a model iun such a way so that the model learns to produce no spike in case of a normal data but when it sees a anomalous data it is producing the spikes. So this model was working fine. Now I am trying to move forward with it and try it on image dataset where I am using simple MNIST data and I have just taken image for one digit and trying to build a same kind of architecture where it should produce almost no spike when it sees a digit one and oin case of other digits it should produce spikes but I am facing issues in it . The model is not able to produce enough spikes and when I am creating more exc connections then spikes are not reducing with each epochs so can you please suggest me a better approach or some suggestion for my current approach.
# loading dataset
from sklearn.datasets import fetch_openml
import numpy as np
def preprocess_mnist_digit(digit=1):
"""
Load and preprocess the MNIST dataset to only include images of a specific digit.
Normalize pixel values by dividing by 4 and flatten images.
Args:
digit (int): The digit to filter (e.g., 1 for digit '1').
num_samples (int): Number of samples to retain.
Returns:
flattened_images (array): Flattened MNIST images of the specified digit.
labels (array): Corresponding labels.
"""
# Load MNIST data
mnist = fetch_openml('mnist_784', version=1)
data, labels = mnist['data'], mnist['target'].astype(int) # Convert labels to integers
# Filter images that correspond to the specified digit
digit_indices = np.where(labels == digit)[0] # Get indices where label is equal to the specified digit
# Extract the corresponding images and labels
filtered_images = data.iloc[digit_indices].to_numpy(dtype=float) / 4.0 # Normalize
filtered_labels = labels[digit_indices]
return filtered_images, filtered_labels
# Load only images of the digit '1'
images, labels = preprocess_mnist_digit(digit=1)
# building simulation
nest.ResetKernel()
# Simulation parameters
dt = 0.01 # Simulation time step in ms
nest.SetKernelStatus({"resolution": dt})
N = 784 # Number of neurons
images=images[:10]
gmax = 3
neuron_params = {
"tau_m": 140.0, # Membrane time constant in ms
"E_L": -70.6, # Resting potential (in mV)
"V_th": -40.4, # Threshold potential (in mV)
"V_reset": -70.6, # Reset potential (in mV)
"V_m":0.0,
}
nest.CopyModel("stdp_synapse", "stdp_synapse_exc", {"weight":gmax,"alpha":1,"Wmax":gmax, "tau_plus": 20.0,"lambda": -0.0005, "mu_minus": 0, "mu_plus": 0})
nest.CopyModel("stdp_synapse", "stdp_synapse_inh", {"weight":-gmax/(N-1),"alpha":1,"Wmax":-gmax, "tau_plus": 20.0,"lambda": 0.005, "mu_minus": 0, "mu_plus": 0})
G_spike_gen = nest.Create("poisson_generator",N)
G_input = nest.Create("parrot_neuron", N)
input_spike_recorder = nest.Create("spike_recorder",N)
nest.Connect(G_input, input_spike_recorder,"one_to_one")
G_hidden = nest.Create("iaf_psc_delta", N,params=neuron_params)
G_output = nest.Create("iaf_psc_delta", 1,params=neuron_params)
hidden_spike_recorder = nest.Create("spike_recorder",N)
nest.Connect(G_hidden, hidden_spike_recorder,"one_to_one")
S_in = nest.Connect(G_spike_gen, G_input, "one_to_one")
# Connect input to hidden layer with excitatory STDP synapses
S_exc = nest.Connect(G_input, G_hidden,"one_to_one",syn_spec = "stdp_synapse_exc")
# Retrieve the neuron IDs for G_input and G_hidden
input_ids = G_input.tolist()
hidden_ids = G_hidden.tolist()
# Get the existing excitatory connections and extract their source-target ID pairs
exc_connections = nest.GetConnections(G_input, G_hidden)
exc_connection_pairs = set(zip(exc_connections.get("source"), exc_connections.get("target")))
for i in input_ids:
for h in hidden_ids:
if (i, h) not in exc_connection_pairs:
nc1 = nest.NodeCollection([i])
nc2 = nest.NodeCollection([h])
nest.Connect(nc1, nc2, syn_spec= "stdp_synapse_inh")
S_out = nest.Connect(G_hidden, G_output,"all_to_all",syn_spec = {"weight": 40.4/N})
# training function
def train_net(images, spike_tol=0):
output_spike_recorder = nest.Create("spike_recorder")
output_membrane_recorder = nest.Create("voltmeter", 1)
nest.Connect(G_output, output_spike_recorder) # Record output layer spikes
nest.Connect(output_membrane_recorder, G_output, "one_to_one")
while True:
for image in images:
for i in range(784):
G_spike_gen[i].rate=image[i]
nest.Simulate(500)
for i in range(784):
spikes = nest.GetStatus(hidden_spike_recorder[i], "events")[0]
# if length of spikes is not zero then print the number of spikes and the times of spikes also print the index of the input_spike_recorder
if len(spikes["times"]) != 0:
print("Number of spikes in hidden_spike_recorder",i,":", len(spikes["times"]))
print("Times of spikes in hidden_spike_recorder",i,":", spikes["times"])
print("\n")
output_spikes_num = len(output_spike_recorder.get()["events"]["times"])
print("# of output spikes: " + str(output_spikes_num))
# del recorder_for_epoch
if output_spikes_num <= spike_tol:
print("Training finished")
break
return output_membrane_recorder, output_spike_recorder
v, spk = train_net(images, spike_tol = 2)
___________________________________________