Skip to content

Commit

Permalink
updating example
Browse files Browse the repository at this point in the history
  • Loading branch information
harsh306 committed Jun 7, 2021
1 parent 2b26fdf commit a172f31
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 23 deletions.
69 changes: 47 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ pip install continuation-jax

```python
import cjax
help(cjax)
```

#### Simple Math on Pytrees
#### Math operations on Pytrees
```python
>>> import cjax
>>> from cjax.utils import math_trees
Expand All @@ -35,52 +34,78 @@ help(cjax)

#### Examples:
- Examples: https://github.com/harsh306/continuation-jax/tree/main/examples
- Sample Runner: https://github.com/harsh306/continuation-jax/blob/main/run.py
- Sample Runner: https://github.com/harsh306/continuation-jax/blob/main/model_simple_classifier/run.py

```python
"""
Main file to run contination on the user defined problem. Examples can be found in the examples/ directory.
Continuation is topological procedure to train a neural network. This module tracks all
the critical points or fixed points and dumps them to output file provided in hparams.json file.
Typical usage example:
continuation = ContinuationCreator(
problem=problem, hparams=hparams
).get_continuation_method()
continuation.run()
"""
from cjax.continuation.creator.continuation_creator import ContinuationCreator
from examples.toy.vectror_pitchfork import SigmoidFold
from examples.model_simple_classifier.model_classifier import ModelContClassifier
from cjax.utils.abstract_problem import ProblemWraper
import json
from jax.config import config
from datetime import datetime
from cjax.utils.visualizer import bif_plot, pick_array
import mlflow
from cjax.utils.visualizer import pick_array, bif_plot

config.update("jax_debug_nans", True)

# TODO: use **kwargs to reduce params

if __name__ == "__main__":
problem = SigmoidFold()
problem = ModelContClassifier()
problem = ProblemWraper(problem)

with open(problem.HPARAMS_PATH, "r") as hfile:
hparams = json.load(hfile)
start_time = datetime.now()

if hparams["n_perturbs"] > 1:
for perturb in range(hparams["n_perturbs"]):
print(f"Running perturb {perturb}")
mlflow.set_tracking_uri(hparams['meta']["mlflow_uri"])
mlflow.set_experiment(hparams['meta']["name"])

with mlflow.start_run(run_name=hparams['meta']["method"]+"-"+hparams["meta"]["optimizer"]) as run:
mlflow.log_dict(hparams, artifact_file="hparams/hparams.json")
mlflow.log_text("", artifact_file="output/_touch.txt")
artifact_uri = mlflow.get_artifact_uri("output/")
hparams["meta"]["output_dir"] = artifact_uri
print(f"URI: {artifact_uri}")
start_time = datetime.now()

if hparams["n_perturbs"] > 1:
for perturb in range(hparams["n_perturbs"]):
print(f"Running perturb {perturb}")
continuation = ContinuationCreator(
problem=problem, hparams=hparams, key=perturb
).get_continuation_method()
continuation.run()
else:
continuation = ContinuationCreator(
problem=problem, hparams=hparams, key=perturb
problem=problem, hparams=hparams
).get_continuation_method()
continuation.run()
else:
continuation = ContinuationCreator(
problem=problem, hparams=hparams
).get_continuation_method()
continuation.run()

end_time = datetime.now()
print(f"Duration: {end_time-start_time}")

bif_plot(hparams['output_dir'], pick_array, hparams['n_perturbs'])
end_time = datetime.now()
print(f"Duration: {end_time-start_time}")

figure = bif_plot(hparams["meta"]["output_dir"], pick_array)
mlflow.log_figure(figure, artifact_file="plots/fig.png")
```

#### Note on Hyperparameters

#### Papers:


#### Contact:
`harshnpathak@gmail.com`
2 changes: 1 addition & 1 deletion examples/toy/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
from cjax.continuation.creator.continuation_creator import ContinuationCreator
from examples.toy.vectror_pitchfork import SigmoidFold, PitchForkProblem
from examples.toy.vectror_pitchfork import PitchForkProblem
from cjax.utils.abstract_problem import ProblemWraper
import json
from jax.config import config
Expand Down

0 comments on commit a172f31

Please sign in to comment.