Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whittington 2020 #70

Merged
merged 154 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
154 commits
Select commit Hold shift + click to select a range
37c0e84
new TEM run file
Mar 7, 2023
244efae
TEM saving variables and running in varied environments
Mar 7, 2023
ec519a1
example plots added to notebook
Mar 8, 2023
93e7b72
tests updated
Mar 8, 2023
21f0faf
only model parameters
Mar 8, 2023
cee284d
variable saving added
Mar 8, 2023
81eba9f
updated run and notebook files
Mar 8, 2023
d7cff47
update model files
Mar 8, 2023
3881237
updated arena files
Mar 8, 2023
a70c8a0
Delete whittington_2020_examples.ipynb
LukeHollingsworth Mar 8, 2023
b211f04
Delete whittington_2020_test.py
LukeHollingsworth Mar 8, 2023
601c2c7
label readme
rodrigcd Mar 9, 2023
c41706b
includes torch dependencies
Mar 9, 2023
42415c1
includes torch
Mar 9, 2023
8e064d0
update environment classes
Mar 9, 2023
41e446a
updated agent functions
Mar 9, 2023
c9eef7d
updated training and plotting scripts
Mar 9, 2023
9400e70
requirements updated
Mar 9, 2023
ae0d05c
updated plotting notebook
Mar 9, 2023
34424ea
Merge branch 'whittington_2020' of https://github.com/ClementineDomin…
Mar 9, 2023
fd17704
added pytorch dependency
Mar 9, 2023
ede88b8
torch=1.12.1 not found
Mar 9, 2023
507c43d
added pytorch dependency
Mar 9, 2023
f720ad6
added pytorch dependency
Mar 9, 2023
aa37ee3
TEM tests added
Mar 9, 2023
99cd31b
improved plotting
Mar 9, 2023
0d4d635
summary files for example TEM results
Mar 9, 2023
17bee39
example results for plotting
Mar 9, 2023
09870aa
default parameters
Mar 9, 2023
5ce1c49
Update whittington_2020_example.ipynb
LukeHollingsworth Mar 13, 2023
793bdce
beg
Apr 6, 2023
d0a102d
environment duplication bug fix
Apr 17, 2023
fa5124d
config classes done
rodrigcd Apr 18, 2023
a34e4a6
zero shot accuracy fixed
Apr 19, 2023
909e82c
tensorboard added to requirements and model saving added
LukeHollingsworth Apr 20, 2023
1103e2b
state (x,y) causing grid plotting issues
Apr 21, 2023
a00ec7b
bug in position rounding and centering found
LukeHollingsworth May 1, 2023
7bc169a
position centering and rounding added
LukeHollingsworth May 3, 2023
93691db
plotting error in envrionment variable
LukeHollingsworth May 8, 2023
8e112d0
plotting finctions working - NPG data added under /torch_run3
May 10, 2023
878e110
gridscore metric
ClementineDomine May 23, 2023
afbe8b7
add metric
ClementineDomine May 23, 2023
62579ec
modified agent and exp for the metric
ClementineDomine May 23, 2023
1a86b6f
updated metric- allows for 2 D
ClementineDomine May 24, 2023
23d97e6
Creado mediante Colaboratory
rodrigcd May 31, 2023
7d7ec2f
Creado mediante Colaboratory
rodrigcd May 31, 2023
7232880
colab example
rodrigcd May 31, 2023
6dcb3f9
fixing colab
rodrigcd May 31, 2023
113e279
Creado mediante Colaboratory
rodrigcd May 31, 2023
a415c53
colab on readme
rodrigcd May 31, 2023
82a0391
open in colab
rodrigcd May 31, 2023
8592693
fixing path to colab
rodrigcd May 31, 2023
3f997e5
setting colab env
rodrigcd May 31, 2023
39c0ea4
colab
rodrigcd May 31, 2023
3250d26
2023-05-17 contains both NPG and original models
Jun 1, 2023
eb0395d
fixing merging conflicts
rodrigcd Jun 4, 2023
9d28b45
updating colab installation
rodrigcd Jun 4, 2023
e6756cf
colab example skeleton done, need markdowns and explanation
rodrigcd Jun 4, 2023
41bad78
adding use of behavioural data to batched environment
LukeHollingsworth Jun 5, 2023
8a6a5d8
adding option to use behavioural data
Jun 9, 2023
5334e88
colab example from main
rodrigcd Jun 13, 2023
d91d6b5
pulling from main
rodrigcd Jun 13, 2023
67933d6
adding behavioural trajectory
LukeHollingsworth Jun 13, 2023
0037a57
config file running properly
rodrigcd Jun 13, 2023
7234834
adding behavioural trajectories
Jun 16, 2023
7946e1c
problem plotting with non-square environments
Jun 20, 2023
bf71826
commenting config module, need to config other modules
rodrigcd Jun 20, 2023
d72512b
data path setup for behavioural trajectory use
Jun 28, 2023
40e54cb
behaviorual traj
Jun 28, 2023
160941a
TEM running on behavioural data
Jun 29, 2023
bf5326f
added robust data path creation and access
LukeHollingsworth Jun 30, 2023
889588c
fixing action-transition discrepencies
LukeHollingsworth Jun 30, 2023
468360b
fixing action-transition discrepency
LukeHollingsworth Jul 3, 2023
461023b
action generation added
LukeHollingsworth Jul 4, 2023
fa9b920
agent generation added
LukeHollingsworth Jul 5, 2023
8c467e7
config file comment
rodrigcd Jul 6, 2023
f1e728b
merge from main with Gin datasets
rodrigcd Jul 6, 2023
aca9c49
action-transition discrepency fixed
LukeHollingsworth Jul 6, 2023
c49b137
backend for automatic simulation
rodrigcd Jul 6, 2023
6d60732
single sim manager running properly
rodrigcd Jul 7, 2023
8e7d766
generating dir when runing sim
rodrigcd Jul 7, 2023
8421a77
generalised plotting function
ClementineDomine Jul 8, 2023
91ed541
merge
ClementineDomine Jul 8, 2023
6be123b
wernle_2018
ClementineDomine Jul 8, 2023
bed2a35
forgot the pre-commit
ClementineDomine Jul 8, 2023
9d538a7
generalised rate map for experiments
ClementineDomine Jul 9, 2023
9838c2b
generalised rate plotting
ClementineDomine Jul 9, 2023
2e53459
make plotting a module
rodrigcd Jul 9, 2023
12049ca
update metric
ClementineDomine Jul 9, 2023
9c039e8
testing sim manager
rodrigcd Jul 9, 2023
eff8977
metric that works + get rid of developement plot+ same fontsize
ClementineDomine Jul 10, 2023
1409b99
merge
ClementineDomine Jul 10, 2023
07cbd3a
pre commit
rodrigcd Jul 10, 2023
e04eaed
saving entire object with pickle
rodrigcd Jul 10, 2023
2dc2e9c
config input
ClementineDomine Jul 10, 2023
2b78a24
dict to json
rodrigcd Jul 11, 2023
05521b8
saving params as dict
rodrigcd Jul 11, 2023
23594bd
comparison from the run manadger
ClementineDomine Jul 11, 2023
ee91988
status checker and load simulation
rodrigcd Jul 11, 2023
32b55a6
status checker and load simulation
rodrigcd Jul 11, 2023
1f546ad
merge with comparison_board branch
Jul 12, 2023
7dab564
merge corrections
Jul 12, 2023
f6f1443
fix hafting
ClementineDomine Jul 12, 2023
bb0dc0e
get ratemap matrix done for all agents, documentation of simulation m…
rodrigcd Jul 12, 2023
96aa953
Merge branch 'comparison_board' of https://github.com/SainsburyWellco…
rodrigcd Jul 12, 2023
cdebc3b
merging room
ClementineDomine Jul 12, 2023
c3df545
weber plot_rates.py function
ClementineDomine Jul 13, 2023
65385f2
random action policy working in updated branch
LukeHollingsworth Jul 13, 2023
1cfdf5f
plot rates
rodrigcd Jul 13, 2023
27c2322
nice plotting
ClementineDomine Jul 13, 2023
7f20999
Merge branch 'comparison_board' of https://github.com/ClementineDomin…
ClementineDomine Jul 13, 2023
e4eaaf9
update rates function + test comparison figure on other agents
ClementineDomine Jul 14, 2023
b6c6c52
comparison sargo
ClementineDomine Jul 15, 2023
5a9fdc6
update
ClementineDomine Jul 17, 2023
ce72848
fixing TEM plotting
LukeHollingsworth Jul 17, 2023
c00b760
wenrnel tetrode, get_grid score, title figure , table figure and con…
ClementineDomine Jul 17, 2023
2f7f2ba
environment variables fixed
LukeHollingsworth Jul 18, 2023
9b12644
new saved TEM models added
Jul 18, 2023
1a1bfa4
update the jupyter notebooks
ClementineDomine Jul 18, 2023
f50ead6
grid scorere
ClementineDomine Jul 19, 2023
c3885ab
New plotting Functions
ClementineDomine Jul 20, 2023
9b90af1
Really cool jupyter
ClementineDomine Jul 20, 2023
d82031a
update
ClementineDomine Jul 20, 2023
3318e6b
width
ClementineDomine Jul 20, 2023
6e2cff0
score
ClementineDomine Jul 22, 2023
4480305
the score
ClementineDomine Jul 23, 2023
e79bd68
juypter
ClementineDomine Jul 23, 2023
ca47e2c
jupyter
ClementineDomine Jul 23, 2023
ec9aac7
simulation manager notebook
rodrigcd Jul 23, 2023
65905cf
Merge branch 'comparison_board' of https://github.com/SainsburyWellco…
rodrigcd Jul 23, 2023
a821e1a
simulation manager notebook
rodrigcd Jul 23, 2023
faa3185
back to the previous setting
ClementineDomine Jul 24, 2023
89e29c3
Update metrics.py
rhayman Jul 24, 2023
c4a4140
standardised plotting of TEM results added
LukeHollingsworth Jul 24, 2023
1e7fcda
passing test
rodrigcd Jul 24, 2023
5b26144
Merge pull request #65 from rhayman/rhayman_grid_scores
ClementineDomine Jul 24, 2023
83d9014
backend run example
rodrigcd Jul 25, 2023
519a261
Merge branch 'comparison_board' of https://github.com/SainsburyWellco…
rodrigcd Jul 25, 2023
726e9eb
simulation manager example almost done
rodrigcd Jul 25, 2023
b2aa203
changes
ClementineDomine Jul 25, 2023
49119b7
merge
ClementineDomine Jul 25, 2023
b5d491c
cleaning up plotting code
LukeHollingsworth Jul 25, 2023
5c76287
simulation manager example done
rodrigcd Jul 25, 2023
76ad063
Merge branch 'comparison_board' of https://github.com/SainsburyWellco…
rodrigcd Jul 25, 2023
68bc62b
plot_utils
ClementineDomine Jul 26, 2023
5fe403b
Merge branch 'comparison_board' of https://github.com/ClementineDomin…
LukeHollingsworth Jul 26, 2023
d66eaad
merge changes from comparison_board
LukeHollingsworth Jul 26, 2023
88ede3d
removed additional summaries
LukeHollingsworth Jul 26, 2023
b99e509
commenting added to new files
LukeHollingsworth Jul 27, 2023
83a3558
pre-commit changes made
LukeHollingsworth Jul 27, 2023
25c83a4
Merge branch 'main' into whittington_2020
ClementineDomine Jul 28, 2023
26a447e
for test to pass
ClementineDomine Jul 28, 2023
7400a4b
Merge branch 'main' into whittington_2020
ClementineDomine Jul 28, 2023
6b2fb71
fix manifest
ClementineDomine Jul 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
305 changes: 305 additions & 0 deletions examples/agent_examples/whittington_2020_example.ipynb

