Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch dev #1396

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 26 additions & 31 deletions caiman/components_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import os
import peakutils
import tensorflow as tf
import torch
import scipy
from scipy.sparse import csc_matrix
from scipy.stats import norm
Expand Down Expand Up @@ -273,42 +273,37 @@ def evaluate_components_CNN(A,
if not isGPU and 'CAIMAN_ALLOW_GPU' not in os.environ:
print("GPU run not requested, disabling use of GPUs")
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
try:
os.environ["KERAS_BACKEND"] = "tensorflow"
from tensorflow.keras.models import model_from_json
use_keras = True
logger.info('Using Keras')
try:
os.environ["KERAS_BACKEND"] = "torch"
from keras.models import model_load
use_keras = True
logging.info('Using Keras')
except (ModuleNotFoundError):
use_keras = False
logger.info('Using Tensorflow')
use_keras = False
logging.info('Using Torch')

if loaded_model is None:
if use_keras:
if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".json")):
model_file = os.path.join(caiman_datadir(), model_name + ".json")
model_weights = os.path.join(caiman_datadir(), model_name + ".h5")
elif os.path.isfile(model_name + ".json"):
model_file = model_name + ".json"
model_weights = model_name + ".h5"
if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".keras")):
model_file = os.path.join(caiman_datadir(), model_name + ".keras")
elif os.path.isfile(model_name + ".keras"):
model_file = model_name + ".keras"
else:
raise FileNotFoundError(f"File for requested model {model_name} not found")
with open(model_file, 'r') as json_file:
print(f"USING MODEL (keras API): {model_file}")
loaded_model_json = json_file.read()

loaded_model = model_from_json(loaded_model_json)
loaded_model.load_weights(model_name + '.h5')

print(f"USING MODEL (keras API): {model_file}")
loaded_model = model_load(model_file)
else:
if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".h5.pb")):
model_file = os.path.join(caiman_datadir(), model_name + ".h5.pb")
elif os.path.isfile(model_name + ".h5.pb"):
model_file = model_name + ".h5.pb"
if os.path.isfile(os.path.join(caiman_datadir(), model_name + ".pt")):
model_file = os.path.join(caiman_datadir(), model_name + ".pt")
elif os.path.isfile(model_name + ".pt"):
model_file = model_name + ".pt"
else:
raise FileNotFoundError(f"File for requested model {model_name} not found")
print(f"USING MODEL (tensorflow API): {model_file}")
loaded_model = caiman.utils.utils.load_graph(model_file)
loaded_model = torch.load(model_file)

logger.debug("Loaded model from disk")
logging.debug("Loaded model from disk")

