-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtesting.py
executable file
·75 lines (58 loc) · 1.99 KB
/
testing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
Deep Spiking Convolutional Neural Network
with STDP Learning Rule on MNIST data
______________________________________________________
Research Internship
Technical University Munich
Creator: Sven Gronauer
Date: February 2018
"""
import argparse
# import logging
import numpy as np
# import scipy
import matplotlib.pyplot as plt
import pylab
# import seaborn as sns
# sns.set_style("dark")
import spynnaker8 as s
import pyNN.utility.plotting as plot
import pickle
import sklearn
from SpikingConvNet import algorithms, classes, utils
from SpikingConvNet.parameters import *
def testing_function(rc, model):
""" Apply testset on trained network """
s.setup(timestep=TIMESTEP)
network = classes.Spinnaker_Network(rc, model)
network.print_parameters()
s.run(network.total_simtime)
spikes_layer, voltage = network.retrieve_data()
try: # plot spike activity of layer as heatmap
utils.plot_spike_activity(rc,spikes_layer, model.tensors[-1])
pickle.dump(spikes_layer, open(MODEL_PATH+"spikes_layer_1.p", "wb"))
except:
rc.logging.error("Capturing Spike Activity Failed")
s.end()
network.print_parameters()
""" SVM
"""
X_test = algorithms.spikes_for_classifier(rc,
model.tensors[-1],
spikes_layer)
print("X_test")
print X_test
print("Maximum X_test")
print np.max(X_test)
score, confusion_matrix = model.classifier.predict(X_test, network.y_test)
try:
utils.plot_confusion_matrix(rc, confusion_matrix)
except:
print("could not print confusion matrix")
""" Display Plots
"""
# utils.plot_membrane_voltages(v_post_ex, network.total_simtime)
# utils.plot_spikes(rc, spikes_layer_1, title="Spikes of Layer 1")
# utils.plot_spikes(rc, spikes_layer_2, title="Spikes of Layer 2")
utils.plot_heatpmap(rc, network.images, title="Input Patterns")
plt.show()