Large diffs are not rendered by default.

186 changes: 186 additions & 0 deletions examples/agent_examples/whittington_2020_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Standard Imports
import importlib.util
import pickle

import matplotlib.pyplot as plt
import numpy as np
import torch

import neuralplayground.agents.whittington_2020_extras.whittington_2020_analyse as analyse
from neuralplayground.agents.whittington_2020 import Whittington2020
from neuralplayground.arenas.batch_environment import BatchEnvironment

# NeuralPlayground Imports
from neuralplayground.arenas.discritized_objects import DiscreteObjectEnvironment

# NeuralPlayground Experiment Imports
from neuralplayground.experiments import Sargolini2006Data

# Select trained model
date = "2023-05-17"
run = "0"
index = "19999"
base_path = "/nfs/nhome/live/lhollingsworth/Documents/NeuralPlayground/NPG/EHC_model_comparison"
npg_path = "/nfs/nhome/live/lhollingsworth/Documents/NeuralPlayground/NPG/EHC_model_comparison/examples"
base_win_path = "H:/Documents/PhD/NeuralPlayground"
win_path = "H:/Documents/PhD/NeuralPlayground/NPG/NeuralPlayground/examples"
# Load the model: use import library to import module from specified path
model_spec = importlib.util.spec_from_file_location(
"model", win_path + "/Summaries/" + date + "/torch_run" + run + "/script/whittington_2020_model.py"
)
model = importlib.util.module_from_spec(model_spec)
model_spec.loader.exec_module(model)

