-
Notifications
You must be signed in to change notification settings - Fork 0
/
3_annotate.py
65 lines (51 loc) · 1.97 KB
/
3_annotate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import json
import numpy as np
import pandas as pd
from scipy import optimize
from scipy.spatial import distance
import tqdm
satellites = []
with open('data/spotGEO/train_anno.json') as f:
for ann in json.load(f):
for i, coords in enumerate(ann['object_coords']):
satellites.append({
'sequence': ann['sequence_id'],
'frame': ann['frame'],
'satellite': i + 1,
'r': int(coords[1] + .5),
'c': int(coords[0] + .5),
})
satellites = pd.DataFrame(satellites)
satellites = satellites.set_index(['sequence', 'frame', 'satellite'])
satellites.head()
df = pd.read_pickle('data/interesting.pkl')
labels = []
for (sequence, frame), g in tqdm.tqdm(df.query('part == "train"').groupby(['sequence', 'frame'])):
try:
sats = satellites.loc[sequence, frame]
except KeyError:
continue
# Compute the distance between each satellite and each interesting location,
# thus forming a bipartite graph
centers = g[['r', 'c']]
distances = distance.cdist(sats, centers, metric='chebyshev')
# Guess which locations correspond to which satellites
row_ind, col_ind = optimize.linear_sum_assignment(distances)
# Each satellite is assigned, but some of them may be too distant to be likely
likely = distances[row_ind, col_ind] <= 2
is_satellite = np.full(len(centers), False, dtype=bool)
is_satellite[col_ind[likely]] = True
labels.append(pd.DataFrame({
'part': 'train',
'sequence': sequence,
'frame': frame,
'r': g['r'],
'c': g['c'],
'is_satellite': is_satellite
}))
labels = pd.concat(labels).set_index(['part', 'sequence', 'frame', 'r', 'c'])
df = df.join(labels, on=labels.index.names)
df['is_satellite'] = df['is_satellite'].fillna(False)
df.loc[df['part'] == 'test', 'is_satellite'] = np.nan
df.to_pickle('data/interesting.pkl')
print(f'Recall is {df.is_satellite.sum() / len(satellites):.2%}')