Skip to content

Commit

Permalink
Update SARunner to be able to take decay objects directly in temperat…
Browse files Browse the repository at this point in the history
…ure list
  • Loading branch information
hiive committed Mar 1, 2021
1 parent 4dbcf15 commit 1658a73
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions mlrose_hiive/runners/sa_runner.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
import mlrose_hiive
from mlrose_hiive.decorators import short_name
from mlrose_hiive.runners._runner_base import _RunnerBase
import numpy as np

"""
Example usage:
experiment_name = 'example_experiment'
problem = TSPGenerator.generate(seed=SEED, number_of_cities=22)
# note that you can also initialize a temperature_list this way
# temperature_list = [mlrose_hiive.GeomDecay(init_temp=t, decay=d) for (t, d) in [(1, 0.99), (1e2, 0.999)]]
# if you use this form, the decay_list parameter is ignored.
sa = SARunner(problem=problem,
experiment_name=experiment_name,
output_directory=OUTPUT_DIRECTORY,
seed=SEED,
iteration_list=2 ** np.arange(14),
max_attempts=5000,
temperature_list=[1, 10, 50, 100, 250, 500, 1000, 2500, 5000, 10000])
temperature_list=[1, 10, 50, 100, 250, 500, 1000, 2500, 5000, 10000],
decay_list=[mlrose_hiive.GeomDecay])
# the two data frames will contain the results
df_run_stats, df_run_curves = sa.run()
Expand All @@ -29,12 +35,18 @@ def __init__(self, problem, experiment_name, seed, iteration_list, temperature_l
super().__init__(problem=problem, experiment_name=experiment_name, seed=seed, iteration_list=iteration_list,
max_attempts=max_attempts, generate_curves=generate_curves,
**kwargs)
self.use_raw_temp = True
self.temperature_list = temperature_list
if decay_list is None:
decay_list = [mlrose_hiive.GeomDecay]
self.decay_list = decay_list
if all([np.isscalar(x) for x in temperature_list]):
print('all numbers')

if decay_list is None:
decay_list = [mlrose_hiive.GeomDecay]
self.decay_list = decay_list
self.use_raw_temp = False

def run(self):
temperatures = [decay(init_temp=t) for t in self.temperature_list for decay in self.decay_list]
temperatures = self.temperature_list if self.use_raw_temp else [d(init_temp=t) for t in self.temperature_list
for d in self.decay_list]
return super().run_experiment_(algorithm=mlrose_hiive.simulated_annealing,
schedule=('Temperature', temperatures))

0 comments on commit 1658a73

Please sign in to comment.