-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmisc.py
147 lines (108 loc) · 4.41 KB
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import operator
import time
from operator import itemgetter
import diffrax
import ipdb
import jax
import jax.numpy as np
import matplotlib
import matplotlib.pyplot as pl
import meshcat
import meshcat.geometry as geom
import meshcat.transformations as tf
import numpy as onp
import tqdm
from jax.tree_util import tree_map as jtm
import nn_utils
import plotting_utils
import pontryagin_utils
import visualiser
# various small utility functions.
def rnd(a, b):
# relative norm difference. useful for checking if matrices or vectors are close
return np.linalg.norm(a - b) / np.maximum(np.linalg.norm(a), np.linalg.norm(b))
def count_floats(pytree):
# counts total number of elements of a pytree with jax array nodes.
float_dtype = np.zeros(1).dtype # will be float32 or float64 depending on settings.
node_sizes = jax.tree_util.tree_map(lambda n: n.size if n.dtype == float_dtype else 0, pytree)
return jax.tree_util.tree_reduce(operator.add, node_sizes)
def find_max_l_bisection(sols, v_k, problem_params):
# this function is unnecessary (aimed towards a misguided goal). but it has
# a working implementation of bisection across time axis. come here if we need
# this again anytime.
def state_at_vk_bisection(sol):
# find the state at which v = v_k by bisection on the time axis.
# everything using the interpolated solution.
# how many do we need? log2(time interval / final t tolerance)
# this here = log2(5 / 5e-6)
iters = 20
# initially we assume left > v_k, right < v_k
# (v monotonously decreasing.)
def f_scan(time_interval, input):
# time_interval needs to be np.array of shape (2,) for jax.lax.select to work.
left, right = time_interval
# does it work even if we take another convex combination here?
# to account for "skewedness" of the v(t) function?
# probably not worth the marginal gains. regula falsi exists too
# but plain bisection is good enough.
mid = time_interval.mean()
vmid = sol.evaluate(mid)['v']
# if vmid is lower, mid becomes right
# if higher, mid becomes left.
next_time_interval = jax.lax.select(
vmid < v_k,
np.array([left, mid]), # on_true
np.array([mid, right]), # on_false
)
return next_time_interval, vmid
init_time_interval = np.array([sol.t1, sol.t0])
# if v_k is not in this interval, the bisection result is meaningless.
# maybe smarter to check in the end if vmid reaches our target?
init_v_interval = jax.vmap(sol.evaluate)(init_time_interval)['v']
result_usable = np.logical_and(init_v_interval[0] >= v_k, v_k >= init_v_interval[1])
ts_final, vmids = jax.lax.scan(f_scan, init_time_interval, None, length=100)
# if v_k not in interval replace the result by NaN :)
ts_final = ts_final + (np.nan * ~result_usable)
state = sol.evaluate(ts_final.mean())
return state
# TODO consider case where the solution does not intersect the value level.
# sol0 = jtm(itemgetter(0), sols)
# state = state_at_vk_bisection(sol0)
# ipdb.set_trace()
ys = jax.vmap(state_at_vk_bisection)(sols)
def l_of_y(y):
x = y['x']
vx = y['vx']
u = pontryagin_utils.u_star_2d(x, vx, problem_params)
return problem_params['l'](x, u)
all_ls = jax.vmap(l_of_y)(ys)
ipdb.set_trace()
max_l = np.nanmax(all_ls)
return max_l
import sys
import contextlib
class IndentStdout:
def __init__(self, num_spaces):
self.num_spaces = num_spaces
self._original_stdout = sys.stdout
def write(self, message):
# Prepend spaces to each line
indented_message = ''.join(f'{" " * self.num_spaces}{line}\n' if line else '\n'
for line in message.splitlines())
self._original_stdout.write(indented_message)
def flush(self):
self._original_stdout.flush()
@contextlib.contextmanager
def stdout_spaces(num_spaces):
new_stdout = IndentStdout(num_spaces)
original_stdout = sys.stdout
sys.stdout = new_stdout
try:
yield
finally:
sys.stdout = original_stdout
# Example usage:
if __name__ == "__main__":
with stdout_spaces(4):
print("This is a test.")
print("Each line will be indented.")