# Load the parameters of the model
params = torch.load(win_path + "/Summaries/" + date + "/torch_run" + run + "/model/params_" + index + ".pt")
# Create a new tem model with the loaded parameters
tem = model.Model(params)
# Load the model weights after training
model_weights = torch.load(win_path + "/Summaries/" + date + "/torch_run" + run + "/model/tem_" + index + ".pt")
# Set the model weights to the loaded trained model weights
tem.load_state_dict(model_weights)
# Make sure model is in evaluate mode (not crucial because it doesn't currently use dropout or batchnorm layers)
tem.eval()

# Initialise environment parameters
batch_size = 16
arena_x_limits = [
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
]
arena_y_limits = [
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
]
# arena_x_limits = [[-20,20], [-20,20], [-15,15], [-10,10], [-20,20], [-20,20], [-15,15], [-10,10],
# [-20,20], [-20,20], [-15,15], [-10,10], [-20,20], [-20,20], [-15,15], [-10,10]]
# arena_y_limits = [[-4,4], [-2,2], [-2,2], [-1,1], [-4,4], [-2,2], [-2,2], [-1,1],
# [-4,4], [-2,2], [-2,2], [-1,1], [-4,4], [-2,2], [-2,2], [-1,1]]
env_name = "env_example"
mod_name = "SimpleTEM"
time_step_size = 1
state_density = 1
agent_step_size = 1 / state_density
n_objects = 45

