-
Notifications
You must be signed in to change notification settings - Fork 0
/
probe_sbr_shared.py
74 lines (70 loc) · 4.73 KB
/
probe_sbr_shared.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
y = torch.tensor([[[ -0.00267595, -0.00690883, 0.00095639, 0.00144628],
[ 0.00671170, -0.02427113, -19.12420273, 0.00878605],
[ -0.04579823, 19.21648407, 0.02576208, -0.09778216],
[ 0.00306971, -0.00147089, 0.00145939, -0.00371212]],
[[ -0.00234936, -0.00543301, -0.00210993, 1.77932322],
[ 0.00624070, 0.00155828, 0.00005865, -0.00428426],
[ 0.00263186, -0.00012475, 0.00153051, 0.00137638],
[ 1.78773522, -0.00331541, 0.00533109, 0.00375492]],
[[ -0.00024781, -0.00277999, -0.00149714, -0.00022702],
[ 0.00327885, -0.00187669, 0.00014913, 0.00097528],
[ 0.00130655, 0.00004810, 0.00072407, 0.59660774],
[ 0.00078483, 0.00141406, -0.62144721, 0.00124915]],
[[ -0.35089865, 0.00267916, -0.00246031, 0.00021386],
[ -0.00273432, -0.34731507, 0.00009110, -0.00465555],
[ 0.00263027, 0.00005358, -0.36229292, 0.00272502],
[ -0.00073789, -0.00282637, 0.00535668, -0.35213637]],
[[ 0.00070698, -0.39653391, 0.00009765, 0.00379389],
[ -0.38043809, 0.00839045, 0.00017544, -0.00017725],
[ 0.00000953, 0.00006367, -0.01437722, -0.00069978],
[ -0.00507073, 0.00001247, -0.00111713, -0.00097920]],
[[ -0.00243983, 0.00952138, -0.39074254, 0.00239019],
[ -0.01045662, 0.00930998, -0.00014336, -0.00234616],
[ -0.38384521, 0.00008254, 0.00453316, -0.00000446],
[ -0.00262937, -0.00152871, 0.00007233, -0.00325850]],
[[ 0.00024792, 0.00378530, 0.00116913, 0.00031381],
[ -0.00408471, -0.00234000, 0.00027797, 0.64761943],
[ -0.00133333, -0.00038143, 0.00023685, -0.00303845],
[ -0.00019169, -0.67700928, -0.00381380, -0.00021126]]])
x = torch.tensor([[[ -0.14177431, -0.05512094, -0.05190736, 0.01966458],
[ -0.03129339, -0.13053992, -0.00536269, -0.20788656],
[ -0.07308455, -0.00110295, -0.12896767, 0.00522719],
[ 0.01264129, 0.20261863, 0.01769014, -0.13747029]],
[[ 0.00519245, 0.17335783, 0.20906085, 1.21877623],
[ 0.17944759, 0.03434239, 9.83431244, -0.18008515],
[ 0.22701925, -9.79364777, -0.01931207, 0.61105889],
[ 1.23251987, 0.22697963, -0.57771277, -0.00622910]],
[[ -0.01311911, -0.16098917, -0.01467771, -1.45433819],
[ -0.16076282, 0.01439310, 0.17141792, -0.08231727],
[ -0.02732595, -0.19809683, 0.00612975, 0.03030457],
[ -1.48880386, 0.07134882, -0.04308231, -0.00357314]],
[[ -0.15269850, 0.13718513, -0.01074478, -0.02196591],
[ 0.12449956, -0.16025540, 0.00760945, 0.17399928],
[ -0.00798067, 0.00216012, -0.14932625, 0.01805941],
[ -0.02153320, -0.16637614, -0.00584378, -0.14211686]],
[[ -0.00411133, -0.06264049, -0.05874282, 0.00252660],
[ -0.04619981, -0.00843079, 0.01711345, 0.02333736],
[ -0.04513314, -0.01807640, -0.01065661, -0.25748119],
[ -0.00154598, -0.06110601, 0.24755795, -0.01208787]],
[[ 0.01955069, 0.05289609, -0.17240576, 0.00857888],
[ 0.05645267, 0.05334385, 0.00120642, 0.01043023],
[ -0.18770154, -0.00601175, 0.03403803, 0.03063194],
[ -0.01456334, 0.00527843, -0.01035724, 0.03055911]],
[[ -0.02646191, -0.17164133, -0.04352060, 0.01944972],
[ -0.13042858, -0.02130671, -0.00935954, 0.05487733],
[ -0.05506284, -0.01165947, -0.01765519, 0.06355262],
[ 0.00920069, -0.09460083, -0.05307925, -0.01377958]]])
def ans(x):
x = torch.abs(x) >= 0.05
total = 0
p = 0
for i in range(7):
for j in range(i + 1, 7):
total += torch.sum(x[i] & x[j])
p += torch.any(x[i] & x[j])
if torch.any(x[i] & x[j]):
print(i, j, x[i] & x[j])
return p, total
print("SBR", ans(y))
print("Cosine similarity", ans(x))