-
Notifications
You must be signed in to change notification settings - Fork 0
/
HopfieldUtils.py
158 lines (129 loc) · 6.36 KB
/
HopfieldUtils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import datetime
from typing import List
import matplotlib.pyplot as plt
import numpy as np
import json
def plotSingleTaskStability(taskPatternStability:np.ndarray, taskEpochStart:int, title:str=None,
epochMarkers:str=None, legend:List[str]=None, figsize=(8,6), fileName:str=None):
"""
Plot the task pattern stability of one task over many epochs
Args:
taskPatternAccuracies (np.ndarray): A numpy array of dimension (numEpochs)
taskEpochStart (int): The start epoch of the task.
epochMarkers (Str, optional): The marker to use for epoch data points. Defaults to None
title (str or None, optional): Title of the graph. Defaults to None.
legend (List[str] or None, optional): A list of strings to use as the legend of the plot.
Do not include the average legend. If None, use default legend. Defaults to None
figsize (Tuple[int, int]): The size of the figure.
fileName (str, optional): If not None, saves the plot to the file name. Defaults to None.
"""
xRange = np.arange(taskPatternStability.shape[0])
plt.figure(figsize=(12,6))
label=None
if legend is not None:
label = legend[0]
plt.plot(xRange[taskEpochStart:], taskPatternStability[taskEpochStart:], marker=epochMarkers, label=label)
if legend is not None:
plt.legend(bbox_to_anchor=(1.04, 0.5), loc='center left')
plt.ylim(-0.05, max(taskPatternStability)*1.05)
plt.title(title)
plt.xlabel("Epoch")
plt.ylabel("Task Accuracy")
plt.tight_layout()
if fileName is None:
plt.show()
else:
plt.savefig(fileName)
plt.cla()
def plotTaskPatternStability(taskPatternStabilities:np.ndarray, taskEpochBoundaries:List[int], epochMarkers:str=None, plotAverage:bool=True, title:str=None,
legend:List[str]=None, figsize=(8,6), fileName:str=None):
"""
Plot the task pattern stability over many epochs, with each task getting its own line
Args:
taskPatternAccuracies (np.ndarray): A numpy array of dimension (numEpochs, numTasks)
The first index walks over epochs, while the second index walks over tasks
taskEpochBoundaries (List[int]): The list of epochs where each task starts being learned.
epochMarkers (Str, optional): The marker to use for epoch data points. Defaults to None
plotAverage (bool, optional): A boolean to also plot the average task pattern stability over all tasks
Defaults to True.
title (str or None, optional): Title of the graph. Defaults to None.
legend (List[str] or None, optional): A list of strings to use as the legend of the plot.
Do not include the average legend. If None, use default legend. Defaults to None
figsize (Tuple[int, int]): The size of the figure.
fileName (str, optional): If not None, saves the plot to the file name. Defaults to None.
"""
xRange = np.arange(taskPatternStabilities.shape[0])
plt.figure(figsize=figsize)
for i in range(taskPatternStabilities.shape[1]):
label=f"Task {i+1}"
if legend is not None:
label = legend[i]
# The index [i:, i] will select the i-th column (task i) but only
# from time i onwards, so we do not plot tasks before they are learned
plt.plot(xRange[taskEpochBoundaries[i]:], taskPatternStabilities[taskEpochBoundaries[i]:, i], marker=epochMarkers, label=label)
if plotAverage:
avgStability = []
for i in range(len(taskPatternStabilities)):
avgStability.append(np.average(taskPatternStabilities[i, :i+1]))
plt.plot(avgStability, color='k', linestyle="-.", linewidth=3, label="Average Stability")
plt.ylim(-0.05,1.05)
plt.legend(bbox_to_anchor=(1.04, 0.5), loc='center left')
plt.title(title)
plt.xlabel("Epoch")
plt.ylabel("Task Accuracy")
plt.tight_layout()
if fileName is None:
plt.show()
else:
plt.savefig(fileName)
plt.cla()
def plotTotalStablePatterns(numStableOverEpochs:List[int], N:int=None, hebbianMaximumCapacity:np.float64=None,
title:str=None, figsize=(8,6), fileName:str=None):
"""
Plot the total number of stable learned patterns over epochs.
Args:
numStableOverEpochs (List[int]): The number of stable learned patterns by epoch
N (int or None, optional): The number of units in the network. If not None plots a line
At the Hebbian maximum stable patterns. Defaults to None.
hebbianMaximumCapacity (np.float64, optional): The maximum capacity of the network (expressed as a ratio of total units),
or None (default). If None, no maximum capacity line is plotted.
title (str or None, optional): Title of the graph. Defaults to None.
figsize (Tuple[int, int]): The size of the figure.
fileName (str, optional): If not None, saves the plot to the file name. Defaults to None.
"""
plt.figure(figsize=figsize)
plt.plot(numStableOverEpochs)
if N is not None:
plt.axhline(0.138*N, color='r', linestyle='--', label="Hebbian Max")
if hebbianMaximumCapacity is not None:
plt.axhline(hebbianMaximumCapacity*N, color='r', linestyle='-.', label="Allowable Error Max Constraint")
plt.axhline(max(numStableOverEpochs), color='b', linestyle='--', label="Actual Max")
plt.legend(bbox_to_anchor=(1.04, 0.5), loc='center left')
plt.ylim(bottom=-0.05)
plt.title(title)
plt.xlabel("Epoch")
plt.ylabel("Total stable learned states")
plt.tight_layout()
if fileName is None:
plt.show()
else:
plt.savefig(fileName)
plt.cla()
def saveDataAsJSON(fileName:str=None, **kwargs):
"""
Save the given data (in kwargs) as a JSON file.
The intention is to give a network description (from network.getNetworkDescriptionJSON), as well as
any taskPatternStability and numStableOverEpochs.
Args:
fileName (str, optional): The name of the file to save to. If None, saves using a time stamp
Defaults to None.
kwargs: The names and values of items to store in the JSON file. Please be consistent!!!
"""
if fileName is None:
fileName = datetime.datetime.now().strftime("%d-%m-%Y %H-%M-%S")
data = {}
for (key, value) in kwargs.items():
data[key]=value
with open(fileName, 'w') as f:
json.dump(data, f)
print(f"SAVED TO {fileName}")