# Init simple 2D environment with discrtised objects
env_class = DiscreteObjectEnvironment
env = BatchEnvironment(
environment_name=env_name,
env_class=DiscreteObjectEnvironment,
batch_size=batch_size,
arena_x_limits=arena_x_limits,
arena_y_limits=arena_y_limits,
state_density=state_density,
n_objects=n_objects,
agent_step_size=agent_step_size,
use_behavioural_data=False,
data_path=None,
experiment_class=Sargolini2006Data,
)

# Init TEM agent
agent = Whittington2020(
model_name=mod_name,
params=params,
batch_size=batch_size,
room_widths=env.room_widths,
room_depths=env.room_depths,
state_densities=env.state_densities,
use_behavioural_data=False,
)

# # Run around environment
# observation, state = env.reset(random_state=True, custom_state=None)
# while agent.n_walk < 5000:
# if agent.n_walk % 100 == 0:
# print(agent.n_walk)
# action = agent.batch_act(observation)
# observation, state = env.step(action, normalize_step=True)
# model_input, history, environments = agent.collect_final_trajectory()
# environments = [env.collect_environment_info(model_input, history, environments)]

# # Save environments and model_input using pickle
# with open('NPG_environments.pkl', 'wb') as f:
# pickle.dump(environments, f)
# with open('NPG_model_input.pkl', 'wb') as f:
# pickle.dump(model_input, f)

# Load environments and model_input using pickle
with open("NPG_environments.pkl", "rb") as f:
environments = pickle.load(f)
with open("NPG_model_input.pkl", "rb") as f:
model_input = pickle.load(f)

with torch.no_grad():
forward = tem(model_input, prev_iter=None)
include_stay_still = False
shiny_envs = [False, False, False, False]
env_to_plot = 0
envs_to_avg = shiny_envs if shiny_envs[env_to_plot] else [not shiny_env for shiny_env in shiny_envs]

correct_model, correct_node, correct_edge = analyse.compare_to_agents(
forward, tem, environments, include_stay_still=include_stay_still
)
zero_shot = analyse.zero_shot(forward, tem, environments, include_stay_still=include_stay_still)
occupation = analyse.location_occupation(forward, tem, environments)
g, p = analyse.rate_map(forward, tem, environments)
from_acc, to_acc = analyse.location_accuracy(forward, tem, environments)

# Plot rate maps for grid or place cells
agent.plot_rate_map(g)

