-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use DifferentiationInterface for autodiff, allow ADTypes #153
base: master
Are you sure you want to change the base?
Changes from 1 commit
5d007f3
8be2259
7783026
dbf30ab
250fd89
ebaa4bd
5a7d1c4
30ec326
914c6c5
1374805
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,37 +43,16 @@ function OnceDifferentiable(f, x_seed::AbstractArray{T}, | |
|
||
return OnceDifferentiable(fF, dfF, fdfF, x_seed, F, DF) | ||
else | ||
if is_finitediff(autodiff) | ||
|
||
# Figure out which Val-type to use for FiniteDiff based on our | ||
# symbol interface. | ||
fdtype = finitediff_fdtype(autodiff) | ||
df_array_spec = DF | ||
x_array_spec = x_seed | ||
return_spec = typeof(F) | ||
gcache = FiniteDiff.GradientCache(df_array_spec, x_array_spec, fdtype, return_spec) | ||
|
||
function g!(storage, x) | ||
FiniteDiff.finite_difference_gradient!(storage, f, x, gcache) | ||
return | ||
end | ||
function fg!(storage, x) | ||
g!(storage, x) | ||
return f(x) | ||
end | ||
elseif is_forwarddiff(autodiff) | ||
gcfg = ForwardDiff.GradientConfig(f, x_seed, chunk) | ||
g! = (out, x) -> ForwardDiff.gradient!(out, f, x, gcfg) | ||
|
||
fg! = (out, x) -> begin | ||
gr_res = DiffResults.DiffResult(zero(T), out) | ||
ForwardDiff.gradient!(gr_res, f, x, gcfg) | ||
DiffResults.value(gr_res) | ||
end | ||
else | ||
error("The autodiff value $autodiff is not supported. Use :finite or :forward.") | ||
backend = get_adtype(autodiff, chunk) | ||
grad_prep = DI.prepare_gradient(f, backend, x_seed) | ||
function g!(_g, _x) | ||
DI.gradient!(f, _g, grad_prep, backend, _x) | ||
return nothing | ||
end | ||
function fg!(_g, _x) | ||
y, _ = DI.value_and_gradient!(f, _g, grad_prep, backend, _x) | ||
return y | ||
end | ||
|
||
return OnceDifferentiable(f, g!, fg!, x_seed, F, DF) | ||
end | ||
end | ||
|
@@ -99,72 +78,18 @@ function OnceDifferentiable(f, x_seed::AbstractArray, F::AbstractArray, DF::Abst | |
fdfF = make_fdf(f, x_seed, F) | ||
return OnceDifferentiable(fF, dfF, fdfF, x_seed, F, DF) | ||
else | ||
if is_finitediff(autodiff) | ||
# Figure out which Val-type to use for FiniteDiff based on our | ||
# symbol interface. | ||
fdtype = finitediff_fdtype(autodiff) | ||
# Apparently only the third input is aliased. | ||
j_finitediff_cache = FiniteDiff.JacobianCache(copy(x_seed), copy(F), copy(F), fdtype) | ||
if autodiff == :finiteforward | ||
# These copies can be done away with if we add a keyword for | ||
# reusing arrays instead for overwriting them. | ||
Fx = copy(F) | ||
DF = copy(DF) | ||
|
||
x_f, x_df = x_of_nans(x_seed), x_of_nans(x_seed) | ||
f_calls, j_calls = [0,], [0,] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we try to preserve this call count? I don't think it makes a lot of sense for autodiff in general, because some backends will not go through the actual code to compute the gradient (unlike ForwardDiff and FiniteDiff). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In addition, I don't think this call count was present for every operator or every autodiff method There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have to look. It's done this way because people complained that it didn't match with what they printed from their objective, but I agree that people will have to calculate the number themselves in the special case of finite diff |
||
function j_finiteforward!(J, x) | ||
# Exploit the possibility that it might be that x_f == x | ||
# then we don't have to call f again. | ||
|
||
# if at least one element of x_f is different from x, update | ||
if any(x_f .!= x) | ||
f(Fx, x) | ||
f_calls .+= 1 | ||
end | ||
|
||
FiniteDiff.finite_difference_jacobian!(J, f, x, j_finitediff_cache, Fx) | ||
end | ||
function fj_finiteforward!(F, J, x) | ||
f(F, x) | ||
FiniteDiff.finite_difference_jacobian!(J, f, x, j_finitediff_cache, F) | ||
end | ||
|
||
|
||
return OnceDifferentiable(f, j_finiteforward!, fj_finiteforward!, Fx, DF, x_f, x_df, f_calls, j_calls) | ||
end | ||
|
||
function fj_finitediff!(F, J, x) | ||
f(F, x) | ||
FiniteDiff.finite_difference_jacobian!(J, f, x, j_finitediff_cache) | ||
F | ||
end | ||
function j_finitediff!(J, x) | ||
F_cache = copy(F) | ||
fj_finitediff!(F_cache, J, x) | ||
end | ||
|
||
return OnceDifferentiable(f, j_finitediff!, fj_finitediff!, x_seed, F, DF) | ||
|
||
elseif is_forwarddiff(autodiff) | ||
|
||
jac_cfg = ForwardDiff.JacobianConfig(f, F, x_seed, chunk) | ||
ForwardDiff.checktag(jac_cfg, f, x_seed) | ||
|
||
F2 = copy(F) | ||
function j_forwarddiff!(J, x) | ||
ForwardDiff.jacobian!(J, f, F2, x, jac_cfg, Val{false}()) | ||
end | ||
function fj_forwarddiff!(F, J, x) | ||
jac_res = DiffResults.DiffResult(F, J) | ||
ForwardDiff.jacobian!(jac_res, f, F2, x, jac_cfg, Val{false}()) | ||
DiffResults.value(jac_res) | ||
end | ||
|
||
return OnceDifferentiable(f, j_forwarddiff!, fj_forwarddiff!, x_seed, F, DF) | ||
else | ||
error("The autodiff value $(autodiff) is not supported. Use :finite or :forward.") | ||
F2 = similar(F) | ||
backend = get_adtype(autodiff, chunk) | ||
jac_prep = DI.prepare_jacobian(f, F2, backend, x_seed) | ||
function j!(_j, _x) | ||
DI.jacobian!(f, F2, _j, jac_prep, backend, _x) | ||
return _j | ||
end | ||
function fj!(_y, _j, _x) | ||
y, _ = DI.value_and_jacobian!(f, _y, _j, jac_prep, backend, _x) | ||
return y | ||
end | ||
return OnceDifferentiable(f, j!, fj!, x_seed, F, DF) | ||
end | ||
end | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe these could be made callable structs to avoid closing over
c!
,ccache
,jac_prep
andbackend
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These were already closures in the existing code, I thought I'd go for the "minimum diff" changes and then we could improve later on
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I saw that. It seemed you're closing over more variables in the DI version though, so it might be even more valuable to change the design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I seem to recall that closures are not an issue if the variable we close over does not get assigned to more than once? So we might be good here?
In any case, to hunt down this kind of type inference barriers we would also need to improve annotation of functions in the various structs of the package.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know. Even if it's currently the case, personally I wouldn't rely on it since in my experience the exact behaviour of the compiler is inherently unstable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See for instance the explanations in https://discourse.julialang.org/t/can-someone-explain-closures-to-me/105605