From db2fc744d4481968000185fd9d39cb8213cd7489 Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 21 Nov 2024 16:33:33 +0100 Subject: [PATCH 01/77] create separate module for preprocessing steps --- src/vame/__init__.py | 2 +- src/vame/preprocessing/__init__.py | 0 src/vame/{util => preprocessing}/align_egocentrical.py | 0 3 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 src/vame/preprocessing/__init__.py rename src/vame/{util => preprocessing}/align_egocentrical.py (100%) diff --git a/src/vame/__init__.py b/src/vame/__init__.py index 4156c778..d9d56628 100644 --- a/src/vame/__init__.py +++ b/src/vame/__init__.py @@ -14,7 +14,7 @@ from vame.analysis import generative_model from vame.analysis import gif from vame.util.csv_to_npy import pose_to_numpy -from vame.util.align_egocentrical import egocentric_alignment +from vame.preprocessing.align_egocentrical import egocentric_alignment from vame.util import model_util from vame.util import auxiliary from vame.util.report import report diff --git a/src/vame/preprocessing/__init__.py b/src/vame/preprocessing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/vame/util/align_egocentrical.py b/src/vame/preprocessing/align_egocentrical.py similarity index 100% rename from src/vame/util/align_egocentrical.py rename to src/vame/preprocessing/align_egocentrical.py From 2b2ead03ae40bb83410f28b6026919a03f5f659e Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 21 Nov 2024 17:00:30 +0100 Subject: [PATCH 02/77] wip --- .github/workflows/testing.yaml | 2 +- src/vame/initialize_project/new.py | 2 +- src/vame/preprocessing/align_egocentrical.py | 76 +++----------------- src/vame/util/report.py | 8 +-- src/vame/util/video.py | 8 --- src/vame/video/__init__.py | 4 ++ src/vame/video/video.py | 58 +++++++++++++++ 7 files changed, 75 insertions(+), 83 deletions(-) delete mode 100644 src/vame/util/video.py create mode 100644 src/vame/video/__init__.py create mode 100644 src/vame/video/video.py diff --git a/.github/workflows/testing.yaml b/.github/workflows/testing.yaml index 2fba845d..3289e280 100644 --- a/.github/workflows/testing.yaml +++ b/.github/workflows/testing.yaml @@ -5,7 +5,7 @@ on: branches: - main - dev - + - preprocessing jobs: run: diff --git a/src/vame/initialize_project/new.py b/src/vame/initialize_project/new.py index e342342b..784d72fa 100644 --- a/src/vame/initialize_project/new.py +++ b/src/vame/initialize_project/new.py @@ -9,7 +9,7 @@ from vame.schemas.states import VAMEPipelineStatesSchema from vame.logging.logger import VameLogger from vame.util.auxiliary import write_config -from vame.util.video import get_video_frame_rate +from vame.video.video import get_video_frame_rate from vame.io.load_poses import load_pose_estimation diff --git a/src/vame/preprocessing/align_egocentrical.py b/src/vame/preprocessing/align_egocentrical.py index 7ab47099..e2d95b18 100644 --- a/src/vame/preprocessing/align_egocentrical.py +++ b/src/vame/preprocessing/align_egocentrical.py @@ -4,8 +4,9 @@ import pandas as pd import tqdm from typing import Tuple, List, Union -from vame.logging.logger import VameLogger, TqdmToLogger from pathlib import Path + +from vame.logging.logger import VameLogger, TqdmToLogger from vame.util.auxiliary import read_config from vame.schemas.states import EgocentricAlignmentFunctionSchema, save_state from vame.schemas.project import PoseEstimationFiletype @@ -15,6 +16,7 @@ background, read_pose_estimation_file, ) +from vame.video import get_video_frame_rate logger_config = VameLogger(__name__) @@ -160,13 +162,6 @@ def align_mouse( rect = tuple(lst) center, size, theta = rect - # lst2 = list(rect) - # lst2[0][0] = center[0] - size[0]//2 - # lst2[0][1] = center[1] - size[1]//2 - # rect = tuple(lst2) - # center[0] -= size[0]//2 - # center[1] -= size[0]//2 # added this shift to change center to belly 2/28/2024 - # crop image out, shifted_points = crop_and_flip( rect, @@ -192,52 +187,6 @@ def align_mouse( return images, points, time_series -def play_aligned_video( - a: List[np.ndarray], - n: List[List[np.ndarray]], - frame_count: int, -) -> None: - """ - Play the aligned video. - - Parameters - ---------- - a : List[np.ndarray] - List of aligned images. - n : List[List[np.ndarray]] - List of aligned DLC points. - frame_count : int - Number of frames in the video. - """ - colors = [ - (255, 0, 0), - (0, 255, 0), - (0, 0, 255), - (255, 255, 0), - (255, 0, 255), - (0, 255, 255), - (0, 0, 0), - (255, 255, 255), - ] - for i in range(frame_count): - # Capture frame-by-frame - ret, frame = True, a[i] - if ret is True: - # Display the resulting frame - frame = cv.cvtColor(frame.astype("uint8") * 255, cv.COLOR_GRAY2BGR) - im_color = cv.applyColorMap(frame, cv.COLORMAP_JET) - for c, j in enumerate(n[i]): - cv.circle(im_color, (j[0], j[1]), 5, colors[c], -1) - cv.imshow("Frame", im_color) - # Press Q on keyboard to exit - # Break the loop - if cv.waitKey(25) & 0xFF == ord("q"): - break - else: - break - cv.destroyAllWindows() - - def alignment( project_path: str, session: str, @@ -248,7 +197,6 @@ def alignment( pose_estimation_filetype: PoseEstimationFiletype, path_to_pose_nwb_series_data: Union[str, None] = None, use_video: bool = False, - check_video: bool = False, tqdm_stream: Union[TqdmToLogger, None] = None, ) -> Tuple[np.ndarray, List[np.ndarray]]: """ @@ -274,8 +222,6 @@ def alignment( Path to the pose series data in nwb files. Defaults to None. use_video : bool, optional Whether to use video for alignment. Defaults to False. - check_video : bool, optional - Whether to check the aligned video. Defaults to False. tqdm_stream : Union[TqdmToLogger, None], optional Tqdm stream to log the progress. Defaults to None. @@ -320,11 +266,7 @@ def alignment( video_path=video_path, save_background=False, ) - capture = cv.VideoCapture(video_path) - if not capture.isOpened(): - raise Exception(f"Unable to open video file: {video_path}") - frame_count = int(capture.get(cv.CAP_PROP_FRAME_COUNT)) - capture.release() + frame_count = get_video_frame_rate(video_path) else: bg = 0 # Change this to an abitrary number if you first want to test the code @@ -345,9 +287,6 @@ def alignment( tqdm_stream=tqdm_stream, ) - if check_video: - play_aligned_video(frames, n, frame_count) - return time_series, frames @@ -442,7 +381,6 @@ def egocentric_alignment( else paths_to_pose_nwb_series_data[i] ), use_video=use_video, - check_video=check_video, tqdm_stream=tqdm_stream, ) @@ -458,7 +396,11 @@ def egocentric_alignment( # Save new shifted file np.save( os.path.join( - project_path, "data", "processed", session, session + "-PE-seq.npy" + project_path, + "data", + "processed", + session, + session + "-PE-seq.npy", ), egocentric_time_series_shifted, ) diff --git a/src/vame/util/report.py b/src/vame/util/report.py index e4560ea6..1bb36fd3 100644 --- a/src/vame/util/report.py +++ b/src/vame/util/report.py @@ -30,12 +30,8 @@ def report( with open(project_path / "states" / "states.json") as f: project_states = json.load(f) - pose_estimation_files = list( - (project_path / "data" / "raw").glob("*.nc") - ) - video_files = list( - (project_path / "data" / "raw").glob("*.mp4") - ) + pose_estimation_files = list((project_path / "data" / "raw").glob("*.nc")) + video_files = list((project_path / "data" / "raw").glob("*.mp4")) # Create a report folder for the project, if it does not exist report_folder = project_path / "reports" diff --git a/src/vame/util/video.py b/src/vame/util/video.py deleted file mode 100644 index 7481a7da..00000000 --- a/src/vame/util/video.py +++ /dev/null @@ -1,8 +0,0 @@ -import cv2 - - -def get_video_frame_rate(video_path): - video = cv2.VideoCapture(video_path) - frame_rate = int(video.get(cv2.CAP_PROP_FPS)) - video.release() - return frame_rate diff --git a/src/vame/video/__init__.py b/src/vame/video/__init__.py new file mode 100644 index 00000000..3cec722d --- /dev/null +++ b/src/vame/video/__init__.py @@ -0,0 +1,4 @@ +from vame.video.video import ( + get_video_frame_rate, + play_aligned_video, +) diff --git a/src/vame/video/video.py b/src/vame/video/video.py new file mode 100644 index 00000000..aeee4b80 --- /dev/null +++ b/src/vame/video/video.py @@ -0,0 +1,58 @@ +import cv2 +from typing import List +import numpy as np + + +def get_video_frame_rate(video_path): + video = cv2.VideoCapture(video_path) + if not video.isOpened(): + raise Exception(f"Unable to open video file: {video_path}") + frame_rate = int(video.get(cv2.CAP_PROP_FPS)) + video.release() + return frame_rate + + +def play_aligned_video( + a: List[np.ndarray], + n: List[List[np.ndarray]], + frame_count: int, +) -> None: + """ + Play the aligned video. + + Parameters + ---------- + a : List[np.ndarray] + List of aligned images. + n : List[List[np.ndarray]] + List of aligned DLC points. + frame_count : int + Number of frames in the video. + """ + colors = [ + (255, 0, 0), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (255, 0, 255), + (0, 255, 255), + (0, 0, 0), + (255, 255, 255), + ] + for i in range(frame_count): + # Capture frame-by-frame + ret, frame = True, a[i] + if ret is True: + # Display the resulting frame + frame = cv2.cvtColor(frame.astype("uint8") * 255, cv2.COLOR_GRAY2BGR) + im_color = cv2.applyColorMap(frame, cv2.COLORMAP_JET) + for c, j in enumerate(n[i]): + cv2.circle(im_color, (j[0], j[1]), 5, colors[c], -1) + cv2.imshow("Frame", im_color) + # Press Q on keyboard to exit + # Break the loop + if cv2.waitKey(25) & 0xFF == ord("q"): + break + else: + break + cv2.destroyAllWindows() From d538d8679c90a311db509ef32bdce043e610340f Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 28 Nov 2024 18:37:55 +0100 Subject: [PATCH 03/77] clean a bit egocentrical align --- src/vame/__init__.py | 5 +- src/vame/preprocessing/align_egocentrical.py | 315 ++++++++++++++++++- src/vame/preprocessing/align_new.py | 83 +++++ src/vame/util/data_manipulation.py | 13 +- src/vame/util/gif_pose_helper.py | 4 +- tests/test_util.py | 12 +- 6 files changed, 412 insertions(+), 20 deletions(-) create mode 100644 src/vame/preprocessing/align_new.py diff --git a/src/vame/__init__.py b/src/vame/__init__.py index d9d56628..3e43ad94 100644 --- a/src/vame/__init__.py +++ b/src/vame/__init__.py @@ -14,7 +14,10 @@ from vame.analysis import generative_model from vame.analysis import gif from vame.util.csv_to_npy import pose_to_numpy -from vame.preprocessing.align_egocentrical import egocentric_alignment +from vame.preprocessing.align_egocentrical import ( + egocentric_alignment_legacy, + egocentric_alignment, +) from vame.util import model_util from vame.util import auxiliary from vame.util.report import report diff --git a/src/vame/preprocessing/align_egocentrical.py b/src/vame/preprocessing/align_egocentrical.py index e2d95b18..c2d3b1cb 100644 --- a/src/vame/preprocessing/align_egocentrical.py +++ b/src/vame/preprocessing/align_egocentrical.py @@ -12,7 +12,7 @@ from vame.schemas.project import PoseEstimationFiletype from vame.util.data_manipulation import ( interpol_first_rows_nans, - crop_and_flip, + crop_and_flip_legacy, background, read_pose_estimation_file, ) @@ -23,7 +23,7 @@ logger = logger_config.logger -def align_mouse( +def align_mouse_legacy( project_path: str, session: str, video_format: str, @@ -163,7 +163,7 @@ def align_mouse( center, size, theta = rect # crop image - out, shifted_points = crop_and_flip( + out, shifted_points = crop_and_flip_legacy( rect, img, pose_list_bordered, @@ -187,7 +187,7 @@ def align_mouse( return images, points, time_series -def alignment( +def alignment_legacy( project_path: str, session: str, pose_ref_index: Tuple[int, int], @@ -232,7 +232,7 @@ def alignment( """ # read out data file_path = str(Path(project_path) / "data" / "raw" / f"{session}.nc") - data, data_mat = read_pose_estimation_file( + data, data_mat, _ = read_pose_estimation_file( file_path=file_path, file_type=pose_estimation_filetype, path_to_pose_nwb_series_data=path_to_pose_nwb_series_data, @@ -272,7 +272,7 @@ def alignment( # Change this to an abitrary number if you first want to test the code frame_count = len(data) - frames, n, time_series = align_mouse( + frames, n, time_series = align_mouse_legacy( project_path=project_path, session=session, video_format=video_format, @@ -290,8 +290,8 @@ def alignment( return time_series, frames -@save_state(model=EgocentricAlignmentFunctionSchema) -def egocentric_alignment( +# @save_state(model=EgocentricAlignmentFunctionSchema) +def egocentric_alignment_legacy( config: str, pose_ref_index: Tuple[int, int] = (0, 1), crop_size: Tuple[int, int] = (300, 300), @@ -367,7 +367,7 @@ def egocentric_alignment( "Aligning session %s, Pose confidence value: %.2f" % (session, confidence) ) - egocentric_time_series, frames = alignment( + egocentric_time_series, frames = alignment_legacy( project_path=project_path, session=session, pose_ref_index=pose_ref_index, @@ -393,6 +393,132 @@ def egocentric_alignment( egocentric_time_series_shifted[y_shifted_indices, :] -= belly_Y_shift egocentric_time_series_shifted[x_shifted_indices, :] -= belly_X_shift + # Save new shifted file + np.save( + os.path.join( + project_path, + "data", + "processed", + session, + session + "-PE-seq-legacy.npy", + ), + egocentric_time_series_shifted, + ) + + logger.info( + "Your data is now in the right format and you can call vame.create_trainset()" + ) + except Exception as e: + logger.exception(f"{e}") + raise e + finally: + logger_config.remove_file_handler() + + +@save_state(model=EgocentricAlignmentFunctionSchema) +def egocentric_alignment( + config: str, + pose_ref_1: str = "snout", + pose_ref_2: str = "tailbase", + crop_size: Tuple[int, int] = (300, 300), + save_logs: bool = False, +) -> None: + """ + Egocentric alignment of bevarioral videos. + Fills in the values in the "egocentric_alignment" key of the states.json file. + Creates training dataset for VAME at: + - project_name/ + - data/ + - filename/ + - filename-PE-seq.npy + - filename/ + - filename-PE-seq.npy + The produced .npy files contain the aligned time series data in the + shape of (num_dlc_features, num_video_frames). + + Parameters + ---------- + config : str + Path for the project config file. + pose_ref_index : list, optional + Pose reference index to be used to align. Defaults to [0, 1]. + crop_size : tuple, optional + Size to crop the video. Defaults to (300,300). + + Raises: + ------ + ValueError + If the config.yaml indicates that the data is not egocentric. + """ + try: + config_file = Path(config).resolve() + cfg = read_config(str(config_file)) + if cfg["egocentric_data"]: + raise ValueError( + "The config.yaml indicates that the data is egocentric. Please check the parameter 'egocentric_data'." + ) + tqdm_stream = None + + if save_logs: + log_path = Path(cfg["project_path"]) / "logs" / "egocentric_alignment.log" + logger_config.add_file_handler(str(log_path)) + tqdm_stream = TqdmToLogger(logger=logger) + + logger.info("Starting egocentric alignment") + project_path = cfg["project_path"] + sessions = cfg["session_names"] + confidence = cfg["pose_confidence"] + num_features = cfg["num_features"] + + y_shifted_indices = np.arange(0, num_features, 2) + x_shifted_indices = np.arange(1, num_features, 2) + # reference_Y_ind = pose_ref_index[0] * 2 + # reference_X_ind = (pose_ref_index[0] * 2) + 1 + + # call function and save into your VAME data folder + for i, session in enumerate(sessions): + logger.info( + "Aligning session %s, Pose confidence value: %.2f" + % (session, confidence) + ) + # read out data + file_path = str(Path(project_path) / "data" / "raw" / f"{session}.nc") + _, data_mat, ds = read_pose_estimation_file(file_path=file_path) + + # get the coordinates for alignment from data table + # pose_list dimensions: (num_body_parts, num_frames, 3) + pose_list = [] + for i in range(int(data_mat.shape[1] / 3)): + pose_list.append(data_mat[:, i * 3 : (i + 1) * 3]) + + frame_count = ds.position.time.shape[0] + keypoints_names = ds.keypoints.values + + reference_X_ind = np.where(ds.keypoints.values == pose_ref_1)[0][0] * 2 + reference_Y_ind = reference_X_ind + 1 + + pose_ref_index = ( + np.where(keypoints_names == pose_ref_1)[0][0], + np.where(keypoints_names == pose_ref_2)[0][0], + ) + + egocentric_time_series = alignment( + crop_size=crop_size, + pose_list=pose_list, + pose_ref_index=pose_ref_index, + confidence=confidence, + frame_count=frame_count, + tqdm_stream=tqdm_stream, + ) + + # Shifiting section added 2/29/2024 PN + egocentric_time_series_shifted = egocentric_time_series + reference_Y_shift = egocentric_time_series[reference_Y_ind, :] + reference_X_shift = egocentric_time_series[reference_X_ind, :] + + egocentric_time_series_shifted[y_shifted_indices, :] -= reference_Y_shift + egocentric_time_series_shifted[x_shifted_indices, :] -= reference_X_shift + # Save new shifted file np.save( os.path.join( @@ -405,6 +531,15 @@ def egocentric_alignment( egocentric_time_series_shifted, ) + # Add new variable to the dataset + ds["position_aligned"] = ( + ("time", "individuals", "keypoints", "space"), + egocentric_time_series_shifted.T.reshape(frame_count, 1, len(keypoints_names), 2), + ) + # save to file + result_file = Path(project_path) / "data" / "processed" / session / f"{session}-aligned.nc" + ds.to_netcdf(result_file, engine="scipy") + logger.info( "Your data is now in the right format and you can call vame.create_trainset()" ) @@ -413,3 +548,165 @@ def egocentric_alignment( raise e finally: logger_config.remove_file_handler() + + +def alignment( + crop_size: Tuple[int, int], + pose_list: List[np.ndarray], + pose_ref_index: Tuple[int, int], + confidence: float, + frame_count: int, + tqdm_stream: Union[TqdmToLogger, None] = None, +) -> np.ndarray: + """ + Egocentric alignment of pose estimation data. + + Parameters: + ----------- + crop_size : Tuple[int, int] + Size to crop the video frames. + pose_list : List[np.ndarray] + List of pose coordinates. + pose_ref_index : Tuple[int, int] + Pose reference indices. + confidence : float + Pose confidence threshold. + frame_count : int + Number of frames to align. + tqdm_stream : Union[TqdmToLogger, None], optional + Tqdm stream to log the progress. Defaults to None. + + Returns + ------- + np.ndarray + Aligned time series data. + """ + points = [] + + # for i in pose_list: + # for j in i: + # if j[2] <= confidence: + # j[0], j[1] = np.nan, np.nan + + # for i in pose_list: + # i = interpol_first_rows_nans(i) + + for idx in tqdm.tqdm( + range(frame_count), + disable=not True, + file=tqdm_stream, + desc="Align frames", + ): + # Read coordinates and add border + pose_list_bordered = [] + + for i in pose_list: + pose_list_bordered.append( + (int(i[idx][0] + crop_size[0]), int(i[idx][1] + crop_size[1])) + ) + + punkte = [] + for i in pose_ref_index: + coord = [ + pose_list_bordered[i][0], + pose_list_bordered[i][1], + ] + punkte.append(coord) + + punkte = [punkte] + punkte = np.asarray(punkte) + + # calculate minimal rectangle around snout and tail + rect = cv.minAreaRect(punkte) + + # change size in rect tuple structure to be equal to crop_size + lst = list(rect) + # lst[0] = center_belly + lst[1] = crop_size + rect = tuple(lst) + + # crop image + shifted_points = crop_and_flip( + rect=rect, + points=pose_list_bordered, + ref_index=pose_ref_index, + ) + + points.append(shifted_points) + + time_series = np.zeros((len(pose_list) * 2, frame_count)) + for i in range(frame_count): + idx = 0 + for j in range(len(pose_list)): + time_series[idx : idx + 2, i] = points[i][j] + idx += 2 + + return time_series + + +def crop_and_flip( + rect: Tuple, + points: List[np.ndarray], + ref_index: Tuple[int, int], +) -> List[np.ndarray]: + """ + Crop and flip the image based on the given rectangle and points. + + Parameters + ---------- + rect : Tuple + Rectangle coordinates (center, size, theta). + points : List[np.ndarray] + List of points. + ref_index : Tuple[int, int] + Reference indices for alignment. + + Returns + ------- + Tuple[np.ndarray, List[np.ndarray]] + Cropped and flipped image, and shifted points. + """ + # Read out rect structures and convert + center, size, theta = rect + center, size = tuple(map(int, center)), tuple(map(int, size)) + + # Get rotation matrix + M = cv.getRotationMatrix2D(center, theta, 1) + + # shift DLC points + x_diff = center[0] - size[0] // 2 + y_diff = center[1] - size[1] // 2 + dlc_points_shifted = [] + for i in points: + point = cv.transform(np.array([[[i[0], i[1]]]]), M)[0][0] + point[0] -= x_diff + point[1] -= y_diff + dlc_points_shifted.append(point) + + # check if flipped correctly, otherwise flip again + if dlc_points_shifted[ref_index[1]][0] >= dlc_points_shifted[ref_index[0]][0]: + rect = ( + (size[0] // 2, size[0] // 2), + size, + 180, + ) # should second value be size[1]? Is this relevant to the flip? 3/5/24 KKL + center, size, theta = rect + center, size = tuple(map(int, center)), tuple(map(int, size)) + + # Get rotation matrix + M = cv.getRotationMatrix2D(center, theta, 1) + + # shift DLC points + x_diff = center[0] - size[0] // 2 + y_diff = center[1] - size[1] // 2 + + points = dlc_points_shifted + dlc_points_shifted = [] + + for i in points: + point = cv.transform(np.array([[[i[0], i[1]]]]), M)[0][0] + point[0] -= x_diff + point[1] -= y_diff + dlc_points_shifted.append(point) + + return dlc_points_shifted diff --git a/src/vame/preprocessing/align_new.py b/src/vame/preprocessing/align_new.py new file mode 100644 index 00000000..94635414 --- /dev/null +++ b/src/vame/preprocessing/align_new.py @@ -0,0 +1,83 @@ +import numpy as np +import pandas as pd +import xarray as xr + + +def align_time_series(data, keypoint1, keypoint2, confidence_threshold): + """ + Aligns the time series by first centralizing all positions around the first keypoint + and then applying rotation to align with the line connecting the two keypoints. + Handles low-confidence points by replacing them with NaNs and interpolating. + + Parameters: + - data (xarray.Dataset): The input dataset. + - keypoint1 (str): The name of the first reference keypoint. + - keypoint2 (str): The name of the second reference keypoint. + - confidence_threshold (float): Confidence threshold below which points are replaced with NaNs. + + Returns: + - xarray.Dataset: The dataset with a new 'position_aligned' variable. + """ + # Extract keypoint indices + keypoints = data.coords["keypoints"].values + idx1 = np.where(keypoints == keypoint1)[0][0] + idx2 = np.where(keypoints == keypoint2)[0][0] + + # Extract positions and confidence values + positions = data["position"].values # Shape: (time, individuals, keypoints, space) + confidence = data["confidence"].values # Shape: (time, individuals, keypoints) + + aligned_positions = np.empty_like(positions) # Preallocate aligned positions + + # Loop over individuals + for ind in range(positions.shape[1]): + individual_positions = positions[ + :, ind, :, : + ] # Shape: (time, keypoints, space) + individual_confidence = confidence[:, ind, :] # Shape: (time, keypoints) + + # Replace low-confidence points with NaN + for kp in range(individual_positions.shape[1]): # Loop over keypoints + for dim in range(2): # Loop over x and y + low_confidence = individual_confidence[:, kp] < confidence_threshold + individual_positions[low_confidence, kp, dim] = np.nan + + # Interpolate NaN values + for kp in range(individual_positions.shape[1]): # Loop over keypoints + for dim in range(2): # Loop over x and y + series = pd.Series(individual_positions[:, kp, dim]) + individual_positions[:, kp, dim] = ( + series.interpolate(method="linear", limit_direction="both") + .bfill() # Backward fill for initial NaNs + .ffill() # Forward fill for final NaNs + .values + ) + + # Centralize all positions around the first keypoint + centralized_positions = ( + individual_positions - individual_positions[:, idx1, :][:, np.newaxis, :] + ) + + # Calculate vectors between keypoints + vector = centralized_positions[:, idx2, :] # Vector from keypoint1 to keypoint2 + angles = np.arctan2(vector[:, 1], vector[:, 0]) # Angles in radians + + # Apply rotation to align the second keypoint along the x-axis + for t in range(centralized_positions.shape[0]): + rotation_matrix = np.array( + [ + [np.cos(-angles[t]), -np.sin(-angles[t])], + [np.sin(-angles[t]), np.cos(-angles[t])], + ] + ) + frame_positions = centralized_positions[t, :, :] + rotated_positions = (rotation_matrix @ frame_positions.T).T + aligned_positions[t, ind, :, :] = rotated_positions + + # Add new variable to the dataset + data["position_aligned"] = ( + ("time", "individuals", "keypoints", "space"), + aligned_positions, + ) + + return data diff --git a/src/vame/util/data_manipulation.py b/src/vame/util/data_manipulation.py index 39f9789b..97bba269 100644 --- a/src/vame/util/data_manipulation.py +++ b/src/vame/util/data_manipulation.py @@ -1,13 +1,14 @@ -import numpy as np from typing import List, Tuple, Optional +import numpy as np +import pandas as pd +import xarray as xr import cv2 as cv import os -from scipy.ndimage import median_filter import tqdm +from scipy.ndimage import median_filter from pynwb import NWBHDF5IO from pynwb.file import NWBFile from hdmf.utils import LabelledDict -import pandas as pd from vame.schemas.project import PoseEstimationFiletype from vame.logging.logger import VameLogger @@ -85,7 +86,7 @@ def read_pose_estimation_file( file_path: str, file_type: Optional[PoseEstimationFiletype] = None, path_to_pose_nwb_series_data: Optional[str] = None, -) -> Tuple[pd.DataFrame, np.ndarray]: +) -> Tuple[pd.DataFrame, np.ndarray, xr.Dataset]: """ Read pose estimation file. @@ -106,7 +107,7 @@ def read_pose_estimation_file( ds = load_vame_dataset(ds_path=file_path) data = nc_to_dataframe(ds) data_mat = pd.DataFrame.to_numpy(data) - return data, data_mat + return data, data_mat, ds # if file_type == PoseEstimationFiletype.csv: # data = pd.read_csv(file_path, skiprows=2, index_col=0) # if "coords" in data: @@ -211,7 +212,7 @@ def interpol_first_rows_nans(arr: np.ndarray) -> np.ndarray: return arr -def crop_and_flip( +def crop_and_flip_legacy( rect: Tuple, src: np.ndarray, points: List[np.ndarray], diff --git a/src/vame/util/gif_pose_helper.py b/src/vame/util/gif_pose_helper.py index b9a731fb..9c384497 100644 --- a/src/vame/util/gif_pose_helper.py +++ b/src/vame/util/gif_pose_helper.py @@ -6,7 +6,7 @@ from vame.logging.logger import VameLogger from vame.util.data_manipulation import ( interpol_first_rows_nans, - crop_and_flip, + crop_and_flip_legacy, background, read_pose_estimation_file, ) @@ -191,7 +191,7 @@ def get_animal_frames( center, size, theta = rect # crop image - out, shifted_points = crop_and_flip( + out, shifted_points = crop_and_flip_legacy( rect, img, pose_list_bordered, diff --git a/tests/test_util.py b/tests/test_util.py index 354d91c4..28426e53 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -13,7 +13,11 @@ def test_pose_to_numpy_file_exists(setup_project_and_convert_pose_to_numpy): 0 ] file_path = os.path.join( - project_path, "data", "processed", file_name, f"{file_name}-PE-seq.npy" + project_path, + "data", + "processed", + file_name, + f"{file_name}-PE-seq.npy", ) assert os.path.exists(file_path) @@ -25,7 +29,11 @@ def test_egocentric_alignment_file_is_created(setup_project_and_align_egocentric project_path = setup_project_and_align_egocentric["config_data"]["project_path"] file_name = setup_project_and_align_egocentric["config_data"]["session_names"][0] file_path = os.path.join( - project_path, "data", "processed", file_name, f"{file_name}-PE-seq.npy" + project_path, + "data", + "processed", + file_name, + f"{file_name}-PE-seq.npy", ) assert os.path.exists(file_path) From 6103a4f5022d868a44cc1abdd5761a24ea383508 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 29 Nov 2024 12:24:54 +0100 Subject: [PATCH 04/77] black --- pyproject.toml | 3 + src/vame/__init__.py | 1 + src/vame/analysis/community_analysis.py | 39 ++-------- src/vame/analysis/generative_functions.py | 7 +- src/vame/analysis/gif_creator.py | 15 +--- src/vame/analysis/pose_segmentation.py | 14 +--- src/vame/analysis/tree_hierarchy.py | 8 +- src/vame/analysis/umap.py | 7 +- src/vame/analysis/videowriter.py | 15 +--- src/vame/initialize_project/new.py | 16 +--- src/vame/logging/logger.py | 15 +--- src/vame/model/create_training.py | 31 ++------ src/vame/model/evaluate.py | 48 +++--------- src/vame/model/rnn_vae.py | 67 ++++------------- src/vame/pipeline.py | 8 +- src/vame/preprocessing/align_egocentrical.py | 26 ++----- src/vame/preprocessing/align_new.py | 8 +- src/vame/preprocessing/clean_timeseries.py | 77 ++++++++++++++++++++ src/vame/preprocessing/preprocessing.py | 27 +++++++ src/vame/schemas/project.py | 4 +- src/vame/schemas/states.py | 20 ++--- src/vame/util/auxiliary.py | 5 +- src/vame/util/cli.py | 4 +- src/vame/util/csv_to_npy.py | 8 +- src/vame/util/data_manipulation.py | 34 +++++++-- src/vame/util/gif_pose_helper.py | 4 +- src/vame/util/report.py | 42 +++-------- tests/test_analysis.py | 46 +++--------- tests/test_initialize_project.py | 4 +- tests/test_model.py | 56 +++----------- tests/test_util.py | 8 +- 31 files changed, 251 insertions(+), 416 deletions(-) create mode 100644 src/vame/preprocessing/clean_timeseries.py create mode 100644 src/vame/preprocessing/preprocessing.py diff --git a/pyproject.toml b/pyproject.toml index a423e742..61f7aecc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,3 +28,6 @@ where = ["src"] [tool.pytest.ini_options] pythonpath = [".", "src"] testpaths = ["tests"] + +[tool.black] +line-length = 119 diff --git a/src/vame/__init__.py b/src/vame/__init__.py index 3e43ad94..b3bc56a6 100644 --- a/src/vame/__init__.py +++ b/src/vame/__init__.py @@ -3,6 +3,7 @@ sys.dont_write_bytecode = True from vame.initialize_project import init_new_project +from vame.preprocessing.preprocessing import preprocessing from vame.model import create_trainset from vame.model import train_model from vame.model import evaluate_model diff --git a/src/vame/analysis/community_analysis.py b/src/vame/analysis/community_analysis.py index ca0232ba..60dbcffb 100644 --- a/src/vame/analysis/community_analysis.py +++ b/src/vame/analysis/community_analysis.py @@ -250,12 +250,7 @@ def get_motif_labels( file_labels = np.load( os.path.join( path_to_dir, - str(n_clusters) - + "_" - + segmentation_algorithm - + "_label_" - + session - + ".npy", + str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy", ) ) shape = len(file_labels) @@ -276,12 +271,7 @@ def get_motif_labels( file_labels = np.load( os.path.join( path_to_dir, - str(n_clusters) - + "_" - + segmentation_algorithm - + "_label_" - + session - + ".npy", + str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy", ) )[:min_frames] community_label.extend(file_labels) @@ -390,11 +380,7 @@ def create_cohort_community_bag( add = input("Extend list or add in the end? (ext/end)") if add == "ext": motif_idx = int(input("Which motif number? ")) - list_idx = int( - input( - "At which position in the list? (pythonic indexing starts at 0) " - ) - ) + list_idx = int(input("At which position in the list? (pythonic indexing starts at 0) ")) community_bag[list_idx].append(motif_idx) if add == "end": motif_idx = int(input("Which motif number? ")) @@ -440,9 +426,7 @@ def get_cohort_community_labels( for j in range(len(clust)): find_clust = np.where(motif_labels == clust[j])[0] community_labels[find_clust] = i - community_labels = np.int64( - scipy.signal.medfilt(community_labels, median_filter_size) - ) + community_labels = np.int64(scipy.signal.medfilt(community_labels, median_filter_size)) community_labels_all.append(community_labels) return community_labels_all @@ -468,12 +452,7 @@ def save_cohort_community_labels_per_file( file_labels = np.load( os.path.join( path_to_dir, - str(n_clusters) - + "_" - + segmentation_algorithm - + "_label_" - + session - + ".npy", + str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy", ) ) community_labels = get_cohort_community_labels( @@ -640,9 +619,7 @@ def community( ), cohort_community_bag, ) - with open( - os.path.join(path_to_dir, "hierarchy" + ".pkl"), "wb" - ) as fp: # Pickling + with open(os.path.join(path_to_dir, "hierarchy" + ".pkl"), "wb") as fp: # Pickling pickle.dump(cohort_community_bag, fp) # Added by Luiz - 11/10/2024 @@ -659,9 +636,7 @@ def community( # # Work in Progress - cohort is False else: - raise NotImplementedError( - "Community analysis for cohort=False is not supported yet." - ) + raise NotImplementedError("Community analysis for cohort=False is not supported yet.") # labels = get_labels(cfg, files, model_name, n_clusters, parametrization) # transition_matrices = compute_transition_matrices( # files, diff --git a/src/vame/analysis/generative_functions.py b/src/vame/analysis/generative_functions.py index 6849e5d2..075a6b0e 100644 --- a/src/vame/analysis/generative_functions.py +++ b/src/vame/analysis/generative_functions.py @@ -333,12 +333,7 @@ def generative_model( os.path.join( path_to_file, "", - str(n_clusters) - + "_" - + segmentation_algorithm - + "_label_" - + session - + ".npy", + str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy", ) ) return random_generative_samples_motif( diff --git a/src/vame/analysis/gif_creator.py b/src/vame/analysis/gif_creator.py index 6a864298..d28c4d5c 100644 --- a/src/vame/analysis/gif_creator.py +++ b/src/vame/analysis/gif_creator.py @@ -102,9 +102,7 @@ def create_video( frame = frames[i] ax2.imshow(frame, cmap=cmap_reversed) # ax2.set_title("Motif %d,\n Community: %s" % (lbl, motifs[lbl]), fontsize=10) - fig.savefig( - os.path.join(path_to_file, "gif_frames", session + "gif_%d.png") % i - ) + fig.savefig(os.path.join(path_to_file, "gif_frames", session + "gif_%d.png") % i) def gif( @@ -205,9 +203,7 @@ def gif( random_state=cfg["random_state"], ) - latent_vector = np.load( - os.path.join(path_to_file, "", "latent_vector_" + session + ".npy") - ) + latent_vector = np.load(os.path.join(path_to_file, "", "latent_vector_" + session + ".npy")) num_points = cfg["num_points"] if num_points > latent_vector.shape[0]: @@ -228,12 +224,7 @@ def gif( umap_label = np.load( os.path.join( path_to_file, - str(n_clusters) - + "_" - + segmentation_algorithm - + "_label_" - + session - + ".npy", + str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy", ) ) elif label == "community": diff --git a/src/vame/analysis/pose_segmentation.py b/src/vame/analysis/pose_segmentation.py index 1fce1670..e6520835 100644 --- a/src/vame/analysis/pose_segmentation.py +++ b/src/vame/analysis/pose_segmentation.py @@ -82,15 +82,9 @@ def embedd_latent_vectors( data_sample_np = data[:, i : temp_win + i].T data_sample_np = np.reshape(data_sample_np, (1, temp_win, num_features)) if use_gpu: - h_n = model.encoder( - torch.from_numpy(data_sample_np) - .type("torch.FloatTensor") - .cuda() - ) + h_n = model.encoder(torch.from_numpy(data_sample_np).type("torch.FloatTensor").cuda()) else: - h_n = model.encoder( - torch.from_numpy(data_sample_np).type("torch.FloatTensor").to() - ) + h_n = model.encoder(torch.from_numpy(data_sample_np).type("torch.FloatTensor").to()) mu, _, _ = model.lmbda(h_n) latent_vector_list.append(mu.cpu().data.numpy()) @@ -406,9 +400,7 @@ def segment_session( ) else: - logger.info( - f"\nSegmentation with {n_clusters} k-means clusters already exists for model {model_name}" - ) + logger.info(f"\nSegmentation with {n_clusters} k-means clusters already exists for model {model_name}") if os.path.exists( os.path.join( diff --git a/src/vame/analysis/tree_hierarchy.py b/src/vame/analysis/tree_hierarchy.py index a718e861..652648a4 100644 --- a/src/vame/analysis/tree_hierarchy.py +++ b/src/vame/analysis/tree_hierarchy.py @@ -42,9 +42,7 @@ def hierarchy_pos( raise TypeError("cannot use hierarchy_pos on a graph that is not a tree") if root is None: if isinstance(G, nx.DiGraph): - root = next( - iter(nx.topological_sort(G)) - ) # allows back compatibility with nx version 1.11 + root = next(iter(nx.topological_sort(G))) # allows back compatibility with nx version 1.11 else: root = random.choice(list(G.nodes)) @@ -121,9 +119,7 @@ def merge_func( for i in range(n_clusters): for j in range(n_clusters): try: - cost = motif_norm[i] + motif_norm[j] / np.abs( - transition_matrix[i, j] + transition_matrix[j, i] - ) + cost = motif_norm[i] + motif_norm[j] / np.abs(transition_matrix[i, j] + transition_matrix[j, i]) except ZeroDivisionError: print( "Error: Transition probabilities between motif " diff --git a/src/vame/analysis/umap.py b/src/vame/analysis/umap.py index 2216d43d..1fd67b25 100644 --- a/src/vame/analysis/umap.py +++ b/src/vame/analysis/umap.py @@ -328,12 +328,7 @@ def visualization( os.path.join( path_to_file, "", - str(n_clusters) - + "_" - + segmentation_algorithm - + "_label_" - + session - + ".npy", + str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy", ) ) output_figure = umap_label_vis( diff --git a/src/vame/analysis/videowriter.py b/src/vame/analysis/videowriter.py index 40f78bdf..a03656aa 100644 --- a/src/vame/analysis/videowriter.py +++ b/src/vame/analysis/videowriter.py @@ -71,19 +71,12 @@ def create_cluster_videos( labels = np.load( os.path.join( path_to_file, - str(n_clusters) - + "_" - + segmentation_algorithm - + "_label_" - + session - + ".npy", + str(n_clusters) + "_" + segmentation_algorithm + "_label_" + session + ".npy", ) ) if flag == "community": if cohort: - logger.info( - "Cohort community videos getting created for " + session + " ..." - ) + logger.info("Cohort community videos getting created for " + session + " ...") labels = np.load( os.path.join( path_to_file, @@ -109,9 +102,7 @@ def create_cluster_videos( ) capture = cv.VideoCapture(video_file_path) if not capture.isOpened(): - raise ValueError( - f"Video capture could not be opened. Ensure the video file is valid.\n {video_file_path}" - ) + raise ValueError(f"Video capture could not be opened. Ensure the video file is valid.\n {video_file_path}") width = capture.get(cv.CAP_PROP_FRAME_WIDTH) height = capture.get(cv.CAP_PROP_FRAME_HEIGHT) fps = 25 # capture.get(cv.CAP_PROP_FPS) diff --git a/src/vame/initialize_project/new.py b/src/vame/initialize_project/new.py index 784d72fa..b6f05328 100644 --- a/src/vame/initialize_project/new.py +++ b/src/vame/initialize_project/new.py @@ -110,20 +110,14 @@ def init_new_project( for i in videos: # Check if it is a folder if os.path.isdir(i): - vids_in_dir = [ - os.path.join(i, vp) for vp in os.listdir(i) if video_type in vp - ] + vids_in_dir = [os.path.join(i, vp) for vp in os.listdir(i) if video_type in vp] vids = vids + vids_in_dir if len(vids_in_dir) == 0: logger.info(f"No videos found in {i}") - logger.info( - f"Perhaps change the video_type, which is currently set to: {video_type}" - ) + logger.info(f"Perhaps change the video_type, which is currently set to: {video_type}") else: videos = vids - logger.info( - f"{len(vids_in_dir)} videos from the directory {i} were added to the project." - ) + logger.info(f"{len(vids_in_dir)} videos from the directory {i} were added to the project.") else: if os.path.isfile(i): vids = vids + [i] @@ -210,9 +204,7 @@ def init_new_project( unique_num_features = list(set(num_features_list)) if len(unique_num_features) > 1: - raise ValueError( - "All pose estimation files must have the same number of features." - ) + raise ValueError("All pose estimation files must have the same number of features.") if config_kwargs is None: config_kwargs = {} diff --git a/src/vame/logging/logger.py b/src/vame/logging/logger.py index 9fd68753..f5fb0b61 100644 --- a/src/vame/logging/logger.py +++ b/src/vame/logging/logger.py @@ -6,8 +6,7 @@ class VameLogger: LOG_FORMAT = ( - "%(asctime)-15s.%(msecs)d %(levelname)-5s --- [%(threadName)s]" - " %(name)-15s : %(lineno)d : %(message)s" + "%(asctime)-15s.%(msecs)d %(levelname)-5s --- [%(threadName)s]" " %(name)-15s : %(lineno)d : %(message)s" ) LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S" @@ -19,9 +18,7 @@ def __init__( ): self.log_level = log_level self.file_handler = None - logging.basicConfig( - level=log_level, format=self.LOG_FORMAT, datefmt=self.LOG_DATE_FORMAT - ) + logging.basicConfig(level=log_level, format=self.LOG_FORMAT, datefmt=self.LOG_DATE_FORMAT) self.logger = logging.getLogger(f"{base_name}") if self.logger.hasHandlers(): self.logger.handlers.clear() @@ -29,9 +26,7 @@ def __init__( self.logger.setLevel(self.log_level) # Stream handler for logging to stdout stream_handler = logging.StreamHandler() - stream_handler.setFormatter( - logging.Formatter(self.LOG_FORMAT, self.LOG_DATE_FORMAT) - ) + stream_handler.setFormatter(logging.Formatter(self.LOG_FORMAT, self.LOG_DATE_FORMAT)) self.logger.addHandler(stream_handler) self.logger.propagate = False @@ -56,9 +51,7 @@ def add_file_handler(self, file_path: str): f.write(f"{line_break}[LOGGING STARTED AT: {handler_datetime}]") self.file_handler = logging.FileHandler(file_path, mode="a") - self.file_handler.setFormatter( - logging.Formatter(self.LOG_FORMAT, self.LOG_DATE_FORMAT) - ) + self.file_handler.setFormatter(logging.Formatter(self.LOG_FORMAT, self.LOG_DATE_FORMAT)) self.logger.addHandler(self.file_handler) def remove_file_handler(self): diff --git a/src/vame/model/create_training.py b/src/vame/model/create_training.py index ea19312f..a9f554c4 100644 --- a/src/vame/model/create_training.py +++ b/src/vame/model/create_training.py @@ -98,9 +98,7 @@ def plot_check_parameter( plt.title("Original signal z-scored") plt.legend() - logger.info( - "Please run the function with check_parameter=False if you are happy with the results" - ) + logger.info("Please run the function with check_parameter=False if you are happy with the results") def traindata_aligned( @@ -170,10 +168,7 @@ def traindata_aligned( if cfg["robust"]: iqr_val = iqr(X_z) - logger.info( - "IQR value: %.2f, IQR cutoff: %.2f" - % (iqr_val, cfg["iqr_factor"] * iqr_val) - ) + logger.info("IQR value: %.2f, IQR cutoff: %.2f" % (iqr_val, cfg["iqr_factor"] * iqr_val)) for i in range(X_z.shape[0]): for marker in range(X_z.shape[1]): if X_z[i, marker] > cfg["iqr_factor"] * iqr_val: @@ -330,10 +325,7 @@ def traindata_fixed( if cfg["robust"]: iqr_val = iqr(X_z) - logger.info( - "IQR value: %.2f, IQR cutoff: %.2f" - % (iqr_val, cfg["iqr_factor"] * iqr_val) - ) + logger.info("IQR value: %.2f, IQR cutoff: %.2f" % (iqr_val, cfg["iqr_factor"] * iqr_val)) for i in range(X_z.shape[0]): for marker in range(X_z.shape[1]): if X_z[i, marker] > cfg["iqr_factor"] * iqr_val: @@ -373,9 +365,7 @@ def traindata_fixed( else: if pose_ref_index is None: - raise ValueError( - "Please provide a pose reference index for training on fixed data. E.g. [0,5]" - ) + raise ValueError("Please provide a pose reference index for training on fixed data. E.g. [0,5]") # save numpy arrays the the test/train info: np.save( os.path.join( @@ -493,15 +483,10 @@ def create_trainset( logger.info("Creating training dataset...") if cfg["robust"]: - logger.info( - "Using robust setting to eliminate outliers! IQR factor: %d" - % cfg["iqr_factor"] - ) + logger.info("Using robust setting to eliminate outliers! IQR factor: %d" % cfg["iqr_factor"]) if not fixed: - logger.info( - "Creating trainset from the vame.egocentrical_alignment() output " - ) + logger.info("Creating trainset from the vame.egocentrical_alignment() output ") traindata_aligned( cfg, sessions, @@ -522,9 +507,7 @@ def create_trainset( ) if not check_parameter: - logger.info( - "A training and test set has been created. Next step: vame.train_model()" - ) + logger.info("A training and test set has been created. Next step: vame.train_model()") except Exception as e: logger.exception(str(e)) diff --git a/src/vame/model/evaluate.py b/src/vame/model/evaluate.py index 5372c922..de2111c4 100644 --- a/src/vame/model/evaluate.py +++ b/src/vame/model/evaluate.py @@ -69,18 +69,10 @@ def plot_reconstruction( x = x.permute(0, 2, 1) if use_gpu: data = x[:, :seq_len_half, :].type("torch.FloatTensor").cuda() - data_fut = ( - x[:, seq_len_half : seq_len_half + FUTURE_STEPS, :] - .type("torch.FloatTensor") - .cuda() - ) + data_fut = x[:, seq_len_half : seq_len_half + FUTURE_STEPS, :].type("torch.FloatTensor").cuda() else: data = x[:, :seq_len_half, :].type("torch.FloatTensor").to() - data_fut = ( - x[:, seq_len_half : seq_len_half + FUTURE_STEPS, :] - .type("torch.FloatTensor") - .to() - ) + data_fut = x[:, seq_len_half : seq_len_half + FUTURE_STEPS, :].type("torch.FloatTensor").to() if FUTURE_DECODER: x_tilde, future, latent, mu, logvar = model(data) @@ -99,9 +91,7 @@ def plot_reconstruction( if FUTURE_DECODER: fig, axs = plt.subplots(2, 5) - fig.suptitle( - "Reconstruction [top] and future prediction [bottom] of input sequence" - ) + fig.suptitle("Reconstruction [top] and future prediction [bottom] of input sequence") for i in range(5): axs[0, i].plot(data_orig[i, ...], color="k", label="Sequence Data") axs[0, i].plot( @@ -129,9 +119,7 @@ def plot_reconstruction( fig.set_tight_layout(True) if not suffix: fig.savefig( - os.path.join( - filepath, "evaluate", "Reconstruction_" + model_name + ".png" - ), + os.path.join(filepath, "evaluate", "Reconstruction_" + model_name + ".png"), bbox_inches="tight", ) elif suffix: @@ -174,12 +162,8 @@ def plot_loss( basepath = os.path.join(cfg["project_path"], "model", "model_losses") train_loss = np.load(os.path.join(basepath, "train_losses_" + model_name + ".npy")) test_loss = np.load(os.path.join(basepath, "test_losses_" + model_name + ".npy")) - mse_loss_train = np.load( - os.path.join(basepath, "mse_train_losses_" + model_name + ".npy") - ) - mse_loss_test = np.load( - os.path.join(basepath, "mse_test_losses_" + model_name + ".npy") - ) + mse_loss_train = np.load(os.path.join(basepath, "mse_train_losses_" + model_name + ".npy")) + mse_loss_test = np.load(os.path.join(basepath, "mse_test_losses_" + model_name + ".npy")) km_losses = np.load(os.path.join(basepath, "kmeans_losses_" + model_name + ".npy")) kl_loss = np.load(os.path.join(basepath, "kl_losses_" + model_name + ".npy")) fut_loss = np.load(os.path.join(basepath, "fut_losses_" + model_name + ".npy")) @@ -196,9 +180,7 @@ def plot_loss( ax1.plot(kl_loss, label="KL-Loss") ax1.plot(fut_loss, label="Prediction-Loss") ax1.legend() - fig.savefig( - os.path.join(filepath, "evaluate", "MSE-and-KL-Loss" + model_name + ".png") - ) + fig.savefig(os.path.join(filepath, "evaluate", "MSE-and-KL-Loss" + model_name + ".png")) def eval_temporal( @@ -308,9 +290,7 @@ def eval_temporal( ) ) elif snapshot: - model.load_state_dict( - torch.load(snapshot), map_location=torch.device("cpu") - ) + model.load_state_dict(torch.load(snapshot), map_location=torch.device("cpu")) model.eval() # toggle evaluation mode testset = SEQUENCE_DATASET( @@ -320,9 +300,7 @@ def eval_temporal( temporal_window=TEMPORAL_WINDOW, logger_config=logger_config, ) - test_loader = Data.DataLoader( - testset, batch_size=TEST_BATCH_SIZE, shuffle=True, drop_last=True - ) + test_loader = Data.DataLoader(testset, batch_size=TEST_BATCH_SIZE, shuffle=True, drop_last=True) if not snapshot: plot_reconstruction( @@ -405,13 +383,9 @@ def evaluate_model( if not use_snapshots: eval_temporal(cfg, use_gpu, model_name, fixed) # suffix=suffix elif use_snapshots: - snapshots = os.listdir( - os.path.join(cfg["project_path"], "model", "best_model", "snapshots") - ) + snapshots = os.listdir(os.path.join(cfg["project_path"], "model", "best_model", "snapshots")) for snap in snapshots: - fullpath = os.path.join( - cfg["project_path"], "model", "best_model", "snapshots", snap - ) + fullpath = os.path.join(cfg["project_path"], "model", "best_model", "snapshots", snap) epoch = snap.split("_")[-1] eval_temporal( cfg, diff --git a/src/vame/model/rnn_vae.py b/src/vame/model/rnn_vae.py index ca0e5b49..cc1c88b6 100644 --- a/src/vame/model/rnn_vae.py +++ b/src/vame/model/rnn_vae.py @@ -175,9 +175,7 @@ def kl_annealing( elif function == "sigmoid": new_weight = float(1 / (1 + np.exp(-0.9 * (epoch - annealtime)))) else: - raise NotImplementedError( - 'currently only "linear" and "sigmoid" are implemented' - ) + raise NotImplementedError('currently only "linear" and "sigmoid" are implemented') return new_weight else: @@ -302,18 +300,10 @@ def train( data_item = data_item.permute(0, 2, 1) if use_gpu: data = data_item[:, :seq_len_half, :].type("torch.FloatTensor").cuda() - fut = ( - data_item[:, seq_len_half : seq_len_half + future_steps, :] - .type("torch.FloatTensor") - .cuda() - ) + fut = data_item[:, seq_len_half : seq_len_half + future_steps, :].type("torch.FloatTensor").cuda() else: data = data_item[:, :seq_len_half, :].type("torch.FloatTensor").to() - fut = ( - data_item[:, seq_len_half : seq_len_half + future_steps, :] - .type("torch.FloatTensor") - .to() - ) + fut = data_item[:, seq_len_half : seq_len_half + future_steps, :].type("torch.FloatTensor").to() if noise is True: data_gaussian = gaussian(data, True, seq_len_half) @@ -327,12 +317,7 @@ def train( kmeans_loss = cluster_loss(latent.T, kloss, klmbda, bsize) kl_loss = kullback_leibler_loss(mu, logvar) kl_weight = kl_annealing(epoch, kl_start, annealtime, anneal_function) - loss = ( - rec_loss - + fut_rec_loss - + BETA * kl_weight * kl_loss - + kl_weight * kmeans_loss - ) + loss = rec_loss + fut_rec_loss + BETA * kl_weight * kl_loss + kl_weight * kmeans_loss fut_loss += fut_rec_loss.item() else: data_tilde, latent, mu, logvar = model(data_gaussian) @@ -536,15 +521,9 @@ def train_model(config: str, save_logs: bool = False) -> None: fixed = cfg["egocentric_data"] logger.info("Train Variational Autoencoder - model name: %s \n" % model_name) - if not os.path.exists( - os.path.join(cfg["project_path"], "model", "best_model", "") - ): + if not os.path.exists(os.path.join(cfg["project_path"], "model", "best_model", "")): os.mkdir(os.path.join(cfg["project_path"], "model", "best_model", "")) - os.mkdir( - os.path.join( - cfg["project_path"], "model", "best_model", "snapshots", "" - ) - ) + os.mkdir(os.path.join(cfg["project_path"], "model", "best_model", "snapshots", "")) os.mkdir(os.path.join(cfg["project_path"], "model", "model_losses", "")) # make sure torch uses cuda for GPU computing @@ -555,9 +534,7 @@ def train_model(config: str, save_logs: bool = False) -> None: logger.info("GPU used: {}".format(torch.cuda.get_device_name(0))) else: torch.device("cpu") - logger.info( - "warning, a GPU was not found... proceeding with CPU (slow!) \n" - ) + logger.info("warning, a GPU was not found... proceeding with CPU (slow!) \n") # raise NotImplementedError('GPU Computing is required!') # HYPERPARAMETERS @@ -687,16 +664,12 @@ def train_model(config: str, save_logs: bool = False) -> None: ) ) try: - logger.info( - "Loading pretrained weights from %s\n" % pretrained_model - ) + logger.info("Loading pretrained weights from %s\n" % pretrained_model) model.load_state_dict(torch.load(pretrained_model)) KL_START = 0 ANNEALTIME = 1 except Exception: - logger.error( - "Could not load pretrained model. Check file path in config.yaml." - ) + logger.error("Could not load pretrained model. Check file path in config.yaml.") """ DATASET """ trainset = SEQUENCE_DATASET( @@ -712,19 +685,14 @@ def train_model(config: str, save_logs: bool = False) -> None: temporal_window=TEMPORAL_WINDOW, ) - train_loader = Data.DataLoader( - trainset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, drop_last=True - ) - test_loader = Data.DataLoader( - testset, batch_size=TEST_BATCH_SIZE, shuffle=True, drop_last=True - ) + train_loader = Data.DataLoader(trainset, batch_size=TRAIN_BATCH_SIZE, shuffle=True, drop_last=True) + test_loader = Data.DataLoader(testset, batch_size=TEST_BATCH_SIZE, shuffle=True, drop_last=True) optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, amsgrad=True) if optimizer_scheduler: logger.info( - "Scheduler step size: %d, Scheduler gamma: %.2f\n" - % (scheduler_step_size, cfg["scheduler_gamma"]) + "Scheduler step size: %d, Scheduler gamma: %.2f\n" % (scheduler_step_size, cfg["scheduler_gamma"]) ) # Thanks to @alexcwsmith for the optimized scheduler contribution scheduler = ReduceLROnPlateau( @@ -737,9 +705,7 @@ def train_model(config: str, save_logs: bool = False) -> None: verbose=True, ) else: - scheduler = StepLR( - optimizer, step_size=scheduler_step_size, gamma=1, last_epoch=-1 - ) + scheduler = StepLR(optimizer, step_size=scheduler_step_size, gamma=1, last_epoch=-1) logger.info("Start training... ") for epoch in tqdm( @@ -817,12 +783,7 @@ def train_model(config: str, save_logs: bool = False) -> None: "model", "best_model", "snapshots", - model_name - + "_" - + cfg["project_name"] - + "_epoch_" - + str(epoch) - + ".pkl", + model_name + "_" + cfg["project_name"] + "_epoch_" + str(epoch) + ".pkl", ), ) diff --git a/src/vame/pipeline.py b/src/vame/pipeline.py index 5fb73371..cba1b344 100644 --- a/src/vame/pipeline.py +++ b/src/vame/pipeline.py @@ -64,9 +64,7 @@ def get_raw_datasets(self) -> xr.Dataset: datasets = list() attributes = list() for session in sessions: - ds_path = ( - Path(self.config["project_path"]) / "data" / "raw" / f"{session}.nc" - ) + ds_path = Path(self.config["project_path"]) / "data" / "raw" / f"{session}.nc" ds = load_vame_dataset(ds_path=ds_path) ds = ds.expand_dims({"session": [session]}) datasets.append(ds) @@ -78,9 +76,7 @@ def get_raw_datasets(self) -> xr.Dataset: dss_attrs.setdefault(key, []).append(value) for key, values in dss_attrs.items(): unique_values = unique_in_order(values) # Maintain order of unique values - dss_attrs[key] = ( - unique_values[0] if len(unique_values) == 1 else unique_values - ) + dss_attrs[key] = unique_values[0] if len(unique_values) == 1 else unique_values for key, value in dss_attrs.items(): dss.attrs[key] = value return dss diff --git a/src/vame/preprocessing/align_egocentrical.py b/src/vame/preprocessing/align_egocentrical.py index c2d3b1cb..36ba8240 100644 --- a/src/vame/preprocessing/align_egocentrical.py +++ b/src/vame/preprocessing/align_egocentrical.py @@ -118,9 +118,7 @@ def align_mouse_legacy( pose_list_bordered = [] for i in pose_list: - pose_list_bordered.append( - (int(i[idx][0] + crop_size[0]), int(i[idx][1] + crop_size[1])) - ) + pose_list_bordered.append((int(i[idx][0] + crop_size[0]), int(i[idx][1] + crop_size[1]))) img = cv.copyMakeBorder( frame, @@ -363,10 +361,7 @@ def egocentric_alignment_legacy( # call function and save into your VAME data folder paths_to_pose_nwb_series_data = cfg["paths_to_pose_nwb_series_data"] for i, session in enumerate(sessions): - logger.info( - "Aligning session %s, Pose confidence value: %.2f" - % (session, confidence) - ) + logger.info("Aligning session %s, Pose confidence value: %.2f" % (session, confidence)) egocentric_time_series, frames = alignment_legacy( project_path=project_path, session=session, @@ -405,9 +400,7 @@ def egocentric_alignment_legacy( egocentric_time_series_shifted, ) - logger.info( - "Your data is now in the right format and you can call vame.create_trainset()" - ) + logger.info("Your data is now in the right format and you can call vame.create_trainset()") except Exception as e: logger.exception(f"{e}") raise e @@ -477,10 +470,7 @@ def egocentric_alignment( # call function and save into your VAME data folder for i, session in enumerate(sessions): - logger.info( - "Aligning session %s, Pose confidence value: %.2f" - % (session, confidence) - ) + logger.info("Aligning session %s, Pose confidence value: %.2f" % (session, confidence)) # read out data file_path = str(Path(project_path) / "data" / "raw" / f"{session}.nc") _, data_mat, ds = read_pose_estimation_file(file_path=file_path) @@ -540,9 +530,7 @@ def egocentric_alignment( result_file = Path(project_path) / "data" / "processed" / session / f"{session}-aligned.nc" ds.to_netcdf(result_file, engine="scipy") - logger.info( - "Your data is now in the right format and you can call vame.create_trainset()" - ) + logger.info("Your data is now in the right format and you can call vame.create_trainset()") except Exception as e: logger.exception(f"{e}") raise e @@ -601,9 +589,7 @@ def alignment( pose_list_bordered = [] for i in pose_list: - pose_list_bordered.append( - (int(i[idx][0] + crop_size[0]), int(i[idx][1] + crop_size[1])) - ) + pose_list_bordered.append((int(i[idx][0] + crop_size[0]), int(i[idx][1] + crop_size[1]))) punkte = [] for i in pose_ref_index: diff --git a/src/vame/preprocessing/align_new.py b/src/vame/preprocessing/align_new.py index 94635414..7bce5e91 100644 --- a/src/vame/preprocessing/align_new.py +++ b/src/vame/preprocessing/align_new.py @@ -31,9 +31,7 @@ def align_time_series(data, keypoint1, keypoint2, confidence_threshold): # Loop over individuals for ind in range(positions.shape[1]): - individual_positions = positions[ - :, ind, :, : - ] # Shape: (time, keypoints, space) + individual_positions = positions[:, ind, :, :] # Shape: (time, keypoints, space) individual_confidence = confidence[:, ind, :] # Shape: (time, keypoints) # Replace low-confidence points with NaN @@ -54,9 +52,7 @@ def align_time_series(data, keypoint1, keypoint2, confidence_threshold): ) # Centralize all positions around the first keypoint - centralized_positions = ( - individual_positions - individual_positions[:, idx1, :][:, np.newaxis, :] - ) + centralized_positions = individual_positions - individual_positions[:, idx1, :][:, np.newaxis, :] # Calculate vectors between keypoints vector = centralized_positions[:, idx2, :] # Vector from keypoint1 to keypoint2 diff --git a/src/vame/preprocessing/clean_timeseries.py b/src/vame/preprocessing/clean_timeseries.py new file mode 100644 index 00000000..87351204 --- /dev/null +++ b/src/vame/preprocessing/clean_timeseries.py @@ -0,0 +1,77 @@ +from pathlib import Path +import numpy as np +from scipy.stats import iqr + +from vame.logging.logger import VameLogger +from vame.io.load_poses import load_vame_dataset +from vame.util.data_manipulation import interpolate_nans_with_pandas + + +logger_config = VameLogger(__name__) +logger = logger_config.logger + + +def clean_timeseries( + config: dict, +): + X_all_sessions = [] + pos = [0] + pos_temp = 0 + + session_names = config["session_names"] + for session in session_names: + logger.info("z-scoring of session %s" % session) + + # path_to_file = Path(config["project_path"]) / "data" / "processed" / session / session + "-PE-seq.npy" + # data = np.load(path_to_file) + + path_to_file = Path(config["project_path"]) / "data" / "processed" / session / session + "-aligned.nc" + ds = load_vame_dataset(path_to_file) + X = ds.position_aligned.sel(individuals="individual_0").values + + # Standardize data + X_mean = np.mean(X, axis=0) + X_std = np.std(X, axis=0) + X_z = (X - X_mean) / X_std + + # Robust interquartile range outlier detection + if config["robust"]: + iqr_val = iqr(X_z, axis=0) + logger.info("IQR value: %.2f, IQR cutoff: %.2f" % (iqr_val, config["iqr_factor"] * iqr_val)) + for t in range(X_z.shape[0]): # Iterate over time dimension + for kp in range(X_z.shape[1]): # Iterate over keypoints dimension + for sp in range(X_z.shape[2]): # Iterate over space dimennsion (x, y) + if X_z[t, kp, sp] > config["iqr_factor"] * iqr_val[kp, sp]: + X_z[t, kp, sp] = np.nan + elif X_z[t, kp, sp] < -config["iqr_factor"] * iqr_val[kp, sp]: + X_z[t, kp, sp] = np.nan + X_z = interpolate_nans_with_pandas(X_z) + + X_len = X.shape[0] + pos_temp += X_len + pos.append(pos_temp) + X_all_sessions.append(X_z) + + X_all_sessions = np.concatenate(X_all_sessions, axis=0) + + # Detect and delete anchors + detect_anchors = np.std(X_all_sessions, axis=0) + sort_anchors = np.sort(detect_anchors) + if sort_anchors[0] == sort_anchors[1]: + anchors = np.where(detect_anchors == sort_anchors[0])[0] + anchor_1_temp = anchors[0] + anchor_2_temp = anchors[1] + else: + anchor_1_temp = int(np.where(detect_anchors == sort_anchors[0])[0]) + anchor_2_temp = int(np.where(detect_anchors == sort_anchors[1])[0]) + + if anchor_1_temp > anchor_2_temp: + anchor_1 = anchor_1_temp + anchor_2 = anchor_2_temp + else: + anchor_1 = anchor_2_temp + anchor_2 = anchor_1_temp + + X = np.delete(X, anchor_1, 1) + X = np.delete(X, anchor_2, 1) + X = X.T diff --git a/src/vame/preprocessing/preprocessing.py b/src/vame/preprocessing/preprocessing.py new file mode 100644 index 00000000..fff20f28 --- /dev/null +++ b/src/vame/preprocessing/preprocessing.py @@ -0,0 +1,27 @@ +from pathlib import Path +import xarray as xr + +from vame.logging.logger import VameLogger +from vame.preprocessing.align_egocentrical import ( + egocentric_alignment_legacy, + egocentric_alignment, +) + + +def preprocessing( + config: dict, + pose_ref_1: str = "snout", + pose_ref_2: str = "tailbase", + save_logs: bool = False, +): + + egocentric_alignment( + config=config, + pose_ref_1=pose_ref_1, + pose_ref_2=pose_ref_2, + ) + + clean_timeseries( + config=config, + save_logs=save_logs, + ) diff --git a/src/vame/schemas/project.py b/src/vame/schemas/project.py index 60ef0de3..037f75e1 100644 --- a/src/vame/schemas/project.py +++ b/src/vame/schemas/project.py @@ -29,9 +29,7 @@ class ProjectSchema(BaseModel): title="Project name", ) creation_datetime: str = Field( - default_factory=lambda: datetime.now(timezone.utc).isoformat( - timespec="seconds" - ), + default_factory=lambda: datetime.now(timezone.utc).isoformat(timespec="seconds"), title="Creation datetime", ) model_name: str = Field( diff --git a/src/vame/schemas/states.py b/src/vame/schemas/states.py index aba933a8..00724e41 100644 --- a/src/vame/schemas/states.py +++ b/src/vame/schemas/states.py @@ -84,9 +84,7 @@ class MotifVideosFunctionSchema(BaseStateSchema): title="Type of video", default=".mp4", ) - segmentation_algorithm: SegmentationAlgorithms = Field( - title="Segmentation algorithm" - ) + segmentation_algorithm: SegmentationAlgorithms = Field(title="Segmentation algorithm") output_video_type: str = Field( title="Type of output video", default=".mp4", @@ -95,9 +93,7 @@ class MotifVideosFunctionSchema(BaseStateSchema): class CommunityFunctionSchema(BaseStateSchema): cohort: bool = Field(title="Cohort", default=True) - segmentation_algorithm: SegmentationAlgorithms = Field( - title="Segmentation algorithm" - ) + segmentation_algorithm: SegmentationAlgorithms = Field(title="Segmentation algorithm") cut_tree: int | None = Field( title="Cut tree", default=None, @@ -105,9 +101,7 @@ class CommunityFunctionSchema(BaseStateSchema): class CommunityVideosFunctionSchema(BaseStateSchema): - segmentation_algorithm: SegmentationAlgorithms = Field( - title="Segmentation algorithm" - ) + segmentation_algorithm: SegmentationAlgorithms = Field(title="Segmentation algorithm") cohort: bool = Field(title="Cohort", default=True) video_type: str = Field( title="Type of video", @@ -120,9 +114,7 @@ class CommunityVideosFunctionSchema(BaseStateSchema): class VisualizationFunctionSchema(BaseStateSchema): - segmentation_algorithm: SegmentationAlgorithms = Field( - title="Segmentation algorithm" - ) + segmentation_algorithm: SegmentationAlgorithms = Field(title="Segmentation algorithm") label: Optional[str] = Field( title="Type of labels to visualize", default=None, @@ -130,9 +122,7 @@ class VisualizationFunctionSchema(BaseStateSchema): class GenerativeModelFunctionSchema(BaseStateSchema): - segmentation_algorithm: SegmentationAlgorithms = Field( - title="Segmentation algorithm" - ) + segmentation_algorithm: SegmentationAlgorithms = Field(title="Segmentation algorithm") mode: GenerativeModelModeEnum = Field( title="Mode for generating samples", default=GenerativeModelModeEnum.sampling, diff --git a/src/vame/util/auxiliary.py b/src/vame/util/auxiliary.py index c206818b..560f6293 100644 --- a/src/vame/util/auxiliary.py +++ b/src/vame/util/auxiliary.py @@ -137,10 +137,7 @@ def read_config(configname: str) -> dict: write_config(configname, cfg) except Exception as err: if len(err.args) > 2: - if ( - err.args[2] - == "could not determine a constructor for the tag '!!python/tuple'" - ): + if err.args[2] == "could not determine a constructor for the tag '!!python/tuple'": with open(path, "r") as ymlfile: cfg = yaml.load(ymlfile, Loader=yaml.SafeLoader) write_config(configname, cfg) diff --git a/src/vame/util/cli.py b/src/vame/util/cli.py index e0a09fdf..f7aa5ce7 100644 --- a/src/vame/util/cli.py +++ b/src/vame/util/cli.py @@ -24,6 +24,4 @@ def get_sessions_from_user_input( if user_input in cfg["session_names"]: sessions = [user_input] else: - raise ValueError( - "Invalid input. Please enter yes, no, or a valid session name." - ) + raise ValueError("Invalid input. Please enter yes, no, or a valid session name.") diff --git a/src/vame/util/csv_to_npy.py b/src/vame/util/csv_to_npy.py index 4010c03e..7ce2c382 100644 --- a/src/vame/util/csv_to_npy.py +++ b/src/vame/util/csv_to_npy.py @@ -89,9 +89,7 @@ def pose_to_numpy( i = interpol_first_rows_nans(i) positions = np.concatenate(pose_list, axis=1) - final_positions = np.zeros( - (data_mat.shape[0], int(data_mat.shape[1] / 3) * 2) - ) + final_positions = np.zeros((data_mat.shape[0], int(data_mat.shape[1] / 3) * 2)) jdx = 0 idx = 0 @@ -113,9 +111,7 @@ def pose_to_numpy( ) logger.info("conversion from DeepLabCut csv to numpy complete...") - logger.info( - "Your data is now in right format and you can call vame.create_trainset()" - ) + logger.info("Your data is now in right format and you can call vame.create_trainset()") except Exception as e: logger.exception(f"{e}") raise e diff --git a/src/vame/util/data_manipulation.py b/src/vame/util/data_manipulation.py index 97bba269..563ddcc3 100644 --- a/src/vame/util/data_manipulation.py +++ b/src/vame/util/data_manipulation.py @@ -212,6 +212,32 @@ def interpol_first_rows_nans(arr: np.ndarray) -> np.ndarray: return arr +def interpolate_nans_with_pandas(data: np.ndarray) -> np.ndarray: + """ + Interpolate NaN values along the time axis of a 3D NumPy array using Pandas. + + Parameters: + ----------- + data : numpy.ndarray + Input 3D array of shape (time, keypoints, space). + + Returns: + -------- + numpy.ndarray: + Array with NaN values interpolated. + """ + for kp in range(data.shape[1]): # Loop over keypoints dimension + for sp in range(data.shape[2]): # Loop over space dimension (x, y) + series = pd.Series(data[:, kp, sp]) + series_interpolated = series.interpolate( + method="linear", + limit_direction="both", + axis=0, + ) + data[:, kp, sp] = series_interpolated.values + return data + + def crop_and_flip_legacy( rect: Tuple, src: np.ndarray, @@ -363,9 +389,7 @@ def nc_to_dataframe(nc_data): # Flatten position data position_data = nc_data["position"].isel(individuals=0).values - position_column_names = [ - f"{keypoint}_{sp}" for keypoint in keypoints for sp in space - ] + position_column_names = [f"{keypoint}_{sp}" for keypoint in keypoints for sp in space] position_flattened = position_data.reshape(position_data.shape[0], -1) # Create a DataFrame for position data @@ -383,9 +407,7 @@ def nc_to_dataframe(nc_data): # Reorder columns: keypoint_x, keypoint_y, keypoint_confidence reordered_columns = [] for keypoint in keypoints: - reordered_columns.extend( - [f"{keypoint}_x", f"{keypoint}_y", f"{keypoint}_confidence"] - ) + reordered_columns.extend([f"{keypoint}_x", f"{keypoint}_y", f"{keypoint}_confidence"]) combined_df = combined_df[reordered_columns] diff --git a/src/vame/util/gif_pose_helper.py b/src/vame/util/gif_pose_helper.py index 9c384497..752efcba 100644 --- a/src/vame/util/gif_pose_helper.py +++ b/src/vame/util/gif_pose_helper.py @@ -145,9 +145,7 @@ def get_animal_frames( frame = frame - bg frame[frame <= 0] = 0 except Exception: - logger.info( - f"Couldn't find a frame in capture.read(). #Frame: {idx + start + lag}" - ) + logger.info(f"Couldn't find a frame in capture.read(). #Frame: {idx + start + lag}") continue # Read coordinates and add border diff --git a/src/vame/util/report.py b/src/vame/util/report.py index 1bb36fd3..e13504b5 100644 --- a/src/vame/util/report.py +++ b/src/vame/util/report.py @@ -38,15 +38,10 @@ def report( report_folder.mkdir(exist_ok=True) # Motifs and Communities - if ( - project_states.get("segment_session", {}).get("execution_state", "") - != "success" - ): + if project_states.get("segment_session", {}).get("execution_state", "") != "success": raise Exception("Segmentation failed. Skipping motifs and communities report.") if project_states.get("community", {}).get("execution_state", "") != "success": - raise Exception( - "Community detection failed. Skipping motifs and communities report." - ) + raise Exception("Community detection failed. Skipping motifs and communities report.") ml = np.load( project_path @@ -96,8 +91,7 @@ def report( title=f"Community and Motif Counts - Cohort - {model_name} - {segmentation_algorithm} - {n_clusters}", save_to_file=True, save_path=str( - report_folder - / f"community_motifs_cohort_{model_name}_{segmentation_algorithm}-{n_clusters}.png" + report_folder / f"community_motifs_cohort_{model_name}_{segmentation_algorithm}-{n_clusters}.png" ), ) @@ -141,8 +135,7 @@ def report( title=f"Community and Motif Counts - {session} - {model_name} - {segmentation_algorithm} - {n_clusters}", save_to_file=True, save_path=str( - report_folder - / f"community_motifs_{session}_{model_name}_{segmentation_algorithm}-{n_clusters}.png" + report_folder / f"community_motifs_{session}_{model_name}_{segmentation_algorithm}-{n_clusters}.png" ), ) @@ -165,9 +158,7 @@ def plot_community_motifs( community_indices = [community for community, count in communities] community_counts = [count for community, count in communities] total_community_counts = sum(community_counts) - community_percentages = [ - (count / total_community_counts) * 100 for count in community_counts - ] + community_percentages = [(count / total_community_counts) * 100 for count in community_counts] # Define positions and bar widths bar_width = 0.8 @@ -203,9 +194,7 @@ def plot_community_motifs( ax2 = ax1.twinx() ax2.set_ylim(ax1.get_ylim()) ax2.set_yticks(ax1.get_yticks()) - ax2.set_yticklabels( - [f"{(tick / total_community_counts) * 100:.1f}%" for tick in ax1.get_yticks()] - ) + ax2.set_yticklabels([f"{(tick / total_community_counts) * 100:.1f}%" for tick in ax1.get_yticks()]) ax2.set_ylabel("Percentage") # Overlay motif bars within each community @@ -217,16 +206,12 @@ def plot_community_motifs( motifs_sorted = [motif for motif, count in motif_counts] counts_sorted = [count for motif, count in motif_counts] total_motif_counts = sum(counts_sorted) - motif_percentages = [ - (count / total_motif_counts) * 100 for count in counts_sorted - ] + motif_percentages = [(count / total_motif_counts) * 100 for count in counts_sorted] num_motifs = len(motifs_sorted) # Adjust motif bar width to fill the community bar width if num_motifs > 0: - motif_width = ( - motif_bar_width / num_motifs * 0.9 - ) # Slightly reduce width to create space between bars + motif_width = motif_bar_width / num_motifs * 0.9 # Slightly reduce width to create space between bars else: motif_width = motif_bar_width @@ -240,8 +225,7 @@ def plot_community_motifs( bars = ax1.bar( motif_positions, counts_sorted, - width=motif_width - * 0.9, # Slightly reduce width to create space between bars + width=motif_width * 0.9, # Slightly reduce width to create space between bars label=f"Motifs in Community {community}", ) @@ -254,9 +238,7 @@ def plot_community_motifs( ha="center", va="bottom", fontsize=9, - color=( - "white" if bar.get_facecolor()[0] < 0.5 else "black" - ), # Contrast with bar color + color=("white" if bar.get_facecolor()[0] < 0.5 else "black"), # Contrast with bar color ) # Add percentage values on top of motif bars @@ -268,9 +250,7 @@ def plot_community_motifs( ha="center", va="bottom", fontsize=8, - color=( - "white" if bar.get_facecolor()[0] < 0.5 else "black" - ), # Contrast with bar color + color=("white" if bar.get_facecolor()[0] < 0.5 else "black"), # Contrast with bar color ) # Formatting diff --git a/tests/test_analysis.py b/tests/test_analysis.py index e41206d4..e80502d5 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -27,9 +27,7 @@ def test_pose_segmentation_hmm_files_exists( "individual_segmentation": individual_segmentation, } mock_config["hmm_trained"] = hmm_trained - with patch( - "vame.analysis.pose_segmentation.read_config", return_value=mock_config - ) as mock_read_config: + with patch("vame.analysis.pose_segmentation.read_config", return_value=mock_config) as mock_read_config: with patch("builtins.input", return_value="yes"): vame.segment_session( setup_project_and_train_model["config_path"], @@ -39,13 +37,7 @@ def test_pose_segmentation_hmm_files_exists( file = setup_project_and_train_model["config_data"]["session_names"][0] model_name = setup_project_and_train_model["config_data"]["model_name"] n_clusters = setup_project_and_train_model["config_data"]["n_clusters"] - save_base_path = ( - Path(project_path) - / "results" - / file - / model_name - / f"{segmentation_algorithm}-{n_clusters}" - ) + save_base_path = Path(project_path) / "results" / file / model_name / f"{segmentation_algorithm}-{n_clusters}" latent_vector_path = save_base_path / f"latent_vector_{file}.npy" motif_usage_path = save_base_path / f"motif_usage_{file}.npy" @@ -54,9 +46,7 @@ def test_pose_segmentation_hmm_files_exists( @pytest.mark.parametrize("segmentation_algorithm", ["hmm", "kmeans"]) -def test_motif_videos_mp4_files_exists( - setup_project_and_train_model, segmentation_algorithm -): +def test_motif_videos_mp4_files_exists(setup_project_and_train_model, segmentation_algorithm): vame.motif_videos( setup_project_and_train_model["config_path"], segmentation_algorithm=segmentation_algorithm, @@ -82,9 +72,7 @@ def test_motif_videos_mp4_files_exists( @pytest.mark.parametrize("segmentation_algorithm", ["hmm", "kmeans"]) -def test_motif_videos_avi_files_exists( - setup_project_and_train_model, segmentation_algorithm -): +def test_motif_videos_avi_files_exists(setup_project_and_train_model, segmentation_algorithm): # Check if the files are created vame.motif_videos( setup_project_and_train_model["config_path"], @@ -144,9 +132,7 @@ def test_motif_videos_avi_files_exists( @pytest.mark.parametrize("segmentation_algorithm", ["hmm", "kmeans"]) -def test_cohort_community_files_exists( - setup_project_and_train_model, segmentation_algorithm -): +def test_cohort_community_files_exists(setup_project_and_train_model, segmentation_algorithm): # Check if the files are created vame.community( setup_project_and_train_model["config_path"], @@ -158,17 +144,10 @@ def test_cohort_community_files_exists( project_path = setup_project_and_train_model["config_data"]["project_path"] n_clusters = setup_project_and_train_model["config_data"]["n_clusters"] - base_path = ( - Path(project_path) - / "results" - / "community_cohort" - / f"{segmentation_algorithm}-{n_clusters}" - ) + base_path = Path(project_path) / "results" / "community_cohort" / f"{segmentation_algorithm}-{n_clusters}" cohort_path = base_path / "cohort_transition_matrix.npy" community_path = base_path / "cohort_community_label.npy" - cohort_segmentation_algorithm_path = ( - base_path / f"cohort_{segmentation_algorithm}_label.npy" - ) + cohort_segmentation_algorithm_path = base_path / f"cohort_{segmentation_algorithm}_label.npy" cohort_community_bag_path = base_path / "cohort_community_bag.npy" assert cohort_path.exists() @@ -268,12 +247,7 @@ def test_visualization_output_files( project_path = setup_project_and_train_model["config_data"]["project_path"] save_base_path = ( - Path(project_path) - / "results" - / file - / model_name - / f"{segmentation_algorithm}-{n_clusters}" - / "community" + Path(project_path) / "results" / file / model_name / f"{segmentation_algorithm}-{n_clusters}" / "community" ) assert len(list(save_base_path.glob(f"umap_vis*{file}.png"))) > 0 @@ -316,9 +290,7 @@ def test_report( config=setup_project_and_train_model["config_path"], segmentation_algorithm=segmentation_algorithm, ) - reports_path = ( - Path(setup_project_and_train_model["config_data"]["project_path"]) / "reports" - ) + reports_path = Path(setup_project_and_train_model["config_data"]["project_path"]) / "reports" assert len(list(reports_path.glob("*.png"))) > 0 diff --git a/tests/test_initialize_project.py b/tests/test_initialize_project.py index 0f39cae4..7fea631e 100644 --- a/tests/test_initialize_project.py +++ b/tests/test_initialize_project.py @@ -17,9 +17,7 @@ def test_project_name_config(setup_project_not_aligned_data): """ config = Path(setup_project_not_aligned_data["config_path"]) config_values = read_config(config) - assert ( - config_values["project_name"] == setup_project_not_aligned_data["project_name"] - ) + assert config_values["project_name"] == setup_project_not_aligned_data["project_name"] def test_existing_project(): diff --git a/tests/test_model.py b/tests/test_model.py index f2f827da..5b92489f 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -47,42 +47,14 @@ def test_train_model_losses_files_exists(setup_project_and_train_model): model_name = setup_project_and_train_model["config_data"]["model_name"] # save logged losses - train_losses_path = ( - Path(project_path) / "model" / "model_losses" / f"train_losses_{model_name}.npy" - ) - test_losses_path = ( - Path(project_path) / "model" / "model_losses" / f"test_losses_{model_name}.npy" - ) - kmeans_losses_path = ( - Path(project_path) - / "model" - / "model_losses" - / f"kmeans_losses_{model_name}.npy" - ) - kl_losses_path = ( - Path(project_path) / "model" / "model_losses" / f"kl_losses_{model_name}.npy" - ) - weight_values_path = ( - Path(project_path) - / "model" - / "model_losses" - / f"weight_values_{model_name}.npy" - ) - mse_train_losses_path = ( - Path(project_path) - / "model" - / "model_losses" - / f"mse_train_losses_{model_name}.npy" - ) - mse_test_losses_path = ( - Path(project_path) - / "model" - / "model_losses" - / f"mse_test_losses_{model_name}.npy" - ) - fut_losses_path = ( - Path(project_path) / "model" / "model_losses" / f"fut_losses_{model_name}.npy" - ) + train_losses_path = Path(project_path) / "model" / "model_losses" / f"train_losses_{model_name}.npy" + test_losses_path = Path(project_path) / "model" / "model_losses" / f"test_losses_{model_name}.npy" + kmeans_losses_path = Path(project_path) / "model" / "model_losses" / f"kmeans_losses_{model_name}.npy" + kl_losses_path = Path(project_path) / "model" / "model_losses" / f"kl_losses_{model_name}.npy" + weight_values_path = Path(project_path) / "model" / "model_losses" / f"weight_values_{model_name}.npy" + mse_train_losses_path = Path(project_path) / "model" / "model_losses" / f"mse_train_losses_{model_name}.npy" + mse_test_losses_path = Path(project_path) / "model" / "model_losses" / f"mse_test_losses_{model_name}.npy" + fut_losses_path = Path(project_path) / "model" / "model_losses" / f"fut_losses_{model_name}.npy" assert train_losses_path.exists() assert test_losses_path.exists() @@ -98,9 +70,7 @@ def test_train_model_best_model_file_exists(setup_project_and_train_model): project_path = setup_project_and_train_model["config_data"]["project_path"] model_name = setup_project_and_train_model["config_data"]["model_name"] project_name = setup_project_and_train_model["config_data"]["project_name"] - best_model_path = ( - Path(project_path) / "model" / "best_model" / f"{model_name}_{project_name}.pkl" - ) + best_model_path = Path(project_path) / "model" / "best_model" / f"{model_name}_{project_name}.pkl" assert best_model_path.exists() @@ -108,12 +78,8 @@ def test_train_model_best_model_file_exists(setup_project_and_train_model): def test_evaluate_model_images_exists(setup_project_and_evaluate_model): project_path = setup_project_and_evaluate_model["config_data"]["project_path"] model_name = setup_project_and_evaluate_model["config_data"]["model_name"] - reconstruction_image_path = ( - Path(project_path) / "model" / "evaluate" / "Future_Reconstruction.png" - ) - loss_image_path = ( - Path(project_path) / "model" / "evaluate" / f"MSE-and-KL-Loss{model_name}.png" - ) + reconstruction_image_path = Path(project_path) / "model" / "evaluate" / "Future_Reconstruction.png" + loss_image_path = Path(project_path) / "model" / "evaluate" / f"MSE-and-KL-Loss{model_name}.png" assert reconstruction_image_path.exists() assert loss_image_path.exists() diff --git a/tests/test_util.py b/tests/test_util.py index 28426e53..52f1a769 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -6,12 +6,8 @@ def test_pose_to_numpy_file_exists(setup_project_and_convert_pose_to_numpy): """ Test if the pose-estimation file was converted to a numpy array file. """ - project_path = setup_project_and_convert_pose_to_numpy["config_data"][ - "project_path" - ] - file_name = setup_project_and_convert_pose_to_numpy["config_data"]["session_names"][ - 0 - ] + project_path = setup_project_and_convert_pose_to_numpy["config_data"]["project_path"] + file_name = setup_project_and_convert_pose_to_numpy["config_data"]["session_names"][0] file_path = os.path.join( project_path, "data", From 96ce0b85d7cbb2f081f278d7baa36a207f0c524f Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 29 Nov 2024 15:03:06 +0100 Subject: [PATCH 05/77] remove anchor --- src/vame/preprocessing/clean_timeseries.py | 27 ++++------------------ 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/src/vame/preprocessing/clean_timeseries.py b/src/vame/preprocessing/clean_timeseries.py index 87351204..0d9fb342 100644 --- a/src/vame/preprocessing/clean_timeseries.py +++ b/src/vame/preprocessing/clean_timeseries.py @@ -18,6 +18,9 @@ def clean_timeseries( pos = [0] pos_temp = 0 + pose_ref_1 = config["pose_ref_1"] + pose_ref_2 = config["pose_ref_2"] + session_names = config["session_names"] for session in session_names: logger.info("z-scoring of session %s" % session) @@ -27,7 +30,7 @@ def clean_timeseries( path_to_file = Path(config["project_path"]) / "data" / "processed" / session / session + "-aligned.nc" ds = load_vame_dataset(path_to_file) - X = ds.position_aligned.sel(individuals="individual_0").values + X = ds.position_aligned.sel(individuals="individual_0").drop_sel(keypoints=pose_ref_1).values # Standardize data X_mean = np.mean(X, axis=0) @@ -53,25 +56,3 @@ def clean_timeseries( X_all_sessions.append(X_z) X_all_sessions = np.concatenate(X_all_sessions, axis=0) - - # Detect and delete anchors - detect_anchors = np.std(X_all_sessions, axis=0) - sort_anchors = np.sort(detect_anchors) - if sort_anchors[0] == sort_anchors[1]: - anchors = np.where(detect_anchors == sort_anchors[0])[0] - anchor_1_temp = anchors[0] - anchor_2_temp = anchors[1] - else: - anchor_1_temp = int(np.where(detect_anchors == sort_anchors[0])[0]) - anchor_2_temp = int(np.where(detect_anchors == sort_anchors[1])[0]) - - if anchor_1_temp > anchor_2_temp: - anchor_1 = anchor_1_temp - anchor_2 = anchor_2_temp - else: - anchor_1 = anchor_2_temp - anchor_2 = anchor_1_temp - - X = np.delete(X, anchor_1, 1) - X = np.delete(X, anchor_2, 1) - X = X.T From c04366902ce5ddc6ebc4776146544746b48bda36 Mon Sep 17 00:00:00 2001 From: luiz Date: Wed, 18 Dec 2024 20:59:06 +0100 Subject: [PATCH 06/77] lowconf_cleaning --- src/vame/__init__.py | 2 + src/vame/preprocessing/__init__.py | 1 + src/vame/preprocessing/cleaning.py | 67 +++++++++++++++++++++++++ src/vame/preprocessing/preprocessing.py | 36 ++++++++----- src/vame/util/auxiliary.py | 12 ++--- 5 files changed, 99 insertions(+), 19 deletions(-) create mode 100644 src/vame/preprocessing/cleaning.py diff --git a/src/vame/__init__.py b/src/vame/__init__.py index b3bc56a6..e249f36f 100644 --- a/src/vame/__init__.py +++ b/src/vame/__init__.py @@ -22,3 +22,5 @@ from vame.util import model_util from vame.util import auxiliary from vame.util.report import report + +from vame.preprocessing import preprocessing diff --git a/src/vame/preprocessing/__init__.py b/src/vame/preprocessing/__init__.py index e69de29b..d4f87ed8 100644 --- a/src/vame/preprocessing/__init__.py +++ b/src/vame/preprocessing/__init__.py @@ -0,0 +1 @@ +from vame.preprocessing.preprocessing import preprocessing diff --git a/src/vame/preprocessing/cleaning.py b/src/vame/preprocessing/cleaning.py new file mode 100644 index 00000000..bccf59ab --- /dev/null +++ b/src/vame/preprocessing/cleaning.py @@ -0,0 +1,67 @@ +from pathlib import Path +import numpy as np + +from vame.logging.logger import VameLogger +from vame.util.data_manipulation import read_pose_estimation_file + + +logger_config = VameLogger(__name__) +logger = logger_config.logger + + +def lowconf_cleaning(config: dict): + """ + Clean the low confidence data points from the dataset. + Processes position data by: + - setting low-confidence points to NaN + - interpolating NaN points + """ + project_path = config["project_path"] + sessions = config["session_names"] + pose_confidence = config["pose_confidence"] + + for i, session in enumerate(sessions): + logger.info(f"Low-confidence cleaning: session {session}, confidence threshold {pose_confidence}") + # Read raw session data + file_path = str(Path(project_path) / "data" / "raw" / f"{session}.nc") + _, data_mat, ds = read_pose_estimation_file(file_path=file_path) + + position = ds["position"].values + cleaned_position = np.empty_like(position) + confidence = ds["confidence"].values + + for individual in range(position.shape[1]): + for keypoint in range(position.shape[2]): + for space in range(position.shape[3]): + series = np.copy(position[:, individual, keypoint, space]) + conf_series = confidence[:, individual, keypoint] + + # Set low-confidence positions to NaN + nan_mask = conf_series < pose_confidence + series[nan_mask] = np.nan + + # Interpolate NaN values + if not nan_mask.all(): + series[nan_mask] = np.interp( + np.flatnonzero(nan_mask), np.flatnonzero(~nan_mask), series[~nan_mask] + ) + + # Update the position array + cleaned_position[:, individual, keypoint, space] = series + + # Update the dataset with the cleaned position values + ds["position_processed"] = (ds["position"].dims, cleaned_position) + processed_metadata = { + "processed_confidence": True, + "processed_egocentric": False, + "processed_outlier": False, + "processed_savgol": False, + } + ds.attrs.update(processed_metadata) + + # Save the cleaned dataset to file + cleaned_file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") + ds.to_netcdf( + path=cleaned_file_path, + engine="scipy", + ) diff --git a/src/vame/preprocessing/preprocessing.py b/src/vame/preprocessing/preprocessing.py index fff20f28..474c0458 100644 --- a/src/vame/preprocessing/preprocessing.py +++ b/src/vame/preprocessing/preprocessing.py @@ -2,10 +2,16 @@ import xarray as xr from vame.logging.logger import VameLogger -from vame.preprocessing.align_egocentrical import ( - egocentric_alignment_legacy, - egocentric_alignment, -) +from vame.preprocessing.cleaning import lowconf_cleaning + +# from vame.preprocessing.align_egocentrical import ( +# egocentric_alignment_legacy, +# egocentric_alignment, +# ) + + +logger_config = VameLogger(__name__) +logger = logger_config.logger def preprocessing( @@ -15,13 +21,17 @@ def preprocessing( save_logs: bool = False, ): - egocentric_alignment( - config=config, - pose_ref_1=pose_ref_1, - pose_ref_2=pose_ref_2, - ) + # Low-confidence cleaning + logger.info("Cleaning low confidence data points...") + lowconf_cleaning(config=config) + + # egocentric_alignment( + # config=config, + # pose_ref_1=pose_ref_1, + # pose_ref_2=pose_ref_2, + # ) - clean_timeseries( - config=config, - save_logs=save_logs, - ) + # clean_timeseries( + # config=config, + # save_logs=save_logs, + # ) diff --git a/src/vame/util/auxiliary.py b/src/vame/util/auxiliary.py index 560f6293..b535c218 100644 --- a/src/vame/util/auxiliary.py +++ b/src/vame/util/auxiliary.py @@ -111,13 +111,13 @@ def create_config_template() -> Tuple[dict, ruamel.yaml.YAML]: return (cfg_file, ruamelFile) -def read_config(configname: str) -> dict: +def read_config(config_file: str) -> dict: """ Reads structured config file defining a project. Parameters ---------- - configname : str + config_file : str Path to the config file. Returns @@ -126,21 +126,21 @@ def read_config(configname: str) -> dict: The contents of the config file as a dictionary. """ ruamelFile = ruamel.yaml.YAML() - path = Path(configname) + path = Path(config_file) if os.path.exists(path): try: with open(path, "r") as f: cfg = ruamelFile.load(f) - curr_dir = os.path.dirname(configname) + curr_dir = os.path.dirname(config_file) if cfg["project_path"] != curr_dir: cfg["project_path"] = curr_dir - write_config(configname, cfg) + write_config(config_file, cfg) except Exception as err: if len(err.args) > 2: if err.args[2] == "could not determine a constructor for the tag '!!python/tuple'": with open(path, "r") as ymlfile: cfg = yaml.load(ymlfile, Loader=yaml.SafeLoader) - write_config(configname, cfg) + write_config(config_file, cfg) else: raise else: From 32706f7daac5a9936d4628ab0f8c9123ace87acc Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 19 Dec 2024 09:26:10 +0100 Subject: [PATCH 07/77] alignment wip --- src/vame/__init__.py | 5 +- ...trical.py => align_egocentrical_legacy.py} | 0 src/vame/preprocessing/align_new.py | 79 --------------- src/vame/preprocessing/alignment.py | 95 +++++++++++++++++++ src/vame/preprocessing/cleaning.py | 6 +- src/vame/preprocessing/preprocessing.py | 25 ++--- 6 files changed, 110 insertions(+), 100 deletions(-) rename src/vame/preprocessing/{align_egocentrical.py => align_egocentrical_legacy.py} (100%) delete mode 100644 src/vame/preprocessing/align_new.py create mode 100644 src/vame/preprocessing/alignment.py diff --git a/src/vame/__init__.py b/src/vame/__init__.py index e249f36f..288650bb 100644 --- a/src/vame/__init__.py +++ b/src/vame/__init__.py @@ -15,10 +15,7 @@ from vame.analysis import generative_model from vame.analysis import gif from vame.util.csv_to_npy import pose_to_numpy -from vame.preprocessing.align_egocentrical import ( - egocentric_alignment_legacy, - egocentric_alignment, -) +from vame.preprocessing.align_egocentrical_legacy import egocentric_alignment_legacy from vame.util import model_util from vame.util import auxiliary from vame.util.report import report diff --git a/src/vame/preprocessing/align_egocentrical.py b/src/vame/preprocessing/align_egocentrical_legacy.py similarity index 100% rename from src/vame/preprocessing/align_egocentrical.py rename to src/vame/preprocessing/align_egocentrical_legacy.py diff --git a/src/vame/preprocessing/align_new.py b/src/vame/preprocessing/align_new.py deleted file mode 100644 index 7bce5e91..00000000 --- a/src/vame/preprocessing/align_new.py +++ /dev/null @@ -1,79 +0,0 @@ -import numpy as np -import pandas as pd -import xarray as xr - - -def align_time_series(data, keypoint1, keypoint2, confidence_threshold): - """ - Aligns the time series by first centralizing all positions around the first keypoint - and then applying rotation to align with the line connecting the two keypoints. - Handles low-confidence points by replacing them with NaNs and interpolating. - - Parameters: - - data (xarray.Dataset): The input dataset. - - keypoint1 (str): The name of the first reference keypoint. - - keypoint2 (str): The name of the second reference keypoint. - - confidence_threshold (float): Confidence threshold below which points are replaced with NaNs. - - Returns: - - xarray.Dataset: The dataset with a new 'position_aligned' variable. - """ - # Extract keypoint indices - keypoints = data.coords["keypoints"].values - idx1 = np.where(keypoints == keypoint1)[0][0] - idx2 = np.where(keypoints == keypoint2)[0][0] - - # Extract positions and confidence values - positions = data["position"].values # Shape: (time, individuals, keypoints, space) - confidence = data["confidence"].values # Shape: (time, individuals, keypoints) - - aligned_positions = np.empty_like(positions) # Preallocate aligned positions - - # Loop over individuals - for ind in range(positions.shape[1]): - individual_positions = positions[:, ind, :, :] # Shape: (time, keypoints, space) - individual_confidence = confidence[:, ind, :] # Shape: (time, keypoints) - - # Replace low-confidence points with NaN - for kp in range(individual_positions.shape[1]): # Loop over keypoints - for dim in range(2): # Loop over x and y - low_confidence = individual_confidence[:, kp] < confidence_threshold - individual_positions[low_confidence, kp, dim] = np.nan - - # Interpolate NaN values - for kp in range(individual_positions.shape[1]): # Loop over keypoints - for dim in range(2): # Loop over x and y - series = pd.Series(individual_positions[:, kp, dim]) - individual_positions[:, kp, dim] = ( - series.interpolate(method="linear", limit_direction="both") - .bfill() # Backward fill for initial NaNs - .ffill() # Forward fill for final NaNs - .values - ) - - # Centralize all positions around the first keypoint - centralized_positions = individual_positions - individual_positions[:, idx1, :][:, np.newaxis, :] - - # Calculate vectors between keypoints - vector = centralized_positions[:, idx2, :] # Vector from keypoint1 to keypoint2 - angles = np.arctan2(vector[:, 1], vector[:, 0]) # Angles in radians - - # Apply rotation to align the second keypoint along the x-axis - for t in range(centralized_positions.shape[0]): - rotation_matrix = np.array( - [ - [np.cos(-angles[t]), -np.sin(-angles[t])], - [np.sin(-angles[t]), np.cos(-angles[t])], - ] - ) - frame_positions = centralized_positions[t, :, :] - rotated_positions = (rotation_matrix @ frame_positions.T).T - aligned_positions[t, ind, :, :] = rotated_positions - - # Add new variable to the dataset - data["position_aligned"] = ( - ("time", "individuals", "keypoints", "space"), - aligned_positions, - ) - - return data diff --git a/src/vame/preprocessing/alignment.py b/src/vame/preprocessing/alignment.py new file mode 100644 index 00000000..d725ce4d --- /dev/null +++ b/src/vame/preprocessing/alignment.py @@ -0,0 +1,95 @@ +import numpy as np +from pathlib import Path + +from vame.logging.logger import VameLogger +from vame.util.data_manipulation import read_pose_estimation_file + + +logger_config = VameLogger(__name__) +logger = logger_config.logger + + +def egocentrically_align_and_center( + config: dict, + centered_reference_keypoint: str = "snout", + orientation_reference_keypoint: str = "tailbase", +) -> None: + """ + Aligns the time series by first centralizing all positions around the first keypoint + and then applying rotation to align with the line connecting the two keypoints. + + Parameters: + ----------- + config : dict + Configuration dictionary + centered_reference_keypoint : str + Name of the keypoint to use as centered reference. + orientation_reference_keypoint : str + Name of the keypoint to use as orientation reference. + + Returns: + -------- + None + """ + logger.info(f"Egocentric alignment with references: {centered_reference_keypoint} and {orientation_reference_keypoint}") + project_path = config["project_path"] + sessions = config["session_names"] + + for i, session in enumerate(sessions): + logger.info(f"Session {session}") + # Read raw session data + file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") + _, _, ds = read_pose_estimation_file(file_path=file_path) + + # Extract keypoint indices + keypoints = ds.coords["keypoints"].values + idx1 = np.where(keypoints == centered_reference_keypoint)[0][0] + idx2 = np.where(keypoints == orientation_reference_keypoint)[0][0] + + # Extract processed positions values, with shape: (time, individuals, keypoints, space) + position_processed = np.copy(ds["position_processed"].values) + position_aligned = np.empty_like(position_processed) + + # Loop over individuals + for individual in range(position_processed.shape[1]): + # Shape: (time, keypoints, space) + individual_positions = position_processed[:, individual, :, :] + + # Centralize all positions around the first keypoint + centralized_positions = individual_positions - individual_positions[:, idx1, :][:, np.newaxis, :] + + # Calculate vectors between keypoints + vector = centralized_positions[:, idx2, :] # Vector from keypoint1 to keypoint2 + angles = np.arctan2(vector[:, 1], vector[:, 0]) # Angles in radians + + # Apply rotation to align the second keypoint along the x-axis + for t in range(centralized_positions.shape[0]): + rotation_matrix = np.array( + [ + [np.cos(-angles[t]), -np.sin(-angles[t])], + [np.sin(-angles[t]), np.cos(-angles[t])], + ] + ) + frame_positions = centralized_positions[t, :, :] + rotated_positions = (rotation_matrix @ frame_positions.T).T + position_aligned[t, individual, :, :] = rotated_positions + + # Update the dataset with the cleaned position values + ds["position_processed"] = ( + ds["position"].dims, + position_aligned, + ) + processed_metadata = { + "processed_confidence": True, + "processed_egocentric": True, + "processed_outlier": False, + "processed_savgol": False, + } + ds.attrs.update(processed_metadata) + + # Save the aligned dataset to file + cleaned_file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") + ds.to_netcdf( + path=cleaned_file_path, + engine="scipy", + ) diff --git a/src/vame/preprocessing/cleaning.py b/src/vame/preprocessing/cleaning.py index bccf59ab..293e6bce 100644 --- a/src/vame/preprocessing/cleaning.py +++ b/src/vame/preprocessing/cleaning.py @@ -24,7 +24,7 @@ def lowconf_cleaning(config: dict): logger.info(f"Low-confidence cleaning: session {session}, confidence threshold {pose_confidence}") # Read raw session data file_path = str(Path(project_path) / "data" / "raw" / f"{session}.nc") - _, data_mat, ds = read_pose_estimation_file(file_path=file_path) + _, _, ds = read_pose_estimation_file(file_path=file_path) position = ds["position"].values cleaned_position = np.empty_like(position) @@ -43,7 +43,9 @@ def lowconf_cleaning(config: dict): # Interpolate NaN values if not nan_mask.all(): series[nan_mask] = np.interp( - np.flatnonzero(nan_mask), np.flatnonzero(~nan_mask), series[~nan_mask] + np.flatnonzero(nan_mask), + np.flatnonzero(~nan_mask), + series[~nan_mask], ) # Update the position array diff --git a/src/vame/preprocessing/preprocessing.py b/src/vame/preprocessing/preprocessing.py index 474c0458..c7bc1696 100644 --- a/src/vame/preprocessing/preprocessing.py +++ b/src/vame/preprocessing/preprocessing.py @@ -1,13 +1,6 @@ -from pathlib import Path -import xarray as xr - from vame.logging.logger import VameLogger from vame.preprocessing.cleaning import lowconf_cleaning - -# from vame.preprocessing.align_egocentrical import ( -# egocentric_alignment_legacy, -# egocentric_alignment, -# ) +from vame.preprocessing.alignment import egocentrically_align_and_center logger_config = VameLogger(__name__) @@ -16,8 +9,8 @@ def preprocessing( config: dict, - pose_ref_1: str = "snout", - pose_ref_2: str = "tailbase", + centered_reference_keypoint: str = "snout", + orientation_reference_keypoint: str = "tailbase", save_logs: bool = False, ): @@ -25,11 +18,13 @@ def preprocessing( logger.info("Cleaning low confidence data points...") lowconf_cleaning(config=config) - # egocentric_alignment( - # config=config, - # pose_ref_1=pose_ref_1, - # pose_ref_2=pose_ref_2, - # ) + # Egocentric alignment + logger.info("Egocentrically aligning and centering...") + egocentrically_align_and_center( + config=config, + centered_reference_keypoint=centered_reference_keypoint, + orientation_reference_keypoint=orientation_reference_keypoint, + ) # clean_timeseries( # config=config, From 55ed85c6142b19155a4d018728230aaef0aa394f Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 19 Dec 2024 15:04:49 +0100 Subject: [PATCH 08/77] alignment and vis funcs --- src/vame/preprocessing/alignment.py | 13 +- src/vame/preprocessing/cleaning.py | 3 +- src/vame/preprocessing/preprocessing.py | 7 +- src/vame/preprocessing/visualization.py | 185 ++++++++++++++++++++++++ 4 files changed, 200 insertions(+), 8 deletions(-) create mode 100644 src/vame/preprocessing/visualization.py diff --git a/src/vame/preprocessing/alignment.py b/src/vame/preprocessing/alignment.py index d725ce4d..f0d8f341 100644 --- a/src/vame/preprocessing/alignment.py +++ b/src/vame/preprocessing/alignment.py @@ -31,12 +31,14 @@ def egocentrically_align_and_center( -------- None """ - logger.info(f"Egocentric alignment with references: {centered_reference_keypoint} and {orientation_reference_keypoint}") + logger.info( + f"Egocentric alignment with references: {centered_reference_keypoint} and {orientation_reference_keypoint}" + ) project_path = config["project_path"] sessions = config["session_names"] for i, session in enumerate(sessions): - logger.info(f"Session {session}") + logger.info(f"Session: {session}") # Read raw session data file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") _, _, ds = read_pose_estimation_file(file_path=file_path) @@ -54,9 +56,14 @@ def egocentrically_align_and_center( for individual in range(position_processed.shape[1]): # Shape: (time, keypoints, space) individual_positions = position_processed[:, individual, :, :] + centralized_positions = np.empty_like(individual_positions) # Centralize all positions around the first keypoint - centralized_positions = individual_positions - individual_positions[:, idx1, :][:, np.newaxis, :] + for kp in range(individual_positions.shape[1]): + for space in range(individual_positions.shape[2]): + centralized_positions[:, kp, space] = ( + individual_positions[:, kp, space] - individual_positions[:, idx1, space] + ) # Calculate vectors between keypoints vector = centralized_positions[:, idx2, :] # Vector from keypoint1 to keypoint2 diff --git a/src/vame/preprocessing/cleaning.py b/src/vame/preprocessing/cleaning.py index 293e6bce..703103b4 100644 --- a/src/vame/preprocessing/cleaning.py +++ b/src/vame/preprocessing/cleaning.py @@ -19,9 +19,10 @@ def lowconf_cleaning(config: dict): project_path = config["project_path"] sessions = config["session_names"] pose_confidence = config["pose_confidence"] + logger.info(f"Cleaning low confidence data points. Confidence threshold: {pose_confidence}") for i, session in enumerate(sessions): - logger.info(f"Low-confidence cleaning: session {session}, confidence threshold {pose_confidence}") + logger.info(f"Session: {session}") # Read raw session data file_path = str(Path(project_path) / "data" / "raw" / f"{session}.nc") _, _, ds = read_pose_estimation_file(file_path=file_path) diff --git a/src/vame/preprocessing/preprocessing.py b/src/vame/preprocessing/preprocessing.py index c7bc1696..dcbdd3db 100644 --- a/src/vame/preprocessing/preprocessing.py +++ b/src/vame/preprocessing/preprocessing.py @@ -26,7 +26,6 @@ def preprocessing( orientation_reference_keypoint=orientation_reference_keypoint, ) - # clean_timeseries( - # config=config, - # save_logs=save_logs, - # ) + # outlier_cleaning(config=config) + + # savgol_filtering(config=config) diff --git a/src/vame/preprocessing/visualization.py b/src/vame/preprocessing/visualization.py new file mode 100644 index 00000000..9a05003d --- /dev/null +++ b/src/vame/preprocessing/visualization.py @@ -0,0 +1,185 @@ +from pathlib import Path +import matplotlib.pyplot as plt +from matplotlib.cm import get_cmap + +from vame.util.data_manipulation import read_pose_estimation_file + + +def visualize_preprocessing_scatter( + config: dict, + session_index: int = 0, + frames: list = [], + save_fig_path: str | None = None, +): + """ + Visualize the preprocessing results by plotting the original and aligned positions + of the keypoints in a scatter plot. + """ + project_path = config["project_path"] + sessions = config["session_names"] + session = sessions[session_index] + + # Read session data + file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") + _, _, ds = read_pose_estimation_file(file_path=file_path) + + original_positions = ds["position"].values + processed_positions = ds["position_processed"].values + keypoints_labels = ds.keypoints.values + + # Fixed axis limits + centralized_limits = { + "x": (-150, 150), + "y": (-150, 150), + } + + if not frames: + frames = [int(i * len(original_positions)) for i in [0.1, 0.3, 0.5, 0.7, 0.9]] + num_frames = len(frames) + fig, axes = plt.subplots(num_frames, 2, figsize=(14, 6 * num_frames)) # Increased figure size + + for i, frame in enumerate(frames): + # Centralized Original positions + ax_original = axes[i, 0] + x_orig, y_orig = original_positions[frame, 0, :, 0], original_positions[frame, 0, :, 1] + x_orig -= x_orig[0] # Centralize around the first keypoint + y_orig -= y_orig[0] + ax_original.scatter(x_orig, y_orig, c="blue", label="Original") + for k, (x, y) in enumerate(zip(x_orig, y_orig)): + if ( + centralized_limits["x"][0] <= x <= centralized_limits["x"][1] + and centralized_limits["y"][0] <= y <= centralized_limits["y"][1] + ): + ax_original.text( + x, y, keypoints_labels[k], fontsize=8, color="blue" + ) # Annotate only points within limits + ax_original.set_title(f"Original - Frame {frame}", fontsize=10) # Reduced title font size + ax_original.set_xlabel("X", fontsize=8) + ax_original.set_ylabel("Y", fontsize=8) + ax_original.axhline(0, color="gray", linestyle="--") + ax_original.axvline(0, color="gray", linestyle="--") + ax_original.axis("equal") + ax_original.set_xlim(*centralized_limits["x"]) + ax_original.set_ylim(*centralized_limits["y"]) + + # Centralized Aligned positions + ax_aligned = axes[i, 1] + x_aligned, y_aligned = processed_positions[frame, 0, :, 0], processed_positions[frame, 0, :, 1] + # x_aligned -= x_aligned[0] # Centralize around the first keypoint + # y_aligned -= y_aligned[0] + ax_aligned.scatter(x_aligned, y_aligned, c="green", label="Aligned") + for k, (x, y) in enumerate(zip(x_aligned, y_aligned)): + if ( + centralized_limits["x"][0] <= x <= centralized_limits["x"][1] + and centralized_limits["y"][0] <= y <= centralized_limits["y"][1] + ): + ax_aligned.text( + x, y, keypoints_labels[k], fontsize=8, color="green" + ) # Annotate only points within limits + ax_aligned.set_title(f"Aligned - Frame {frame}", fontsize=10) # Reduced title font size + ax_aligned.set_xlabel("X", fontsize=8) + ax_aligned.set_ylabel("Y", fontsize=8) + ax_aligned.axhline(0, color="gray", linestyle="--") + ax_aligned.axvline(0, color="gray", linestyle="--") + ax_aligned.axis("equal") + ax_aligned.set_xlim(*centralized_limits["x"]) + ax_aligned.set_ylim(*centralized_limits["y"]) + + plt.tight_layout(pad=2.0) # Add padding to reduce overlap + + if save_fig_path: + plt.savefig(save_fig_path) + + +def visualize_preprocessing_timeseries( + config: dict, + session_index: int = 0, + n_samples: int = 1000, + save_fig_path: str | None = None, +): + """ + Visualize the preprocessing results by plotting the original and aligned positions + of the keypoints in a timeseries plot. + """ + project_path = config["project_path"] + sessions = config["session_names"] + session = sessions[session_index] + + # Read session data + file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") + _, _, ds = read_pose_estimation_file(file_path=file_path) + + fig, ax = plt.subplots(4, 1, figsize=(10, 12)) + + individual = "individual_0" + keypoints_labels = ds.keypoints.values + + # Create a colormap with distinguishable colors + cmap = get_cmap("tab10") if len(keypoints_labels) <= 10 else get_cmap("tab20") + colors = [cmap(i / len(keypoints_labels)) for i in range(len(keypoints_labels))] + + for i, kp in enumerate(keypoints_labels): + sel_x = dict( + individuals=individual, + keypoints=kp, + space="x", + ) + sel_y = dict( + individuals=individual, + keypoints=kp, + space="y", + ) + + ds.position.sel(**sel_x)[0:n_samples].plot( + linewidth=2, + ax=ax[0], + label=kp, + color=colors[i], + ) + ds.position.sel(**sel_y)[0:n_samples].plot( + linewidth=2, + ax=ax[1], + label=kp, + color=colors[i], + ) + ds.position_processed.sel(**sel_x)[0:n_samples].plot( + linewidth=2, + ax=ax[2], + label=kp, + color=colors[i], + ) + ds.position_processed.sel(**sel_y)[0:n_samples].plot( + linewidth=2, + ax=ax[3], + label=kp, + color=colors[i], + ) + + ax[0].set_title("") + ax[1].set_title("") + ax[2].set_title("") + ax[3].set_title("") + + ax[0].set_xlabel("") + ax[1].set_xlabel("") + ax[2].set_xlabel("") + + ax[0].set_ylabel("Allocentric X") + ax[1].set_ylabel("Allocentric Y") + ax[2].set_ylabel("Egocentric X") + ax[3].set_ylabel("Egocentric Y") + + # Add a single legend for all subplots + handles, labels = ax[0].get_legend_handles_labels() + fig.legend( + handles, + labels, + loc="upper center", + ncol=5, + bbox_to_anchor=(0.5, 1.02), + ) + + plt.tight_layout(rect=[0, 0, 1, 0.98]) + + if save_fig_path: + plt.savefig(save_fig_path) From 8fd9560b02c56709c3c8f0d092318db45375c99b Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 19 Dec 2024 17:36:46 +0100 Subject: [PATCH 09/77] iqr cleaning --- src/vame/preprocessing/alignment.py | 13 +---- src/vame/preprocessing/cleaning.py | 74 ++++++++++++++++++++++--- src/vame/preprocessing/preprocessing.py | 6 +- 3 files changed, 73 insertions(+), 20 deletions(-) diff --git a/src/vame/preprocessing/alignment.py b/src/vame/preprocessing/alignment.py index f0d8f341..c7bd5484 100644 --- a/src/vame/preprocessing/alignment.py +++ b/src/vame/preprocessing/alignment.py @@ -82,17 +82,8 @@ def egocentrically_align_and_center( position_aligned[t, individual, :, :] = rotated_positions # Update the dataset with the cleaned position values - ds["position_processed"] = ( - ds["position"].dims, - position_aligned, - ) - processed_metadata = { - "processed_confidence": True, - "processed_egocentric": True, - "processed_outlier": False, - "processed_savgol": False, - } - ds.attrs.update(processed_metadata) + ds["position_processed"] = (ds["position"].dims, position_aligned) + ds.attrs.update({"processed_confidence": True}) # Save the aligned dataset to file cleaned_file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") diff --git a/src/vame/preprocessing/cleaning.py b/src/vame/preprocessing/cleaning.py index 703103b4..6b4b61ab 100644 --- a/src/vame/preprocessing/cleaning.py +++ b/src/vame/preprocessing/cleaning.py @@ -1,5 +1,6 @@ from pathlib import Path import numpy as np +from scipy.stats import iqr from vame.logging.logger import VameLogger from vame.util.data_manipulation import read_pose_estimation_file @@ -54,13 +55,72 @@ def lowconf_cleaning(config: dict): # Update the dataset with the cleaned position values ds["position_processed"] = (ds["position"].dims, cleaned_position) - processed_metadata = { - "processed_confidence": True, - "processed_egocentric": False, - "processed_outlier": False, - "processed_savgol": False, - } - ds.attrs.update(processed_metadata) + ds.attrs.update({"processed_confidence": True}) + + # Save the cleaned dataset to file + cleaned_file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") + ds.to_netcdf( + path=cleaned_file_path, + engine="scipy", + ) + + +def outlier_cleaning(config: dict): + """ + Clean the outliers from the dataset. + Processes position data by: + - setting outlier points to NaN + - interpolating NaN points + """ + logger.info("Cleaning outliers with Z-score transformation and IQR cutoff.") + project_path = config["project_path"] + sessions = config["session_names"] + + for i, session in enumerate(sessions): + logger.info(f"Session: {session}") + # Read raw session data + file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") + _, _, ds = read_pose_estimation_file(file_path=file_path) + + position = np.copy(ds["position_processed"].values) + cleaned_position = np.empty_like(position) + + for individual in range(position.shape[1]): + for keypoint in range(position.shape[2]): + for space in range(position.shape[3]): + series = np.copy(position[:, individual, keypoint, space]) + + # Check if all values are zero, then skip + if np.all(series == 0): + continue + + # Calculate Z-score + z_series = (series - np.nanmean(series)) / np.nanstd(series) + + # Set outlier positions to NaN, based on IQR cutoff + if config["robust"]: + iqr_factor = config["iqr_factor"] + iqr_val = iqr(z_series) + outlier_mask = np.abs(z_series) > iqr_factor * iqr_val + z_series[outlier_mask] = np.nan + + # Interpolate NaN values + if not outlier_mask.all(): + z_series[outlier_mask] = np.interp( + np.flatnonzero(outlier_mask), + np.flatnonzero(~outlier_mask), + z_series[~outlier_mask], + ) + + # Redo the z-score to remove the bias of the now-removed outliers + z_series = (z_series - np.nanmean(z_series)) / np.nanstd(z_series) + + # Update the processed position array + cleaned_position[:, individual, keypoint, space] = z_series + + # Update the dataset with the cleaned position values + ds["position_processed"] = (ds["position"].dims, cleaned_position) + ds.attrs.update({"processed_outliers": True}) # Save the cleaned dataset to file cleaned_file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") diff --git a/src/vame/preprocessing/preprocessing.py b/src/vame/preprocessing/preprocessing.py index dcbdd3db..80a45b92 100644 --- a/src/vame/preprocessing/preprocessing.py +++ b/src/vame/preprocessing/preprocessing.py @@ -1,5 +1,5 @@ from vame.logging.logger import VameLogger -from vame.preprocessing.cleaning import lowconf_cleaning +from vame.preprocessing.cleaning import lowconf_cleaning, outlier_cleaning from vame.preprocessing.alignment import egocentrically_align_and_center @@ -26,6 +26,8 @@ def preprocessing( orientation_reference_keypoint=orientation_reference_keypoint, ) - # outlier_cleaning(config=config) + # Outlier cleaning + logger.info("Cleaning outliers...") + outlier_cleaning(config=config) # savgol_filtering(config=config) From 4c7d5f65c37f2ac020213f51c68caa31d45cf900 Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 19 Dec 2024 18:24:01 +0100 Subject: [PATCH 10/77] filter --- src/vame/preprocessing/alignment.py | 2 +- src/vame/preprocessing/clean_timeseries.py | 58 ---------------------- src/vame/preprocessing/cleaning.py | 2 +- src/vame/preprocessing/filter.py | 58 ++++++++++++++++++++++ src/vame/preprocessing/preprocessing.py | 18 ++++++- src/vame/preprocessing/visualization.py | 24 +++++---- 6 files changed, 91 insertions(+), 71 deletions(-) delete mode 100644 src/vame/preprocessing/clean_timeseries.py create mode 100644 src/vame/preprocessing/filter.py diff --git a/src/vame/preprocessing/alignment.py b/src/vame/preprocessing/alignment.py index c7bd5484..13c742ec 100644 --- a/src/vame/preprocessing/alignment.py +++ b/src/vame/preprocessing/alignment.py @@ -83,7 +83,7 @@ def egocentrically_align_and_center( # Update the dataset with the cleaned position values ds["position_processed"] = (ds["position"].dims, position_aligned) - ds.attrs.update({"processed_confidence": True}) + ds.attrs.update({"processed_alignment": True}) # Save the aligned dataset to file cleaned_file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") diff --git a/src/vame/preprocessing/clean_timeseries.py b/src/vame/preprocessing/clean_timeseries.py deleted file mode 100644 index 0d9fb342..00000000 --- a/src/vame/preprocessing/clean_timeseries.py +++ /dev/null @@ -1,58 +0,0 @@ -from pathlib import Path -import numpy as np -from scipy.stats import iqr - -from vame.logging.logger import VameLogger -from vame.io.load_poses import load_vame_dataset -from vame.util.data_manipulation import interpolate_nans_with_pandas - - -logger_config = VameLogger(__name__) -logger = logger_config.logger - - -def clean_timeseries( - config: dict, -): - X_all_sessions = [] - pos = [0] - pos_temp = 0 - - pose_ref_1 = config["pose_ref_1"] - pose_ref_2 = config["pose_ref_2"] - - session_names = config["session_names"] - for session in session_names: - logger.info("z-scoring of session %s" % session) - - # path_to_file = Path(config["project_path"]) / "data" / "processed" / session / session + "-PE-seq.npy" - # data = np.load(path_to_file) - - path_to_file = Path(config["project_path"]) / "data" / "processed" / session / session + "-aligned.nc" - ds = load_vame_dataset(path_to_file) - X = ds.position_aligned.sel(individuals="individual_0").drop_sel(keypoints=pose_ref_1).values - - # Standardize data - X_mean = np.mean(X, axis=0) - X_std = np.std(X, axis=0) - X_z = (X - X_mean) / X_std - - # Robust interquartile range outlier detection - if config["robust"]: - iqr_val = iqr(X_z, axis=0) - logger.info("IQR value: %.2f, IQR cutoff: %.2f" % (iqr_val, config["iqr_factor"] * iqr_val)) - for t in range(X_z.shape[0]): # Iterate over time dimension - for kp in range(X_z.shape[1]): # Iterate over keypoints dimension - for sp in range(X_z.shape[2]): # Iterate over space dimennsion (x, y) - if X_z[t, kp, sp] > config["iqr_factor"] * iqr_val[kp, sp]: - X_z[t, kp, sp] = np.nan - elif X_z[t, kp, sp] < -config["iqr_factor"] * iqr_val[kp, sp]: - X_z[t, kp, sp] = np.nan - X_z = interpolate_nans_with_pandas(X_z) - - X_len = X.shape[0] - pos_temp += X_len - pos.append(pos_temp) - X_all_sessions.append(X_z) - - X_all_sessions = np.concatenate(X_all_sessions, axis=0) diff --git a/src/vame/preprocessing/cleaning.py b/src/vame/preprocessing/cleaning.py index 6b4b61ab..ab976810 100644 --- a/src/vame/preprocessing/cleaning.py +++ b/src/vame/preprocessing/cleaning.py @@ -83,7 +83,7 @@ def outlier_cleaning(config: dict): _, _, ds = read_pose_estimation_file(file_path=file_path) position = np.copy(ds["position_processed"].values) - cleaned_position = np.empty_like(position) + cleaned_position = np.copy(position) for individual in range(position.shape[1]): for keypoint in range(position.shape[2]): diff --git a/src/vame/preprocessing/filter.py b/src/vame/preprocessing/filter.py new file mode 100644 index 00000000..200531e0 --- /dev/null +++ b/src/vame/preprocessing/filter.py @@ -0,0 +1,58 @@ +from scipy.signal import savgol_filter +import numpy as np +from pathlib import Path + +from vame.logging.logger import VameLogger +from vame.util.data_manipulation import read_pose_estimation_file + + +logger_config = VameLogger(__name__) +logger = logger_config.logger + + +def savgol_filtering(config: dict): + """ + Apply Savitzky-Golay filter to the data. + """ + logger.info("Applying Savitzky-Golay filter...") + project_path = config["project_path"] + sessions = config["session_names"] + + savgol_length = config["savgol_length"] + savgol_order = config["savgol_order"] + for i, session in enumerate(sessions): + logger.info(f"Session: {session}") + # Read raw session data + file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") + _, _, ds = read_pose_estimation_file(file_path=file_path) + + # Extract processed positions values, with shape: (time, individuals, keypoints, space) + position = np.copy(ds["position_processed"].values) + filtered_position = np.copy(position) + for individual in range(position.shape[1]): + for keypoint in range(position.shape[2]): + for space in range(position.shape[3]): + series = np.copy(position[:, individual, keypoint, space]) + + # Check if all values are zero, then skip + if np.all(series == 0): + continue + + # Apply Savitzky-Golay filter + filtered_position[:, individual, keypoint, space] = savgol_filter( + x=series, + window_length=savgol_length, + polyorder=savgol_order, + axis=0, + ) + + # Update the dataset with the filtered position values + ds["position_processed"] = (ds["position"].dims, filtered_position) + ds.attrs.update({"processed_filtered": True}) + + # Save the filtered dataset to file + filtered_file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") + ds.to_netcdf( + path=filtered_file_path, + engine="scipy", + ) diff --git a/src/vame/preprocessing/preprocessing.py b/src/vame/preprocessing/preprocessing.py index 80a45b92..c4eaef9e 100644 --- a/src/vame/preprocessing/preprocessing.py +++ b/src/vame/preprocessing/preprocessing.py @@ -1,6 +1,11 @@ from vame.logging.logger import VameLogger from vame.preprocessing.cleaning import lowconf_cleaning, outlier_cleaning from vame.preprocessing.alignment import egocentrically_align_and_center +from vame.preprocessing.filter import savgol_filtering +from vame.preprocessing.visualization import ( + visualize_preprocessing_scatter, + visualize_preprocessing_timeseries, +) logger_config = VameLogger(__name__) @@ -25,9 +30,20 @@ def preprocessing( centered_reference_keypoint=centered_reference_keypoint, orientation_reference_keypoint=orientation_reference_keypoint, ) + visualize_preprocessing_scatter( + config, + save_to_file=True, + ) # Outlier cleaning logger.info("Cleaning outliers...") outlier_cleaning(config=config) - # savgol_filtering(config=config) + # Savgol filtering + logger.info("Applying Savitzky-Golay filter...") + savgol_filtering(config=config) + + visualize_preprocessing_timeseries( + config, + save_to_file=True, + ) diff --git a/src/vame/preprocessing/visualization.py b/src/vame/preprocessing/visualization.py index 9a05003d..2aeab686 100644 --- a/src/vame/preprocessing/visualization.py +++ b/src/vame/preprocessing/visualization.py @@ -9,7 +9,7 @@ def visualize_preprocessing_scatter( config: dict, session_index: int = 0, frames: list = [], - save_fig_path: str | None = None, + save_to_file: bool = False, ): """ Visualize the preprocessing results by plotting the original and aligned positions @@ -87,15 +87,17 @@ def visualize_preprocessing_scatter( plt.tight_layout(pad=2.0) # Add padding to reduce overlap - if save_fig_path: - plt.savefig(save_fig_path) + if save_to_file: + save_fig_path = Path(project_path) / "reports" / "figures" / f"{session}_preprocessing_scatter.png" + save_fig_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(str(save_fig_path)) def visualize_preprocessing_timeseries( config: dict, session_index: int = 0, n_samples: int = 1000, - save_fig_path: str | None = None, + save_to_file: bool = False, ): """ Visualize the preprocessing results by plotting the original and aligned positions @@ -131,25 +133,25 @@ def visualize_preprocessing_timeseries( ) ds.position.sel(**sel_x)[0:n_samples].plot( - linewidth=2, + linewidth=1.5, ax=ax[0], label=kp, color=colors[i], ) ds.position.sel(**sel_y)[0:n_samples].plot( - linewidth=2, + linewidth=1.5, ax=ax[1], label=kp, color=colors[i], ) ds.position_processed.sel(**sel_x)[0:n_samples].plot( - linewidth=2, + linewidth=1.5, ax=ax[2], label=kp, color=colors[i], ) ds.position_processed.sel(**sel_y)[0:n_samples].plot( - linewidth=2, + linewidth=1.5, ax=ax[3], label=kp, color=colors[i], @@ -181,5 +183,7 @@ def visualize_preprocessing_timeseries( plt.tight_layout(rect=[0, 0, 1, 0.98]) - if save_fig_path: - plt.savefig(save_fig_path) + if save_to_file: + save_fig_path = Path(project_path) / "reports" / "figures" / f"{session}_preprocessing_timeseries.png" + save_fig_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(str(save_fig_path)) From a90e9a49aca97dfcc0867b1048e92b21cef6fff9 Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 19 Dec 2024 18:35:45 +0100 Subject: [PATCH 11/77] keep track of percent of cleaned points --- src/vame/preprocessing/cleaning.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/vame/preprocessing/cleaning.py b/src/vame/preprocessing/cleaning.py index ab976810..ae73394a 100644 --- a/src/vame/preprocessing/cleaning.py +++ b/src/vame/preprocessing/cleaning.py @@ -32,14 +32,15 @@ def lowconf_cleaning(config: dict): cleaned_position = np.empty_like(position) confidence = ds["confidence"].values + perc_interp_points = np.zeros((position.shape[1], position.shape[2])) for individual in range(position.shape[1]): for keypoint in range(position.shape[2]): + conf_series = confidence[:, individual, keypoint] + nan_mask = conf_series < pose_confidence + perc_interp_points[individual, keypoint] = 100 * np.sum(nan_mask) / len(nan_mask) for space in range(position.shape[3]): - series = np.copy(position[:, individual, keypoint, space]) - conf_series = confidence[:, individual, keypoint] - # Set low-confidence positions to NaN - nan_mask = conf_series < pose_confidence + series = np.copy(position[:, individual, keypoint, space]) series[nan_mask] = np.nan # Interpolate NaN values @@ -57,6 +58,8 @@ def lowconf_cleaning(config: dict): ds["position_processed"] = (ds["position"].dims, cleaned_position) ds.attrs.update({"processed_confidence": True}) + ds["percentage_low_confidence"] = (["individual", "keypoint"], perc_interp_points) + # Save the cleaned dataset to file cleaned_file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") ds.to_netcdf( @@ -85,6 +88,8 @@ def outlier_cleaning(config: dict): position = np.copy(ds["position_processed"].values) cleaned_position = np.copy(position) + perc_interp_points = np.zeros((position.shape[1], position.shape[2], position.shape[3])) + for individual in range(position.shape[1]): for keypoint in range(position.shape[2]): for space in range(position.shape[3]): @@ -103,6 +108,7 @@ def outlier_cleaning(config: dict): iqr_val = iqr(z_series) outlier_mask = np.abs(z_series) > iqr_factor * iqr_val z_series[outlier_mask] = np.nan + perc_interp_points[individual, keypoint, space] = 100 * np.sum(outlier_mask) / len(outlier_mask) # Interpolate NaN values if not outlier_mask.all(): @@ -122,6 +128,8 @@ def outlier_cleaning(config: dict): ds["position_processed"] = (ds["position"].dims, cleaned_position) ds.attrs.update({"processed_outliers": True}) + ds["percentage_iqr_outliers"] = (["individual", "keypoint", "space"], perc_interp_points) + # Save the cleaned dataset to file cleaned_file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") ds.to_netcdf( From 433a4bcb36123c2ef04a09e75b46686033ef84f8 Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 19 Dec 2024 19:17:50 +0100 Subject: [PATCH 12/77] update alignment wip --- src/vame/preprocessing/alignment.py | 13 ++++-- src/vame/preprocessing/visualization.py | 55 +++++++++++-------------- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/src/vame/preprocessing/alignment.py b/src/vame/preprocessing/alignment.py index 13c742ec..bb806662 100644 --- a/src/vame/preprocessing/alignment.py +++ b/src/vame/preprocessing/alignment.py @@ -67,18 +67,23 @@ def egocentrically_align_and_center( # Calculate vectors between keypoints vector = centralized_positions[:, idx2, :] # Vector from keypoint1 to keypoint2 - angles = np.arctan2(vector[:, 1], vector[:, 0]) # Angles in radians + angles = np.arctan2(vector[:, 0], vector[:, 1]) # Angles in radians - # Apply rotation to align the second keypoint along the x-axis + # Apply rotation to align the second keypoint along the y-axis for t in range(centralized_positions.shape[0]): rotation_matrix = np.array( [ - [np.cos(-angles[t]), -np.sin(-angles[t])], - [np.sin(-angles[t]), np.cos(-angles[t])], + [np.cos(angles[t]), -np.sin(angles[t])], + [np.sin(angles[t]), np.cos(angles[t])], ] ) frame_positions = centralized_positions[t, :, :] rotated_positions = (rotation_matrix @ frame_positions.T).T + + # Check and ensure the y-value of orientation_reference_keypoint is negative + if rotated_positions[idx2, 1] > 0: + rotated_positions[:, :] *= -1 # Flip all coordinates + position_aligned[t, individual, :, :] = rotated_positions # Update the dataset with the cleaned position values diff --git a/src/vame/preprocessing/visualization.py b/src/vame/preprocessing/visualization.py index 2aeab686..d7546211 100644 --- a/src/vame/preprocessing/visualization.py +++ b/src/vame/preprocessing/visualization.py @@ -27,63 +27,58 @@ def visualize_preprocessing_scatter( processed_positions = ds["position_processed"].values keypoints_labels = ds.keypoints.values - # Fixed axis limits - centralized_limits = { - "x": (-150, 150), - "y": (-150, 150), - } + # # Fixed axis limits + # centralized_limits = { + # "x": (-150, 150), + # "y": (-150, 150), + # } if not frames: frames = [int(i * len(original_positions)) for i in [0.1, 0.3, 0.5, 0.7, 0.9]] num_frames = len(frames) + fig, axes = plt.subplots(num_frames, 2, figsize=(14, 6 * num_frames)) # Increased figure size for i, frame in enumerate(frames): - # Centralized Original positions - ax_original = axes[i, 0] + # Compute dynamic limits for the original positions x_orig, y_orig = original_positions[frame, 0, :, 0], original_positions[frame, 0, :, 1] x_orig -= x_orig[0] # Centralize around the first keypoint y_orig -= y_orig[0] + x_min, x_max = x_orig.min() - 10, x_orig.max() + 10 # Add a margin + y_min, y_max = y_orig.min() - 10, y_orig.max() + 10 + + # Centralized Original positions + ax_original = axes[i, 0] ax_original.scatter(x_orig, y_orig, c="blue", label="Original") for k, (x, y) in enumerate(zip(x_orig, y_orig)): - if ( - centralized_limits["x"][0] <= x <= centralized_limits["x"][1] - and centralized_limits["y"][0] <= y <= centralized_limits["y"][1] - ): - ax_original.text( - x, y, keypoints_labels[k], fontsize=8, color="blue" - ) # Annotate only points within limits - ax_original.set_title(f"Original - Frame {frame}", fontsize=10) # Reduced title font size + ax_original.text(x, y, keypoints_labels[k], fontsize=8, color="blue") + ax_original.set_title(f"Original - Frame {frame}", fontsize=10) ax_original.set_xlabel("X", fontsize=8) ax_original.set_ylabel("Y", fontsize=8) ax_original.axhline(0, color="gray", linestyle="--") ax_original.axvline(0, color="gray", linestyle="--") ax_original.axis("equal") - ax_original.set_xlim(*centralized_limits["x"]) - ax_original.set_ylim(*centralized_limits["y"]) + ax_original.set_xlim(x_min, x_max) + ax_original.set_ylim(y_min, y_max) + + # Compute dynamic limits for the aligned positions + x_aligned, y_aligned = processed_positions[frame, 0, :, 0], processed_positions[frame, 0, :, 1] + x_min_aligned, x_max_aligned = x_aligned.min() - 10, x_aligned.max() + 10 # Add a margin + y_min_aligned, y_max_aligned = y_aligned.min() - 10, y_aligned.max() + 10 # Centralized Aligned positions ax_aligned = axes[i, 1] - x_aligned, y_aligned = processed_positions[frame, 0, :, 0], processed_positions[frame, 0, :, 1] - # x_aligned -= x_aligned[0] # Centralize around the first keypoint - # y_aligned -= y_aligned[0] ax_aligned.scatter(x_aligned, y_aligned, c="green", label="Aligned") for k, (x, y) in enumerate(zip(x_aligned, y_aligned)): - if ( - centralized_limits["x"][0] <= x <= centralized_limits["x"][1] - and centralized_limits["y"][0] <= y <= centralized_limits["y"][1] - ): - ax_aligned.text( - x, y, keypoints_labels[k], fontsize=8, color="green" - ) # Annotate only points within limits - ax_aligned.set_title(f"Aligned - Frame {frame}", fontsize=10) # Reduced title font size + ax_aligned.text(x, y, keypoints_labels[k], fontsize=8, color="green") + ax_aligned.set_title(f"Aligned - Frame {frame}", fontsize=10) ax_aligned.set_xlabel("X", fontsize=8) ax_aligned.set_ylabel("Y", fontsize=8) ax_aligned.axhline(0, color="gray", linestyle="--") ax_aligned.axvline(0, color="gray", linestyle="--") ax_aligned.axis("equal") - ax_aligned.set_xlim(*centralized_limits["x"]) - ax_aligned.set_ylim(*centralized_limits["y"]) + ax_aligned.set_xlim(x_min_aligned, x_max_aligned) + ax_aligned.set_ylim(y_min_aligned, y_max_aligned) plt.tight_layout(pad=2.0) # Add padding to reduce overlap From 3b161890792849c7a40a98e7a3e420a9f52526c0 Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 19 Dec 2024 19:22:20 +0100 Subject: [PATCH 13/77] . --- src/vame/preprocessing/visualization.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/vame/preprocessing/visualization.py b/src/vame/preprocessing/visualization.py index d7546211..e0382898 100644 --- a/src/vame/preprocessing/visualization.py +++ b/src/vame/preprocessing/visualization.py @@ -27,12 +27,6 @@ def visualize_preprocessing_scatter( processed_positions = ds["position_processed"].values keypoints_labels = ds.keypoints.values - # # Fixed axis limits - # centralized_limits = { - # "x": (-150, 150), - # "y": (-150, 150), - # } - if not frames: frames = [int(i * len(original_positions)) for i in [0.1, 0.3, 0.5, 0.7, 0.9]] num_frames = len(frames) From 3970e120d975bb379b0f9fae25ddefda13554ba4 Mon Sep 17 00:00:00 2001 From: luiz Date: Sat, 21 Dec 2024 13:02:03 +0100 Subject: [PATCH 14/77] args and visualizations --- src/vame/initialize_project/new.py | 39 +++--- src/vame/preprocessing/alignment.py | 8 +- src/vame/preprocessing/cleaning.py | 26 ++-- src/vame/preprocessing/filter.py | 10 +- src/vame/preprocessing/preprocessing.py | 25 +++- src/vame/preprocessing/visualization.py | 171 ++++++++++++++++++------ 6 files changed, 207 insertions(+), 72 deletions(-) diff --git a/src/vame/initialize_project/new.py b/src/vame/initialize_project/new.py index b6f05328..7ef6d625 100644 --- a/src/vame/initialize_project/new.py +++ b/src/vame/initialize_project/new.py @@ -34,9 +34,16 @@ def init_new_project( A VAME project is a directory with the following structure: - project_name/ - data/ - - video1/ - - video2/ - - ... + - raw/ + - session1.mp4 + - session1.nc + - session2.mp4 + - session2.nc + - ... + - processed/ + - session1_processed.nc + - session2_processed.nc + - ... - model/ - pretrained_model/ - results/ @@ -45,13 +52,6 @@ def init_new_project( - ... - states/ - states.json - - videos/ - - pose_estimation/ - - video1.csv - - video2.csv - - video1.mp4 - - video2.mp4 - - ... - config.yaml Parameters: @@ -160,13 +160,16 @@ def init_new_project( "If the pose estimation file is in nwb format, you must provide the path to the pose series data for each nwb file." ) - # Creates directories under project/data/processed/ + # Session names videos_paths = [Path(vp).resolve() for vp in videos] session_names = [] - dirs_processed_data = [data_processed_path / Path(i.stem) for i in videos_paths] - for p in dirs_processed_data: - p.mkdir(parents=True, exist_ok=True) - session_names.append(p.stem) + for s in videos_paths: + session_names.append(s.stem) + + # # Creates directories under project/data/processed/ + # dirs_processed_data = [data_processed_path / Path(i.stem) for i in videos_paths] + # for p in dirs_processed_data: + # p.mkdir(parents=True, exist_ok=True) # Creates directories under project/results/ dirs_results = [results_path / Path(i.stem) for i in videos_paths] @@ -202,6 +205,12 @@ def init_new_project( ) num_features_list.append(ds.space.shape[0] * ds.keypoints.shape[0]) + output_processed_name = data_processed_path / Path(video_path).stem + ds.to_netcdf( + path=f"{output_processed_name}_processed.nc", + engine="scipy", + ) + unique_num_features = list(set(num_features_list)) if len(unique_num_features) > 1: raise ValueError("All pose estimation files must have the same number of features.") diff --git a/src/vame/preprocessing/alignment.py b/src/vame/preprocessing/alignment.py index bb806662..af064021 100644 --- a/src/vame/preprocessing/alignment.py +++ b/src/vame/preprocessing/alignment.py @@ -13,6 +13,8 @@ def egocentrically_align_and_center( config: dict, centered_reference_keypoint: str = "snout", orientation_reference_keypoint: str = "tailbase", + read_from_variable: str = "position_processed", + save_to_variable: str = "position_egocentric_aligned", ) -> None: """ Aligns the time series by first centralizing all positions around the first keypoint @@ -39,7 +41,7 @@ def egocentrically_align_and_center( for i, session in enumerate(sessions): logger.info(f"Session: {session}") - # Read raw session data + # Read session data file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") _, _, ds = read_pose_estimation_file(file_path=file_path) @@ -49,7 +51,7 @@ def egocentrically_align_and_center( idx2 = np.where(keypoints == orientation_reference_keypoint)[0][0] # Extract processed positions values, with shape: (time, individuals, keypoints, space) - position_processed = np.copy(ds["position_processed"].values) + position_processed = np.copy(ds[read_from_variable].values) position_aligned = np.empty_like(position_processed) # Loop over individuals @@ -87,7 +89,7 @@ def egocentrically_align_and_center( position_aligned[t, individual, :, :] = rotated_positions # Update the dataset with the cleaned position values - ds["position_processed"] = (ds["position"].dims, position_aligned) + ds[save_to_variable] = (ds[read_from_variable].dims, position_aligned) ds.attrs.update({"processed_alignment": True}) # Save the aligned dataset to file diff --git a/src/vame/preprocessing/cleaning.py b/src/vame/preprocessing/cleaning.py index ae73394a..a03e545c 100644 --- a/src/vame/preprocessing/cleaning.py +++ b/src/vame/preprocessing/cleaning.py @@ -10,7 +10,11 @@ logger = logger_config.logger -def lowconf_cleaning(config: dict): +def lowconf_cleaning( + config: dict, + read_from_variable: str = "position_processed", + save_to_variable: str = "position_processed", +): """ Clean the low confidence data points from the dataset. Processes position data by: @@ -25,10 +29,10 @@ def lowconf_cleaning(config: dict): for i, session in enumerate(sessions): logger.info(f"Session: {session}") # Read raw session data - file_path = str(Path(project_path) / "data" / "raw" / f"{session}.nc") + file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") _, _, ds = read_pose_estimation_file(file_path=file_path) - position = ds["position"].values + position = ds[read_from_variable].values cleaned_position = np.empty_like(position) confidence = ds["confidence"].values @@ -55,7 +59,7 @@ def lowconf_cleaning(config: dict): cleaned_position[:, individual, keypoint, space] = series # Update the dataset with the cleaned position values - ds["position_processed"] = (ds["position"].dims, cleaned_position) + ds[save_to_variable] = (ds[read_from_variable].dims, cleaned_position) ds.attrs.update({"processed_confidence": True}) ds["percentage_low_confidence"] = (["individual", "keypoint"], perc_interp_points) @@ -68,7 +72,11 @@ def lowconf_cleaning(config: dict): ) -def outlier_cleaning(config: dict): +def outlier_cleaning( + config: dict, + read_from_variable: str = "position_processed", + save_to_variable: str = "position_processed", +): """ Clean the outliers from the dataset. Processes position data by: @@ -85,7 +93,7 @@ def outlier_cleaning(config: dict): file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") _, _, ds = read_pose_estimation_file(file_path=file_path) - position = np.copy(ds["position_processed"].values) + position = np.copy(ds[read_from_variable].values) cleaned_position = np.copy(position) perc_interp_points = np.zeros((position.shape[1], position.shape[2], position.shape[3])) @@ -108,7 +116,9 @@ def outlier_cleaning(config: dict): iqr_val = iqr(z_series) outlier_mask = np.abs(z_series) > iqr_factor * iqr_val z_series[outlier_mask] = np.nan - perc_interp_points[individual, keypoint, space] = 100 * np.sum(outlier_mask) / len(outlier_mask) + perc_interp_points[individual, keypoint, space] = ( + 100 * np.sum(outlier_mask) / len(outlier_mask) + ) # Interpolate NaN values if not outlier_mask.all(): @@ -125,7 +135,7 @@ def outlier_cleaning(config: dict): cleaned_position[:, individual, keypoint, space] = z_series # Update the dataset with the cleaned position values - ds["position_processed"] = (ds["position"].dims, cleaned_position) + ds[save_to_variable] = (ds[read_from_variable].dims, cleaned_position) ds.attrs.update({"processed_outliers": True}) ds["percentage_iqr_outliers"] = (["individual", "keypoint", "space"], perc_interp_points) diff --git a/src/vame/preprocessing/filter.py b/src/vame/preprocessing/filter.py index 200531e0..86cf3476 100644 --- a/src/vame/preprocessing/filter.py +++ b/src/vame/preprocessing/filter.py @@ -10,7 +10,11 @@ logger = logger_config.logger -def savgol_filtering(config: dict): +def savgol_filtering( + config: dict, + read_from_variable: str = "position_processed", + save_to_variable: str = "position_processed", +): """ Apply Savitzky-Golay filter to the data. """ @@ -27,7 +31,7 @@ def savgol_filtering(config: dict): _, _, ds = read_pose_estimation_file(file_path=file_path) # Extract processed positions values, with shape: (time, individuals, keypoints, space) - position = np.copy(ds["position_processed"].values) + position = np.copy(ds[read_from_variable].values) filtered_position = np.copy(position) for individual in range(position.shape[1]): for keypoint in range(position.shape[2]): @@ -47,7 +51,7 @@ def savgol_filtering(config: dict): ) # Update the dataset with the filtered position values - ds["position_processed"] = (ds["position"].dims, filtered_position) + ds[save_to_variable] = (ds[read_from_variable].dims, filtered_position) ds.attrs.update({"processed_filtered": True}) # Save the filtered dataset to file diff --git a/src/vame/preprocessing/preprocessing.py b/src/vame/preprocessing/preprocessing.py index c4eaef9e..55b55f80 100644 --- a/src/vame/preprocessing/preprocessing.py +++ b/src/vame/preprocessing/preprocessing.py @@ -21,7 +21,11 @@ def preprocessing( # Low-confidence cleaning logger.info("Cleaning low confidence data points...") - lowconf_cleaning(config=config) + lowconf_cleaning( + config=config, + read_from_variable="position", + save_to_variable="position_cleaned_lowconf", + ) # Egocentric alignment logger.info("Egocentrically aligning and centering...") @@ -29,21 +33,36 @@ def preprocessing( config=config, centered_reference_keypoint=centered_reference_keypoint, orientation_reference_keypoint=orientation_reference_keypoint, + read_from_variable="position_cleaned_lowconf", + save_to_variable="position_egocentric_aligned", ) + + # Create visualization of the preprocessing results up to this point visualize_preprocessing_scatter( config, save_to_file=True, + show_figure=False, ) # Outlier cleaning logger.info("Cleaning outliers...") - outlier_cleaning(config=config) + outlier_cleaning( + config=config, + read_from_variable="position_egocentric_aligned", + save_to_variable="position_processed", + ) # Savgol filtering logger.info("Applying Savitzky-Golay filter...") - savgol_filtering(config=config) + savgol_filtering( + config=config, + read_from_variable="position_processed", + save_to_variable="position_processed", + ) + # Create visualization of the preprocessing results visualize_preprocessing_timeseries( config, save_to_file=True, + show_figure=False, ) diff --git a/src/vame/preprocessing/visualization.py b/src/vame/preprocessing/visualization.py index e0382898..8c55f413 100644 --- a/src/vame/preprocessing/visualization.py +++ b/src/vame/preprocessing/visualization.py @@ -9,11 +9,15 @@ def visualize_preprocessing_scatter( config: dict, session_index: int = 0, frames: list = [], + original_positions_key: str = "position", + cleaned_positions_key: str = "position_cleaned_lowconf", + aligned_positions_key: str = "position_egocentric_aligned", save_to_file: bool = False, + show_figure: bool = True, ): """ - Visualize the preprocessing results by plotting the original and aligned positions - of the keypoints in a scatter plot. + Visualize the preprocessing results by plotting the original, cleaned low-confidence, + and egocentric aligned positions of the keypoints in a scatter plot. """ project_path = config["project_path"] sessions = config["session_names"] @@ -23,21 +27,22 @@ def visualize_preprocessing_scatter( file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") _, _, ds = read_pose_estimation_file(file_path=file_path) - original_positions = ds["position"].values - processed_positions = ds["position_processed"].values + original_positions = ds[original_positions_key].values + cleaned_positions = ds[cleaned_positions_key].values + aligned_positions = ds[aligned_positions_key].values keypoints_labels = ds.keypoints.values if not frames: frames = [int(i * len(original_positions)) for i in [0.1, 0.3, 0.5, 0.7, 0.9]] num_frames = len(frames) - fig, axes = plt.subplots(num_frames, 2, figsize=(14, 6 * num_frames)) # Increased figure size + fig, axes = plt.subplots(num_frames, 3, figsize=(21, 6 * num_frames)) # Increased figure size and columns for i, frame in enumerate(frames): # Compute dynamic limits for the original positions x_orig, y_orig = original_positions[frame, 0, :, 0], original_positions[frame, 0, :, 1] - x_orig -= x_orig[0] # Centralize around the first keypoint - y_orig -= y_orig[0] + # x_orig -= x_orig[0] # Centralize around the first keypoint + # y_orig -= y_orig[0] x_min, x_max = x_orig.min() - 10, x_orig.max() + 10 # Add a margin y_min, y_max = y_orig.min() - 10, y_orig.max() + 10 @@ -45,51 +50,86 @@ def visualize_preprocessing_scatter( ax_original = axes[i, 0] ax_original.scatter(x_orig, y_orig, c="blue", label="Original") for k, (x, y) in enumerate(zip(x_orig, y_orig)): - ax_original.text(x, y, keypoints_labels[k], fontsize=8, color="blue") - ax_original.set_title(f"Original - Frame {frame}", fontsize=10) - ax_original.set_xlabel("X", fontsize=8) - ax_original.set_ylabel("Y", fontsize=8) + ax_original.text(x, y, keypoints_labels[k], fontsize=10, color="blue") + ax_original.set_title(f"Original - Frame {frame}", fontsize=14) + ax_original.set_xlabel("X", fontsize=12) + ax_original.set_ylabel("Y", fontsize=12) ax_original.axhline(0, color="gray", linestyle="--") ax_original.axvline(0, color="gray", linestyle="--") ax_original.axis("equal") ax_original.set_xlim(x_min, x_max) ax_original.set_ylim(y_min, y_max) + # Compute dynamic limits for the cleaned positions + x_cleaned, y_cleaned = cleaned_positions[frame, 0, :, 0], cleaned_positions[frame, 0, :, 1] + x_min_cleaned, x_max_cleaned = x_cleaned.min() - 10, x_cleaned.max() + 10 # Add a margin + y_min_cleaned, y_max_cleaned = y_cleaned.min() - 10, y_cleaned.max() + 10 + + # Centralized Cleaned positions + ax_cleaned = axes[i, 1] + ax_cleaned.scatter(x_cleaned, y_cleaned, c="orange", label="Cleaned Low-Conf") + for k, (x, y) in enumerate(zip(x_cleaned, y_cleaned)): + ax_cleaned.text(x, y, keypoints_labels[k], fontsize=10, color="orange") + ax_cleaned.set_title(f"Cleaned - Frame {frame}", fontsize=14) + ax_cleaned.set_xlabel("X", fontsize=12) + ax_cleaned.set_ylabel("Y", fontsize=12) + ax_cleaned.axhline(0, color="gray", linestyle="--") + ax_cleaned.axvline(0, color="gray", linestyle="--") + ax_cleaned.axis("equal") + ax_cleaned.set_xlim(x_min_cleaned, x_max_cleaned) + ax_cleaned.set_ylim(y_min_cleaned, y_max_cleaned) + # Compute dynamic limits for the aligned positions - x_aligned, y_aligned = processed_positions[frame, 0, :, 0], processed_positions[frame, 0, :, 1] + x_aligned, y_aligned = aligned_positions[frame, 0, :, 0], aligned_positions[frame, 0, :, 1] x_min_aligned, x_max_aligned = x_aligned.min() - 10, x_aligned.max() + 10 # Add a margin y_min_aligned, y_max_aligned = y_aligned.min() - 10, y_aligned.max() + 10 # Centralized Aligned positions - ax_aligned = axes[i, 1] - ax_aligned.scatter(x_aligned, y_aligned, c="green", label="Aligned") + ax_aligned = axes[i, 2] + ax_aligned.scatter(x_aligned, y_aligned, c="green", label="Egocentric Aligned") for k, (x, y) in enumerate(zip(x_aligned, y_aligned)): - ax_aligned.text(x, y, keypoints_labels[k], fontsize=8, color="green") - ax_aligned.set_title(f"Aligned - Frame {frame}", fontsize=10) - ax_aligned.set_xlabel("X", fontsize=8) - ax_aligned.set_ylabel("Y", fontsize=8) + ax_aligned.text(x, y, keypoints_labels[k], fontsize=10, color="green") + ax_aligned.set_title(f"Aligned - Frame {frame}", fontsize=14) + ax_aligned.set_xlabel("X", fontsize=12) + ax_aligned.set_ylabel("Y", fontsize=12) ax_aligned.axhline(0, color="gray", linestyle="--") ax_aligned.axvline(0, color="gray", linestyle="--") ax_aligned.axis("equal") ax_aligned.set_xlim(x_min_aligned, x_max_aligned) ax_aligned.set_ylim(y_min_aligned, y_max_aligned) - plt.tight_layout(pad=2.0) # Add padding to reduce overlap + # Add a figure-level title + fig.suptitle( + f"{session}, Confidence threshold: {config['pose_confidence']}", + fontsize=16, + ) + + # Add padding to reduce overlap between subplots + plt.tight_layout(pad=3.0) if save_to_file: save_fig_path = Path(project_path) / "reports" / "figures" / f"{session}_preprocessing_scatter.png" save_fig_path.parent.mkdir(parents=True, exist_ok=True) plt.savefig(str(save_fig_path)) + if show_figure: + plt.show() + else: + plt.close(fig) + def visualize_preprocessing_timeseries( config: dict, session_index: int = 0, n_samples: int = 1000, + original_positions_key: str = "position", + aligned_positions_key: str = "position_egocentric_aligned", + processed_positions_key: str = "position_processed", save_to_file: bool = False, + show_figure: bool = True, ): """ - Visualize the preprocessing results by plotting the original and aligned positions + Visualize the preprocessing results by plotting the original, aligned, and processed positions of the keypoints in a timeseries plot. """ project_path = config["project_path"] @@ -100,7 +140,7 @@ def visualize_preprocessing_timeseries( file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") _, _, ds = read_pose_estimation_file(file_path=file_path) - fig, ax = plt.subplots(4, 1, figsize=(10, 12)) + fig, ax = plt.subplots(6, 1, figsize=(10, 16)) # Adjusted for 6 subplots individual = "individual_0" keypoints_labels = ds.keypoints.values @@ -121,44 +161,88 @@ def visualize_preprocessing_timeseries( space="y", ) - ds.position.sel(**sel_x)[0:n_samples].plot( + # Original positions (first two subplots) + ds[original_positions_key].sel(**sel_x)[0:n_samples].plot( linewidth=1.5, ax=ax[0], label=kp, color=colors[i], ) - ds.position.sel(**sel_y)[0:n_samples].plot( + ds[original_positions_key].sel(**sel_y)[0:n_samples].plot( linewidth=1.5, ax=ax[1], label=kp, color=colors[i], ) - ds.position_processed.sel(**sel_x)[0:n_samples].plot( + + # Aligned positions (next two subplots) + ds[aligned_positions_key].sel(**sel_x)[0:n_samples].plot( linewidth=1.5, ax=ax[2], label=kp, color=colors[i], ) - ds.position_processed.sel(**sel_y)[0:n_samples].plot( + ds[aligned_positions_key].sel(**sel_y)[0:n_samples].plot( linewidth=1.5, ax=ax[3], label=kp, color=colors[i], ) - ax[0].set_title("") - ax[1].set_title("") - ax[2].set_title("") - ax[3].set_title("") + # Processed positions (last two subplots) + ds[processed_positions_key].sel(**sel_x)[0:n_samples].plot( + linewidth=1.5, + ax=ax[4], + label=kp, + color=colors[i], + ) + ds[processed_positions_key].sel(**sel_y)[0:n_samples].plot( + linewidth=1.5, + ax=ax[5], + label=kp, + color=colors[i], + ) + + # Set common labels for Y axes + ax[0].set_ylabel( + "Original Allocentric X", + fontsize=12, + ) + ax[1].set_ylabel( + "Original Allocentric Y", + fontsize=12, + ) + ax[2].set_ylabel( + "Aligned Egocentric X", + fontsize=12, + ) + ax[3].set_ylabel( + "Aligned Egocentric Y", + fontsize=12, + ) + ax[4].set_ylabel( + "Processed Egocentric X", + fontsize=12, + ) + ax[5].set_ylabel( + "Processed Egocentric Y", + fontsize=12, + ) - ax[0].set_xlabel("") - ax[1].set_xlabel("") - ax[2].set_xlabel("") + # Labels for X axes + for idx, a in enumerate(ax): + a.set_title("") + if idx % 2 == 0: + a.set_xlabel("") + else: + a.set_xlabel( + "Time", + fontsize=10, + ) - ax[0].set_ylabel("Allocentric X") - ax[1].set_ylabel("Allocentric Y") - ax[2].set_ylabel("Egocentric X") - ax[3].set_ylabel("Egocentric Y") + # Adjust padding + fig.subplots_adjust(hspace=0.4) + fig.tight_layout(rect=[0, 0, 1, 0.96], h_pad=1.2) # Add a single legend for all subplots handles, labels = ax[0].get_legend_handles_labels() @@ -167,12 +251,19 @@ def visualize_preprocessing_timeseries( labels, loc="upper center", ncol=5, - bbox_to_anchor=(0.5, 1.02), + bbox_to_anchor=(0.5, 0.98), ) - plt.tight_layout(rect=[0, 0, 1, 0.98]) - if save_to_file: save_fig_path = Path(project_path) / "reports" / "figures" / f"{session}_preprocessing_timeseries.png" save_fig_path.parent.mkdir(parents=True, exist_ok=True) - plt.savefig(str(save_fig_path)) + plt.savefig( + str(save_fig_path), + ) + + if show_figure: + plt.show() + else: + plt.close( + fig, + ) From 1f535f7dbfbf293f3760e77fccca235ef150ce45 Mon Sep 17 00:00:00 2001 From: luiz Date: Sat, 21 Dec 2024 15:38:56 +0100 Subject: [PATCH 15/77] create trainset --- src/vame/model/create_training.py | 251 +++++++++------------------- src/vame/preprocessing/alignment.py | 6 +- src/vame/preprocessing/cleaning.py | 2 +- src/vame/schemas/states.py | 19 +-- 4 files changed, 88 insertions(+), 190 deletions(-) diff --git a/src/vame/model/create_training.py b/src/vame/model/create_training.py index a9f554c4..f5aa8471 100644 --- a/src/vame/model/create_training.py +++ b/src/vame/model/create_training.py @@ -5,10 +5,12 @@ from scipy.stats import iqr import matplotlib.pyplot as plt from typing import List, Optional + from vame.logging.logger import VameLogger from vame.util.auxiliary import read_config from vame.schemas.states import CreateTrainsetFunctionSchema, save_state from vame.util.data_manipulation import interpol_all_nans +from vame.util.data_manipulation import read_pose_estimation_file logger_config = VameLogger(__name__) @@ -102,164 +104,80 @@ def plot_check_parameter( def traindata_aligned( - cfg: dict, - sessions: List[str], - testfraction: float, - savgol_filter: bool, - check_parameter: bool, + config: dict, + sessions: List[str] | None = None, + test_fraction: float | None = None, + read_from_variable: str = "position_processed", ) -> None: """ Create training dataset for aligned data. + Save numpy arrays with the test/train info to the project folder. Parameters ---------- - cfg : dict - Configuration parameters. - sessions : List[str] - List of sessions. - testfraction : float - Fraction of data to use as test data. - savgol_filter : bool - Flag indicating whether to apply Savitzky-Golay filter. - check_parameter : bool - If True, the function will plot the z-scored data and the filtered data. + config : dict + Configuration parameters dictionary. + sessions : List[str], optional + List of session names. If None, all sessions will be used. Defaults to None. + test_fraction : float, optional + Fraction of data to use as test data. Defaults to 0.1. Returns ------- None - Save numpy arrays with the test/train info to the project folder. """ - X_train = [] - pos = [] - pos_temp = 0 - pos.append(0) - - if check_parameter: - X_true = [] - sessions = [sessions[0]] + project_path = config["project_path"] + if sessions is None: + sessions = config["session_names"] + if test_fraction is None: + test_fraction = config["test_fraction"] + all_data_list = [] for session in sessions: - logger.info("z-scoring of session %s" % session) - path_to_file = os.path.join( - cfg["project_path"], - "data", - "processed", - session, - session + "-PE-seq.npy", - ) - data = np.load(path_to_file) - - X_mean = np.mean(data, axis=None) - X_std = np.std(data, axis=None) - X_z = (data.T - X_mean) / X_std - - # Introducing artificial error spikes - # rang = [1.5, 2, 2.5, 3, 3.5, 3, 3, 2.5, 2, 1.5] - # for i in range(num_frames): - # if i % 300 == 0: - # rnd = np.random.choice(12,2) - # for j in range(10): - # X_z[i+j, rnd[0]] = X_z[i+j, rnd[0]] * rang[j] - # X_z[i+j, rnd[1]] = X_z[i+j, rnd[1]] * rang[j] - - if check_parameter: - X_z_copy = X_z.copy() - X_true.append(X_z_copy) - - if cfg["robust"]: - iqr_val = iqr(X_z) - logger.info("IQR value: %.2f, IQR cutoff: %.2f" % (iqr_val, cfg["iqr_factor"] * iqr_val)) - for i in range(X_z.shape[0]): - for marker in range(X_z.shape[1]): - if X_z[i, marker] > cfg["iqr_factor"] * iqr_val: - X_z[i, marker] = np.nan - - elif X_z[i, marker] < -cfg["iqr_factor"] * iqr_val: - X_z[i, marker] = np.nan + # Read session data + file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") + _, _, ds = read_pose_estimation_file(file_path=file_path) - X_z = interpol_all_nans(X_z) + position_data = ds[read_from_variable] + centered_reference_keypoint = ds.attrs['centered_reference_keypoint'] + orientation_reference_keypoint = ds.attrs['orientation_reference_keypoint'] - X_len = len(data.T) - pos_temp += X_len - pos.append(pos_temp) - X_train.append(X_z) + # Get the coordinates + individuals = position_data.coords['individuals'].values + keypoints = position_data.coords['keypoints'].values + spaces = position_data.coords['space'].values - X = np.concatenate(X_train, axis=0) - # X_std = np.std(X) + # Create a flattened array and infer column indices + flattened_array = position_data.values.reshape(position_data.shape[0], -1) + columns = [f"{ind}_{kp}_{sp}" for ind in individuals for kp in keypoints for sp in spaces] - detect_anchors = np.std(X.T, axis=1) - sort_anchors = np.sort(detect_anchors) - if sort_anchors[0] == sort_anchors[1]: - anchors = np.where(detect_anchors == sort_anchors[0])[0] - anchor_1_temp = anchors[0] - anchor_2_temp = anchors[1] - else: - anchor_1_temp = int(np.where(detect_anchors == sort_anchors[0])[0]) - anchor_2_temp = int(np.where(detect_anchors == sort_anchors[1])[0]) + # Identify columns to exclude + excluded_columns = [] + for ind in individuals: + excluded_columns.append(f"{ind}_{centered_reference_keypoint}_x") # Exclude both x and y for centered_reference_keypoint + excluded_columns.append(f"{ind}_{centered_reference_keypoint}_y") + excluded_columns.append(f"{ind}_{orientation_reference_keypoint}_x") # Exclude only x for orientation_reference_keypoint - if anchor_1_temp > anchor_2_temp: - anchor_1 = anchor_1_temp - anchor_2 = anchor_2_temp - else: - anchor_1 = anchor_2_temp - anchor_2 = anchor_1_temp + # Filter out the excluded columns + included_indices = [i for i, col in enumerate(columns) if col not in excluded_columns] + filtered_array = flattened_array[:, included_indices] - X = np.delete(X, anchor_1, 1) - X = np.delete(X, anchor_2, 1) - X = X.T + all_data_list.append(filtered_array) - if savgol_filter: - X_med = scipy.signal.savgol_filter(X, cfg["savgol_length"], cfg["savgol_order"]) - else: - X_med = X + all_data_array = np.concatenate(all_data_list, axis=0).T + test_size = int(all_data_array.shape[1] * test_fraction) + data_test = all_data_array[:, :test_size] + data_train = all_data_array[:, test_size:] - num_frames = len(X_med.T) - test = int(num_frames * testfraction) + # Save numpy arrays the the test/train info: + train_data_path = Path(project_path) / "data" / "train" / "train_seq.npy" + np.save(str(train_data_path), data_train) - z_test = X_med[:, :test] - z_train = X_med[:, test:] + test_data_path = Path(project_path) / "data" / "train" / "test_seq.npy" + np.save(str(test_data_path), data_test) - if check_parameter: - plot_check_parameter( - cfg=cfg, - iqr_val=iqr_val, - num_frames=num_frames, - X_true=X_true, - X_med=X_med, - ) - else: - # save numpy arrays the the test/train info: - np.save( - os.path.join( - cfg["project_path"], - "data", - "train", - "train_seq.npy", - ), - z_train, - ) - np.save( - os.path.join( - cfg["project_path"], - "data", - "train", - "test_seq.npy", - ), - z_test, - ) - for i, session in enumerate(sessions): - np.save( - os.path.join( - cfg["project_path"], - "data", - "processed", - session, - session + "-PE-seq-clean.npy", - ), - X_med[:, pos[i] : pos[i + 1]], - ) - logger.info("Lenght of train data: %d" % len(z_train.T)) - logger.info("Lenght of test data: %d" % len(z_test.T)) + logger.info(f"Lenght of train data: {data_train.shape[1]}") + logger.info(f"Lenght of test data: {data_test.shape[1]}") def traindata_fixed( @@ -417,9 +335,7 @@ def traindata_fixed( @save_state(model=CreateTrainsetFunctionSchema) def create_trainset( - config: str, - pose_ref_index: Optional[List] = None, - check_parameter: bool = False, + config: dict, save_logs: bool = False, ) -> None: """ @@ -445,12 +361,8 @@ def create_trainset( Parameters ---------- - config : str - Path to the config file. - pose_ref_index : Optional[List], optional - List of reference coordinate indices for alignment. Defaults to None. - check_parameter : bool, optional - If True, the function will plot the z-scored data and the filtered data. Defaults to False. + config : dict + Configuration parameters dictionary. save_logs : bool, optional If True, the function will save logs to the project folder. Defaults to False. @@ -459,55 +371,48 @@ def create_trainset( None """ try: - config_file = Path(config).resolve() - cfg = read_config(str(config_file)) - fixed = cfg["egocentric_data"] + fixed = config["egocentric_data"] if save_logs: - log_path = Path(cfg["project_path"]) / "logs" / "create_trainset.log" + log_path = Path(config["project_path"]) / "logs" / "create_trainset.log" logger_config.add_file_handler(str(log_path)) - if not os.path.exists(os.path.join(cfg["project_path"], "data", "train", "")): - os.mkdir(os.path.join(cfg["project_path"], "data", "train", "")) + if not os.path.exists(os.path.join(config["project_path"], "data", "train", "")): + os.mkdir(os.path.join(config["project_path"], "data", "train", "")) sessions = [] - if cfg["all_data"] == "No": - for session in cfg["session_names"]: + if config["all_data"] == "No": + for session in config["session_names"]: use_session = input("Do you want to train on " + session + "? yes/no: ") if use_session == "yes": sessions.append(session) if use_session == "no": continue else: - sessions = cfg["session_names"] + sessions = config["session_names"] logger.info("Creating training dataset...") - if cfg["robust"]: - logger.info("Using robust setting to eliminate outliers! IQR factor: %d" % cfg["iqr_factor"]) if not fixed: logger.info("Creating trainset from the vame.egocentrical_alignment() output ") traindata_aligned( - cfg, - sessions, - cfg["test_fraction"], - cfg["savgol_filter"], - check_parameter, + config=config, + sessions=sessions, ) else: - logger.info("Creating trainset from the vame.pose_to_numpy() output ") - traindata_fixed( - cfg, - sessions, - cfg["test_fraction"], - cfg["num_features"], - cfg["savgol_filter"], - check_parameter, - pose_ref_index, - ) - - if not check_parameter: - logger.info("A training and test set has been created. Next step: vame.train_model()") + raise NotImplementedError("Fixed data training is not implemented yet") + # logger.info("Creating trainset from the vame.pose_to_numpy() output ") + # traindata_fixed( + # cfg, + # sessions, + # cfg["test_fraction"], + # cfg["num_features"], + # cfg["savgol_filter"], + # check_parameter, + # pose_ref_index, + # ) + + logger.info("A training and test set has been created. Next step: vame.train_model()") except Exception as e: logger.exception(str(e)) diff --git a/src/vame/preprocessing/alignment.py b/src/vame/preprocessing/alignment.py index af064021..50dded18 100644 --- a/src/vame/preprocessing/alignment.py +++ b/src/vame/preprocessing/alignment.py @@ -90,7 +90,11 @@ def egocentrically_align_and_center( # Update the dataset with the cleaned position values ds[save_to_variable] = (ds[read_from_variable].dims, position_aligned) - ds.attrs.update({"processed_alignment": True}) + ds.attrs.update({ + "processed_alignment": True, + "centered_reference_keypoint": centered_reference_keypoint, + "orientation_reference_keypoint": orientation_reference_keypoint, + }) # Save the aligned dataset to file cleaned_file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") diff --git a/src/vame/preprocessing/cleaning.py b/src/vame/preprocessing/cleaning.py index a03e545c..91185535 100644 --- a/src/vame/preprocessing/cleaning.py +++ b/src/vame/preprocessing/cleaning.py @@ -28,7 +28,7 @@ def lowconf_cleaning( for i, session in enumerate(sessions): logger.info(f"Session: {session}") - # Read raw session data + # Read session data file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") _, _, ds = read_pose_estimation_file(file_path=file_path) diff --git a/src/vame/schemas/states.py b/src/vame/schemas/states.py index 00724e41..3d76b33b 100644 --- a/src/vame/schemas/states.py +++ b/src/vame/schemas/states.py @@ -22,7 +22,7 @@ class GenerativeModelModeEnum(str, Enum): class BaseStateSchema(BaseModel): - config: str = Field(title="Configuration file path") + config: dict = Field(title="Configuration dictionary") execution_state: StatesEnum | None = Field( title="Method execution state", default=None, @@ -55,15 +55,7 @@ class EgocentricAlignmentFunctionSchema(BaseStateSchema): class PoseToNumpyFunctionSchema(BaseStateSchema): ... -class CreateTrainsetFunctionSchema(BaseStateSchema): - pose_ref_index: Optional[list] = Field( - title="Pose reference index", - default=None, - ) - check_parameter: bool = Field( - title="Check parameter", - default=False, - ) +class CreateTrainsetFunctionSchema(BaseStateSchema): ... class TrainModelFunctionSchema(BaseStateSchema): ... @@ -176,14 +168,11 @@ class VAMEPipelineStatesSchema(BaseModel): ) -def _save_state(model: BaseModel, function_name: str, state: StatesEnum) -> None: +def _save_state(model: BaseStateSchema, function_name: str, state: StatesEnum) -> None: """ Save the state of the function to the project states json file. """ - config_file_path = Path(model.config) - project_path = config_file_path.parent - states_file_path = project_path / "states/states.json" - + states_file_path = Path(model.config["project_path"]) / "states" / "states.json" with open(states_file_path, "r") as f: states = json.load(f) From 0488fc6de2b1095f07906494bf845e548e9be5ad Mon Sep 17 00:00:00 2001 From: luiz Date: Sat, 21 Dec 2024 15:41:59 +0100 Subject: [PATCH 16/77] comment old --- src/vame/model/create_training.py | 395 +++++++++++------------------ src/vame/util/data_manipulation.py | 38 +-- 2 files changed, 171 insertions(+), 262 deletions(-) diff --git a/src/vame/model/create_training.py b/src/vame/model/create_training.py index f5aa8471..0d91f011 100644 --- a/src/vame/model/create_training.py +++ b/src/vame/model/create_training.py @@ -1,15 +1,10 @@ import os import numpy as np from pathlib import Path -import scipy.signal -from scipy.stats import iqr -import matplotlib.pyplot as plt -from typing import List, Optional +from typing import List from vame.logging.logger import VameLogger -from vame.util.auxiliary import read_config from vame.schemas.states import CreateTrainsetFunctionSchema, save_state -from vame.util.data_manipulation import interpol_all_nans from vame.util.data_manipulation import read_pose_estimation_file @@ -17,92 +12,6 @@ logger = logger_config.logger -def plot_check_parameter( - cfg: dict, - iqr_val: float, - num_frames: int, - X_true: List[np.ndarray], - X_med: np.ndarray, -) -> None: - """ - Plot the check parameter - z-scored data and the filtered data. - - Parameters - ---------- - cfg : dict - Configuration parameters. - iqr_val : float - IQR value. - num_frames : int - Number of frames. - X_true : List[np.ndarray] - List of true data. - X_med : np.ndarray - Filtered data. - - Returns - ------- - None - Plot the z-scored data and the filtered data. - """ - plot_X_orig = np.concatenate(X_true, axis=0).T - plot_X_med = X_med.copy() - iqr_cutoff = cfg["iqr_factor"] * iqr_val - - plt.figure() - plt.plot(plot_X_orig.T) - plt.axhline(y=iqr_cutoff, color="r", linestyle="--", label="IQR cutoff") - plt.axhline(y=-iqr_cutoff, color="r", linestyle="--") - plt.title("Full Signal z-scored") - plt.legend() - - if num_frames > 1000: - rnd = np.random.choice(num_frames) - - plt.figure() - plt.plot(plot_X_med[:, rnd : rnd + 1000].T) - plt.axhline(y=iqr_cutoff, color="r", linestyle="--", label="IQR cutoff") - plt.axhline(y=-iqr_cutoff, color="r", linestyle="--") - plt.title("Filtered signal z-scored") - plt.legend() - - plt.figure() - plt.plot(plot_X_orig[:, rnd : rnd + 1000].T) - plt.axhline(y=iqr_cutoff, color="r", linestyle="--", label="IQR cutoff") - plt.axhline(y=-iqr_cutoff, color="r", linestyle="--") - plt.title("Original signal z-scored") - plt.legend() - - plt.figure() - plt.plot(plot_X_orig[:, rnd : rnd + 1000].T, "g", alpha=0.5) - plt.plot(plot_X_med[:, rnd : rnd + 1000].T, "--m", alpha=0.6) - plt.axhline(y=iqr_cutoff, color="r", linestyle="--", label="IQR cutoff") - plt.axhline(y=-iqr_cutoff, color="r", linestyle="--") - plt.title("Overlayed z-scored") - plt.legend() - - # plot_X_orig = np.delete(plot_X_orig.T, anchor_1, 1) - # plot_X_orig = np.delete(plot_X_orig, anchor_2, 1) - # mse = (np.square(plot_X_orig[rnd:rnd+1000, :] - plot_X_med[:,rnd:rnd+1000].T)).mean(axis=0) - - else: - plt.figure() - plt.plot(plot_X_med.T) - plt.axhline(y=iqr_cutoff, color="r", linestyle="--", label="IQR cutoff") - plt.axhline(y=-iqr_cutoff, color="r", linestyle="--") - plt.title("Filtered signal z-scored") - plt.legend() - - plt.figure() - plt.plot(plot_X_orig.T) - plt.axhline(y=iqr_cutoff, color="r", linestyle="--", label="IQR cutoff") - plt.axhline(y=-iqr_cutoff, color="r", linestyle="--") - plt.title("Original signal z-scored") - plt.legend() - - logger.info("Please run the function with check_parameter=False if you are happy with the results") - - def traindata_aligned( config: dict, sessions: List[str] | None = None, @@ -180,157 +89,157 @@ def traindata_aligned( logger.info(f"Lenght of test data: {data_test.shape[1]}") -def traindata_fixed( - cfg: dict, - sessions: List[str], - testfraction: float, - num_features: int, - savgol_filter: bool, - check_parameter: bool, - pose_ref_index: Optional[List[int]], -) -> None: - """ - Create training dataset for fixed data. - - Parameters - ---------- - cfg : dict - Configuration parameters. - sessions : List[str] - List of sessions. - testfraction : float - Fraction of data to use as test data. - num_features : int - Number of features. - savgol_filter : bool - Flag indicating whether to apply Savitzky-Golay filter. - check_parameter : bool - If True, the function will plot the z-scored data and the filtered data. - pose_ref_index : Optional[List[int]] - List of reference coordinate indices for alignment. - - Returns: - None - Save numpy arrays with the test/train info to the project folder. - """ - X_train = [] - pos = [] - pos_temp = 0 - pos.append(0) - - if check_parameter: - X_true = [] - sessions = [sessions[0]] - - for session in sessions: - logger.info("z-scoring of file %s" % session) - path_to_file = os.path.join( - cfg["project_path"], - "data", - "processed", - session, - session + "-PE-seq.npy", - ) - data = np.load(path_to_file) - - X_mean = np.mean(data, axis=None) - X_std = np.std(data, axis=None) - X_z = (data.T - X_mean) / X_std - - if check_parameter: - X_z_copy = X_z.copy() - X_true.append(X_z_copy) - - if cfg["robust"]: - iqr_val = iqr(X_z) - logger.info("IQR value: %.2f, IQR cutoff: %.2f" % (iqr_val, cfg["iqr_factor"] * iqr_val)) - for i in range(X_z.shape[0]): - for marker in range(X_z.shape[1]): - if X_z[i, marker] > cfg["iqr_factor"] * iqr_val: - X_z[i, marker] = np.nan - - elif X_z[i, marker] < -cfg["iqr_factor"] * iqr_val: - X_z[i, marker] = np.nan - - X_z[i, :] = interpol_all_nans(X_z[i, :]) - - X_len = len(data.T) - pos_temp += X_len - pos.append(pos_temp) - X_train.append(X_z) - - X = np.concatenate(X_train, axis=0).T - - if savgol_filter: - X_med = scipy.signal.savgol_filter(X, cfg["savgol_length"], cfg["savgol_order"]) - else: - X_med = X - - num_frames = len(X_med.T) - test = int(num_frames * testfraction) - - z_test = X_med[:, :test] - z_train = X_med[:, test:] - - if check_parameter: - plot_check_parameter( - cfg, - iqr_val, - num_frames, - X_true, - X_med, - ) - - else: - if pose_ref_index is None: - raise ValueError("Please provide a pose reference index for training on fixed data. E.g. [0,5]") - # save numpy arrays the the test/train info: - np.save( - os.path.join( - cfg["project_path"], - "data", - "train", - "train_seq.npy", - ), - z_train, - ) - np.save( - os.path.join( - cfg["project_path"], - "data", - "train", - "test_seq.npy", - ), - z_test, - ) - - y_shifted_indices = np.arange(0, num_features, 2) - x_shifted_indices = np.arange(1, num_features, 2) - belly_Y_ind = pose_ref_index[0] * 2 - belly_X_ind = (pose_ref_index[0] * 2) + 1 - - for i, session in enumerate(sessions): - # Shifting section added 2/29/2024 PN - X_med_shifted_file = X_med[:, pos[i] : pos[i + 1]] - belly_Y_shift = X_med[belly_Y_ind, pos[i] : pos[i + 1]] - belly_X_shift = X_med[belly_X_ind, pos[i] : pos[i + 1]] - - X_med_shifted_file[y_shifted_indices, :] -= belly_Y_shift - X_med_shifted_file[x_shifted_indices, :] -= belly_X_shift - - np.save( - os.path.join( - cfg["project_path"], - "data", - "processed", - session, - session + "-PE-seq-clean.npy", - ), - X_med_shifted_file, - ) # saving new shifted file - - logger.info("Lenght of train data: %d" % len(z_train.T)) - logger.info("Lenght of test data: %d" % len(z_test.T)) +# def traindata_fixed( +# cfg: dict, +# sessions: List[str], +# testfraction: float, +# num_features: int, +# savgol_filter: bool, +# check_parameter: bool, +# pose_ref_index: Optional[List[int]], +# ) -> None: +# """ +# Create training dataset for fixed data. + +# Parameters +# ---------- +# cfg : dict +# Configuration parameters. +# sessions : List[str] +# List of sessions. +# testfraction : float +# Fraction of data to use as test data. +# num_features : int +# Number of features. +# savgol_filter : bool +# Flag indicating whether to apply Savitzky-Golay filter. +# check_parameter : bool +# If True, the function will plot the z-scored data and the filtered data. +# pose_ref_index : Optional[List[int]] +# List of reference coordinate indices for alignment. + +# Returns: +# None +# Save numpy arrays with the test/train info to the project folder. +# """ +# X_train = [] +# pos = [] +# pos_temp = 0 +# pos.append(0) + +# if check_parameter: +# X_true = [] +# sessions = [sessions[0]] + +# for session in sessions: +# logger.info("z-scoring of file %s" % session) +# path_to_file = os.path.join( +# cfg["project_path"], +# "data", +# "processed", +# session, +# session + "-PE-seq.npy", +# ) +# data = np.load(path_to_file) + +# X_mean = np.mean(data, axis=None) +# X_std = np.std(data, axis=None) +# X_z = (data.T - X_mean) / X_std + +# if check_parameter: +# X_z_copy = X_z.copy() +# X_true.append(X_z_copy) + +# if cfg["robust"]: +# iqr_val = iqr(X_z) +# logger.info("IQR value: %.2f, IQR cutoff: %.2f" % (iqr_val, cfg["iqr_factor"] * iqr_val)) +# for i in range(X_z.shape[0]): +# for marker in range(X_z.shape[1]): +# if X_z[i, marker] > cfg["iqr_factor"] * iqr_val: +# X_z[i, marker] = np.nan + +# elif X_z[i, marker] < -cfg["iqr_factor"] * iqr_val: +# X_z[i, marker] = np.nan + +# X_z[i, :] = interpol_all_nans(X_z[i, :]) + +# X_len = len(data.T) +# pos_temp += X_len +# pos.append(pos_temp) +# X_train.append(X_z) + +# X = np.concatenate(X_train, axis=0).T + +# if savgol_filter: +# X_med = scipy.signal.savgol_filter(X, cfg["savgol_length"], cfg["savgol_order"]) +# else: +# X_med = X + +# num_frames = len(X_med.T) +# test = int(num_frames * testfraction) + +# z_test = X_med[:, :test] +# z_train = X_med[:, test:] + +# if check_parameter: +# plot_check_parameter( +# cfg, +# iqr_val, +# num_frames, +# X_true, +# X_med, +# ) + +# else: +# if pose_ref_index is None: +# raise ValueError("Please provide a pose reference index for training on fixed data. E.g. [0,5]") +# # save numpy arrays the the test/train info: +# np.save( +# os.path.join( +# cfg["project_path"], +# "data", +# "train", +# "train_seq.npy", +# ), +# z_train, +# ) +# np.save( +# os.path.join( +# cfg["project_path"], +# "data", +# "train", +# "test_seq.npy", +# ), +# z_test, +# ) + +# y_shifted_indices = np.arange(0, num_features, 2) +# x_shifted_indices = np.arange(1, num_features, 2) +# belly_Y_ind = pose_ref_index[0] * 2 +# belly_X_ind = (pose_ref_index[0] * 2) + 1 + +# for i, session in enumerate(sessions): +# # Shifting section added 2/29/2024 PN +# X_med_shifted_file = X_med[:, pos[i] : pos[i + 1]] +# belly_Y_shift = X_med[belly_Y_ind, pos[i] : pos[i + 1]] +# belly_X_shift = X_med[belly_X_ind, pos[i] : pos[i + 1]] + +# X_med_shifted_file[y_shifted_indices, :] -= belly_Y_shift +# X_med_shifted_file[x_shifted_indices, :] -= belly_X_shift + +# np.save( +# os.path.join( +# cfg["project_path"], +# "data", +# "processed", +# session, +# session + "-PE-seq-clean.npy", +# ), +# X_med_shifted_file, +# ) # saving new shifted file + +# logger.info("Lenght of train data: %d" % len(z_train.T)) +# logger.info("Lenght of test data: %d" % len(z_test.T)) @save_state(model=CreateTrainsetFunctionSchema) diff --git a/src/vame/util/data_manipulation.py b/src/vame/util/data_manipulation.py index 563ddcc3..0ae05474 100644 --- a/src/vame/util/data_manipulation.py +++ b/src/vame/util/data_manipulation.py @@ -168,25 +168,25 @@ def nan_helper(y: np.ndarray) -> Tuple: return np.isnan(y), lambda z: z.nonzero()[0] -def interpol_all_nans(arr: np.ndarray) -> np.ndarray: - """ - Interpolates all NaN values in the given array. - - Parameters - ---------- - arr : np.ndarray - Input array containing NaN values. - - Returns - ------- - np.ndarray - Array with NaN values replaced by interpolated values. - """ - y = np.transpose(arr) - nans, x = nan_helper(y) - y[nans] = np.interp(x(nans), x(~nans), y[~nans]) - arr = np.transpose(y) - return arr +# def interpol_all_nans(arr: np.ndarray) -> np.ndarray: +# """ +# Interpolates all NaN values in the given array. + +# Parameters +# ---------- +# arr : np.ndarray +# Input array containing NaN values. + +# Returns +# ------- +# np.ndarray +# Array with NaN values replaced by interpolated values. +# """ +# y = np.transpose(arr) +# nans, x = nan_helper(y) +# y[nans] = np.interp(x(nans), x(~nans), y[~nans]) +# arr = np.transpose(y) +# return arr def interpol_first_rows_nans(arr: np.ndarray) -> np.ndarray: From eeec6676d8adc9b696fb5c52c2a44b4466719f67 Mon Sep 17 00:00:00 2001 From: luiz Date: Sat, 21 Dec 2024 15:48:59 +0100 Subject: [PATCH 17/77] train model wip --- src/vame/model/rnn_vae.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/vame/model/rnn_vae.py b/src/vame/model/rnn_vae.py index cc1c88b6..36bf2d6d 100644 --- a/src/vame/model/rnn_vae.py +++ b/src/vame/model/rnn_vae.py @@ -471,7 +471,10 @@ def test( @save_state(model=TrainModelFunctionSchema) -def train_model(config: str, save_logs: bool = False) -> None: +def train_model( + config: dict, + save_logs: bool = False, +) -> None: """ Train Variational Autoencoder using the configuration file values. Fills in the values in the "train_model" key of the states.json file. @@ -497,8 +500,8 @@ def train_model(config: str, save_logs: bool = False) -> None: Parameters ---------- - config : str - Path to the configuration file. + config : dict + Configuration dictionary. save_logs : bool, optional Whether to save the logs, by default False. @@ -506,10 +509,9 @@ def train_model(config: str, save_logs: bool = False) -> None: ------- None """ + cfg = config try: tqdm_logger_stream = None - config_file = Path(config).resolve() - cfg = read_config(str(config_file)) if save_logs: tqdm_logger_stream = TqdmToLogger(logger) log_path = Path(cfg["project_path"]) / "logs" / "train_model.log" From fd38d554342204523f8a244017478e7f5f7af216 Mon Sep 17 00:00:00 2001 From: luiz Date: Sat, 21 Dec 2024 16:24:07 +0100 Subject: [PATCH 18/77] model train with correct num features --- src/vame/model/rnn_vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vame/model/rnn_vae.py b/src/vame/model/rnn_vae.py index 36bf2d6d..0f51dc66 100644 --- a/src/vame/model/rnn_vae.py +++ b/src/vame/model/rnn_vae.py @@ -552,7 +552,7 @@ def train_model( LEARNING_RATE = cfg["learning_rate"] NUM_FEATURES = cfg["num_features"] if not fixed: - NUM_FEATURES = NUM_FEATURES - 2 + NUM_FEATURES = NUM_FEATURES - 3 TEMPORAL_WINDOW = cfg["time_window"] * 2 FUTURE_DECODER = cfg["prediction_decoder"] FUTURE_STEPS = cfg["prediction_steps"] From 9ab2c487f4c53cdd21fffb512436496211964254 Mon Sep 17 00:00:00 2001 From: luiz Date: Sun, 22 Dec 2024 11:20:07 +0100 Subject: [PATCH 19/77] evaluate and pipeline --- src/vame/initialize_project/new.py | 2 ++ src/vame/model/create_training.py | 18 ++++++++----- src/vame/model/evaluate.py | 41 +++++++++++++++-------------- src/vame/pipeline.py | 28 +++++++++++++------- src/vame/preprocessing/alignment.py | 12 +++++---- 5 files changed, 59 insertions(+), 42 deletions(-) diff --git a/src/vame/initialize_project/new.py b/src/vame/initialize_project/new.py index 7ef6d625..f982c91a 100644 --- a/src/vame/initialize_project/new.py +++ b/src/vame/initialize_project/new.py @@ -94,6 +94,7 @@ def init_new_project( data_processed_path = data_path / "processed" results_path = project_path / "results" model_path = project_path / "model" + model_evaluate_path = model_path / "evaluate" model_pretrained_path = model_path / "pretrained_model" for p in [ data_path, @@ -102,6 +103,7 @@ def init_new_project( results_path, model_path, model_pretrained_path, + model_evaluate_path, ]: p.mkdir(parents=True) logger.info('Created "{}"'.format(p)) diff --git a/src/vame/model/create_training.py b/src/vame/model/create_training.py index 0d91f011..16ac04e6 100644 --- a/src/vame/model/create_training.py +++ b/src/vame/model/create_training.py @@ -48,13 +48,13 @@ def traindata_aligned( _, _, ds = read_pose_estimation_file(file_path=file_path) position_data = ds[read_from_variable] - centered_reference_keypoint = ds.attrs['centered_reference_keypoint'] - orientation_reference_keypoint = ds.attrs['orientation_reference_keypoint'] + centered_reference_keypoint = ds.attrs["centered_reference_keypoint"] + orientation_reference_keypoint = ds.attrs["orientation_reference_keypoint"] # Get the coordinates - individuals = position_data.coords['individuals'].values - keypoints = position_data.coords['keypoints'].values - spaces = position_data.coords['space'].values + individuals = position_data.coords["individuals"].values + keypoints = position_data.coords["keypoints"].values + spaces = position_data.coords["space"].values # Create a flattened array and infer column indices flattened_array = position_data.values.reshape(position_data.shape[0], -1) @@ -63,9 +63,13 @@ def traindata_aligned( # Identify columns to exclude excluded_columns = [] for ind in individuals: - excluded_columns.append(f"{ind}_{centered_reference_keypoint}_x") # Exclude both x and y for centered_reference_keypoint + excluded_columns.append( + f"{ind}_{centered_reference_keypoint}_x" + ) # Exclude both x and y for centered_reference_keypoint excluded_columns.append(f"{ind}_{centered_reference_keypoint}_y") - excluded_columns.append(f"{ind}_{orientation_reference_keypoint}_x") # Exclude only x for orientation_reference_keypoint + excluded_columns.append( + f"{ind}_{orientation_reference_keypoint}_x" + ) # Exclude only x for orientation_reference_keypoint # Filter out the excluded columns included_indices = [i for i, col in enumerate(columns) if col not in excluded_columns] diff --git a/src/vame/model/evaluate.py b/src/vame/model/evaluate.py index de2111c4..7ec3f8f4 100644 --- a/src/vame/model/evaluate.py +++ b/src/vame/model/evaluate.py @@ -220,7 +220,7 @@ def eval_temporal( FUTURE_STEPS = cfg["prediction_steps"] NUM_FEATURES = cfg["num_features"] if not fixed: - NUM_FEATURES = NUM_FEATURES - 2 + NUM_FEATURES = NUM_FEATURES - 3 TEST_BATCH_SIZE = 64 hidden_size_layer_1 = cfg["hidden_size_layer_1"] hidden_size_layer_2 = cfg["hidden_size_layer_2"] @@ -332,7 +332,7 @@ def eval_temporal( @save_state(model=EvaluateModelFunctionSchema) def evaluate_model( - config: str, + config: dict, use_snapshots: bool = False, save_logs: bool = False, ) -> None: @@ -346,8 +346,8 @@ def evaluate_model( Parameters ---------- - config : str - Path to config file. + config : dict + Configuration dictionary. use_snapshots : bool, optional Whether to plot for all snapshots or only the best model. Defaults to False. save_logs : bool, optional @@ -357,18 +357,14 @@ def evaluate_model( ------- None """ + project_path = Path(config["project_path"]).resolve() try: - config_file = Path(config).resolve() - cfg = read_config(str(config_file)) if save_logs: - log_path = Path(cfg["project_path"]) / "logs" / "evaluate_model.log" + log_path = project_path / "logs" / "evaluate_model.log" logger_config.add_file_handler(str(log_path)) - model_name = cfg["model_name"] - fixed = cfg["egocentric_data"] - - if not os.path.exists(os.path.join(cfg["project_path"], "model", "evaluate")): - os.mkdir(os.path.join(cfg["project_path"], "model", "evaluate")) + model_name = config["model_name"] + fixed = config["egocentric_data"] use_gpu = torch.cuda.is_available() if use_gpu: @@ -379,19 +375,24 @@ def evaluate_model( torch.device("cpu") logger.info("CUDA is not working, or a GPU is not found; using CPU!") - logger.info("Evaluation of %s model. " % model_name) + logger.info(f"Evaluation of model: {model_name}") if not use_snapshots: - eval_temporal(cfg, use_gpu, model_name, fixed) # suffix=suffix + eval_temporal( + cfg=config, + use_gpu=use_gpu, + model_name=model_name, + fixed=fixed, + ) elif use_snapshots: - snapshots = os.listdir(os.path.join(cfg["project_path"], "model", "best_model", "snapshots")) + snapshots = os.listdir(os.path.join(str(project_path), "model", "best_model", "snapshots")) for snap in snapshots: - fullpath = os.path.join(cfg["project_path"], "model", "best_model", "snapshots", snap) + fullpath = os.path.join(str(project_path), "model", "best_model", "snapshots", snap) epoch = snap.split("_")[-1] eval_temporal( - cfg, - use_gpu, - model_name, - fixed, + cfg=config, + use_gpu=use_gpu, + model_name=model_name, + fixed=fixed, snapshot=fullpath, suffix="snapshot" + str(epoch), ) diff --git a/src/vame/pipeline.py b/src/vame/pipeline.py index cba1b344..b9906a67 100644 --- a/src/vame/pipeline.py +++ b/src/vame/pipeline.py @@ -5,6 +5,10 @@ import vame from vame.util.auxiliary import read_config, read_states from vame.io.load_poses import load_vame_dataset +from vame.preprocessing.visualization import ( + visualize_preprocessing_scatter, + visualize_preprocessing_timeseries, +) from vame.logging.logger import VameLogger @@ -81,23 +85,27 @@ def get_raw_datasets(self) -> xr.Dataset: dss.attrs[key] = value return dss - def preprocessing(self, pose_ref_index=[0, 1]): - vame.egocentric_alignment( - config=self.config_path, - pose_ref_index=pose_ref_index, + def preprocessing( + self, + centered_reference_keypoint: str = "snout", + orientation_reference_keypoint: str = "tailbase", + ): + vame.preprocessing( + config=self.config, + centered_reference_keypoint=centered_reference_keypoint, + orientation_reference_keypoint=orientation_reference_keypoint, ) + visualize_preprocessing_scatter(config=self.config) + visualize_preprocessing_timeseries(config=self.config) def create_training_set(self): - vame.create_trainset( - config=self.config_path, - check_parameter=False, - ) + vame.create_trainset(config=self.config) def train_model(self): - vame.train_model(config=self.config_path) + vame.train_model(config=self.config) def evaluate_model(self): - vame.evaluate_model(config=self.config_path) + vame.evaluate_model(config=self.config) def run_segmentation(self): vame.segment_session(config=self.config_path) diff --git a/src/vame/preprocessing/alignment.py b/src/vame/preprocessing/alignment.py index 50dded18..0e774a8a 100644 --- a/src/vame/preprocessing/alignment.py +++ b/src/vame/preprocessing/alignment.py @@ -90,11 +90,13 @@ def egocentrically_align_and_center( # Update the dataset with the cleaned position values ds[save_to_variable] = (ds[read_from_variable].dims, position_aligned) - ds.attrs.update({ - "processed_alignment": True, - "centered_reference_keypoint": centered_reference_keypoint, - "orientation_reference_keypoint": orientation_reference_keypoint, - }) + ds.attrs.update( + { + "processed_alignment": True, + "centered_reference_keypoint": centered_reference_keypoint, + "orientation_reference_keypoint": orientation_reference_keypoint, + } + ) # Save the aligned dataset to file cleaned_file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") From 2c5fc35bb8f6b9240b0035109ee63a7fe15684bf Mon Sep 17 00:00:00 2001 From: luiz Date: Sun, 22 Dec 2024 11:43:59 +0100 Subject: [PATCH 20/77] segmentation wip --- src/vame/analysis/pose_segmentation.py | 95 ++++++++++++++------------ src/vame/model/create_training.py | 10 ++- src/vame/pipeline.py | 2 +- src/vame/util/model_util.py | 6 +- 4 files changed, 57 insertions(+), 56 deletions(-) diff --git a/src/vame/analysis/pose_segmentation.py b/src/vame/analysis/pose_segmentation.py index e6520835..d69d39ee 100644 --- a/src/vame/analysis/pose_segmentation.py +++ b/src/vame/analysis/pose_segmentation.py @@ -11,9 +11,7 @@ from vame.schemas.states import save_state, SegmentSessionFunctionSchema from vame.logging.logger import VameLogger, TqdmToLogger from vame.model.rnn_model import RNN_VAE -from vame.util.auxiliary import read_config - -# from vame.util.data_manipulation import consecutive +from vame.util.data_manipulation import read_pose_estimation_file from vame.util.cli import get_sessions_from_user_input from vame.util.model_util import load_model @@ -27,7 +25,8 @@ def embedd_latent_vectors( sessions: List[str], model: RNN_VAE, fixed: bool, - tqdm_stream: Union[TqdmToLogger, None], + read_from_variable: str = "position_processed", + tqdm_stream: Union[TqdmToLogger, None] = None, ) -> List[np.ndarray]: """ Embed latent vectors for the given files using the VAME model. @@ -54,7 +53,7 @@ def embedd_latent_vectors( temp_win = cfg["time_window"] num_features = cfg["num_features"] if not fixed: - num_features = num_features - 2 + num_features = num_features - 3 use_gpu = torch.cuda.is_available() if use_gpu: @@ -66,15 +65,24 @@ def embedd_latent_vectors( for session in sessions: logger.info(f"Embedding of latent vector for file {session}") - data = np.load( - os.path.join( - project_path, - "data", - "processed", - session, - session + "-PE-seq-clean.npy", - ) - ) + # Read session data + file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") + _, _, ds = read_pose_estimation_file(file_path=file_path) + data = np.copy(ds[read_from_variable].values) + + # WIP - need to fix data loading here + # look at how it's done in `traindata_aligned()` function + # probably a good idea to write a function `data_load_for_rnn()` + + # data = np.load( + # os.path.join( + # project_path, + # "data", + # "processed", + # session, + # session + "-PE-seq-clean.npy", + # ) + # ) latent_vector_list = [] with torch.no_grad(): for i in tqdm.tqdm(range(data.shape[1] - temp_win), file=tqdm_stream): @@ -258,7 +266,7 @@ def individual_segmentation( @save_state(model=SegmentSessionFunctionSchema) def segment_session( - config: str, + config: dict, save_logs: bool = False, ) -> None: """ @@ -291,8 +299,8 @@ def segment_session( Parameters ---------- - config : str - Path to the configuration file. + config : dict + Configuration dictionary. save_logs : bool, optional Whether to save logs, by default False. @@ -300,29 +308,28 @@ def segment_session( ------- None """ + project_path = Path(config["project_path"]).resolve() try: - config_file = Path(config).resolve() - cfg = read_config(str(config_file)) tqdm_stream = None if save_logs: - log_path = Path(cfg["project_path"]) / "logs" / "pose_segmentation.log" + log_path = project_path / "logs" / "pose_segmentation.log" logger_config.add_file_handler(str(log_path)) tqdm_stream = TqdmToLogger(logger) - model_name = cfg["model_name"] - n_clusters = cfg["n_clusters"] - fixed = cfg["egocentric_data"] - segmentation_algorithms = cfg["segmentation_algorithms"] + model_name = config["model_name"] + n_clusters = config["n_clusters"] + fixed = config["egocentric_data"] + segmentation_algorithms = config["segmentation_algorithms"] + ind_seg = config["individual_segmentation"] logger.info("Pose segmentation for VAME model: %s \n" % model_name) - ind_seg = cfg["individual_segmentation"] logger.info(f"Segmentation algorithms: {segmentation_algorithms}") for seg in segmentation_algorithms: logger.info(f"Running pose segmentation using {seg} algorithm...") - for session in cfg["session_names"]: + for session in config["session_names"]: if not os.path.exists( os.path.join( - cfg["project_path"], + str(project_path), "results", session, model_name, @@ -331,7 +338,7 @@ def segment_session( ): os.mkdir( os.path.join( - cfg["project_path"], + str(project_path), "results", session, model_name, @@ -340,11 +347,11 @@ def segment_session( ) # Get sessions - if cfg["all_data"] in ["Yes", "yes"]: - sessions = cfg["session_names"] + if config["all_data"] in ["Yes", "yes"]: + sessions = config["session_names"] else: sessions = get_sessions_from_user_input( - cfg=cfg, + cfg=config, action_message="run segmentation", ) @@ -359,7 +366,7 @@ def segment_session( if not os.path.exists( os.path.join( - cfg["project_path"], + str(project_path), "results", sessions[0], model_name, @@ -368,9 +375,9 @@ def segment_session( ) ): new = True - model = load_model(cfg, model_name, fixed) + model = load_model(config, model_name, fixed) latent_vectors = embedd_latent_vectors( - cfg, + config, sessions, model, fixed, @@ -382,7 +389,7 @@ def segment_session( f"Apply individual segmentation of latent vectors for each session, {n_clusters} clusters" ) labels, cluster_center, motif_usages = individual_segmentation( - cfg=cfg, + cfg=config, sessions=sessions, latent_vectors=latent_vectors, n_clusters=n_clusters, @@ -392,7 +399,7 @@ def segment_session( f"Apply the same segmentation of latent vectors for all sessions, {n_clusters} clusters" ) labels, cluster_center, motif_usages = same_segmentation( - cfg=cfg, + cfg=config, sessions=sessions, latent_vectors=latent_vectors, n_clusters=n_clusters, @@ -404,7 +411,7 @@ def segment_session( if os.path.exists( os.path.join( - cfg["project_path"], + str(project_path), "results", sessions[0], model_name, @@ -424,7 +431,7 @@ def segment_session( latent_vectors = [] for session in sessions: path_to_latent_vector = os.path.join( - cfg["project_path"], + str(project_path), "results", session, model_name, @@ -445,7 +452,7 @@ def segment_session( ) # [SRM, 10/28/24] rename to cluster_centers labels, cluster_center, motif_usages = individual_segmentation( - cfg=cfg, + cfg=config, sessions=sessions, latent_vectors=latent_vectors, n_clusters=n_clusters, @@ -456,7 +463,7 @@ def segment_session( ) # [SRM, 10/28/24] rename to cluster_centers labels, cluster_center, motif_usages = same_segmentation( - cfg=cfg, + cfg=config, sessions=sessions, latent_vectors=latent_vectors, n_clusters=n_clusters, @@ -471,7 +478,7 @@ def segment_session( for idx, session in enumerate(sessions): logger.info( os.path.join( - cfg["project_path"], + project_path, "results", session, "", @@ -482,7 +489,7 @@ def segment_session( ) if not os.path.exists( os.path.join( - cfg["project_path"], + project_path, "results", session, model_name, @@ -493,7 +500,7 @@ def segment_session( try: os.mkdir( os.path.join( - cfg["project_path"], + project_path, "results", session, "", @@ -506,7 +513,7 @@ def segment_session( logger.error(error) save_data = os.path.join( - cfg["project_path"], + str(project_path), "results", session, model_name, diff --git a/src/vame/model/create_training.py b/src/vame/model/create_training.py index 16ac04e6..10a4b2fa 100644 --- a/src/vame/model/create_training.py +++ b/src/vame/model/create_training.py @@ -63,13 +63,11 @@ def traindata_aligned( # Identify columns to exclude excluded_columns = [] for ind in individuals: - excluded_columns.append( - f"{ind}_{centered_reference_keypoint}_x" - ) # Exclude both x and y for centered_reference_keypoint + # Exclude both x and y for centered_reference_keypoint + excluded_columns.append(f"{ind}_{centered_reference_keypoint}_x") excluded_columns.append(f"{ind}_{centered_reference_keypoint}_y") - excluded_columns.append( - f"{ind}_{orientation_reference_keypoint}_x" - ) # Exclude only x for orientation_reference_keypoint + # Exclude only x for orientation_reference_keypoint + excluded_columns.append(f"{ind}_{orientation_reference_keypoint}_x") # Filter out the excluded columns included_indices = [i for i, col in enumerate(columns) if col not in excluded_columns] diff --git a/src/vame/pipeline.py b/src/vame/pipeline.py index b9906a67..dfe1f3f7 100644 --- a/src/vame/pipeline.py +++ b/src/vame/pipeline.py @@ -108,7 +108,7 @@ def evaluate_model(self): vame.evaluate_model(config=self.config) def run_segmentation(self): - vame.segment_session(config=self.config_path) + vame.segment_session(config=self.config) def generate_motif_videos(self): vame.motif_videos( diff --git a/src/vame/util/model_util.py b/src/vame/util/model_util.py index 127001a0..315c8b24 100644 --- a/src/vame/util/model_util.py +++ b/src/vame/util/model_util.py @@ -1,8 +1,4 @@ import os -import yaml -import ruamel.yaml -from pathlib import Path -from typing import Tuple import torch from vame.logging.logger import VameLogger from vame.model.rnn_model import RNN_VAE @@ -31,7 +27,7 @@ def load_model(cfg: dict, model_name: str, fixed: bool = True) -> RNN_VAE: NUM_FEATURES = cfg["num_features"] if not fixed: - NUM_FEATURES = NUM_FEATURES - 2 + NUM_FEATURES = NUM_FEATURES - 3 hidden_size_layer_1 = cfg["hidden_size_layer_1"] hidden_size_layer_2 = cfg["hidden_size_layer_2"] hidden_size_rec = cfg["hidden_size_rec"] From 7b62b84cc1b3841d6c75ca1320d4141faec779ef Mon Sep 17 00:00:00 2001 From: luiz Date: Sun, 22 Dec 2024 19:00:07 +0100 Subject: [PATCH 21/77] fix segment session --- src/vame/analysis/pose_segmentation.py | 22 +-- src/vame/io/load_poses.py | 81 +++++++++- src/vame/io/nwb.py | 67 ++++++++ src/vame/model/create_training.py | 35 +---- .../align_egocentrical_legacy.py | 2 +- src/vame/preprocessing/alignment.py | 2 +- src/vame/preprocessing/cleaning.py | 2 +- src/vame/preprocessing/filter.py | 2 +- src/vame/preprocessing/to_model.py | 53 +++++++ src/vame/preprocessing/visualization.py | 2 +- src/vame/util/csv_to_npy.py | 7 +- src/vame/util/data_manipulation.py | 146 +----------------- src/vame/util/gif_pose_helper.py | 3 +- 13 files changed, 227 insertions(+), 197 deletions(-) create mode 100644 src/vame/io/nwb.py create mode 100644 src/vame/preprocessing/to_model.py diff --git a/src/vame/analysis/pose_segmentation.py b/src/vame/analysis/pose_segmentation.py index d69d39ee..7eae0192 100644 --- a/src/vame/analysis/pose_segmentation.py +++ b/src/vame/analysis/pose_segmentation.py @@ -11,9 +11,10 @@ from vame.schemas.states import save_state, SegmentSessionFunctionSchema from vame.logging.logger import VameLogger, TqdmToLogger from vame.model.rnn_model import RNN_VAE -from vame.util.data_manipulation import read_pose_estimation_file +from vame.io.load_poses import read_pose_estimation_file from vame.util.cli import get_sessions_from_user_input from vame.util.model_util import load_model +from vame.preprocessing.to_model import format_xarray_for_rnn logger_config = VameLogger(__name__) @@ -70,19 +71,12 @@ def embedd_latent_vectors( _, _, ds = read_pose_estimation_file(file_path=file_path) data = np.copy(ds[read_from_variable].values) - # WIP - need to fix data loading here - # look at how it's done in `traindata_aligned()` function - # probably a good idea to write a function `data_load_for_rnn()` - - # data = np.load( - # os.path.join( - # project_path, - # "data", - # "processed", - # session, - # session + "-PE-seq-clean.npy", - # ) - # ) + # Format the data for the RNN model + data = format_xarray_for_rnn( + ds=ds, + read_from_variable=read_from_variable, + ) + latent_vector_list = [] with torch.no_grad(): for i in tqdm.tqdm(range(data.shape[1] - temp_win), file=tqdm_stream): diff --git a/src/vame/io/load_poses.py b/src/vame/io/load_poses.py index b48c8b82..de61f6eb 100644 --- a/src/vame/io/load_poses.py +++ b/src/vame/io/load_poses.py @@ -1,7 +1,11 @@ -from typing import Literal +from typing import Literal, Optional, Tuple from pathlib import Path from movement.io import load_poses as mio_load_poses import xarray as xr +import numpy as np +import pandas as pd + +from vame.schemas.project import PoseEstimationFiletype def load_pose_estimation( @@ -51,3 +55,78 @@ def load_vame_dataset(ds_path: Path | str) -> xr.Dataset: -------- """ return xr.open_dataset(ds_path, engine="scipy") + + +def nc_to_dataframe(nc_data): + keypoints = nc_data["keypoints"].values + space = nc_data["space"].values + + # Flatten position data + position_data = nc_data["position"].isel(individuals=0).values + position_column_names = [f"{keypoint}_{sp}" for keypoint in keypoints for sp in space] + position_flattened = position_data.reshape(position_data.shape[0], -1) + + # Create a DataFrame for position data + position_df = pd.DataFrame(position_flattened, columns=position_column_names) + + # Extract and flatten confidence data + confidence_data = nc_data["confidence"].isel(individuals=0).values + confidence_column_names = [f"{keypoint}_confidence" for keypoint in keypoints] + confidence_flattened = confidence_data.reshape(confidence_data.shape[0], -1) + confidence_df = pd.DataFrame(confidence_flattened, columns=confidence_column_names) + + # Combine position and confidence data + combined_df = pd.concat([position_df, confidence_df], axis=1) + + # Reorder columns: keypoint_x, keypoint_y, keypoint_confidence + reordered_columns = [] + for keypoint in keypoints: + reordered_columns.extend([f"{keypoint}_x", f"{keypoint}_y", f"{keypoint}_confidence"]) + + combined_df = combined_df[reordered_columns] + + return combined_df + + +def read_pose_estimation_file( + file_path: str, + file_type: Optional[PoseEstimationFiletype] = None, + path_to_pose_nwb_series_data: Optional[str] = None, +) -> Tuple[pd.DataFrame, np.ndarray, xr.Dataset]: + """ + Read pose estimation file. + + Parameters + ---------- + file_path : str + Path to the pose estimation file. + file_type : PoseEstimationFiletype + Type of the pose estimation file. Supported types are 'csv' and 'nwb'. + path_to_pose_nwb_series_data : str, optional + Path to the pose data inside the nwb file, by default None + + Returns + ------- + Tuple[pd.DataFrame, np.ndarray] + Tuple containing the pose estimation data as a pandas DataFrame and a numpy array. + """ + ds = load_vame_dataset(ds_path=file_path) + data = nc_to_dataframe(ds) + data_mat = pd.DataFrame.to_numpy(data) + return data, data_mat, ds + # if file_type == PoseEstimationFiletype.csv: + # data = pd.read_csv(file_path, skiprows=2, index_col=0) + # if "coords" in data: + # data = data.drop(columns=["coords"], axis=1) + # data_mat = pd.DataFrame.to_numpy(data) + # return data, data_mat + # elif file_type == PoseEstimationFiletype.nwb: + # if not path_to_pose_nwb_series_data: + # raise ValueError("Path to pose nwb series data is required.") + # data = get_dataframe_from_pose_nwb_file( + # file_path=file_path, + # path_to_pose_nwb_series_data=path_to_pose_nwb_series_data, + # ) + # data_mat = pd.DataFrame.to_numpy(data) + # return data, data_mat + # raise ValueError(f"Filetype {file_type} not supported") diff --git a/src/vame/io/nwb.py b/src/vame/io/nwb.py new file mode 100644 index 00000000..da142191 --- /dev/null +++ b/src/vame/io/nwb.py @@ -0,0 +1,67 @@ +from pynwb import NWBHDF5IO +from pynwb.file import NWBFile +from hdmf.utils import LabelledDict +import pandas as pd + + +def get_pose_data_from_nwb_file( + nwbfile: NWBFile, + path_to_pose_nwb_series_data: str, +) -> LabelledDict: + """ + Get pose data from nwb file using a inside path to the nwb data. + + Parameters: + ---------- + nwbfile : NWBFile) + NWB file object. + path_to_pose_nwb_series_data : str + Path to the pose data inside the nwb file. + + Returns + ------- + LabelledDict + Pose data. + """ + if not path_to_pose_nwb_series_data: + raise ValueError("Path to pose nwb series data is required.") + pose_data = nwbfile + for key in path_to_pose_nwb_series_data.split("/"): + if isinstance(pose_data, dict): + pose_data = pose_data.get(key) + continue + pose_data = getattr(pose_data, key) + return pose_data + + +def get_dataframe_from_pose_nwb_file( + file_path: str, + path_to_pose_nwb_series_data: str, +) -> pd.DataFrame: + """ + Get pose data from nwb file and return it as a pandas DataFrame. + + Parameters + ---------- + file_path : str + Path to the nwb file. + path_to_pose_nwb_series_data : str + Path to the pose data inside the nwb file. + + Returns + ------- + pd.DataFrame + Pose data as a pandas DataFrame. + """ + with NWBHDF5IO(file_path, "r") as io: + nwbfile = io.read() + pose = get_pose_data_from_nwb_file(nwbfile, path_to_pose_nwb_series_data) + dataframes = [] + for label, pose_series in pose.items(): + data = pose_series.data[:] + confidence = pose_series.confidence[:] + df = pd.DataFrame(data, columns=[f"{label}_x", f"{label}_y"]) + df[f"likelihood_{label}"] = confidence + dataframes.append(df) + final_df = pd.concat(dataframes, axis=1) + return final_df diff --git a/src/vame/model/create_training.py b/src/vame/model/create_training.py index 10a4b2fa..8db45adf 100644 --- a/src/vame/model/create_training.py +++ b/src/vame/model/create_training.py @@ -5,7 +5,8 @@ from vame.logging.logger import VameLogger from vame.schemas.states import CreateTrainsetFunctionSchema, save_state -from vame.util.data_manipulation import read_pose_estimation_file +from vame.io.load_poses import read_pose_estimation_file +from vame.preprocessing.to_model import format_xarray_for_rnn logger_config = VameLogger(__name__) @@ -47,35 +48,15 @@ def traindata_aligned( file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") _, _, ds = read_pose_estimation_file(file_path=file_path) - position_data = ds[read_from_variable] - centered_reference_keypoint = ds.attrs["centered_reference_keypoint"] - orientation_reference_keypoint = ds.attrs["orientation_reference_keypoint"] - - # Get the coordinates - individuals = position_data.coords["individuals"].values - keypoints = position_data.coords["keypoints"].values - spaces = position_data.coords["space"].values - - # Create a flattened array and infer column indices - flattened_array = position_data.values.reshape(position_data.shape[0], -1) - columns = [f"{ind}_{kp}_{sp}" for ind in individuals for kp in keypoints for sp in spaces] - - # Identify columns to exclude - excluded_columns = [] - for ind in individuals: - # Exclude both x and y for centered_reference_keypoint - excluded_columns.append(f"{ind}_{centered_reference_keypoint}_x") - excluded_columns.append(f"{ind}_{centered_reference_keypoint}_y") - # Exclude only x for orientation_reference_keypoint - excluded_columns.append(f"{ind}_{orientation_reference_keypoint}_x") - - # Filter out the excluded columns - included_indices = [i for i, col in enumerate(columns) if col not in excluded_columns] - filtered_array = flattened_array[:, included_indices] + # Format the data for the RNN model + filtered_array = format_xarray_for_rnn( + ds=ds, + read_from_variable=read_from_variable, + ) all_data_list.append(filtered_array) - all_data_array = np.concatenate(all_data_list, axis=0).T + all_data_array = np.concatenate(all_data_list, axis=1) test_size = int(all_data_array.shape[1] * test_fraction) data_test = all_data_array[:, :test_size] data_train = all_data_array[:, test_size:] diff --git a/src/vame/preprocessing/align_egocentrical_legacy.py b/src/vame/preprocessing/align_egocentrical_legacy.py index 36ba8240..947e1705 100644 --- a/src/vame/preprocessing/align_egocentrical_legacy.py +++ b/src/vame/preprocessing/align_egocentrical_legacy.py @@ -10,11 +10,11 @@ from vame.util.auxiliary import read_config from vame.schemas.states import EgocentricAlignmentFunctionSchema, save_state from vame.schemas.project import PoseEstimationFiletype +from vame.io.load_poses import read_pose_estimation_file from vame.util.data_manipulation import ( interpol_first_rows_nans, crop_and_flip_legacy, background, - read_pose_estimation_file, ) from vame.video import get_video_frame_rate diff --git a/src/vame/preprocessing/alignment.py b/src/vame/preprocessing/alignment.py index 0e774a8a..9a2851f7 100644 --- a/src/vame/preprocessing/alignment.py +++ b/src/vame/preprocessing/alignment.py @@ -2,7 +2,7 @@ from pathlib import Path from vame.logging.logger import VameLogger -from vame.util.data_manipulation import read_pose_estimation_file +from vame.io.load_poses import read_pose_estimation_file logger_config = VameLogger(__name__) diff --git a/src/vame/preprocessing/cleaning.py b/src/vame/preprocessing/cleaning.py index 91185535..71ed680d 100644 --- a/src/vame/preprocessing/cleaning.py +++ b/src/vame/preprocessing/cleaning.py @@ -3,7 +3,7 @@ from scipy.stats import iqr from vame.logging.logger import VameLogger -from vame.util.data_manipulation import read_pose_estimation_file +from vame.io.load_poses import read_pose_estimation_file logger_config = VameLogger(__name__) diff --git a/src/vame/preprocessing/filter.py b/src/vame/preprocessing/filter.py index 86cf3476..d3c15da5 100644 --- a/src/vame/preprocessing/filter.py +++ b/src/vame/preprocessing/filter.py @@ -3,7 +3,7 @@ from pathlib import Path from vame.logging.logger import VameLogger -from vame.util.data_manipulation import read_pose_estimation_file +from vame.io.load_poses import read_pose_estimation_file logger_config = VameLogger(__name__) diff --git a/src/vame/preprocessing/to_model.py b/src/vame/preprocessing/to_model.py new file mode 100644 index 00000000..f7b9903a --- /dev/null +++ b/src/vame/preprocessing/to_model.py @@ -0,0 +1,53 @@ +import xarray as xr + + +def format_xarray_for_rnn( + ds: xr.Dataset, + read_from_variable: str = "position_processed", +): + """ + Formats the xarray dataset for use VAME's RNN model: + - The x and y coordinates of the centered_reference_keypoint are excluded. + - The x coordinate of the orientation_reference_keypoint is excluded. + - The remaining data is flattened and transposed. + + Parameters: + ----------- + ds : xr.Dataset + The xarray dataset to format. + read_from_variable : str, default="position_processed" + The variable to read from the dataset. + + Returns: + -------- + np.ndarray + The formatted array in the shape (n_features, n_samples). + Where n_features = 2 * n_keypoints * n_spaces - 3. + """ + data = ds[read_from_variable] + centered_reference_keypoint = ds.attrs["centered_reference_keypoint"] + orientation_reference_keypoint = ds.attrs["orientation_reference_keypoint"] + + # Get the coordinates + individuals = data.coords["individuals"].values + keypoints = data.coords["keypoints"].values + spaces = data.coords["space"].values + + # Create a flattened array and infer column indices + flattened_array = data.values.reshape(data.shape[0], -1) + columns = [f"{ind}_{kp}_{sp}" for ind in individuals for kp in keypoints for sp in spaces] + + # Identify columns to exclude + excluded_columns = [] + for ind in individuals: + # Exclude both x and y for centered_reference_keypoint + excluded_columns.append(f"{ind}_{centered_reference_keypoint}_x") + excluded_columns.append(f"{ind}_{centered_reference_keypoint}_y") + # Exclude only x for orientation_reference_keypoint + excluded_columns.append(f"{ind}_{orientation_reference_keypoint}_x") + + # Filter out the excluded columns + included_indices = [i for i, col in enumerate(columns) if col not in excluded_columns] + filtered_array = flattened_array[:, included_indices] + + return filtered_array.T diff --git a/src/vame/preprocessing/visualization.py b/src/vame/preprocessing/visualization.py index 8c55f413..2bc34592 100644 --- a/src/vame/preprocessing/visualization.py +++ b/src/vame/preprocessing/visualization.py @@ -2,7 +2,7 @@ import matplotlib.pyplot as plt from matplotlib.cm import get_cmap -from vame.util.data_manipulation import read_pose_estimation_file +from vame.io.load_poses import read_pose_estimation_file def visualize_preprocessing_scatter( diff --git a/src/vame/util/csv_to_npy.py b/src/vame/util/csv_to_npy.py index 7ce2c382..97b47d60 100644 --- a/src/vame/util/csv_to_npy.py +++ b/src/vame/util/csv_to_npy.py @@ -2,13 +2,12 @@ import numpy as np import pandas as pd from pathlib import Path + from vame.util.auxiliary import read_config from vame.schemas.states import PoseToNumpyFunctionSchema, save_state from vame.logging.logger import VameLogger -from vame.util.data_manipulation import ( - interpol_first_rows_nans, - read_pose_estimation_file, -) +from vame.util.data_manipulation import interpol_first_rows_nans +from vame.io.load_poses import read_pose_estimation_file logger_config = VameLogger(__name__) diff --git a/src/vame/util/data_manipulation.py b/src/vame/util/data_manipulation.py index 0ae05474..5e2de8f8 100644 --- a/src/vame/util/data_manipulation.py +++ b/src/vame/util/data_manipulation.py @@ -1,131 +1,18 @@ -from typing import List, Tuple, Optional +from typing import List, Tuple import numpy as np import pandas as pd -import xarray as xr import cv2 as cv import os import tqdm from scipy.ndimage import median_filter -from pynwb import NWBHDF5IO -from pynwb.file import NWBFile -from hdmf.utils import LabelledDict -from vame.schemas.project import PoseEstimationFiletype from vame.logging.logger import VameLogger -from vame.io.load_poses import load_vame_dataset logger_config = VameLogger(__name__) logger = logger_config.logger -def get_pose_data_from_nwb_file( - nwbfile: NWBFile, - path_to_pose_nwb_series_data: str, -) -> LabelledDict: - """ - Get pose data from nwb file using a inside path to the nwb data. - - Parameters: - ---------- - nwbfile : NWBFile) - NWB file object. - path_to_pose_nwb_series_data : str - Path to the pose data inside the nwb file. - - Returns - ------- - LabelledDict - Pose data. - """ - if not path_to_pose_nwb_series_data: - raise ValueError("Path to pose nwb series data is required.") - pose_data = nwbfile - for key in path_to_pose_nwb_series_data.split("/"): - if isinstance(pose_data, dict): - pose_data = pose_data.get(key) - continue - pose_data = getattr(pose_data, key) - return pose_data - - -def get_dataframe_from_pose_nwb_file( - file_path: str, - path_to_pose_nwb_series_data: str, -) -> pd.DataFrame: - """ - Get pose data from nwb file and return it as a pandas DataFrame. - - Parameters - ---------- - file_path : str - Path to the nwb file. - path_to_pose_nwb_series_data : str - Path to the pose data inside the nwb file. - - Returns - ------- - pd.DataFrame - Pose data as a pandas DataFrame. - """ - with NWBHDF5IO(file_path, "r") as io: - nwbfile = io.read() - pose = get_pose_data_from_nwb_file(nwbfile, path_to_pose_nwb_series_data) - dataframes = [] - for label, pose_series in pose.items(): - data = pose_series.data[:] - confidence = pose_series.confidence[:] - df = pd.DataFrame(data, columns=[f"{label}_x", f"{label}_y"]) - df[f"likelihood_{label}"] = confidence - dataframes.append(df) - final_df = pd.concat(dataframes, axis=1) - return final_df - - -def read_pose_estimation_file( - file_path: str, - file_type: Optional[PoseEstimationFiletype] = None, - path_to_pose_nwb_series_data: Optional[str] = None, -) -> Tuple[pd.DataFrame, np.ndarray, xr.Dataset]: - """ - Read pose estimation file. - - Parameters - ---------- - file_path : str - Path to the pose estimation file. - file_type : PoseEstimationFiletype - Type of the pose estimation file. Supported types are 'csv' and 'nwb'. - path_to_pose_nwb_series_data : str, optional - Path to the pose data inside the nwb file, by default None - - Returns - ------- - Tuple[pd.DataFrame, np.ndarray] - Tuple containing the pose estimation data as a pandas DataFrame and a numpy array. - """ - ds = load_vame_dataset(ds_path=file_path) - data = nc_to_dataframe(ds) - data_mat = pd.DataFrame.to_numpy(data) - return data, data_mat, ds - # if file_type == PoseEstimationFiletype.csv: - # data = pd.read_csv(file_path, skiprows=2, index_col=0) - # if "coords" in data: - # data = data.drop(columns=["coords"], axis=1) - # data_mat = pd.DataFrame.to_numpy(data) - # return data, data_mat - # elif file_type == PoseEstimationFiletype.nwb: - # if not path_to_pose_nwb_series_data: - # raise ValueError("Path to pose nwb series data is required.") - # data = get_dataframe_from_pose_nwb_file( - # file_path=file_path, - # path_to_pose_nwb_series_data=path_to_pose_nwb_series_data, - # ) - # data_mat = pd.DataFrame.to_numpy(data) - # return data, data_mat - # raise ValueError(f"Filetype {file_type} not supported") - - def consecutive( data: np.ndarray, stepsize: int = 1, @@ -381,34 +268,3 @@ def background( capture.release() return background - - -def nc_to_dataframe(nc_data): - keypoints = nc_data["keypoints"].values - space = nc_data["space"].values - - # Flatten position data - position_data = nc_data["position"].isel(individuals=0).values - position_column_names = [f"{keypoint}_{sp}" for keypoint in keypoints for sp in space] - position_flattened = position_data.reshape(position_data.shape[0], -1) - - # Create a DataFrame for position data - position_df = pd.DataFrame(position_flattened, columns=position_column_names) - - # Extract and flatten confidence data - confidence_data = nc_data["confidence"].isel(individuals=0).values - confidence_column_names = [f"{keypoint}_confidence" for keypoint in keypoints] - confidence_flattened = confidence_data.reshape(confidence_data.shape[0], -1) - confidence_df = pd.DataFrame(confidence_flattened, columns=confidence_column_names) - - # Combine position and confidence data - combined_df = pd.concat([position_df, confidence_df], axis=1) - - # Reorder columns: keypoint_x, keypoint_y, keypoint_confidence - reordered_columns = [] - for keypoint in keypoints: - reordered_columns.extend([f"{keypoint}_x", f"{keypoint}_y", f"{keypoint}_confidence"]) - - combined_df = combined_df[reordered_columns] - - return combined_df diff --git a/src/vame/util/gif_pose_helper.py b/src/vame/util/gif_pose_helper.py index 752efcba..7b9eebc6 100644 --- a/src/vame/util/gif_pose_helper.py +++ b/src/vame/util/gif_pose_helper.py @@ -3,12 +3,13 @@ import cv2 as cv import numpy as np import pandas as pd + from vame.logging.logger import VameLogger +from vame.io.load_poses import read_pose_estimation_file from vame.util.data_manipulation import ( interpol_first_rows_nans, crop_and_flip_legacy, background, - read_pose_estimation_file, ) From 25a2d006fff9bb3240d3d4e2d569d947db6a29a1 Mon Sep 17 00:00:00 2001 From: luiz Date: Mon, 23 Dec 2024 13:37:14 +0100 Subject: [PATCH 22/77] config --- src/vame/analysis/community_analysis.py | 27 ++++++-------- src/vame/analysis/videowriter.py | 49 +++++++++++-------------- src/vame/pipeline.py | 8 ++-- src/vame/util/report.py | 10 ++--- 4 files changed, 42 insertions(+), 52 deletions(-) diff --git a/src/vame/analysis/community_analysis.py b/src/vame/analysis/community_analysis.py index 60dbcffb..28b31210 100644 --- a/src/vame/analysis/community_analysis.py +++ b/src/vame/analysis/community_analysis.py @@ -473,7 +473,7 @@ def save_cohort_community_labels_per_file( @save_state(model=CommunityFunctionSchema) def community( - config: str, + config: dict, segmentation_algorithm: SegmentationAlgorithms, cohort: bool = True, cut_tree: int | None = None, @@ -513,8 +513,8 @@ def community( Parameters ---------- - config : str - Path to the configuration file. + config : dict + Configuration parameters. segmentation_algorithm : SegmentationAlgorithms Which segmentation algorithm to use. Options are 'hmm' or 'kmeans'. cohort : bool, optional @@ -529,22 +529,19 @@ def community( None """ try: - config_file = Path(config).resolve() - cfg = read_config(str(config_file)) - if save_logs: - log_path = Path(cfg["project_path"]) / "logs" / "community.log" + log_path = Path(config["project_path"]) / "logs" / "community.log" logger_config.add_file_handler(str(log_path)) - model_name = cfg["model_name"] - n_clusters = cfg["n_clusters"] + model_name = config["model_name"] + n_clusters = config["n_clusters"] # Get sessions - if cfg["all_data"] in ["Yes", "yes"]: - sessions = cfg["session_names"] + if config["all_data"] in ["Yes", "yes"]: + sessions = config["session_names"] else: sessions = get_sessions_from_user_input( - cfg=cfg, + cfg=config, action_message="run community analysis", ) @@ -552,7 +549,7 @@ def community( if cohort: path_to_dir = Path( os.path.join( - cfg["project_path"], + config["project_path"], "results", "community_cohort", segmentation_algorithm + "-" + str(n_clusters), @@ -563,7 +560,7 @@ def community( path_to_dir.mkdir(parents=True, exist_ok=True) motif_labels = get_motif_labels( - config=cfg, + config=config, sessions=sessions, model_name=model_name, n_clusters=n_clusters, @@ -626,7 +623,7 @@ def community( # Saves the full community labels list to each of the original video files # This is useful for further analysis when cohort=True save_cohort_community_labels_per_file( - config=cfg, + config=config, sessions=sessions, model_name=model_name, n_clusters=n_clusters, diff --git a/src/vame/analysis/videowriter.py b/src/vame/analysis/videowriter.py index a03656aa..35e0adcb 100644 --- a/src/vame/analysis/videowriter.py +++ b/src/vame/analysis/videowriter.py @@ -164,7 +164,7 @@ def create_cluster_videos( @save_state(model=MotifVideosFunctionSchema) def motif_videos( - config: Union[str, Path], + config: dict, segmentation_algorithm: SegmentationAlgorithms, video_type: str = ".mp4", output_video_type: str = ".mp4", @@ -186,8 +186,8 @@ def motif_videos( Parameters ---------- - config : Union[str, Path] - Path to the configuration file. + config : dict + Configuration parameters. segmentation_algorithm : SegmentationAlgorithms Which segmentation algorithm to use. Options are 'hmm' or 'kmeans'. video_type : str, optional @@ -203,31 +203,28 @@ def motif_videos( """ try: tqdm_logger_stream = None - config_file = Path(config).resolve() - cfg = read_config(str(config_file)) - if save_logs: - log_path = Path(cfg["project_path"]) / "logs" / "motif_videos.log" + log_path = Path(config["project_path"]) / "logs" / "motif_videos.log" logger_config.add_file_handler(str(log_path)) tqdm_logger_stream = TqdmToLogger(logger=logger) - model_name = cfg["model_name"] - n_clusters = cfg["n_clusters"] + model_name = config["model_name"] + n_clusters = config["n_clusters"] logger.info(f"Creating motif videos for algorithm: {segmentation_algorithm}...") # Get sessions - if cfg["all_data"] in ["Yes", "yes"]: - sessions = cfg["session_names"] + if config["all_data"] in ["Yes", "yes"]: + sessions = config["session_names"] else: sessions = get_sessions_from_user_input( - cfg=cfg, + cfg=config, action_message="write motif videos", ) logger.info("Cluster size is: %d " % n_clusters) for session in sessions: path_to_file = os.path.join( - cfg["project_path"], + config["project_path"], "results", session, model_name, @@ -238,7 +235,7 @@ def motif_videos( os.mkdir(os.path.join(path_to_file, "cluster_videos")) create_cluster_videos( - config=cfg, + config=config, path_to_file=path_to_file, session=session, n_clusters=n_clusters, @@ -258,7 +255,7 @@ def motif_videos( @save_state(model=CommunityVideosFunctionSchema) def community_videos( - config: Union[str, Path], + config: dict, segmentation_algorithm: SegmentationAlgorithms, cohort: bool = True, video_type: str = ".mp4", @@ -286,8 +283,8 @@ def community_videos( Parameters: ----------- - config : Union[str, Path] - Path to the configuration file. + config : dict + Configuration parameters. segmentation_algorithm : SegmentationAlgorithms Which segmentation algorithm to use. Options are 'hmm' or 'kmeans'. cohort : bool, optional @@ -305,29 +302,27 @@ def community_videos( """ try: tqdm_logger_stream = None - config_file = Path(config).resolve() - cfg = read_config(str(config_file)) if save_logs: - log_path = Path(cfg["project_path"]) / "logs" / "community_videos.log" + log_path = Path(config["project_path"]) / "logs" / "community_videos.log" logger_config.add_file_handler(str(log_path)) tqdm_logger_stream = TqdmToLogger(logger=logger) - model_name = cfg["model_name"] - n_clusters = cfg["n_clusters"] + model_name = config["model_name"] + n_clusters = config["n_clusters"] # Get sessions - if cfg["all_data"] in ["Yes", "yes"]: - sessions = cfg["session_names"] + if config["all_data"] in ["Yes", "yes"]: + sessions = config["session_names"] else: sessions = get_sessions_from_user_input( - cfg=cfg, + cfg=config, action_message="write community videos", ) logger.info("Cluster size is: %d " % n_clusters) for session in sessions: path_to_file = os.path.join( - cfg["project_path"], + config["project_path"], "results", session, model_name, @@ -338,7 +333,7 @@ def community_videos( os.mkdir(os.path.join(path_to_file, "community_videos")) create_cluster_videos( - config=cfg, + config=config, path_to_file=path_to_file, session=session, n_clusters=n_clusters, diff --git a/src/vame/pipeline.py b/src/vame/pipeline.py index dfe1f3f7..b8cd172c 100644 --- a/src/vame/pipeline.py +++ b/src/vame/pipeline.py @@ -112,14 +112,14 @@ def run_segmentation(self): def generate_motif_videos(self): vame.motif_videos( - config=self.config_path, + config=self.config, video_type=".mp4", segmentation_algorithm="hmm", ) def run_community_clustering(self): vame.community( - config=self.config_path, + config=self.config, segmentation_algorithm="hmm", cohort=True, cut_tree=2, @@ -127,7 +127,7 @@ def run_community_clustering(self): def generate_community_videos(self): vame.community_videos( - config=self.config_path, + config=self.config, video_type=".mp4", segmentation_algorithm="hmm", ) @@ -141,7 +141,7 @@ def visualization(self): def report(self): vame.report( - config=self.config_path, + config=self.config, segmentation_algorithm="hmm", ) diff --git a/src/vame/util/report.py b/src/vame/util/report.py index e13504b5..b119d879 100644 --- a/src/vame/util/report.py +++ b/src/vame/util/report.py @@ -15,17 +15,15 @@ def report( - config: str, + config: dict, segmentation_algorithm: str = "hmm", ) -> None: """ Report for a project. """ - config_file = Path(config).resolve() - cfg = read_config(str(config_file)) - project_path = Path(cfg["project_path"]) - n_clusters = cfg["n_clusters"] - model_name = cfg["model_name"] + project_path = Path(config["project_path"]) + n_clusters = config["n_clusters"] + model_name = config["model_name"] with open(project_path / "states" / "states.json") as f: project_states = json.load(f) From 2170a2f4479a020003506bcf7a58953b4c68fe09 Mon Sep 17 00:00:00 2001 From: luiz Date: Mon, 23 Dec 2024 13:46:54 +0100 Subject: [PATCH 23/77] keypoints present test --- src/vame/preprocessing/alignment.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/vame/preprocessing/alignment.py b/src/vame/preprocessing/alignment.py index 9a2851f7..72e55cf6 100644 --- a/src/vame/preprocessing/alignment.py +++ b/src/vame/preprocessing/alignment.py @@ -47,6 +47,10 @@ def egocentrically_align_and_center( # Extract keypoint indices keypoints = ds.coords["keypoints"].values + if centered_reference_keypoint not in keypoints: + raise ValueError(f"Centered reference keypoint {centered_reference_keypoint} not found in dataset.") + if orientation_reference_keypoint not in keypoints: + raise ValueError(f"Orientation reference keypoint {orientation_reference_keypoint} not found in dataset.") idx1 = np.where(keypoints == centered_reference_keypoint)[0][0] idx2 = np.where(keypoints == orientation_reference_keypoint)[0][0] From 2f43a97b7ba746e38248436a9a4da6438179a2fc Mon Sep 17 00:00:00 2001 From: luiz Date: Mon, 23 Dec 2024 13:53:05 +0100 Subject: [PATCH 24/77] pipeline --- src/vame/pipeline.py | 8 ++++++-- tests/test_pipeline.py | 6 +++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/vame/pipeline.py b/src/vame/pipeline.py index b8cd172c..5c55c4cd 100644 --- a/src/vame/pipeline.py +++ b/src/vame/pipeline.py @@ -161,9 +161,13 @@ def get_states(self, summary: bool = True) -> dict: logger.info(f"{key}: {value.get('execution_state', 'Not executed')}") return states - def run_pipeline(self, from_step: int = 0): + def run_pipeline( + self, + from_step: int = 0, + preprocessing_kwargs: dict = {}, + ): if from_step == 0: - self.preprocessing() + self.preprocessing(**preprocessing_kwargs) if from_step <= 1: self.create_training_set() if from_step <= 2: diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 7ad6901d..f8c25934 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -44,4 +44,8 @@ def test_pipeline(setup_pipeline): ds = pipeline.get_raw_datasets() assert isinstance(ds, xr.Dataset) - pipeline.run_pipeline() + preprocessing_kwargs = { + "centered_reference_keypoint": "Nose", + "orientation_reference_keypoint": "Tailroot", + } + pipeline.run_pipeline(preprocessing_kwargs=preprocessing_kwargs) From 520d74963e5c1349b1cbd5e47f759532c69191db Mon Sep 17 00:00:00 2001 From: luiz Date: Mon, 23 Dec 2024 14:02:57 +0100 Subject: [PATCH 25/77] visualization --- src/vame/analysis/umap.py | 29 +++++++++++++---------------- src/vame/pipeline.py | 14 +++++++++++--- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/src/vame/analysis/umap.py b/src/vame/analysis/umap.py index 1fd67b25..ab8898e8 100644 --- a/src/vame/analysis/umap.py +++ b/src/vame/analysis/umap.py @@ -218,7 +218,7 @@ def umap_vis_comm( @save_state(model=VisualizationFunctionSchema) def visualization( - config: Union[str, Path], + config: dict, segmentation_algorithm: SegmentationAlgorithms, label: Optional[str] = None, save_logs: bool = False, @@ -242,8 +242,8 @@ def visualization( Parameters ---------- - config : Union[str, Path] - Path to the configuration file. + config : dict + Configuration parameters. segmentation_algorithm : SegmentationAlgorithms Which segmentation algorithm to use. Options are 'hmm' or 'kmeans'. label : str, optional @@ -256,28 +256,25 @@ def visualization( None """ try: - config_file = Path(config).resolve() - cfg = read_config(str(config_file)) - if save_logs: - logs_path = Path(cfg["project_path"]) / "logs" / "visualization.log" + logs_path = Path(config["project_path"]) / "logs" / "visualization.log" logger_config.add_file_handler(str(logs_path)) - model_name = cfg["model_name"] - n_clusters = cfg["n_clusters"] + model_name = config["model_name"] + n_clusters = config["n_clusters"] # Get sessions - if cfg["all_data"] in ["Yes", "yes"]: - sessions = cfg["session_names"] + if config["all_data"] in ["Yes", "yes"]: + sessions = config["session_names"] else: sessions = get_sessions_from_user_input( - cfg=cfg, + cfg=config, action_message="generate visualization", ) for idx, session in enumerate(sessions): path_to_file = os.path.join( - cfg["project_path"], + config["project_path"], "results", session, "", @@ -296,7 +293,7 @@ def visualization( "umap_embedding_" + session + ".npy", ) ) - num_points = cfg["num_points"] + num_points = config["num_points"] if num_points > embed.shape[0]: num_points = embed.shape[0] except Exception: @@ -304,13 +301,13 @@ def visualization( os.mkdir(os.path.join(path_to_file, "community")) logger.info(f"Compute embedding for session {session}") embed = umap_embedding( - cfg, + config, session, model_name, n_clusters, segmentation_algorithm, ) - num_points = cfg["num_points"] + num_points = config["num_points"] if num_points > embed.shape[0]: num_points = embed.shape[0] diff --git a/src/vame/pipeline.py b/src/vame/pipeline.py index 5c55c4cd..fb595e55 100644 --- a/src/vame/pipeline.py +++ b/src/vame/pipeline.py @@ -95,8 +95,16 @@ def preprocessing( centered_reference_keypoint=centered_reference_keypoint, orientation_reference_keypoint=orientation_reference_keypoint, ) - visualize_preprocessing_scatter(config=self.config) - visualize_preprocessing_timeseries(config=self.config) + visualize_preprocessing_scatter( + config=self.config, + show_figure=False, + save_to_file=True, + ) + visualize_preprocessing_timeseries( + config=self.config, + show_figure=False, + save_to_file=True, + ) def create_training_set(self): vame.create_trainset(config=self.config) @@ -134,7 +142,7 @@ def generate_community_videos(self): def visualization(self): vame.visualization( - config=self.config_path, + config=self.config, label="community", segmentation_algorithm="hmm", ) From 706b9425d73a0e6ce8d6e20ac2d1b0ac0f4748ac Mon Sep 17 00:00:00 2001 From: luiz Date: Mon, 23 Dec 2024 17:35:49 +0100 Subject: [PATCH 26/77] some test fixes --- src/vame/initialize_project/new.py | 14 +++--- src/vame/pipeline.py | 2 +- tests/conftest.py | 68 +++++++++++++++--------------- 3 files changed, 42 insertions(+), 42 deletions(-) diff --git a/src/vame/initialize_project/new.py b/src/vame/initialize_project/new.py index f982c91a..66a7ed37 100644 --- a/src/vame/initialize_project/new.py +++ b/src/vame/initialize_project/new.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Literal +from typing import List, Optional, Literal, Tuple from datetime import datetime, timezone from pathlib import Path import shutil @@ -8,7 +8,7 @@ from vame.schemas.project import ProjectSchema, PoseEstimationFiletype from vame.schemas.states import VAMEPipelineStatesSchema from vame.logging.logger import VameLogger -from vame.util.auxiliary import write_config +from vame.util.auxiliary import write_config, read_config from vame.video.video import get_video_frame_rate from vame.io.load_poses import load_pose_estimation @@ -28,7 +28,7 @@ def init_new_project( copy_videos: bool = False, paths_to_pose_nwb_series_data: Optional[str] = None, config_kwargs: Optional[dict] = None, -) -> str: +) -> Tuple[str, dict]: """ Creates a new VAME project with the given parameters. A VAME project is a directory with the following structure: @@ -79,15 +79,15 @@ def init_new_project( Returns ------- - projconfigfile : str - Path to the new vame project config file. + Tuple[str, dict] + Tuple containing the path to the config file and the config data. """ creation_datetime = datetime.now(timezone.utc).isoformat(timespec="seconds") project_path = Path(working_directory).resolve() / project_name if project_path.exists(): logger.info('Project "{}" already exists!'.format(project_path)) projconfigfile = os.path.join(str(project_path), "config.yaml") - return projconfigfile + return projconfigfile, read_config(projconfigfile) data_path = project_path / "data" data_raw_path = data_path / "raw" @@ -244,4 +244,4 @@ def init_new_project( logger.info(f"A VAME project has been created at {project_path}") - return projconfigfile + return projconfigfile, cfg_data diff --git a/src/vame/pipeline.py b/src/vame/pipeline.py index fb595e55..9297e574 100644 --- a/src/vame/pipeline.py +++ b/src/vame/pipeline.py @@ -30,7 +30,7 @@ def __init__( paths_to_pose_nwb_series_data: Optional[str] = None, config_kwargs: Optional[dict] = None, ): - self.config_path = vame.init_new_project( + self.config_path, self.config = vame.init_new_project( project_name=project_name, videos=videos, poses_estimations=poses_estimations, diff --git a/tests/conftest.py b/tests/conftest.py index dbf04fc7..fe6dc278 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,9 +14,11 @@ def init_project( source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"], working_directory: str, egocentric_data: bool = False, + centered_reference_keypoint: str = "Nose", + orientation_reference_keypoint: str = "Tailroot", paths_to_pose_nwb_series_data: Optional[List[str]] = None, ): - config = vame.init_new_project( + config_path, config_values = vame.init_new_project( project_name=project_name, videos=videos, poses_estimations=poses_estimations, @@ -27,21 +29,21 @@ def init_project( ) # Override config values with test values to speed up tests - config_values = read_config(config) config_values["egocentric_data"] = egocentric_data config_values["max_epochs"] = 10 config_values["batch_size"] = 10 - write_config(config, config_values) + write_config(config_path, config_values) project_data = { "project_name": project_name, "videos": videos, - "config_path": config, + "config_path": config_path, "config_data": config_values, - "pose_ref_index": [0, 5], + "centered_reference_keypoint": centered_reference_keypoint, + "orientation_reference_keypoint": orientation_reference_keypoint, } - return config, project_data + return project_data @fixture(scope="session") @@ -52,7 +54,7 @@ def setup_project_from_folder(): working_directory = "./tests" # Initialize project - config, project_data = init_project( + project_data = init_project( project_name=project_name, videos=videos, poses_estimations=poses_estimations, @@ -64,7 +66,8 @@ def setup_project_from_folder(): yield project_data # Clean up - shutil.rmtree(Path(config).parent) + config_path = project_data["config_path"] + shutil.rmtree(Path(config_path).parent) @fixture(scope="session") @@ -75,7 +78,7 @@ def setup_project_not_aligned_data(): working_directory = "./tests" # Initialize project - config, project_data = init_project( + project_data = init_project( project_name=project_name, videos=videos, poses_estimations=poses_estimations, @@ -87,7 +90,8 @@ def setup_project_not_aligned_data(): yield project_data # Clean up - shutil.rmtree(Path(config).parent) + config_path = project_data["config_path"] + shutil.rmtree(Path(config_path).parent) # # TODO change to test fixed (already egocentrically aligned) data when have it @@ -99,7 +103,7 @@ def setup_project_fixed_data(): working_directory = "./tests" # Initialize project - config, project_data = init_project( + project_data = init_project( project_name=project_name, videos=videos, poses_estimations=poses_estimations, @@ -111,7 +115,8 @@ def setup_project_fixed_data(): yield project_data # Clean up - shutil.rmtree(Path(config).parent) + config_path = project_data["config_path"] + shutil.rmtree(Path(config_path).parent) # @fixture(scope="session") @@ -149,10 +154,13 @@ def setup_project_and_convert_pose_to_numpy(setup_project_fixed_data): @fixture(scope="session") def setup_project_and_align_egocentric(setup_project_not_aligned_data): - config_path = setup_project_not_aligned_data["config_path"] - vame.egocentric_alignment( - config_path, - pose_ref_index=setup_project_not_aligned_data["pose_ref_index"], + config_data = setup_project_not_aligned_data["config_data"] + centered_reference_keypoint = setup_project_not_aligned_data["centered_reference_keypoint"] + orientation_reference_keypoint = setup_project_not_aligned_data["orientation_reference_keypoint"] + vame.preprocessing( + config=config_data, + centered_reference_keypoint=centered_reference_keypoint, + orientation_reference_keypoint=orientation_reference_keypoint, save_logs=True, ) return setup_project_not_aligned_data @@ -160,11 +168,9 @@ def setup_project_and_align_egocentric(setup_project_not_aligned_data): @fixture(scope="function") def setup_project_and_check_param_aligned_dataset(setup_project_and_align_egocentric): - config = setup_project_and_align_egocentric["config_path"] + config = setup_project_and_align_egocentric["config_data"] vame.create_trainset( - config, - check_parameter=True, - pose_ref_index=setup_project_and_align_egocentric["pose_ref_index"], + config=config, save_logs=True, ) return setup_project_and_align_egocentric @@ -175,11 +181,9 @@ def setup_project_and_check_param_fixed_dataset( setup_project_and_convert_pose_to_numpy, ): # use setup_project_and_align_egocentric fixture or setup_project_and_convert_pose_to_numpy based on value of egocentric_aligned - config = setup_project_and_convert_pose_to_numpy["config_path"] + config = setup_project_and_convert_pose_to_numpy["config_data"] vame.create_trainset( - config, - check_parameter=True, - pose_ref_index=setup_project_and_convert_pose_to_numpy["pose_ref_index"], + config=config, save_logs=True, ) return setup_project_and_convert_pose_to_numpy @@ -187,11 +191,9 @@ def setup_project_and_check_param_fixed_dataset( @fixture(scope="session") def setup_project_and_create_train_aligned_dataset(setup_project_and_align_egocentric): - config = setup_project_and_align_egocentric["config_path"] + config = setup_project_and_align_egocentric["config_data"] vame.create_trainset( - config, - check_parameter=False, - pose_ref_index=setup_project_and_align_egocentric["pose_ref_index"], + config=config, save_logs=True, ) return setup_project_and_align_egocentric @@ -202,11 +204,9 @@ def setup_project_and_create_train_fixed_dataset( setup_project_and_convert_pose_to_numpy, ): # use setup_project_and_align_egocentric fixture or setup_project_and_convert_pose_to_numpy based on value of egocentric_aligned - config = setup_project_and_convert_pose_to_numpy["config_path"] + config = setup_project_and_convert_pose_to_numpy["config_data"] vame.create_trainset( - config, - check_parameter=False, - pose_ref_index=setup_project_and_convert_pose_to_numpy["pose_ref_index"], + config=config, save_logs=True, ) return setup_project_and_convert_pose_to_numpy @@ -214,13 +214,13 @@ def setup_project_and_create_train_fixed_dataset( @fixture(scope="session") def setup_project_and_train_model(setup_project_and_create_train_aligned_dataset): - config = setup_project_and_create_train_aligned_dataset["config_path"] + config = setup_project_and_create_train_aligned_dataset["config_data"] vame.train_model(config, save_logs=True) return setup_project_and_create_train_aligned_dataset @fixture(scope="session") def setup_project_and_evaluate_model(setup_project_and_train_model): - config = setup_project_and_train_model["config_path"] + config = setup_project_and_train_model["config_data"] vame.evaluate_model(config, save_logs=True) return setup_project_and_train_model From 965cf2af6bcd122beb7d321fa2d96b4352a8d77e Mon Sep 17 00:00:00 2001 From: luiz Date: Mon, 23 Dec 2024 18:01:31 +0100 Subject: [PATCH 27/77] fixes tests --- tests/test_analysis.py | 35 +++++++++++++++----------------- tests/test_initialize_project.py | 4 ++-- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/tests/test_analysis.py b/tests/test_analysis.py index e80502d5..98822738 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -30,7 +30,7 @@ def test_pose_segmentation_hmm_files_exists( with patch("vame.analysis.pose_segmentation.read_config", return_value=mock_config) as mock_read_config: with patch("builtins.input", return_value="yes"): vame.segment_session( - setup_project_and_train_model["config_path"], + config=setup_project_and_train_model["config_data"], save_logs=True, ) project_path = setup_project_and_train_model["config_data"]["project_path"] @@ -48,7 +48,7 @@ def test_pose_segmentation_hmm_files_exists( @pytest.mark.parametrize("segmentation_algorithm", ["hmm", "kmeans"]) def test_motif_videos_mp4_files_exists(setup_project_and_train_model, segmentation_algorithm): vame.motif_videos( - setup_project_and_train_model["config_path"], + config=setup_project_and_train_model["config_data"], segmentation_algorithm=segmentation_algorithm, output_video_type=".mp4", save_logs=True, @@ -75,7 +75,7 @@ def test_motif_videos_mp4_files_exists(setup_project_and_train_model, segmentati def test_motif_videos_avi_files_exists(setup_project_and_train_model, segmentation_algorithm): # Check if the files are created vame.motif_videos( - setup_project_and_train_model["config_path"], + config=setup_project_and_train_model["config_data"], segmentation_algorithm=segmentation_algorithm, output_video_type=".avi", save_logs=True, @@ -102,7 +102,7 @@ def test_motif_videos_avi_files_exists(setup_project_and_train_model, segmentati # def test_community_files_exists(setup_project_and_train_model, segmentation_algorithm): # # Check if the files are created # vame.community( -# setup_project_and_train_model["config_path"], +# config=setup_project_and_train_model["config_data"], # cut_tree=2, # cohort=False, # segmentation_algorithm=segmentation_algorithm, @@ -135,11 +135,11 @@ def test_motif_videos_avi_files_exists(setup_project_and_train_model, segmentati def test_cohort_community_files_exists(setup_project_and_train_model, segmentation_algorithm): # Check if the files are created vame.community( - setup_project_and_train_model["config_path"], - cut_tree=2, + config=setup_project_and_train_model["config_data"], + segmentation_algorithm=segmentation_algorithm, cohort=True, + cut_tree=2, save_logs=True, - segmentation_algorithm=segmentation_algorithm, ) project_path = setup_project_and_train_model["config_data"]["project_path"] n_clusters = setup_project_and_train_model["config_data"]["n_clusters"] @@ -161,9 +161,8 @@ def test_community_videos_mp4_files_exists( setup_project_and_train_model, segmentation_algorithm, ): - vame.community_videos( - config=setup_project_and_train_model["config_path"], + config=setup_project_and_train_model["config_data"], segmentation_algorithm=segmentation_algorithm, save_logs=True, output_video_type=".mp4", @@ -191,9 +190,8 @@ def test_community_videos_avi_files_exists( setup_project_and_train_model, segmentation_algorithm, ): - vame.community_videos( - config=setup_project_and_train_model["config_path"], + config=setup_project_and_train_model["config_data"], segmentation_algorithm=segmentation_algorithm, save_logs=True, output_video_type=".avi", @@ -233,7 +231,7 @@ def test_visualization_output_files( segmentation_algorithm, ): vame.visualization( - setup_project_and_train_model["config_path"], + setup_project_and_train_model["config_data"], segmentation_algorithm=segmentation_algorithm, label=label, save_logs=True, @@ -270,7 +268,7 @@ def test_generative_model_figures( segmentation_algorithm, ): generative_figure = vame.generative_model( - config=setup_project_and_train_model["config_path"], + config=setup_project_and_train_model["config_data"], segmentation_algorithm=segmentation_algorithm, mode=mode, save_logs=True, @@ -287,7 +285,7 @@ def test_report( segmentation_algorithm, ): vame.report( - config=setup_project_and_train_model["config_path"], + config=setup_project_and_train_model["config_data"], segmentation_algorithm=segmentation_algorithm, ) reports_path = Path(setup_project_and_train_model["config_data"]["project_path"]) / "reports" @@ -297,7 +295,7 @@ def test_report( def test_generative_kmeans_wrong_mode(setup_project_and_train_model): with pytest.raises(ValueError): vame.generative_model( - config=setup_project_and_train_model["config_path"], + config=setup_project_and_train_model["config_data"], segmentation_algorithm="hmm", mode="centers", save_logs=True, @@ -306,9 +304,8 @@ def test_generative_kmeans_wrong_mode(setup_project_and_train_model): @pytest.mark.parametrize("label", [None, "community", "motif"]) def test_gif_frames_files_exists(setup_project_and_evaluate_model, label): - with patch("builtins.input", return_value="yes"): - vame.segment_session(setup_project_and_evaluate_model["config_path"]) + vame.segment_session(setup_project_and_evaluate_model["config_data"]) def mock_background( project_path=None, @@ -329,14 +326,14 @@ def mock_background( SEGMENTATION_ALGORITHM = "hmm" VIDEO_LEN = 30 vame.community( - setup_project_and_evaluate_model["config_path"], + config=setup_project_and_evaluate_model["config_data"], cut_tree=2, cohort=True, save_logs=False, segmentation_algorithm=SEGMENTATION_ALGORITHM, ) vame.visualization( - setup_project_and_evaluate_model["config_path"], + config=setup_project_and_evaluate_model["config_data"], segmentation_algorithm=SEGMENTATION_ALGORITHM, label=label, save_logs=False, diff --git a/tests/test_initialize_project.py b/tests/test_initialize_project.py index 7fea631e..d4971406 100644 --- a/tests/test_initialize_project.py +++ b/tests/test_initialize_project.py @@ -26,14 +26,14 @@ def test_existing_project(): poses_estimations = ["./tests/tests_project_sample_data/cropped_video.csv"] working_directory = "./tests" - config_path_creation = init_new_project( + config_path_creation, config_creation = init_new_project( project_name=project_name, videos=videos, poses_estimations=poses_estimations, source_software="DeepLabCut", working_directory=working_directory, ) - config_path_duplicated = init_new_project( + config_path_duplicated, config_duplicated = init_new_project( project_name=project_name, videos=videos, poses_estimations=poses_estimations, From 400e3cae433832210c00a9d846c04cdaa6552742 Mon Sep 17 00:00:00 2001 From: luiz Date: Wed, 25 Dec 2024 16:10:08 +0100 Subject: [PATCH 28/77] test --- tests/test_analysis.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_analysis.py b/tests/test_analysis.py index 98822738..cccc811c 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -27,12 +27,12 @@ def test_pose_segmentation_hmm_files_exists( "individual_segmentation": individual_segmentation, } mock_config["hmm_trained"] = hmm_trained - with patch("vame.analysis.pose_segmentation.read_config", return_value=mock_config) as mock_read_config: - with patch("builtins.input", return_value="yes"): - vame.segment_session( - config=setup_project_and_train_model["config_data"], - save_logs=True, - ) + # with patch("vame.analysis.pose_segmentation.read_config", return_value=mock_config) as mock_read_config: + with patch("builtins.input", return_value="yes"): + vame.segment_session( + config=setup_project_and_train_model["config_data"], + save_logs=True, + ) project_path = setup_project_and_train_model["config_data"]["project_path"] file = setup_project_and_train_model["config_data"]["session_names"][0] model_name = setup_project_and_train_model["config_data"]["model_name"] From 05924cb0d04adcda1f379d3c4c5c97c0e585ddf7 Mon Sep 17 00:00:00 2001 From: luiz Date: Wed, 25 Dec 2024 16:13:35 +0100 Subject: [PATCH 29/77] tests --- src/vame/analysis/generative_functions.py | 32 +++++++++++------------ 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/src/vame/analysis/generative_functions.py b/src/vame/analysis/generative_functions.py index 075a6b0e..29128c5c 100644 --- a/src/vame/analysis/generative_functions.py +++ b/src/vame/analysis/generative_functions.py @@ -228,7 +228,7 @@ def visualize_cluster_center( @save_state(model=GenerativeModelFunctionSchema) def generative_model( - config: str, + config: dict, segmentation_algorithm: SegmentationAlgorithms, mode: str = "sampling", save_logs: bool = False, @@ -238,8 +238,8 @@ def generative_model( Parameters: ----------- - config : str - Path to the configuration file. + config : dict + Configuration dictionary. mode : str, optional Mode for generating samples. Defaults to "sampling". @@ -249,29 +249,27 @@ def generative_model( Plots of generated samples for each segmentation algorithm. """ try: - config_file = str(Path(config).resolve()) - cfg = read_config(config_file) if save_logs: - logs_path = Path(cfg["project_path"]) / "logs" / "generative_model.log" + logs_path = Path(config["project_path"]) / "logs" / "generative_model.log" logger_config.add_file_handler(str(logs_path)) logger.info(f"Running generative model with mode {mode}...") - model_name = cfg["model_name"] - n_clusters = cfg["n_clusters"] + model_name = config["model_name"] + n_clusters = config["n_clusters"] # Get sessions - if cfg["all_data"] in ["Yes", "yes"]: - sessions = cfg["session_names"] + if config["all_data"] in ["Yes", "yes"]: + sessions = config["session_names"] else: sessions = get_sessions_from_user_input( - cfg=cfg, + cfg=config, action_message="generate samples", ) - model = load_model(cfg, model_name, fixed=False) + model = load_model(config, model_name, fixed=False) for session in sessions: path_to_file = os.path.join( - cfg["project_path"], + config["project_path"], "results", session, model_name, @@ -287,7 +285,7 @@ def generative_model( ) ) return random_generative_samples( - cfg, + config, model, latent_vector, ) @@ -300,7 +298,7 @@ def generative_model( ) ) return random_reconstruction_samples( - cfg, + config, model, latent_vector, ) @@ -317,7 +315,7 @@ def generative_model( ) ) return visualize_cluster_center( - cfg, + config, model, cluster_center, ) @@ -337,7 +335,7 @@ def generative_model( ) ) return random_generative_samples_motif( - cfg=cfg, + cfg=config, model=model, latent_vector=latent_vector, labels=labels, From 42def9acd07f5f5b2f4258f71dd1c8d01eb3b229 Mon Sep 17 00:00:00 2001 From: luiz Date: Wed, 25 Dec 2024 16:20:21 +0100 Subject: [PATCH 30/77] tests --- src/vame/util/gif_pose_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vame/util/gif_pose_helper.py b/src/vame/util/gif_pose_helper.py index 7b9eebc6..4247f834 100644 --- a/src/vame/util/gif_pose_helper.py +++ b/src/vame/util/gif_pose_helper.py @@ -84,7 +84,7 @@ def get_animal_frames( "raw", session + ".nc", ) - data, data_mat = read_pose_estimation_file(file_path=file_path) + data, data_mat, ds = read_pose_estimation_file(file_path=file_path) # get the coordinates for alignment from data table pose_list = [] From 587caa6345a41d90ff545bd21ee47bcb07ce4a13 Mon Sep 17 00:00:00 2001 From: luiz Date: Wed, 25 Dec 2024 16:28:59 +0100 Subject: [PATCH 31/77] tests --- src/vame/util/csv_to_npy.py | 25 +++++++++++-------------- tests/conftest.py | 4 ++-- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/src/vame/util/csv_to_npy.py b/src/vame/util/csv_to_npy.py index 97b47d60..1588ab66 100644 --- a/src/vame/util/csv_to_npy.py +++ b/src/vame/util/csv_to_npy.py @@ -16,7 +16,7 @@ @save_state(model=PoseToNumpyFunctionSchema) def pose_to_numpy( - config: str, + config: dict, save_logs=False, ) -> None: """ @@ -26,8 +26,8 @@ def pose_to_numpy( Parameters ---------- - config : str - Path to the config.yaml file. + config : dict + Configuration dictionary. save_logs : bool, optional If True, the logs will be saved to a file, by default False. @@ -37,23 +37,20 @@ def pose_to_numpy( If the config.yaml file indicates that the data is not egocentric. """ try: - config_file = Path(config).resolve() - cfg = read_config(str(config_file)) - if save_logs: - log_path = Path(cfg["project_path"]) / "logs" / "pose_to_numpy.log" + log_path = Path(config["project_path"]) / "logs" / "pose_to_numpy.log" logger_config.add_file_handler(str(log_path)) - project_path = cfg["project_path"] - sessions = cfg["session_names"] - confidence = cfg["pose_confidence"] - if not cfg["egocentric_data"]: + project_path = config["project_path"] + sessions = config["session_names"] + confidence = config["pose_confidence"] + if not config["egocentric_data"]: raise ValueError( "The config.yaml indicates that the data is not egocentric. Please check the parameter egocentric_data" ) - file_type = cfg["pose_estimation_filetype"] - paths_to_pose_nwb_series_data = cfg["paths_to_pose_nwb_series_data"] + file_type = config["pose_estimation_filetype"] + paths_to_pose_nwb_series_data = config["paths_to_pose_nwb_series_data"] for i, session in enumerate(sessions): file_path = os.path.join( project_path, @@ -61,7 +58,7 @@ def pose_to_numpy( "raw", session + ".nc", ) - data, data_mat = read_pose_estimation_file( + data, data_mat, ds = read_pose_estimation_file( file_path=file_path, file_type=file_type, path_to_pose_nwb_series_data=( diff --git a/tests/conftest.py b/tests/conftest.py index fe6dc278..70080051 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -147,8 +147,8 @@ def setup_project_fixed_data(): @fixture(scope="session") def setup_project_and_convert_pose_to_numpy(setup_project_fixed_data): - config_path = setup_project_fixed_data["config_path"] - vame.pose_to_numpy(config_path, save_logs=True) + config = setup_project_fixed_data["config_data"] + vame.pose_to_numpy(config, save_logs=True) return setup_project_fixed_data From 6dba1e8fb3a567868e7466938852195f677b6f9f Mon Sep 17 00:00:00 2001 From: luiz Date: Wed, 25 Dec 2024 16:42:54 +0100 Subject: [PATCH 32/77] path fix --- src/vame/util/csv_to_npy.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/vame/util/csv_to_npy.py b/src/vame/util/csv_to_npy.py index 1588ab66..590c3192 100644 --- a/src/vame/util/csv_to_npy.py +++ b/src/vame/util/csv_to_npy.py @@ -100,7 +100,6 @@ def pose_to_numpy( project_path, "data", "processed", - session, session + "-PE-seq.npy", ), final_positions.T, From a08365243729bad1d4d12f2ceb9c02fcc321eb7e Mon Sep 17 00:00:00 2001 From: luiz Date: Wed, 25 Dec 2024 17:10:12 +0100 Subject: [PATCH 33/77] tests --- tests/conftest.py | 33 +++++++++++++++++---------------- tests/test_model.py | 4 ++-- tests/test_util.py | 6 ++---- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 70080051..b75ce3e2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -from pytest import fixture +from pytest import fixture, raises from pathlib import Path import shutil from typing import List, Optional, Literal @@ -182,10 +182,11 @@ def setup_project_and_check_param_fixed_dataset( ): # use setup_project_and_align_egocentric fixture or setup_project_and_convert_pose_to_numpy based on value of egocentric_aligned config = setup_project_and_convert_pose_to_numpy["config_data"] - vame.create_trainset( - config=config, - save_logs=True, - ) + with raises(NotImplementedError, match="Fixed data training is not implemented yet"): + vame.create_trainset( + config=config, + save_logs=True, + ) return setup_project_and_convert_pose_to_numpy @@ -199,17 +200,17 @@ def setup_project_and_create_train_aligned_dataset(setup_project_and_align_egoce return setup_project_and_align_egocentric -@fixture(scope="session") -def setup_project_and_create_train_fixed_dataset( - setup_project_and_convert_pose_to_numpy, -): - # use setup_project_and_align_egocentric fixture or setup_project_and_convert_pose_to_numpy based on value of egocentric_aligned - config = setup_project_and_convert_pose_to_numpy["config_data"] - vame.create_trainset( - config=config, - save_logs=True, - ) - return setup_project_and_convert_pose_to_numpy +# @fixture(scope="session") +# def setup_project_and_create_train_fixed_dataset( +# setup_project_and_convert_pose_to_numpy, +# ): +# # use setup_project_and_align_egocentric fixture or setup_project_and_convert_pose_to_numpy based on value of egocentric_aligned +# config = setup_project_and_convert_pose_to_numpy["config_data"] +# vame.create_trainset( +# config=config, +# save_logs=True, +# ) +# return setup_project_and_convert_pose_to_numpy @fixture(scope="session") diff --git a/tests/test_model.py b/tests/test_model.py index 5b92489f..913d3e9f 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -8,7 +8,7 @@ "fixture", [ "setup_project_and_create_train_aligned_dataset", - "setup_project_and_create_train_fixed_dataset", + # "setup_project_and_create_train_fixed_dataset", ], ) def test_create_train_dataset_output_files_exists(request, fixture): @@ -28,7 +28,7 @@ def test_create_train_dataset_output_files_exists(request, fixture): "fixture", [ "setup_project_and_check_param_aligned_dataset", - "setup_project_and_check_param_fixed_dataset", + # "setup_project_and_check_param_fixed_dataset", ], ) def test_create_check_param_train_dataset(request, fixture): diff --git a/tests/test_util.py b/tests/test_util.py index 52f1a769..4f2e8eb4 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -12,7 +12,6 @@ def test_pose_to_numpy_file_exists(setup_project_and_convert_pose_to_numpy): project_path, "data", "processed", - file_name, f"{file_name}-PE-seq.npy", ) assert os.path.exists(file_path) @@ -23,13 +22,12 @@ def test_egocentric_alignment_file_is_created(setup_project_and_align_egocentric Test if the egocentric alignment function creates the expected file. """ project_path = setup_project_and_align_egocentric["config_data"]["project_path"] - file_name = setup_project_and_align_egocentric["config_data"]["session_names"][0] + session_name = setup_project_and_align_egocentric["config_data"]["session_names"][0] file_path = os.path.join( project_path, "data", "processed", - file_name, - f"{file_name}-PE-seq.npy", + f"{session_name}_processed.nc", ) assert os.path.exists(file_path) From b328ea7d69c5da800daabc1594922074b882e753 Mon Sep 17 00:00:00 2001 From: luiz Date: Wed, 25 Dec 2024 17:39:31 +0100 Subject: [PATCH 34/77] version --- CHANGELOG.md | 15 +++++++++++++++ pyproject.toml | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 591e2ea1..7aeba75e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,18 @@ +# v0.7.0 + +### Fixes + +- Egocentric alignment outputs with incorrect column order ([Issue #96](https://github.com/EthoML/VAME/issues/96)) +- Slow align egocentrical data ([Issue #113](https://github.com/EthoML/VAME/issues/113)) +- Standardized config argument across all functions + +### Features + +- Adopt movement Xarray data format ([Issue #111](https://github.com/EthoML/VAME/issues/111)) +- Relocate IQR cleaning into preprocessing ([Issue #22](https://github.com/EthoML/VAME/issues/22)) +- Created preprocessing module ([Issue #119](https://github.com/EthoML/VAME/issues/119)) + + # v0.6.0 ### Fixes diff --git a/pyproject.toml b/pyproject.toml index f26b7f14..086ac161 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vame-py" -version = '0.6.0' +version = '0.7.0' dynamic = ["dependencies", "readme"] description = "Variational Animal Motion Embedding." authors = [{ name = "K. Luxem & " }, { name = "P. Bauer" }] From 2242c79b918f15409f0f7a97d97c171d213195f1 Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 26 Dec 2024 17:49:46 +0100 Subject: [PATCH 35/77] fix for nans --- src/vame/preprocessing/alignment.py | 10 +- src/vame/preprocessing/cleaning.py | 12 ++- src/vame/preprocessing/preprocessing.py | 3 + src/vame/preprocessing/visualization.py | 131 ++++++++++++++++++++---- 4 files changed, 132 insertions(+), 24 deletions(-) diff --git a/src/vame/preprocessing/alignment.py b/src/vame/preprocessing/alignment.py index 72e55cf6..afe54ba7 100644 --- a/src/vame/preprocessing/alignment.py +++ b/src/vame/preprocessing/alignment.py @@ -48,9 +48,15 @@ def egocentrically_align_and_center( # Extract keypoint indices keypoints = ds.coords["keypoints"].values if centered_reference_keypoint not in keypoints: - raise ValueError(f"Centered reference keypoint {centered_reference_keypoint} not found in dataset.") + raise ValueError( + f"Centered reference keypoint {centered_reference_keypoint} not found in dataset.", + f"Available keypoints: {keypoints}", + ) if orientation_reference_keypoint not in keypoints: - raise ValueError(f"Orientation reference keypoint {orientation_reference_keypoint} not found in dataset.") + raise ValueError( + f"Orientation reference keypoint {orientation_reference_keypoint} not found in dataset.", + f"Available keypoints: {keypoints}", + ) idx1 = np.where(keypoints == centered_reference_keypoint)[0][0] idx2 = np.where(keypoints == orientation_reference_keypoint)[0][0] diff --git a/src/vame/preprocessing/cleaning.py b/src/vame/preprocessing/cleaning.py index 71ed680d..100ca849 100644 --- a/src/vame/preprocessing/cleaning.py +++ b/src/vame/preprocessing/cleaning.py @@ -36,17 +36,21 @@ def lowconf_cleaning( cleaned_position = np.empty_like(position) confidence = ds["confidence"].values - perc_interp_points = np.zeros((position.shape[1], position.shape[2])) + perc_interp_points = np.zeros((position.shape[1], position.shape[2], position.shape[3])) for individual in range(position.shape[1]): for keypoint in range(position.shape[2]): conf_series = confidence[:, individual, keypoint] - nan_mask = conf_series < pose_confidence - perc_interp_points[individual, keypoint] = 100 * np.sum(nan_mask) / len(nan_mask) for space in range(position.shape[3]): # Set low-confidence positions to NaN + nan_mask = conf_series < pose_confidence series = np.copy(position[:, individual, keypoint, space]) series[nan_mask] = np.nan + # Update nan_mask because the series might come with NaN values previously + nan_mask = np.isnan(series) + + perc_interp_points[individual, keypoint, space] = 100 * np.sum(nan_mask) / len(nan_mask) + # Interpolate NaN values if not nan_mask.all(): series[nan_mask] = np.interp( @@ -62,7 +66,7 @@ def lowconf_cleaning( ds[save_to_variable] = (ds[read_from_variable].dims, cleaned_position) ds.attrs.update({"processed_confidence": True}) - ds["percentage_low_confidence"] = (["individual", "keypoint"], perc_interp_points) + ds["percentage_low_confidence"] = (["individual", "keypoint", "space"], perc_interp_points) # Save the cleaned dataset to file cleaned_file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") diff --git a/src/vame/preprocessing/preprocessing.py b/src/vame/preprocessing/preprocessing.py index 55b55f80..39d48fae 100644 --- a/src/vame/preprocessing/preprocessing.py +++ b/src/vame/preprocessing/preprocessing.py @@ -40,6 +40,9 @@ def preprocessing( # Create visualization of the preprocessing results up to this point visualize_preprocessing_scatter( config, + original_positions_key="position", + cleaned_positions_key="position_cleaned_lowconf", + aligned_positions_key="position_egocentric_aligned", save_to_file=True, show_figure=False, ) diff --git a/src/vame/preprocessing/visualization.py b/src/vame/preprocessing/visualization.py index 2bc34592..337e9adf 100644 --- a/src/vame/preprocessing/visualization.py +++ b/src/vame/preprocessing/visualization.py @@ -1,6 +1,7 @@ from pathlib import Path import matplotlib.pyplot as plt from matplotlib.cm import get_cmap +import numpy as np from vame.io.load_poses import read_pose_estimation_file @@ -41,24 +42,40 @@ def visualize_preprocessing_scatter( for i, frame in enumerate(frames): # Compute dynamic limits for the original positions x_orig, y_orig = original_positions[frame, 0, :, 0], original_positions[frame, 0, :, 1] - # x_orig -= x_orig[0] # Centralize around the first keypoint - # y_orig -= y_orig[0] - x_min, x_max = x_orig.min() - 10, x_orig.max() + 10 # Add a margin - y_min, y_max = y_orig.min() - 10, y_orig.max() + 10 - - # Centralized Original positions - ax_original = axes[i, 0] - ax_original.scatter(x_orig, y_orig, c="blue", label="Original") - for k, (x, y) in enumerate(zip(x_orig, y_orig)): - ax_original.text(x, y, keypoints_labels[k], fontsize=10, color="blue") - ax_original.set_title(f"Original - Frame {frame}", fontsize=14) - ax_original.set_xlabel("X", fontsize=12) - ax_original.set_ylabel("Y", fontsize=12) - ax_original.axhline(0, color="gray", linestyle="--") - ax_original.axvline(0, color="gray", linestyle="--") - ax_original.axis("equal") - ax_original.set_xlim(x_min, x_max) - ax_original.set_ylim(y_min, y_max) + + # Identify keypoints that are NaN + nan_keypoints = [keypoints_labels[k] for k in range(len(keypoints_labels)) if np.isnan(x_orig[k]) or np.isnan(y_orig[k])] + + # Check if original positions contain all NaNs + if np.all(np.isnan(x_orig)) or np.all(np.isnan(y_orig)): + ax_original = axes[i, 0] + ax_original.set_title(f"Original - Frame {frame} (All NaNs)", fontsize=14, color="red") + ax_original.axis("off") # Hide axis since there is no data to plot + else: + x_min, x_max = np.nanmin(x_orig) - 10, np.nanmax(x_orig) + 10 # Add a margin + y_min, y_max = np.nanmin(y_orig) - 10, np.nanmax(y_orig) + 10 + + ax_original = axes[i, 0] + ax_original.scatter(x_orig, y_orig, c="blue", label="Original") + for k, (x, y) in enumerate(zip(x_orig, y_orig)): + ax_original.text(x, y, keypoints_labels[k], fontsize=10, color="blue") + + # Include NaN keypoints in the title + if nan_keypoints: + nan_text = ", ".join(nan_keypoints) + title_text = f"Original - Frame {frame}\nNaNs: {nan_text}" + else: + title_text = f"Original - Frame {frame}" + + ax_original.set_title(title_text, fontsize=14) + + ax_original.set_xlabel("X", fontsize=12) + ax_original.set_ylabel("Y", fontsize=12) + ax_original.axhline(0, color="gray", linestyle="--") + ax_original.axvline(0, color="gray", linestyle="--") + ax_original.axis("equal") + ax_original.set_xlim(x_min, x_max) + ax_original.set_ylim(y_min, y_max) # Compute dynamic limits for the cleaned positions x_cleaned, y_cleaned = cleaned_positions[frame, 0, :, 0], cleaned_positions[frame, 0, :, 1] @@ -267,3 +284,81 @@ def visualize_preprocessing_timeseries( plt.close( fig, ) + + +def visualize_timeseries( + config: dict, + session_index: int = 0, + n_samples: int = 1000, + positions_key: str = "position", + keypoints_labels: list[str] | None = None, + save_to_file: bool = False, + show_figure: bool = True, +): + """ + Visualize the original positions of the keypoints in a timeseries plot. + """ + project_path = config["project_path"] + sessions = config["session_names"] + session = sessions[session_index] + + # Read session data + file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") + _, _, ds = read_pose_estimation_file(file_path=file_path) + + fig, ax = plt.subplots(2, 1, figsize=(10, 8)) + + individual = "individual_0" + if keypoints_labels is None: + keypoints_labels = ds.keypoints.values + + # Create a colormap with distinguishable colors + cmap = get_cmap("tab10") if len(keypoints_labels) <= 10 else get_cmap("tab20") + colors = [cmap(i / len(keypoints_labels)) for i in range(len(keypoints_labels))] + + for i, kp in enumerate(keypoints_labels): + sel_x = dict( + individuals=individual, + keypoints=kp, + space="x", + ) + sel_y = dict( + individuals=individual, + keypoints=kp, + space="y", + ) + + # Original positions (first two subplots) + ds[positions_key].sel(**sel_x)[0:n_samples].plot( + linewidth=1.5, + ax=ax[0], + label=kp, + color=colors[i], + ) + ds[positions_key].sel(**sel_y)[0:n_samples].plot( + linewidth=1.5, + ax=ax[1], + label=kp, + color=colors[i], + ) + + # Set common labels for Y axes + ax[0].set_ylabel( + "Allocentric X", + fontsize=12, + ) + ax[1].set_ylabel( + "Allocentric Y", + fontsize=12, + ) + + # Labels for X axes + for idx, a in enumerate(ax): + a.set_title("") + if idx % 2 == 0: + a.set_xlabel("") + else: + a.set_xlabel( + "Time", + fontsize=10, + ) From ec16d20176dea299151100e38b33f037d9c0cb51 Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 26 Dec 2024 18:43:30 +0100 Subject: [PATCH 36/77] update example notebook --- examples/pipeline.ipynb | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/pipeline.ipynb b/examples/pipeline.ipynb index 7f09ee47..cbbf4fe0 100644 --- a/examples/pipeline.ipynb +++ b/examples/pipeline.ipynb @@ -35,6 +35,13 @@ "videos = [ps[\"video\"]]\n", "poses_estimations = [ps[\"poses\"]]\n", "\n", + "# Customize the configuration for the project\n", + "config_kwargs = {\n", + " \"n_clusters\": 15,\n", + " \"pose_confidence\": 0.9,\n", + " \"max_epochs\": 10,\n", + "}\n", + "\n", "# Instantiate the pipeline\n", "# this will create a VAME project and prepare the data\n", "pipeline = VAMEPipeline(\n", @@ -43,6 +50,7 @@ " videos=videos,\n", " poses_estimations=poses_estimations,\n", " source_software=source_software,\n", + " config_kwargs=config_kwargs,\n", ")" ] }, @@ -64,7 +72,11 @@ "outputs": [], "source": [ "# Run the pipeline\n", - "pipeline.run_pipeline()" + "preprocessing_kwargs = {\n", + " \"centered_reference_keypoint\": \"snout\",\n", + " \"orientation_reference_keypoint\": \"tail_base\",\n", + "}\n", + "pipeline.run_pipeline(preprocessing_kwargs=preprocessing_kwargs)" ] }, { From 6e9836434d59a601b027c916fb3bf4500b346169 Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 26 Dec 2024 19:27:53 +0100 Subject: [PATCH 37/77] docstrings --- src/vame/preprocessing/cleaning.py | 36 ++++++++++++++++++++----- src/vame/preprocessing/filter.py | 15 ++++++++++- src/vame/preprocessing/visualization.py | 4 ++- 3 files changed, 47 insertions(+), 8 deletions(-) diff --git a/src/vame/preprocessing/cleaning.py b/src/vame/preprocessing/cleaning.py index 100ca849..8e63eb71 100644 --- a/src/vame/preprocessing/cleaning.py +++ b/src/vame/preprocessing/cleaning.py @@ -14,12 +14,24 @@ def lowconf_cleaning( config: dict, read_from_variable: str = "position_processed", save_to_variable: str = "position_processed", -): +) -> None: """ - Clean the low confidence data points from the dataset. - Processes position data by: + Clean the low confidence data points from the dataset. Processes position data by: - setting low-confidence points to NaN - interpolating NaN points + + Parameters: + ----------- + config : dict + Configuration dictionary. + read_from_variable : str, optional + Variable to read from the dataset. + save_to_variable : str, optional + Variable to save the cleaned data to. + + Returns: + -------- + None """ project_path = config["project_path"] sessions = config["session_names"] @@ -80,12 +92,24 @@ def outlier_cleaning( config: dict, read_from_variable: str = "position_processed", save_to_variable: str = "position_processed", -): +) -> None: """ - Clean the outliers from the dataset. - Processes position data by: + Clean the outliers from the dataset. Processes position data by: - setting outlier points to NaN - interpolating NaN points + + Parameters: + ----------- + config : dict + Configuration dictionary. + read_from_variable : str, optional + Variable to read from the dataset. + save_to_variable : str, optional + Variable to save the cleaned data to. + + Returns: + -------- + None """ logger.info("Cleaning outliers with Z-score transformation and IQR cutoff.") project_path = config["project_path"] diff --git a/src/vame/preprocessing/filter.py b/src/vame/preprocessing/filter.py index d3c15da5..a4f81702 100644 --- a/src/vame/preprocessing/filter.py +++ b/src/vame/preprocessing/filter.py @@ -14,9 +14,22 @@ def savgol_filtering( config: dict, read_from_variable: str = "position_processed", save_to_variable: str = "position_processed", -): +) -> None: """ Apply Savitzky-Golay filter to the data. + + Parameters: + ----------- + config : dict + Configuration dictionary. + read_from_variable : str, optional + Variable to read from the dataset. + save_to_variable : str, optional + Variable to save the filtered data to. + + Returns: + -------- + None """ logger.info("Applying Savitzky-Golay filter...") project_path = config["project_path"] diff --git a/src/vame/preprocessing/visualization.py b/src/vame/preprocessing/visualization.py index 337e9adf..946d8ad9 100644 --- a/src/vame/preprocessing/visualization.py +++ b/src/vame/preprocessing/visualization.py @@ -44,7 +44,9 @@ def visualize_preprocessing_scatter( x_orig, y_orig = original_positions[frame, 0, :, 0], original_positions[frame, 0, :, 1] # Identify keypoints that are NaN - nan_keypoints = [keypoints_labels[k] for k in range(len(keypoints_labels)) if np.isnan(x_orig[k]) or np.isnan(y_orig[k])] + nan_keypoints = [ + keypoints_labels[k] for k in range(len(keypoints_labels)) if np.isnan(x_orig[k]) or np.isnan(y_orig[k]) + ] # Check if original positions contain all NaNs if np.all(np.isnan(x_orig)) or np.all(np.isnan(y_orig)): From d6e93c73b575b3bafd6728dfd1d2ea77c19ff772 Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 26 Dec 2024 19:29:31 +0100 Subject: [PATCH 38/77] docstring --- src/vame/preprocessing/preprocessing.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/vame/preprocessing/preprocessing.py b/src/vame/preprocessing/preprocessing.py index 39d48fae..3eb15869 100644 --- a/src/vame/preprocessing/preprocessing.py +++ b/src/vame/preprocessing/preprocessing.py @@ -17,8 +17,27 @@ def preprocessing( centered_reference_keypoint: str = "snout", orientation_reference_keypoint: str = "tailbase", save_logs: bool = False, -): +) -> None: + """ + Preprocess the data by: + - Cleaning low confidence data points + - Egocentric alignment + - Outlier cleaning + - Savitzky-Golay filtering + Parameters: + ----------- + config : dict + Configuration dictionary. + centered_reference_keypoint : str, optional + Keypoint to use as centered reference. + orientation_reference_keypoint : str, optional + Keypoint to use as orientation reference. + + Returns: + -------- + None + """ # Low-confidence cleaning logger.info("Cleaning low confidence data points...") lowconf_cleaning( From f9b4e5c100739fbcd5c801bd0200a5886b13c6fd Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Dec 2024 09:19:54 +0100 Subject: [PATCH 39/77] try to fix tests for windows --- tests/conftest.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b75ce3e2..ee29e3f9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,9 +49,9 @@ def init_project( @fixture(scope="session") def setup_project_from_folder(): project_name = "test_project_from_folder" - videos = ["./tests/tests_project_sample_data"] - poses_estimations = ["./tests/tests_project_sample_data"] - working_directory = "./tests" + videos = [str(Path("./tests/tests_project_sample_data").resolve())] + poses_estimations = [str(Path("./tests/tests_project_sample_data").resolve())] + working_directory = str(Path("./tests").resolve()) # Initialize project project_data = init_project( @@ -73,9 +73,9 @@ def setup_project_from_folder(): @fixture(scope="session") def setup_project_not_aligned_data(): project_name = "test_project_align" - videos = ["./tests/tests_project_sample_data/cropped_video.mp4"] - poses_estimations = ["./tests/tests_project_sample_data/cropped_video.csv"] - working_directory = "./tests" + videos = [str(Path("./tests/tests_project_sample_data/cropped_video.mp4").resolve())] + poses_estimations = [str(Path("./tests/tests_project_sample_data/cropped_video.csv").resolve())] + working_directory = str(Path("./tests").resolve()) # Initialize project project_data = init_project( @@ -98,9 +98,9 @@ def setup_project_not_aligned_data(): @fixture(scope="session") def setup_project_fixed_data(): project_name = "test_project_fixed" - videos = ["./tests/tests_project_sample_data/cropped_video.mp4"] - poses_estimations = ["./tests/tests_project_sample_data/cropped_video.csv"] - working_directory = "./tests" + videos = [str(Path("./tests/tests_project_sample_data/cropped_video.mp4").resolve())] + poses_estimations = [str(Path("./tests/tests_project_sample_data/cropped_video.csv").resolve())] + working_directory = str(Path("./tests").resolve()) # Initialize project project_data = init_project( From 8323b5d4e8ace2562aa0fdba55a6c7c1a4eea148 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Dec 2024 09:48:51 +0100 Subject: [PATCH 40/77] early stop tests --- .github/workflows/testing.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/testing.yaml b/.github/workflows/testing.yaml index 3289e280..49745dca 100644 --- a/.github/workflows/testing.yaml +++ b/.github/workflows/testing.yaml @@ -31,7 +31,7 @@ jobs: pip install -r tests/requirements-tests.txt --no-cache-dir - name: Run tests. - run: pytest --cov=src/vame --cov-report=xml --cov-report=term-missing -v + run: pytest --cov=src/vame --cov-report=xml --cov-report=term-missing -vx - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v4.0.1 From f96a965bcf663262385766ed908de8809874b495 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Dec 2024 10:05:01 +0100 Subject: [PATCH 41/77] try better teardown strategy --- tests/conftest.py | 33 +++++++++++++++++++++++++++++---- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index ee29e3f9..aab17bb8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,12 @@ from pytest import fixture, raises from pathlib import Path import shutil +import psutil +import time from typing import List, Optional, Literal import vame -from vame.util.auxiliary import read_config, write_config +from vame.util.auxiliary import write_config def init_project( @@ -46,6 +48,29 @@ def init_project( return project_data +def cleanup_directory(directory): + """Helper function to clean up the directory and handle Windows-specific issues.""" + try: + # Wait a moment to ensure all files are closed + time.sleep(1) + + # Check for any open file handles and warn about them + for proc in psutil.process_iter(['open_files']): + if any(file.path.startswith(str(directory)) for file in proc.info['open_files'] or []): + print(f"Process {proc.pid} is holding files in {directory}.") + + # Try to delete the directory + shutil.rmtree(directory) + except PermissionError as e: + print(f"PermissionError during cleanup: {e}. Retrying...") + # Retry after a short delay + time.sleep(2) + try: + shutil.rmtree(directory) + except Exception as final_error: + print(f"Final cleanup failed: {final_error}") + + @fixture(scope="session") def setup_project_from_folder(): project_name = "test_project_from_folder" @@ -67,7 +92,7 @@ def setup_project_from_folder(): # Clean up config_path = project_data["config_path"] - shutil.rmtree(Path(config_path).parent) + cleanup_directory(Path(config_path).parent) @fixture(scope="session") @@ -91,7 +116,7 @@ def setup_project_not_aligned_data(): # Clean up config_path = project_data["config_path"] - shutil.rmtree(Path(config_path).parent) + cleanup_directory(Path(config_path).parent) # # TODO change to test fixed (already egocentrically aligned) data when have it @@ -116,7 +141,7 @@ def setup_project_fixed_data(): # Clean up config_path = project_data["config_path"] - shutil.rmtree(Path(config_path).parent) + cleanup_directory(Path(config_path).parent) # @fixture(scope="session") From a9bf2bc39a2b9b8323d0fdba0ce1408a39520f7e Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Dec 2024 10:09:09 +0100 Subject: [PATCH 42/77] r --- tests/requirements-tests.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/requirements-tests.txt b/tests/requirements-tests.txt index 00d5b63d..3bd2f0a0 100644 --- a/tests/requirements-tests.txt +++ b/tests/requirements-tests.txt @@ -1,2 +1,3 @@ pytest==8.2.0 -pytest-cov==5.0.0 \ No newline at end of file +pytest-cov==5.0.0 +psutil==6.1.1 \ No newline at end of file From 9ef8ca68255429885cd08ee94b2894b157d4dcba Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Dec 2024 10:30:51 +0100 Subject: [PATCH 43/77] rename tests --- tests/{test_pipeline.py => 01_pipeline_test.py} | 0 .../{test_initialize_project.py => 02_initialize_project_test.py} | 0 tests/{test_util.py => 03_util_test.py} | 0 tests/{test_model.py => 04_model_test.py} | 0 tests/{test_analysis.py => 05_analysis_test.py} | 0 5 files changed, 0 insertions(+), 0 deletions(-) rename tests/{test_pipeline.py => 01_pipeline_test.py} (100%) rename tests/{test_initialize_project.py => 02_initialize_project_test.py} (100%) rename tests/{test_util.py => 03_util_test.py} (100%) rename tests/{test_model.py => 04_model_test.py} (100%) rename tests/{test_analysis.py => 05_analysis_test.py} (100%) diff --git a/tests/test_pipeline.py b/tests/01_pipeline_test.py similarity index 100% rename from tests/test_pipeline.py rename to tests/01_pipeline_test.py diff --git a/tests/test_initialize_project.py b/tests/02_initialize_project_test.py similarity index 100% rename from tests/test_initialize_project.py rename to tests/02_initialize_project_test.py diff --git a/tests/test_util.py b/tests/03_util_test.py similarity index 100% rename from tests/test_util.py rename to tests/03_util_test.py diff --git a/tests/test_model.py b/tests/04_model_test.py similarity index 100% rename from tests/test_model.py rename to tests/04_model_test.py diff --git a/tests/test_analysis.py b/tests/05_analysis_test.py similarity index 100% rename from tests/test_analysis.py rename to tests/05_analysis_test.py From 8f5a26a2565ec22662ae9d820232404034f8b04a Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Dec 2024 10:40:57 +0100 Subject: [PATCH 44/77] pipeline test win --- tests/01_pipeline_test.py | 34 ---------------------------------- tests/conftest.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/tests/01_pipeline_test.py b/tests/01_pipeline_test.py index f8c25934..9e775a85 100644 --- a/tests/01_pipeline_test.py +++ b/tests/01_pipeline_test.py @@ -1,38 +1,4 @@ -from pytest import fixture -from pathlib import Path -import shutil import xarray as xr -from vame.pipeline import VAMEPipeline - - -@fixture(scope="session") -def setup_pipeline(): - """ - Setup a Pipeline for testing. - """ - project_name = "test_pipeline" - videos = ["./tests/tests_project_sample_data"] - poses_estimations = ["./tests/tests_project_sample_data"] - working_directory = "./tests" - source_software = "DeepLabCut" - - config_kwargs = { - "egocentric_data": False, - "max_epochs": 10, - "batch_size": 10, - } - pipeline = VAMEPipeline( - working_directory=working_directory, - project_name=project_name, - videos=videos, - poses_estimations=poses_estimations, - source_software=source_software, - config_kwargs=config_kwargs, - ) - yield {"pipeline": pipeline} - - # Clean up - shutil.rmtree(Path(pipeline.config_path).parent) def test_pipeline(setup_pipeline): diff --git a/tests/conftest.py b/tests/conftest.py index aab17bb8..c82bc07c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ from typing import List, Optional, Literal import vame +from vame.pipeline import VAMEPipeline from vame.util.auxiliary import write_config @@ -250,3 +251,33 @@ def setup_project_and_evaluate_model(setup_project_and_train_model): config = setup_project_and_train_model["config_data"] vame.evaluate_model(config, save_logs=True) return setup_project_and_train_model + + +@fixture(scope="session") +def setup_pipeline(): + """ + Setup a Pipeline for testing. + """ + project_name = "test_pipeline" + videos = ["./tests/tests_project_sample_data"] + poses_estimations = ["./tests/tests_project_sample_data"] + working_directory = "./tests" + source_software = "DeepLabCut" + + config_kwargs = { + "egocentric_data": False, + "max_epochs": 10, + "batch_size": 10, + } + pipeline = VAMEPipeline( + working_directory=working_directory, + project_name=project_name, + videos=videos, + poses_estimations=poses_estimations, + source_software=source_software, + config_kwargs=config_kwargs, + ) + yield {"pipeline": pipeline} + + # Clean up + cleanup_directory(Path(pipeline.config_path).parent) From f4be51477eba2a5c3df2d641b4ae3c8374e79a1a Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Dec 2024 10:56:59 +0100 Subject: [PATCH 45/77] try to save to netcdf with Path --- src/vame/preprocessing/cleaning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vame/preprocessing/cleaning.py b/src/vame/preprocessing/cleaning.py index 8e63eb71..93047d28 100644 --- a/src/vame/preprocessing/cleaning.py +++ b/src/vame/preprocessing/cleaning.py @@ -81,7 +81,7 @@ def lowconf_cleaning( ds["percentage_low_confidence"] = (["individual", "keypoint", "space"], perc_interp_points) # Save the cleaned dataset to file - cleaned_file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") + cleaned_file_path = Path(project_path) / "data" / "processed" / f"{session}_processed.nc" ds.to_netcdf( path=cleaned_file_path, engine="scipy", From 83f23161946921f32cdd4b6460d255557f62f6ba Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Dec 2024 11:15:20 +0100 Subject: [PATCH 46/77] fix for windows --- VAME.yaml | 3 ++- src/vame/io/load_poses.py | 8 ++++++-- tests/conftest.py | 6 +++--- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/VAME.yaml b/VAME.yaml index feb3f0fe..dc0d7eec 100644 --- a/VAME.yaml +++ b/VAME.yaml @@ -9,4 +9,5 @@ dependencies: - python=3.11 - pip - pip: - - -r requirements.txt \ No newline at end of file + - -r requirements.txt + - . \ No newline at end of file diff --git a/src/vame/io/load_poses.py b/src/vame/io/load_poses.py index de61f6eb..c55b2d34 100644 --- a/src/vame/io/load_poses.py +++ b/src/vame/io/load_poses.py @@ -54,9 +54,13 @@ def load_vame_dataset(ds_path: Path | str) -> xr.Dataset: Returns: -------- """ - return xr.open_dataset(ds_path, engine="scipy") - + # Windows will not allow opened files to be overwritten, + # so we need to load data into memory, close the file and move on with the operations + with xr.open_dataset(ds_path, engine="scipy") as tmp_ds: + ds_in_memory = tmp_ds.load() # read entire file into memory + return ds_in_memory + def nc_to_dataframe(nc_data): keypoints = nc_data["keypoints"].values space = nc_data["space"].values diff --git a/tests/conftest.py b/tests/conftest.py index c82bc07c..1b127d13 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -259,9 +259,9 @@ def setup_pipeline(): Setup a Pipeline for testing. """ project_name = "test_pipeline" - videos = ["./tests/tests_project_sample_data"] - poses_estimations = ["./tests/tests_project_sample_data"] - working_directory = "./tests" + videos = [str(Path("./tests/tests_project_sample_data").resolve())] + poses_estimations = [str(Path("./tests/tests_project_sample_data").resolve())] + working_directory = str(Path("./tests").resolve()) source_software = "DeepLabCut" config_kwargs = { From 3ce74a2c90c7174443340df056fbaeff1b341792 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Dec 2024 14:25:29 +0100 Subject: [PATCH 47/77] comment legacy code --- src/vame/__init__.py | 2 +- .../align_egocentrical_legacy.py | 1390 ++++++++--------- 2 files changed, 696 insertions(+), 696 deletions(-) diff --git a/src/vame/__init__.py b/src/vame/__init__.py index 288650bb..14a59332 100644 --- a/src/vame/__init__.py +++ b/src/vame/__init__.py @@ -15,7 +15,7 @@ from vame.analysis import generative_model from vame.analysis import gif from vame.util.csv_to_npy import pose_to_numpy -from vame.preprocessing.align_egocentrical_legacy import egocentric_alignment_legacy +# from vame.preprocessing.align_egocentrical_legacy import egocentric_alignment_legacy from vame.util import model_util from vame.util import auxiliary from vame.util.report import report diff --git a/src/vame/preprocessing/align_egocentrical_legacy.py b/src/vame/preprocessing/align_egocentrical_legacy.py index 947e1705..f594faa4 100644 --- a/src/vame/preprocessing/align_egocentrical_legacy.py +++ b/src/vame/preprocessing/align_egocentrical_legacy.py @@ -1,698 +1,698 @@ -import os -import cv2 as cv -import numpy as np -import pandas as pd -import tqdm -from typing import Tuple, List, Union -from pathlib import Path - -from vame.logging.logger import VameLogger, TqdmToLogger -from vame.util.auxiliary import read_config -from vame.schemas.states import EgocentricAlignmentFunctionSchema, save_state -from vame.schemas.project import PoseEstimationFiletype -from vame.io.load_poses import read_pose_estimation_file -from vame.util.data_manipulation import ( - interpol_first_rows_nans, - crop_and_flip_legacy, - background, -) -from vame.video import get_video_frame_rate - - -logger_config = VameLogger(__name__) -logger = logger_config.logger - - -def align_mouse_legacy( - project_path: str, - session: str, - video_format: str, - crop_size: Tuple[int, int], - pose_list: List[np.ndarray], - pose_ref_index: Tuple[int, int], - confidence: float, - pose_flip_ref: Tuple[int, int], - bg: np.ndarray, - frame_count: int, - use_video: bool = True, - tqdm_stream: Union[TqdmToLogger, None] = None, -) -> Tuple[List[np.ndarray], List[List[np.ndarray]], np.ndarray]: - """ - Align the mouse in the video frames. - - Parameters: - ----------- - project_path : str - Path to the project directory. - session : str - Name of the session. - video_format : str - Format of the video file. - crop_size : Tuple[int, int] - Size to crop the video frames. - pose_list : List[np.ndarray] - List of pose coordinates. - pose_ref_index : Tuple[int, int] - Pose reference indices. - confidence : float - Pose confidence threshold. - pose_flip_ref : Tuple[int, int] - Reference indices for flipping. - bg : np.ndarray - Background image. - frame_count : int - Number of frames to align. - use_video : bool, optional - bool if video should be cropped or DLC points only. Defaults to True. - tqdm_stream : Union[TqdmToLogger, None], optional - Tqdm stream to log the progress. Defaults to None. - - Returns - ------- - Tuple[List[np.ndarray], List[List[np.ndarray]], np.ndarray] - List of aligned images, list of aligned DLC points, and aligned time series data. - """ - images = [] - points = [] - for i in pose_list: - for j in i: - if j[2] <= confidence: - j[0], j[1] = np.nan, np.nan - - for i in pose_list: - i = interpol_first_rows_nans(i) - - if use_video: - video_path = str( - os.path.join( - project_path, - "data", - "raw", - session + video_format, - ) - ) - capture = cv.VideoCapture(video_path) - if not capture.isOpened(): - raise Exception(f"Unable to open video file: {video_path}") - - for idx in tqdm.tqdm( - range(frame_count), - disable=not True, - file=tqdm_stream, - desc="Align frames", - ): - if use_video: - # Read frame - try: - ret, frame = capture.read() - frame = cv.cvtColor(frame, cv.COLOR_BGR2GRAY) - frame = frame - bg - frame[frame <= 0] = 0 - except Exception: - logger.info("Couldn't find a frame in capture.read(). #Frame: %d" % idx) - continue - else: - frame = np.zeros((1, 1)) - - # Read coordinates and add border - pose_list_bordered = [] - - for i in pose_list: - pose_list_bordered.append((int(i[idx][0] + crop_size[0]), int(i[idx][1] + crop_size[1]))) - - img = cv.copyMakeBorder( - frame, - crop_size[1], - crop_size[1], - crop_size[0], - crop_size[0], - cv.BORDER_CONSTANT, - 0, - ) - - coord_center = [] - punkte = [] - - for i in pose_ref_index: - coord = [] - # changed from pose_list_bordered[i][0] 2/28/2024 PN - coord.append(pose_list_bordered[i][0]) - # changed from pose_list_bordered[i][1] 2/28/2024 PN - coord.append(pose_list_bordered[i][1]) - punkte.append(coord) - - # coord_center.append(pose_list_bordered[5][0]-5) - # coord_center.append(pose_list_bordered[5][0]+5) - # coord_center = [coord_center] - punkte = [punkte] - - # coord_center = np.asarray(coord_center) - punkte = np.asarray(punkte) - - # calculate minimal rectangle around snout and tail - rect = cv.minAreaRect(punkte) - # rect_belly = cv.minAreaRect(coord_center) - # center_belly, size_belly, theta_belly = rect_belly - # change size in rect tuple structure to be equal to crop_size - lst = list(rect) - lst[1] = crop_size - # lst[0] = center_belly - rect = tuple(lst) - center, size, theta = rect - - # crop image - out, shifted_points = crop_and_flip_legacy( - rect, - img, - pose_list_bordered, - pose_flip_ref, - ) - - if use_video: # for memory optimization, just save images when video is used. - images.append(out) - points.append(shifted_points) - - if use_video: - capture.release() - - time_series = np.zeros((len(pose_list) * 2, frame_count)) - for i in range(frame_count): - idx = 0 - for j in range(len(pose_list)): - time_series[idx : idx + 2, i] = points[i][j] - idx += 2 - - return images, points, time_series - - -def alignment_legacy( - project_path: str, - session: str, - pose_ref_index: Tuple[int, int], - video_format: str, - crop_size: Tuple[int, int], - confidence: float, - pose_estimation_filetype: PoseEstimationFiletype, - path_to_pose_nwb_series_data: Union[str, None] = None, - use_video: bool = False, - tqdm_stream: Union[TqdmToLogger, None] = None, -) -> Tuple[np.ndarray, List[np.ndarray]]: - """ - Perform alignment of egocentric data. - - Parameters: - ----------- - project_path : str - Path to the project directory. - session : str - Name of the session. - pose_ref_index : List[int] - Pose reference indices. - video_format : str - Format of the video file. - crop_size : Tuple[int, int] - Size to crop the video frames. - confidence : float - Pose confidence threshold. - pose_estimation_filetype : PoseEstimationFiletype - Pose estimation file type. Can be .csv or .nwb. - path_to_pose_nwb_series_data : Union[str, None], optional - Path to the pose series data in nwb files. Defaults to None. - use_video : bool, optional - Whether to use video for alignment. Defaults to False. - tqdm_stream : Union[TqdmToLogger, None], optional - Tqdm stream to log the progress. Defaults to None. - - Returns - ------- - Tuple[np.ndarray, List[np.ndarray]] - Aligned time series data and list of aligned frames. - """ - # read out data - file_path = str(Path(project_path) / "data" / "raw" / f"{session}.nc") - data, data_mat, _ = read_pose_estimation_file( - file_path=file_path, - file_type=pose_estimation_filetype, - path_to_pose_nwb_series_data=path_to_pose_nwb_series_data, - ) - - # get the coordinates for alignment from data table - # pose_list dimensions: (num_body_parts, num_frames, 3) - pose_list = [] - for i in range(int(data_mat.shape[1] / 3)): - pose_list.append(data_mat[:, i * 3 : (i + 1) * 3]) - - # list of reference coordinate indices for alignment - # 0: snout, 1: forehand_left, 2: forehand_right, - # 3: hindleft, 4: hindright, 5: tail - # list of 2 reference coordinate indices for avoiding flipping - pose_flip_ref = pose_ref_index - - if use_video: - # compute background - video_path = str( - os.path.join( - project_path, - "data", - "raw", - session + video_format, - ) - ) - bg = background( - project_path=project_path, - session=session, - video_path=video_path, - save_background=False, - ) - frame_count = get_video_frame_rate(video_path) - else: - bg = 0 - # Change this to an abitrary number if you first want to test the code - frame_count = len(data) - - frames, n, time_series = align_mouse_legacy( - project_path=project_path, - session=session, - video_format=video_format, - crop_size=crop_size, - pose_list=pose_list, - pose_ref_index=pose_ref_index, - confidence=confidence, - pose_flip_ref=pose_flip_ref, - bg=bg, - frame_count=frame_count, - use_video=use_video, - tqdm_stream=tqdm_stream, - ) - - return time_series, frames +# import os +# import cv2 as cv +# import numpy as np +# import pandas as pd +# import tqdm +# from typing import Tuple, List, Union +# from pathlib import Path + +# from vame.logging.logger import VameLogger, TqdmToLogger +# from vame.util.auxiliary import read_config +# from vame.schemas.states import EgocentricAlignmentFunctionSchema, save_state +# from vame.schemas.project import PoseEstimationFiletype +# from vame.io.load_poses import read_pose_estimation_file +# from vame.util.data_manipulation import ( +# interpol_first_rows_nans, +# crop_and_flip_legacy, +# background, +# ) +# from vame.video import get_video_frame_rate + + +# logger_config = VameLogger(__name__) +# logger = logger_config.logger + + +# def align_mouse_legacy( +# project_path: str, +# session: str, +# video_format: str, +# crop_size: Tuple[int, int], +# pose_list: List[np.ndarray], +# pose_ref_index: Tuple[int, int], +# confidence: float, +# pose_flip_ref: Tuple[int, int], +# bg: np.ndarray, +# frame_count: int, +# use_video: bool = True, +# tqdm_stream: Union[TqdmToLogger, None] = None, +# ) -> Tuple[List[np.ndarray], List[List[np.ndarray]], np.ndarray]: +# """ +# Align the mouse in the video frames. + +# Parameters: +# ----------- +# project_path : str +# Path to the project directory. +# session : str +# Name of the session. +# video_format : str +# Format of the video file. +# crop_size : Tuple[int, int] +# Size to crop the video frames. +# pose_list : List[np.ndarray] +# List of pose coordinates. +# pose_ref_index : Tuple[int, int] +# Pose reference indices. +# confidence : float +# Pose confidence threshold. +# pose_flip_ref : Tuple[int, int] +# Reference indices for flipping. +# bg : np.ndarray +# Background image. +# frame_count : int +# Number of frames to align. +# use_video : bool, optional +# bool if video should be cropped or DLC points only. Defaults to True. +# tqdm_stream : Union[TqdmToLogger, None], optional +# Tqdm stream to log the progress. Defaults to None. + +# Returns +# ------- +# Tuple[List[np.ndarray], List[List[np.ndarray]], np.ndarray] +# List of aligned images, list of aligned DLC points, and aligned time series data. +# """ +# images = [] +# points = [] +# for i in pose_list: +# for j in i: +# if j[2] <= confidence: +# j[0], j[1] = np.nan, np.nan + +# for i in pose_list: +# i = interpol_first_rows_nans(i) + +# if use_video: +# video_path = str( +# os.path.join( +# project_path, +# "data", +# "raw", +# session + video_format, +# ) +# ) +# capture = cv.VideoCapture(video_path) +# if not capture.isOpened(): +# raise Exception(f"Unable to open video file: {video_path}") + +# for idx in tqdm.tqdm( +# range(frame_count), +# disable=not True, +# file=tqdm_stream, +# desc="Align frames", +# ): +# if use_video: +# # Read frame +# try: +# ret, frame = capture.read() +# frame = cv.cvtColor(frame, cv.COLOR_BGR2GRAY) +# frame = frame - bg +# frame[frame <= 0] = 0 +# except Exception: +# logger.info("Couldn't find a frame in capture.read(). #Frame: %d" % idx) +# continue +# else: +# frame = np.zeros((1, 1)) + +# # Read coordinates and add border +# pose_list_bordered = [] + +# for i in pose_list: +# pose_list_bordered.append((int(i[idx][0] + crop_size[0]), int(i[idx][1] + crop_size[1]))) + +# img = cv.copyMakeBorder( +# frame, +# crop_size[1], +# crop_size[1], +# crop_size[0], +# crop_size[0], +# cv.BORDER_CONSTANT, +# 0, +# ) + +# coord_center = [] +# punkte = [] + +# for i in pose_ref_index: +# coord = [] +# # changed from pose_list_bordered[i][0] 2/28/2024 PN +# coord.append(pose_list_bordered[i][0]) +# # changed from pose_list_bordered[i][1] 2/28/2024 PN +# coord.append(pose_list_bordered[i][1]) +# punkte.append(coord) + +# # coord_center.append(pose_list_bordered[5][0]-5) +# # coord_center.append(pose_list_bordered[5][0]+5) +# # coord_center = [coord_center] +# punkte = [punkte] + +# # coord_center = np.asarray(coord_center) +# punkte = np.asarray(punkte) + +# # calculate minimal rectangle around snout and tail +# rect = cv.minAreaRect(punkte) +# # rect_belly = cv.minAreaRect(coord_center) +# # center_belly, size_belly, theta_belly = rect_belly +# # change size in rect tuple structure to be equal to crop_size +# lst = list(rect) +# lst[1] = crop_size +# # lst[0] = center_belly +# rect = tuple(lst) +# center, size, theta = rect + +# # crop image +# out, shifted_points = crop_and_flip_legacy( +# rect, +# img, +# pose_list_bordered, +# pose_flip_ref, +# ) + +# if use_video: # for memory optimization, just save images when video is used. +# images.append(out) +# points.append(shifted_points) + +# if use_video: +# capture.release() + +# time_series = np.zeros((len(pose_list) * 2, frame_count)) +# for i in range(frame_count): +# idx = 0 +# for j in range(len(pose_list)): +# time_series[idx : idx + 2, i] = points[i][j] +# idx += 2 + +# return images, points, time_series + + +# def alignment_legacy( +# project_path: str, +# session: str, +# pose_ref_index: Tuple[int, int], +# video_format: str, +# crop_size: Tuple[int, int], +# confidence: float, +# pose_estimation_filetype: PoseEstimationFiletype, +# path_to_pose_nwb_series_data: Union[str, None] = None, +# use_video: bool = False, +# tqdm_stream: Union[TqdmToLogger, None] = None, +# ) -> Tuple[np.ndarray, List[np.ndarray]]: +# """ +# Perform alignment of egocentric data. + +# Parameters: +# ----------- +# project_path : str +# Path to the project directory. +# session : str +# Name of the session. +# pose_ref_index : List[int] +# Pose reference indices. +# video_format : str +# Format of the video file. +# crop_size : Tuple[int, int] +# Size to crop the video frames. +# confidence : float +# Pose confidence threshold. +# pose_estimation_filetype : PoseEstimationFiletype +# Pose estimation file type. Can be .csv or .nwb. +# path_to_pose_nwb_series_data : Union[str, None], optional +# Path to the pose series data in nwb files. Defaults to None. +# use_video : bool, optional +# Whether to use video for alignment. Defaults to False. +# tqdm_stream : Union[TqdmToLogger, None], optional +# Tqdm stream to log the progress. Defaults to None. + +# Returns +# ------- +# Tuple[np.ndarray, List[np.ndarray]] +# Aligned time series data and list of aligned frames. +# """ +# # read out data +# file_path = str(Path(project_path) / "data" / "raw" / f"{session}.nc") +# data, data_mat, _ = read_pose_estimation_file( +# file_path=file_path, +# file_type=pose_estimation_filetype, +# path_to_pose_nwb_series_data=path_to_pose_nwb_series_data, +# ) + +# # get the coordinates for alignment from data table +# # pose_list dimensions: (num_body_parts, num_frames, 3) +# pose_list = [] +# for i in range(int(data_mat.shape[1] / 3)): +# pose_list.append(data_mat[:, i * 3 : (i + 1) * 3]) + +# # list of reference coordinate indices for alignment +# # 0: snout, 1: forehand_left, 2: forehand_right, +# # 3: hindleft, 4: hindright, 5: tail +# # list of 2 reference coordinate indices for avoiding flipping +# pose_flip_ref = pose_ref_index + +# if use_video: +# # compute background +# video_path = str( +# os.path.join( +# project_path, +# "data", +# "raw", +# session + video_format, +# ) +# ) +# bg = background( +# project_path=project_path, +# session=session, +# video_path=video_path, +# save_background=False, +# ) +# frame_count = get_video_frame_rate(video_path) +# else: +# bg = 0 +# # Change this to an abitrary number if you first want to test the code +# frame_count = len(data) + +# frames, n, time_series = align_mouse_legacy( +# project_path=project_path, +# session=session, +# video_format=video_format, +# crop_size=crop_size, +# pose_list=pose_list, +# pose_ref_index=pose_ref_index, +# confidence=confidence, +# pose_flip_ref=pose_flip_ref, +# bg=bg, +# frame_count=frame_count, +# use_video=use_video, +# tqdm_stream=tqdm_stream, +# ) + +# return time_series, frames + + +# # @save_state(model=EgocentricAlignmentFunctionSchema) +# def egocentric_alignment_legacy( +# config: str, +# pose_ref_index: Tuple[int, int] = (0, 1), +# crop_size: Tuple[int, int] = (300, 300), +# use_video: bool = False, +# video_format: str = ".mp4", +# check_video: bool = False, +# save_logs: bool = False, +# ) -> None: +# """ +# Egocentric alignment of bevarioral videos. +# Fills in the values in the "egocentric_alignment" key of the states.json file. +# Creates training dataset for VAME at: +# - project_name/ +# - data/ +# - filename/ +# - filename-PE-seq.npy +# - filename/ +# - filename-PE-seq.npy +# The produced .npy files contain the aligned time series data in the +# shape of (num_dlc_features, num_video_frames). + +# Parameters +# ---------- +# config : str +# Path for the project config file. +# pose_ref_index : list, optional +# Pose reference index to be used to align. Defaults to [0, 1]. +# crop_size : tuple, optional +# Size to crop the video. Defaults to (300,300). +# use_video : bool, optional +# Weather to use video to do the post alignment. Defaults to False. +# video_format : str, optional +# Video format, can be .mp4 or .avi. Defaults to '.mp4'. +# check_video : bool, optional +# Weather to check the video. Defaults to False. + +# Raises: +# ------ +# ValueError +# If the config.yaml indicates that the data is not egocentric. +# """ +# try: +# config_file = Path(config).resolve() +# cfg = read_config(str(config_file)) +# if cfg["egocentric_data"]: +# raise ValueError( +# "The config.yaml indicates that the data is egocentric. Please check the parameter 'egocentric_data'." +# ) +# tqdm_stream = None + +# if save_logs: +# log_path = Path(cfg["project_path"]) / "logs" / "egocentric_alignment.log" +# logger_config.add_file_handler(str(log_path)) +# tqdm_stream = TqdmToLogger(logger=logger) + +# logger.info("Starting egocentric alignment") +# project_path = cfg["project_path"] +# sessions = cfg["session_names"] +# confidence = cfg["pose_confidence"] +# num_features = cfg["num_features"] +# video_format = video_format +# crop_size = crop_size + +# y_shifted_indices = np.arange(0, num_features, 2) +# x_shifted_indices = np.arange(1, num_features, 2) +# belly_Y_ind = pose_ref_index[0] * 2 +# belly_X_ind = (pose_ref_index[0] * 2) + 1 + +# # call function and save into your VAME data folder +# paths_to_pose_nwb_series_data = cfg["paths_to_pose_nwb_series_data"] +# for i, session in enumerate(sessions): +# logger.info("Aligning session %s, Pose confidence value: %.2f" % (session, confidence)) +# egocentric_time_series, frames = alignment_legacy( +# project_path=project_path, +# session=session, +# pose_ref_index=pose_ref_index, +# video_format=video_format, +# crop_size=crop_size, +# confidence=confidence, +# pose_estimation_filetype=cfg["pose_estimation_filetype"], +# path_to_pose_nwb_series_data=( +# paths_to_pose_nwb_series_data +# if not paths_to_pose_nwb_series_data +# else paths_to_pose_nwb_series_data[i] +# ), +# use_video=use_video, +# tqdm_stream=tqdm_stream, +# ) + +# # Shifiting section added 2/29/2024 PN +# # TODO - should this be hardcoded like that? +# egocentric_time_series_shifted = egocentric_time_series +# belly_Y_shift = egocentric_time_series[belly_Y_ind, :] +# belly_X_shift = egocentric_time_series[belly_X_ind, :] + +# egocentric_time_series_shifted[y_shifted_indices, :] -= belly_Y_shift +# egocentric_time_series_shifted[x_shifted_indices, :] -= belly_X_shift + +# # Save new shifted file +# np.save( +# os.path.join( +# project_path, +# "data", +# "processed", +# session, +# session + "-PE-seq-legacy.npy", +# ), +# egocentric_time_series_shifted, +# ) + +# logger.info("Your data is now in the right format and you can call vame.create_trainset()") +# except Exception as e: +# logger.exception(f"{e}") +# raise e +# finally: +# logger_config.remove_file_handler() # @save_state(model=EgocentricAlignmentFunctionSchema) -def egocentric_alignment_legacy( - config: str, - pose_ref_index: Tuple[int, int] = (0, 1), - crop_size: Tuple[int, int] = (300, 300), - use_video: bool = False, - video_format: str = ".mp4", - check_video: bool = False, - save_logs: bool = False, -) -> None: - """ - Egocentric alignment of bevarioral videos. - Fills in the values in the "egocentric_alignment" key of the states.json file. - Creates training dataset for VAME at: - - project_name/ - - data/ - - filename/ - - filename-PE-seq.npy - - filename/ - - filename-PE-seq.npy - The produced .npy files contain the aligned time series data in the - shape of (num_dlc_features, num_video_frames). - - Parameters - ---------- - config : str - Path for the project config file. - pose_ref_index : list, optional - Pose reference index to be used to align. Defaults to [0, 1]. - crop_size : tuple, optional - Size to crop the video. Defaults to (300,300). - use_video : bool, optional - Weather to use video to do the post alignment. Defaults to False. - video_format : str, optional - Video format, can be .mp4 or .avi. Defaults to '.mp4'. - check_video : bool, optional - Weather to check the video. Defaults to False. - - Raises: - ------ - ValueError - If the config.yaml indicates that the data is not egocentric. - """ - try: - config_file = Path(config).resolve() - cfg = read_config(str(config_file)) - if cfg["egocentric_data"]: - raise ValueError( - "The config.yaml indicates that the data is egocentric. Please check the parameter 'egocentric_data'." - ) - tqdm_stream = None - - if save_logs: - log_path = Path(cfg["project_path"]) / "logs" / "egocentric_alignment.log" - logger_config.add_file_handler(str(log_path)) - tqdm_stream = TqdmToLogger(logger=logger) - - logger.info("Starting egocentric alignment") - project_path = cfg["project_path"] - sessions = cfg["session_names"] - confidence = cfg["pose_confidence"] - num_features = cfg["num_features"] - video_format = video_format - crop_size = crop_size - - y_shifted_indices = np.arange(0, num_features, 2) - x_shifted_indices = np.arange(1, num_features, 2) - belly_Y_ind = pose_ref_index[0] * 2 - belly_X_ind = (pose_ref_index[0] * 2) + 1 - - # call function and save into your VAME data folder - paths_to_pose_nwb_series_data = cfg["paths_to_pose_nwb_series_data"] - for i, session in enumerate(sessions): - logger.info("Aligning session %s, Pose confidence value: %.2f" % (session, confidence)) - egocentric_time_series, frames = alignment_legacy( - project_path=project_path, - session=session, - pose_ref_index=pose_ref_index, - video_format=video_format, - crop_size=crop_size, - confidence=confidence, - pose_estimation_filetype=cfg["pose_estimation_filetype"], - path_to_pose_nwb_series_data=( - paths_to_pose_nwb_series_data - if not paths_to_pose_nwb_series_data - else paths_to_pose_nwb_series_data[i] - ), - use_video=use_video, - tqdm_stream=tqdm_stream, - ) - - # Shifiting section added 2/29/2024 PN - # TODO - should this be hardcoded like that? - egocentric_time_series_shifted = egocentric_time_series - belly_Y_shift = egocentric_time_series[belly_Y_ind, :] - belly_X_shift = egocentric_time_series[belly_X_ind, :] - - egocentric_time_series_shifted[y_shifted_indices, :] -= belly_Y_shift - egocentric_time_series_shifted[x_shifted_indices, :] -= belly_X_shift - - # Save new shifted file - np.save( - os.path.join( - project_path, - "data", - "processed", - session, - session + "-PE-seq-legacy.npy", - ), - egocentric_time_series_shifted, - ) - - logger.info("Your data is now in the right format and you can call vame.create_trainset()") - except Exception as e: - logger.exception(f"{e}") - raise e - finally: - logger_config.remove_file_handler() - - -@save_state(model=EgocentricAlignmentFunctionSchema) -def egocentric_alignment( - config: str, - pose_ref_1: str = "snout", - pose_ref_2: str = "tailbase", - crop_size: Tuple[int, int] = (300, 300), - save_logs: bool = False, -) -> None: - """ - Egocentric alignment of bevarioral videos. - Fills in the values in the "egocentric_alignment" key of the states.json file. - Creates training dataset for VAME at: - - project_name/ - - data/ - - filename/ - - filename-PE-seq.npy - - filename/ - - filename-PE-seq.npy - The produced .npy files contain the aligned time series data in the - shape of (num_dlc_features, num_video_frames). - - Parameters - ---------- - config : str - Path for the project config file. - pose_ref_index : list, optional - Pose reference index to be used to align. Defaults to [0, 1]. - crop_size : tuple, optional - Size to crop the video. Defaults to (300,300). - - Raises: - ------ - ValueError - If the config.yaml indicates that the data is not egocentric. - """ - try: - config_file = Path(config).resolve() - cfg = read_config(str(config_file)) - if cfg["egocentric_data"]: - raise ValueError( - "The config.yaml indicates that the data is egocentric. Please check the parameter 'egocentric_data'." - ) - tqdm_stream = None - - if save_logs: - log_path = Path(cfg["project_path"]) / "logs" / "egocentric_alignment.log" - logger_config.add_file_handler(str(log_path)) - tqdm_stream = TqdmToLogger(logger=logger) - - logger.info("Starting egocentric alignment") - project_path = cfg["project_path"] - sessions = cfg["session_names"] - confidence = cfg["pose_confidence"] - num_features = cfg["num_features"] - - y_shifted_indices = np.arange(0, num_features, 2) - x_shifted_indices = np.arange(1, num_features, 2) - # reference_Y_ind = pose_ref_index[0] * 2 - # reference_X_ind = (pose_ref_index[0] * 2) + 1 - - # call function and save into your VAME data folder - for i, session in enumerate(sessions): - logger.info("Aligning session %s, Pose confidence value: %.2f" % (session, confidence)) - # read out data - file_path = str(Path(project_path) / "data" / "raw" / f"{session}.nc") - _, data_mat, ds = read_pose_estimation_file(file_path=file_path) - - # get the coordinates for alignment from data table - # pose_list dimensions: (num_body_parts, num_frames, 3) - pose_list = [] - for i in range(int(data_mat.shape[1] / 3)): - pose_list.append(data_mat[:, i * 3 : (i + 1) * 3]) - - frame_count = ds.position.time.shape[0] - keypoints_names = ds.keypoints.values - - reference_X_ind = np.where(ds.keypoints.values == pose_ref_1)[0][0] * 2 - reference_Y_ind = reference_X_ind + 1 - - pose_ref_index = ( - np.where(keypoints_names == pose_ref_1)[0][0], - np.where(keypoints_names == pose_ref_2)[0][0], - ) - - egocentric_time_series = alignment( - crop_size=crop_size, - pose_list=pose_list, - pose_ref_index=pose_ref_index, - confidence=confidence, - frame_count=frame_count, - tqdm_stream=tqdm_stream, - ) - - # Shifiting section added 2/29/2024 PN - egocentric_time_series_shifted = egocentric_time_series - reference_Y_shift = egocentric_time_series[reference_Y_ind, :] - reference_X_shift = egocentric_time_series[reference_X_ind, :] - - egocentric_time_series_shifted[y_shifted_indices, :] -= reference_Y_shift - egocentric_time_series_shifted[x_shifted_indices, :] -= reference_X_shift - - # Save new shifted file - np.save( - os.path.join( - project_path, - "data", - "processed", - session, - session + "-PE-seq.npy", - ), - egocentric_time_series_shifted, - ) - - # Add new variable to the dataset - ds["position_aligned"] = ( - ("time", "individuals", "keypoints", "space"), - egocentric_time_series_shifted.T.reshape(frame_count, 1, len(keypoints_names), 2), - ) - # save to file - result_file = Path(project_path) / "data" / "processed" / session / f"{session}-aligned.nc" - ds.to_netcdf(result_file, engine="scipy") - - logger.info("Your data is now in the right format and you can call vame.create_trainset()") - except Exception as e: - logger.exception(f"{e}") - raise e - finally: - logger_config.remove_file_handler() - - -def alignment( - crop_size: Tuple[int, int], - pose_list: List[np.ndarray], - pose_ref_index: Tuple[int, int], - confidence: float, - frame_count: int, - tqdm_stream: Union[TqdmToLogger, None] = None, -) -> np.ndarray: - """ - Egocentric alignment of pose estimation data. - - Parameters: - ----------- - crop_size : Tuple[int, int] - Size to crop the video frames. - pose_list : List[np.ndarray] - List of pose coordinates. - pose_ref_index : Tuple[int, int] - Pose reference indices. - confidence : float - Pose confidence threshold. - frame_count : int - Number of frames to align. - tqdm_stream : Union[TqdmToLogger, None], optional - Tqdm stream to log the progress. Defaults to None. - - Returns - ------- - np.ndarray - Aligned time series data. - """ - points = [] - - # for i in pose_list: - # for j in i: - # if j[2] <= confidence: - # j[0], j[1] = np.nan, np.nan - - # for i in pose_list: - # i = interpol_first_rows_nans(i) - - for idx in tqdm.tqdm( - range(frame_count), - disable=not True, - file=tqdm_stream, - desc="Align frames", - ): - # Read coordinates and add border - pose_list_bordered = [] - - for i in pose_list: - pose_list_bordered.append((int(i[idx][0] + crop_size[0]), int(i[idx][1] + crop_size[1]))) - - punkte = [] - for i in pose_ref_index: - coord = [ - pose_list_bordered[i][0], - pose_list_bordered[i][1], - ] - punkte.append(coord) - - punkte = [punkte] - punkte = np.asarray(punkte) - - # calculate minimal rectangle around snout and tail - rect = cv.minAreaRect(punkte) - - # change size in rect tuple structure to be equal to crop_size - lst = list(rect) - # lst[0] = center_belly - lst[1] = crop_size - rect = tuple(lst) - - # crop image - shifted_points = crop_and_flip( - rect=rect, - points=pose_list_bordered, - ref_index=pose_ref_index, - ) - - points.append(shifted_points) - - time_series = np.zeros((len(pose_list) * 2, frame_count)) - for i in range(frame_count): - idx = 0 - for j in range(len(pose_list)): - time_series[idx : idx + 2, i] = points[i][j] - idx += 2 - - return time_series - - -def crop_and_flip( - rect: Tuple, - points: List[np.ndarray], - ref_index: Tuple[int, int], -) -> List[np.ndarray]: - """ - Crop and flip the image based on the given rectangle and points. - - Parameters - ---------- - rect : Tuple - Rectangle coordinates (center, size, theta). - points : List[np.ndarray] - List of points. - ref_index : Tuple[int, int] - Reference indices for alignment. - - Returns - ------- - Tuple[np.ndarray, List[np.ndarray]] - Cropped and flipped image, and shifted points. - """ - # Read out rect structures and convert - center, size, theta = rect - center, size = tuple(map(int, center)), tuple(map(int, size)) - - # Get rotation matrix - M = cv.getRotationMatrix2D(center, theta, 1) - - # shift DLC points - x_diff = center[0] - size[0] // 2 - y_diff = center[1] - size[1] // 2 - dlc_points_shifted = [] - for i in points: - point = cv.transform(np.array([[[i[0], i[1]]]]), M)[0][0] - point[0] -= x_diff - point[1] -= y_diff - dlc_points_shifted.append(point) - - # check if flipped correctly, otherwise flip again - if dlc_points_shifted[ref_index[1]][0] >= dlc_points_shifted[ref_index[0]][0]: - rect = ( - (size[0] // 2, size[0] // 2), - size, - 180, - ) # should second value be size[1]? Is this relevant to the flip? 3/5/24 KKL - center, size, theta = rect - center, size = tuple(map(int, center)), tuple(map(int, size)) - - # Get rotation matrix - M = cv.getRotationMatrix2D(center, theta, 1) - - # shift DLC points - x_diff = center[0] - size[0] // 2 - y_diff = center[1] - size[1] // 2 - - points = dlc_points_shifted - dlc_points_shifted = [] - - for i in points: - point = cv.transform(np.array([[[i[0], i[1]]]]), M)[0][0] - point[0] -= x_diff - point[1] -= y_diff - dlc_points_shifted.append(point) - - return dlc_points_shifted +# def egocentric_alignment( +# config: str, +# pose_ref_1: str = "snout", +# pose_ref_2: str = "tailbase", +# crop_size: Tuple[int, int] = (300, 300), +# save_logs: bool = False, +# ) -> None: +# """ +# Egocentric alignment of bevarioral videos. +# Fills in the values in the "egocentric_alignment" key of the states.json file. +# Creates training dataset for VAME at: +# - project_name/ +# - data/ +# - filename/ +# - filename-PE-seq.npy +# - filename/ +# - filename-PE-seq.npy +# The produced .npy files contain the aligned time series data in the +# shape of (num_dlc_features, num_video_frames). + +# Parameters +# ---------- +# config : str +# Path for the project config file. +# pose_ref_index : list, optional +# Pose reference index to be used to align. Defaults to [0, 1]. +# crop_size : tuple, optional +# Size to crop the video. Defaults to (300,300). + +# Raises: +# ------ +# ValueError +# If the config.yaml indicates that the data is not egocentric. +# """ +# try: +# config_file = Path(config).resolve() +# cfg = read_config(str(config_file)) +# if cfg["egocentric_data"]: +# raise ValueError( +# "The config.yaml indicates that the data is egocentric. Please check the parameter 'egocentric_data'." +# ) +# tqdm_stream = None + +# if save_logs: +# log_path = Path(cfg["project_path"]) / "logs" / "egocentric_alignment.log" +# logger_config.add_file_handler(str(log_path)) +# tqdm_stream = TqdmToLogger(logger=logger) + +# logger.info("Starting egocentric alignment") +# project_path = cfg["project_path"] +# sessions = cfg["session_names"] +# confidence = cfg["pose_confidence"] +# num_features = cfg["num_features"] + +# y_shifted_indices = np.arange(0, num_features, 2) +# x_shifted_indices = np.arange(1, num_features, 2) +# # reference_Y_ind = pose_ref_index[0] * 2 +# # reference_X_ind = (pose_ref_index[0] * 2) + 1 + +# # call function and save into your VAME data folder +# for i, session in enumerate(sessions): +# logger.info("Aligning session %s, Pose confidence value: %.2f" % (session, confidence)) +# # read out data +# file_path = str(Path(project_path) / "data" / "raw" / f"{session}.nc") +# _, data_mat, ds = read_pose_estimation_file(file_path=file_path) + +# # get the coordinates for alignment from data table +# # pose_list dimensions: (num_body_parts, num_frames, 3) +# pose_list = [] +# for i in range(int(data_mat.shape[1] / 3)): +# pose_list.append(data_mat[:, i * 3 : (i + 1) * 3]) + +# frame_count = ds.position.time.shape[0] +# keypoints_names = ds.keypoints.values + +# reference_X_ind = np.where(ds.keypoints.values == pose_ref_1)[0][0] * 2 +# reference_Y_ind = reference_X_ind + 1 + +# pose_ref_index = ( +# np.where(keypoints_names == pose_ref_1)[0][0], +# np.where(keypoints_names == pose_ref_2)[0][0], +# ) + +# egocentric_time_series = alignment( +# crop_size=crop_size, +# pose_list=pose_list, +# pose_ref_index=pose_ref_index, +# confidence=confidence, +# frame_count=frame_count, +# tqdm_stream=tqdm_stream, +# ) + +# # Shifiting section added 2/29/2024 PN +# egocentric_time_series_shifted = egocentric_time_series +# reference_Y_shift = egocentric_time_series[reference_Y_ind, :] +# reference_X_shift = egocentric_time_series[reference_X_ind, :] + +# egocentric_time_series_shifted[y_shifted_indices, :] -= reference_Y_shift +# egocentric_time_series_shifted[x_shifted_indices, :] -= reference_X_shift + +# # Save new shifted file +# np.save( +# os.path.join( +# project_path, +# "data", +# "processed", +# session, +# session + "-PE-seq.npy", +# ), +# egocentric_time_series_shifted, +# ) + +# # Add new variable to the dataset +# ds["position_aligned"] = ( +# ("time", "individuals", "keypoints", "space"), +# egocentric_time_series_shifted.T.reshape(frame_count, 1, len(keypoints_names), 2), +# ) +# # save to file +# result_file = Path(project_path) / "data" / "processed" / session / f"{session}-aligned.nc" +# ds.to_netcdf(result_file, engine="scipy") + +# logger.info("Your data is now in the right format and you can call vame.create_trainset()") +# except Exception as e: +# logger.exception(f"{e}") +# raise e +# finally: +# logger_config.remove_file_handler() + + +# def alignment( +# crop_size: Tuple[int, int], +# pose_list: List[np.ndarray], +# pose_ref_index: Tuple[int, int], +# confidence: float, +# frame_count: int, +# tqdm_stream: Union[TqdmToLogger, None] = None, +# ) -> np.ndarray: +# """ +# Egocentric alignment of pose estimation data. + +# Parameters: +# ----------- +# crop_size : Tuple[int, int] +# Size to crop the video frames. +# pose_list : List[np.ndarray] +# List of pose coordinates. +# pose_ref_index : Tuple[int, int] +# Pose reference indices. +# confidence : float +# Pose confidence threshold. +# frame_count : int +# Number of frames to align. +# tqdm_stream : Union[TqdmToLogger, None], optional +# Tqdm stream to log the progress. Defaults to None. + +# Returns +# ------- +# np.ndarray +# Aligned time series data. +# """ +# points = [] + +# # for i in pose_list: +# # for j in i: +# # if j[2] <= confidence: +# # j[0], j[1] = np.nan, np.nan + +# # for i in pose_list: +# # i = interpol_first_rows_nans(i) + +# for idx in tqdm.tqdm( +# range(frame_count), +# disable=not True, +# file=tqdm_stream, +# desc="Align frames", +# ): +# # Read coordinates and add border +# pose_list_bordered = [] + +# for i in pose_list: +# pose_list_bordered.append((int(i[idx][0] + crop_size[0]), int(i[idx][1] + crop_size[1]))) + +# punkte = [] +# for i in pose_ref_index: +# coord = [ +# pose_list_bordered[i][0], +# pose_list_bordered[i][1], +# ] +# punkte.append(coord) + +# punkte = [punkte] +# punkte = np.asarray(punkte) + +# # calculate minimal rectangle around snout and tail +# rect = cv.minAreaRect(punkte) + +# # change size in rect tuple structure to be equal to crop_size +# lst = list(rect) +# # lst[0] = center_belly +# lst[1] = crop_size +# rect = tuple(lst) + +# # crop image +# shifted_points = crop_and_flip( +# rect=rect, +# points=pose_list_bordered, +# ref_index=pose_ref_index, +# ) + +# points.append(shifted_points) + +# time_series = np.zeros((len(pose_list) * 2, frame_count)) +# for i in range(frame_count): +# idx = 0 +# for j in range(len(pose_list)): +# time_series[idx : idx + 2, i] = points[i][j] +# idx += 2 + +# return time_series + + +# def crop_and_flip( +# rect: Tuple, +# points: List[np.ndarray], +# ref_index: Tuple[int, int], +# ) -> List[np.ndarray]: +# """ +# Crop and flip the image based on the given rectangle and points. + +# Parameters +# ---------- +# rect : Tuple +# Rectangle coordinates (center, size, theta). +# points : List[np.ndarray] +# List of points. +# ref_index : Tuple[int, int] +# Reference indices for alignment. + +# Returns +# ------- +# Tuple[np.ndarray, List[np.ndarray]] +# Cropped and flipped image, and shifted points. +# """ +# # Read out rect structures and convert +# center, size, theta = rect +# center, size = tuple(map(int, center)), tuple(map(int, size)) + +# # Get rotation matrix +# M = cv.getRotationMatrix2D(center, theta, 1) + +# # shift DLC points +# x_diff = center[0] - size[0] // 2 +# y_diff = center[1] - size[1] // 2 +# dlc_points_shifted = [] +# for i in points: +# point = cv.transform(np.array([[[i[0], i[1]]]]), M)[0][0] +# point[0] -= x_diff +# point[1] -= y_diff +# dlc_points_shifted.append(point) + +# # check if flipped correctly, otherwise flip again +# if dlc_points_shifted[ref_index[1]][0] >= dlc_points_shifted[ref_index[0]][0]: +# rect = ( +# (size[0] // 2, size[0] // 2), +# size, +# 180, +# ) # should second value be size[1]? Is this relevant to the flip? 3/5/24 KKL +# center, size, theta = rect +# center, size = tuple(map(int, center)), tuple(map(int, size)) + +# # Get rotation matrix +# M = cv.getRotationMatrix2D(center, theta, 1) + +# # shift DLC points +# x_diff = center[0] - size[0] // 2 +# y_diff = center[1] - size[1] // 2 + +# points = dlc_points_shifted +# dlc_points_shifted = [] + +# for i in points: +# point = cv.transform(np.array([[[i[0], i[1]]]]), M)[0][0] +# point[0] -= x_diff +# point[1] -= y_diff +# dlc_points_shifted.append(point) + +# return dlc_points_shifted From 1d42f98cbe8ca78d740d956f5be86b867e765ad7 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Dec 2024 15:35:34 +0100 Subject: [PATCH 48/77] actions --- .github/workflows/deploy-docs.yaml | 3 ++- .github/workflows/test-deploy.yaml | 34 ------------------------------ 2 files changed, 2 insertions(+), 35 deletions(-) delete mode 100644 .github/workflows/test-deploy.yaml diff --git a/.github/workflows/deploy-docs.yaml b/.github/workflows/deploy-docs.yaml index aba08154..1ef83799 100644 --- a/.github/workflows/deploy-docs.yaml +++ b/.github/workflows/deploy-docs.yaml @@ -3,7 +3,7 @@ name: Deploy VAME Docs to GitHub Pages on: push: branches: - - docs + - main jobs: deploy: @@ -59,3 +59,4 @@ jobs: with: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: ./docs/vame-docs-app/build + publish_branch: gh-pages diff --git a/.github/workflows/test-deploy.yaml b/.github/workflows/test-deploy.yaml deleted file mode 100644 index 63b3e2e1..00000000 --- a/.github/workflows/test-deploy.yaml +++ /dev/null @@ -1,34 +0,0 @@ -name: Test deployment - -on: - pull_request: - branches: - - docs - -jobs: - test-deploy: - name: Test deployment - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.11 - uses: actions/setup-python@v5 - with: - python-version: 3.11 - - - name: Install docs dependencies. - run: pip install -r docs/requirements-docs.txt - - - name: Auto generate API Reference. - run: cd docs && pydoc-markdown - - uses: actions/setup-node@v3 - with: - node-version: 18 - cache: yarn - working-directory: docs/vame-docs-app - cache-dependency-path: docs/vame-docs-app/yarn.lock - - - name: Install dependencies - run: cd docs/vame-docs-app && yarn install --frozen-lockfile - - name: Test build website - run: cd docs/vame-docs-app && yarn build \ No newline at end of file From 3ebc0932b3dbbfc36f272ea11856e440a73cf917 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Dec 2024 15:40:08 +0100 Subject: [PATCH 49/77] comment unused code --- src/vame/io/nwb.py | 118 +++++++++--------- src/vame/preprocessing/visualization.py | 152 ++++++++++++------------ src/vame/video/__init__.py | 2 +- src/vame/video/video.py | 86 +++++++------- 4 files changed, 179 insertions(+), 179 deletions(-) diff --git a/src/vame/io/nwb.py b/src/vame/io/nwb.py index da142191..91f19a61 100644 --- a/src/vame/io/nwb.py +++ b/src/vame/io/nwb.py @@ -1,67 +1,67 @@ -from pynwb import NWBHDF5IO -from pynwb.file import NWBFile -from hdmf.utils import LabelledDict -import pandas as pd +# from pynwb import NWBHDF5IO +# from pynwb.file import NWBFile +# from hdmf.utils import LabelledDict +# import pandas as pd -def get_pose_data_from_nwb_file( - nwbfile: NWBFile, - path_to_pose_nwb_series_data: str, -) -> LabelledDict: - """ - Get pose data from nwb file using a inside path to the nwb data. +# def get_pose_data_from_nwb_file( +# nwbfile: NWBFile, +# path_to_pose_nwb_series_data: str, +# ) -> LabelledDict: +# """ +# Get pose data from nwb file using a inside path to the nwb data. - Parameters: - ---------- - nwbfile : NWBFile) - NWB file object. - path_to_pose_nwb_series_data : str - Path to the pose data inside the nwb file. +# Parameters: +# ---------- +# nwbfile : NWBFile) +# NWB file object. +# path_to_pose_nwb_series_data : str +# Path to the pose data inside the nwb file. - Returns - ------- - LabelledDict - Pose data. - """ - if not path_to_pose_nwb_series_data: - raise ValueError("Path to pose nwb series data is required.") - pose_data = nwbfile - for key in path_to_pose_nwb_series_data.split("/"): - if isinstance(pose_data, dict): - pose_data = pose_data.get(key) - continue - pose_data = getattr(pose_data, key) - return pose_data +# Returns +# ------- +# LabelledDict +# Pose data. +# """ +# if not path_to_pose_nwb_series_data: +# raise ValueError("Path to pose nwb series data is required.") +# pose_data = nwbfile +# for key in path_to_pose_nwb_series_data.split("/"): +# if isinstance(pose_data, dict): +# pose_data = pose_data.get(key) +# continue +# pose_data = getattr(pose_data, key) +# return pose_data -def get_dataframe_from_pose_nwb_file( - file_path: str, - path_to_pose_nwb_series_data: str, -) -> pd.DataFrame: - """ - Get pose data from nwb file and return it as a pandas DataFrame. +# def get_dataframe_from_pose_nwb_file( +# file_path: str, +# path_to_pose_nwb_series_data: str, +# ) -> pd.DataFrame: +# """ +# Get pose data from nwb file and return it as a pandas DataFrame. - Parameters - ---------- - file_path : str - Path to the nwb file. - path_to_pose_nwb_series_data : str - Path to the pose data inside the nwb file. +# Parameters +# ---------- +# file_path : str +# Path to the nwb file. +# path_to_pose_nwb_series_data : str +# Path to the pose data inside the nwb file. - Returns - ------- - pd.DataFrame - Pose data as a pandas DataFrame. - """ - with NWBHDF5IO(file_path, "r") as io: - nwbfile = io.read() - pose = get_pose_data_from_nwb_file(nwbfile, path_to_pose_nwb_series_data) - dataframes = [] - for label, pose_series in pose.items(): - data = pose_series.data[:] - confidence = pose_series.confidence[:] - df = pd.DataFrame(data, columns=[f"{label}_x", f"{label}_y"]) - df[f"likelihood_{label}"] = confidence - dataframes.append(df) - final_df = pd.concat(dataframes, axis=1) - return final_df +# Returns +# ------- +# pd.DataFrame +# Pose data as a pandas DataFrame. +# """ +# with NWBHDF5IO(file_path, "r") as io: +# nwbfile = io.read() +# pose = get_pose_data_from_nwb_file(nwbfile, path_to_pose_nwb_series_data) +# dataframes = [] +# for label, pose_series in pose.items(): +# data = pose_series.data[:] +# confidence = pose_series.confidence[:] +# df = pd.DataFrame(data, columns=[f"{label}_x", f"{label}_y"]) +# df[f"likelihood_{label}"] = confidence +# dataframes.append(df) +# final_df = pd.concat(dataframes, axis=1) +# return final_df diff --git a/src/vame/preprocessing/visualization.py b/src/vame/preprocessing/visualization.py index 946d8ad9..088b381f 100644 --- a/src/vame/preprocessing/visualization.py +++ b/src/vame/preprocessing/visualization.py @@ -288,79 +288,79 @@ def visualize_preprocessing_timeseries( ) -def visualize_timeseries( - config: dict, - session_index: int = 0, - n_samples: int = 1000, - positions_key: str = "position", - keypoints_labels: list[str] | None = None, - save_to_file: bool = False, - show_figure: bool = True, -): - """ - Visualize the original positions of the keypoints in a timeseries plot. - """ - project_path = config["project_path"] - sessions = config["session_names"] - session = sessions[session_index] - - # Read session data - file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") - _, _, ds = read_pose_estimation_file(file_path=file_path) - - fig, ax = plt.subplots(2, 1, figsize=(10, 8)) - - individual = "individual_0" - if keypoints_labels is None: - keypoints_labels = ds.keypoints.values - - # Create a colormap with distinguishable colors - cmap = get_cmap("tab10") if len(keypoints_labels) <= 10 else get_cmap("tab20") - colors = [cmap(i / len(keypoints_labels)) for i in range(len(keypoints_labels))] - - for i, kp in enumerate(keypoints_labels): - sel_x = dict( - individuals=individual, - keypoints=kp, - space="x", - ) - sel_y = dict( - individuals=individual, - keypoints=kp, - space="y", - ) - - # Original positions (first two subplots) - ds[positions_key].sel(**sel_x)[0:n_samples].plot( - linewidth=1.5, - ax=ax[0], - label=kp, - color=colors[i], - ) - ds[positions_key].sel(**sel_y)[0:n_samples].plot( - linewidth=1.5, - ax=ax[1], - label=kp, - color=colors[i], - ) - - # Set common labels for Y axes - ax[0].set_ylabel( - "Allocentric X", - fontsize=12, - ) - ax[1].set_ylabel( - "Allocentric Y", - fontsize=12, - ) - - # Labels for X axes - for idx, a in enumerate(ax): - a.set_title("") - if idx % 2 == 0: - a.set_xlabel("") - else: - a.set_xlabel( - "Time", - fontsize=10, - ) +# def visualize_timeseries( +# config: dict, +# session_index: int = 0, +# n_samples: int = 1000, +# positions_key: str = "position", +# keypoints_labels: list[str] | None = None, +# save_to_file: bool = False, +# show_figure: bool = True, +# ): +# """ +# Visualize the original positions of the keypoints in a timeseries plot. +# """ +# project_path = config["project_path"] +# sessions = config["session_names"] +# session = sessions[session_index] + +# # Read session data +# file_path = str(Path(project_path) / "data" / "processed" / f"{session}_processed.nc") +# _, _, ds = read_pose_estimation_file(file_path=file_path) + +# fig, ax = plt.subplots(2, 1, figsize=(10, 8)) + +# individual = "individual_0" +# if keypoints_labels is None: +# keypoints_labels = ds.keypoints.values + +# # Create a colormap with distinguishable colors +# cmap = get_cmap("tab10") if len(keypoints_labels) <= 10 else get_cmap("tab20") +# colors = [cmap(i / len(keypoints_labels)) for i in range(len(keypoints_labels))] + +# for i, kp in enumerate(keypoints_labels): +# sel_x = dict( +# individuals=individual, +# keypoints=kp, +# space="x", +# ) +# sel_y = dict( +# individuals=individual, +# keypoints=kp, +# space="y", +# ) + +# # Original positions (first two subplots) +# ds[positions_key].sel(**sel_x)[0:n_samples].plot( +# linewidth=1.5, +# ax=ax[0], +# label=kp, +# color=colors[i], +# ) +# ds[positions_key].sel(**sel_y)[0:n_samples].plot( +# linewidth=1.5, +# ax=ax[1], +# label=kp, +# color=colors[i], +# ) + +# # Set common labels for Y axes +# ax[0].set_ylabel( +# "Allocentric X", +# fontsize=12, +# ) +# ax[1].set_ylabel( +# "Allocentric Y", +# fontsize=12, +# ) + +# # Labels for X axes +# for idx, a in enumerate(ax): +# a.set_title("") +# if idx % 2 == 0: +# a.set_xlabel("") +# else: +# a.set_xlabel( +# "Time", +# fontsize=10, +# ) diff --git a/src/vame/video/__init__.py b/src/vame/video/__init__.py index 3cec722d..87119b8a 100644 --- a/src/vame/video/__init__.py +++ b/src/vame/video/__init__.py @@ -1,4 +1,4 @@ from vame.video.video import ( get_video_frame_rate, - play_aligned_video, + # play_aligned_video, ) diff --git a/src/vame/video/video.py b/src/vame/video/video.py index aeee4b80..13d640b7 100644 --- a/src/vame/video/video.py +++ b/src/vame/video/video.py @@ -12,47 +12,47 @@ def get_video_frame_rate(video_path): return frame_rate -def play_aligned_video( - a: List[np.ndarray], - n: List[List[np.ndarray]], - frame_count: int, -) -> None: - """ - Play the aligned video. +# def play_aligned_video( +# a: List[np.ndarray], +# n: List[List[np.ndarray]], +# frame_count: int, +# ) -> None: +# """ +# Play the aligned video. - Parameters - ---------- - a : List[np.ndarray] - List of aligned images. - n : List[List[np.ndarray]] - List of aligned DLC points. - frame_count : int - Number of frames in the video. - """ - colors = [ - (255, 0, 0), - (0, 255, 0), - (0, 0, 255), - (255, 255, 0), - (255, 0, 255), - (0, 255, 255), - (0, 0, 0), - (255, 255, 255), - ] - for i in range(frame_count): - # Capture frame-by-frame - ret, frame = True, a[i] - if ret is True: - # Display the resulting frame - frame = cv2.cvtColor(frame.astype("uint8") * 255, cv2.COLOR_GRAY2BGR) - im_color = cv2.applyColorMap(frame, cv2.COLORMAP_JET) - for c, j in enumerate(n[i]): - cv2.circle(im_color, (j[0], j[1]), 5, colors[c], -1) - cv2.imshow("Frame", im_color) - # Press Q on keyboard to exit - # Break the loop - if cv2.waitKey(25) & 0xFF == ord("q"): - break - else: - break - cv2.destroyAllWindows() +# Parameters +# ---------- +# a : List[np.ndarray] +# List of aligned images. +# n : List[List[np.ndarray]] +# List of aligned DLC points. +# frame_count : int +# Number of frames in the video. +# """ +# colors = [ +# (255, 0, 0), +# (0, 255, 0), +# (0, 0, 255), +# (255, 255, 0), +# (255, 0, 255), +# (0, 255, 255), +# (0, 0, 0), +# (255, 255, 255), +# ] +# for i in range(frame_count): +# # Capture frame-by-frame +# ret, frame = True, a[i] +# if ret is True: +# # Display the resulting frame +# frame = cv2.cvtColor(frame.astype("uint8") * 255, cv2.COLOR_GRAY2BGR) +# im_color = cv2.applyColorMap(frame, cv2.COLORMAP_JET) +# for c, j in enumerate(n[i]): +# cv2.circle(im_color, (j[0], j[1]), 5, colors[c], -1) +# cv2.imshow("Frame", im_color) +# # Press Q on keyboard to exit +# # Break the loop +# if cv2.waitKey(25) & 0xFF == ord("q"): +# break +# else: +# break +# cv2.destroyAllWindows() From b88db1115bcc17cb97bea12f9fba2fb2c24783e3 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Dec 2024 16:08:23 +0100 Subject: [PATCH 50/77] docs --- .github/workflows/deploy-docs.yaml | 1 + docs/pydoc-markdown.yml | 7 +- docs/requirements-docs.txt | 3 +- .../docs/reference/__init__/__init__.md | 5 + .../docs/reference/analysis/__init__.md | 5 + .../reference/analysis/community_analysis.md | 245 ++++++++++++++++++ .../analysis/generative_functions.md | 110 ++++++++ .../docs/reference/analysis/gif_creator.md | 70 +++++ .../reference/analysis/pose_segmentation.md | 139 ++++++++++ .../docs/reference/analysis/tree_hierarchy.md | 169 ++++++++++++ .../docs/reference/analysis/umap.md | 123 +++++++++ .../docs/reference/analysis/videowriter.md | 123 +++++++++ .../reference/initialize_project/__init__.md | 5 + .../docs/reference/initialize_project/new.md | 66 +++++ .../docs/reference/io/__init__.md | 5 + .../docs/reference/io/load_poses.md | 71 +++++ docs/vame-docs-app/docs/reference/io/nwb.md | 5 + .../docs/reference/model/__init__.md | 5 + .../docs/reference/model/create_training.md | 67 +++++ .../docs/reference/model/dataloader.md | 65 +++++ .../docs/reference/model/evaluate.md | 121 +++++++++ .../docs/reference/model/rnn_model.md | 230 ++++++++++++++++ .../docs/reference/model/rnn_vae.md | 240 +++++++++++++++++ docs/vame-docs-app/docs/reference/pipeline.md | 139 ++++++++++ .../docs/reference/preprocessing/__init__.md | 5 + .../align_egocentrical_legacy.md | 5 + .../docs/reference/preprocessing/alignment.md | 33 +++ .../docs/reference/preprocessing/cleaning.md | 53 ++++ .../docs/reference/preprocessing/filter.md | 29 +++ .../reference/preprocessing/preprocessing.md | 34 +++ .../docs/reference/preprocessing/to_model.md | 27 ++ .../reference/preprocessing/visualization.md | 39 +++ .../vame-docs-app/docs/reference/sidebar.json | 60 +---- .../docs/reference/util/__init__.md | 5 + .../docs/reference/util/auxiliary.md | 62 +++++ docs/vame-docs-app/docs/reference/util/cli.md | 13 + .../docs/reference/util/csv_to_npy.md | 29 +++ .../docs/reference/util/data_manipulation.md | 120 +++++++++ .../docs/reference/util/gif_pose_helper.md | 40 +++ .../docs/reference/util/model_util.md | 25 ++ .../docs/reference/util/report.md | 30 +++ .../docs/reference/util/sample_data.md | 25 ++ .../docs/reference/video/__init__.md | 5 + .../docs/reference/video/video.md | 11 + src/vame/analysis/community_analysis.py | 12 +- src/vame/analysis/generative_functions.py | 40 +-- src/vame/analysis/gif_creator.py | 16 +- src/vame/analysis/tree_hierarchy.py | 56 ++-- src/vame/analysis/umap.py | 2 +- src/vame/analysis/videowriter.py | 4 +- src/vame/initialize_project/new.py | 2 +- src/vame/io/load_poses.py | 22 +- src/vame/io/nwb.py | 6 +- src/vame/model/create_training.py | 4 +- src/vame/model/rnn_model.py | 60 ++--- src/vame/pipeline.py | 12 +- .../align_egocentrical_legacy.py | 18 +- src/vame/preprocessing/alignment.py | 8 +- src/vame/preprocessing/cleaning.py | 16 +- src/vame/preprocessing/filter.py | 8 +- src/vame/preprocessing/preprocessing.py | 8 +- src/vame/preprocessing/to_model.py | 8 +- src/vame/util/auxiliary.py | 2 +- src/vame/util/data_manipulation.py | 10 +- src/vame/util/gif_pose_helper.py | 8 +- src/vame/util/model_util.py | 2 +- src/vame/video/video.py | 2 +- 67 files changed, 2769 insertions(+), 226 deletions(-) create mode 100644 docs/vame-docs-app/docs/reference/__init__/__init__.md create mode 100644 docs/vame-docs-app/docs/reference/analysis/__init__.md create mode 100644 docs/vame-docs-app/docs/reference/analysis/community_analysis.md create mode 100644 docs/vame-docs-app/docs/reference/analysis/generative_functions.md create mode 100644 docs/vame-docs-app/docs/reference/analysis/gif_creator.md create mode 100644 docs/vame-docs-app/docs/reference/analysis/pose_segmentation.md create mode 100644 docs/vame-docs-app/docs/reference/analysis/tree_hierarchy.md create mode 100644 docs/vame-docs-app/docs/reference/analysis/umap.md create mode 100644 docs/vame-docs-app/docs/reference/analysis/videowriter.md create mode 100644 docs/vame-docs-app/docs/reference/initialize_project/__init__.md create mode 100644 docs/vame-docs-app/docs/reference/initialize_project/new.md create mode 100644 docs/vame-docs-app/docs/reference/io/__init__.md create mode 100644 docs/vame-docs-app/docs/reference/io/load_poses.md create mode 100644 docs/vame-docs-app/docs/reference/io/nwb.md create mode 100644 docs/vame-docs-app/docs/reference/model/__init__.md create mode 100644 docs/vame-docs-app/docs/reference/model/create_training.md create mode 100644 docs/vame-docs-app/docs/reference/model/dataloader.md create mode 100644 docs/vame-docs-app/docs/reference/model/evaluate.md create mode 100644 docs/vame-docs-app/docs/reference/model/rnn_model.md create mode 100644 docs/vame-docs-app/docs/reference/model/rnn_vae.md create mode 100644 docs/vame-docs-app/docs/reference/pipeline.md create mode 100644 docs/vame-docs-app/docs/reference/preprocessing/__init__.md create mode 100644 docs/vame-docs-app/docs/reference/preprocessing/align_egocentrical_legacy.md create mode 100644 docs/vame-docs-app/docs/reference/preprocessing/alignment.md create mode 100644 docs/vame-docs-app/docs/reference/preprocessing/cleaning.md create mode 100644 docs/vame-docs-app/docs/reference/preprocessing/filter.md create mode 100644 docs/vame-docs-app/docs/reference/preprocessing/preprocessing.md create mode 100644 docs/vame-docs-app/docs/reference/preprocessing/to_model.md create mode 100644 docs/vame-docs-app/docs/reference/preprocessing/visualization.md create mode 100644 docs/vame-docs-app/docs/reference/util/__init__.md create mode 100644 docs/vame-docs-app/docs/reference/util/auxiliary.md create mode 100644 docs/vame-docs-app/docs/reference/util/cli.md create mode 100644 docs/vame-docs-app/docs/reference/util/csv_to_npy.md create mode 100644 docs/vame-docs-app/docs/reference/util/data_manipulation.md create mode 100644 docs/vame-docs-app/docs/reference/util/gif_pose_helper.md create mode 100644 docs/vame-docs-app/docs/reference/util/model_util.md create mode 100644 docs/vame-docs-app/docs/reference/util/report.md create mode 100644 docs/vame-docs-app/docs/reference/util/sample_data.md create mode 100644 docs/vame-docs-app/docs/reference/video/__init__.md create mode 100644 docs/vame-docs-app/docs/reference/video/video.md diff --git a/.github/workflows/deploy-docs.yaml b/.github/workflows/deploy-docs.yaml index 1ef83799..97e0c0b6 100644 --- a/.github/workflows/deploy-docs.yaml +++ b/.github/workflows/deploy-docs.yaml @@ -36,6 +36,7 @@ jobs: - name: Install dependencies run: cd docs/vame-docs-app && yarn install --frozen-lockfile + - name: Build website run: cd docs/vame-docs-app && yarn build diff --git a/docs/pydoc-markdown.yml b/docs/pydoc-markdown.yml index 001921af..03bccd69 100644 --- a/docs/pydoc-markdown.yml +++ b/docs/pydoc-markdown.yml @@ -1,11 +1,8 @@ loaders: - type: python - search_path: ['../src'] + search_path: ["../src/vame"] processors: - - type: filter - skip_empty_modules: true - - type: smart - - type: crossref + - type: numpy renderer: type: docusaurus docs_base_path: vame-docs-app/docs diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 6dcdf89d..b9ad3b85 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -1 +1,2 @@ -pydoc-markdown==4.8.2 \ No newline at end of file +# pydoc-markdown==4.8.2 +git+https://github.com/luiztauffer/pydoc-markdown.git@develop \ No newline at end of file diff --git a/docs/vame-docs-app/docs/reference/__init__/__init__.md b/docs/vame-docs-app/docs/reference/__init__/__init__.md new file mode 100644 index 00000000..7d29be34 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/__init__/__init__.md @@ -0,0 +1,5 @@ +--- +sidebar_label: __init__ +title: __init__ +--- + diff --git a/docs/vame-docs-app/docs/reference/analysis/__init__.md b/docs/vame-docs-app/docs/reference/analysis/__init__.md new file mode 100644 index 00000000..15aa5ec5 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/analysis/__init__.md @@ -0,0 +1,5 @@ +--- +sidebar_label: analysis +title: analysis +--- + diff --git a/docs/vame-docs-app/docs/reference/analysis/community_analysis.md b/docs/vame-docs-app/docs/reference/analysis/community_analysis.md new file mode 100644 index 00000000..0552f9ff --- /dev/null +++ b/docs/vame-docs-app/docs/reference/analysis/community_analysis.md @@ -0,0 +1,245 @@ +--- +sidebar_label: community_analysis +title: analysis.community_analysis +--- + +#### logger\_config + +#### logger + +#### get\_adjacency\_matrix + +```python +def get_adjacency_matrix( + labels: np.ndarray, + n_clusters: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray] +``` + +Calculate the adjacency matrix, transition matrix, and temporal matrix. + +**Arguments** + +* **labels** (`np.ndarray`): Array of cluster labels. +* **n_clusters** (`int`): Number of clusters. + +**Returns** + +* `Tuple[np.ndarray, np.ndarray, np.ndarray]`: Tuple containing: adjacency matrix, transition matrix, and temporal matrix. + +#### get\_transition\_matrix + +```python +def get_transition_matrix(adjacency_matrix: np.ndarray, + threshold: float = 0.0) -> np.ndarray +``` + +Compute the transition matrix from the adjacency matrix. + +**Arguments** + +* **adjacency_matrix** (`np.ndarray`): Adjacency matrix. +* **threshold** (`float, optional`): Threshold for considering transitions. Defaults to 0.0. + +**Returns** + +* `np.ndarray`: Transition matrix. + +#### fill\_motifs\_with\_zero\_counts + +```python +def fill_motifs_with_zero_counts(unique_motif_labels: np.ndarray, + motif_counts: np.ndarray, + n_clusters: int) -> np.ndarray +``` + +Find motifs that never occur in the dataset, and fill the motif_counts array with zeros for those motifs. +Example 1: + - unique_motif_labels = [0, 1, 3, 4] + - motif_counts = [10, 20, 30, 40], + - n_clusters = 5 + - the function will return [10, 20, 0, 30, 40]. +Example 2: + - unique_motif_labels = [0, 1, 3, 4] + - motif_counts = [10, 20, 30, 40], + - n_clusters = 6 + - the function will return [10, 20, 0, 30, 40, 0]. + +**Arguments** + +* **unique_motif_labels** (`np.ndarray`): Array of unique motif labels. +* **motif_counts** (`np.ndarray`): Array of motif counts (in number of frames). +* **n_clusters** (`int`): Number of clusters. + +**Returns** + +* `np.ndarray`: List of motif counts (in number of frame) with 0's for motifs that never happened. + +#### augment\_motif\_timeseries + +```python +def augment_motif_timeseries(labels: np.ndarray, + n_clusters: int) -> Tuple[np.ndarray, np.ndarray] +``` + +Augment motif time series by filling zero motifs. + +**Arguments** + +* **labels** (`np.ndarray`): Original array of labels. +* **n_clusters** (`int`): Number of clusters. + +**Returns** + +* `Tuple[np.ndarray, np.ndarray]`: Tuple with: + - Array of labels augmented with motifs that never occurred, artificially inputed + at the end of the original labels array + - Indices of the motifs that never occurred. + +#### get\_motif\_labels + +```python +def get_motif_labels(config: dict, sessions: List[str], model_name: str, + n_clusters: int, + segmentation_algorithm: str) -> np.ndarray +``` + +Get motif labels for given files. + +**Arguments** + +* **config** (`dict`): Configuration parameters. +* **sessions** (`List[str]`): List of session names. +* **model_name** (`str`): Model name. +* **n_clusters** (`int`): Number of clusters. +* **segmentation_algorithm** (`str`): Which segmentation algorithm to use. Options are 'hmm' or 'kmeans'. + +**Returns** + +* `np.ndarray`: Array of community labels (integers). + +#### compute\_transition\_matrices + +```python +def compute_transition_matrices(files: List[str], labels: List[np.ndarray], + n_clusters: int) -> List[np.ndarray] +``` + +Compute transition matrices for given files and labels. + +**Arguments** + +* **files** (`List[str]`): List of file paths. +* **labels** (`List[np.ndarray]`): List of label arrays. +* **n_clusters** (`int`): Number of clusters. + +**Returns** + +* `List[np.ndarray]:`: List of transition matrices. + +#### create\_cohort\_community\_bag + +```python +def create_cohort_community_bag(motif_labels: List[np.ndarray], + trans_mat_full: np.ndarray, + cut_tree: int | None, n_clusters: int) -> list +``` + +Create cohort community bag for given motif labels, transition matrix, +cut tree, and number of clusters. (markov chain to tree -> community detection) + +**Arguments** + +* **motif_labels** (`List[np.ndarray]`): List of motif label arrays. +* **trans_mat_full** (`np.ndarray`): Full transition matrix. +* **cut_tree** (`int | None`): Cut line for tree. +* **n_clusters** (`int`): Number of clusters. + +**Returns** + +* `List`: List of community bags. + +#### get\_cohort\_community\_labels + +```python +def get_cohort_community_labels( + motif_labels: List[np.ndarray], + cohort_community_bag: list, + median_filter_size: int = 7) -> List[np.ndarray] +``` + +Transform kmeans/hmm parameterized latent vector motifs into communities. +Get cohort community labels for given labels, and community bags. + +**Arguments** + +* **labels** (`List[np.ndarray]`): List of label arrays. +* **cohort_community_bag** (`np.ndarray`): List of community bags. Dimensions: (n_communities, n_clusters_in_community) +* **median_filter_size** (`int, optional`): Size of the median filter, in number of frames. Defaults to 7. + +**Returns** + +* `List[np.ndarray]`: List of cohort community labels for each file. + +#### save\_cohort\_community\_labels\_per\_file + +```python +def save_cohort_community_labels_per_file(config: dict, sessions: List[str], + model_name: str, n_clusters: int, + segmentation_algorithm: str, + cohort_community_bag: list) -> None +``` + +#### community + +```python +@save_state(model=CommunityFunctionSchema) +def community(config: dict, + segmentation_algorithm: SegmentationAlgorithms, + cohort: bool = True, + cut_tree: int | None = None, + save_logs: bool = False) -> None +``` + +Perform community analysis. +Fills in the values in the "community" key of the states.json file. +Saves results files at: + +1. If cohort is True: +- project_name/ + - results/ + - community_cohort/ + - segmentation_algorithm-n_clusters/ + - cohort_community_bag.npy + - cohort_community_label.npy + - cohort_segmentation_algorithm_label.npy + - cohort_transition_matrix.npy + - hierarchy.pkl + - file_name/ + - model_name/ + - segmentation_algorithm-n_clusters/ + - community/ + - cohort_community_label_file_name.npy + +2. If cohort is False: +- project_name/ + - results/ + - file_name/ + - model_name/ + - segmentation_algorithm-n_clusters/ + - community/ + - transition_matrix_file_name.npy + - community_label_file_name.npy + - hierarchy_file_name.pkl + +**Arguments** + +* **config** (`dict`): Configuration parameters. +* **segmentation_algorithm** (`SegmentationAlgorithms`): Which segmentation algorithm to use. Options are 'hmm' or 'kmeans'. +* **cohort** (`bool, optional`): Flag indicating cohort analysis. Defaults to True. +* **cut_tree** (`int, optional`): Cut line for tree. Defaults to None. +* **save_logs** (`bool, optional`): Flag indicating whether to save logs. Defaults to False. + +**Returns** + +* `None` + diff --git a/docs/vame-docs-app/docs/reference/analysis/generative_functions.md b/docs/vame-docs-app/docs/reference/analysis/generative_functions.md new file mode 100644 index 00000000..6b500d92 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/analysis/generative_functions.md @@ -0,0 +1,110 @@ +--- +sidebar_label: generative_functions +title: analysis.generative_functions +--- + +#### logger\_config + +#### logger + +#### random\_generative\_samples\_motif + +```python +def random_generative_samples_motif(cfg: dict, model: torch.nn.Module, + latent_vector: np.ndarray, + labels: np.ndarray, + n_clusters: int) -> plt.Figure +``` + +Generate random samples for motifs. + +**Arguments** + +* **cfg** (`dict`): Configuration dictionary. +* **model** (`torch.nn.Module`): PyTorch model. +* **latent_vector** (`np.ndarray`): Latent vectors. +* **labels** (`np.ndarray`): Labels. +* **n_clusters** (`int`): Number of clusters. + +**Returns** + +* `plt.Figure`: Figure of generated samples. + +#### random\_generative\_samples + +```python +def random_generative_samples(cfg: dict, model: torch.nn.Module, + latent_vector: np.ndarray) -> plt.Figure +``` + +Generate random generative samples. + +**Arguments** + +* **cfg** (`dict`): Configuration dictionary. +* **model** (`torch.nn.Module`): PyTorch model. +* **latent_vector** (`np.ndarray`): Latent vectors. + +**Returns** + +* `plt.Figure`: Figure of generated samples. + +#### random\_reconstruction\_samples + +```python +def random_reconstruction_samples(cfg: dict, model: torch.nn.Module, + latent_vector: np.ndarray) -> plt.Figure +``` + +Generate random reconstruction samples. + +**Arguments** + +* **cfg** (`dict`): Configuration dictionary. +* **model** (`torch.nn.Module`): PyTorch model to use. +* **latent_vector** (`np.ndarray`): Latent vectors. + +**Returns** + +* `plt.Figure`: Figure of reconstructed samples. + +#### visualize\_cluster\_center + +```python +def visualize_cluster_center(cfg: dict, model: torch.nn.Module, + cluster_center: np.ndarray) -> plt.Figure +``` + +Visualize cluster centers. + +**Arguments** + +* **cfg** (`dict`): Configuration dictionary. +* **model** (`torch.nn.Module`): PyTorch model. +* **cluster_center** (`np.ndarray`): Cluster centers. + +**Returns** + +* `plt.Figure`: Figure of cluster centers. + +#### generative\_model + +```python +@save_state(model=GenerativeModelFunctionSchema) +def generative_model(config: dict, + segmentation_algorithm: SegmentationAlgorithms, + mode: str = "sampling", + save_logs: bool = False) -> plt.Figure +``` + +Generative model. + +**Arguments** + +* **config** (`dict`): Configuration dictionary. +* **mode** (`str, optional`): Mode for generating samples. Defaults to "sampling". + +**Returns** + +* `plt.Figure`: Plots of generated samples for each segmentation algorithm. + diff --git a/docs/vame-docs-app/docs/reference/analysis/gif_creator.md b/docs/vame-docs-app/docs/reference/analysis/gif_creator.md new file mode 100644 index 00000000..85099c91 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/analysis/gif_creator.md @@ -0,0 +1,70 @@ +--- +sidebar_label: gif_creator +title: analysis.gif_creator +--- + +#### logger\_config + +#### logger + +#### create\_video + +```python +def create_video(path_to_file: str, session: str, embed: np.ndarray, + clabel: np.ndarray, frames: List[np.ndarray], start: int, + length: int, max_lag: int, num_points: int) -> None +``` + +Create video frames for the given embedding. + +**Arguments** + +* **path_to_file** (`str`): Path to the file. +* **session** (`str`): Session name. +* **embed** (`np.ndarray`): Embedding array. +* **clabel** (`np.ndarray`): Cluster labels. +* **frames** (`List[np.ndarray]`): List of frames. +* **start** (`int`): Starting index. +* **length** (`int`): Length of the video. +* **max_lag** (`int`): Maximum lag. +* **num_points** (`int`): Number of points. + +**Returns** + +* `None` + +#### gif + +```python +def gif( + config: str, + pose_ref_index: list, + segmentation_algorithm: SegmentationAlgorithms, + subtract_background: bool = True, + start: int | None = None, + length: int = 500, + max_lag: int = 30, + label: str = "community", + file_format: str = ".mp4", + crop_size: Tuple[int, int] = (300, 300)) -> None +``` + +Create a GIF from the given configuration. + +**Arguments** + +* **config** (`str`): Path to the configuration file. +* **pose_ref_index** (`list`): List of reference coordinate indices for alignment. +* **segmentation_algorithm** (`SegmentationAlgorithms`): Segmentation algorithm. +* **subtract_background** (`bool, optional`): Whether to subtract background. Defaults to True. +* **start :int, optional**: Starting index. Defaults to None. +* **length** (`int, optional`): Length of the video. Defaults to 500. +* **max_lag** (`int, optional`): Maximum lag. Defaults to 30. +* **label** (`str, optional`): Label type [None, community, motif]. Defaults to 'community'. +* **file_format** (`str, optional`): File format. Defaults to '.mp4'. +* **crop_size** (`Tuple[int, int], optional`): Crop size. Defaults to (300,300). + +**Returns** + +* `None` + diff --git a/docs/vame-docs-app/docs/reference/analysis/pose_segmentation.md b/docs/vame-docs-app/docs/reference/analysis/pose_segmentation.md new file mode 100644 index 00000000..038152bb --- /dev/null +++ b/docs/vame-docs-app/docs/reference/analysis/pose_segmentation.md @@ -0,0 +1,139 @@ +--- +sidebar_label: pose_segmentation +title: analysis.pose_segmentation +--- + +#### logger\_config + +#### logger + +#### embedd\_latent\_vectors + +```python +def embedd_latent_vectors( + cfg: dict, + sessions: List[str], + model: RNN_VAE, + fixed: bool, + read_from_variable: str = "position_processed", + tqdm_stream: Union[TqdmToLogger, None] = None) -> List[np.ndarray] +``` + +Embed latent vectors for the given files using the VAME model. + +**Arguments** + +* **cfg** (`dict`): Configuration dictionary. +* **sessions** (`List[str]`): List of session names. +* **model** (`RNN_VAE`): VAME model. +* **fixed** (`bool`): Whether the model is fixed. +* **tqdm_stream** (`TqdmToLogger, optional`): TQDM Stream to redirect the tqdm output to logger. + +**Returns** + +* `List[np.ndarray]`: List of latent vectors for each file. + +#### get\_motif\_usage + +```python +def get_motif_usage(session_labels: np.ndarray, n_clusters: int) -> np.ndarray +``` + +Count motif usage from session label array. + +**Arguments** + +* **session_labels** (`np.ndarray`): Array of session labels. +* **n_clusters** (`int`): Number of clusters. + +**Returns** + +* `np.ndarray`: Array of motif usage counts. + +#### same\_segmentation + +```python +def same_segmentation( + cfg: dict, sessions: List[str], latent_vectors: List[np.ndarray], + n_clusters: int, segmentation_algorithm: str +) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]] +``` + +Apply the same segmentation to all animals. + +**Arguments** + +* **cfg** (`dict`): Configuration dictionary. +* **sessions** (`List[str]`): List of session names. +* **latent_vectors** (`List[np.ndarray]`): List of latent vector arrays. +* **n_clusters** (`int`): Number of clusters. +* **segmentation_algorithm** (`str`): Segmentation algorithm. + +**Returns** + +* `Tuple`: Tuple of labels, cluster centers, and motif usages. + +#### individual\_segmentation + +```python +def individual_segmentation(cfg: dict, sessions: List[str], + latent_vectors: List[np.ndarray], + n_clusters: int) -> Tuple +``` + +Apply individual segmentation to each session. + +**Arguments** + +* **cfg** (`dict`): Configuration dictionary. +* **sessions** (`List[str]`): List of session names. +* **latent_vectors** (`List[np.ndarray]`): List of latent vector arrays. +* **n_clusters** (`int`): Number of clusters. + +**Returns** + +* `Tuple`: Tuple of labels, cluster centers, and motif usages. + +#### segment\_session + +```python +@save_state(model=SegmentSessionFunctionSchema) +def segment_session(config: dict, save_logs: bool = False) -> None +``` + +Perform pose segmentation using the VAME model. +Fills in the values in the "segment_session" key of the states.json file. +Creates files at: +- project_name/ + - results/ + - hmm_trained.pkl + - session/ + - model_name/ + - hmm-n_clusters/ + - latent_vector_session.npy + - motif_usage_session.npy + - n_cluster_label_session.npy + - kmeans-n_clusters/ + - latent_vector_session.npy + - motif_usage_session.npy + - n_cluster_label_session.npy + - cluster_center_session.npy + +latent_vector_session.npy contains the projection of the data into the latent space, +for each frame of the video. Dimmentions: (n_frames, n_latent_features) + +motif_usage_session.npy contains the number of times each motif was used in the video. +Dimmentions: (n_motifs,) + +n_cluster_label_session.npy contains the label of the cluster assigned to each frame. +Dimmentions: (n_frames,) + +**Arguments** + +* **config** (`dict`): Configuration dictionary. +* **save_logs** (`bool, optional`): Whether to save logs, by default False. + +**Returns** + +* `None` + diff --git a/docs/vame-docs-app/docs/reference/analysis/tree_hierarchy.md b/docs/vame-docs-app/docs/reference/analysis/tree_hierarchy.md new file mode 100644 index 00000000..6a3e8790 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/analysis/tree_hierarchy.md @@ -0,0 +1,169 @@ +--- +sidebar_label: tree_hierarchy +title: analysis.tree_hierarchy +--- + +#### hierarchy\_pos + +```python +def hierarchy_pos(G: nx.Graph, + root: str | None = None, + width: float = 0.5, + vert_gap: float = 0.2, + vert_loc: float = 0, + xcenter: float = 0.5) -> Dict[str, Tuple[float, float]] +``` + +Positions nodes in a tree-like layout. +Ref: From Joel's answer at https://stackoverflow.com/a/29597209/2966723. + +**Arguments** + +* **G** (`nx.Graph`): The input graph. Must be a tree. +* **root** (`str, optional`): The root node of the tree. If None, the function selects a root node based on graph type. +Defaults to None. +* **width** (`float, optional`): The horizontal space assigned to each level. Defaults to 0.5. +* **vert_gap** (`float, optional`): The vertical gap between levels. Defaults to 0.2. +* **vert_loc** (`float, optional`): The vertical location of the root node. Defaults to 0. +* **xcenter** (`float, optional`): The horizontal location of the root node. Defaults to 0.5. + +**Returns** + +* `Dict[str, Tuple[float, float]]`: A dictionary mapping node names to their positions (x, y). + +#### merge\_func + +```python +def merge_func(transition_matrix: np.ndarray, n_clusters: int, + motif_norm: np.ndarray, + merge_sel: int) -> Tuple[np.ndarray, np.ndarray] +``` + +Merge nodes in a graph based on a selection criterion. + +**Arguments** + +* **transition_matrix** (`np.ndarray`): The transition matrix of the graph. +* **n_clusters** (`int`): The number of clusters. +* **motif_norm** (`np.ndarray`): The normalized motif matrix. +* **merge_sel** (`int`): The merge selection criterion. +- 0: Merge nodes with highest transition probability. +- 1: Merge nodes with lowest cost. + +**Returns** + +* `Tuple[np.ndarray, np.ndarray]`: A tuple containing the merged nodes. + +#### graph\_to\_tree + +```python +def graph_to_tree(motif_usage: np.ndarray, + transition_matrix: np.ndarray, + n_clusters: int, + merge_sel: int = 1) -> nx.Graph +``` + +Convert a graph to a tree. + +**Arguments** + +* **motif_usage** (`np.ndarray`): The motif usage matrix. +* **transition_matrix** (`np.ndarray`): The transition matrix of the graph. +* **n_clusters** (`int`): The number of clusters. +* **merge_sel** (`int, optional`): The merge selection criterion. Defaults to 1. +- 0: Merge nodes with highest transition probability. +- 1: Merge nodes with lowest cost. + +**Returns** + +* `nx.Graph`: The tree. + +#### draw\_tree + +```python +def draw_tree( + T: nx.Graph, + fig_width: float = 200.0, + usage_dict: Dict[str, float] = dict()) -> None +``` + +Draw a tree. + +**Arguments** + +* **T** (`nx.Graph`): The tree to be drawn. +* **fig_width** (`int, optional`): The width of the figure. Defaults to 10. + +**Returns** + +* `None` + +#### \_traverse\_tree\_cutline + +```python +def _traverse_tree_cutline( + T: nx.Graph, + node: List[str], + traverse_list: List[str], + cutline: int, + level: int, + community_bag: List[List[str]], + community_list: List[str] = None) -> List[List[str]] +``` + +DEPRECATED in favor of bag_nodes_by_cutline. +Helper function for tree traversal with a cutline. + +**Arguments** + +* **T** (`nx.Graph`): The tree to be traversed. +* **node** (`List[str]`): Current node being traversed. +* **traverse_list** (`List[str]`): List of traversed nodes. +* **cutline** (`int`): The cutline level. +* **level** (`int`): The current level in the tree. +* **community_bag** (`List[List[str]]`): List of community bags. +* **community_list** (`List[str], optional`): List of nodes in the current community bag. + +**Returns** + +* `List[List[str]]`: List of lists community bags. + +#### traverse\_tree\_cutline + +```python +def traverse_tree_cutline(T: nx.Graph, + root_node: str | None = None, + cutline: int = 2) -> List[List[str]] +``` + +DEPRECATED in favor of bag_nodes_by_cutline. +Traverse a tree with a cutline and return the community bags. + +**Arguments** + +* **T** (`nx.Graph`): The tree to be traversed. +* **root_node** (`str, optional`): The root node of the tree. If None, traversal starts from the root. +* **cutline** (`int, optional`): The cutline level. + +**Returns** + +* `List[List[str]]`: List of community bags. + +#### bag\_nodes\_by\_cutline + +```python +def bag_nodes_by_cutline(tree: nx.Graph, cutline: int = 2, root: str = "Root") +``` + +Bag nodes of a tree by a cutline. + +**Arguments** + +* **tree** (`nx.Graph`): The tree to be bagged. +* **cutline** (`int, optional`): The cutline level. Defaults to 2. +* **root** (`str, optional`): The root node of the tree. Defaults to 'Root'. + +**Returns** + +* `List[List[str]]`: List of bags of nodes. + diff --git a/docs/vame-docs-app/docs/reference/analysis/umap.md b/docs/vame-docs-app/docs/reference/analysis/umap.md new file mode 100644 index 00000000..6b306602 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/analysis/umap.md @@ -0,0 +1,123 @@ +--- +sidebar_label: umap +title: analysis.umap +--- + +#### logger\_config + +#### logger + +#### umap\_embedding + +```python +def umap_embedding( + cfg: dict, session: str, model_name: str, n_clusters: int, + segmentation_algorithm: SegmentationAlgorithms) -> np.ndarray +``` + +Perform UMAP embedding for given file and parameters. + +**Arguments** + +* **cfg** (`dict`): Configuration parameters. +* **session** (`str`): Session name. +* **model_name** (`str`): Model name. +* **n_clusters** (`int`): Number of clusters. +* **segmentation_algorithm** (`str`): Segmentation algorithm. + +**Returns** + +* `np.ndarray`: UMAP embedding. + +#### umap\_vis + +```python +def umap_vis(embed: np.ndarray, num_points: int) -> plt.Figure +``` + +Visualize UMAP embedding without labels. + +**Arguments** + +* **embed** (`np.ndarray`): UMAP embedding. +* **num_points** (`int`): Number of data points to visualize. + +**Returns** + +* `plt.Figure`: Plot Visualization of UMAP embedding. + +#### umap\_label\_vis + +```python +def umap_label_vis(embed: np.ndarray, label: np.ndarray, + num_points: int) -> plt.Figure +``` + +Visualize UMAP embedding with motif labels. + +**Arguments** + +* **embed** (`np.ndarray`): UMAP embedding. +* **label** (`np.ndarray`): Motif labels. +* **num_points** (`int`): Number of data points to visualize. + +**Returns** + +* `plt.Figure`: Plot figure of UMAP visualization embedding with motif labels. + +#### umap\_vis\_comm + +```python +def umap_vis_comm(embed: np.ndarray, community_label: np.ndarray, + num_points: int) -> plt.Figure +``` + +Visualize UMAP embedding with community labels. + +**Arguments** + +* **embed** (`np.ndarray`): UMAP embedding. +* **community_label** (`np.ndarray`): Community labels. +* **num_points** (`int`): Number of data points to visualize. + +**Returns** + +* `plt.Figure`: Plot figure of UMAP visualization embedding with community labels. + +#### visualization + +```python +@save_state(model=VisualizationFunctionSchema) +def visualization(config: dict, + segmentation_algorithm: SegmentationAlgorithms, + label: Optional[str] = None, + save_logs: bool = False) -> None +``` + +Visualize UMAP embeddings based on configuration settings. +Fills in the values in the "visualization" key of the states.json file. +Saves results files at: + +If label is None (UMAP visualization without labels): +- project_name/ + - results/ + - file_name/ + - model_name/ + - segmentation_algorithm-n_clusters/ + - community/ + - umap_embedding_file_name.npy + - umap_vis_label_none_file_name.png (UMAP visualization without labels) + - umap_vis_motif_file_name.png (UMAP visualization with motif labels) + - umap_vis_community_file_name.png (UMAP visualization with community labels) + +**Arguments** + +* **config** (`dict`): Configuration parameters. +* **segmentation_algorithm** (`SegmentationAlgorithms`): Which segmentation algorithm to use. Options are 'hmm' or 'kmeans'. +* **label** (`str, optional`): Type of labels to visualize. Options are None, 'motif' or 'community'. Default is None. +* **save_logs** (`bool, optional`): Save logs to file. Default is False. + +**Returns** + +* `None` + diff --git a/docs/vame-docs-app/docs/reference/analysis/videowriter.md b/docs/vame-docs-app/docs/reference/analysis/videowriter.md new file mode 100644 index 00000000..fa570458 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/analysis/videowriter.md @@ -0,0 +1,123 @@ +--- +sidebar_label: videowriter +title: analysis.videowriter +--- + +#### logger\_config + +#### logger + +#### create\_cluster\_videos + +```python +def create_cluster_videos( + config: dict, + path_to_file: str, + session: str, + n_clusters: int, + video_type: str, + flag: str, + segmentation_algorithm: SegmentationAlgorithms, + cohort: bool = True, + output_video_type: str = ".mp4", + tqdm_logger_stream: Union[TqdmToLogger, None] = None) -> None +``` + +Generate cluster videos and save them to filesystem on project folder. + +**Arguments** + +* **config** (`dict`): Configuration parameters. +* **path_to_file** (`str`): Path to the file. +* **session** (`str`): Name of the session. +* **n_clusters** (`int`): Number of clusters. +* **video_type** (`str`): Type of input video. +* **flag** (`str`): Flag indicating the type of video (motif or community). +* **segmentation_algorithm** (`SegmentationAlgorithms`): Which segmentation algorithm to use. Options are 'hmm' or 'kmeans'. +* **cohort** (`bool, optional`): Flag indicating cohort analysis. Defaults to True. +* **output_video_type** (`str, optional`): Type of output video. Default is '.mp4'. +* **tqdm_logger_stream** (`TqdmToLogger, optional`): Tqdm logger stream. Default is None. + +**Returns** + +* `None` + +#### motif\_videos + +```python +@save_state(model=MotifVideosFunctionSchema) +def motif_videos(config: dict, + segmentation_algorithm: SegmentationAlgorithms, + video_type: str = ".mp4", + output_video_type: str = ".mp4", + save_logs: bool = False) -> None +``` + +Generate motif videos and save them to filesystem. +Fills in the values in the "motif_videos" key of the states.json file. +Files are saved at: +- project_name/ + - results/ + - session/ + - model_name/ + - segmentation_algorithm-n_clusters/ + - cluster_videos/ + - session-motif_0.mp4 + - session-motif_1.mp4 + - ... + +**Arguments** + +* **config** (`dict`): Configuration parameters. +* **segmentation_algorithm** (`SegmentationAlgorithms`): Which segmentation algorithm to use. Options are 'hmm' or 'kmeans'. +* **video_type** (`str, optional`): Type of video. Default is '.mp4'. +* **output_video_type** (`str, optional`): Type of output video. Default is '.mp4'. +* **save_logs** (`bool, optional`): Save logs to filesystem. Default is False. + +**Returns** + +* `None` + +#### community\_videos + +```python +@save_state(model=CommunityVideosFunctionSchema) +def community_videos(config: dict, + segmentation_algorithm: SegmentationAlgorithms, + cohort: bool = True, + video_type: str = ".mp4", + save_logs: bool = False, + output_video_type: str = ".mp4") -> None +``` + +Generate community videos and save them to filesystem on project community_videos folder. +Fills in the values in the "community_videos" key of the states.json file. +Files are saved at: + +1. If cohort is True: +TODO: Add cohort analysis + +2. If cohort is False: +- project_name/ + - results/ + - file_name/ + - model_name/ + - segmentation_algorithm-n_clusters/ + - community_videos/ + - file_name-community_0.mp4 + - file_name-community_1.mp4 + - ... + +**Arguments** + +* **config** (`dict`): Configuration parameters. +* **segmentation_algorithm** (`SegmentationAlgorithms`): Which segmentation algorithm to use. Options are 'hmm' or 'kmeans'. +* **cohort** (`bool, optional`): Flag indicating cohort analysis. Defaults to True. +* **video_type** (`str, optional`): Type of video. Default is '.mp4'. +* **save_logs** (`bool, optional`): Save logs to filesystem. Default is False. +* **output_video_type** (`str, optional`): Type of output video. Default is '.mp4'. + +**Returns** + +* `None` + diff --git a/docs/vame-docs-app/docs/reference/initialize_project/__init__.md b/docs/vame-docs-app/docs/reference/initialize_project/__init__.md new file mode 100644 index 00000000..193558b0 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/initialize_project/__init__.md @@ -0,0 +1,5 @@ +--- +sidebar_label: initialize_project +title: initialize_project +--- + diff --git a/docs/vame-docs-app/docs/reference/initialize_project/new.md b/docs/vame-docs-app/docs/reference/initialize_project/new.md new file mode 100644 index 00000000..5c62c0d6 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/initialize_project/new.md @@ -0,0 +1,66 @@ +--- +sidebar_label: new +title: initialize_project.new +--- + +#### logger\_config + +#### logger + +#### init\_new\_project + +```python +def init_new_project(project_name: str, + videos: List[str], + poses_estimations: List[str], + source_software: Literal["DeepLabCut", "SLEAP", + "LightningPose"], + working_directory: str = ".", + video_type: str = ".mp4", + fps: int | None = None, + copy_videos: bool = False, + paths_to_pose_nwb_series_data: Optional[str] = None, + config_kwargs: Optional[dict] = None) -> Tuple[str, dict] +``` + +Creates a new VAME project with the given parameters. +A VAME project is a directory with the following structure: +- project_name/ + - data/ + - raw/ + - session1.mp4 + - session1.nc + - session2.mp4 + - session2.nc + - ... + - processed/ + - session1_processed.nc + - session2_processed.nc + - ... + - model/ + - pretrained_model/ + - results/ + - video1/ + - video2/ + - ... + - states/ + - states.json + - config.yaml + +**Arguments** + +* **project_name** (`str`): Project name. +* **videos** (`List[str]`): List of videos paths to be used in the project. E.g. ['./sample_data/Session001.mp4'] +* **poses_estimations** (`List[str]`): List of pose estimation files paths to be used in the project. E.g. ['./sample_data/pose estimation/Session001.csv'] +* **source_software** (`Literal["DeepLabCut", "SLEAP", "LightningPose"]`): Source software used for pose estimation. +* **working_directory** (`str, optional`): Working directory. Defaults to '.'. +* **video_type** (`str, optional`): Video extension (.mp4 or .avi). Defaults to '.mp4'. +* **fps** (`int, optional`): Sampling rate of the video. If not passed, it will be estimated from the video file. Defaults to None. +* **copy_videos** (`bool, optional`): If True, the videos will be copied to the project directory. If False, symbolic links will be created instead. Defaults to False. +* **paths_to_pose_nwb_series_data** (`Optional[str], optional`): List of paths to the pose series data in nwb files. Defaults to None. +* **config_kwargs** (`Optional[dict], optional`): Additional configuration parameters. Defaults to None. + +**Returns** + +* `Tuple[str, dict]`: Tuple containing the path to the config file and the config data. + diff --git a/docs/vame-docs-app/docs/reference/io/__init__.md b/docs/vame-docs-app/docs/reference/io/__init__.md new file mode 100644 index 00000000..26c3b3a8 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/io/__init__.md @@ -0,0 +1,5 @@ +--- +sidebar_label: io +title: io +--- + diff --git a/docs/vame-docs-app/docs/reference/io/load_poses.md b/docs/vame-docs-app/docs/reference/io/load_poses.md new file mode 100644 index 00000000..7fc29412 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/io/load_poses.md @@ -0,0 +1,71 @@ +--- +sidebar_label: load_poses +title: io.load_poses +--- + +#### load\_pose\_estimation + +```python +def load_pose_estimation( + pose_estimation_file: Path | str, video_file: Path | str, fps: int, + source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"] +) -> xr.Dataset +``` + +Load pose estimation data. + +**Arguments** + +* **pose_estimation_file** (`Path or str`): Path to the pose estimation file. +* **video_file** (`Path or str`): Path to the video file. +* **fps** (`int`): Sampling rate of the video. +* **source_software** (`Literal["DeepLabCut", "SLEAP", "LightningPose"]`): Source software used for pose estimation. + +**Returns** + +* **ds** (`xarray.Dataset`): Pose estimation dataset. + +#### load\_vame\_dataset + +```python +def load_vame_dataset(ds_path: Path | str) -> xr.Dataset +``` + +Load VAME dataset. + +**Arguments** + +* **ds_path** (`Path or str`): Path to the netCDF dataset. + +**Returns** + +* `xr.Dataset`: VAME dataset + +#### nc\_to\_dataframe + +```python +def nc_to_dataframe(nc_data) +``` + +#### read\_pose\_estimation\_file + +```python +def read_pose_estimation_file( + file_path: str, + file_type: Optional[PoseEstimationFiletype] = None, + path_to_pose_nwb_series_data: Optional[str] = None +) -> Tuple[pd.DataFrame, np.ndarray, xr.Dataset] +``` + +Read pose estimation file. + +**Arguments** + +* **file_path** (`str`): Path to the pose estimation file. +* **file_type** (`PoseEstimationFiletype`): Type of the pose estimation file. Supported types are 'csv' and 'nwb'. +* **path_to_pose_nwb_series_data** (`str, optional`): Path to the pose data inside the nwb file, by default None + +**Returns** + +* `Tuple[pd.DataFrame, np.ndarray]`: Tuple containing the pose estimation data as a pandas DataFrame and a numpy array. + diff --git a/docs/vame-docs-app/docs/reference/io/nwb.md b/docs/vame-docs-app/docs/reference/io/nwb.md new file mode 100644 index 00000000..2096edc6 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/io/nwb.md @@ -0,0 +1,5 @@ +--- +sidebar_label: nwb +title: io.nwb +--- + diff --git a/docs/vame-docs-app/docs/reference/model/__init__.md b/docs/vame-docs-app/docs/reference/model/__init__.md new file mode 100644 index 00000000..14b9c8f3 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/model/__init__.md @@ -0,0 +1,5 @@ +--- +sidebar_label: model +title: model +--- + diff --git a/docs/vame-docs-app/docs/reference/model/create_training.md b/docs/vame-docs-app/docs/reference/model/create_training.md new file mode 100644 index 00000000..fd9e8e24 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/model/create_training.md @@ -0,0 +1,67 @@ +--- +sidebar_label: create_training +title: model.create_training +--- + +#### logger\_config + +#### logger + +#### traindata\_aligned + +```python +def traindata_aligned(config: dict, + sessions: List[str] | None = None, + test_fraction: float | None = None, + read_from_variable: str = "position_processed") -> None +``` + +Create training dataset for aligned data. +Save numpy arrays with the test/train info to the project folder. + +**Arguments** + +* **config** (`dict`): Configuration parameters dictionary. +* **sessions** (`List[str], optional`): List of session names. If None, all sessions will be used. Defaults to None. +* **test_fraction** (`float, optional`): Fraction of data to use as test data. Defaults to 0.1. + +**Returns** + +* `None` + +#### create\_trainset + +```python +@save_state(model=CreateTrainsetFunctionSchema) +def create_trainset(config: dict, save_logs: bool = False) -> None +``` + +Creates a training and test datasets for the VAME model. +Fills in the values in the "create_trainset" key of the states.json file. +Creates the training dataset for VAME at: +- project_name/ + - data/ + - session00/ + - session00-PE-seq-clean.npy + - session01/ + - session01-PE-seq-clean.npy + - train/ + - test_seq.npy + - train_seq.npy + +The produced -clean.npy files contain the aligned time series data in the +shape of (num_dlc_features - 2, num_video_frames). + +The produced test_seq.npy contains the combined data in the shape of (num_dlc_features - 2, num_video_frames * test_fraction). + +The produced train_seq.npy contains the combined data in the shape of (num_dlc_features - 2, num_video_frames * (1 - test_fraction)). + +**Arguments** + +* **config** (`dict`): Configuration parameters dictionary. +* **save_logs** (`bool, optional`): If True, the function will save logs to the project folder. Defaults to False. + +**Returns** + +* `None` + diff --git a/docs/vame-docs-app/docs/reference/model/dataloader.md b/docs/vame-docs-app/docs/reference/model/dataloader.md new file mode 100644 index 00000000..bce649f9 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/model/dataloader.md @@ -0,0 +1,65 @@ +--- +sidebar_label: dataloader +title: model.dataloader +--- + +## SEQUENCE\_DATASET Objects + +```python +class SEQUENCE_DATASET(Dataset) +``` + +#### \_\_init\_\_ + +```python +def __init__(path_to_file: str, data: str, train: bool, temporal_window: int, + **kwargs) -> None +``` + +Initialize the Sequence Dataset. +Creates files at: +- project_name/ +- data/ + - train/ + - seq_mean.npy + - seq_std.npy + +**Arguments** + +* **path_to_file** (`str`): Path to the dataset files. +* **data** (`str`): Name of the data file. +* **train** (`bool`): Flag indicating whether it's training data. +* **temporal_window** (`int`): Size of the temporal window. + +**Returns** + +* `None` + +#### \_\_len\_\_ + +```python +def __len__() -> int +``` + +Return the number of data points. + +**Returns** + +* `int`: Number of data points. + +#### \_\_getitem\_\_ + +```python +def __getitem__(index: int) -> torch.Tensor +``` + +Get a normalized sequence at the specified index. + +**Arguments** + +* **index** (`int`): Index of the item. + +**Returns** + +* `torch.Tensor`: Normalized sequence data at the specified index. + diff --git a/docs/vame-docs-app/docs/reference/model/evaluate.md b/docs/vame-docs-app/docs/reference/model/evaluate.md new file mode 100644 index 00000000..28f7e972 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/model/evaluate.md @@ -0,0 +1,121 @@ +--- +sidebar_label: evaluate +title: model.evaluate +--- + +#### logger\_config + +#### logger + +#### use\_gpu + +#### plot\_reconstruction + +```python +def plot_reconstruction(filepath: str, + test_loader: Data.DataLoader, + seq_len_half: int, + model: RNN_VAE, + model_name: str, + FUTURE_DECODER: bool, + FUTURE_STEPS: int, + suffix: Optional[str] = None) -> None +``` + +Plot the reconstruction and future prediction of the input sequence. +Saves the plot to: +- project_name/ + - model/ + - evaluate/ + - Reconstruction_model_name.png + +**Arguments** + +* **filepath** (`str`): Path to save the plot. +* **test_loader** (`Data.DataLoader`): DataLoader for the test dataset. +* **seq_len_half** (`int`): Half of the temporal window size. +* **model** (`RNN_VAE`): Trained VAE model. +* **model_name** (`str`): Name of the model. +* **FUTURE_DECODER** (`bool`): Flag indicating whether the model has a future prediction decoder. +* **FUTURE_STEPS** (`int`): Number of future steps to predict. +* **suffix** (`str, optional`): Suffix for the saved plot filename. Defaults to None. + +**Returns** + +* `None` + +#### plot\_loss + +```python +def plot_loss(cfg: dict, filepath: str, model_name: str) -> None +``` + +Plot the losses of the trained model. +Saves the plot to: +- project_name/ + - model/ + - evaluate/ + - MSE-and-KL-Loss_model_name.png + +**Arguments** + +* **cfg** (`dict`): Configuration dictionary. +* **filepath** (`str`): Path to save the plot. +* **model_name** (`str`): Name of the model. + +**Returns** + +* `None` + +#### eval\_temporal + +```python +def eval_temporal(cfg: dict, + use_gpu: bool, + model_name: str, + fixed: bool, + snapshot: Optional[str] = None, + suffix: Optional[str] = None) -> None +``` + +Evaluate the temporal aspects of the trained model. + +**Arguments** + +* **cfg** (`dict`): Configuration dictionary. +* **use_gpu** (`bool`): Flag indicating whether to use GPU for evaluation. +* **model_name** (`str`): Name of the model. +* **fixed** (`bool`): Flag indicating whether the data is fixed or not. +* **snapshot** (`str, optional`): Path to the model snapshot. Defaults to None. +* **suffix** (`str, optional`): Suffix for the saved plot filename. Defaults to None. + +**Returns** + +* `None` + +#### evaluate\_model + +```python +@save_state(model=EvaluateModelFunctionSchema) +def evaluate_model(config: dict, + use_snapshots: bool = False, + save_logs: bool = False) -> None +``` + +Evaluate the trained model. +Fills in the values in the "evaluate_model" key of the states.json file. +Saves the evaluation results to: +- project_name/ + - model/ + - evaluate/ + +**Arguments** + +* **config** (`dict`): Configuration dictionary. +* **use_snapshots** (`bool, optional`): Whether to plot for all snapshots or only the best model. Defaults to False. +* **save_logs** (`bool, optional`): Flag indicating whether to save logs. Defaults to False. + +**Returns** + +* `None` + diff --git a/docs/vame-docs-app/docs/reference/model/rnn_model.md b/docs/vame-docs-app/docs/reference/model/rnn_model.md new file mode 100644 index 00000000..9c6b9773 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/model/rnn_model.md @@ -0,0 +1,230 @@ +--- +sidebar_label: rnn_model +title: model.rnn_model +--- + +## Encoder Objects + +```python +class Encoder(nn.Module) +``` + +Encoder module of the Variational Autoencoder. + +#### \_\_init\_\_ + +```python +def __init__(NUM_FEATURES: int, hidden_size_layer_1: int, + hidden_size_layer_2: int, dropout_encoder: float) +``` + +Initialize the Encoder module. + +**Arguments** + +* **NUM_FEATURES** (`int`): Number of input features. +* **hidden_size_layer_1** (`int`): Size of the first hidden layer. +* **hidden_size_layer_2** (`int`): Size of the second hidden layer. +* **dropout_encoder** (`float`): Dropout rate for regularization. + +#### forward + +```python +def forward(inputs: torch.Tensor) -> torch.Tensor +``` + +Forward pass of the Encoder module. + +**Arguments** + +* **inputs** (`torch.Tensor`): Input tensor of shape (batch_size, sequence_length, num_features). + +**Returns** + +* `torch.Tensor:`: Encoded representation tensor of shape (batch_size, hidden_size_layer_1 * 4). + +## Lambda Objects + +```python +class Lambda(nn.Module) +``` + +Lambda module for computing the latent space parameters. + +#### \_\_init\_\_ + +```python +def __init__(ZDIMS: int, hidden_size_layer_1: int, softplus: bool) +``` + +Initialize the Lambda module. + +**Arguments** + +* **ZDIMS** (`int`): Size of the latent space. +* **hidden_size_layer_1** (`int`): Size of the first hidden layer. +* **hidden_size_layer_2** (`int, deprecated`): Size of the second hidden layer. +* **softplus** (`bool`): Whether to use softplus activation for logvar. + +#### forward + +```python +def forward( + hidden: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] +``` + +Forward pass of the Lambda module. + +**Arguments** + +* **hidden** (`torch.Tensor`): Hidden representation tensor of shape (batch_size, hidden_size_layer_1 * 4). + +**Returns** + +* `tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: Latent space tensor, mean tensor, logvar tensor. + +## Decoder Objects + +```python +class Decoder(nn.Module) +``` + +Decoder module of the Variational Autoencoder. + +#### \_\_init\_\_ + +```python +def __init__(TEMPORAL_WINDOW: int, ZDIMS: int, NUM_FEATURES: int, + hidden_size_rec: int, dropout_rec: float) +``` + +Initialize the Decoder module. + +**Arguments** + +* **TEMPORAL_WINDOW** (`int`): Size of the temporal window. +* **ZDIMS** (`int`): Size of the latent space. +* **NUM_FEATURES** (`int`): Number of input features. +* **hidden_size_rec** (`int`): Size of the recurrent hidden layer. +* **dropout_rec** (`float`): Dropout rate for regularization. + +#### forward + +```python +def forward(inputs: torch.Tensor, z: torch.Tensor) -> torch.Tensor +``` + +Forward pass of the Decoder module. + +**Arguments** + +* **inputs** (`torch.Tensor`): Input tensor of shape (batch_size, seq_len, ZDIMS). +* **z** (`torch.Tensor`): Latent space tensor of shape (batch_size, ZDIMS). + +**Returns** + +* `torch.Tensor:`: Decoded output tensor of shape (batch_size, seq_len, NUM_FEATURES). + +## Decoder\_Future Objects + +```python +class Decoder_Future(nn.Module) +``` + +Decoder module for predicting future sequences. + +#### \_\_init\_\_ + +```python +def __init__(TEMPORAL_WINDOW: int, ZDIMS: int, NUM_FEATURES: int, + FUTURE_STEPS: int, hidden_size_pred: int, dropout_pred: float) +``` + +Initialize the Decoder_Future module. + +**Arguments** + +* **TEMPORAL_WINDOW** (`int`): Size of the temporal window. +* **ZDIMS** (`int`): Size of the latent space. +* **NUM_FEATURES** (`int`): Number of input features. +* **FUTURE_STEPS** (`int`): Number of future steps to predict. +* **hidden_size_pred** (`int`): Size of the prediction hidden layer. +* **dropout_pred** (`float`): Dropout rate for regularization. + +#### forward + +```python +def forward(inputs: torch.Tensor, z: torch.Tensor) -> torch.Tensor +``` + +Forward pass of the Decoder_Future module. + +**Arguments** + +* **inputs** (`torch.Tensor`): Input tensor of shape (batch_size, seq_len, ZDIMS). +* **z** (`torch.Tensor`): Latent space tensor of shape (batch_size, ZDIMS). + +**Returns** + +* `torch.Tensor:`: Predicted future tensor of shape (batch_size, FUTURE_STEPS, NUM_FEATURES). + +## RNN\_VAE Objects + +```python +class RNN_VAE(nn.Module) +``` + +Variational Autoencoder module. + +#### \_\_init\_\_ + +```python +def __init__(TEMPORAL_WINDOW: int, ZDIMS: int, NUM_FEATURES: int, + FUTURE_DECODER: bool, FUTURE_STEPS: int, hidden_size_layer_1: int, + hidden_size_layer_2: int, hidden_size_rec: int, + hidden_size_pred: int, dropout_encoder: float, dropout_rec: float, + dropout_pred: float, softplus: bool) +``` + +Initialize the VAE module. + +**Arguments** + +* **TEMPORAL_WINDOW** (`int`): Size of the temporal window. +* **ZDIMS** (`int`): Size of the latent space. +* **NUM_FEATURES** (`int`): Number of input features. +* **FUTURE_DECODER** (`bool`): Whether to include a future decoder. +* **FUTURE_STEPS** (`int`): Number of future steps to predict. +* **hidden_size_layer_1** (`int`): Size of the first hidden layer. +* **hidden_size_layer_2** (`int`): Size of the second hidden layer. +* **hidden_size_rec** (`int`): Size of the recurrent hidden layer. +* **hidden_size_pred** (`int`): Size of the prediction hidden layer. +* **dropout_encoder** (`float`): Dropout rate for encoder. + +#### forward + +```python +def forward(seq: torch.Tensor) -> tuple +``` + +Forward pass of the VAE. + +**Arguments** + +* **seq** (`torch.Tensor`): Input sequence tensor of shape (batch_size, seq_len, NUM_FEATURES). + +**Returns** + +* `Tuple containing:`: - If FUTURE_DECODER is True: + - prediction (torch.Tensor): Reconstructed input sequence tensor. + - future (torch.Tensor): Predicted future sequence tensor. + - z (torch.Tensor): Latent representation tensor. + - mu (torch.Tensor): Mean of the latent distribution tensor. + - logvar (torch.Tensor): Log variance of the latent distribution tensor. +- If FUTURE_DECODER is False: + - prediction (torch.Tensor): Reconstructed input sequence tensor. + - z (torch.Tensor): Latent representation tensor. + - mu (torch.Tensor): Mean of the latent distribution tensor. + - logvar (torch.Tensor): Log variance of the latent distribution tensor. + diff --git a/docs/vame-docs-app/docs/reference/model/rnn_vae.md b/docs/vame-docs-app/docs/reference/model/rnn_vae.md new file mode 100644 index 00000000..4b308601 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/model/rnn_vae.md @@ -0,0 +1,240 @@ +--- +sidebar_label: rnn_vae +title: model.rnn_vae +--- + +#### logger\_config + +#### logger + +#### tqdm\_to\_logger + +#### use\_gpu + +#### reconstruction\_loss + +```python +def reconstruction_loss(x: torch.Tensor, x_tilde: torch.Tensor, + reduction: str) -> torch.Tensor +``` + +Compute the reconstruction loss between input and reconstructed data. + +**Arguments** + +* **x** (`torch.Tensor`): Input data tensor. +* **x_tilde** (`torch.Tensor`): Reconstructed data tensor. +* **reduction** (`str`): Type of reduction for the loss. + +**Returns** + +* `torch.Tensor`: Reconstruction loss. + +#### future\_reconstruction\_loss + +```python +def future_reconstruction_loss(x: torch.Tensor, x_tilde: torch.Tensor, + reduction: str) -> torch.Tensor +``` + +Compute the future reconstruction loss between input and predicted future data. + +**Arguments** + +* **x** (`torch.Tensor`): Input future data tensor. +* **x_tilde** (`torch.Tensor`): Reconstructed future data tensor. +* **reduction** (`str`): Type of reduction for the loss. + +**Returns** + +* `torch.Tensor`: Future reconstruction loss. + +#### cluster\_loss + +```python +def cluster_loss(H: torch.Tensor, kloss: int, lmbda: float, + batch_size: int) -> torch.Tensor +``` + +Compute the cluster loss. + +**Arguments** + +* **H** (`torch.Tensor`): Latent representation tensor. +* **kloss** (`int`): Number of clusters. +* **lmbda** (`float`): Lambda value for the loss. +* **batch_size** (`int`): Size of the batch. + +**Returns** + +* `torch.Tensor`: Cluster loss. + +#### kullback\_leibler\_loss + +```python +def kullback_leibler_loss(mu: torch.Tensor, + logvar: torch.Tensor) -> torch.Tensor +``` + +Compute the Kullback-Leibler divergence loss. +See Appendix B from VAE paper: Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 - https://arxiv.org/abs/1312.6114 + +Formula: 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + +**Arguments** + +* **mu** (`torch.Tensor`): Mean of the latent distribution. +* **logvar** (`torch.Tensor`): Log variance of the latent distribution. + +**Returns** + +* `torch.Tensor`: Kullback-Leibler divergence loss. + +#### kl\_annealing + +```python +def kl_annealing(epoch: int, kl_start: int, annealtime: int, + function: str) -> float +``` + +Anneal the Kullback-Leibler loss to let the model learn first the reconstruction of the data +before the KL loss term gets introduced. + +**Arguments** + +* **epoch** (`int`): Current epoch number. +* **kl_start** (`int`): Epoch number to start annealing the loss. +* **annealtime** (`int`): Annealing time. +* **function** (`str`): Annealing function type. + +**Returns** + +* `float`: Annealed weight value for the loss. + +#### gaussian + +```python +def gaussian(ins: torch.Tensor, + is_training: bool, + seq_len: int, + std_n: float = 0.8) -> torch.Tensor +``` + +Add Gaussian noise to the input data. + +**Arguments** + +* **ins** (`torch.Tensor`): Input data tensor. +* **is_training** (`bool`): Whether it is training mode. +* **seq_len** (`int`): Length of the sequence. +* **std_n** (`float`): Standard deviation for the Gaussian noise. + +**Returns** + +* `torch.Tensor`: Noisy input data tensor. + +#### train + +```python +def train(train_loader: Data.DataLoader, epoch: int, model: nn.Module, + optimizer: torch.optim.Optimizer, anneal_function: str, BETA: float, + kl_start: int, annealtime: int, seq_len: int, future_decoder: bool, + future_steps: int, scheduler: torch.optim.lr_scheduler._LRScheduler, + mse_red: str, mse_pred: str, kloss: int, klmbda: float, bsize: int, + noise: bool) -> Tuple[float, float, float, float, float, float] +``` + +Train the model. + +**Arguments** + +* **train_loader** (`DataLoader`): Training data loader. +* **epoch** (`int`): Current epoch number. +* **model** (`nn.Module`): Model to be trained. +* **optimizer** (`Optimizer`): Optimizer for training. +* **anneal_function** (`str`): Annealing function type. +* **BETA** (`float`): Beta value for the loss. +* **kl_start** (`int`): Epoch number to start annealing the loss. +* **annealtime** (`int`): Annealing time. +* **seq_len** (`int`): Length of the sequence. +* **future_decoder** (`bool`): Whether a future decoder is used. +* **future_steps** (`int`): Number of future steps to predict. +* **scheduler** (`lr_scheduler._LRScheduler`): Learning rate scheduler. +* **mse_red** (`str`): Reduction type for MSE reconstruction loss. +* **mse_pred** (`str`): Reduction type for MSE prediction loss. +* **kloss** (`int`): Number of clusters for cluster loss. +* **klmbda** (`float`): Lambda value for cluster loss. +* **bsize** (`int`): Size of the batch. +* **noise** (`bool`): Whether to add Gaussian noise to the input. + +**Returns** + +* `Tuple[float, float, float, float, float, float]`: Kullback-Leibler weight, train loss, K-means loss, KL loss, +MSE loss, future loss. + +#### test + +```python +def test(test_loader: Data.DataLoader, model: nn.Module, BETA: float, + kl_weight: float, seq_len: int, mse_red: str, kloss: str, + klmbda: float, future_decoder: bool, + bsize: int) -> Tuple[float, float, float] +``` + +Evaluate the model on the test dataset. + +**Arguments** + +* **test_loader** (`DataLoader`): DataLoader for the test dataset. +* **model** (`nn.Module`): The trained model. +* **BETA** (`float`): Beta value for the VAE loss. +* **kl_weight** (`float`): Weighting factor for the KL divergence loss. +* **seq_len** (`int`): Length of the sequence. +* **mse_red** (`str`): Reduction method for the MSE loss. +* **kloss** (`str`): Loss function for K-means clustering. +* **klmbda** (`float`): Lambda value for K-means loss. +* **future_decoder** (`bool`): Flag indicating whether to use a future decoder. +* **bsize :int**: Batch size. + +**Returns** + +* `Tuple[float, float, float]`: Tuple containing MSE loss per item, total test loss per item, +and K-means loss weighted by the kl_weight. + +#### train\_model + +```python +@save_state(model=TrainModelFunctionSchema) +def train_model(config: dict, save_logs: bool = False) -> None +``` + +Train Variational Autoencoder using the configuration file values. +Fills in the values in the "train_model" key of the states.json file. +Creates files at: +- project_name/ + - model/ + - best_model/ + - snapshots/ + - model_name_Project_epoch_0.pkl + - ... + - model_name_Project.pkl + - model_losses/ + - fut_losses_VAME.npy + - kl_losses_VAME.npy + - kmeans_losses_VAME.npy + - mse_test_losses_VAME.npy + - mse_train_losses_VAME.npy + - test_losses_VAME.npy + - train_losses_VAME.npy + - weight_values_VAME.npy + - pretrained_model/ + +**Arguments** + +* **config** (`dict`): Configuration dictionary. +* **save_logs** (`bool, optional`): Whether to save the logs, by default False. + +**Returns** + +* `None` + diff --git a/docs/vame-docs-app/docs/reference/pipeline.md b/docs/vame-docs-app/docs/reference/pipeline.md new file mode 100644 index 00000000..b3a3996a --- /dev/null +++ b/docs/vame-docs-app/docs/reference/pipeline.md @@ -0,0 +1,139 @@ +--- +sidebar_label: pipeline +title: pipeline +--- + +#### logger\_config + +#### logger + +## VAMEPipeline Objects + +```python +class VAMEPipeline() +``` + +#### \_\_init\_\_ + +```python +def __init__(project_name: str, + videos: List[str], + poses_estimations: List[str], + source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"], + working_directory: str = ".", + video_type: str = ".mp4", + fps: int | None = None, + copy_videos: bool = False, + paths_to_pose_nwb_series_data: Optional[str] = None, + config_kwargs: Optional[dict] = None) +``` + +#### get\_sessions + +```python +def get_sessions() -> List[str] +``` + +Returns a list of session names. + +**Returns** + +* `List[str]`: Session names. + +#### get\_raw\_datasets + +```python +def get_raw_datasets() -> xr.Dataset +``` + +Returns a xarray dataset which combines all the raw data from the project. + +**Returns** + +* **dss** (`xarray.Dataset`): Combined raw dataset. + +#### preprocessing + +```python +def preprocessing(centered_reference_keypoint: str = "snout", + orientation_reference_keypoint: str = "tailbase") +``` + +#### create\_training\_set + +```python +def create_training_set() +``` + +#### train\_model + +```python +def train_model() +``` + +#### evaluate\_model + +```python +def evaluate_model() +``` + +#### run\_segmentation + +```python +def run_segmentation() +``` + +#### generate\_motif\_videos + +```python +def generate_motif_videos() +``` + +#### run\_community\_clustering + +```python +def run_community_clustering() +``` + +#### generate\_community\_videos + +```python +def generate_community_videos() +``` + +#### visualization + +```python +def visualization() +``` + +#### report + +```python +def report() +``` + +#### get\_states + +```python +def get_states(summary: bool = True) -> dict +``` + +Returns the pipeline states. + +**Returns** + +* `dict`: Pipeline states. + +#### run\_pipeline + +```python +def run_pipeline(from_step: int = 0, preprocessing_kwargs: dict = {}) +``` + +#### unique\_in\_order + +```python +def unique_in_order(sequence) +``` + diff --git a/docs/vame-docs-app/docs/reference/preprocessing/__init__.md b/docs/vame-docs-app/docs/reference/preprocessing/__init__.md new file mode 100644 index 00000000..4a11d213 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/preprocessing/__init__.md @@ -0,0 +1,5 @@ +--- +sidebar_label: preprocessing +title: preprocessing +--- + diff --git a/docs/vame-docs-app/docs/reference/preprocessing/align_egocentrical_legacy.md b/docs/vame-docs-app/docs/reference/preprocessing/align_egocentrical_legacy.md new file mode 100644 index 00000000..55435de5 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/preprocessing/align_egocentrical_legacy.md @@ -0,0 +1,5 @@ +--- +sidebar_label: align_egocentrical_legacy +title: preprocessing.align_egocentrical_legacy +--- + diff --git a/docs/vame-docs-app/docs/reference/preprocessing/alignment.md b/docs/vame-docs-app/docs/reference/preprocessing/alignment.md new file mode 100644 index 00000000..5e77c351 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/preprocessing/alignment.md @@ -0,0 +1,33 @@ +--- +sidebar_label: alignment +title: preprocessing.alignment +--- + +#### logger\_config + +#### logger + +#### egocentrically\_align\_and\_center + +```python +def egocentrically_align_and_center( + config: dict, + centered_reference_keypoint: str = "snout", + orientation_reference_keypoint: str = "tailbase", + read_from_variable: str = "position_processed", + save_to_variable: str = "position_egocentric_aligned") -> None +``` + +Aligns the time series by first centralizing all positions around the first keypoint +and then applying rotation to align with the line connecting the two keypoints. + +**Arguments** + +* **config** (`dict`): Configuration dictionary +* **centered_reference_keypoint** (`str`): Name of the keypoint to use as centered reference. +* **orientation_reference_keypoint** (`str`): Name of the keypoint to use as orientation reference. + +**Returns** + +* `None` + diff --git a/docs/vame-docs-app/docs/reference/preprocessing/cleaning.md b/docs/vame-docs-app/docs/reference/preprocessing/cleaning.md new file mode 100644 index 00000000..1959e5f5 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/preprocessing/cleaning.md @@ -0,0 +1,53 @@ +--- +sidebar_label: cleaning +title: preprocessing.cleaning +--- + +#### logger\_config + +#### logger + +#### lowconf\_cleaning + +```python +def lowconf_cleaning(config: dict, + read_from_variable: str = "position_processed", + save_to_variable: str = "position_processed") -> None +``` + +Clean the low confidence data points from the dataset. Processes position data by: + - setting low-confidence points to NaN + - interpolating NaN points + +**Arguments** + +* **config** (`dict`): Configuration dictionary. +* **read_from_variable** (`str, optional`): Variable to read from the dataset. +* **save_to_variable** (`str, optional`): Variable to save the cleaned data to. + +**Returns** + +* `None` + +#### outlier\_cleaning + +```python +def outlier_cleaning(config: dict, + read_from_variable: str = "position_processed", + save_to_variable: str = "position_processed") -> None +``` + +Clean the outliers from the dataset. Processes position data by: + - setting outlier points to NaN + - interpolating NaN points + +**Arguments** + +* **config** (`dict`): Configuration dictionary. +* **read_from_variable** (`str, optional`): Variable to read from the dataset. +* **save_to_variable** (`str, optional`): Variable to save the cleaned data to. + +**Returns** + +* `None` + diff --git a/docs/vame-docs-app/docs/reference/preprocessing/filter.md b/docs/vame-docs-app/docs/reference/preprocessing/filter.md new file mode 100644 index 00000000..cba225cc --- /dev/null +++ b/docs/vame-docs-app/docs/reference/preprocessing/filter.md @@ -0,0 +1,29 @@ +--- +sidebar_label: filter +title: preprocessing.filter +--- + +#### logger\_config + +#### logger + +#### savgol\_filtering + +```python +def savgol_filtering(config: dict, + read_from_variable: str = "position_processed", + save_to_variable: str = "position_processed") -> None +``` + +Apply Savitzky-Golay filter to the data. + +**Arguments** + +* **config** (`dict`): Configuration dictionary. +* **read_from_variable** (`str, optional`): Variable to read from the dataset. +* **save_to_variable** (`str, optional`): Variable to save the filtered data to. + +**Returns** + +* `None` + diff --git a/docs/vame-docs-app/docs/reference/preprocessing/preprocessing.md b/docs/vame-docs-app/docs/reference/preprocessing/preprocessing.md new file mode 100644 index 00000000..9065e7b7 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/preprocessing/preprocessing.md @@ -0,0 +1,34 @@ +--- +sidebar_label: preprocessing +title: preprocessing.preprocessing +--- + +#### logger\_config + +#### logger + +#### preprocessing + +```python +def preprocessing(config: dict, + centered_reference_keypoint: str = "snout", + orientation_reference_keypoint: str = "tailbase", + save_logs: bool = False) -> None +``` + +Preprocess the data by: + - Cleaning low confidence data points + - Egocentric alignment + - Outlier cleaning + - Savitzky-Golay filtering + +**Arguments** + +* **config** (`dict`): Configuration dictionary. +* **centered_reference_keypoint** (`str, optional`): Keypoint to use as centered reference. +* **orientation_reference_keypoint** (`str, optional`): Keypoint to use as orientation reference. + +**Returns** + +* `None` + diff --git a/docs/vame-docs-app/docs/reference/preprocessing/to_model.md b/docs/vame-docs-app/docs/reference/preprocessing/to_model.md new file mode 100644 index 00000000..44dbd198 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/preprocessing/to_model.md @@ -0,0 +1,27 @@ +--- +sidebar_label: to_model +title: preprocessing.to_model +--- + +#### format\_xarray\_for\_rnn + +```python +def format_xarray_for_rnn(ds: xr.Dataset, + read_from_variable: str = "position_processed") +``` + +Formats the xarray dataset for use VAME's RNN model: +- The x and y coordinates of the centered_reference_keypoint are excluded. +- The x coordinate of the orientation_reference_keypoint is excluded. +- The remaining data is flattened and transposed. + +**Arguments** + +* **ds** (`xr.Dataset`): The xarray dataset to format. +* **read_from_variable** (`str, default="position_processed"`): The variable to read from the dataset. + +**Returns** + +* `np.ndarray`: The formatted array in the shape (n_features, n_samples). +Where n_features = 2 * n_keypoints * n_spaces - 3. + diff --git a/docs/vame-docs-app/docs/reference/preprocessing/visualization.md b/docs/vame-docs-app/docs/reference/preprocessing/visualization.md new file mode 100644 index 00000000..7fe559dc --- /dev/null +++ b/docs/vame-docs-app/docs/reference/preprocessing/visualization.md @@ -0,0 +1,39 @@ +--- +sidebar_label: visualization +title: preprocessing.visualization +--- + +#### visualize\_preprocessing\_scatter + +```python +def visualize_preprocessing_scatter( + config: dict, + session_index: int = 0, + frames: list = [], + original_positions_key: str = "position", + cleaned_positions_key: str = "position_cleaned_lowconf", + aligned_positions_key: str = "position_egocentric_aligned", + save_to_file: bool = False, + show_figure: bool = True) +``` + +Visualize the preprocessing results by plotting the original, cleaned low-confidence, +and egocentric aligned positions of the keypoints in a scatter plot. + +#### visualize\_preprocessing\_timeseries + +```python +def visualize_preprocessing_timeseries( + config: dict, + session_index: int = 0, + n_samples: int = 1000, + original_positions_key: str = "position", + aligned_positions_key: str = "position_egocentric_aligned", + processed_positions_key: str = "position_processed", + save_to_file: bool = False, + show_figure: bool = True) +``` + +Visualize the preprocessing results by plotting the original, aligned, and processed positions +of the keypoints in a timeseries plot. + diff --git a/docs/vame-docs-app/docs/reference/sidebar.json b/docs/vame-docs-app/docs/reference/sidebar.json index 77e66f21..27c7e95c 100644 --- a/docs/vame-docs-app/docs/reference/sidebar.json +++ b/docs/vame-docs-app/docs/reference/sidebar.json @@ -1,63 +1,7 @@ { "items": [ - { - "items": [ - "reference/vame/analysis/community_analysis", - "reference/vame/analysis/generative_functions", - "reference/vame/analysis/gif_creator", - "reference/vame/analysis/pose_segmentation", - "reference/vame/analysis/tree_hierarchy", - "reference/vame/analysis/umap", - "reference/vame/analysis/videowriter" - ], - "label": "vame.analysis", - "type": "category" - }, - { - "items": [ - "reference/vame/initialize_project/new" - ], - "label": "vame.initialize_project", - "type": "category" - }, - { - "items": [ - "reference/vame/logging/logger" - ], - "label": "vame.logging", - "type": "category" - }, - { - "items": [ - "reference/vame/model/create_training", - "reference/vame/model/dataloader", - "reference/vame/model/evaluate", - "reference/vame/model/rnn_model", - "reference/vame/model/rnn_vae" - ], - "label": "vame.model", - "type": "category" - }, - { - "items": [ - "reference/vame/schemas/states" - ], - "label": "vame.schemas", - "type": "category" - }, - { - "items": [ - "reference/vame/util/align_egocentrical", - "reference/vame/util/auxiliary", - "reference/vame/util/csv_to_npy", - "reference/vame/util/data_manipulation", - "reference/vame/util/gif_pose_helper", - "reference/vame/util/model_util" - ], - "label": "vame.util", - "type": "category" - } + "reference/__init__/__init__" ], - "label": "vame", + "label": "__init__", "type": "category" } \ No newline at end of file diff --git a/docs/vame-docs-app/docs/reference/util/__init__.md b/docs/vame-docs-app/docs/reference/util/__init__.md new file mode 100644 index 00000000..993b9f1e --- /dev/null +++ b/docs/vame-docs-app/docs/reference/util/__init__.md @@ -0,0 +1,5 @@ +--- +sidebar_label: util +title: util +--- + diff --git a/docs/vame-docs-app/docs/reference/util/auxiliary.md b/docs/vame-docs-app/docs/reference/util/auxiliary.md new file mode 100644 index 00000000..38e56e49 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/util/auxiliary.md @@ -0,0 +1,62 @@ +--- +sidebar_label: auxiliary +title: util.auxiliary +--- + +#### create\_config\_template + +```python +def create_config_template() -> Tuple[dict, ruamel.yaml.YAML] +``` + +Creates a template for the config.yaml file. + +**Returns** + +* `Tuple[dict, ruamel.yaml.YAML]`: A tuple containing the template dictionary and the Ruamel YAML instance. + +#### read\_config + +```python +def read_config(config_file: str) -> dict +``` + +Reads structured config file defining a project. + +**Arguments** + +* **config_file** (`str`): Path to the config file. + +**Returns** + +* `dict`: The contents of the config file as a dictionary. + +#### write\_config + +```python +def write_config(configname: str, cfg: dict) -> None +``` + +Write structured config file. + +**Arguments** + +* **configname** (`str`): Path to the config file. +* **cfg** (`dict`): Dictionary containing the config data. + +#### read\_states + +```python +def read_states(config: dict) -> dict +``` + +Reads the states.json file. + +**Arguments** + +* **config** (`dict`): Dictionary containing the config data. + +**Returns** + +* `dict`: The contents of the states.json file as a dictionary. + diff --git a/docs/vame-docs-app/docs/reference/util/cli.md b/docs/vame-docs-app/docs/reference/util/cli.md new file mode 100644 index 00000000..d1e5cca5 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/util/cli.md @@ -0,0 +1,13 @@ +--- +sidebar_label: cli +title: util.cli +--- + +#### get\_sessions\_from\_user\_input + +```python +def get_sessions_from_user_input(cfg: dict, + action_message: str = "run this step" + ) -> List[str] +``` + diff --git a/docs/vame-docs-app/docs/reference/util/csv_to_npy.md b/docs/vame-docs-app/docs/reference/util/csv_to_npy.md new file mode 100644 index 00000000..734e6a45 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/util/csv_to_npy.md @@ -0,0 +1,29 @@ +--- +sidebar_label: csv_to_npy +title: util.csv_to_npy +--- + +#### logger\_config + +#### logger + +#### pose\_to\_numpy + +```python +@save_state(model=PoseToNumpyFunctionSchema) +def pose_to_numpy(config: dict, save_logs=False) -> None +``` + +Converts a pose-estimation.csv file to a numpy array. +Note that this code is only useful for data which is a priori egocentric, i.e. head-fixed +or otherwise restrained animals. + +**Arguments** + +* **config** (`dict`): Configuration dictionary. +* **save_logs** (`bool, optional`): If True, the logs will be saved to a file, by default False. + +**Raises** + +* `ValueError`: If the config.yaml file indicates that the data is not egocentric. + diff --git a/docs/vame-docs-app/docs/reference/util/data_manipulation.md b/docs/vame-docs-app/docs/reference/util/data_manipulation.md new file mode 100644 index 00000000..f651108a --- /dev/null +++ b/docs/vame-docs-app/docs/reference/util/data_manipulation.md @@ -0,0 +1,120 @@ +--- +sidebar_label: data_manipulation +title: util.data_manipulation +--- + +#### logger\_config + +#### logger + +#### consecutive + +```python +def consecutive(data: np.ndarray, stepsize: int = 1) -> List[np.ndarray] +``` + +Find consecutive sequences in the data array. + +**Arguments** + +* **data** (`np.ndarray`): Input array. +* **stepsize** (`int, optional`): Step size. Defaults to 1. + +**Returns** + +* `List[np.ndarray]`: List of consecutive sequences. + +#### nan\_helper + +```python +def nan_helper(y: np.ndarray) -> Tuple +``` + +Identifies indices of NaN values in an array and provides a function to convert them to non-NaN indices. + +**Arguments** + +* **y** (`np.ndarray`): Input array containing NaN values. + +**Returns** + +* `Tuple[np.ndarray, Union[np.ndarray, None]]`: A tuple containing two elements: +- An array of boolean values indicating the positions of NaN values. +- A lambda function to convert NaN indices to non-NaN indices. + +#### interpol\_first\_rows\_nans + +```python +def interpol_first_rows_nans(arr: np.ndarray) -> np.ndarray +``` + +Interpolates NaN values in the given array. + +**Arguments** + +* **arr** (`np.ndarray`): Input array with NaN values. + +**Returns** + +* `np.ndarray`: Array with interpolated NaN values. + +#### interpolate\_nans\_with\_pandas + +```python +def interpolate_nans_with_pandas(data: np.ndarray) -> np.ndarray +``` + +Interpolate NaN values along the time axis of a 3D NumPy array using Pandas. + +**Arguments** + +* **data** (`numpy.ndarray`): Input 3D array of shape (time, keypoints, space). + +**Returns** + +* `numpy.ndarray:`: Array with NaN values interpolated. + +#### crop\_and\_flip\_legacy + +```python +def crop_and_flip_legacy( + rect: Tuple, src: np.ndarray, points: List[np.ndarray], + ref_index: Tuple[int, int]) -> Tuple[np.ndarray, List[np.ndarray]] +``` + +Crop and flip the image based on the given rectangle and points. + +**Arguments** + +* **rect** (`Tuple`): Rectangle coordinates (center, size, theta). +* **src: np.ndarray**: Source image. +* **points** (`List[np.ndarray]`): List of points. +* **ref_index** (`Tuple[int, int]`): Reference indices for alignment. + +**Returns** + +* `Tuple[np.ndarray, List[np.ndarray]]`: Cropped and flipped image, and shifted points. + +#### background + +```python +def background(project_path: str, + session: str, + video_path: str, + num_frames: int = 1000, + save_background: bool = True) -> np.ndarray +``` + +Compute background image from fixed camera. + +**Arguments** + +* **project_path** (`str`): Path to the project directory. +* **session** (`str`): Name of the session. +* **video_path** (`str`): Path to the video file. +* **num_frames** (`int, optional`): Number of frames to use for background computation. Defaults to 1000. + +**Returns** + +* `np.ndarray`: Background image. + diff --git a/docs/vame-docs-app/docs/reference/util/gif_pose_helper.md b/docs/vame-docs-app/docs/reference/util/gif_pose_helper.md new file mode 100644 index 00000000..6003bb95 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/util/gif_pose_helper.md @@ -0,0 +1,40 @@ +--- +sidebar_label: gif_pose_helper +title: util.gif_pose_helper +--- + +#### logger\_config + +#### logger + +#### get\_animal\_frames + +```python +def get_animal_frames( + cfg: dict, + session: str, + pose_ref_index: list, + start: int, + length: int, + subtract_background: bool, + file_format: str = ".mp4", + crop_size: tuple = (300, 300)) -> list +``` + +Extracts frames of an animal from a video file and returns them as a list. + +**Arguments** + +* **cfg** (`dict`): Configuration dictionary containing project information. +* **session** (`str`): Name of the session. +* **pose_ref_index** (`list`): List of reference coordinate indices for alignment. +* **start** (`int`): Starting frame index. +* **length** (`int`): Number of frames to extract. +* **subtract_background** (`bool`): Whether to subtract background or not. +* **file_format** (`str, optional`): Format of the video file. Defaults to '.mp4'. +* **crop_size** (`tuple, optional`): Size of the cropped area. Defaults to (300, 300). + +**Returns** + +* `list:`: List of extracted frames. + diff --git a/docs/vame-docs-app/docs/reference/util/model_util.md b/docs/vame-docs-app/docs/reference/util/model_util.md new file mode 100644 index 00000000..f30209bf --- /dev/null +++ b/docs/vame-docs-app/docs/reference/util/model_util.md @@ -0,0 +1,25 @@ +--- +sidebar_label: model_util +title: util.model_util +--- + +#### logger\_config + +#### logger + +#### load\_model + +```python +def load_model(cfg: dict, model_name: str, fixed: bool = True) -> RNN_VAE +``` + +Load the VAME model. + +Args: + cfg (dict): Configuration dictionary. + model_name (str): Name of the model. + fixed (bool): Fixed or variable length sequences. + +Returns + RNN_VAE: Loaded VAME model. + diff --git a/docs/vame-docs-app/docs/reference/util/report.md b/docs/vame-docs-app/docs/reference/util/report.md new file mode 100644 index 00000000..061fe1dc --- /dev/null +++ b/docs/vame-docs-app/docs/reference/util/report.md @@ -0,0 +1,30 @@ +--- +sidebar_label: report +title: util.report +--- + +#### logger\_config + +#### logger + +#### report + +```python +def report(config: dict, segmentation_algorithm: str = "hmm") -> None +``` + +Report for a project. + +#### plot\_community\_motifs + +```python +def plot_community_motifs(motif_labels, + community_labels, + community_bag, + title: str = "Community and Motif Counts", + save_to_file: bool = False, + save_path: str = "") +``` + +Generates a bar plot to represent community and motif counts with percentages. + diff --git a/docs/vame-docs-app/docs/reference/util/sample_data.md b/docs/vame-docs-app/docs/reference/util/sample_data.md new file mode 100644 index 00000000..887aa825 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/util/sample_data.md @@ -0,0 +1,25 @@ +--- +sidebar_label: sample_data +title: util.sample_data +--- + +#### DOWNLOAD\_PATH + +#### dataset\_options + +#### download\_sample\_data + +```python +def download_sample_data(source_software: str) -> dict +``` + +Download sample data. + +**Arguments** + +* **source_software** (`str`): Source software used for pose estimation. + +**Returns** + +* `dict`: Dictionary with the paths to the downloaded sample data. + diff --git a/docs/vame-docs-app/docs/reference/video/__init__.md b/docs/vame-docs-app/docs/reference/video/__init__.md new file mode 100644 index 00000000..7b4e68d7 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/video/__init__.md @@ -0,0 +1,5 @@ +--- +sidebar_label: video +title: video +--- + diff --git a/docs/vame-docs-app/docs/reference/video/video.md b/docs/vame-docs-app/docs/reference/video/video.md new file mode 100644 index 00000000..944dd3a4 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/video/video.md @@ -0,0 +1,11 @@ +--- +sidebar_label: video +title: video.video +--- + +#### get\_video\_frame\_rate + +```python +def get_video_frame_rate(video_path) +``` + diff --git a/src/vame/analysis/community_analysis.py b/src/vame/analysis/community_analysis.py index 28b31210..034a3c3f 100644 --- a/src/vame/analysis/community_analysis.py +++ b/src/vame/analysis/community_analysis.py @@ -287,8 +287,8 @@ def compute_transition_matrices( """ Compute transition matrices for given files and labels. - Parameters: - ----------- + Parameters + ---------- files : List[str] List of file paths. labels : List[np.ndarray] @@ -296,8 +296,8 @@ def compute_transition_matrices( n_clusters : int Number of clusters. - Returns: - -------- + Returns + ------- List[np.ndarray]: List of transition matrices. """ @@ -706,7 +706,7 @@ def community( # cut_tree (int): Cut line for tree. # n_clusters (int): Number of clusters. -# Returns: +# Returns # Tuple: Tuple containing list of community bags and list of trees. # """ # trees = [] @@ -768,7 +768,7 @@ def community( # labels (List[np.ndarray]): List of label arrays. # communities_all (List[List[List[int]]]): List of community bags. -# Returns: +# Returns # List[np.ndarray]: List of community labels for each file. # """ # community_labels_all = [] diff --git a/src/vame/analysis/generative_functions.py b/src/vame/analysis/generative_functions.py index 29128c5c..282c92af 100644 --- a/src/vame/analysis/generative_functions.py +++ b/src/vame/analysis/generative_functions.py @@ -27,8 +27,8 @@ def random_generative_samples_motif( """ Generate random samples for motifs. - Parameters: - ----------- + Parameters + ---------- cfg : dict Configuration dictionary. model : torch.nn.Module @@ -40,8 +40,8 @@ def random_generative_samples_motif( n_clusters : int Number of clusters. - Returns: - -------- + Returns + ------- plt.Figure Figure of generated samples. """ @@ -85,8 +85,8 @@ def random_generative_samples( """ Generate random generative samples. - Parameters: - ----------- + Parameters + ---------- cfg : dict Configuration dictionary. model : torch.nn.Module @@ -94,8 +94,8 @@ def random_generative_samples( latent_vector : np.ndarray Latent vectors. - Returns: - -------- + Returns + ------- plt.Figure Figure of generated samples. """ @@ -136,8 +136,8 @@ def random_reconstruction_samples( """ Generate random reconstruction samples. - Parameters: - ----------- + Parameters + ---------- cfg : dict Configuration dictionary. model : torch.nn.Module @@ -145,8 +145,8 @@ def random_reconstruction_samples( latent_vector : np.ndarray Latent vectors. - Returns: - -------- + Returns + ------- plt.Figure Figure of reconstructed samples. """ @@ -183,8 +183,8 @@ def visualize_cluster_center( """ Visualize cluster centers. - Parameters: - ----------- + Parameters + ---------- cfg : dict Configuration dictionary. model : torch.nn.Module @@ -192,8 +192,8 @@ def visualize_cluster_center( cluster_center : np.ndarray Cluster centers. - Returns: - -------- + Returns + ------- plt.Figure Figure of cluster centers. """ @@ -236,15 +236,15 @@ def generative_model( """ Generative model. - Parameters: - ----------- + Parameters + ---------- config : dict Configuration dictionary. mode : str, optional Mode for generating samples. Defaults to "sampling". - Returns: - -------- + Returns + ------- plt.Figure Plots of generated samples for each segmentation algorithm. """ diff --git a/src/vame/analysis/gif_creator.py b/src/vame/analysis/gif_creator.py index d28c4d5c..1f4dc2b8 100644 --- a/src/vame/analysis/gif_creator.py +++ b/src/vame/analysis/gif_creator.py @@ -33,8 +33,8 @@ def create_video( """ Create video frames for the given embedding. - Parameters: - ----------- + Parameters + ---------- path_to_file : str Path to the file. session : str @@ -54,8 +54,8 @@ def create_video( num_points : int Number of points. - Returns: - -------- + Returns + ------- None """ # set matplotlib colormap @@ -119,8 +119,8 @@ def gif( ) -> None: """Create a GIF from the given configuration. - Parameters: - ----------- + Parameters + ---------- config : str Path to the configuration file. pose_ref_index : list @@ -142,8 +142,8 @@ def gif( crop_size : Tuple[int, int], optional Crop size. Defaults to (300,300). - Returns: - -------- + Returns + ------- None """ config_file = Path(config).resolve() diff --git a/src/vame/analysis/tree_hierarchy.py b/src/vame/analysis/tree_hierarchy.py index 652648a4..05741481 100644 --- a/src/vame/analysis/tree_hierarchy.py +++ b/src/vame/analysis/tree_hierarchy.py @@ -17,8 +17,8 @@ def hierarchy_pos( Positions nodes in a tree-like layout. Ref: From Joel's answer at https://stackoverflow.com/a/29597209/2966723. - Parameters: - ----------- + Parameters + ---------- G : nx.Graph The input graph. Must be a tree. root : str, optional @@ -33,8 +33,8 @@ def hierarchy_pos( xcenter : float, optional The horizontal location of the root node. Defaults to 0.5. - Returns: - -------- + Returns + ------- Dict[str, Tuple[float, float]] A dictionary mapping node names to their positions (x, y). """ @@ -92,8 +92,8 @@ def merge_func( """ Merge nodes in a graph based on a selection criterion. - Parameters: - ----------- + Parameters + ---------- transition_matrix : np.ndarray The transition matrix of the graph. n_clusters : int @@ -105,8 +105,8 @@ def merge_func( - 0: Merge nodes with highest transition probability. - 1: Merge nodes with lowest cost. - Returns: - -------- + Returns + ------- Tuple[np.ndarray, np.ndarray] A tuple containing the merged nodes. """ @@ -145,8 +145,8 @@ def graph_to_tree( """ Convert a graph to a tree. - Parameters: - ----------- + Parameters + ---------- motif_usage : np.ndarray The motif usage matrix. transition_matrix : np.ndarray @@ -158,8 +158,8 @@ def graph_to_tree( - 0: Merge nodes with highest transition probability. - 1: Merge nodes with lowest cost. - Returns: - -------- + Returns + ------- nx.Graph The tree. """ @@ -334,15 +334,15 @@ def draw_tree( """ Draw a tree. - Parameters: - ----------- + Parameters + ---------- T : nx.Graph The tree to be drawn. fig_width : int, optional The width of the figure. Defaults to 10. - Returns: - -------- + Returns + ------- None """ # pos = nx.drawing.layout.fruchterman_reingold_layout(T) @@ -400,8 +400,8 @@ def _traverse_tree_cutline( DEPRECATED in favor of bag_nodes_by_cutline. Helper function for tree traversal with a cutline. - Parameters: - ----------- + Parameters + ---------- T : nx.Graph The tree to be traversed. node : List[str] @@ -417,8 +417,8 @@ def _traverse_tree_cutline( community_list : List[str], optional List of nodes in the current community bag. - Returns: - -------- + Returns + ------- List[List[str]] List of lists community bags. """ @@ -493,8 +493,8 @@ def traverse_tree_cutline( DEPRECATED in favor of bag_nodes_by_cutline. Traverse a tree with a cutline and return the community bags. - Parameters: - ----------- + Parameters + ---------- T : nx.Graph The tree to be traversed. root_node : str, optional @@ -502,8 +502,8 @@ def traverse_tree_cutline( cutline : int, optional The cutline level. - Returns: - -------- + Returns + ------- List[List[str]] List of community bags. """ @@ -537,8 +537,8 @@ def bag_nodes_by_cutline( """ Bag nodes of a tree by a cutline. - Parameters: - ----------- + Parameters + ---------- tree : nx.Graph The tree to be bagged. cutline : int, optional @@ -546,8 +546,8 @@ def bag_nodes_by_cutline( root : str, optional The root node of the tree. Defaults to 'Root'. - Returns: - -------- + Returns + ------- List[List[str]] List of bags of nodes. """ diff --git a/src/vame/analysis/umap.py b/src/vame/analysis/umap.py index ab8898e8..39f575bb 100644 --- a/src/vame/analysis/umap.py +++ b/src/vame/analysis/umap.py @@ -81,7 +81,7 @@ def umap_embedding( # community_labels_all (np.ndarray): Community labels. # save_path: Path to save the plot. If None it will not save the plot. -# Returns: +# Returns # None # """ # num_points = cfg['num_points'] diff --git a/src/vame/analysis/videowriter.py b/src/vame/analysis/videowriter.py index 35e0adcb..5dc16b9a 100644 --- a/src/vame/analysis/videowriter.py +++ b/src/vame/analysis/videowriter.py @@ -281,8 +281,8 @@ def community_videos( - file_name-community_1.mp4 - ... - Parameters: - ----------- + Parameters + ---------- config : dict Configuration parameters. segmentation_algorithm : SegmentationAlgorithms diff --git a/src/vame/initialize_project/new.py b/src/vame/initialize_project/new.py index 66a7ed37..4b021d95 100644 --- a/src/vame/initialize_project/new.py +++ b/src/vame/initialize_project/new.py @@ -54,7 +54,7 @@ def init_new_project( - states.json - config.yaml - Parameters: + Parameters ---------- project_name : str Project name. diff --git a/src/vame/io/load_poses.py b/src/vame/io/load_poses.py index c55b2d34..27334d32 100644 --- a/src/vame/io/load_poses.py +++ b/src/vame/io/load_poses.py @@ -17,8 +17,8 @@ def load_pose_estimation( """ Load pose estimation data. - Parameters: - ----------- + Parameters + ---------- pose_estimation_file : Path or str Path to the pose estimation file. video_file : Path or str @@ -28,8 +28,8 @@ def load_pose_estimation( source_software : Literal["DeepLabCut", "SLEAP", "LightningPose"] Source software used for pose estimation. - Returns: - -------- + Returns + ------- ds : xarray.Dataset Pose estimation dataset. """ @@ -46,21 +46,23 @@ def load_vame_dataset(ds_path: Path | str) -> xr.Dataset: """ Load VAME dataset. - Parameters: - ----------- + Parameters + ---------- ds_path : Path or str Path to the netCDF dataset. - Returns: - -------- + Returns + ------- + xr.Dataset + VAME dataset """ - # Windows will not allow opened files to be overwritten, + # Windows will not allow opened files to be overwritten, # so we need to load data into memory, close the file and move on with the operations with xr.open_dataset(ds_path, engine="scipy") as tmp_ds: ds_in_memory = tmp_ds.load() # read entire file into memory return ds_in_memory - + def nc_to_dataframe(nc_data): keypoints = nc_data["keypoints"].values space = nc_data["space"].values diff --git a/src/vame/io/nwb.py b/src/vame/io/nwb.py index 91f19a61..d773a3de 100644 --- a/src/vame/io/nwb.py +++ b/src/vame/io/nwb.py @@ -11,8 +11,8 @@ # """ # Get pose data from nwb file using a inside path to the nwb data. -# Parameters: -# ---------- +# Parameters +# --------- # nwbfile : NWBFile) # NWB file object. # path_to_pose_nwb_series_data : str @@ -42,7 +42,7 @@ # Get pose data from nwb file and return it as a pandas DataFrame. # Parameters -# ---------- +# --------- # file_path : str # Path to the nwb file. # path_to_pose_nwb_series_data : str diff --git a/src/vame/model/create_training.py b/src/vame/model/create_training.py index 8db45adf..d4c1b5b1 100644 --- a/src/vame/model/create_training.py +++ b/src/vame/model/create_training.py @@ -85,7 +85,7 @@ def traindata_aligned( # Create training dataset for fixed data. # Parameters -# ---------- +# --------- # cfg : dict # Configuration parameters. # sessions : List[str] @@ -101,7 +101,7 @@ def traindata_aligned( # pose_ref_index : Optional[List[int]] # List of reference coordinate indices for alignment. -# Returns: +# Returns # None # Save numpy arrays with the test/train info to the project folder. # """ diff --git a/src/vame/model/rnn_model.py b/src/vame/model/rnn_model.py index ec8038fe..d022a28c 100644 --- a/src/vame/model/rnn_model.py +++ b/src/vame/model/rnn_model.py @@ -18,8 +18,8 @@ def __init__( """ Initialize the Encoder module. - Parameters: - ----------- + Parameters + ---------- NUM_FEATURES : int Number of input features. hidden_size_layer_1 : int @@ -54,13 +54,13 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: """ Forward pass of the Encoder module. - Parameters: - ----------- + Parameters + ---------- inputs : torch.Tensor Input tensor of shape (batch_size, sequence_length, num_features). - Returns: - -------- + Returns + ------- torch.Tensor: Encoded representation tensor of shape (batch_size, hidden_size_layer_1 * 4). """ @@ -89,8 +89,8 @@ def __init__( """ Initialize the Lambda module. - Parameters: - ----------- + Parameters + ---------- ZDIMS : int Size of the latent space. hidden_size_layer_1 : int @@ -122,13 +122,13 @@ def forward( """ Forward pass of the Lambda module. - Parameters: - ----------- + Parameters + ---------- hidden : torch.Tensor Hidden representation tensor of shape (batch_size, hidden_size_layer_1 * 4). - Returns: - -------- + Returns + ------- tuple[torch.Tensor, torch.Tensor, torch.Tensor] Latent space tensor, mean tensor, logvar tensor. """ @@ -160,8 +160,8 @@ def __init__( """ Initialize the Decoder module. - Parameters: - ----------- + Parameters + ---------- TEMPORAL_WINDOW : int Size of the temporal window. ZDIMS : int @@ -212,15 +212,15 @@ def forward( """ Forward pass of the Decoder module. - Parameters: - ----------- + Parameters + ---------- inputs : torch.Tensor Input tensor of shape (batch_size, seq_len, ZDIMS). z : torch.Tensor Latent space tensor of shape (batch_size, ZDIMS). - Returns: - -------- + Returns + ------- torch.Tensor: Decoded output tensor of shape (batch_size, seq_len, NUM_FEATURES). """ @@ -247,8 +247,8 @@ def __init__( """ Initialize the Decoder_Future module. - Parameters: - ----------- + Parameters + ---------- TEMPORAL_WINDOW : int Size of the temporal window. ZDIMS : int @@ -299,15 +299,15 @@ def forward( """ Forward pass of the Decoder_Future module. - Parameters: - ----------- + Parameters + ---------- inputs : torch.Tensor Input tensor of shape (batch_size, seq_len, ZDIMS). z : torch.Tensor Latent space tensor of shape (batch_size, ZDIMS). - Returns: - -------- + Returns + ------- torch.Tensor: Predicted future tensor of shape (batch_size, FUTURE_STEPS, NUM_FEATURES). """ @@ -342,8 +342,8 @@ def __init__( """ Initialize the VAE module. - Parameters: - ----------- + Parameters + ---------- TEMPORAL_WINDOW : int Size of the temporal window. ZDIMS : int @@ -401,13 +401,13 @@ def forward(self, seq: torch.Tensor) -> tuple: """ Forward pass of the VAE. - Parameters: - ----------- + Parameters + ---------- seq : torch.Tensor Input sequence tensor of shape (batch_size, seq_len, NUM_FEATURES). - Returns: - -------- + Returns + ------- Tuple containing: - If FUTURE_DECODER is True: - prediction (torch.Tensor): Reconstructed input sequence tensor. diff --git a/src/vame/pipeline.py b/src/vame/pipeline.py index 9297e574..f325bc22 100644 --- a/src/vame/pipeline.py +++ b/src/vame/pipeline.py @@ -48,8 +48,8 @@ def get_sessions(self) -> List[str]: """ Returns a list of session names. - Returns: - -------- + Returns + ------- List[str] Session names. """ @@ -59,8 +59,8 @@ def get_raw_datasets(self) -> xr.Dataset: """ Returns a xarray dataset which combines all the raw data from the project. - Returns: - -------- + Returns + ------- dss : xarray.Dataset Combined raw dataset. """ @@ -157,8 +157,8 @@ def get_states(self, summary: bool = True) -> dict: """ Returns the pipeline states. - Returns: - -------- + Returns + ------- dict Pipeline states. """ diff --git a/src/vame/preprocessing/align_egocentrical_legacy.py b/src/vame/preprocessing/align_egocentrical_legacy.py index f594faa4..8c1ab624 100644 --- a/src/vame/preprocessing/align_egocentrical_legacy.py +++ b/src/vame/preprocessing/align_egocentrical_legacy.py @@ -40,8 +40,8 @@ # """ # Align the mouse in the video frames. -# Parameters: -# ----------- +# Parameters +# ---------- # project_path : str # Path to the project directory. # session : str @@ -200,8 +200,8 @@ # """ # Perform alignment of egocentric data. -# Parameters: -# ----------- +# Parameters +# ---------- # project_path : str # Path to the project directory. # session : str @@ -312,7 +312,7 @@ # shape of (num_dlc_features, num_video_frames). # Parameters -# ---------- +# --------- # config : str # Path for the project config file. # pose_ref_index : list, optional @@ -430,7 +430,7 @@ # shape of (num_dlc_features, num_video_frames). # Parameters -# ---------- +# --------- # config : str # Path for the project config file. # pose_ref_index : list, optional @@ -549,8 +549,8 @@ # """ # Egocentric alignment of pose estimation data. -# Parameters: -# ----------- +# Parameters +# ---------- # crop_size : Tuple[int, int] # Size to crop the video frames. # pose_list : List[np.ndarray] @@ -639,7 +639,7 @@ # Crop and flip the image based on the given rectangle and points. # Parameters -# ---------- +# --------- # rect : Tuple # Rectangle coordinates (center, size, theta). # points : List[np.ndarray] diff --git a/src/vame/preprocessing/alignment.py b/src/vame/preprocessing/alignment.py index afe54ba7..7a398a94 100644 --- a/src/vame/preprocessing/alignment.py +++ b/src/vame/preprocessing/alignment.py @@ -20,8 +20,8 @@ def egocentrically_align_and_center( Aligns the time series by first centralizing all positions around the first keypoint and then applying rotation to align with the line connecting the two keypoints. - Parameters: - ----------- + Parameters + ---------- config : dict Configuration dictionary centered_reference_keypoint : str @@ -29,8 +29,8 @@ def egocentrically_align_and_center( orientation_reference_keypoint : str Name of the keypoint to use as orientation reference. - Returns: - -------- + Returns + ------- None """ logger.info( diff --git a/src/vame/preprocessing/cleaning.py b/src/vame/preprocessing/cleaning.py index 93047d28..48348db3 100644 --- a/src/vame/preprocessing/cleaning.py +++ b/src/vame/preprocessing/cleaning.py @@ -20,8 +20,8 @@ def lowconf_cleaning( - setting low-confidence points to NaN - interpolating NaN points - Parameters: - ----------- + Parameters + ---------- config : dict Configuration dictionary. read_from_variable : str, optional @@ -29,8 +29,8 @@ def lowconf_cleaning( save_to_variable : str, optional Variable to save the cleaned data to. - Returns: - -------- + Returns + ------- None """ project_path = config["project_path"] @@ -98,8 +98,8 @@ def outlier_cleaning( - setting outlier points to NaN - interpolating NaN points - Parameters: - ----------- + Parameters + ---------- config : dict Configuration dictionary. read_from_variable : str, optional @@ -107,8 +107,8 @@ def outlier_cleaning( save_to_variable : str, optional Variable to save the cleaned data to. - Returns: - -------- + Returns + ------- None """ logger.info("Cleaning outliers with Z-score transformation and IQR cutoff.") diff --git a/src/vame/preprocessing/filter.py b/src/vame/preprocessing/filter.py index a4f81702..40cee29f 100644 --- a/src/vame/preprocessing/filter.py +++ b/src/vame/preprocessing/filter.py @@ -18,8 +18,8 @@ def savgol_filtering( """ Apply Savitzky-Golay filter to the data. - Parameters: - ----------- + Parameters + ---------- config : dict Configuration dictionary. read_from_variable : str, optional @@ -27,8 +27,8 @@ def savgol_filtering( save_to_variable : str, optional Variable to save the filtered data to. - Returns: - -------- + Returns + ------- None """ logger.info("Applying Savitzky-Golay filter...") diff --git a/src/vame/preprocessing/preprocessing.py b/src/vame/preprocessing/preprocessing.py index 3eb15869..adf09954 100644 --- a/src/vame/preprocessing/preprocessing.py +++ b/src/vame/preprocessing/preprocessing.py @@ -25,8 +25,8 @@ def preprocessing( - Outlier cleaning - Savitzky-Golay filtering - Parameters: - ----------- + Parameters + ---------- config : dict Configuration dictionary. centered_reference_keypoint : str, optional @@ -34,8 +34,8 @@ def preprocessing( orientation_reference_keypoint : str, optional Keypoint to use as orientation reference. - Returns: - -------- + Returns + ------- None """ # Low-confidence cleaning diff --git a/src/vame/preprocessing/to_model.py b/src/vame/preprocessing/to_model.py index f7b9903a..1b8e9ac6 100644 --- a/src/vame/preprocessing/to_model.py +++ b/src/vame/preprocessing/to_model.py @@ -11,15 +11,15 @@ def format_xarray_for_rnn( - The x coordinate of the orientation_reference_keypoint is excluded. - The remaining data is flattened and transposed. - Parameters: - ----------- + Parameters + ---------- ds : xr.Dataset The xarray dataset to format. read_from_variable : str, default="position_processed" The variable to read from the dataset. - Returns: - -------- + Returns + ------- np.ndarray The formatted array in the shape (n_features, n_samples). Where n_features = 2 * n_keypoints * n_spaces - 3. diff --git a/src/vame/util/auxiliary.py b/src/vame/util/auxiliary.py index b535c218..ad9875c2 100644 --- a/src/vame/util/auxiliary.py +++ b/src/vame/util/auxiliary.py @@ -81,7 +81,7 @@ def create_config_template() -> Tuple[dict, ruamel.yaml.YAML]: random_state: num_points: \n -#--------------------------------------------------------------- +#-------------------------------------------------------- # ONLY CHANGE ANYTHING BELOW IF YOU ARE FAMILIAR WITH RNN MODELS # RNN encoder hyperparamter: hidden_size_layer_1: diff --git a/src/vame/util/data_manipulation.py b/src/vame/util/data_manipulation.py index 5e2de8f8..c40c2806 100644 --- a/src/vame/util/data_manipulation.py +++ b/src/vame/util/data_manipulation.py @@ -60,7 +60,7 @@ def nan_helper(y: np.ndarray) -> Tuple: # Interpolates all NaN values in the given array. # Parameters -# ---------- +# --------- # arr : np.ndarray # Input array containing NaN values. @@ -103,13 +103,13 @@ def interpolate_nans_with_pandas(data: np.ndarray) -> np.ndarray: """ Interpolate NaN values along the time axis of a 3D NumPy array using Pandas. - Parameters: - ----------- + Parameters + ---------- data : numpy.ndarray Input 3D array of shape (time, keypoints, space). - Returns: - -------- + Returns + ------- numpy.ndarray: Array with NaN values interpolated. """ diff --git a/src/vame/util/gif_pose_helper.py b/src/vame/util/gif_pose_helper.py index 4247f834..f00f02cb 100644 --- a/src/vame/util/gif_pose_helper.py +++ b/src/vame/util/gif_pose_helper.py @@ -30,8 +30,8 @@ def get_animal_frames( """ Extracts frames of an animal from a video file and returns them as a list. - Parameters: - ----------- + Parameters + ---------- cfg : dict Configuration dictionary containing project information. session : str @@ -49,8 +49,8 @@ def get_animal_frames( crop_size : tuple, optional Size of the cropped area. Defaults to (300, 300). - Returns: - -------- + Returns + ------- list: List of extracted frames. """ diff --git a/src/vame/util/model_util.py b/src/vame/util/model_util.py index 315c8b24..20b32f18 100644 --- a/src/vame/util/model_util.py +++ b/src/vame/util/model_util.py @@ -16,7 +16,7 @@ def load_model(cfg: dict, model_name: str, fixed: bool = True) -> RNN_VAE: model_name (str): Name of the model. fixed (bool): Fixed or variable length sequences. - Returns: + Returns RNN_VAE: Loaded VAME model. """ # load Model diff --git a/src/vame/video/video.py b/src/vame/video/video.py index 13d640b7..b6b6ef28 100644 --- a/src/vame/video/video.py +++ b/src/vame/video/video.py @@ -21,7 +21,7 @@ def get_video_frame_rate(video_path): # Play the aligned video. # Parameters -# ---------- +# --------- # a : List[np.ndarray] # List of aligned images. # n : List[List[np.ndarray]] From ad83c6f9f45c54d0d72a7b606144708d0752cfb7 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 27 Dec 2024 16:11:47 +0100 Subject: [PATCH 51/77] action --- .github/workflows/{deploy-docs.yaml => publish_docs.yaml} | 2 ++ 1 file changed, 2 insertions(+) rename .github/workflows/{deploy-docs.yaml => publish_docs.yaml} (94%) diff --git a/.github/workflows/deploy-docs.yaml b/.github/workflows/publish_docs.yaml similarity index 94% rename from .github/workflows/deploy-docs.yaml rename to .github/workflows/publish_docs.yaml index 97e0c0b6..9017d2ba 100644 --- a/.github/workflows/deploy-docs.yaml +++ b/.github/workflows/publish_docs.yaml @@ -4,6 +4,7 @@ on: push: branches: - main + - updated-docs jobs: deploy: @@ -61,3 +62,4 @@ jobs: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: ./docs/vame-docs-app/build publish_branch: gh-pages + destination_dir: ${{ github.ref_name == 'main' && '' || github.ref_name }} From f9d016eab4cb17fd514ba669e666c1b8b424135e Mon Sep 17 00:00:00 2001 From: luiz Date: Sat, 28 Dec 2024 11:25:49 +0100 Subject: [PATCH 52/77] docs --- .github/workflows/publish_docs.yaml | 6 +- .../docs/reference/_category_.json | 8 - .../reference/analysis/community_analysis.md | 18 +- .../analysis/generative_functions.md | 10 +- .../docs/reference/analysis/gif_creator.md | 4 +- .../reference/analysis/pose_segmentation.md | 10 +- .../docs/reference/analysis/tree_hierarchy.md | 14 +- .../docs/reference/analysis/umap.md | 10 +- .../docs/reference/analysis/videowriter.md | 6 +- .../docs/reference/initialize_project/new.md | 2 +- .../docs/reference/io/load_poses.md | 6 +- .../docs/reference/model/create_training.md | 4 +- .../docs/reference/model/dataloader.md | 4 +- .../docs/reference/model/evaluate.md | 8 +- .../docs/reference/model/rnn_model.md | 20 +- .../docs/reference/model/rnn_vae.md | 18 +- .../docs/reference/preprocessing/alignment.md | 2 +- .../docs/reference/preprocessing/cleaning.md | 4 +- .../docs/reference/preprocessing/filter.md | 2 +- .../reference/preprocessing/preprocessing.md | 2 +- .../docs/reference/preprocessing/to_model.md | 2 +- .../docs/reference/util/auxiliary.md | 6 +- .../docs/reference/util/csv_to_npy.md | 2 +- .../docs/reference/util/data_manipulation.md | 12 +- .../docs/reference/util/gif_pose_helper.md | 2 +- .../docs/reference/util/sample_data.md | 2 +- .../vame/analysis/community_analysis.md | 267 ------------------ .../vame/analysis/generative_functions.md | 118 -------- .../reference/vame/analysis/gif_creator.md | 74 ----- .../vame/analysis/pose_segmentation.md | 116 -------- .../vame/analysis/segment_behavior.md | 136 --------- .../reference/vame/analysis/tree_hierarchy.md | 132 --------- .../docs/reference/vame/analysis/umap.md | 115 -------- .../vame/analysis/umap_visualization.md | 95 ------- .../reference/vame/analysis/videowriter.md | 89 ------ .../reference/vame/initialize_project/new.md | 47 --- .../docs/reference/vame/logging/logger.md | 14 - .../reference/vame/model/create_training.md | 103 ------- .../docs/reference/vame/model/dataloader.md | 68 ----- .../docs/reference/vame/model/evaluate.md | 90 ------ .../docs/reference/vame/model/rnn_model.md | 246 ---------------- .../docs/reference/vame/model/rnn_vae.md | 225 --------------- .../docs/reference/vame/schemas/states.md | 14 - .../reference/vame/util/align_egocentrical.md | 132 --------- .../docs/reference/vame/util/auxiliary.md | 63 ----- .../docs/reference/vame/util/csv_to_npy.md | 26 -- .../reference/vame/util/data_manipulation.md | 141 --------- .../reference/vame/util/gif_pose_helper.md | 44 --- .../docs/reference/vame/util/model_util.md | 24 -- .../src/components/HomepageFeatures/index.js | 5 +- docs/yarn.lock | 4 + 51 files changed, 94 insertions(+), 2478 deletions(-) delete mode 100644 docs/vame-docs-app/docs/reference/_category_.json delete mode 100644 docs/vame-docs-app/docs/reference/vame/analysis/community_analysis.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/analysis/generative_functions.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/analysis/gif_creator.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/analysis/pose_segmentation.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/analysis/segment_behavior.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/analysis/tree_hierarchy.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/analysis/umap.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/analysis/umap_visualization.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/analysis/videowriter.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/initialize_project/new.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/logging/logger.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/model/create_training.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/model/dataloader.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/model/evaluate.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/model/rnn_model.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/model/rnn_vae.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/schemas/states.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/util/align_egocentrical.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/util/auxiliary.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/util/csv_to_npy.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/util/data_manipulation.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/util/gif_pose_helper.md delete mode 100644 docs/vame-docs-app/docs/reference/vame/util/model_util.md create mode 100644 docs/yarn.lock diff --git a/.github/workflows/publish_docs.yaml b/.github/workflows/publish_docs.yaml index 9017d2ba..062fa992 100644 --- a/.github/workflows/publish_docs.yaml +++ b/.github/workflows/publish_docs.yaml @@ -5,11 +5,14 @@ on: branches: - main - updated-docs + # paths: + # - '.github/workflows/publush_docs.yaml' + # - 'docs/**' jobs: deploy: name: Deploy VAME Docs to GitHub Pages - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 permissions: contents: write pages: write @@ -62,4 +65,3 @@ jobs: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: ./docs/vame-docs-app/build publish_branch: gh-pages - destination_dir: ${{ github.ref_name == 'main' && '' || github.ref_name }} diff --git a/docs/vame-docs-app/docs/reference/_category_.json b/docs/vame-docs-app/docs/reference/_category_.json deleted file mode 100644 index d15c769f..00000000 --- a/docs/vame-docs-app/docs/reference/_category_.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "label": "API reference", - "position": 3, - "link": { - "type": "generated-index", - "description": "VAME package API reference" - } -} \ No newline at end of file diff --git a/docs/vame-docs-app/docs/reference/analysis/community_analysis.md b/docs/vame-docs-app/docs/reference/analysis/community_analysis.md index 0552f9ff..bd160d42 100644 --- a/docs/vame-docs-app/docs/reference/analysis/community_analysis.md +++ b/docs/vame-docs-app/docs/reference/analysis/community_analysis.md @@ -17,7 +17,7 @@ def get_adjacency_matrix( Calculate the adjacency matrix, transition matrix, and temporal matrix. -**Arguments** +**Parameters** * **labels** (`np.ndarray`): Array of cluster labels. * **n_clusters** (`int`): Number of clusters. @@ -35,7 +35,7 @@ def get_transition_matrix(adjacency_matrix: np.ndarray, Compute the transition matrix from the adjacency matrix. -**Arguments** +**Parameters** * **adjacency_matrix** (`np.ndarray`): Adjacency matrix. * **threshold** (`float, optional`): Threshold for considering transitions. Defaults to 0.0. @@ -64,7 +64,7 @@ Example 2: - n_clusters = 6 - the function will return [10, 20, 0, 30, 40, 0]. -**Arguments** +**Parameters** * **unique_motif_labels** (`np.ndarray`): Array of unique motif labels. * **motif_counts** (`np.ndarray`): Array of motif counts (in number of frames). @@ -83,7 +83,7 @@ def augment_motif_timeseries(labels: np.ndarray, Augment motif time series by filling zero motifs. -**Arguments** +**Parameters** * **labels** (`np.ndarray`): Original array of labels. * **n_clusters** (`int`): Number of clusters. @@ -105,7 +105,7 @@ def get_motif_labels(config: dict, sessions: List[str], model_name: str, Get motif labels for given files. -**Arguments** +**Parameters** * **config** (`dict`): Configuration parameters. * **sessions** (`List[str]`): List of session names. @@ -126,7 +126,7 @@ def compute_transition_matrices(files: List[str], labels: List[np.ndarray], Compute transition matrices for given files and labels. -**Arguments** +**Parameters** * **files** (`List[str]`): List of file paths. * **labels** (`List[np.ndarray]`): List of label arrays. @@ -147,7 +147,7 @@ def create_cohort_community_bag(motif_labels: List[np.ndarray], Create cohort community bag for given motif labels, transition matrix, cut tree, and number of clusters. (markov chain to tree -> community detection) -**Arguments** +**Parameters** * **motif_labels** (`List[np.ndarray]`): List of motif label arrays. * **trans_mat_full** (`np.ndarray`): Full transition matrix. @@ -170,7 +170,7 @@ def get_cohort_community_labels( Transform kmeans/hmm parameterized latent vector motifs into communities. Get cohort community labels for given labels, and community bags. -**Arguments** +**Parameters** * **labels** (`List[np.ndarray]`): List of label arrays. * **cohort_community_bag** (`np.ndarray`): List of community bags. Dimensions: (n_communities, n_clusters_in_community) @@ -231,7 +231,7 @@ Saves results files at: - community_label_file_name.npy - hierarchy_file_name.pkl -**Arguments** +**Parameters** * **config** (`dict`): Configuration parameters. * **segmentation_algorithm** (`SegmentationAlgorithms`): Which segmentation algorithm to use. Options are 'hmm' or 'kmeans'. diff --git a/docs/vame-docs-app/docs/reference/analysis/generative_functions.md b/docs/vame-docs-app/docs/reference/analysis/generative_functions.md index 6b500d92..74e11587 100644 --- a/docs/vame-docs-app/docs/reference/analysis/generative_functions.md +++ b/docs/vame-docs-app/docs/reference/analysis/generative_functions.md @@ -18,7 +18,7 @@ def random_generative_samples_motif(cfg: dict, model: torch.nn.Module, Generate random samples for motifs. -**Arguments** +**Parameters** * **cfg** (`dict`): Configuration dictionary. * **model** (`torch.nn.Module`): PyTorch model. @@ -39,7 +39,7 @@ def random_generative_samples(cfg: dict, model: torch.nn.Module, Generate random generative samples. -**Arguments** +**Parameters** * **cfg** (`dict`): Configuration dictionary. * **model** (`torch.nn.Module`): PyTorch model. @@ -58,7 +58,7 @@ def random_reconstruction_samples(cfg: dict, model: torch.nn.Module, Generate random reconstruction samples. -**Arguments** +**Parameters** * **cfg** (`dict`): Configuration dictionary. * **model** (`torch.nn.Module`): PyTorch model to use. @@ -77,7 +77,7 @@ def visualize_cluster_center(cfg: dict, model: torch.nn.Module, Visualize cluster centers. -**Arguments** +**Parameters** * **cfg** (`dict`): Configuration dictionary. * **model** (`torch.nn.Module`): PyTorch model. @@ -99,7 +99,7 @@ def generative_model(config: dict, Generative model. -**Arguments** +**Parameters** * **config** (`dict`): Configuration dictionary. * **mode** (`str, optional`): Mode for generating samples. Defaults to "sampling". diff --git a/docs/vame-docs-app/docs/reference/analysis/gif_creator.md b/docs/vame-docs-app/docs/reference/analysis/gif_creator.md index 85099c91..d95358c6 100644 --- a/docs/vame-docs-app/docs/reference/analysis/gif_creator.md +++ b/docs/vame-docs-app/docs/reference/analysis/gif_creator.md @@ -17,7 +17,7 @@ def create_video(path_to_file: str, session: str, embed: np.ndarray, Create video frames for the given embedding. -**Arguments** +**Parameters** * **path_to_file** (`str`): Path to the file. * **session** (`str`): Session name. @@ -51,7 +51,7 @@ def gif( Create a GIF from the given configuration. -**Arguments** +**Parameters** * **config** (`str`): Path to the configuration file. * **pose_ref_index** (`list`): List of reference coordinate indices for alignment. diff --git a/docs/vame-docs-app/docs/reference/analysis/pose_segmentation.md b/docs/vame-docs-app/docs/reference/analysis/pose_segmentation.md index 038152bb..13d043e9 100644 --- a/docs/vame-docs-app/docs/reference/analysis/pose_segmentation.md +++ b/docs/vame-docs-app/docs/reference/analysis/pose_segmentation.md @@ -21,7 +21,7 @@ def embedd_latent_vectors( Embed latent vectors for the given files using the VAME model. -**Arguments** +**Parameters** * **cfg** (`dict`): Configuration dictionary. * **sessions** (`List[str]`): List of session names. @@ -41,7 +41,7 @@ def get_motif_usage(session_labels: np.ndarray, n_clusters: int) -> np.ndarray Count motif usage from session label array. -**Arguments** +**Parameters** * **session_labels** (`np.ndarray`): Array of session labels. * **n_clusters** (`int`): Number of clusters. @@ -61,7 +61,7 @@ def same_segmentation( Apply the same segmentation to all animals. -**Arguments** +**Parameters** * **cfg** (`dict`): Configuration dictionary. * **sessions** (`List[str]`): List of session names. @@ -83,7 +83,7 @@ def individual_segmentation(cfg: dict, sessions: List[str], Apply individual segmentation to each session. -**Arguments** +**Parameters** * **cfg** (`dict`): Configuration dictionary. * **sessions** (`List[str]`): List of session names. @@ -128,7 +128,7 @@ Dimmentions: (n_motifs,) n_cluster_label_session.npy contains the label of the cluster assigned to each frame. Dimmentions: (n_frames,) -**Arguments** +**Parameters** * **config** (`dict`): Configuration dictionary. * **save_logs** (`bool, optional`): Whether to save logs, by default False. diff --git a/docs/vame-docs-app/docs/reference/analysis/tree_hierarchy.md b/docs/vame-docs-app/docs/reference/analysis/tree_hierarchy.md index 6a3e8790..21a7b234 100644 --- a/docs/vame-docs-app/docs/reference/analysis/tree_hierarchy.md +++ b/docs/vame-docs-app/docs/reference/analysis/tree_hierarchy.md @@ -17,7 +17,7 @@ def hierarchy_pos(G: nx.Graph, Positions nodes in a tree-like layout. Ref: From Joel's answer at https://stackoverflow.com/a/29597209/2966723. -**Arguments** +**Parameters** * **G** (`nx.Graph`): The input graph. Must be a tree. * **root** (`str, optional`): The root node of the tree. If None, the function selects a root node based on graph type. @@ -41,7 +41,7 @@ def merge_func(transition_matrix: np.ndarray, n_clusters: int, Merge nodes in a graph based on a selection criterion. -**Arguments** +**Parameters** * **transition_matrix** (`np.ndarray`): The transition matrix of the graph. * **n_clusters** (`int`): The number of clusters. @@ -65,7 +65,7 @@ def graph_to_tree(motif_usage: np.ndarray, Convert a graph to a tree. -**Arguments** +**Parameters** * **motif_usage** (`np.ndarray`): The motif usage matrix. * **transition_matrix** (`np.ndarray`): The transition matrix of the graph. @@ -89,7 +89,7 @@ def draw_tree( Draw a tree. -**Arguments** +**Parameters** * **T** (`nx.Graph`): The tree to be drawn. * **fig_width** (`int, optional`): The width of the figure. Defaults to 10. @@ -114,7 +114,7 @@ def _traverse_tree_cutline( DEPRECATED in favor of bag_nodes_by_cutline. Helper function for tree traversal with a cutline. -**Arguments** +**Parameters** * **T** (`nx.Graph`): The tree to be traversed. * **node** (`List[str]`): Current node being traversed. @@ -139,7 +139,7 @@ def traverse_tree_cutline(T: nx.Graph, DEPRECATED in favor of bag_nodes_by_cutline. Traverse a tree with a cutline and return the community bags. -**Arguments** +**Parameters** * **T** (`nx.Graph`): The tree to be traversed. * **root_node** (`str, optional`): The root node of the tree. If None, traversal starts from the root. @@ -157,7 +157,7 @@ def bag_nodes_by_cutline(tree: nx.Graph, cutline: int = 2, root: str = "Root") Bag nodes of a tree by a cutline. -**Arguments** +**Parameters** * **tree** (`nx.Graph`): The tree to be bagged. * **cutline** (`int, optional`): The cutline level. Defaults to 2. diff --git a/docs/vame-docs-app/docs/reference/analysis/umap.md b/docs/vame-docs-app/docs/reference/analysis/umap.md index 6b306602..79841311 100644 --- a/docs/vame-docs-app/docs/reference/analysis/umap.md +++ b/docs/vame-docs-app/docs/reference/analysis/umap.md @@ -17,7 +17,7 @@ def umap_embedding( Perform UMAP embedding for given file and parameters. -**Arguments** +**Parameters** * **cfg** (`dict`): Configuration parameters. * **session** (`str`): Session name. @@ -37,7 +37,7 @@ def umap_vis(embed: np.ndarray, num_points: int) -> plt.Figure Visualize UMAP embedding without labels. -**Arguments** +**Parameters** * **embed** (`np.ndarray`): UMAP embedding. * **num_points** (`int`): Number of data points to visualize. @@ -55,7 +55,7 @@ def umap_label_vis(embed: np.ndarray, label: np.ndarray, Visualize UMAP embedding with motif labels. -**Arguments** +**Parameters** * **embed** (`np.ndarray`): UMAP embedding. * **label** (`np.ndarray`): Motif labels. @@ -74,7 +74,7 @@ def umap_vis_comm(embed: np.ndarray, community_label: np.ndarray, Visualize UMAP embedding with community labels. -**Arguments** +**Parameters** * **embed** (`np.ndarray`): UMAP embedding. * **community_label** (`np.ndarray`): Community labels. @@ -110,7 +110,7 @@ If label is None (UMAP visualization without labels): - umap_vis_motif_file_name.png (UMAP visualization with motif labels) - umap_vis_community_file_name.png (UMAP visualization with community labels) -**Arguments** +**Parameters** * **config** (`dict`): Configuration parameters. * **segmentation_algorithm** (`SegmentationAlgorithms`): Which segmentation algorithm to use. Options are 'hmm' or 'kmeans'. diff --git a/docs/vame-docs-app/docs/reference/analysis/videowriter.md b/docs/vame-docs-app/docs/reference/analysis/videowriter.md index fa570458..0dfe6371 100644 --- a/docs/vame-docs-app/docs/reference/analysis/videowriter.md +++ b/docs/vame-docs-app/docs/reference/analysis/videowriter.md @@ -25,7 +25,7 @@ def create_cluster_videos( Generate cluster videos and save them to filesystem on project folder. -**Arguments** +**Parameters** * **config** (`dict`): Configuration parameters. * **path_to_file** (`str`): Path to the file. @@ -66,7 +66,7 @@ Files are saved at: - session-motif_1.mp4 - ... -**Arguments** +**Parameters** * **config** (`dict`): Configuration parameters. * **segmentation_algorithm** (`SegmentationAlgorithms`): Which segmentation algorithm to use. Options are 'hmm' or 'kmeans'. @@ -108,7 +108,7 @@ TODO: Add cohort analysis - file_name-community_1.mp4 - ... -**Arguments** +**Parameters** * **config** (`dict`): Configuration parameters. * **segmentation_algorithm** (`SegmentationAlgorithms`): Which segmentation algorithm to use. Options are 'hmm' or 'kmeans'. diff --git a/docs/vame-docs-app/docs/reference/initialize_project/new.md b/docs/vame-docs-app/docs/reference/initialize_project/new.md index 5c62c0d6..3e136d00 100644 --- a/docs/vame-docs-app/docs/reference/initialize_project/new.md +++ b/docs/vame-docs-app/docs/reference/initialize_project/new.md @@ -47,7 +47,7 @@ A VAME project is a directory with the following structure: - states.json - config.yaml -**Arguments** +**Parameters** * **project_name** (`str`): Project name. * **videos** (`List[str]`): List of videos paths to be used in the project. E.g. ['./sample_data/Session001.mp4'] diff --git a/docs/vame-docs-app/docs/reference/io/load_poses.md b/docs/vame-docs-app/docs/reference/io/load_poses.md index 7fc29412..01112961 100644 --- a/docs/vame-docs-app/docs/reference/io/load_poses.md +++ b/docs/vame-docs-app/docs/reference/io/load_poses.md @@ -14,7 +14,7 @@ def load_pose_estimation( Load pose estimation data. -**Arguments** +**Parameters** * **pose_estimation_file** (`Path or str`): Path to the pose estimation file. * **video_file** (`Path or str`): Path to the video file. @@ -33,7 +33,7 @@ def load_vame_dataset(ds_path: Path | str) -> xr.Dataset Load VAME dataset. -**Arguments** +**Parameters** * **ds_path** (`Path or str`): Path to the netCDF dataset. @@ -59,7 +59,7 @@ def read_pose_estimation_file( Read pose estimation file. -**Arguments** +**Parameters** * **file_path** (`str`): Path to the pose estimation file. * **file_type** (`PoseEstimationFiletype`): Type of the pose estimation file. Supported types are 'csv' and 'nwb'. diff --git a/docs/vame-docs-app/docs/reference/model/create_training.md b/docs/vame-docs-app/docs/reference/model/create_training.md index fd9e8e24..c4473f92 100644 --- a/docs/vame-docs-app/docs/reference/model/create_training.md +++ b/docs/vame-docs-app/docs/reference/model/create_training.md @@ -19,7 +19,7 @@ def traindata_aligned(config: dict, Create training dataset for aligned data. Save numpy arrays with the test/train info to the project folder. -**Arguments** +**Parameters** * **config** (`dict`): Configuration parameters dictionary. * **sessions** (`List[str], optional`): List of session names. If None, all sessions will be used. Defaults to None. @@ -56,7 +56,7 @@ The produced test_seq.npy contains the combined data in the shape of (num_dlc_fe The produced train_seq.npy contains the combined data in the shape of (num_dlc_features - 2, num_video_frames * (1 - test_fraction)). -**Arguments** +**Parameters** * **config** (`dict`): Configuration parameters dictionary. * **save_logs** (`bool, optional`): If True, the function will save logs to the project folder. Defaults to False. diff --git a/docs/vame-docs-app/docs/reference/model/dataloader.md b/docs/vame-docs-app/docs/reference/model/dataloader.md index bce649f9..330c6d2d 100644 --- a/docs/vame-docs-app/docs/reference/model/dataloader.md +++ b/docs/vame-docs-app/docs/reference/model/dataloader.md @@ -24,7 +24,7 @@ Creates files at: - seq_mean.npy - seq_std.npy -**Arguments** +**Parameters** * **path_to_file** (`str`): Path to the dataset files. * **data** (`str`): Name of the data file. @@ -55,7 +55,7 @@ def __getitem__(index: int) -> torch.Tensor Get a normalized sequence at the specified index. -**Arguments** +**Parameters** * **index** (`int`): Index of the item. diff --git a/docs/vame-docs-app/docs/reference/model/evaluate.md b/docs/vame-docs-app/docs/reference/model/evaluate.md index 28f7e972..d33a5b50 100644 --- a/docs/vame-docs-app/docs/reference/model/evaluate.md +++ b/docs/vame-docs-app/docs/reference/model/evaluate.md @@ -29,7 +29,7 @@ Saves the plot to: - evaluate/ - Reconstruction_model_name.png -**Arguments** +**Parameters** * **filepath** (`str`): Path to save the plot. * **test_loader** (`Data.DataLoader`): DataLoader for the test dataset. @@ -57,7 +57,7 @@ Saves the plot to: - evaluate/ - MSE-and-KL-Loss_model_name.png -**Arguments** +**Parameters** * **cfg** (`dict`): Configuration dictionary. * **filepath** (`str`): Path to save the plot. @@ -80,7 +80,7 @@ def eval_temporal(cfg: dict, Evaluate the temporal aspects of the trained model. -**Arguments** +**Parameters** * **cfg** (`dict`): Configuration dictionary. * **use_gpu** (`bool`): Flag indicating whether to use GPU for evaluation. @@ -109,7 +109,7 @@ Saves the evaluation results to: - model/ - evaluate/ -**Arguments** +**Parameters** * **config** (`dict`): Configuration dictionary. * **use_snapshots** (`bool, optional`): Whether to plot for all snapshots or only the best model. Defaults to False. diff --git a/docs/vame-docs-app/docs/reference/model/rnn_model.md b/docs/vame-docs-app/docs/reference/model/rnn_model.md index 9c6b9773..9b7a70b8 100644 --- a/docs/vame-docs-app/docs/reference/model/rnn_model.md +++ b/docs/vame-docs-app/docs/reference/model/rnn_model.md @@ -20,7 +20,7 @@ def __init__(NUM_FEATURES: int, hidden_size_layer_1: int, Initialize the Encoder module. -**Arguments** +**Parameters** * **NUM_FEATURES** (`int`): Number of input features. * **hidden_size_layer_1** (`int`): Size of the first hidden layer. @@ -35,7 +35,7 @@ def forward(inputs: torch.Tensor) -> torch.Tensor Forward pass of the Encoder module. -**Arguments** +**Parameters** * **inputs** (`torch.Tensor`): Input tensor of shape (batch_size, sequence_length, num_features). @@ -59,7 +59,7 @@ def __init__(ZDIMS: int, hidden_size_layer_1: int, softplus: bool) Initialize the Lambda module. -**Arguments** +**Parameters** * **ZDIMS** (`int`): Size of the latent space. * **hidden_size_layer_1** (`int`): Size of the first hidden layer. @@ -76,7 +76,7 @@ def forward( Forward pass of the Lambda module. -**Arguments** +**Parameters** * **hidden** (`torch.Tensor`): Hidden representation tensor of shape (batch_size, hidden_size_layer_1 * 4). @@ -101,7 +101,7 @@ def __init__(TEMPORAL_WINDOW: int, ZDIMS: int, NUM_FEATURES: int, Initialize the Decoder module. -**Arguments** +**Parameters** * **TEMPORAL_WINDOW** (`int`): Size of the temporal window. * **ZDIMS** (`int`): Size of the latent space. @@ -117,7 +117,7 @@ def forward(inputs: torch.Tensor, z: torch.Tensor) -> torch.Tensor Forward pass of the Decoder module. -**Arguments** +**Parameters** * **inputs** (`torch.Tensor`): Input tensor of shape (batch_size, seq_len, ZDIMS). * **z** (`torch.Tensor`): Latent space tensor of shape (batch_size, ZDIMS). @@ -143,7 +143,7 @@ def __init__(TEMPORAL_WINDOW: int, ZDIMS: int, NUM_FEATURES: int, Initialize the Decoder_Future module. -**Arguments** +**Parameters** * **TEMPORAL_WINDOW** (`int`): Size of the temporal window. * **ZDIMS** (`int`): Size of the latent space. @@ -160,7 +160,7 @@ def forward(inputs: torch.Tensor, z: torch.Tensor) -> torch.Tensor Forward pass of the Decoder_Future module. -**Arguments** +**Parameters** * **inputs** (`torch.Tensor`): Input tensor of shape (batch_size, seq_len, ZDIMS). * **z** (`torch.Tensor`): Latent space tensor of shape (batch_size, ZDIMS). @@ -189,7 +189,7 @@ def __init__(TEMPORAL_WINDOW: int, ZDIMS: int, NUM_FEATURES: int, Initialize the VAE module. -**Arguments** +**Parameters** * **TEMPORAL_WINDOW** (`int`): Size of the temporal window. * **ZDIMS** (`int`): Size of the latent space. @@ -210,7 +210,7 @@ def forward(seq: torch.Tensor) -> tuple Forward pass of the VAE. -**Arguments** +**Parameters** * **seq** (`torch.Tensor`): Input sequence tensor of shape (batch_size, seq_len, NUM_FEATURES). diff --git a/docs/vame-docs-app/docs/reference/model/rnn_vae.md b/docs/vame-docs-app/docs/reference/model/rnn_vae.md index 4b308601..3d624db4 100644 --- a/docs/vame-docs-app/docs/reference/model/rnn_vae.md +++ b/docs/vame-docs-app/docs/reference/model/rnn_vae.md @@ -20,7 +20,7 @@ def reconstruction_loss(x: torch.Tensor, x_tilde: torch.Tensor, Compute the reconstruction loss between input and reconstructed data. -**Arguments** +**Parameters** * **x** (`torch.Tensor`): Input data tensor. * **x_tilde** (`torch.Tensor`): Reconstructed data tensor. @@ -39,7 +39,7 @@ def future_reconstruction_loss(x: torch.Tensor, x_tilde: torch.Tensor, Compute the future reconstruction loss between input and predicted future data. -**Arguments** +**Parameters** * **x** (`torch.Tensor`): Input future data tensor. * **x_tilde** (`torch.Tensor`): Reconstructed future data tensor. @@ -58,7 +58,7 @@ def cluster_loss(H: torch.Tensor, kloss: int, lmbda: float, Compute the cluster loss. -**Arguments** +**Parameters** * **H** (`torch.Tensor`): Latent representation tensor. * **kloss** (`int`): Number of clusters. @@ -81,7 +81,7 @@ See Appendix B from VAE paper: Kingma and Welling. Auto-Encoding Variational Bay Formula: 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) -**Arguments** +**Parameters** * **mu** (`torch.Tensor`): Mean of the latent distribution. * **logvar** (`torch.Tensor`): Log variance of the latent distribution. @@ -100,7 +100,7 @@ def kl_annealing(epoch: int, kl_start: int, annealtime: int, Anneal the Kullback-Leibler loss to let the model learn first the reconstruction of the data before the KL loss term gets introduced. -**Arguments** +**Parameters** * **epoch** (`int`): Current epoch number. * **kl_start** (`int`): Epoch number to start annealing the loss. @@ -122,7 +122,7 @@ def gaussian(ins: torch.Tensor, Add Gaussian noise to the input data. -**Arguments** +**Parameters** * **ins** (`torch.Tensor`): Input data tensor. * **is_training** (`bool`): Whether it is training mode. @@ -146,7 +146,7 @@ def train(train_loader: Data.DataLoader, epoch: int, model: nn.Module, Train the model. -**Arguments** +**Parameters** * **train_loader** (`DataLoader`): Training data loader. * **epoch** (`int`): Current epoch number. @@ -183,7 +183,7 @@ def test(test_loader: Data.DataLoader, model: nn.Module, BETA: float, Evaluate the model on the test dataset. -**Arguments** +**Parameters** * **test_loader** (`DataLoader`): DataLoader for the test dataset. * **model** (`nn.Module`): The trained model. @@ -229,7 +229,7 @@ Creates files at: - weight_values_VAME.npy - pretrained_model/ -**Arguments** +**Parameters** * **config** (`dict`): Configuration dictionary. * **save_logs** (`bool, optional`): Whether to save the logs, by default False. diff --git a/docs/vame-docs-app/docs/reference/preprocessing/alignment.md b/docs/vame-docs-app/docs/reference/preprocessing/alignment.md index 5e77c351..2132364f 100644 --- a/docs/vame-docs-app/docs/reference/preprocessing/alignment.md +++ b/docs/vame-docs-app/docs/reference/preprocessing/alignment.md @@ -21,7 +21,7 @@ def egocentrically_align_and_center( Aligns the time series by first centralizing all positions around the first keypoint and then applying rotation to align with the line connecting the two keypoints. -**Arguments** +**Parameters** * **config** (`dict`): Configuration dictionary * **centered_reference_keypoint** (`str`): Name of the keypoint to use as centered reference. diff --git a/docs/vame-docs-app/docs/reference/preprocessing/cleaning.md b/docs/vame-docs-app/docs/reference/preprocessing/cleaning.md index 1959e5f5..c971c964 100644 --- a/docs/vame-docs-app/docs/reference/preprocessing/cleaning.md +++ b/docs/vame-docs-app/docs/reference/preprocessing/cleaning.md @@ -19,7 +19,7 @@ Clean the low confidence data points from the dataset. Processes position data b - setting low-confidence points to NaN - interpolating NaN points -**Arguments** +**Parameters** * **config** (`dict`): Configuration dictionary. * **read_from_variable** (`str, optional`): Variable to read from the dataset. @@ -41,7 +41,7 @@ Clean the outliers from the dataset. Processes position data by: - setting outlier points to NaN - interpolating NaN points -**Arguments** +**Parameters** * **config** (`dict`): Configuration dictionary. * **read_from_variable** (`str, optional`): Variable to read from the dataset. diff --git a/docs/vame-docs-app/docs/reference/preprocessing/filter.md b/docs/vame-docs-app/docs/reference/preprocessing/filter.md index cba225cc..d9762eb5 100644 --- a/docs/vame-docs-app/docs/reference/preprocessing/filter.md +++ b/docs/vame-docs-app/docs/reference/preprocessing/filter.md @@ -17,7 +17,7 @@ def savgol_filtering(config: dict, Apply Savitzky-Golay filter to the data. -**Arguments** +**Parameters** * **config** (`dict`): Configuration dictionary. * **read_from_variable** (`str, optional`): Variable to read from the dataset. diff --git a/docs/vame-docs-app/docs/reference/preprocessing/preprocessing.md b/docs/vame-docs-app/docs/reference/preprocessing/preprocessing.md index 9065e7b7..6bb630e3 100644 --- a/docs/vame-docs-app/docs/reference/preprocessing/preprocessing.md +++ b/docs/vame-docs-app/docs/reference/preprocessing/preprocessing.md @@ -22,7 +22,7 @@ Preprocess the data by: - Outlier cleaning - Savitzky-Golay filtering -**Arguments** +**Parameters** * **config** (`dict`): Configuration dictionary. * **centered_reference_keypoint** (`str, optional`): Keypoint to use as centered reference. diff --git a/docs/vame-docs-app/docs/reference/preprocessing/to_model.md b/docs/vame-docs-app/docs/reference/preprocessing/to_model.md index 44dbd198..801194bf 100644 --- a/docs/vame-docs-app/docs/reference/preprocessing/to_model.md +++ b/docs/vame-docs-app/docs/reference/preprocessing/to_model.md @@ -15,7 +15,7 @@ Formats the xarray dataset for use VAME's RNN model: - The x coordinate of the orientation_reference_keypoint is excluded. - The remaining data is flattened and transposed. -**Arguments** +**Parameters** * **ds** (`xr.Dataset`): The xarray dataset to format. * **read_from_variable** (`str, default="position_processed"`): The variable to read from the dataset. diff --git a/docs/vame-docs-app/docs/reference/util/auxiliary.md b/docs/vame-docs-app/docs/reference/util/auxiliary.md index 38e56e49..7fcac602 100644 --- a/docs/vame-docs-app/docs/reference/util/auxiliary.md +++ b/docs/vame-docs-app/docs/reference/util/auxiliary.md @@ -23,7 +23,7 @@ def read_config(config_file: str) -> dict Reads structured config file defining a project. -**Arguments** +**Parameters** * **config_file** (`str`): Path to the config file. @@ -39,7 +39,7 @@ def write_config(configname: str, cfg: dict) -> None Write structured config file. -**Arguments** +**Parameters** * **configname** (`str`): Path to the config file. * **cfg** (`dict`): Dictionary containing the config data. @@ -52,7 +52,7 @@ def read_states(config: dict) -> dict Reads the states.json file. -**Arguments** +**Parameters** * **config** (`dict`): Dictionary containing the config data. diff --git a/docs/vame-docs-app/docs/reference/util/csv_to_npy.md b/docs/vame-docs-app/docs/reference/util/csv_to_npy.md index 734e6a45..00dbb70a 100644 --- a/docs/vame-docs-app/docs/reference/util/csv_to_npy.md +++ b/docs/vame-docs-app/docs/reference/util/csv_to_npy.md @@ -18,7 +18,7 @@ Converts a pose-estimation.csv file to a numpy array. Note that this code is only useful for data which is a priori egocentric, i.e. head-fixed or otherwise restrained animals. -**Arguments** +**Parameters** * **config** (`dict`): Configuration dictionary. * **save_logs** (`bool, optional`): If True, the logs will be saved to a file, by default False. diff --git a/docs/vame-docs-app/docs/reference/util/data_manipulation.md b/docs/vame-docs-app/docs/reference/util/data_manipulation.md index f651108a..5e143461 100644 --- a/docs/vame-docs-app/docs/reference/util/data_manipulation.md +++ b/docs/vame-docs-app/docs/reference/util/data_manipulation.md @@ -15,7 +15,7 @@ def consecutive(data: np.ndarray, stepsize: int = 1) -> List[np.ndarray] Find consecutive sequences in the data array. -**Arguments** +**Parameters** * **data** (`np.ndarray`): Input array. * **stepsize** (`int, optional`): Step size. Defaults to 1. @@ -32,7 +32,7 @@ def nan_helper(y: np.ndarray) -> Tuple Identifies indices of NaN values in an array and provides a function to convert them to non-NaN indices. -**Arguments** +**Parameters** * **y** (`np.ndarray`): Input array containing NaN values. @@ -50,7 +50,7 @@ def interpol_first_rows_nans(arr: np.ndarray) -> np.ndarray Interpolates NaN values in the given array. -**Arguments** +**Parameters** * **arr** (`np.ndarray`): Input array with NaN values. @@ -66,7 +66,7 @@ def interpolate_nans_with_pandas(data: np.ndarray) -> np.ndarray Interpolate NaN values along the time axis of a 3D NumPy array using Pandas. -**Arguments** +**Parameters** * **data** (`numpy.ndarray`): Input 3D array of shape (time, keypoints, space). @@ -84,7 +84,7 @@ def crop_and_flip_legacy( Crop and flip the image based on the given rectangle and points. -**Arguments** +**Parameters** * **rect** (`Tuple`): Rectangle coordinates (center, size, theta). * **src: np.ndarray**: Source image. @@ -107,7 +107,7 @@ def background(project_path: str, Compute background image from fixed camera. -**Arguments** +**Parameters** * **project_path** (`str`): Path to the project directory. * **session** (`str`): Name of the session. diff --git a/docs/vame-docs-app/docs/reference/util/gif_pose_helper.md b/docs/vame-docs-app/docs/reference/util/gif_pose_helper.md index 6003bb95..ed3a1abf 100644 --- a/docs/vame-docs-app/docs/reference/util/gif_pose_helper.md +++ b/docs/vame-docs-app/docs/reference/util/gif_pose_helper.md @@ -23,7 +23,7 @@ def get_animal_frames( Extracts frames of an animal from a video file and returns them as a list. -**Arguments** +**Parameters** * **cfg** (`dict`): Configuration dictionary containing project information. * **session** (`str`): Name of the session. diff --git a/docs/vame-docs-app/docs/reference/util/sample_data.md b/docs/vame-docs-app/docs/reference/util/sample_data.md index 887aa825..59795b87 100644 --- a/docs/vame-docs-app/docs/reference/util/sample_data.md +++ b/docs/vame-docs-app/docs/reference/util/sample_data.md @@ -15,7 +15,7 @@ def download_sample_data(source_software: str) -> dict Download sample data. -**Arguments** +**Parameters** * **source_software** (`str`): Source software used for pose estimation. diff --git a/docs/vame-docs-app/docs/reference/vame/analysis/community_analysis.md b/docs/vame-docs-app/docs/reference/vame/analysis/community_analysis.md deleted file mode 100644 index 34b4b4fc..00000000 --- a/docs/vame-docs-app/docs/reference/vame/analysis/community_analysis.md +++ /dev/null @@ -1,267 +0,0 @@ ---- -sidebar_label: community_analysis -title: vame.analysis.community_analysis ---- - -Variational Animal Motion Embedding 1.0-alpha Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -Updated 5/11/2022 with PH edits - -#### get\_adjacency\_matrix - -```python -def get_adjacency_matrix( - labels: np.ndarray, - n_cluster: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray] -``` - -Calculate the adjacency matrix, transition matrix, and temporal matrix. - -**Arguments**: - -- `labels` _np.ndarray_ - Array of cluster labels. -- `n_cluster` _int_ - Number of clusters. - - -**Returns**: - - Tuple[np.ndarray, np.ndarray, np.ndarray]: Tuple containing adjacency matrix, transition matrix, and temporal matrix. - -#### get\_transition\_matrix - -```python -def get_transition_matrix(adjacency_matrix: np.ndarray, - threshold: float = 0.0) -> np.ndarray -``` - -Compute the transition matrix from the adjacency matrix. - -**Arguments**: - -- `adjacency_matrix` _np.ndarray_ - Adjacency matrix. -- `threshold` _float, optional_ - Threshold for considering transitions. Defaults to 0.0. - - -**Returns**: - -- `np.ndarray` - Transition matrix. - -#### find\_zero\_labels - -```python -def find_zero_labels(motif_usage: Tuple[np.ndarray, np.ndarray], - n_cluster: int) -> np.ndarray -``` - -Find zero labels in motif usage and fill them. - -**Arguments**: - -- `motif_usage` _Tuple[np.ndarray, np.ndarray]_ - 2D list where the first index is a unique list of motif used and the second index is the motif usage in frames. -- `n_cluster` _int_ - Number of clusters. - - -**Returns**: - -- `np.ndarray` - List of motif usage frames with 0's where motifs weren't used (array with zero labels filled). - -#### augment\_motif\_timeseries - -```python -def augment_motif_timeseries(label: np.ndarray, - n_cluster: int) -> Tuple[np.ndarray, np.ndarray] -``` - -Augment motif time series by filling zero motifs. - -**Arguments**: - -- `label` _np.ndarray_ - Original label array. -- `n_cluster` _int_ - Number of clusters. - - -**Returns**: - - Tuple[np.ndarray, np.ndarray]: Augmented label array and indices of zero motifs. - -#### get\_labels - -```python -def get_labels(cfg: dict, files: List[str], model_name: str, n_cluster: int, - parametrization: str) -> List[np.ndarray] -``` - -Get cluster labels for given videos files. - -**Arguments**: - -- `cfg` _dict_ - Configuration parameters. -- `files` _List[str]_ - List of video files paths. -- `model_name` _str_ - Model name. -- `n_cluster` _int_ - Number of clusters. -- `parametrization` _str_ - parametrization. - - -**Returns**: - -- `List[np.ndarray]` - List of cluster labels for each file. - -#### get\_community\_label - -```python -def get_community_label(cfg: dict, files: List[str], model_name: str, - n_cluster: int, parametrization: str) -> np.ndarray -``` - -Get community labels for given files. - -**Arguments**: - -- `cfg` _dict_ - Configuration parameters. -- `files` _List[str]_ - List of files paths. -- `model_name` _str_ - Model name. -- `n_cluster` _int_ - Number of clusters. -- `parametrization` _str_ - parametrization. - - -**Returns**: - -- `np.ndarray` - Array of community labels. - -#### compute\_transition\_matrices - -```python -def compute_transition_matrices(files: List[str], labels: List[np.ndarray], - n_cluster: int) -> List[np.ndarray] -``` - -Compute transition matrices for given files and labels. - -**Arguments**: - -- `files` _List[str]_ - List of file paths. -- `labels` _List[np.ndarray]_ - List of label arrays. -- `n_cluster` _int_ - Number of clusters. - - -**Returns**: - -- `List[np.ndarray]` - List of transition matrices. - -#### create\_community\_bag - -```python -def create_community_bag(files: List[str], labels: List[np.ndarray], - transition_matrices: List[np.ndarray], cut_tree: int, - n_cluster: int) -> Tuple -``` - -Create community bag for given files and labels (Markov chain to tree -> community detection). - -**Arguments**: - -- `files` _List[str]_ - List of file paths. -- `labels` _List[np.ndarray]_ - List of label arrays. -- `transition_matrices` _List[np.ndarray]_ - List of transition matrices. -- `cut_tree` _int_ - Cut line for tree. -- `n_cluster` _int_ - Number of clusters. - - -**Returns**: - -- `Tuple` - Tuple containing list of community bags and list of trees. - -#### create\_cohort\_community\_bag - -```python -def create_cohort_community_bag(labels: List[np.ndarray], - trans_mat_full: np.ndarray, cut_tree: int, - n_cluster: int) -> Tuple -``` - -Create cohort community bag for given labels, transition matrix, cut tree, and number of clusters. -(markov chain to tree -> community detection) - -**Arguments**: - -- `labels` _List[np.ndarray]_ - List of label arrays. -- `trans_mat_full` _np.ndarray_ - Full transition matrix. -- `cut_tree` _int_ - Cut line for tree. -- `n_cluster` _int_ - Number of clusters. - - -**Returns**: - -- `Tuple` - Tuple containing list of community bags and list of trees. - -#### get\_community\_labels - -```python -def get_community_labels( - files: List[str], labels: List[np.ndarray], - communities_all: List[List[List[int]]]) -> List[np.ndarray] -``` - -Transform kmeans parameterized latent vector into communities. Get community labels for given files and community bags. - -**Arguments**: - -- `files` _List[str]_ - List of file paths. -- `labels` _List[np.ndarray]_ - List of label arrays. -- `communities_all` _List[List[List[int]]]_ - List of community bags. - - -**Returns**: - -- `List[np.ndarray]` - List of community labels for each file. - -#### get\_cohort\_community\_labels - -```python -def get_cohort_community_labels( - files: List[str], labels: List[np.ndarray], - communities_all: List[List[List[int]]]) -> List[np.ndarray] -``` - -Transform kmeans parameterized latent vector into communities. Get cohort community labels for given labels, and community bags. - -**Arguments**: - -- `files` _List[str], deprecated_ - List of file paths. -- `labels` _List[np.ndarray]_ - List of label arrays. -- `communities_all` _List[List[List[int]]]_ - List of community bags. - - -**Returns**: - -- `List[np.ndarray]` - List of cohort community labels for each file. - -#### community - -```python -@save_state(model=CommunityFunctionSchema) -def community(config: str, - parametrization: Parametrizations, - cohort: bool = True, - cut_tree: int | None = None, - save_logs: bool = False) -> None -``` - -Perform community analysis. - -**Arguments**: - -- `config` _str_ - Path to the configuration file. -- `cohort` _bool, optional_ - Flag indicating cohort analysis. Defaults to True. -- `cut_tree` _int, optional_ - Cut line for tree. Defaults to None. - - -**Returns**: - - None - diff --git a/docs/vame-docs-app/docs/reference/vame/analysis/generative_functions.md b/docs/vame-docs-app/docs/reference/vame/analysis/generative_functions.md deleted file mode 100644 index 778bd611..00000000 --- a/docs/vame-docs-app/docs/reference/vame/analysis/generative_functions.md +++ /dev/null @@ -1,118 +0,0 @@ ---- -sidebar_label: generative_functions -title: vame.analysis.generative_functions ---- - -Variational Animal Motion Embedding 1.0-alpha Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -#### random\_generative\_samples\_motif - -```python -def random_generative_samples_motif(cfg: dict, model: torch.nn.Module, - latent_vector: np.ndarray, - labels: np.ndarray, - n_cluster: int) -> None -``` - -Generate random samples for motifs. - -**Arguments**: - -- `cfg` _dict_ - Configuration dictionary. -- `model` _torch.nn.Module_ - PyTorch model. -- `latent_vector` _np.ndarray_ - Latent vectors. -- `labels` _np.ndarray_ - Labels. -- `n_cluster` _int_ - Number of clusters. - - -**Returns**: - -- `None` - Plot of generated samples. - -#### random\_generative\_samples - -```python -def random_generative_samples(cfg: dict, model: torch.nn.Module, - latent_vector: np.ndarray) -> None -``` - -Generate random generative samples. - -**Arguments**: - -- `cfg` _dict_ - Configuration dictionary. -- `model` _torch.nn.Module_ - PyTorch model. -- `latent_vector` _np.ndarray_ - Latent vectors. - - -**Returns**: - - None - -#### random\_reconstruction\_samples - -```python -def random_reconstruction_samples(cfg: dict, model: torch.nn.Module, - latent_vector: np.ndarray) -> None -``` - -Generate random reconstruction samples. - -**Arguments**: - -- `cfg` _dict_ - Configuration dictionary. -- `model` _torch.nn.Module_ - PyTorch model to use. -- `latent_vector` _np.ndarray_ - Latent vectors. - - -**Returns**: - - None - -#### visualize\_cluster\_center - -```python -def visualize_cluster_center(cfg: dict, model: torch.nn.Module, - cluster_center: np.ndarray) -> None -``` - -Visualize cluster centers. - -**Arguments**: - -- `cfg` _dict_ - Configuration dictionary. -- `model` _torch.nn.Module_ - PyTorch model. -- `cluster_center` _np.ndarray_ - Cluster centers. - - -**Returns**: - - None - -#### generative\_model - -```python -@save_state(model=GenerativeModelFunctionSchema) -def generative_model(config: str, - parametrization: Parametrizations, - mode: str = "sampling", - save_logs: bool = False) -> Dict[str, plt.Figure] -``` - -Generative model. - -**Arguments**: - -- `config` _str_ - Path to the configuration file. -- `mode` _str, optional_ - Mode for generating samples. Defaults to "sampling". - - -**Returns**: - - Dict[str, plt.Figure]: Plots of generated samples for each parametrization. - diff --git a/docs/vame-docs-app/docs/reference/vame/analysis/gif_creator.md b/docs/vame-docs-app/docs/reference/vame/analysis/gif_creator.md deleted file mode 100644 index 6073f8b0..00000000 --- a/docs/vame-docs-app/docs/reference/vame/analysis/gif_creator.md +++ /dev/null @@ -1,74 +0,0 @@ ---- -sidebar_label: gif_creator -title: vame.analysis.gif_creator ---- - -Variational Animal Motion Embedding 1.0-alpha Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -#### create\_video - -```python -def create_video(path_to_file: str, file: str, embed: np.ndarray, - clabel: np.ndarray, frames: List[np.ndarray], start: int, - length: int, max_lag: int, num_points: int) -> None -``` - -Create video frames for the given embedding. - -**Arguments**: - -- `path_to_file` _str_ - Path to the file. -- `file` _str_ - File name. -- `embed` _np.ndarray_ - Embedding array. -- `clabel` _np.ndarray_ - Cluster labels. -- `frames` _List[np.ndarray]_ - List of frames. -- `start` _int_ - Starting index. -- `length` _int_ - Length of the video. -- `max_lag` _int_ - Maximum lag. -- `num_points` _int_ - Number of points. - - -**Returns**: - - None - -#### gif - -```python -def gif( - config: str, - pose_ref_index: int, - parametrization: Parametrizations, - subtract_background: bool = True, - start: int | None = None, - length: int = 500, - max_lag: int = 30, - label: str = 'community', - file_format: str = '.mp4', - crop_size: Tuple[int, int] = (300, 300)) -> None -``` - -Create a GIF from the given configuration. - -**Arguments**: - -- `config` _str_ - Path to the configuration file. -- `pose_ref_index` _int_ - Pose reference index. -- `subtract_background` _bool, optional_ - Whether to subtract background. Defaults to True. -- `start` _int, optional_ - Starting index. Defaults to None. -- `length` _int, optional_ - Length of the video. Defaults to 500. -- `max_lag` _int, optional_ - Maximum lag. Defaults to 30. -- `label` _str, optional_ - Label type [None, community, motif]. Defaults to 'community'. -- `file_format` _str, optional_ - File format. Defaults to '.mp4'. -- `crop_size` _Tuple[int, int], optional_ - Crop size. Defaults to (300,300). - - -**Returns**: - - None - diff --git a/docs/vame-docs-app/docs/reference/vame/analysis/pose_segmentation.md b/docs/vame-docs-app/docs/reference/vame/analysis/pose_segmentation.md deleted file mode 100644 index 8a520605..00000000 --- a/docs/vame-docs-app/docs/reference/vame/analysis/pose_segmentation.md +++ /dev/null @@ -1,116 +0,0 @@ ---- -sidebar_label: pose_segmentation -title: vame.analysis.pose_segmentation ---- - -Variational Animal Motion Embedding 1.0-alpha Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -#### embedd\_latent\_vectors - -```python -def embedd_latent_vectors( - cfg: dict, files: List[str], model: RNN_VAE, fixed: bool, - tqdm_stream: TqdmToLogger | None) -> List[np.ndarray] -``` - -Embed latent vectors for the given files using the VAME model. - -**Arguments**: - -- `cfg` _dict_ - Configuration dictionary. -- `files` _List[str]_ - List of files names. -- `model` _RNN_VAE_ - VAME model. -- `fixed` _bool_ - Whether the model is fixed. -- `tqdm_stream` _TqdmToLogger_ - TQDM Stream to redirect the tqdm output to logger. - - -**Returns**: - -- `List[np.ndarray]` - List of latent vectors for each file. - -#### get\_motif\_usage - -```python -def get_motif_usage(label: np.ndarray) -> np.ndarray -``` - -Compute motif usage from the label array. - -**Arguments**: - -- `label` _np.ndarray_ - Label array. - - -**Returns**: - -- `np.ndarray` - Array of motif usage counts. - -#### same\_parametrization - -```python -def same_parametrization( - cfg: dict, files: List[str], latent_vector_files: List[np.ndarray], - states: int, parametrization: str -) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]] -``` - -Apply the same parametrization to all animals. - -**Arguments**: - -- `cfg` _dict_ - Configuration dictionary. -- `files` _List[str]_ - List of file names. -- `latent_vector_files` _List[np.ndarray]_ - List of latent vector arrays. -- `states` _int_ - Number of states. -- `parametrization` _str_ - parametrization method. - - -**Returns**: - -- `Tuple` - Tuple of labels, cluster centers, and motif usages. - -#### individual\_parametrization - -```python -def individual_parametrization(cfg: dict, files: List[str], - latent_vector_files: List[np.ndarray], - cluster: int) -> Tuple -``` - -Apply individual parametrization to each animal. - -**Arguments**: - -- `cfg` _dict_ - Configuration dictionary. -- `files` _List[str]_ - List of file names. -- `latent_vector_files` _List[np.ndarray]_ - List of latent vector arrays. -- `cluster` _int_ - Number of clusters. - - -**Returns**: - -- `Tuple` - Tuple of labels, cluster centers, and motif usages. - -#### pose\_segmentation - -```python -@save_state(model=PoseSegmentationFunctionSchema) -def pose_segmentation(config: str, save_logs: bool = False) -> None -``` - -Perform pose segmentation using the VAME model. - -**Arguments**: - -- `config` _str_ - Path to the configuration file. - - -**Returns**: - - None - diff --git a/docs/vame-docs-app/docs/reference/vame/analysis/segment_behavior.md b/docs/vame-docs-app/docs/reference/vame/analysis/segment_behavior.md deleted file mode 100644 index 6ac26482..00000000 --- a/docs/vame-docs-app/docs/reference/vame/analysis/segment_behavior.md +++ /dev/null @@ -1,136 +0,0 @@ ---- -sidebar_label: segment_behavior -title: vame.analysis.segment_behavior ---- - -Variational Animal Motion Embedding 0.1 Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -#### load\_data - -```python -def load_data(PROJECT_PATH: str, file: str, data: str) -> np.ndarray -``` - -Load data for the given file. - -**Arguments**: - -- `PROJECT_PATH` _str_ - Path to the project directory. -- `file` _str_ - Name of the file. -- `data` _str_ - Data to load. - - -**Returns**: - -- `np.ndarray` - Loaded data. - -#### kmeans\_clustering - -```python -def kmeans_clustering(context: np.ndarray, n_clusters: int) -> np.ndarray -``` - -Perform k-Means clustering. - -**Arguments**: - -- `context` _np.ndarray_ - Input data for clustering. -- `n_clusters` _int_ - Number of clusters. - - -**Returns**: - -- `np.ndarray` - Cluster labels. - -#### gmm\_clustering - -```python -def gmm_clustering(context: np.ndarray, n_components: int) -> np.ndarray -``` - -Perform Gaussian Mixture Model (GMM) clustering. - -**Arguments**: - -- `context` _np.ndarray_ - Input data for clustering. -- `n_components` _int_ - Number of components. - - -**Returns**: - -- `np.ndarray` - Cluster labels. - -#### behavior\_segmentation - -```python -def behavior_segmentation(config: str, - model_name: str = None, - cluster_method: str = 'kmeans', - n_cluster: List[int] = [30]) -> None -``` - -Perform behavior segmentation. - -**Arguments**: - -- `config` _str_ - Path to the configuration file. -- `model_name` _str, optional_ - Name of the model. Defaults to None. -- `cluster_method` _str, optional_ - Clustering method. Defaults to 'kmeans'. -- `n_cluster` _List[int], optional_ - List of number of clusters. Defaults to [30]. - - -**Returns**: - -- `None` - Save data to the results directory. - -#### temporal\_quant - -```python -def temporal_quant(cfg: dict, model_name: str, files: List[str], - use_gpu: bool) -> Tuple -``` - -Quantify the temporal latent space. - -**Arguments**: - -- `cfg` _dict_ - Configuration dictionary. -- `model_name` _str_ - Name of the model. -- `files` _List[str]_ - List of file names. -- `use_gpu` _bool_ - Whether to use GPU. - - -**Returns**: - -- `Tuple` - Tuple of latent space array and logger. - -#### cluster\_latent\_space - -```python -def cluster_latent_space(cfg: dict, files: List[str], z_data: np.ndarray, - z_logger: List[int], cluster_method: str, - n_cluster: List[int], model_name: str) -> None -``` - -Cluster the latent space. - -**Arguments**: - -- `cfg` _dict_ - Configuration dictionary. -- `files` _List[str]_ - List of file names. -- `z_data` _np.ndarray_ - Latent space data. -- `z_logger` _List[int]_ - Logger for the latent space. -- `cluster_method` _str_ - Clustering method. -- `n_cluster` _List[int]_ - List of number of clusters. -- `model_name` _str_ - Name of the model. - - -**Returns**: - - None -> Save data to the results directory. - diff --git a/docs/vame-docs-app/docs/reference/vame/analysis/tree_hierarchy.md b/docs/vame-docs-app/docs/reference/vame/analysis/tree_hierarchy.md deleted file mode 100644 index 30d39816..00000000 --- a/docs/vame-docs-app/docs/reference/vame/analysis/tree_hierarchy.md +++ /dev/null @@ -1,132 +0,0 @@ ---- -sidebar_label: tree_hierarchy -title: vame.analysis.tree_hierarchy ---- - -Variational Animal Motion Embedding 1.0-alpha Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -#### hierarchy\_pos - -```python -def hierarchy_pos(G: nx.Graph, - root: str | None = None, - width: float = 0.5, - vert_gap: float = 0.2, - vert_loc: float = 0, - xcenter: float = 0.5) -> Dict[str, Tuple[float, float]] -``` - -Positions nodes in a tree-like layout. -Ref: From Joel's answer at https://stackoverflow.com/a/29597209/2966723. - -**Arguments**: - -- `G` _nx.Graph_ - The input graph. Must be a tree. -- `root` _str, optional_ - The root node of the tree. If None, the function selects a root node based on graph type. -- `width` _float, optional_ - The horizontal space assigned to each level. -- `vert_gap` _float, optional_ - The vertical gap between levels. -- `vert_loc` _float, optional_ - The vertical location of the root node. -- `xcenter` _float, optional_ - The horizontal location of the root node. - - -**Returns**: - - Dict[str, Tuple[float, float]]: A dictionary mapping node names to their positions (x, y). - -#### merge\_func - -```python -def merge_func(transition_matrix: np.ndarray, n_cluster: int, - motif_norm: np.ndarray, - merge_sel: int) -> Tuple[np.ndarray, np.ndarray] -``` - -Merge nodes in a graph based on a selection criterion. - -**Arguments**: - -- `transition_matrix` _np.ndarray_ - The transition matrix of the graph. -- `n_cluster` _int_ - The number of clusters. -- `motif_norm` _np.ndarray_ - The normalized motif matrix. -- `merge_sel` _int_ - The merge selection criterion. - - 0: Merge nodes with highest transition probability. - - 1: Merge nodes with lowest cost. - - -**Raises**: - -- `ValueError` - If an invalid merge selection criterion is provided. - - -**Returns**: - - Tuple[np.ndarray, np.ndarray]: A tuple containing the merged nodes. - -#### graph\_to\_tree - -```python -def graph_to_tree(motif_usage: np.ndarray, - transition_matrix: np.ndarray, - n_cluster: int, - merge_sel: int = 1) -> nx.Graph -``` - -Convert a graph to a tree. - -**Arguments**: - -- `motif_usage` _np.ndarray_ - The motif usage matrix. -- `transition_matrix` _np.ndarray_ - The transition matrix of the graph. -- `n_cluster` _int_ - The number of clusters. -- `merge_sel` _int, optional_ - The merge selection criterion. Defaults to 1. - - 0: Merge nodes with highest transition probability. - - 1: Merge nodes with lowest cost. - - -**Returns**: - -- `nx.Graph` - The tree. - -#### draw\_tree - -```python -def draw_tree(T: nx.Graph) -> None -``` - -Draw a tree. - -**Arguments**: - -- `T` _nx.Graph_ - The tree to be drawn. - - -**Returns**: - - None - -#### traverse\_tree\_cutline - -```python -def traverse_tree_cutline(T: nx.Graph, - root_node: str | None = None, - cutline: int = 2) -> List[List[str]] -``` - -Traverse a tree with a cutline and return the community bags. - -**Arguments**: - -- `T` _nx.Graph_ - The tree to be traversed. -- `root_node` _str, optional_ - The root node of the tree. If None, traversal starts from the root. -- `cutline` _int, optional_ - The cutline level. - - -**Returns**: - -- `List[List[str]]` - List of community bags. - diff --git a/docs/vame-docs-app/docs/reference/vame/analysis/umap.md b/docs/vame-docs-app/docs/reference/vame/analysis/umap.md deleted file mode 100644 index 91dd1fb8..00000000 --- a/docs/vame-docs-app/docs/reference/vame/analysis/umap.md +++ /dev/null @@ -1,115 +0,0 @@ ---- -sidebar_label: umap -title: vame.analysis.umap ---- - -Variational Animal Motion Embedding 1.0-alpha Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -#### umap\_embedding - -```python -def umap_embedding(cfg: dict, file: str, model_name: str, n_cluster: int, - parametrization: str) -> np.ndarray -``` - -Perform UMAP embedding for given file and parameters. - -**Arguments**: - -- `cfg` _dict_ - Configuration parameters. -- `file` _str_ - File path. -- `model_name` _str_ - Model name. -- `n_cluster` _int_ - Number of clusters. -- `parametrization` _str_ - parametrization. - - -**Returns**: - -- `np.ndarray` - UMAP embedding. - -#### umap\_vis - -```python -def umap_vis(embed: np.ndarray, num_points: int) -> None -``` - -Visualize UMAP embedding without labels. - -**Arguments**: - -- `embed` _np.ndarray_ - UMAP embedding. -- `num_points` _int_ - Number of data points to visualize. - - -**Returns**: - - None - Plot Visualization of UMAP embedding. - -#### umap\_label\_vis - -```python -def umap_label_vis(embed: np.ndarray, label: np.ndarray, n_cluster: int, - num_points: int) -> None -``` - -Visualize UMAP embedding with motif labels. - -**Arguments**: - -- `embed` _np.ndarray_ - UMAP embedding. -- `label` _np.ndarray_ - Motif labels. -- `n_cluster` _int_ - Number of clusters. -- `num_points` _int_ - Number of data points to visualize. - - -**Returns**: - - fig - Plot figure of UMAP visualization embedding with motif labels. - -#### umap\_vis\_comm - -```python -def umap_vis_comm(embed: np.ndarray, community_label: np.ndarray, - num_points: int) -> None -``` - -Visualize UMAP embedding with community labels. - -**Arguments**: - -- `embed` _np.ndarray_ - UMAP embedding. -- `community_label` _np.ndarray_ - Community labels. -- `num_points` _int_ - Number of data points to visualize. - - -**Returns**: - - fig - Plot figure of UMAP visualization embedding with community labels. - -#### visualization - -```python -@save_state(model=VisualizationFunctionSchema) -def visualization(config: Union[str, Path], - parametrization: Parametrizations, - label: Optional[str] = None, - save_logs: bool = False) -> None -``` - -Visualize UMAP embeddings based on configuration settings. - -**Arguments**: - -- `config` _Union[str, Path]_ - Path to the configuration file. -- `label` _str, optional_ - Type of labels to visualize. Default is None. - - -**Returns**: - - None - Plot Visualization of UMAP embeddings. - diff --git a/docs/vame-docs-app/docs/reference/vame/analysis/umap_visualization.md b/docs/vame-docs-app/docs/reference/vame/analysis/umap_visualization.md deleted file mode 100644 index 469fff9e..00000000 --- a/docs/vame-docs-app/docs/reference/vame/analysis/umap_visualization.md +++ /dev/null @@ -1,95 +0,0 @@ ---- -sidebar_label: umap_visualization -title: vame.analysis.umap_visualization ---- - -Variational Animal Motion Embedding 1.0-alpha Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -#### umap\_vis - -```python -def umap_vis(file: str, embed: np.ndarray, num_points: int) -> None -``` - -Visualize UMAP embedding without labels. - -**Arguments**: - -- `file` _str_ - Name of the file (deprecated). -- `embed` _np.ndarray_ - UMAP embedding. -- `num_points` _int_ - Number of data points to visualize. - - -**Returns**: - - None - Plot Visualization of UMAP embedding. - -#### umap\_label\_vis - -```python -def umap_label_vis(file: str, embed: np.ndarray, label: np.ndarray, - n_cluster: int, num_points: int) -> None -``` - -Visualize UMAP embedding with motif labels. - -**Arguments**: - -- `file` _str_ - Name of the file (deprecated). -- `embed` _np.ndarray_ - UMAP embedding. -- `label` _np.ndarray_ - Motif labels. -- `n_cluster` _int_ - Number of clusters. -- `num_points` _int_ - Number of data points to visualize. - - -**Returns**: - - fig - Plot figure of UMAP visualization embedding with motif labels. - -#### umap\_vis\_comm - -```python -def umap_vis_comm(file: str, embed: np.ndarray, community_label: np.ndarray, - num_points: int) -> None -``` - -Visualize UMAP embedding with community labels. - -**Arguments**: - -- `file` _str_ - Name of the file (deprecated). -- `embed` _np.ndarray_ - UMAP embedding. -- `community_label` _np.ndarray_ - Community labels. -- `num_points` _int_ - Number of data points to visualize. - - -**Returns**: - - fig - Plot figure of UMAP visualization embedding with community labels. - -#### visualization - -```python -@save_state(model=VisualizationFunctionSchema) -def visualization(config: Union[str, Path], - label: Optional[str] = None, - save_logs: bool = False) -> None -``` - -Visualize UMAP embeddings based on configuration settings. - -**Arguments**: - -- `config` _Union[str, Path]_ - Path to the configuration file. -- `label` _str, optional_ - Type of labels to visualize. Default is None. - - -**Returns**: - - None - Plot Visualization of UMAP embeddings. - diff --git a/docs/vame-docs-app/docs/reference/vame/analysis/videowriter.md b/docs/vame-docs-app/docs/reference/vame/analysis/videowriter.md deleted file mode 100644 index 5baddf24..00000000 --- a/docs/vame-docs-app/docs/reference/vame/analysis/videowriter.md +++ /dev/null @@ -1,89 +0,0 @@ ---- -sidebar_label: videowriter -title: vame.analysis.videowriter ---- - -Variational Animal Motion Embedding 1.0-alpha Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -#### get\_cluster\_vid - -```python -def get_cluster_vid(cfg: dict, - path_to_file: str, - file: str, - n_cluster: int, - videoType: str, - flag: str, - param: Parametrizations, - output_video_type: str = ".mp4", - tqdm_logger_stream: TqdmToLogger | None = None) -> None -``` - -Generate cluster videos. - -**Arguments**: - -- `cfg` _dict_ - Configuration parameters. -- `path_to_file` _str_ - Path to the file. -- `file` _str_ - Name of the file. -- `n_cluster` _int_ - Number of clusters. -- `videoType` _str_ - Type of input video. -- `flag` _str_ - Flag indicating the type of video (motif or community). - - -**Returns**: - - None - Generate cluster videos and save them to fs on project folder. - -#### motif\_videos - -```python -@save_state(model=MotifVideosFunctionSchema) -def motif_videos(config: Union[str, Path], - parametrization: Parametrizations, - videoType: str = '.mp4', - output_video_type: str = '.mp4', - save_logs: bool = False) -> None -``` - -Generate motif videos. - -**Arguments**: - -- `config` _Union[str, Path]_ - Path to the configuration file. -- `videoType` _str, optional_ - Type of video. Default is '.mp4'. -- `output_video_type` _str, optional_ - Type of output video. Default is '.mp4'. - - -**Returns**: - - None - Generate motif videos and save them to filesystem on project cluster_videos folder. - -#### community\_videos - -```python -@save_state(model=CommunityVideosFunctionSchema) -def community_videos(config: Union[str, Path], - parametrization: Parametrizations, - videoType: str = '.mp4', - save_logs: bool = False, - output_video_type: str = '.mp4') -> None -``` - -Generate community videos. - -**Arguments**: - -- `config` _Union[str, Path]_ - Path to the configuration file. -- `videoType` _str, optional_ - Type of video. Default is '.mp4'. - - -**Returns**: - - None - Generate community videos and save them to filesystem on project community_videos folder. - diff --git a/docs/vame-docs-app/docs/reference/vame/initialize_project/new.md b/docs/vame-docs-app/docs/reference/vame/initialize_project/new.md deleted file mode 100644 index 955f97a1..00000000 --- a/docs/vame-docs-app/docs/reference/vame/initialize_project/new.md +++ /dev/null @@ -1,47 +0,0 @@ ---- -sidebar_label: new -title: vame.initialize_project.new ---- - -Variational Animal Motion Embedding 1.0-alpha Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -The following code is adapted from: - -DeepLabCut2.0 Toolbox (deeplabcut.org) -© A. & M. Mathis Labs -https://github.com/AlexEMG/DeepLabCut -Please see AUTHORS for contributors. -https://github.com/AlexEMG/DeepLabCut/blob/master/AUTHORS -Licensed under GNU Lesser General Public License v3.0 - -#### init\_new\_project - -```python -def init_new_project( - project: str, - videos: List[str], - poses_estimations: List[str], - working_directory: str = '.', - videotype: str = '.mp4', - paths_to_pose_nwb_series_data: Optional[str] = None) -> str -``` - -Creates a new VAME project with the given parameters. - -**Arguments**: - -- `project` _str_ - Project name. -- `videos` _List[str]_ - List of videos paths to be used in the project. E.g. ['./sample_data/Session001.mp4'] -- `poses_estimations` _List[str]_ - List of pose estimation files paths to be used in the project. E.g. ['./sample_data/pose estimation/Session001.csv'] working_directory (str, optional): _description_. Defaults to None. -- `videotype` _str, optional_ - Video extension (.mp4 or .avi). Defaults to '.mp4'. - - -**Returns**: - -- `projconfigfile` _str_ - Path to the new vame project config file. - diff --git a/docs/vame-docs-app/docs/reference/vame/logging/logger.md b/docs/vame-docs-app/docs/reference/vame/logging/logger.md deleted file mode 100644 index e8b77d5e..00000000 --- a/docs/vame-docs-app/docs/reference/vame/logging/logger.md +++ /dev/null @@ -1,14 +0,0 @@ ---- -sidebar_label: logger -title: vame.logging.logger ---- - -## TqdmToLogger Objects - -```python -class TqdmToLogger(io.StringIO) -``` - -Output stream for TQDM which will output to logger module instead of -the StdOut. - diff --git a/docs/vame-docs-app/docs/reference/vame/model/create_training.md b/docs/vame-docs-app/docs/reference/vame/model/create_training.md deleted file mode 100644 index 707a7c1e..00000000 --- a/docs/vame-docs-app/docs/reference/vame/model/create_training.md +++ /dev/null @@ -1,103 +0,0 @@ ---- -sidebar_label: create_training -title: vame.model.create_training ---- - -Variational Animal Motion Embedding 1.0-alpha Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -#### plot\_check\_parameter - -```python -def plot_check_parameter(cfg: dict, iqr_val: float, num_frames: int, - X_true: List[np.ndarray], X_med: np.ndarray) -> None -``` - -Plot the check parameter - z-scored data and the filtered data. - -**Arguments**: - -- `cfg` _dict_ - Configuration parameters. -- `iqr_val` _float_ - IQR value. -- `num_frames` _int_ - Number of frames. -- `X_true` _List[np.ndarray]_ - List of true data. -- `X_med` _np.ndarray_ - Filtered data. -- `anchor_1` _int_ - Index of the first anchor point (deprecated). -- `anchor_2` _int_ - Index of the second anchor point (deprecated). - - -**Returns**: - - None - Plot the z-scored data and the filtered data. - -#### traindata\_aligned - -```python -def traindata_aligned(cfg: dict, files: List[str], testfraction: float, - savgol_filter: bool, check_parameter: bool) -> None -``` - -Create training dataset for aligned data. - -**Arguments**: - -- `cfg` _dict_ - Configuration parameters. -- `files` _List[str]_ - List of files. -- `testfraction` _float_ - Fraction of data to use as test data. -- `num_features` _int_ - Number of features (deprecated). -- `savgol_filter` _bool_ - Flag indicating whether to apply Savitzky-Golay filter. -- `check_parameter` _bool_ - If True, the function will plot the z-scored data and the filtered data. - - -**Returns**: - - None - Save numpy arrays with the test/train info to the project folder. - -#### traindata\_fixed - -```python -def traindata_fixed(cfg: dict, files: List[str], testfraction: float, - num_features: int, savgol_filter: bool, - check_parameter: bool, - pose_ref_index: Optional[List[int]]) -> None -``` - -Create training dataset for fixed data. - -**Arguments**: - -- `cfg` _dict_ - Configuration parameters. -- `files` _List[str]_ - List of files. -- `testfraction` _float_ - Fraction of data to use as test data. -- `num_features` _int_ - Number of features. -- `savgol_filter` _bool_ - Flag indicating whether to apply Savitzky-Golay filter. -- `check_parameter` _bool_ - If True, the function will plot the z-scored data and the filtered data. -- `pose_ref_index` _Optional[List[int]]_ - List of reference coordinate indices for alignment. - - -**Returns**: - - None - Save numpy arrays with the test/train info to the project folder. - -#### create\_trainset - -```python -@save_state(model=CreateTrainsetFunctionSchema) -def create_trainset(config: str, - pose_ref_index: Optional[List] = None, - check_parameter: bool = False, - save_logs: bool = False) -> None -``` - -Creates a training dataset for the VAME model. - -**Arguments**: - -- `config` _str_ - Path to the config file. -- `pose_ref_index` _Optional[List], optional_ - List of reference coordinate indices for alignment. Defaults to None. -- `check_parameter` _bool, optional_ - If True, the function will plot the z-scored data and the filtered data. Defaults to False. - diff --git a/docs/vame-docs-app/docs/reference/vame/model/dataloader.md b/docs/vame-docs-app/docs/reference/vame/model/dataloader.md deleted file mode 100644 index c14f5586..00000000 --- a/docs/vame-docs-app/docs/reference/vame/model/dataloader.md +++ /dev/null @@ -1,68 +0,0 @@ ---- -sidebar_label: dataloader -title: vame.model.dataloader ---- - -Variational Animal Motion Embedding 0.1 Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -## SEQUENCE\_DATASET Objects - -```python -class SEQUENCE_DATASET(Dataset) -``` - -#### \_\_init\_\_ - -```python -def __init__(path_to_file: str, data: str, train: bool, temporal_window: int, - **kwargs) -> None -``` - -Initialize the Sequence Dataset. - -**Arguments**: - -- `path_to_file` _str_ - Path to the dataset files. -- `data` _str_ - Name of the data file. -- `train` _bool_ - Flag indicating whether it's training data. -- `temporal_window` _int_ - Size of the temporal window. - - -**Returns**: - - None - -#### \_\_len\_\_ - -```python -def __len__() -> int -``` - -Return the number of data points. - -**Returns**: - -- `int` - Number of data points. - -#### \_\_getitem\_\_ - -```python -def __getitem__(index: int) -> torch.Tensor -``` - -Get a normalized sequence at the specified index. - -**Arguments**: - -- `index` _int_ - Index of the item. - - -**Returns**: - -- `torch.Tensor` - Normalized sequence data at the specified index. - diff --git a/docs/vame-docs-app/docs/reference/vame/model/evaluate.md b/docs/vame-docs-app/docs/reference/vame/model/evaluate.md deleted file mode 100644 index 69f12af5..00000000 --- a/docs/vame-docs-app/docs/reference/vame/model/evaluate.md +++ /dev/null @@ -1,90 +0,0 @@ ---- -sidebar_label: evaluate -title: vame.model.evaluate ---- - -Variational Animal Motion Embedding 0.1 Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -#### plot\_reconstruction - -```python -def plot_reconstruction(filepath: str, - test_loader: Data.DataLoader, - seq_len_half: int, - model: RNN_VAE, - model_name: str, - FUTURE_DECODER: bool, - FUTURE_STEPS: int, - suffix: Optional[str] = None) -> None -``` - -Plot the reconstruction and future prediction of the input sequence. - -**Arguments**: - -- `filepath` _str_ - Path to save the plot. -- `test_loader` _Data.DataLoader_ - DataLoader for the test dataset. -- `seq_len_half` _int_ - Half of the temporal window size. -- `model` _RNN_VAE_ - Trained VAE model. -- `model_name` _str_ - Name of the model. -- `FUTURE_DECODER` _bool_ - Flag indicating whether the model has a future prediction decoder. -- `FUTURE_STEPS` _int_ - Number of future steps to predict. -- `suffix` _Optional[str], optional_ - Suffix for the saved plot filename. Defaults to None. - -#### plot\_loss - -```python -def plot_loss(cfg: dict, filepath: str, model_name: str) -> None -``` - -Plot the losses of the trained model. - -**Arguments**: - -- `cfg` _dict_ - Configuration dictionary. -- `filepath` _str_ - Path to save the plot. -- `model_name` _str_ - Name of the model. - -#### eval\_temporal - -```python -def eval_temporal(cfg: dict, - use_gpu: bool, - model_name: str, - fixed: bool, - snapshot: Optional[str] = None, - suffix: Optional[str] = None) -> None -``` - -Evaluate the temporal aspects of the trained model. - -**Arguments**: - -- `cfg` _dict_ - Configuration dictionary. -- `use_gpu` _bool_ - Flag indicating whether to use GPU for evaluation. -- `model_name` _str_ - Name of the model. -- `fixed` _bool_ - Flag indicating whether the data is fixed or not. -- `snapshot` _Optional[str], optional_ - Path to the model snapshot. Defaults to None. -- `suffix` _Optional[str], optional_ - Suffix for the saved plot filename. Defaults to None. - -#### evaluate\_model - -```python -@save_state(model=EvaluateModelFunctionSchema) -def evaluate_model(config: str, - use_snapshots: bool = False, - save_logs: bool = False) -> None -``` - -Evaluate the trained model. - -**Arguments**: - -- `config` _str_ - Path to config file. -- `use_snapshots` _bool, optional_ - Whether to plot for all snapshots or only the best model. Defaults to False. - diff --git a/docs/vame-docs-app/docs/reference/vame/model/rnn_model.md b/docs/vame-docs-app/docs/reference/vame/model/rnn_model.md deleted file mode 100644 index b74d1e80..00000000 --- a/docs/vame-docs-app/docs/reference/vame/model/rnn_model.md +++ /dev/null @@ -1,246 +0,0 @@ ---- -sidebar_label: rnn_model -title: vame.model.rnn_model ---- - -Variational Animal Motion Embedding 0.1 Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -The Model is partially adapted from the Timeseries Clustering repository developed by Tejas Lodaya: -https://github.com/tejaslodaya/timeseries-clustering-vae/blob/master/vrae/vrae.py - -## Encoder Objects - -```python -class Encoder(nn.Module) -``` - -Encoder module of the Variational Autoencoder. - -#### \_\_init\_\_ - -```python -def __init__(NUM_FEATURES: int, hidden_size_layer_1: int, - hidden_size_layer_2: int, dropout_encoder: float) -``` - -Initialize the Encoder module. - -**Arguments**: - -- `NUM_FEATURES` _int_ - Number of input features. -- `hidden_size_layer_1` _int_ - Size of the first hidden layer. -- `hidden_size_layer_2` _int_ - Size of the second hidden layer. -- `dropout_encoder` _float_ - Dropout rate for regularization. - -#### forward - -```python -def forward(inputs: torch.Tensor) -> torch.Tensor -``` - -Forward pass of the Encoder module. - -**Arguments**: - -- `inputs` _torch.Tensor_ - Input tensor of shape (batch_size, sequence_length, num_features). - - -**Returns**: - -- `torch.Tensor` - Encoded representation tensor of shape (batch_size, hidden_size_layer_1 * 4). - -## Lambda Objects - -```python -class Lambda(nn.Module) -``` - -Lambda module for computing the latent space parameters. - -#### \_\_init\_\_ - -```python -def __init__(ZDIMS: int, hidden_size_layer_1: int, softplus: bool) -``` - -Initialize the Lambda module. - -**Arguments**: - -- `ZDIMS` _int_ - Size of the latent space. -- `hidden_size_layer_1` _int_ - Size of the first hidden layer. -- `hidden_size_layer_2` _int, deprecated_ - Size of the second hidden layer. -- `softplus` _bool_ - Whether to use softplus activation for logvar. - -#### forward - -```python -def forward( - hidden: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] -``` - -Forward pass of the Lambda module. - -**Arguments**: - -- `hidden` _torch.Tensor_ - Hidden representation tensor of shape (batch_size, hidden_size_layer_1 * 4). - - -**Returns**: - - tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Latent space tensor, mean tensor, logvar tensor. - -## Decoder Objects - -```python -class Decoder(nn.Module) -``` - -Decoder module of the Variational Autoencoder. - -#### \_\_init\_\_ - -```python -def __init__(TEMPORAL_WINDOW: int, ZDIMS: int, NUM_FEATURES: int, - hidden_size_rec: int, dropout_rec: float) -``` - -Initialize the Decoder module. - -**Arguments**: - -- `TEMPORAL_WINDOW` _int_ - Size of the temporal window. -- `ZDIMS` _int_ - Size of the latent space. -- `NUM_FEATURES` _int_ - Number of input features. -- `hidden_size_rec` _int_ - Size of the recurrent hidden layer. -- `dropout_rec` _float_ - Dropout rate for regularization. - -#### forward - -```python -def forward(inputs: torch.Tensor, z: torch.Tensor) -> torch.Tensor -``` - -Forward pass of the Decoder module. - -**Arguments**: - -- `inputs` _torch.Tensor_ - Input tensor of shape (batch_size, seq_len, ZDIMS). -- `z` _torch.Tensor_ - Latent space tensor of shape (batch_size, ZDIMS). - - -**Returns**: - -- `torch.Tensor` - Decoded output tensor of shape (batch_size, seq_len, NUM_FEATURES). - -## Decoder\_Future Objects - -```python -class Decoder_Future(nn.Module) -``` - -Decoder module for predicting future sequences. - -#### \_\_init\_\_ - -```python -def __init__(TEMPORAL_WINDOW: int, ZDIMS: int, NUM_FEATURES: int, - FUTURE_STEPS: int, hidden_size_pred: int, dropout_pred: float) -``` - -Initialize the Decoder_Future module. - -**Arguments**: - -- `TEMPORAL_WINDOW` _int_ - Size of the temporal window. -- `ZDIMS` _int_ - Size of the latent space. -- `NUM_FEATURES` _int_ - Number of input features. -- `FUTURE_STEPS` _int_ - Number of future steps to predict. -- `hidden_size_pred` _int_ - Size of the prediction hidden layer. -- `dropout_pred` _float_ - Dropout rate for regularization. - -#### forward - -```python -def forward(inputs: torch.Tensor, z: torch.Tensor) -> torch.Tensor -``` - -Forward pass of the Decoder_Future module. - -**Arguments**: - -- `inputs` _torch.Tensor_ - Input tensor of shape (batch_size, seq_len, ZDIMS). -- `z` _torch.Tensor_ - Latent space tensor of shape (batch_size, ZDIMS). - - -**Returns**: - -- `torch.Tensor` - Predicted future tensor of shape (batch_size, FUTURE_STEPS, NUM_FEATURES). - -## RNN\_VAE Objects - -```python -class RNN_VAE(nn.Module) -``` - -Variational Autoencoder module. - -#### \_\_init\_\_ - -```python -def __init__(TEMPORAL_WINDOW: int, ZDIMS: int, NUM_FEATURES: int, - FUTURE_DECODER: bool, FUTURE_STEPS: int, hidden_size_layer_1: int, - hidden_size_layer_2: int, hidden_size_rec: int, - hidden_size_pred: int, dropout_encoder: float, dropout_rec: float, - dropout_pred: float, softplus: bool) -``` - -Initialize the VAE module. - -**Arguments**: - -- `TEMPORAL_WINDOW` _int_ - Size of the temporal window. -- `ZDIMS` _int_ - Size of the latent space. -- `NUM_FEATURES` _int_ - Number of input features. -- `FUTURE_DECODER` _bool_ - Whether to include a future decoder. -- `FUTURE_STEPS` _int_ - Number of future steps to predict. -- `hidden_size_layer_1` _int_ - Size of the first hidden layer. -- `hidden_size_layer_2` _int_ - Size of the second hidden layer. -- `hidden_size_rec` _int_ - Size of the recurrent hidden layer. -- `hidden_size_pred` _int_ - Size of the prediction hidden layer. -- `dropout_encoder` _float_ - Dropout rate for encoder. - -#### forward - -```python -def forward(seq: torch.Tensor) -> tuple -``` - -Forward pass of the VAE. - -**Arguments**: - -- `seq` _torch.Tensor_ - Input sequence tensor of shape (batch_size, seq_len, NUM_FEATURES). - - -**Returns**: - - Tuple containing: - - If FUTURE_DECODER is True: - - prediction (torch.Tensor): Reconstructed input sequence tensor. - - future (torch.Tensor): Predicted future sequence tensor. - - z (torch.Tensor): Latent representation tensor. - - mu (torch.Tensor): Mean of the latent distribution tensor. - - logvar (torch.Tensor): Log variance of the latent distribution tensor. - - If FUTURE_DECODER is False: - - prediction (torch.Tensor): Reconstructed input sequence tensor. - - z (torch.Tensor): Latent representation tensor. - - mu (torch.Tensor): Mean of the latent distribution tensor. - - logvar (torch.Tensor): Log variance of the latent distribution tensor. - diff --git a/docs/vame-docs-app/docs/reference/vame/model/rnn_vae.md b/docs/vame-docs-app/docs/reference/vame/model/rnn_vae.md deleted file mode 100644 index 8a5624c8..00000000 --- a/docs/vame-docs-app/docs/reference/vame/model/rnn_vae.md +++ /dev/null @@ -1,225 +0,0 @@ ---- -sidebar_label: rnn_vae -title: vame.model.rnn_vae ---- - -Variational Animal Motion Embedding 0.1 Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -#### reconstruction\_loss - -```python -def reconstruction_loss(x: torch.Tensor, x_tilde: torch.Tensor, - reduction: str) -> torch.Tensor -``` - -Compute the reconstruction loss between input and reconstructed data. - -**Arguments**: - -- `x` _torch.Tensor_ - Input data tensor. -- `x_tilde` _torch.Tensor_ - Reconstructed data tensor. -- `reduction` _str_ - Type of reduction for the loss. - - -**Returns**: - -- `torch.Tensor` - Reconstruction loss. - -#### future\_reconstruction\_loss - -```python -def future_reconstruction_loss(x: torch.Tensor, x_tilde: torch.Tensor, - reduction: str) -> torch.Tensor -``` - -Compute the future reconstruction loss between input and predicted future data. - -**Arguments**: - -- `x` _torch.Tensor_ - Input future data tensor. -- `x_tilde` _torch.Tensor_ - Reconstructed future data tensor. -- `reduction` _str_ - Type of reduction for the loss. - - -**Returns**: - -- `torch.Tensor` - Future reconstruction loss. - -#### cluster\_loss - -```python -def cluster_loss(H: torch.Tensor, kloss: int, lmbda: float, - batch_size: int) -> torch.Tensor -``` - -Compute the cluster loss. - -**Arguments**: - -- `H` _torch.Tensor_ - Latent representation tensor. -- `kloss` _int_ - Number of clusters. -- `lmbda` _float_ - Lambda value for the loss. -- `batch_size` _int_ - Size of the batch. - - -**Returns**: - -- `torch.Tensor` - Cluster loss. - -#### kullback\_leibler\_loss - -```python -def kullback_leibler_loss(mu: torch.Tensor, - logvar: torch.Tensor) -> torch.Tensor -``` - -Compute the Kullback-Leibler divergence loss. -see Appendix B from VAE paper: Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 - https://arxiv.org/abs/1312.6114 - -Formula: 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) - -**Arguments**: - -- `mu` _torch.Tensor_ - Mean of the latent distribution. -- `logvar` _torch.Tensor_ - Log variance of the latent distribution. - - -**Returns**: - -- `torch.Tensor` - Kullback-Leibler divergence loss. - -#### kl\_annealing - -```python -def kl_annealing(epoch: int, kl_start: int, annealtime: int, - function: str) -> float -``` - -Anneal the Kullback-Leibler loss to let the model learn first the reconstruction of the data -before the KL loss term gets introduced. - -**Arguments**: - -- `epoch` _int_ - Current epoch number. -- `kl_start` _int_ - Epoch number to start annealing the loss. -- `annealtime` _int_ - Annealing time. -- `function` _str_ - Annealing function type. - - -**Returns**: - -- `float` - Annealed weight value for the loss. - -#### gaussian - -```python -def gaussian(ins: torch.Tensor, - is_training: bool, - seq_len: int, - std_n: float = 0.8) -> torch.Tensor -``` - -Add Gaussian noise to the input data. - -**Arguments**: - -- `ins` _torch.Tensor_ - Input data tensor. -- `is_training` _bool_ - Whether it is training mode. -- `seq_len` _int_ - Length of the sequence. -- `std_n` _float_ - Standard deviation for the Gaussian noise. - - -**Returns**: - -- `torch.Tensor` - Noisy input data tensor. - -#### train - -```python -def train(train_loader: Data.DataLoader, epoch: int, model: nn.Module, - optimizer: torch.optim.Optimizer, anneal_function: str, BETA: float, - kl_start: int, annealtime: int, seq_len: int, future_decoder: bool, - future_steps: int, scheduler: torch.optim.lr_scheduler._LRScheduler, - mse_red: str, mse_pred: str, kloss: int, klmbda: float, bsize: int, - noise: bool) -> Tuple[float, float, float, float, float, float] -``` - -Train the model. - -**Arguments**: - -- `train_loader` _DataLoader_ - Training data loader. -- `epoch` _int_ - Current epoch number. -- `model` _nn.Module_ - Model to be trained. -- `optimizer` _Optimizer_ - Optimizer for training. -- `anneal_function` _str_ - Annealing function type. -- `BETA` _float_ - Beta value for the loss. -- `kl_start` _int_ - Epoch number to start annealing the loss. -- `annealtime` _int_ - Annealing time. -- `seq_len` _int_ - Length of the sequence. -- `future_decoder` _bool_ - Whether a future decoder is used. -- `epoch`0 _int_ - Number of future steps to predict. -- `epoch`1 _lr_scheduler._LRScheduler_ - Learning rate scheduler. -- `epoch`2 _str_ - Reduction type for MSE reconstruction loss. -- `epoch`3 _str_ - Reduction type for MSE prediction loss. -- `epoch`4 _int_ - Number of clusters for cluster loss. -- `epoch`5 _float_ - Lambda value for cluster loss. -- `epoch`6 _int_ - Size of the batch. -- `epoch`7 _bool_ - Whether to add Gaussian noise to the input. - - -**Returns**: - - Tuple[float, float, float, float, float, float]: Kullback-Leibler weight, train loss, K-means loss, KL loss, - MSE loss, future loss. - -#### test - -```python -def test(test_loader: Data.DataLoader, model: nn.Module, BETA: float, - kl_weight: float, seq_len: int, mse_red: str, kloss: str, - klmbda: float, future_decoder: bool, - bsize: int) -> Tuple[float, float, float] -``` - -Evaluate the model on the test dataset. - -**Arguments**: - -- `test_loader` _DataLoader_ - DataLoader for the test dataset. -- `epoch` _int, deprecated_ - Current epoch number. -- `model` _nn.Module_ - The trained model. -- `optimizer` _Optimizer, deprecated_ - The optimizer used for training. -- `BETA` _float_ - Beta value for the VAE loss. -- `kl_weight` _float_ - Weighting factor for the KL divergence loss. -- `seq_len` _int_ - Length of the sequence. -- `mse_red` _str_ - Reduction method for the MSE loss. -- `kloss` _str_ - Loss function for K-means clustering. -- `klmbda` _float_ - Lambda value for K-means loss. -- `epoch`0 _bool_ - Flag indicating whether to use a future decoder. -- `epoch`1 _int_ - Batch size. - - -**Returns**: - - Tuple[float, float, float]: Tuple containing MSE loss per item, total test loss per item, - and K-means loss weighted by the kl_weight. - -#### train\_model - -```python -@save_state(model=TrainModelFunctionSchema) -def train_model(config: str, save_logs: bool = False) -> None -``` - -Train Variational Autoencoder using the configuration file values. - -**Arguments**: - -- `config` _str_ - Path to the configuration file. - diff --git a/docs/vame-docs-app/docs/reference/vame/schemas/states.md b/docs/vame-docs-app/docs/reference/vame/schemas/states.md deleted file mode 100644 index d178598d..00000000 --- a/docs/vame-docs-app/docs/reference/vame/schemas/states.md +++ /dev/null @@ -1,14 +0,0 @@ ---- -sidebar_label: states -title: vame.schemas.states ---- - -#### save\_state - -```python -def save_state(model: BaseModel) -``` - -Decorator responsible for validating function arguments using pydantic and -saving the state of the called function to the project states json file. - diff --git a/docs/vame-docs-app/docs/reference/vame/util/align_egocentrical.md b/docs/vame-docs-app/docs/reference/vame/util/align_egocentrical.md deleted file mode 100644 index b965814b..00000000 --- a/docs/vame-docs-app/docs/reference/vame/util/align_egocentrical.md +++ /dev/null @@ -1,132 +0,0 @@ ---- -sidebar_label: align_egocentrical -title: vame.util.align_egocentrical ---- - -Variational Animal Motion Embedding 0.1 Toolbox -© K. Luxem & J. Kürsch & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -#### align\_mouse - -```python -def align_mouse( - path_to_file: str, - filename: str, - video_format: str, - crop_size: Tuple[int, int], - pose_list: List[np.ndarray], - pose_ref_index: Tuple[int, int], - confidence: float, - pose_flip_ref: Tuple[int, int], - bg: np.ndarray, - frame_count: int, - use_video: bool = True, - tqdm_stream: TqdmToLogger = None -) -> Tuple[List[np.ndarray], List[List[np.ndarray]], np.ndarray] -``` - -Align the mouse in the video frames. - -**Arguments**: - -- `path_to_file` _str_ - Path to the file directory. -- `filename` _str_ - Name of the video file without the format. -- `video_format` _str_ - Format of the video file. -- `crop_size` _Tuple[int, int]_ - Size to crop the video frames. -- `pose_list` _List[np.ndarray]_ - List of pose coordinates. -- `pose_ref_index` _Tuple[int, int]_ - Pose reference indices. -- `confidence` _float_ - Pose confidence threshold. -- `pose_flip_ref` _Tuple[int, int]_ - Reference indices for flipping. -- `bg` _np.ndarray_ - Background image. -- `frame_count` _int_ - Number of frames to align. -- `filename`0 _bool, optional_ - bool if video should be cropped or DLC points only. Defaults to True. - - -**Returns**: - - Tuple[List[np.ndarray], List[List[np.ndarray]], np.ndarray]: List of aligned images, list of aligned DLC points, and time series data. - -#### play\_aligned\_video - -```python -def play_aligned_video(a: List[np.ndarray], n: List[List[np.ndarray]], - frame_count: int) -> None -``` - -Play the aligned video. - -**Arguments**: - -- `a` _List[np.ndarray]_ - List of aligned images. -- `n` _List[List[np.ndarray]]_ - List of aligned DLC points. -- `frame_count` _int_ - Number of frames in the video. - -#### alignment - -```python -def alignment( - path_to_file: str, - filename: str, - pose_ref_index: List[int], - video_format: str, - crop_size: Tuple[int, int], - confidence: float, - pose_estimation_filetype: PoseEstimationFiletype, - path_to_pose_nwb_series_data: str = None, - use_video: bool = False, - check_video: bool = False, - tqdm_stream: TqdmToLogger = None -) -> Tuple[np.ndarray, List[np.ndarray]] -``` - -Perform alignment of egocentric data. - -**Arguments**: - -- `path_to_file` _str_ - Path to the file directory. -- `filename` _str_ - Name of the video file without the format. -- `pose_ref_index` _List[int]_ - Pose reference indices. -- `video_format` _str_ - Format of the video file. -- `crop_size` _Tuple[int, int]_ - Size to crop the video frames. -- `confidence` _float_ - Pose confidence threshold. -- `use_video` _bool, optional_ - Whether to use video for alignment. Defaults to False. -- `check_video` _bool, optional_ - Whether to check the aligned video. Defaults to False. - - -**Returns**: - - Tuple[np.ndarray, List[np.ndarray]]: Aligned time series data and list of aligned frames. - -#### egocentric\_alignment - -```python -@save_state(model=EgocentricAlignmentFunctionSchema) -def egocentric_alignment(config: str, - pose_ref_index: list = [5, 6], - crop_size: tuple = (300, 300), - use_video: bool = False, - video_format: str = '.mp4', - check_video: bool = False, - save_logs: bool = False) -> None -``` - -Aligns egocentric data for VAME training - -**Arguments**: - -- `config` _str_ - Path for the project config file. -- `pose_ref_index` _list, optional_ - Pose reference index to be used to align. Defaults to [5,6]. -- `crop_size` _tuple, optional_ - Size to crop the video. Defaults to (300,300). -- `use_video` _bool, optional_ - Weather to use video to do the post alignment. Defaults to False. # TODO check what to put in this docstring -- `video_format` _str, optional_ - Video format, can be .mp4 or .avi. Defaults to '.mp4'. -- `check_video` _bool, optional_ - Weather to check the video. Defaults to False. - - -**Raises**: - -- `ValueError` - If the config.yaml indicates that the data is not egocentric. - diff --git a/docs/vame-docs-app/docs/reference/vame/util/auxiliary.md b/docs/vame-docs-app/docs/reference/vame/util/auxiliary.md deleted file mode 100644 index bc8dff7c..00000000 --- a/docs/vame-docs-app/docs/reference/vame/util/auxiliary.md +++ /dev/null @@ -1,63 +0,0 @@ ---- -sidebar_label: auxiliary -title: vame.util.auxiliary ---- - -Variational Animal Motion Embedding 1.0-alpha Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -The following code is adapted from: - -DeepLabCut2.0 Toolbox (deeplabcut.org) -© A. & M. Mathis Labs -https://github.com/AlexEMG/DeepLabCut -Please see AUTHORS for contributors. -https://github.com/AlexEMG/DeepLabCut/blob/master/AUTHORS -Licensed under GNU Lesser General Public License v3.0 - -#### create\_config\_template - -```python -def create_config_template() -> Tuple[dict, ruamel.yaml.YAML] -``` - -Creates a template for the config.yaml file. - -**Returns**: - - Tuple[dict, ruamel.yaml.YAML]: A tuple containing the template dictionary and the Ruamel YAML instance. - -#### read\_config - -```python -def read_config(configname: str) -> dict -``` - -Reads structured config file defining a project. - -**Arguments**: - -- `configname` _str_ - Path to the config file. - - -**Returns**: - -- `dict` - The contents of the config file as a dictionary. - -#### write\_config - -```python -def write_config(configname: str, cfg: dict) -> None -``` - -Write structured config file. - -**Arguments**: - -- `configname` _str_ - Path to the config file. -- `cfg` _dict_ - Dictionary containing the config data. - diff --git a/docs/vame-docs-app/docs/reference/vame/util/csv_to_npy.md b/docs/vame-docs-app/docs/reference/vame/util/csv_to_npy.md deleted file mode 100644 index 7db81d4b..00000000 --- a/docs/vame-docs-app/docs/reference/vame/util/csv_to_npy.md +++ /dev/null @@ -1,26 +0,0 @@ ---- -sidebar_label: csv_to_npy -title: vame.util.csv_to_npy ---- - -Variational Animal Motion Embedding 1.0-alpha Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -#### pose\_to\_numpy - -```python -@save_state(model=PoseToNumpyFunctionSchema) -def pose_to_numpy(config: str, save_logs=False) -> None -``` - -Converts a pose-estimation.csv file to a numpy array. Note that this code is only useful for data which is a priori egocentric, i.e. head-fixed -or otherwise restrained animals. - -**Raises**: - -- `ValueError` - If the config.yaml file indicates that the data is not egocentric. - diff --git a/docs/vame-docs-app/docs/reference/vame/util/data_manipulation.md b/docs/vame-docs-app/docs/reference/vame/util/data_manipulation.md deleted file mode 100644 index 8d3ea693..00000000 --- a/docs/vame-docs-app/docs/reference/vame/util/data_manipulation.md +++ /dev/null @@ -1,141 +0,0 @@ ---- -sidebar_label: data_manipulation -title: vame.util.data_manipulation ---- - -#### get\_pose\_data\_from\_nwb\_file - -```python -def get_pose_data_from_nwb_file( - nwbfile: NWBFile, path_to_pose_nwb_series_data: str) -> LabelledDict -``` - -Get pose data from nwb file using a inside path to the nwb data. - -**Arguments**: - -- `nwbfile` _NWBFile_ - NWB file object. -- `path_to_pose_nwb_series_data` _str_ - Path to the pose data inside the nwb file. - - -**Returns**: - -- `LabelledDict` - Pose data. - -#### consecutive - -```python -def consecutive(data: np.ndarray, stepsize: int = 1) -> List[np.ndarray] -``` - -Find consecutive sequences in the data array. - -**Arguments**: - -- `data` _np.ndarray_ - Input array. -- `stepsize` _int, optional_ - Step size. Defaults to 1. - - -**Returns**: - -- `List[np.ndarray]` - List of consecutive sequences. - -#### nan\_helper - -```python -def nan_helper(y: np.ndarray) -> Tuple -``` - -Identifies indices of NaN values in an array and provides a function to convert them to non-NaN indices. - -**Arguments**: - -- `y` _np.ndarray_ - Input array containing NaN values. - - -**Returns**: - - Tuple[np.ndarray, Union[np.ndarray, None]]: A tuple containing two elements: - - An array of boolean values indicating the positions of NaN values. - - A lambda function to convert NaN indices to non-NaN indices. - -#### interpol\_all\_nans - -```python -def interpol_all_nans(arr: np.ndarray) -> np.ndarray -``` - -Interpolates all NaN values in the given array. - -**Arguments**: - -- `arr` _np.ndarray_ - Input array containing NaN values. - - -**Returns**: - -- `np.ndarray` - Array with NaN values replaced by interpolated values. - -#### interpol\_first\_rows\_nans - -```python -def interpol_first_rows_nans(arr: np.ndarray) -> np.ndarray -``` - -Interpolates NaN values in the given array. - -**Arguments**: - -- `arr` _np.ndarray_ - Input array with NaN values. - - -**Returns**: - -- `np.ndarray` - Array with interpolated NaN values. - -#### crop\_and\_flip - -```python -def crop_and_flip( - rect: Tuple, src: np.ndarray, points: List[np.ndarray], - ref_index: Tuple[int, int]) -> Tuple[np.ndarray, List[np.ndarray]] -``` - -Crop and flip the image based on the given rectangle and points. - -**Arguments**: - -- `rect` _Tuple_ - Rectangle coordinates (center, size, theta). -- `src` _np.ndarray_ - Source image. -- `points` _List[np.ndarray]_ - List of points. -- `ref_index` _Tuple[int, int]_ - Reference indices for alignment. - - -**Returns**: - - Tuple[np.ndarray, List[np.ndarray]]: Cropped and flipped image, and shifted points. - -#### background - -```python -def background(path_to_file: str, - filename: str, - file_format: str = '.mp4', - num_frames: int = 1000, - save_background: bool = True) -> np.ndarray -``` - -Compute background image from fixed camera. - -**Arguments**: - -- `path_to_file` _str_ - Path to the directory containing the video files. -- `filename` _str_ - Name of the video file. -- `file_format` _str, optional_ - Format of the video file. Defaults to '.mp4'. -- `num_frames` _int, optional_ - Number of frames to use for background computation. Defaults to 1000. - - -**Returns**: - -- `np.ndarray` - Background image. - diff --git a/docs/vame-docs-app/docs/reference/vame/util/gif_pose_helper.md b/docs/vame-docs-app/docs/reference/vame/util/gif_pose_helper.md deleted file mode 100644 index 9811fd4c..00000000 --- a/docs/vame-docs-app/docs/reference/vame/util/gif_pose_helper.md +++ /dev/null @@ -1,44 +0,0 @@ ---- -sidebar_label: gif_pose_helper -title: vame.util.gif_pose_helper ---- - -Variational Animal Motion Embedding 1.0-alpha Toolbox -© K. Luxem & P. Bauer, Department of Cellular Neuroscience -Leibniz Institute for Neurobiology, Magdeburg, Germany - -https://github.com/LINCellularNeuroscience/VAME -Licensed under GNU General Public License v3.0 - -#### get\_animal\_frames - -```python -def get_animal_frames( - cfg: dict, - filename: str, - pose_ref_index: list, - start: int, - length: int, - subtract_background: bool, - file_format: str = '.mp4', - crop_size: tuple = (300, 300)) -> list -``` - -Extracts frames of an animal from a video file and returns them as a list. - -**Arguments**: - -- `cfg` _dict_ - Configuration dictionary containing project information. -- `filename` _str_ - Name of the video file. -- `pose_ref_index` _list_ - List of reference coordinate indices for alignment. -- `start` _int_ - Starting frame index. -- `length` _int_ - Number of frames to extract. -- `subtract_background` _bool_ - Whether to subtract background or not. -- `file_format` _str, optional_ - Format of the video file. Defaults to '.mp4'. -- `crop_size` _tuple, optional_ - Size of the cropped area. Defaults to (300, 300). - - -**Returns**: - -- `list` - List of extracted frames. - diff --git a/docs/vame-docs-app/docs/reference/vame/util/model_util.md b/docs/vame-docs-app/docs/reference/vame/util/model_util.md deleted file mode 100644 index 7e5dd81c..00000000 --- a/docs/vame-docs-app/docs/reference/vame/util/model_util.md +++ /dev/null @@ -1,24 +0,0 @@ ---- -sidebar_label: model_util -title: vame.util.model_util ---- - -#### load\_model - -```python -def load_model(cfg: dict, model_name: str, fixed: bool = True) -> RNN_VAE -``` - -Load the VAME model. - -**Arguments**: - -- `cfg` _dict_ - Configuration dictionary. -- `model_name` _str_ - Name of the model. -- `fixed` _bool_ - Fixed or variable length sequences. - - -**Returns**: - -- `RNN_VAE` - Loaded VAME model. - diff --git a/docs/vame-docs-app/src/components/HomepageFeatures/index.js b/docs/vame-docs-app/src/components/HomepageFeatures/index.js index 17599cc2..9da30a32 100644 --- a/docs/vame-docs-app/src/components/HomepageFeatures/index.js +++ b/docs/vame-docs-app/src/components/HomepageFeatures/index.js @@ -5,7 +5,6 @@ import styles from './styles.module.css'; const FeatureList = [ { title: 'Behavioral Segmentation', - // Svg: require('@site/static/img/undraw_docusaurus_mountain.svg').default, description: ( <> VAME provides advanced algorithms for precise behavioral segmentation, allowing researchers to analyze animal motion patterns efficiently. @@ -14,7 +13,6 @@ const FeatureList = [ }, { title: 'Machine Learning Framework', - Svg: require('@site/static/img/undraw_docusaurus_tree.svg').default, description: ( <> Utilizing state-of-the-art machine learning techniques, VAME extracts meaningful insights from behavioral data to facilitate scientific discoveries. @@ -23,7 +21,6 @@ const FeatureList = [ }, { title: 'Python API', - Svg: require('@site/static/img/undraw_docusaurus_react.svg').default, description: ( <> VAME offers a straight forward Python API, making it easy for users to integrate into their workflows and perform behavioral analysis with minimal effort. @@ -32,7 +29,7 @@ const FeatureList = [ }, ]; -function Feature({Svg, title, description}) { +function Feature({ Svg, title, description }) { return (
{/*
diff --git a/docs/yarn.lock b/docs/yarn.lock new file mode 100644 index 00000000..fb57ccd1 --- /dev/null +++ b/docs/yarn.lock @@ -0,0 +1,4 @@ +# THIS IS AN AUTOGENERATED FILE. DO NOT EDIT THIS FILE DIRECTLY. +# yarn lockfile v1 + + From 6c119352f73a459e9ffebce6fc31c82114bf542d Mon Sep 17 00:00:00 2001 From: luiz Date: Sat, 28 Dec 2024 13:10:57 +0100 Subject: [PATCH 53/77] category --- docs/vame-docs-app/docs/reference/_category_.json | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 docs/vame-docs-app/docs/reference/_category_.json diff --git a/docs/vame-docs-app/docs/reference/_category_.json b/docs/vame-docs-app/docs/reference/_category_.json new file mode 100644 index 00000000..8bfab157 --- /dev/null +++ b/docs/vame-docs-app/docs/reference/_category_.json @@ -0,0 +1,8 @@ +{ + "label": "API Reference", + "position": 6, + "link": { + "type": "generated-index", + "description": "API Reference" + } +} \ No newline at end of file From 054a0b103207939c241434a600b73c3394595920 Mon Sep 17 00:00:00 2001 From: luiz Date: Sat, 28 Dec 2024 14:50:14 +0100 Subject: [PATCH 54/77] render notebook --- .../docs/getting_started/installation.md | 2 +- .../docs/getting_started/pipeline.mdx | 13 + docs/vame-docs-app/sidebars.js | 6 +- .../src/components/IframeResizer/index.js | 58 + .../static/notebooks_html/pipeline.html | 8479 +++++++++++++++++ 5 files changed, 8554 insertions(+), 4 deletions(-) create mode 100644 docs/vame-docs-app/docs/getting_started/pipeline.mdx create mode 100644 docs/vame-docs-app/src/components/IframeResizer/index.js create mode 100644 docs/vame-docs-app/static/notebooks_html/pipeline.html diff --git a/docs/vame-docs-app/docs/getting_started/installation.md b/docs/vame-docs-app/docs/getting_started/installation.md index bf4d4fe7..96655076 100644 --- a/docs/vame-docs-app/docs/getting_started/installation.md +++ b/docs/vame-docs-app/docs/getting_started/installation.md @@ -1,6 +1,6 @@ --- title: Installation -sidebar_position: 2 +sidebar_position: 1 --- diff --git a/docs/vame-docs-app/docs/getting_started/pipeline.mdx b/docs/vame-docs-app/docs/getting_started/pipeline.mdx new file mode 100644 index 00000000..5e69ecc0 --- /dev/null +++ b/docs/vame-docs-app/docs/getting_started/pipeline.mdx @@ -0,0 +1,13 @@ +--- +id: run-pipeline +title: Run Pipeline +sidebar_position: 3 +slug: /getting_started/run-pipeline +--- + +import React from 'react'; +import useBaseUrl from '@docusaurus/useBaseUrl'; + +import IframeResizer from '@site/src/components/IframeResizer'; + + \ No newline at end of file diff --git a/docs/vame-docs-app/sidebars.js b/docs/vame-docs-app/sidebars.js index e9c53432..2e743330 100644 --- a/docs/vame-docs-app/sidebars.js +++ b/docs/vame-docs-app/sidebars.js @@ -14,9 +14,9 @@ /** @type {import('@docusaurus/plugin-content-docs').SidebarsConfig} */ const sidebars = { // By default, Docusaurus generates a sidebar from the docs folder structure - docsSidebar: [{type: 'autogenerated', dirName: '.'}], - - + docsSidebar: [ + { type: 'autogenerated', dirName: '.' }, + ], }; export default sidebars; diff --git a/docs/vame-docs-app/src/components/IframeResizer/index.js b/docs/vame-docs-app/src/components/IframeResizer/index.js new file mode 100644 index 00000000..902b73a3 --- /dev/null +++ b/docs/vame-docs-app/src/components/IframeResizer/index.js @@ -0,0 +1,58 @@ +import React, { useRef, useEffect } from 'react'; + +const IframeResizer = ({ src, heightBuffer = 20, maxCellHeight = 200, ...props }) => { + const iframeRef = useRef(null); + + useEffect(() => { + const injectCustomStyles = () => { + const iframe = iframeRef.current; + if (iframe) { + try { + const iframeDocument = iframe.contentDocument || iframe.contentWindow.document; + + // Inject custom CSS for output content + const style = iframeDocument.createElement('style'); + style.innerHTML = ` + .jp-OutputArea, .jp-Cell-outputArea { + max-height: ${maxCellHeight}px; + overflow-y: auto; + } + .jp-RenderedText, .jp-OutputArea-output { + word-wrap: break-word; + white-space: pre-wrap; + } + `; + iframeDocument.head.appendChild(style); + + // Adjust iframe height after styles are applied + const height = iframeDocument.body.scrollHeight + heightBuffer; + iframe.style.height = `${height}px`; + } catch (error) { + console.warn('Could not access iframe content due to cross-origin restrictions.', error); + } + } + }; + + const iframe = iframeRef.current; + if (iframe) { + iframe.addEventListener('load', injectCustomStyles); + } + + return () => { + if (iframe) { + iframe.removeEventListener('load', injectCustomStyles); + } + }; + }, [heightBuffer, maxCellHeight]); + + return ( + + ); +}; + +export default IframeResizer; diff --git a/docs/vame-docs-app/static/notebooks_html/pipeline.html b/docs/vame-docs-app/static/notebooks_html/pipeline.html new file mode 100644 index 00000000..bebb9b32 --- /dev/null +++ b/docs/vame-docs-app/static/notebooks_html/pipeline.html @@ -0,0 +1,8479 @@ + + + + + +pipeline + + + + + + + + + + + + +
+ +
+ + From cd2153ff5213f1f7bf345420a007e11653b20cc6 Mon Sep 17 00:00:00 2001 From: luiz Date: Mon, 30 Dec 2024 11:19:56 +0100 Subject: [PATCH 55/77] some fixes for pipeline, docstring and notebook --- .github/workflows/publish_docs.yaml | 4 + .gitignore | 3 + docs/README.md | 9 +- docs/requirements-docs.txt | 3 +- .../src/components/IframeResizer/index.js | 27 +- .../static/notebooks_html/pipeline.html | 1033 +++-------------- examples/pipeline.ipynb | 121 +- src/vame/__init__.py | 2 +- src/vame/analysis/__init__.py | 2 +- src/vame/analysis/umap.py | 2 +- src/vame/initialize_project/new.py | 2 +- src/vame/pipeline.py | 253 +++- 12 files changed, 515 insertions(+), 946 deletions(-) diff --git a/.github/workflows/publish_docs.yaml b/.github/workflows/publish_docs.yaml index 062fa992..4f30365b 100644 --- a/.github/workflows/publish_docs.yaml +++ b/.github/workflows/publish_docs.yaml @@ -31,6 +31,10 @@ jobs: - name: Auto generate API Reference. run: cd docs && pydoc-markdown + - name: Convert Jupyter Notebooks to HTML + run: | + jupyter nbconvert --to html examples/pipeline.ipynb --embed-images --no-prompt --output-dir=docs/vame-docs-app/static/notebooks_html + - uses: actions/setup-node@v3 with: node-version: 18 diff --git a/.gitignore b/.gitignore index c3952ffb..c9226fc1 100644 --- a/.gitignore +++ b/.gitignore @@ -112,3 +112,6 @@ venv_docs/ # mypy .mypy_cache/ + +# Others +**/pipeline_example/** \ No newline at end of file diff --git a/docs/README.md b/docs/README.md index 6d4da88c..d86f10e9 100644 --- a/docs/README.md +++ b/docs/README.md @@ -7,13 +7,17 @@ This folder contains the documentation for the VAME project. The docs are a docu The API Reference documentation is automatically generated from the docstrings and type annotations in the codebase using [pydoc-markdown](https://github.com/NiklasRosenstein/pydoc-markdown). 1. Install pydoc-markdown: -First install `pydoc-markdown` package following their guide [here](https://niklasrosenstein.github.io/pydoc-markdown/#installation-). +First install `pydoc-markdown` package. For the moment, we are using a fork of the original package, so you need to install it from the forked repository: + +```bash +pip install git+https://github.com/luiztauffer/pydoc-markdown.git@develop +``` 2. In the `docs/` directory, run the following command to generate the API Reference documentation: ```bash pydoc-markdown ``` -This command will generate the API Reference documentation from the project and save it in the `docs/vame-docs-app/docs/reference/vame` folder. +This command will generate the API Reference documentation from the project and save it in the `docs/vame-docs-app/docs/reference/` folder. ### Running the documentation app locally @@ -29,3 +33,4 @@ yarn yarn start ``` +The Docusaurus website should be running locally at: http://localhost:3000/VAME/ diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index b9ad3b85..8ef92f3a 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -1,2 +1,3 @@ # pydoc-markdown==4.8.2 -git+https://github.com/luiztauffer/pydoc-markdown.git@develop \ No newline at end of file +git+https://github.com/luiztauffer/pydoc-markdown.git@develop +nbconvert==7.16.4 \ No newline at end of file diff --git a/docs/vame-docs-app/src/components/IframeResizer/index.js b/docs/vame-docs-app/src/components/IframeResizer/index.js index 902b73a3..66e42cbd 100644 --- a/docs/vame-docs-app/src/components/IframeResizer/index.js +++ b/docs/vame-docs-app/src/components/IframeResizer/index.js @@ -10,13 +10,38 @@ const IframeResizer = ({ src, heightBuffer = 20, maxCellHeight = 200, ...props } try { const iframeDocument = iframe.contentDocument || iframe.contentWindow.document; - // Inject custom CSS for output content + // Inject custom CSS for Markdown cells and other notebook elements const style = iframeDocument.createElement('style'); style.innerHTML = ` + /* Apply Docusaurus font style */ + html { + background-color: var(--ifm-background-color); + color: var(--ifm-font-color-base); + color-scheme: var(--ifm-color-scheme); + font: var(--ifm-font-size-base) / var(--ifm-line-height-base) var(--ifm-font-family-base); + -webkit-font-smoothing: antialiased; + -webkit-tap-highlight-color: transparent; + text-rendering: optimizelegibility; + text-size-adjust: 100%; + } + + body { + margin: 0; + padding: 0; + } + + /* Markdown cell styling */ + .jp-MarkdownOutput { + margin: 0rem 0; + padding: 0rem; + } + + /* Ensure long outputs in other cells scroll vertically */ .jp-OutputArea, .jp-Cell-outputArea { max-height: ${maxCellHeight}px; overflow-y: auto; } + .jp-RenderedText, .jp-OutputArea-output { word-wrap: break-word; white-space: pre-wrap; diff --git a/docs/vame-docs-app/static/notebooks_html/pipeline.html b/docs/vame-docs-app/static/notebooks_html/pipeline.html index bebb9b32..8d35e79e 100644 --- a/docs/vame-docs-app/static/notebooks_html/pipeline.html +++ b/docs/vame-docs-app/static/notebooks_html/pipeline.html @@ -7331,11 +7331,12 @@ if (!diagrams.length) { return; } - const mermaid = (await import("https://cdnjs.cloudflare.com/ajax/libs/mermaid/10.6.0/mermaid.esm.min.mjs")).default; + const mermaid = (await import("https://cdnjs.cloudflare.com/ajax/libs/mermaid/10.7.0/mermaid.esm.min.mjs")).default; const parser = new DOMParser(); mermaid.initialize({ maxTextSize: 100000, + maxEdges: 100000, startOnLoad: false, fontFamily: window .getComputedStyle(document.body) @@ -7406,7 +7407,8 @@ let results = null; let output = null; try { - const { svg } = await mermaid.render(id, raw, el); + let { svg } = await mermaid.render(id, raw, el); + svg = cleanMermaidSvg(svg); results = makeMermaidImage(svg); output = document.createElement("figure"); results.map(output.appendChild, output); @@ -7421,6 +7423,38 @@ parent.appendChild(output); } + + /** + * Post-process to ensure mermaid diagrams contain only valid SVG and XHTML. + */ + function cleanMermaidSvg(svg) { + return svg.replace(RE_VOID_ELEMENT, replaceVoidElement); + } + + + /** + * A regular expression for all void elements, which may include attributes and + * a slash. + * + * @see https://developer.mozilla.org/en-US/docs/Glossary/Void_element + * + * Of these, only `
` is generated by Mermaid in place of `\n`, + * but _any_ "malformed" tag will break the SVG rendering entirely. + */ + const RE_VOID_ELEMENT = + /<\s*(area|base|br|col|embed|hr|img|input|link|meta|param|source|track|wbr)\s*([^>]*?)\s*>/gi; + + /** + * Ensure a void element is closed with a slash, preserving any attributes. + */ + function replaceVoidElement(match, tag, rest) { + rest = rest.trim(); + if (!rest.endsWith('/')) { + rest = `${rest} /`; + } + return `<${tag} ${rest}>`; + } + void Promise.all([...diagrams].map(renderOneMarmaid)); }); @@ -7476,104 +7510,109 @@
-
From 58ec485bae6c0218217a4d985eefd80fba42a955 Mon Sep 17 00:00:00 2001 From: luiz Date: Tue, 31 Dec 2024 11:07:34 +0100 Subject: [PATCH 67/77] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7aeba75e..a9e58486 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,9 +8,12 @@ ### Features +- User friendly way to select reference points for alignment ([Issue #89](https://github.com/EthoML/VAME/issues/89)) - Adopt movement Xarray data format ([Issue #111](https://github.com/EthoML/VAME/issues/111)) - Relocate IQR cleaning into preprocessing ([Issue #22](https://github.com/EthoML/VAME/issues/22)) - Created preprocessing module ([Issue #119](https://github.com/EthoML/VAME/issues/119)) +- Separate module for visualization of results ([Issue #127](https://github.com/EthoML/VAME/issues/127)) +- Further improvements to Pipeline # v0.6.0 From 451daf6e4d0478de2586b6ce305c0d268c20614f Mon Sep 17 00:00:00 2001 From: luiz Date: Tue, 31 Dec 2024 11:16:03 +0100 Subject: [PATCH 68/77] actions and fix docs --- .github/workflows/publish_docs.yaml | 5 ++ .github/workflows/testing.yaml | 5 +- .../docs/getting_started/installation.md | 5 +- .../src/components/IframeResizer/index.js | 76 +++++++++---------- 4 files changed, 50 insertions(+), 41 deletions(-) diff --git a/.github/workflows/publish_docs.yaml b/.github/workflows/publish_docs.yaml index 97618db0..ea647611 100644 --- a/.github/workflows/publish_docs.yaml +++ b/.github/workflows/publish_docs.yaml @@ -5,6 +5,11 @@ on: branches: - main - dev + paths: + - '.github/workflows/publush_docs.yaml' + - 'docs/**' + - 'examples/**' + - 'src/**' jobs: deploy: diff --git a/.github/workflows/testing.yaml b/.github/workflows/testing.yaml index 49745dca..84c80acd 100644 --- a/.github/workflows/testing.yaml +++ b/.github/workflows/testing.yaml @@ -5,7 +5,10 @@ on: branches: - main - dev - - preprocessing + paths: + - '.github/workflows/testing.yaml' + - 'src/**' + - 'tests/**' jobs: run: diff --git a/docs/vame-docs-app/docs/getting_started/installation.md b/docs/vame-docs-app/docs/getting_started/installation.md index 96655076..962ac219 100644 --- a/docs/vame-docs-app/docs/getting_started/installation.md +++ b/docs/vame-docs-app/docs/getting_started/installation.md @@ -19,7 +19,8 @@ pip install vame-py 1. Clone the VAME repository to your local machine by running ```bash -git clone https://github.com/LINCellularNeuroscience/VAME.git +git clone https://github.com/EthoML/VAME.git +cd VAME ``` @@ -29,9 +30,9 @@ git clone https://github.com/LINCellularNeuroscience/VAME.git ```bash conda env create -f VAME.yaml ``` + **Option 2:** Installing local VAME with pip in your active virtual environment by running ```bash -cd VAME pip install . ``` diff --git a/docs/vame-docs-app/src/components/IframeResizer/index.js b/docs/vame-docs-app/src/components/IframeResizer/index.js index 66e42cbd..6ff1fd58 100644 --- a/docs/vame-docs-app/src/components/IframeResizer/index.js +++ b/docs/vame-docs-app/src/components/IframeResizer/index.js @@ -1,18 +1,18 @@ import React, { useRef, useEffect } from 'react'; -const IframeResizer = ({ src, heightBuffer = 20, maxCellHeight = 200, ...props }) => { - const iframeRef = useRef(null); +const IframeResizer = ({ src, heightBuffer = 20, maxCellHeight = 400, ...props }) => { + const iframeRef = useRef(null); - useEffect(() => { - const injectCustomStyles = () => { - const iframe = iframeRef.current; - if (iframe) { - try { - const iframeDocument = iframe.contentDocument || iframe.contentWindow.document; + useEffect(() => { + const injectCustomStyles = () => { + const iframe = iframeRef.current; + if (iframe) { + try { + const iframeDocument = iframe.contentDocument || iframe.contentWindow.document; - // Inject custom CSS for Markdown cells and other notebook elements - const style = iframeDocument.createElement('style'); - style.innerHTML = ` + // Inject custom CSS for Markdown cells and other notebook elements + const style = iframeDocument.createElement('style'); + style.innerHTML = ` /* Apply Docusaurus font style */ html { background-color: var(--ifm-background-color); @@ -47,37 +47,37 @@ const IframeResizer = ({ src, heightBuffer = 20, maxCellHeight = 200, ...props } white-space: pre-wrap; } `; - iframeDocument.head.appendChild(style); + iframeDocument.head.appendChild(style); - // Adjust iframe height after styles are applied - const height = iframeDocument.body.scrollHeight + heightBuffer; - iframe.style.height = `${height}px`; - } catch (error) { - console.warn('Could not access iframe content due to cross-origin restrictions.', error); - } - } - }; - - const iframe = iframeRef.current; - if (iframe) { - iframe.addEventListener('load', injectCustomStyles); + // Adjust iframe height after styles are applied + const height = iframeDocument.body.scrollHeight + heightBuffer; + iframe.style.height = `${height}px`; + } catch (error) { + console.warn('Could not access iframe content due to cross-origin restrictions.', error); } + } + }; - return () => { - if (iframe) { - iframe.removeEventListener('load', injectCustomStyles); - } - }; - }, [heightBuffer, maxCellHeight]); + const iframe = iframeRef.current; + if (iframe) { + iframe.addEventListener('load', injectCustomStyles); + } + + return () => { + if (iframe) { + iframe.removeEventListener('load', injectCustomStyles); + } + }; + }, [heightBuffer, maxCellHeight]); - return ( - - ); + return ( + + ); }; export default IframeResizer; From ce2278e93fa084bc2dddfba8257aa067e5e493d8 Mon Sep 17 00:00:00 2001 From: luiz Date: Tue, 31 Dec 2024 17:13:43 +0100 Subject: [PATCH 69/77] fix test model fig --- tests/04_model_test.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/04_model_test.py b/tests/04_model_test.py index 913d3e9f..66e8e64b 100644 --- a/tests/04_model_test.py +++ b/tests/04_model_test.py @@ -76,6 +76,14 @@ def test_train_model_best_model_file_exists(setup_project_and_train_model): def test_evaluate_model_images_exists(setup_project_and_evaluate_model): + from vame.visualization.model import plot_loss + + plot_loss( + cfg=setup_project_and_evaluate_model["config_data"], + model_name=setup_project_and_evaluate_model["config_data"]["model_name"], + save_to_file=True, + show_figure=False, + ) project_path = setup_project_and_evaluate_model["config_data"]["project_path"] model_name = setup_project_and_evaluate_model["config_data"]["model_name"] reconstruction_image_path = Path(project_path) / "model" / "evaluate" / "Future_Reconstruction.png" From 50e166658af6a1116d01091db67dbabb0dd391f8 Mon Sep 17 00:00:00 2001 From: luiz Date: Tue, 31 Dec 2024 17:22:53 +0100 Subject: [PATCH 70/77] fix umap fig test --- tests/05_analysis_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/05_analysis_test.py b/tests/05_analysis_test.py index cccc811c..9f2e6788 100644 --- a/tests/05_analysis_test.py +++ b/tests/05_analysis_test.py @@ -230,8 +230,9 @@ def test_visualization_output_files( label, segmentation_algorithm, ): - vame.visualization( - setup_project_and_train_model["config_data"], + from vame.visualization.umap import visualize_umap + visualize_umap( + config=setup_project_and_train_model["config_data"], segmentation_algorithm=segmentation_algorithm, label=label, save_logs=True, From aa07d2d0522b7cd9eb69a86567802628407980ce Mon Sep 17 00:00:00 2001 From: luiz Date: Tue, 31 Dec 2024 17:29:07 +0100 Subject: [PATCH 71/77] fix test --- tests/05_analysis_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/05_analysis_test.py b/tests/05_analysis_test.py index 9f2e6788..7602b782 100644 --- a/tests/05_analysis_test.py +++ b/tests/05_analysis_test.py @@ -4,6 +4,7 @@ from matplotlib.figure import Figure from unittest.mock import patch from vame.util.gif_pose_helper import background +from vame.visualization.umap import visualize_umap @pytest.mark.parametrize( @@ -230,7 +231,6 @@ def test_visualization_output_files( label, segmentation_algorithm, ): - from vame.visualization.umap import visualize_umap visualize_umap( config=setup_project_and_train_model["config_data"], segmentation_algorithm=segmentation_algorithm, @@ -333,7 +333,7 @@ def mock_background( save_logs=False, segmentation_algorithm=SEGMENTATION_ALGORITHM, ) - vame.visualization( + visualize_umap( config=setup_project_and_evaluate_model["config_data"], segmentation_algorithm=SEGMENTATION_ALGORITHM, label=label, From e65cb7656064cd1da2008b999342100c4e8dde2b Mon Sep 17 00:00:00 2001 From: luiz Date: Wed, 1 Jan 2025 10:23:37 +0100 Subject: [PATCH 72/77] documentation examples - wip --- .../docs/getting_started/pipeline.mdx | 2 + .../{running.mdx => step_by_step.mdx} | 38 ++++++++++--------- 2 files changed, 22 insertions(+), 18 deletions(-) rename docs/vame-docs-app/docs/getting_started/{running.mdx => step_by_step.mdx} (83%) diff --git a/docs/vame-docs-app/docs/getting_started/pipeline.mdx b/docs/vame-docs-app/docs/getting_started/pipeline.mdx index 5e69ecc0..4b129b71 100644 --- a/docs/vame-docs-app/docs/getting_started/pipeline.mdx +++ b/docs/vame-docs-app/docs/getting_started/pipeline.mdx @@ -10,4 +10,6 @@ import useBaseUrl from '@docusaurus/useBaseUrl'; import IframeResizer from '@site/src/components/IframeResizer'; +The notebook below is available [here](https://github.com/EthoML/VAME/blob/main/examples/pipeline.ipynb). + \ No newline at end of file diff --git a/docs/vame-docs-app/docs/getting_started/running.mdx b/docs/vame-docs-app/docs/getting_started/step_by_step.mdx similarity index 83% rename from docs/vame-docs-app/docs/getting_started/running.mdx rename to docs/vame-docs-app/docs/getting_started/step_by_step.mdx index 588e3449..52d329ac 100644 --- a/docs/vame-docs-app/docs/getting_started/running.mdx +++ b/docs/vame-docs-app/docs/getting_started/step_by_step.mdx @@ -1,39 +1,41 @@ --- -title: Running VAME Workflow +title: VAME step-by-step sidebar_position: 2 --- - - -## Workflow Overview -The below diagram shows the workflow of the VAME application, which consists of four main steps and optional steps to analyse your data. -![Workflow Overview](/img/workflow_overview.png) -1. Initialize project: This is step is responsible by starting the project, getting your data into the right format and creating a training dataset for the VAME deep learning model. -2. Train neural network: Train a variational autoencoder which is parameterized with recurrent neural network to embed behavioural dynamics -3. Evaluate performance: Evaluate the trained model based on its reconstruction capabilities -4. Segment behavior: Segment behavioural motifs/poses/states from the input time series -5. Quantify behavior: +The VAME workflow consists of four main steps, plus optional steps to analyse your data: +1. Initialize project: This is step is responsible by starting the project, getting your pose estimation data into the right format +2. Preprocess: This step is responsible for cleaning, filtering and aligning the pose estimation data +3. Train neural network: + - Create a training dataset for the VAME deep learning model. + - Train a variational autoencoder which is parameterized with recurrent neural network to embed behavioural dynamics. + - Evaluate the performance of the trained model based on its reconstruction capabilities +4. Segment behavior: + - Segment pose estimation time series into behavioral motifs, using HMM or K-means. + - Group similar motifs into communities, using hierarchical clustering. +5. Analysis: - Optional: Create motif videos to get insights about the fine grained poses. - - Optional: Investigate the hierarchical order of your behavioural states by detecting communities in the resulting markov chain. - Optional: Create community videos to get more insights about behaviour on a hierarchical scale. - Optional: Visualization and projection of latent vectors onto a 2D plane via UMAP. - Optional: Use the generative model (reconstruction decoder) to sample from the learned data distribution, reconstruct random real samples or visualize the cluster centre for validation. - - Optional: Create a video of an egocentrically aligned animal + path through the community space (similar to our gif on github readme). + - Optional: Create a video of an egocentrically aligned animal + path through the community space. :::tip -⚠️ Check out also the published VAME Workflow Guide, including more hands-on recommendations and tricks [HERE](https://www.nature.com/articles/s42003-022-04080-7#Sec8). +Check out also the published VAME Workflow Guide, including more hands-on recommendations [HERE](https://www.nature.com/articles/s42003-022-04080-7#Sec8). ::: -## Running a demo workflow -In our github in `/examples` folder there is a demo script called `demo.py` that you can use to run a simple example of the VAME workflow. To run this workflow you will need to do the following: +## Running VAME, step-by-step +In this tutorial we will show you how to run the VAME workflow step-by-step using a simple example. The code below can also be found in the `demo.py` script in our Github repository ([link here](https://github.com/EthoML/VAME/blob/main/examples/demo.py)). -### 1. Download the necessary resources: + +### 0. Download the necessary resources: To run the demo you will need a video and a csv file with the pose estimation results. You can use the following files links: - `video-1.mp4`: Video file [link](https://drive.google.com/file/d/1w6OW9cN_-S30B7rOANvSaR9c3O5KeF0c/view) - `video-1.csv`: Pose estimation results [link](https://github.com/EthoML/VAME/blob/master/examples/video-1.csv) + ### 2. Running the demo main pipeline We will now show you how to run the main pipeline of the VAME workflow using snnipets of code. We suggest you to run these snippets in a jupyter notebook. @@ -170,7 +172,7 @@ To create behavioral hierarchies and communities detection run: ```python vame.community(config, parametrization='hmm', cut_tree=2, cohort=False) ``` -It will produce a tree plot of the behavioural hierarchies using hmm motifs. +It will produce a tree plot of the behavioural hierarchies using hmm motifs. #### 3.3 Community Videos Create community videos to get insights about behavior on a hierarchical scale. From 36960360d7133f1ded8ea82177dbf6aa6bacfd03 Mon Sep 17 00:00:00 2001 From: luiz Date: Wed, 1 Jan 2025 10:30:48 +0100 Subject: [PATCH 73/77] fix broken link --- docs/vame-docs-app/docs/getting_started/installation.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/vame-docs-app/docs/getting_started/installation.md b/docs/vame-docs-app/docs/getting_started/installation.md index 962ac219..c50f0742 100644 --- a/docs/vame-docs-app/docs/getting_started/installation.md +++ b/docs/vame-docs-app/docs/getting_started/installation.md @@ -44,7 +44,7 @@ You should make sure that you have a GPU powerful enough to train deep learning VAME can also be trained in Google Colab or on a HPC cluster. ::: -Once you have your computing setup ready, begin using VAME by following the [demo workflow guide](/docs/getting_started/running). +Once you have your computing setup ready, begin using VAME by following the [step-by-step guide](/docs/getting_started/step_by_step). ## References Original VAME publication: [Identifying Behavioral Structure from Deep Variational Embeddings of Animal Motion](https://www.biorxiv.org/content/10.1101/2020.05.14.095430v2)
From c360c69d264390173a29c2ae08a3681237058de9 Mon Sep 17 00:00:00 2001 From: luiz Date: Wed, 1 Jan 2025 11:00:56 +0100 Subject: [PATCH 74/77] guide wip --- .../docs/getting_started/step_by_step.mdx | 101 ++++++------------ src/vame/pipeline.py | 1 - 2 files changed, 31 insertions(+), 71 deletions(-) diff --git a/docs/vame-docs-app/docs/getting_started/step_by_step.mdx b/docs/vame-docs-app/docs/getting_started/step_by_step.mdx index 52d329ac..51dda48c 100644 --- a/docs/vame-docs-app/docs/getting_started/step_by_step.mdx +++ b/docs/vame-docs-app/docs/getting_started/step_by_step.mdx @@ -20,99 +20,60 @@ The VAME workflow consists of four main steps, plus optional steps to analyse yo - Optional: Use the generative model (reconstruction decoder) to sample from the learned data distribution, reconstruct random real samples or visualize the cluster centre for validation. - Optional: Create a video of an egocentrically aligned animal + path through the community space. + +In this tutorial we will show you how to run the VAME workflow using a simple example. The code below can also be found in the `demo.py` script in our Github repository ([link here](https://github.com/EthoML/VAME/blob/main/examples/demo.py)). + :::tip Check out also the published VAME Workflow Guide, including more hands-on recommendations [HERE](https://www.nature.com/articles/s42003-022-04080-7#Sec8). ::: - -## Running VAME, step-by-step - -In this tutorial we will show you how to run the VAME workflow step-by-step using a simple example. The code below can also be found in the `demo.py` script in our Github repository ([link here](https://github.com/EthoML/VAME/blob/main/examples/demo.py)). - - -### 0. Download the necessary resources: -To run the demo you will need a video and a csv file with the pose estimation results. You can use the following files links: -- `video-1.mp4`: Video file [link](https://drive.google.com/file/d/1w6OW9cN_-S30B7rOANvSaR9c3O5KeF0c/view) -- `video-1.csv`: Pose estimation results [link](https://github.com/EthoML/VAME/blob/master/examples/video-1.csv) +:::tip +You can run an entire VAME workflow with just a few lines, using the [Pipeline method](/docs/getting_started/pipeline). +::: -### 2. Running the demo main pipeline -We will now show you how to run the main pipeline of the VAME workflow using snnipets of code. We suggest you to run these snippets in a jupyter notebook. +## 0. [Optional] Download input data +To run VAME you will need a video and a pose estimation file. If you don't have your own data, you download sample data: -#### 2.1a Setting the demo variables using CSV files -To start the demo you must define 4 variables: ```python -import vame - +from vame.util.sample_data import download_sample_data -# The directory where the project will be saved -working_directory = '.' +source_software = "DeepLabCut" # "DeepLabCut", "SLEAP" or "LightningPose" +ps = download_sample_data(source_software) # Data will be downloaded to ~/.movement/data/ +videos = [ps["video"]] # List of paths to the video files +poses_estimations = [ps["poses"]] # List of paths to the pose estimation files +``` -# The name you want for the project -project = 'my-vame-project' -# A list of paths to the videos file -videos = ['video-1.mp4'] +## 1. Initialize the project +VAME organizes around projects. To start a new project, you need to define a few things: -# A list of paths to the poses estimations files. -# Important: The name (without the extension) of the video file and the pose estimation file must be the same. E.g. `video-1.mp4` and `video-1.csv` -poses_estimations = ['video-1.csv'] -``` - -#### 2.1b Setting the demo variables using NWB files -Alternativaly you can use `.nwb` files as pose estimation files. In this case you must define 4 variables: ```python import vame +working_directory = '.' # The directory where the project will be saved +project = 'my-vame-project' # The name of the project -# The directory where the project will be saved -working_directory = '.' - -# The name you want for the project -project = 'my-vame-project' - -# A list of paths to the videos file -videos = ['video-1.mp4'] +# [Optional] Customized configuration for the project +config_kwargs = { + "n_clusters": 15, + "pose_confidence": 0.9, + "max_epochs": 100, +} -# A list of paths to the poses estimations files. -# Important: The name (without the extension) of the video file and the pose estimation file must be the same. E.g. `video-1.mp4` and `video-1.nwb` -poses_estimations = ['video-1.nwb'] - -# A list of paths in the NWB file where the pose estimation data is stored. -paths_to_pose_nwb_series_data = ['processing/behavior/data_interfaces/PoseEstimation/pose_estimation_series'] -``` - - - -#### 2.2 Initializing the project -With the variables set, you can initialize the project by running the following code: - -If you are using CSV files you can run the following code to initialize the project: -```python -config = vame.init_new_project( - project=project, - videos=videos, - poses_estimations=poses_estimations, - working_directory=working_directory, - videotype='.mp4' -) -``` - -If you are using NWB files you can run the following code to initialize the project: -```python -config = vame.init_new_project( - project=project, +config_path, config = vame.init_new_project( + project_name=project_name, videos=videos, poses_estimations=poses_estimations, + source_software=source_software, working_directory=working_directory, - videotype='.mp4', - paths_to_pose_nwb_series_data=paths_to_pose_nwb_series_data + config_kwargs=config_kwargs, ) ``` -This command will create a project folder in the defined working directory with the name you set in the `project` variable and a date suffix, e.g: `my-vame-project-May-9-2024`. -In this folder you can find a config file called `config.yaml` where you can set the parameters for the VAME algorithm. -The videos and poses estimations files will be copied to the project videos folder. It is really important to define in the `config.yaml` file if your data is egocentrically aligned or not before running the rest of the workflow. +This command will create a project folder in the defined working directory with the project name you defined. +In this folder you can find a config file called `config.yaml` which holds the main parameters for the VAME workflow. +The videos and pose estimation files will be linked or copied to the project folder. #### 2.3 Egocentric alignment diff --git a/src/vame/pipeline.py b/src/vame/pipeline.py index 56edf70b..dabd54fd 100644 --- a/src/vame/pipeline.py +++ b/src/vame/pipeline.py @@ -77,7 +77,6 @@ def __init__( paths_to_pose_nwb_series_data=paths_to_pose_nwb_series_data, config_kwargs=config_kwargs, ) - self.config = read_config(self.config_path) def get_states(self, summary: bool = True) -> dict: """ From 501629d7d80355618424da718cb9ed24d0ecae0f Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 2 Jan 2025 09:56:34 +0100 Subject: [PATCH 75/77] fix broken link --- docs/vame-docs-app/docs/getting_started/pipeline.mdx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/vame-docs-app/docs/getting_started/pipeline.mdx b/docs/vame-docs-app/docs/getting_started/pipeline.mdx index 4b129b71..d2961ea2 100644 --- a/docs/vame-docs-app/docs/getting_started/pipeline.mdx +++ b/docs/vame-docs-app/docs/getting_started/pipeline.mdx @@ -1,8 +1,8 @@ --- -id: run-pipeline +id: pipeline title: Run Pipeline sidebar_position: 3 -slug: /getting_started/run-pipeline +slug: /getting_started/pipeline --- import React from 'react'; From e25e945e0a1e82d36dd0ce5e87e163b9f554ac87 Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 2 Jan 2025 11:11:33 +0100 Subject: [PATCH 76/77] docs --- .../docs/getting_started/step_by_step.mdx | 141 ++++++++++-------- docs/vame-docs-app/docs/project_config.mdx | 94 ++++++++++++ src/vame/schemas/project.py | 8 +- 3 files changed, 178 insertions(+), 65 deletions(-) create mode 100644 docs/vame-docs-app/docs/project_config.mdx diff --git a/docs/vame-docs-app/docs/getting_started/step_by_step.mdx b/docs/vame-docs-app/docs/getting_started/step_by_step.mdx index 51dda48c..ab951b4a 100644 --- a/docs/vame-docs-app/docs/getting_started/step_by_step.mdx +++ b/docs/vame-docs-app/docs/getting_started/step_by_step.mdx @@ -51,8 +51,8 @@ VAME organizes around projects. To start a new project, you need to define a few ```python import vame -working_directory = '.' # The directory where the project will be saved -project = 'my-vame-project' # The name of the project +working_directory = "." # The directory where the project will be saved +project = "my-vame-project" # The name of the project # [Optional] Customized configuration for the project config_kwargs = { @@ -72,102 +72,123 @@ config_path, config = vame.init_new_project( ``` This command will create a project folder in the defined working directory with the project name you defined. -In this folder you can find a config file called `config.yaml` which holds the main parameters for the VAME workflow. +In this folder you can find a config file called [config.yaml](/docs/project-config) which holds the main parameters for the VAME workflow. The videos and pose estimation files will be linked or copied to the project folder. - -#### 2.3 Egocentric alignment -If your data is not egocentrically aligned, you can align it by running the following code: +## 2. Preprocess the data +The preprocessing step is responsible for cleaning, filtering and aligning the pose estimation data. ```python -vame.egocentric_alignment(config, pose_ref_index=[0, 5]) +vame.preprocessing( + config=config, + centered_reference_keypoint=centered_reference_keypoint, + orientation_reference_keypoint=orientation_reference_keypoint, +) ``` -But if your experiment is by design egocentrical (e.g. head-fixed experiment on treadmill etc) you can use the following to convert your .csv to a .npy array, ready to train vame on it. +Internally, this function will: -```python -vame.csv_to_numpy(config) -``` +### 2.1 Clean low confidence data points +Pose estimation data points with confidence below the threshold will be cleared and interpolated. + +### 2.2 Egocentric alignment +Based on two reference keypoints, the data will be aligned to an egocentric coordinate system: +- `centered_reference_keypoint`: The keypoint that will be centered in the frame. +- `orientation_reference_keypoint`: The keypoint that will be used to determine the rotation of the frame. + +By consequence, the `x` and `y` coordinates of the `centered_reference_keypoint` and the `x` coordinate of the `orientation_reference_keypoint` will be set to an array of zeros, and further removed from the dataset. + +### 2.3 Clean outliers +Outliers will be removed based on the interquartile range (IQR) method. This means that data points that are below `Q1 - iqr_factor * IQR` or above `Q3 + iqr_factor * IQR` will be cleared and interpolated. + +### 2.4 Savitzky-Golay filter +The data will be further smoothed using a Savitzky-Golay filter. -#### 2.4 Creating the training dataset -To create the training dataset you can run the following code: + +## 3. Train the neural network +At this point, we will prepare the data for training the VAME model, run the training and evaluate the model. + +### 3.1 Create the training dataset +To create the training dataset, which will put the data in the right format for the VAME model, and split it into training and test sets, you can run: ```python -vame.create_trainset(config, pose_ref_index=[0,5]) +vame.create_trainset(config=config) ``` +### 3.2 Training the model +Training the vame model might take a while depending on the size of your dataset and your machine settings. To train the model you can run: -#### 2.5 Training the model -Training the vame model might take a while depending on the size of your dataset and your machine settings. To train the model you can run the following code: ```python -vame.train_model(config) +vame.train_model(config=config) ``` -#### 2.6 Evaluate the model +### 3.3 Evaluate the model THe model evaluation produces two plots, one showing the loss of the model during training and the other showing the reconstruction and future prediction of input sequence. + ```python -vame.evaluate_model(config) +vame.evaluate_model(config=config) ``` -#### 2.7 Segmenting the behavior -To perform pose segmentation you can run the following code: +## 4. Segment behavior +Behavioral segmentation in VAME is done in two steps: pose segmentation into motifs and community detection. + +### 4.1 Pose segmentation +To perform pose segmentation you can run: ```python -vame.pose_segmentation(config) +vame.segmentat_session(config=config) ``` +This will perfomr the segmentation using two different algorithms: HMM and K-means. The results will be saved in the project folder. -### 3. Running Optional Steps of the Pipeline -:::tip -The following steps are optional and can be run if you want to create motif VideoColorSpace, communities/hierarchies of behavior and community VideoColorSpace. -::: +### 4.2 Community detection +Community detection is done by grouping similar motifs into communities using hierarchical clustering. To run community detection you can run: -#### 3.1 Creating motif videos -To create motif videos and get insights about the fine grained poses you can run: ```python -vame.motif_videos(config, videoType='.mp4') +vame.community( + config=config, + segmentation_algorithm="hmm", + cut_tree=2, +) ``` -#### 3.2 Run community detection -To create behavioral hierarchies and communities detection run: -```python -vame.community(config, parametrization='hmm', cut_tree=2, cohort=False) -``` -It will produce a tree plot of the behavioural hierarchies using hmm motifs. +where `segmentation_algorithm` can be either "hmm" or "kmeans" and `cut_tree` is the cut level for the hierarchical clustering. + + +## 5. Visualize and analyze + +### 5.1 Creating motif and community videos +To create motif videos and get insights about the fine grained poses you can run: -#### 3.3 Community Videos -Create community videos to get insights about behavior on a hierarchical scale. ```python -vame.community_videos(config) +vame.motif_videos( + config=config, + segmentation_algorithm="hmm", + video_type=".mp4", +) ``` -#### 3.4 UMAP Visualization - Down projection of latent vectors and visualization via UMAP. +Create community videos: ```python -fig = vame.visualization(config, label=None) #options: label: None, "motif", "community" +vame.community_videos( + config=config, + segmentation_algorithm="hmm", + video_type=".mp4", +) ``` +### 5.2 UMAP visualization +To visualize and project latent vectors onto a 2D plane via UMAP you can run: -#### 3.5 Generative Model (Reconstruction decoder) -Use the generative model (reconstruction decoder) to sample from the learned data distribution, reconstruct random real samples or visualize -the cluster center for validation. ```python -vame.generative_model(config, mode="centers") #options: mode: "sampling", "reconstruction", "centers", "motifs" -``` +from vame.visualization.umap import visualize_umap -#### 3.6 Create output video -Create a video of an egocentrically aligned mouse + path through -the community space (similar to our gif on github) to learn more about your representation -and have something cool to show around. +visualize_umap( + config=config, + label="motif", + segmentation_algorithm="hmm", +) +``` -:::warning -This function is currently very slow. -::: +where `label` can be either None, "motif" or "community", and `segmentation_algorithm` can be either 'hmm' or 'kmeans'. -```python -vame.gif(config, pose_ref_index=[0,5], subtract_background=True, start=None, - length=500, max_lag=30, label='community', file_format='.mp4', crop_size=(300,300)) -``` -:::tip -Once the frames are saved you can create a video or gif via e.g. ImageJ or other tools -::: diff --git a/docs/vame-docs-app/docs/project_config.mdx b/docs/vame-docs-app/docs/project_config.mdx new file mode 100644 index 00000000..bb2ed218 --- /dev/null +++ b/docs/vame-docs-app/docs/project_config.mdx @@ -0,0 +1,94 @@ +--- +title: Project Configuration +sidebar_position: 3 +slug: /project-config +--- + +The project configuration YAML file exists in the root of the project folder and holds the main parameters for the VAME workflow. The configuration file is created when initializing a new project with the [init_new_project](/docs/reference/initialize_project/new) function. + +The configuration file contains the following parameters: + +### Project parameters +- **project_name** (`str`): The name of the project. +- **creation_datetime** (`str`): The creation datetime of the project. +- **model_name** (`str`): The name of the model. +- **n_clusters** (`int`): The number of clusters. +- **pose_confidence** (`float`): The pose confidence. +- **project_path** (`str`): The path to the project. +- **session_names** (`List[str]`): The names of the sessions. +- **pose_estimation_filetype** (`PoseEstimationFiletype`): The pose estimation filetype. +- **paths_to_pose_nwb_series_data** (`Optional[List[str]]`): Paths to pose series data in nwb files. + +### Data +- **all_data** (`str`): All data. +- **egocentric_data** (`bool`): Egocentric data. +- **robust** (`bool`): Robust data. +- **iqr_factor** (`int`): IQR factor. +- **axis** (`str`): Axis. +- **savgol_filter** (`bool`): Savgol filter. +- **savgol_length** (`int`): Savgol length. +- **savgol_order** (`int`): Savgol order. +- **test_fraction** (`float`): Test fraction. + +### RNN model general hyperparameters +- **pretrained_model** (`str`): Pretrained model. +- **pretrained_weights** (`bool`): Pretrained weights. +- **num_features** (`int`): Number of features. +- **batch_size** (`int`): Batch size. +- **max_epochs** (`int`): Max epochs. +- **model_snapshot** (`int`): Model snapshot. +- **model_convergence** (`int`): Model convergence. +- **transition_function** (`str`): Transition function. +- **beta** (`float`): Beta. +- **beta_norm** (`bool`): Beta normalization. +- **zdims** (`int`): Zdims. +- **learning_rate** (`float`): Learning rate. +- **time_window** (`int`): Time window. +- **prediction_decoder** (`int`): Prediction decoder. +- **prediction_steps** (`int`): Prediction steps. +- **noise** (`bool`): Noise. +- **scheduler** (`int`): Scheduler. +- **scheduler_step_size** (`int`): Scheduler step size. +- **scheduler_gamma** (`float`): Scheduler gamma. +- **scheduler_threshold** (`float`): Scheduler threshold. +- **softplus** (`bool`): Softplus. + +### Segmentation +- **segmentation_algorithms** (`List[SegmentationAlgorithms]`): Segmentation algorithms. +- **hmm_trained** (`bool`): HMM trained. +- **load_data** (`str`): Load data. +- **individual_segmentation** (`bool`): Individual segmentation. +- **random_state_kmeans** (`int`): Random state kmeans. +- **n_init_kmeans** (`int`): N init kmeans. + +### Video writer +- **length_of_motif_video** (`int`): Length of motif video. + +### UMAP parameter +- **min_dist** (`float`): Min dist. +- **n_neighbors** (`int`): N neighbors. +- **random_state** (`int`): Random state. +- **num_points** (`int`): Num points. + +### RNN encoder hyperparameters +- **hidden_size_layer_1** (`int`): Hidden size layer 1. +- **hidden_size_layer_2** (`int`): Hidden size layer 2. +- **dropout_encoder** (`float`): Dropout encoder. + +### RNN reconstruction hyperparameters +- **hidden_size_rec** (`int`): Hidden size rec. +- **dropout_rec** (`float`): Dropout rec. +- **n_layers** (`int`): N layers. + +### RNN prediction hyperparameters +- **hidden_size_pred** (`int`): Hidden size pred. +- **dropout_pred** (`float`): Dropout pred. + +### RNN loss hyperparameters +- **mse_reconstruction_reduction** (`str`): MSE reconstruction reduction. +- **mse_prediction_reduction** (`str`): MSE prediction reduction. +- **kmeans_loss** (`int`): Kmeans loss. +- **kmeans_lambda** (`float`): Kmeans lambda. +- **anneal_function** (`str`): Anneal function. +- **kl_start** (`int`): KL start. +- **annealtime** (`int`): Annealtime. diff --git a/src/vame/schemas/project.py b/src/vame/schemas/project.py index 037f75e1..17d4efd1 100644 --- a/src/vame/schemas/project.py +++ b/src/vame/schemas/project.py @@ -23,7 +23,7 @@ class Config: class ProjectSchema(BaseModel): - # Project attributes + # Project parameters project_name: str = Field( ..., title="Project name", @@ -44,8 +44,6 @@ class ProjectSchema(BaseModel): default=0.99, title="Pose confidence", ) - - # Project path and videos project_path: str = Field( ..., title="Project path", @@ -100,7 +98,7 @@ class ProjectSchema(BaseModel): title="Test fraction", ) - # RNN model general hyperparameter: + # RNN model general hyperparameters pretrained_model: str = Field( default="None", title="Pretrained model", @@ -186,7 +184,7 @@ class ProjectSchema(BaseModel): title="Softplus", ) - # Segmentation: + # Segmentation segmentation_algorithms: List[SegmentationAlgorithms] = Field( title="Segmentation algorithms", default_factory=lambda: [ From e16913498d60b5385238157d7bd9c10fd9897f1e Mon Sep 17 00:00:00 2001 From: luiz Date: Thu, 2 Jan 2025 11:38:12 +0100 Subject: [PATCH 77/77] test pipeline vis --- tests/01_pipeline_test.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/01_pipeline_test.py b/tests/01_pipeline_test.py index 9e775a85..7a9d424b 100644 --- a/tests/01_pipeline_test.py +++ b/tests/01_pipeline_test.py @@ -1,9 +1,10 @@ import xarray as xr +from pathlib import Path def test_pipeline(setup_pipeline): pipeline = setup_pipeline["pipeline"] - + project_path = pipeline.config["project_path"] sessions = pipeline.get_sessions() assert len(sessions) == 1 @@ -15,3 +16,12 @@ def test_pipeline(setup_pipeline): "orientation_reference_keypoint": "Tailroot", } pipeline.run_pipeline(preprocessing_kwargs=preprocessing_kwargs) + + pipeline.visualize_preprocessing( + show_figure=False, + save_to_file=True, + ) + save_fig_path_0 = Path(project_path) / "reports" / "figures" / f"{sessions[0]}_preprocessing_scatter.png" + save_fig_path_1 = Path(project_path) / "reports" / "figures" / f"{sessions[0]}_preprocessing_timeseries.png" + assert save_fig_path_0.exists() + assert save_fig_path_1.exists()