Skip to content

Commit

Permalink
implement bootstrap for MMD
Browse files Browse the repository at this point in the history
  • Loading branch information
wangronin committed Jul 3, 2024
1 parent ea07736 commit 39903bb
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 31 deletions.
34 changes: 15 additions & 19 deletions examples/MMD/example_DTLZ1.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import sys

sys.path.insert(0, "./")
import random

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -10,7 +9,7 @@
from sklearn_extra.cluster import KMedoids

from hvd.delta_p import GenerationalDistance, InvertedGenerationalDistance
from hvd.mmd_newton import MMDNewton
from hvd.mmd_newton import MMDNewton, bootstrap_reference_set
from hvd.newton import DpN
from hvd.problems import DTLZ1, PymooProblemWithAD
from hvd.reference_set import ClusteredReferenceSet
Expand All @@ -30,29 +29,22 @@
rcParams["ytick.major.width"] = 1

np.random.seed(66)
max_iters = 5
max_iters = 10
f = DTLZ1(boundry_constraints=True)
problem = PymooProblemWithAD(f)
pareto_front = problem.get_pareto_front()
reference_point = {
"DTLZ[1-6]": np.array([1, 1, 1]),
"DTLZ7": np.array([1, 1, 6]),
"IDTLZ1[1-4]": np.array([1, 1, 1]),
}

# read the reference set data
ref_ = pd.read_csv("./DTLZ1/DTLZ1_RANDOM_run_1_ref_1_gen0.csv", header=None).values
X0 = pd.read_csv("./DTLZ1/DTLZ1_RANDOM_run_1_lastpopu_x_gen0.csv", header=None).values
Y0 = pd.read_csv("./DTLZ1/DTLZ1_RANDOM_run_1_lastpopu_y_gen0.csv", header=None).values
eta = {0: pd.read_csv("./DTLZ1/DTLZ1_RANDOM_run_1_eta_1_gen0.csv", header=None).values.ravel()}
Y_idx = None

method = "alternate"
km = KMedoids(n_clusters=50, method=method, random_state=0, init="k-medoids++").fit(ref_)
ref_ = ref_[km.medoid_indices_]
km = KMedoids(n_clusters=50, method=method, random_state=0, init="k-medoids++").fit(Y0)
Y0 = Y0[km.medoid_indices_]
X0 = X0[km.medoid_indices_]

# select a subset of the reference data "evenly" since 300 points are taking too long for MMD
# km = KMedoids(n_clusters=50, method="alternate", random_state=0, init="k-medoids++").fit(ref_)
# ref_ = ref_[km.medoid_indices_]
# km = KMedoids(n_clusters=50, method="alternate", random_state=0, init="k-medoids++").fit(Y0)
# Y0 = Y0[km.medoid_indices_]
# X0 = X0[km.medoid_indices_]
N = len(X0)
ref = ClusteredReferenceSet(ref=ref_, eta=eta, Y_idx=Y_idx)
metrics = dict(GD=GenerationalDistance(pareto_front), IGD=InvertedGenerationalDistance(pareto_front))
Expand All @@ -73,9 +65,12 @@
max_iters=max_iters,
verbose=True,
metrics=metrics,
preconditioning=True,
preconditioning=False,
)
X, Y, _ = opt.run()
if 11 < 2:
X, Y, _ = bootstrap_reference_set(opt, problem, ref_, 5)
else:
X, Y, _ = opt.run()
Y = get_non_dominated(Y)
igd_mmd = igd.compute(Y=Y)

Expand Down Expand Up @@ -150,4 +145,5 @@
ax1.set_ylabel(r"$f_3$")

plt.tight_layout()
plt.show()
plt.savefig(f"MMD-{f.__class__.__name__}.pdf", dpi=1000)
59 changes: 49 additions & 10 deletions hvd/mmd_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .base import State
from .mmd import MMDMatching
from .reference_set import ClusteredReferenceSet
from .utils import get_logger, precondition_hessian, set_bounds
from .utils import compute_chim, get_logger, get_non_dominated, precondition_hessian, set_bounds

np.seterr(divide="ignore", invalid="ignore")

