-
Notifications
You must be signed in to change notification settings - Fork 1
/
model-selection.py
104 lines (86 loc) · 2.74 KB
/
model-selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python3
"""
Compares resulting model likelihood vs. number of clusters
"""
import os
import matplotlib.pyplot as plt
from framework import marginalizable_mixture_model as mixmodel
# from util import util_macc as data_macc
from util import util_adni as data_adni
from util import util_state_space as util
plt.rcParams["figure.autolayout"] = True
plt.rcParams["legend.loc"] = "upper right"
plt.rcParams["font.family"] = "serif"
home_dir = os.path.dirname(os.path.abspath(__file__))
alpha = 1.0
n_cluster_list = range(1, 8)
def main():
ztrain_orig, xtrain, *_ = data_adni.get_trajectories()
ztrain, std_param = util.standardize(ztrain_orig, return_params=True)
# (
# ztest_orig,
# xtest,
# dtest,
# mmsetest,
# lengthtest,
# idstest,
# agestest,
# ) = data_macc.get_data()
# ztest = util.standardize(ztest_orig, params=std_param)
"""
train models with different numbers of clusters and compare results
"""
mdls = [
mixmodel.MMLinGaussSS_marginalizable(
n_clusters=n_clusters,
states=ztrain,
observations=xtrain,
init="k-means",
alpha=alpha,
).train_with_multiple_random_starts(n_starts=1000, use_cache=True)
for n_clusters in n_cluster_list
]
for dset in ["ADNI"]: # "MACC"
for s, attr in {
"Expected complete data log likelihood": "e_complete_data_log_lik",
"AIC": "aic",
"BIC": "bic",
}.items():
fig, ax = plt.subplots()
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
values = [
getattr(m, attr)(
states=ztrain, # if dset == "ADNI" else ztest,
observations=xtrain, # if dset == "ADNI" else xtest,
)
for m in mdls
]
plt.plot(
n_cluster_list,
values,
"o-",
color="#0072CE",
linestyle="solid",
)
plt.xticks(
ticks=n_cluster_list,
labels=n_cluster_list,
)
# plt.title(f"{s} vs. number of clusters")
ax.set_xlabel("Number of clusters")
ax.set_ylabel(s)
plt.tight_layout()
os.makedirs("figures", exist_ok=True)
plt.savefig(
os.path.join(
"figures",
f"{dset}_elbow_plot_{attr.upper()}.pdf",
),
bbox_inches="tight",
transparent=True,
)
for m in mdls:
m.to_pickle()
if __name__ == "__main__":
main()