-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtemporal_clustering_analysis.py
248 lines (203 loc) · 9.73 KB
/
temporal_clustering_analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from visualization import plot_clustered_movement, plot_movement
import random
import os
import sys
from tqdm import tqdm
import argparse
def temporal_dfs_clustering(session, distance_threshold=30, max_neighbors=5, time_threshold=100):
"""
Perform DFS-based clustering on worm movement paths, considering both spatial and temporal proximity.
Args:
session (pd.DataFrame): Session data with X, Y coordinates and Timestamp.
distance_threshold (float): Distance threshold for spatial clustering.
max_neighbors (int): Maximum number of subsequent points to consider.
time_threshold (float): Time threshold (in seconds) for temporal vicinity.
Returns:
pd.Series: Cluster labels for each frame in the session.
"""
visited = np.zeros(len(session), dtype=bool)
clusters = -np.ones(len(session), dtype=int) # Initialize all as unclustered
cluster_id = 0
# Initialize progress bar for the main clustering loop
pbar = tqdm(total=len(session), desc="Clustering points", leave=False)
for start_idx in range(len(session)):
if not visited[start_idx]:
# Start a new cluster
stack = [start_idx]
points_in_current_cluster = 0
while stack:
idx = stack.pop()
if visited[idx]:
continue
visited[idx] = True
clusters[idx] = cluster_id
points_in_current_cluster += 1
# Add the next `max_neighbors` rows to the stack if within both thresholds
for neighbor_offset in range(1, max_neighbors + 1):
neighbor_idx = idx + neighbor_offset
if neighbor_idx < len(session) and not visited[neighbor_idx]:
# Check temporal threshold first (it's cheaper)
time_diff = abs(session.iloc[idx]['Timestamp'] - session.iloc[neighbor_idx]['Timestamp'])
if time_diff <= time_threshold:
# Then check spatial threshold
distance = np.sqrt(
(session.iloc[idx]['X'] - session.iloc[neighbor_idx]['X'])**2 +
(session.iloc[idx]['Y'] - session.iloc[neighbor_idx]['Y'])**2
)
if distance < distance_threshold:
stack.append(neighbor_idx)
cluster_id += 1
pbar.update(points_in_current_cluster)
pbar.close()
print(f"Found {cluster_id} clusters")
return pd.Series(clusters, index=session.index, name="Cluster")
def process_treatment_group(input_path, output_path, distance_threshold=30, max_neighbors=5, time_threshold=100):
"""
Process all files in a treatment group and save clustered data.
Args:
input_path (Path): Path to input treatment directory.
output_path (Path): Path to output treatment directory.
distance_threshold (float): Maximum distance between consecutive points.
max_neighbors (int): Maximum number of subsequent frames to check.
time_threshold (float): Maximum time between consecutive points.
"""
# Create output directory if it doesn't exist
output_path.mkdir(parents=True, exist_ok=True)
# Get all CSV files in the treatment directory
csv_files = list(input_path.glob("*.csv"))
print(f"Found {len(csv_files)} files to process")
# Process each file
for csv_file in tqdm(csv_files, desc="Processing files", leave=True):
try:
# Load data
data = pd.read_csv(csv_file)
print(f"\nProcessing {csv_file.name}")
print(f"Initial shape: {data.shape}")
# Remove rows where X or Y are NaN
data_clean = data.dropna(subset=['X', 'Y'])
print(f"Shape after removing NaN values: {data_clean.shape}")
if len(data_clean) == 0:
print(f"Warning: All rows were NaN in {csv_file.name}")
continue
# Perform temporal clustering with specified parameters
clusters = temporal_dfs_clustering(
data_clean,
distance_threshold=distance_threshold,
max_neighbors=max_neighbors,
time_threshold=time_threshold
)
# Add cluster information to the data
data_clean['Cluster'] = clusters
# Save processed data
output_file = output_path / csv_file.name
data_clean.to_csv(output_file, index=False)
print(f"Saved clustered data to: {output_file}")
except Exception as e:
print(f"Error processing file {csv_file.name}: {str(e)}")
def plot_sample_files(base_path, output_dir, samples_per_treatment=2):
"""
Plot sample files from each treatment group.
Args:
base_path (Path): Path to the clustered data directory.
output_dir (Path): Directory to save plots.
samples_per_treatment (int): Number of files to sample from each treatment.
"""
# Create output directory if it doesn't exist
output_dir.mkdir(parents=True, exist_ok=True)
# Process each treatment group
for treatment_dir in tqdm(list(base_path.iterdir()), desc="Processing treatment groups"):
if not treatment_dir.is_dir():
continue
treatment = treatment_dir.name
print(f"\nProcessing {treatment}")
# Get all CSV files
csv_files = list(treatment_dir.glob("*.csv"))
if not csv_files:
print(f"No files found in {treatment}")
continue
# Sample random files
sampled_files = random.sample(csv_files, min(samples_per_treatment, len(csv_files)))
# Plot each sampled file
for csv_file in sampled_files:
print(f"Plotting {csv_file.name}")
try:
# Load data
data = pd.read_csv(csv_file)
# Create figure
fig, axes = plt.subplots(1, 2, figsize=(20, 8))
fig.suptitle(f"{treatment} - {csv_file.stem}", fontsize=16)
# Raw movement plot
plot_movement(
axes[0], data,
"Raw Movement"
)
# Clustered movement plot
plot_clustered_movement(
axes[1], data, data['Cluster'],
"Temporal Clustering"
)
plt.tight_layout()
# Save plot
output_path = output_dir / f"{treatment}_{csv_file.stem}_comparison.png"
plt.savefig(
output_path,
dpi=300,
bbox_inches='tight'
)
plt.close()
except Exception as e:
print(f"Error plotting file {csv_file.name}: {str(e)}")
plt.close()
def main():
# Add argument parser
parser = argparse.ArgumentParser(description='Perform temporal clustering analysis.')
parser.add_argument('--input', type=str, required=True, help='Input directory containing processed data')
parser.add_argument('--output', type=str, required=True, help='Output directory for clustered data')
parser.add_argument('--plots', type=str, help='Directory for saving plots (optional)', default='data/temporal_clustering_plots')
parser.add_argument('--noplot', action='store_true', help='Disable plotting')
# Add clustering parameters
parser.add_argument('--distance-threshold', type=float, default=30,
help='Maximum distance between consecutive points (default: 30)')
parser.add_argument('--max-neighbors', type=int, default=5,
help='Maximum number of subsequent frames to check (default: 5)')
parser.add_argument('--time-threshold', type=float, default=100,
help='Maximum time between consecutive points (default: 100)')
args = parser.parse_args()
# Set up paths
input_base = Path(args.input)
output_base = Path(args.output)
plot_dir = Path(args.plots)
print("Step 1: Processing files")
print(f"Input directory: {input_base}")
print(f"Output directory: {output_base}")
print(f"\nClustering parameters:")
print(f"Distance threshold: {args.distance_threshold}")
print(f"Max neighbors: {args.max_neighbors}")
print(f"Time threshold: {args.time_threshold}")
# Process each treatment group
for treatment_dir in tqdm(list(input_base.iterdir()), desc="Processing treatment groups"):
if not treatment_dir.is_dir():
continue
treatment = treatment_dir.name
print(f"\nProcessing {treatment} group...")
# Set up input and output paths for this treatment
input_path = input_base / treatment
output_path = output_base / treatment
# Process the treatment group with specified parameters
process_treatment_group(
input_path,
output_path,
distance_threshold=args.distance_threshold,
max_neighbors=args.max_neighbors,
time_threshold=args.time_threshold
)
print(f"Completed {treatment} group")
if not args.noplot:
print("\nStep 2: Plotting sample files")
plot_sample_files(output_base, plot_dir)
if __name__ == "__main__":
main()