# Plot results of agent comparison and zero-shot inference analysis
filt_size = 41
plt.figure()
plt.plot(
analyse.smooth(
np.mean(np.array([env for env_i, env in enumerate(correct_model) if envs_to_avg[env_i]]), 0)[1:], filt_size
),
label="tem",
)
plt.plot(
analyse.smooth(np.mean(np.array([env for env_i, env in enumerate(correct_node) if envs_to_avg[env_i]]), 0)[1:], filt_size),
label="node",
)
plt.plot(
analyse.smooth(np.mean(np.array([env for env_i, env in enumerate(correct_edge) if envs_to_avg[env_i]]), 0)[1:], filt_size),
label="edge",
)
plt.ylim(0, 1)
plt.legend()
plt.title(
"Zero-shot inference: "
+ str(np.mean([np.mean(env) for env_i, env in enumerate(zero_shot) if envs_to_avg[env_i]]) * 100)
+ "%"
)

# plt.show()
124 changes: 124 additions & 0 deletions examples/agent_examples/whittington_2020_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
Run file for the Tolman-Eichenbaum Machine (TEM) model from Whittington et al. 2020. An example setup is provided, with
TEM learning to predict upcoming sensory stimulus in a range of 16 square environments of varying sizes.
"""

# Standard Imports

import matplotlib.pyplot as plt

# NeuralPlayground Agent Imports
import neuralplayground.agents.whittington_2020_extras.whittington_2020_parameters as parameters
from neuralplayground.agents.whittington_2020 import Whittington2020
from neuralplayground.arenas.batch_environment import BatchEnvironment

# NeuralPlayground Arena Imports
from neuralplayground.arenas.discritized_objects import DiscreteObjectEnvironment

# NeuralPlayground Experiment Imports
from neuralplayground.experiments import Sargolini2006Data

# Initialise TEM Parameters
pars_orig = parameters.parameters()
params = pars_orig.copy()

# Initialise environment parameters
batch_size = 16
arena_x_limits = [
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
]
arena_y_limits = [
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
]
# arena_x_limits = [[-20,20], [-20,20], [-15,15], [-10,10], [-20,20], [-20,20], [-15,15], [-10,10],
# [-20,20], [-20,20], [-15,15], [-10,10], [-20,20], [-20,20], [-15,15], [-10,10]]
# arena_y_limits = [[-4,4], [-2,2], [-2,2], [-1,1], [-4,4], [-2,2], [-2,2], [-1,1],
# [-4,4], [-2,2], [-2,2], [-1,1], [-4,4], [-2,2], [-2,2], [-1,1]]
env_name = "Sargolini2006"
mod_name = "SimpleTEM"
time_step_size = 1
state_density = 1
agent_step_size = 1 / state_density
n_objects = 45

# # Init environment from Hafting 2008 (optional, if chosen, comment out the )
# env = Hafting2008(agent_step_size=agent_step_size,
# time_step_size=time_step_size,
# use_behavioral_data=False)

# # Init simple 2D (batched) environment with discrtised objects
# env_class = DiscreteObjectEnvironment

# Init environment from Sargolini, with behavioural data instead of random walk
env = BatchEnvironment(
environment_name=env_name,
env_class=DiscreteObjectEnvironment,
batch_size=batch_size,
arena_x_limits=arena_x_limits,
arena_y_limits=arena_y_limits,
state_density=state_density,
n_objects=n_objects,
agent_step_size=agent_step_size,
use_behavioural_data=False,
data_path=None,
experiment_class=Sargolini2006Data,
)

# Init TEM agent
agent = Whittington2020(
model_name=mod_name,
params=params,
batch_size=batch_size,
room_widths=env.room_widths,
room_depths=env.room_depths,
state_densities=env.state_densities,
use_behavioural_data=False,
)

# Reset environment and begin training (random_state=True is currently necessary)
observation, state = env.reset(random_state=True, custom_state=None)
for i in range(3):
print("Iteration: ", i)
while agent.n_walk < params["n_rollout"]:
actions = agent.batch_act(observation)
observation, state = env.step(actions, normalize_step=True)
agent.update()

# Plot most recent trajectory of the first environment in batch
ax = env.plot_trajectory()
fontsize = 18
ax.grid()
ax.set_xlabel("width", fontsize=fontsize)
ax.set_ylabel("depth", fontsize=fontsize)
plt.savefig("trajectory.png")
plt.show()
Loading