By Turing
Spiking Neural Networks
To use this notebook, either download it, or open it in Google Colab:
In [3]:
# Spiking Neural Network with STDP Learning - Debug Version
import numpy as np
import matplotlib.pyplot as plt
class LIFNeuron:
"""
Leaky Integrate-and-Fire Neuron with adjusted parameters for more activity
"""
def __init__(self, tau_m=10.0, v_rest=-70.0, v_reset=-65.0, v_thresh=-50.0,
refractory_period=1.0):
self.tau_m = tau_m # Decreased time constant for faster response
self.v_rest = v_rest
self.v_reset = v_reset
self.v_thresh = v_thresh
self.refractory_period = refractory_period # Shorter refractory period
self.v = v_rest
self.last_spike = -np.inf
self.I = 0
self.spike_count = 0 # For debugging
def step(self, t, dt):
if (t - self.last_spike) <= self.refractory_period:
self.v = self.v_reset
return 0
# More sensitive membrane potential update
dv = (-(self.v - self.v_rest) + self.I * 50) * (dt / self.tau_m) # Increased current sensitivity
self.v += dv
if self.v >= self.v_thresh:
self.last_spike = t
self.v = self.v_reset
self.spike_count += 1
return 1
return 0
class Synapse:
def __init__(self, pre_neuron, post_neuron, w_init=0.5):
self.pre_neuron = pre_neuron
self.post_neuron = post_neuron
self.weight = w_init
# Increased learning rates and adjusted time constants
self.tau_pre = 10.0 # Faster trace decay
self.tau_post = 10.0
self.A_plus = 0.2 # Much stronger LTP
self.A_minus = -0.21 # Slightly stronger LTD
self.w_max = 1.0
self.w_min = 0.0
self.trace_pre = 0.0
self.trace_post = 0.0
self.weight_changes = [] # Track all weight changes
def update(self, t, dt, pre_spike, post_spike):
# Decay traces
self.trace_pre *= np.exp(-dt / self.tau_pre)
self.trace_post *= np.exp(-dt / self.tau_post)
old_weight = self.weight
if pre_spike:
self.trace_pre += 1.0
dw = self.A_minus * self.trace_post
self.weight = np.clip(self.weight + dw, self.w_min, self.w_max)
if post_spike:
self.trace_post += 1.0
dw = self.A_plus * self.trace_pre
self.weight = np.clip(self.weight + dw, self.w_min, self.w_max)
if abs(old_weight - self.weight) > 0.001:
self.weight_changes.append((t, old_weight, self.weight))
class SNN:
def __init__(self, n_input, n_output):
self.input_neurons = [LIFNeuron() for _ in range(n_input)]
self.output_neurons = [LIFNeuron() for _ in range(n_output)]
# Initialize synapses with random weights
self.synapses = []
for pre in self.input_neurons:
for post in self.output_neurons:
w_init = np.random.uniform(0.4, 0.6) # Random initial weights
self.synapses.append(Synapse(pre, post, w_init))
def step(self, t, dt, input_currents):
for n in self.output_neurons:
n.I = 0
input_spikes = []
for n, I in zip(self.input_neurons, input_currents):
n.I = I
input_spikes.append(n.step(t, dt))
# Stronger synaptic transmission
for syn in self.synapses:
pre_idx = self.input_neurons.index(syn.pre_neuron)
if input_spikes[pre_idx]:
syn.post_neuron.I += 100 * syn.weight # Increased synaptic strength
output_spikes = []
for n in self.output_neurons:
output_spikes.append(n.step(t, dt))
for syn in self.synapses:
pre_idx = self.input_neurons.index(syn.pre_neuron)
post_idx = self.output_neurons.index(syn.post_neuron)
syn.update(t, dt, input_spikes[pre_idx], output_spikes[post_idx])
return input_spikes, output_spikes
def get_weights_matrix(self):
W = np.zeros((len(self.input_neurons), len(self.output_neurons)))
for syn in self.synapses:
i = self.input_neurons.index(syn.pre_neuron)
j = self.output_neurons.index(syn.post_neuron)
W[i, j] = syn.weight
return W
# Simulation parameters
T = 1000 # Simulation time
dt = 0.1 # Time step
time = np.arange(0, T, dt)
# Create network
n_input = 4
n_output = 3
network = SNN(n_input, n_output)
# Generate stronger input patterns
def generate_input_pattern(t, pattern_id=0):
if pattern_id == 0:
return [
30 * (1 + np.sin(2 * np.pi * t / 50)), # Faster oscillation
30 * (1 + np.sin(2 * np.pi * t / 50 + np.pi/2)),
10 * np.sin(2 * np.pi * t / 100),
10 * np.sin(2 * np.pi * t / 100 + np.pi/2)
]
else:
return [
10 * np.sin(2 * np.pi * t / 100),
10 * np.sin(2 * np.pi * t / 100 + np.pi/2),
30 * (1 + np.sin(2 * np.pi * t / 50)),
30 * (1 + np.sin(2 * np.pi * t / 50 + np.pi/2))
]
# Run simulation
input_spikes_history = []
output_spikes_history = []
weight_history = []
v_history = []
pattern_switch = T/2
for t in time:
pattern_id = 0 if t < pattern_switch else 1
input_currents = generate_input_pattern(t, pattern_id)
input_spikes, output_spikes = network.step(t, dt, input_currents)
input_spikes_history.append(input_spikes)
output_spikes_history.append(output_spikes)
v_history.append([n.v for n in network.output_neurons])
if int(t/dt) % 100 == 0:
weight_history.append(network.get_weights_matrix())
# Convert to numpy arrays
input_spikes_history = np.array(input_spikes_history)
output_spikes_history = np.array(output_spikes_history)
v_history = np.array(v_history)
weight_history = np.array(weight_history)
# Plotting
plt.figure(figsize=(15, 12))
# Plot membrane potentials
plt.subplot(411)
for i in range(n_output):
plt.plot(time, v_history[:, i], label=f'Neuron {i+1}')
plt.ylabel('Membrane Potential (mV)')
plt.title('Output Neuron Membrane Potentials')
plt.legend()
plt.grid(True)
# Plot input spike raster
plt.subplot(412)
for i in range(n_input):
spike_times = time[input_spikes_history[:, i] > 0]
plt.plot(spike_times, [i] * len(spike_times), '|', markersize=10, label=f'Input {i+1}')
plt.ylabel('Input Neuron')
plt.title('Input Spike Raster')
plt.legend()
plt.grid(True)
# Plot output spike raster
plt.subplot(413)
for i in range(n_output):
spike_times = time[output_spikes_history[:, i] > 0]
plt.plot(spike_times, [i] * len(spike_times), '|', markersize=10, label=f'Output {i+1}')
plt.ylabel('Output Neuron')
plt.title('Output Spike Raster')
plt.legend()
plt.grid(True)
# Plot weight evolution with distinct colors
plt.subplot(414)
colors = ['red', 'blue', 'green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan', 'magenta', 'yellow']
for idx, (i, j) in enumerate([(i, j) for i in range(n_input) for j in range(n_output)]):
plt.plot(np.arange(len(weight_history)) * 100 * dt, weight_history[:, i, j],
label=f'In{i+1}→Out{j+1}',
color=colors[idx % len(colors)],
marker='.',
markersize=2,
linewidth=1)
plt.xlabel('Time (ms)')
plt.ylabel('Weight')
plt.title('Synaptic Weight Evolution')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True)
plt.tight_layout()
plt.show()
# Plot detailed weight matrix
final_weights = network.get_weights_matrix()
plt.figure(figsize=(10, 8))
im = plt.imshow(final_weights, cmap='coolwarm', aspect='auto', vmin=0, vmax=1)
plt.colorbar(im, label='Weight')
plt.xlabel('Output Neuron')
plt.ylabel('Input Neuron')
plt.title('Final Synaptic Weight Matrix')
# Add text annotations
for i in range(final_weights.shape[0]):
for j in range(final_weights.shape[1]):
plt.text(j, i, f'{final_weights[i,j]:.3f}',
ha='center', va='center',
color='white' if final_weights[i,j] > 0.5 else 'black')
plt.tight_layout()
plt.show()
# Print statistics
print("\nNeuron Spiking Statistics:")
print("Input neurons:")
for i, neuron in enumerate(network.input_neurons):
print(f"Input {i+1}: {neuron.spike_count} spikes")
print("\nOutput neurons:")
for i, neuron in enumerate(network.output_neurons):
print(f"Output {i+1}: {neuron.spike_count} spikes")
print("\nWeight Changes Summary:")
initial_weights = weight_history[0]
final_weights = weight_history[-1]
for i in range(n_input):
for j in range(n_output):
change = final_weights[i,j] - initial_weights[i,j]
print(f"In{i+1}→Out{j+1}: {initial_weights[i,j]:.3f} → {final_weights[i,j]:.3f} (Δ = {change:.3f})")
# Print significant weight changes from synapses
print("\nDetailed Weight Change Events:")
for syn in network.synapses:
if len(syn.weight_changes) > 0:
i = network.input_neurons.index(syn.pre_neuron)
j = network.output_neurons.index(syn.post_neuron)
print(f"\nSynapse In{i+1}→Out{j+1} changes:")
for t, old_w, new_w in syn.weight_changes[:5]: # Show first 5 changes
print(f"Time {t:.1f}ms: {old_w:.3f} → {new_w:.3f} (Δ = {new_w - old_w:.3f})")
Neuron Spiking Statistics: Input neurons: Input 1: 504 spikes Input 2: 500 spikes Input 3: 503 spikes Input 4: 501 spikes Output neurons: Output 1: 861 spikes Output 2: 861 spikes Output 3: 861 spikes Weight Changes Summary: In1→Out1: 0.550 → 1.000 (Δ = 0.450) In1→Out2: 0.535 → 1.000 (Δ = 0.465) In1→Out3: 0.524 → 1.000 (Δ = 0.476) In2→Out1: 0.671 → 0.630 (Δ = -0.041) In2→Out2: 0.749 → 0.630 (Δ = -0.119) In2→Out3: 0.695 → 0.630 (Δ = -0.064) In3→Out1: 0.551 → 1.000 (Δ = 0.449) In3→Out2: 0.495 → 1.000 (Δ = 0.505) In3→Out3: 0.421 → 1.000 (Δ = 0.579) In4→Out1: 0.506 → 1.000 (Δ = 0.494) In4→Out2: 0.539 → 1.000 (Δ = 0.461) In4→Out3: 0.523 → 1.000 (Δ = 0.477) Detailed Weight Change Events: Synapse In1→Out1 changes: Time 0.1ms: 0.550 → 0.342 (Δ = -0.208) Time 1.1ms: 0.342 → 0.523 (Δ = 0.181) Time 1.2ms: 0.523 → 0.129 (Δ = -0.394) Time 2.2ms: 0.129 → 0.472 (Δ = 0.343) Time 2.3ms: 0.472 → 0.000 (Δ = -0.472) Synapse In1→Out2 changes: Time 0.1ms: 0.535 → 0.327 (Δ = -0.208) Time 1.1ms: 0.327 → 0.508 (Δ = 0.181) Time 1.2ms: 0.508 → 0.114 (Δ = -0.394) Time 2.2ms: 0.114 → 0.457 (Δ = 0.343) Time 2.3ms: 0.457 → 0.000 (Δ = -0.457) Synapse In1→Out3 changes: Time 0.1ms: 0.524 → 0.316 (Δ = -0.208) Time 1.1ms: 0.316 → 0.497 (Δ = 0.181) Time 1.2ms: 0.497 → 0.102 (Δ = -0.394) Time 2.2ms: 0.102 → 0.445 (Δ = 0.343) Time 2.3ms: 0.445 → 0.000 (Δ = -0.445) Synapse In2→Out1 changes: Time 0.0ms: 0.471 → 0.671 (Δ = 0.200) Time 1.1ms: 0.671 → 0.862 (Δ = 0.191) Time 2.2ms: 0.862 → 1.000 (Δ = 0.138) Time 15.6ms: 1.000 → 0.000 (Δ = -1.000) Time 16.6ms: 0.000 → 1.000 (Δ = 1.000) Synapse In2→Out2 changes: Time 0.0ms: 0.549 → 0.749 (Δ = 0.200) Time 1.1ms: 0.749 → 0.940 (Δ = 0.191) Time 2.2ms: 0.940 → 1.000 (Δ = 0.060) Time 15.6ms: 1.000 → 0.000 (Δ = -1.000) Time 16.6ms: 0.000 → 1.000 (Δ = 1.000) Synapse In2→Out3 changes: Time 0.0ms: 0.495 → 0.695 (Δ = 0.200) Time 1.1ms: 0.695 → 0.886 (Δ = 0.191) Time 2.2ms: 0.886 → 1.000 (Δ = 0.114) Time 15.6ms: 1.000 → 0.000 (Δ = -1.000) Time 16.6ms: 0.000 → 1.000 (Δ = 1.000) Synapse In3→Out1 changes: Time 3.8ms: 0.551 → 0.000 (Δ = -0.551) Time 4.4ms: 0.000 → 0.188 (Δ = 0.188) Time 5.5ms: 0.188 → 0.357 (Δ = 0.169) Time 5.7ms: 0.357 → 0.000 (Δ = -0.357) Time 6.6ms: 0.000 → 0.334 (Δ = 0.334) Synapse In3→Out2 changes: Time 3.8ms: 0.495 → 0.000 (Δ = -0.495) Time 4.4ms: 0.000 → 0.188 (Δ = 0.188) Time 5.5ms: 0.188 → 0.357 (Δ = 0.169) Time 5.7ms: 0.357 → 0.000 (Δ = -0.357) Time 6.6ms: 0.000 → 0.334 (Δ = 0.334) Synapse In3→Out3 changes: Time 3.8ms: 0.421 → 0.000 (Δ = -0.421) Time 4.4ms: 0.000 → 0.188 (Δ = 0.188) Time 5.5ms: 0.188 → 0.357 (Δ = 0.169) Time 5.7ms: 0.357 → 0.000 (Δ = -0.357) Time 6.6ms: 0.000 → 0.334 (Δ = 0.334) Synapse In4→Out1 changes: Time 0.4ms: 0.506 → 0.304 (Δ = -0.202) Time 1.1ms: 0.304 → 0.491 (Δ = 0.186) Time 1.8ms: 0.491 → 0.120 (Δ = -0.371) Time 2.2ms: 0.120 → 0.479 (Δ = 0.359) Time 3.1ms: 0.479 → 0.000 (Δ = -0.479) Synapse In4→Out2 changes: Time 0.4ms: 0.539 → 0.338 (Δ = -0.202) Time 1.1ms: 0.338 → 0.524 (Δ = 0.186) Time 1.8ms: 0.524 → 0.153 (Δ = -0.371) Time 2.2ms: 0.153 → 0.512 (Δ = 0.359) Time 3.1ms: 0.512 → 0.000 (Δ = -0.512) Synapse In4→Out3 changes: Time 0.4ms: 0.523 → 0.322 (Δ = -0.202) Time 1.1ms: 0.322 → 0.508 (Δ = 0.186) Time 1.8ms: 0.508 → 0.137 (Δ = -0.371) Time 2.2ms: 0.137 → 0.496 (Δ = 0.359) Time 3.1ms: 0.496 → 0.000 (Δ = -0.496)