Expand Down Expand Up @@ -102,14 +102,7 @@ def __init__(
self.state: State = State(self.dim_p, self.n_eq, self.n_ieq, func, jac, h, h_jac, g, g_jac)
# TODO: move indicator out of this class
self.indicator = MMDMatching(
self.dim_p,
self.n_obj,
ref=self.ref,
func=func,
jac=jac,
hessian=hessian,
theta=1.0 / N,
beta=0.3,
self.dim_p, self.n_obj, self.ref, func, jac, hessian, theta=1.0 / N, beta=0.2
)
self._initialize(X0)
self._set_logging(verbose)
Expand Down Expand Up @@ -287,7 +280,6 @@ def _shift_reference_set(self):
masks = np.bitwise_and(np.isclose(distance, 0), np.isclose(step_len, 0))
indices = np.nonzero(masks)[0]
self.ref.shift(0.08, indices)
self.indicator.ref = self.ref # TODO: check if this is needed
for k in indices: # log the updated medoids
self.history_medoids[k].append(self.ref.medoids[k].copy())
self.logger.info(f"{len(indices)} target points are shifted")
Expand Down Expand Up @@ -376,3 +368,50 @@ def _handle_box_constraint(self, step: np.ndarray) -> Tuple[np.ndarray, np.ndarr
s = np.array([dist[i] / np.abs(np.minimum(0, vv)) for i, vv in enumerate(v)])
max_step_size = np.array([min(1, np.nanmin(_)) for _ in s])
return step, max_step_size


def bootstrap_reference_set(
optimizer, problem, init_ref: np.ndarray, interval: int = 5
) -> Tuple[np.ndarray, np.ndarray, Dict]:
"""Bootstrap the reference set with the intermediate population of an MOO algorithm
Args:
optimizer (_type_): an MOO algorithm
problem (_type_): the MOO problem to solve
init_ref (np.ndarray): the initial reference set
interval (int, optional): intervals at which bootstrapping is performed. Defaults to 5.
Returns:
Tuple[np.ndarray, np.ndarray, Dict]: (efficient set, Pareto front approximation,
the stopping criteria)
"""
for i in range(optimizer.max_iters):
if i % interval == 0 and i > 0:
ref_ = np.r_[optimizer.state.Y, init_ref]
ref_ = get_non_dominated(ref_)
eta = compute_chim(ref_)
ref_ += 0.08 * eta
Y_idx = None
ref = ClusteredReferenceSet(ref=ref_, eta={0: eta}, Y_idx=Y_idx)
optimizer.ref = ref
optimizer.indicator.ref = ref
optimizer.indicator.compute(Y=optimizer.state.Y) # To compute the medoids

if 11 < 2:
import matplotlib.pyplot as plt

pareto_front = problem.get_pareto_front()
m = ref.medoids
fig = plt.figure(figsize=plt.figaspect(1))
plt.subplots_adjust(bottom=0.08, top=0.9, right=0.93, left=0.05)
ax0 = fig.add_subplot(1, 1, 1, projection="3d")
ax0.set_box_aspect((1, 1, 1))
ax0.view_init(45, 45)
ax0.plot(ref_[:, 0], ref_[:, 1], ref_[:, 2], "g.", ms=6, alpha=0.6)
ax0.plot(pareto_front[:, 0], pareto_front[:, 1], pareto_front[:, 2], "k.", ms=6, alpha=0.6)
ax0.plot(m[:, 0], m[:, 1], m[:, 2], "r+", ms=6, alpha=0.6)
plt.show()

optimizer.newton_iteration()
optimizer.log()
return optimizer.state.primal, optimizer.state.Y, optimizer.stop_dict
3 changes: 1 addition & 2 deletions hvd/reference_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ def _cluster(self, component: np.ndarray, k: int) -> np.ndarray:
"""
if len(component) == k:
return component
method = "alternate"
km = KMedoids(n_clusters=k, method=method, random_state=0, init="k-medoids++").fit(component)
km = KMedoids(n_clusters=k, method="alternate", random_state=0, init="k-medoids++").fit(component)
return component[km.medoid_indices_]

def _match(self, X: np.ndarray, Y: np.ndarray) -> np.ndarray:
Expand Down

0 comments on commit 39903bb

Please sign in to comment.