-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
86 lines (68 loc) · 2.76 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from optimizer import *
from manifold import *
import torch
import matplotlib.pyplot as plt
# Initialize parameters
rdim = 1000
kdim = 5
max_iter = 200
eye_mat = torch.eye(rdim)[:, :kdim]
symmetric_mat = torch.randn(rdim, rdim)
symmetric_mat = (symmetric_mat + symmetric_mat.T) / 2
G = torch.randn(rdim, kdim)
# Create a simple test function
def test_function(x):
return torch.trace(x.T @ symmetric_mat @ x) + 2 * torch.trace(x.T @ G)
# Create Stiefel manifold
manifold = StiefelManifold(rdim, kdim)
# Generate random initial point and orthogonalize
start_point = torch.nn.init.orthogonal_(torch.empty(rdim, kdim))
# Test gradient descent
print("\nTesting Gradient Descent:")
gd = GradientDescent(manifold)
result_gd = gd.optimize(start_point, test_function, max_iter)
print(f"Final function value: {test_function(result_gd):.10f}")
# Test BB method
print("\nTesting BB Method:")
bb = BBMethod(manifold)
result_bb = bb.optimize(start_point, test_function, max_iter)
print(f"Final function value: {test_function(result_bb):.10f}")
# Test regularized Newton method
print("\nTesting Regularized Newton Method:")
rn = RegularizedNewton(manifold)
result_rn = rn.optimize(start_point, test_function, max_iter)
print(f"Final function value: {test_function(result_rn):.10f}")
# Test adaptive regularized Newton trust method
print("\nTesting Adaptive Regularized Newton Trust Method:")
arnt = AdaptiveRegularizedNewtonTrust(manifold)
result_arnt = arnt.optimize(start_point, test_function, max_iter)
print(f"Final function value: {test_function(result_arnt):.10f}")
# Plot convergence curves
plt.figure(figsize=(10, 6))
# 获取所有函数值
gd_values = gd.log_data["function_value"]
bb_values = bb.log_data["function_value"]
rn_values = rn.log_data["function_value"]
arnt_values = arnt.log_data["function_value"]
# 计算所有数据中的最小值
min_value = min(min(gd_values), min(bb_values), min(rn_values), min(arnt_values))
# 对数据进行处理:减去最小值后取对数
gd_processed = torch.log(torch.tensor(gd_values) - min_value + 1e-10)
bb_processed = torch.log(torch.tensor(bb_values) - min_value + 1e-10)
rn_processed = torch.log(torch.tensor(rn_values) - min_value + 1e-10)
arnt_processed = torch.log(torch.tensor(arnt_values) - min_value + 1e-10)
plt.plot(gd.log_data["iteration"], gd_processed, label="Gradient Descent")
plt.plot(bb.log_data["iteration"], bb_processed, label="BB Method")
plt.plot(rn.log_data["iteration"], rn_processed, label="Regularized Newton")
plt.plot(
arnt.log_data["iteration"],
arnt_processed,
label="Adaptive Regularized Newton Trust",
)
plt.xlabel("Iterations")
plt.ylabel("Log(Function Value - Minimum)")
plt.title("Optimization Methods Convergence Comparison")
plt.legend()
plt.grid(True)
plt.savefig("optimization_comparison.png")
plt.close()