half_crop = np.minimum(gSig[0] * 4 + 1, patch_size), np.minimum(gSig[1] * 4 + 1, patch_size)
dims = np.array(dims)
Expand All @@ -323,11 +318,11 @@ def evaluate_components_CNN(A,
if use_keras:
predictions = loaded_model.predict(final_crops[:, :, :, np.newaxis], batch_size=32, verbose=1)
else:
tf_in = loaded_model.get_tensor_by_name('prefix/conv2d_20_input:0')
tf_out = loaded_model.get_tensor_by_name('prefix/output_node0:0')
with tf.Session(graph=loaded_model) as sess:
predictions = sess.run(tf_out, feed_dict={tf_in: final_crops[:, :, :, np.newaxis]})
sess.close()
final_crops = torch.tensor(final_crops, dtype=torch.float32)
final_crops = torch.reshape(final_crops, (-1, final_crops.shape[-1],
final_crops.shape[1], final_crops.shape[2]))
with torch.no_grad():
prediction = loaded_model(final_crops[:, np.newaxis, :, :])

return predictions, final_crops

Expand Down
62 changes: 36 additions & 26 deletions caiman/source_extraction/cnmf/online_cnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
imaging data in real time. In Advances in Neural Information Processing Systems
(pp. 2381-2391).
@url http://papers.nips.cc/paper/6832-onacid-online-analysis-of-calcium-imaging-data-in-real-time

Implemented in PyTorch
Date: July 18, 2024
"""

import cv2
Expand All @@ -26,7 +29,7 @@
from scipy.stats import norm
from sklearn.decomposition import NMF
from sklearn.preprocessing import normalize
import tensorflow as tf
import torch
from time import time

import caiman
Expand Down Expand Up @@ -322,32 +325,28 @@ def _prepare_object(self, Yr, T, new_dims=None, idx_components=None):
self.params.set('online', {'sniper_mode': False})
self.tf_in = None
self.tf_out = None
# self.use_torch = None
else:
try:
from tensorflow.keras.models import model_from_json
logger.info('Using Keras')
try:
from keras.models import load_model
logging.info('Using Keras')
use_keras = True
except(ModuleNotFoundError):
use_keras = False
logger.info('Using Tensorflow')
use_keras = False
logging.info('Using Torch')

path = self.params.get('online', 'path_to_model').split(".")[:-1]
if use_keras:
path = self.params.get('online', 'path_to_model').split(".")[:-1]
json_path = ".".join(path + ["json"])
model_path = ".".join(path + ["h5"])
json_file = open(json_path, 'r')
loaded_model_json = json_file.read()
json_file.close()
loaded_model = model_from_json(loaded_model_json)
loaded_model.load_weights(model_path)
self.tf_in = None
self.tf_out = None
else:
path = self.params.get('online', 'path_to_model').split(".")[:-1]
model_path = '.'.join(path + ['h5', 'pb'])
# uses online model -> be careful
model_path = ".".join(path + ["keras"])
loaded_model = model_load(model_path)
# self.use_torch = False
else:
model_path = '.'.join(path + ['pt'])
loaded_model = load_graph(model_path)
self.tf_in = loaded_model.get_tensor_by_name('prefix/conv2d_1_input:0')
self.tf_out = loaded_model.get_tensor_by_name('prefix/output_node0:0')
loaded_model = tf.Session(graph=loaded_model)
loaded_model = torch.load(model_file)
# self.use_torch = True

self.loaded_model = loaded_model

if self.is1p:
Expand Down Expand Up @@ -549,6 +548,7 @@ def fit_next(self, t, frame_in, num_iters_hals=3):
use_peak_max=self.params.get('online', 'use_peak_max'),
mean_buff=self.estimates.mean_buff,
tf_in=self.tf_in, tf_out=self.tf_out,
# use_torch=self.use_torch,
ssub_B=ssub_B, W=self.estimates.W if self.is1p else None,
b0=self.estimates.b0 if self.is1p else None,
corr_img=self.estimates.corr_img if use_corr else None,
Expand Down Expand Up @@ -2004,6 +2004,7 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5),
thresh_CNN_noisy=0.5, use_peak_max=False,
thresh_std_peak_resid = 1, mean_buff=None,
tf_in=None, tf_out=None):
# use_torch=None):
"""
Extract new candidate components from the residual buffer and test them
using space correlation or the CNN classifier. The function runs the CNN
Expand Down Expand Up @@ -2084,12 +2085,19 @@ def get_candidate_components(sv, dims, Yres_buf, min_num_trial=3, gSig=(5, 5),
Ain2 /= np.std(Ain2,axis=1)[:,None]
Ain2 = np.reshape(Ain2,(-1,) + tuple(np.diff(ijSig_cnn).squeeze()),order= 'F')
Ain2 = np.stack([cv2.resize(ain,(patch_size ,patch_size)) for ain in Ain2])
if tf_in is None:
if use_torch is None:
predictions = loaded_model.predict(Ain2[:,:,:,np.newaxis], batch_size=min_num_trial, verbose=0)
keep_cnn = list(np.where(predictions[:, 0] > thresh_CNN_noisy)[0])
else:
predictions = loaded_model.run(tf_out, feed_dict={tf_in: Ain2[:, :, :, np.newaxis]})
keep_cnn = list(np.where(predictions[:, 0] > thresh_CNN_noisy)[0])
cnn_pos = Ain2[keep_cnn]
final_crops = torch.tensor(Ain2, dtype=torch.float32)
final_crops = torch.reshape(Ain2, (-1, Ain2.shape[-1],
Ain2.shape[1], Ain2.shape[2]))
with torch.no_grad():
prediction = loaded_model(Ain2[:, np.newaxis, :, :])
keep_cnn = list(torch.where(predictions[:, 0] > thresh_CNN_noisy)[0])

cnn_pos = Ain2[keep_cnn] #Make sure this works
# tensor.numpy() also works
else:
keep_cnn = [] # list(range(len(Ain_cnn)))

Expand Down Expand Up @@ -2139,6 +2147,7 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
crosscorr=None, col_ind=None, row_ind=None, corr_img_mode=None,
max_img=None, downscale_matrix=None, upscale_matrix=None,
tf_in=None, tf_out=None):
# torch_in=None, torch_out=None):
"""
Checks for new components in the residual buffer and incorporates them if they pass the acceptance tests
"""
Expand Down Expand Up @@ -2169,6 +2178,7 @@ def update_num_components(t, sv, Ab, Cf, Yres_buf, Y_buf, rho_buf,
loaded_model=loaded_model, thresh_CNN_noisy=thresh_CNN_noisy,
use_peak_max=use_peak_max, test_both=test_both, mean_buff=mean_buff,
tf_in=tf_in, tf_out=tf_out)
#torch_in=torch_in, torch_out=torch_out)

ind_new_all = ijsig_all

Expand Down
49 changes: 49 additions & 0 deletions caiman/tests/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#!/usr/bin/env python

import numpy as np
import os
import keras

from caiman.paths import caiman_datadir
from caiman.utils.utils import load_graph

try:
os.environ["KERAS_BACKEND"] = "torch"
from keras.models import load_model
use_keras = True
except(ModuleNotFoundError):
import torch
use_keras = False

def test_torch():
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

try:
model_name = os.path.join(caiman_datadir(), 'model', 'cnn_model')
if use_keras:
model_file = model_name + ".keras"
print('USING MODEL:' + model_file)

loaded_model = load_model(model_file)
loaded_model.compile('sgd', 'mse')
elif use_keras == True:
model_file = model_name + ".pth"
loaded_model = torch.load(model_file)
except:
raise Exception(f'NN model could not be loaded. use_keras = {use_keras}')

A = np.random.randn(10, 50, 50, 1)
try:
if use_keras == False:
predictions = loaded_model.predict(A, batch_size=32)
elif use_keras == True:
A = torch.tensor(A, dtype=torch.float32)
A = torch.reshape(A, (-1, A.shape[-1], A.shape[1], A.shape[2]))
with torch.no_grad():
predictions = loaded_model(A)
pass
except:
raise Exception('NN model could not be deployed. use_keras = ' + str(use_keras))

if __name__ == "__main__":
test_torch()
6 changes: 6 additions & 0 deletions caiman/train/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env python
import pkg_resources

from caiman.train.train_cnn_model_helper import cnn_model_pytorch, train_test_split, train, validate, get_batch_accuracy, save_model_pytorch, load_model_pytorch, cnn_model_keras, save_model_keras, load_model_keras

__version__ = pkg_resources.get_distribution('caiman').version
Loading