Skip to content

Commit

Permalink
Merge pull request SciML#795 from sathvikbhagavan/sb/NNODE_paramestim
Browse files Browse the repository at this point in the history
feat: add Parameter estimation capability in NNODE
  • Loading branch information
ChrisRackauckas authored Feb 10, 2024
2 parents 1c8f138 + dc49fc6 commit 15dd978
Show file tree
Hide file tree
Showing 5 changed files with 167 additions and 33 deletions.
3 changes: 2 additions & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
pages = ["index.md",
"ODE PINN Tutorials" => Any["Introduction to NeuralPDE for ODEs" => "tutorials/ode.md",
"Bayesian PINNs for Coupled ODEs" => "tutorials/Lotka_Volterra_BPINNs.md",
"PINNs DAEs" => "tutorials/dae.md"
"PINNs DAEs" => "tutorials/dae.md",
"Parameter Estimation with PINNs for ODEs" => "tutorials/ode_parameter_estimation.md",
#"examples/nnrode_example.md", # currently incorrect
],
"PDE PINN Tutorials" => Any["Introduction to NeuralPDE for PDEs" => "tutorials/pdesystem.md",
Expand Down
88 changes: 88 additions & 0 deletions docs/src/tutorials/ode_parameter_estimation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Parameter Estimation with Physics-Informed Neural Networks for ODEs

Consider the [lotka volterra system](https://en.wikipedia.org/wiki/Lotka%E2%80%93Volterra_equations)

with Physics-Informed Neural Networks. Now we would consider the case where we want to optimize the parameters $\alpha$, $\beta$, $\gamma$ and $\delta$.

We start by defining the problem:

```@example param_estim_lv
using NeuralPDE, OrdinaryDiffEq
using Lux, Random
using OptimizationOptimJL, LineSearches
using Test # hide
function lv(u, p, t)
u₁, u₂ = u
α, β, γ, δ = p
du₁ = α * u₁ - β * u₁ * u₂
du₂ = δ * u₁ * u₂ - γ * u₂
[du₁, du₂]
end
tspan = (0.0, 5.0)
u0 = [5.0, 5.0]
prob = ODEProblem(lv, u0, tspan, [1.0, 1.0, 1.0, 1.0])
```

As we want to estimate the parameters as well, lets get some data.

```@example param_estim_lv
true_p = [1.5, 1.0, 3.0, 1.0]
prob_data = remake(prob, p = true_p)
sol_data = solve(prob_data, Tsit5(), saveat = 0.01)
t_ = sol_data.t
u_ = reduce(hcat, sol_data.u)
```

Now, lets define a neural network for the PINN using [Lux.jl](https://lux.csail.mit.edu/).

```@example param_estim_lv
rng = Random.default_rng()
Random.seed!(rng, 0)
n = 12
chain = Lux.Chain(
Lux.Dense(1, n, Lux.σ),
Lux.Dense(n, n, Lux.σ),
Lux.Dense(n, n, Lux.σ),
Lux.Dense(n, 2)
)
ps, st = Lux.setup(rng, chain) |> Lux.f64
```

Next we define an additional loss term to in the total loss which measures how the neural network's predictions is fitting the data.

```@example param_estim_lv
function additional_loss(phi, θ)
return sum(abs2, phi(t_, θ) .- u_)/size(u_, 2)
end
```

Next we define the optimizer and [`NNODE`](@ref) which is then plugged into the `solve` call.

```@example param_estim_lv
opt = LBFGS(linesearch = BackTracking())
alg = NNODE(chain, opt, ps; strategy = GridTraining(0.01), param_estim = true, additional_loss = additional_loss)
```

Now we have all the pieces to solve the optimization problem.

```@example param_estim_lv
sol = solve(prob, alg, verbose = true, abstol = 1e-8, maxiters = 5000, saveat = t_)
@test sol.k.u.p ≈ true_p rtol=1e-2 # hide
```

Let's plot the predictions from the PINN and compare it to the data.

```@example param_estim_lv
plot(sol, labels = ["u1_pinn" "u2_pinn"])
plot!(sol_data, labels = ["u1_data" "u2_data"])
```

We can see it is a good fit! Now lets see if we have the parameters of the equation also estimated correctly or not.

```@example param_estim_lv
sol.k.u.p
```

We can see it is indeed close to the true values [1.5, 1.0, 3.0, 1.0].
1 change: 1 addition & 0 deletions src/dae_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractDAEProblem,

if chain isa Lux.AbstractExplicitLayer || chain isa Flux.Chain
phi, init_params = generate_phi_θ(chain, t0, u0, init_params)
init_params = ComponentArrays.ComponentArray(; depvar = ComponentArrays.ComponentArray(init_params))
else
error("Only Lux.AbstractExplicitLayer and Flux.Chain neural networks are supported")
end
Expand Down
78 changes: 46 additions & 32 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ of the physics-informed neural network which is used as a solver for a standard
at a time. `batch>0` means the neural network is applied at a row vector of values
`t` simultaneously, i.e. it's the batch size for the neural network evaluations.
This requires a neural network compatible with batched data.
* `param_estim`: Boolean to indicate whether parameters of the differential equations are learnt along with parameters of the neural network.
* `strategy`: The training strategy used to choose the points for the evaluations.
Default of `nothing` means that `QuadratureTraining` with QuadGK is used if no
`dt` is given, and `GridTraining` is used with `dt` if given.
Expand Down Expand Up @@ -71,7 +72,7 @@ is an accurate interpolation (up to the neural network training result). In addi
Lagaris, Isaac E., Aristidis Likas, and Dimitrios I. Fotiadis. "Artificial neural networks for solving
ordinary and partial differential equations." IEEE Transactions on Neural Networks 9, no. 5 (1998): 987-1000.
"""
struct NNODE{C, O, P, B, K, AL <: Union{Nothing, Function},
struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function},
S <: Union{Nothing, AbstractTrainingStrategy},
} <:
NeuralPDEAlgorithm
Expand All @@ -81,14 +82,15 @@ struct NNODE{C, O, P, B, K, AL <: Union{Nothing, Function},
autodiff::Bool
batch::B
strategy::S
param_estim::PE
additional_loss::AL
kwargs::K
end
function NNODE(chain, opt, init_params = nothing;
strategy = nothing,
autodiff = false, batch = nothing, additional_loss = nothing, kwargs...)
autodiff = false, batch = nothing, param_estim = false, additional_loss = nothing, kwargs...)
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
NNODE(chain, opt, init_params, autodiff, batch, strategy, additional_loss, kwargs)
NNODE(chain, opt, init_params, autodiff, batch, strategy, param_estim, additional_loss, kwargs)
end

"""
Expand Down Expand Up @@ -119,29 +121,29 @@ end

function (f::ODEPhi{C, T, U})(t::Number,
θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number}
y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), [t]), θ, f.st)
y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata.depvar)), [t]), θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 + (t - f.t0) * first(y)
end

function (f::ODEPhi{C, T, U})(t::AbstractVector,
θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number}
# Batch via data as row vectors
y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), t'), θ, f.st)
y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata.depvar)), t'), θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t' .- f.t0) .* y
end

function (f::ODEPhi{C, T, U})(t::Number, θ) where {C <: Lux.AbstractExplicitLayer, T, U}
y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), [t]), θ, f.st)
y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata.depvar)), [t]), θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t .- f.t0) .* y
end

function (f::ODEPhi{C, T, U})(t::AbstractVector,
θ) where {C <: Lux.AbstractExplicitLayer, T, U}
# Batch via data as row vectors
y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), t'), θ, f.st)
y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata.depvar)), t'), θ.depvar, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
f.u0 .+ (t' .- f.t0) .* y
end
Expand Down Expand Up @@ -187,28 +189,32 @@ Simple L2 inner loss at a time `t` with parameters `θ` of the neural network.
function inner_loss end

function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,
p) where {C, T, U <: Number}
sum(abs2, ode_dfdx(phi, t, θ, autodiff) - f(phi(t, θ), p, t))
p, param_estim::Bool) where {C, T, U <: Number}
p_ = param_estim ? θ.p : p
sum(abs2, ode_dfdx(phi, t, θ, autodiff) - f(phi(t, θ), p_, t))
end

function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ,
p) where {C, T, U <: Number}
p, param_estim::Bool) where {C, T, U <: Number}
p_ = param_estim ? θ.p : p
out = phi(t, θ)
fs = reduce(hcat, [f(out[i], p, t[i]) for i in axes(out, 2)])
fs = reduce(hcat, [f(out[i], p_, t[i]) for i in axes(out, 2)])
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
sum(abs2, dxdtguess .- fs) / length(t)
end

function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ,
p) where {C, T, U}
sum(abs2, ode_dfdx(phi, t, θ, autodiff) .- f(phi(t, θ), p, t))
p, param_estim::Bool) where {C, T, U}
p_ = param_estim ? θ.p : p
sum(abs2, ode_dfdx(phi, t, θ, autodiff) .- f(phi(t, θ), p_, t))
end

function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ,
p) where {C, T, U}
p, param_estim::Bool) where {C, T, U}
p_ = param_estim ? θ.p : p
out = Array(phi(t, θ))
arrt = Array(t)
fs = reduce(hcat, [f(out[:, i], p, arrt[i]) for i in 1:size(out, 2)])
fs = reduce(hcat, [f(out[:, i], p_, arrt[i]) for i in 1:size(out, 2)])
dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff))
sum(abs2, dxdtguess .- fs) / length(t)
end
Expand All @@ -219,10 +225,10 @@ end
Representation of the loss function, parametric on the training strategy `strategy`.
"""
function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tspan, p,
batch)
integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p))
batch, param_estim::Bool)
integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim))

integrand(ts, θ) = [abs2(inner_loss(phi, f, autodiff, t, θ, p)) for t in ts]
integrand(ts, θ) = [abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim)) for t in ts]
@assert batch == 0 # not implemented

function loss(θ, _)
Expand All @@ -234,39 +240,39 @@ function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tsp
return loss
end

function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, batch)
function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, batch, param_estim::Bool)
ts = tspan[1]:(strategy.dx):tspan[2]
# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
autodiff && throw(ArgumentError("autodiff not supported for GridTraining."))
function loss(θ, _)
if batch
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p))
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, param_estim))
else
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts])
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p, param_estim) for t in ts])
end
end
return loss
end

function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p,
batch)
batch, param_estim::Bool)
# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
autodiff && throw(ArgumentError("autodiff not supported for StochasticTraining."))
function loss(θ, _)
ts = adapt(parameterless_type(θ),
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)])

if batch
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p))
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, param_estim))
else
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts])
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p, param_estim) for t in ts])
end
end
return loss
end

function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p,
batch)
batch, param_estim::Bool)
autodiff && throw(ArgumentError("autodiff not supported for WeightedIntervalTraining."))
minT = tspan[1]
maxT = tspan[2]
Expand All @@ -289,22 +295,22 @@ function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Boo

function loss(θ, _)
if batch
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p))
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p, param_estim))
else
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in ts])
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p, param_estim) for t in ts])
end
end
return loss
end

function evaluate_tstops_loss(phi, f, autodiff::Bool, tstops, p, batch)
function evaluate_tstops_loss(phi, f, autodiff::Bool, tstops, p, batch, param_estim::Bool)

# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
function loss(θ, _)
if batch
sum(abs2, inner_loss(phi, f, autodiff, tstops, θ, p))
sum(abs2, inner_loss(phi, f, autodiff, tstops, θ, p, param_estim))
else
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p) for t in tstops])
sum(abs2, [inner_loss(phi, f, autodiff, t, θ, p, param_estim) for t in tstops])
end
end
return loss
Expand Down Expand Up @@ -351,6 +357,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
f = prob.f
p = prob.p
t0 = tspan[1]
param_estim = alg.param_estim

#hidden layer
chain = alg.chain
Expand All @@ -363,6 +370,12 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
!(chain isa Lux.AbstractExplicitLayer) && error("Only Lux.AbstractExplicitLayer neural networks are supported")
phi, init_params = generate_phi_θ(chain, t0, u0, init_params)

init_params = if alg.param_estim
ComponentArrays.ComponentArray(; depvar = ComponentArrays.ComponentArray(init_params), p = prob.p)
else
ComponentArrays.ComponentArray(; depvar = ComponentArrays.ComponentArray(init_params))
end

isinplace(prob) && throw(error("The NNODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."))

try
Expand Down Expand Up @@ -398,8 +411,9 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
alg.batch
end

inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, batch)
inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, batch, param_estim)
additional_loss = alg.additional_loss
(param_estim && isnothing(additional_loss)) && throw(ArgumentError("Please provide `additional_loss` in `NNODE` for parameter estimation (`param_estim` is true)."))

# Creates OptimizationFunction Object from total_loss
function total_loss(θ, _)
Expand All @@ -409,7 +423,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
end
if !(tstops isa Nothing)
num_tstops_points = length(tstops)
tstops_loss_func = evaluate_tstops_loss(phi, f, autodiff, tstops, p, batch)
tstops_loss_func = evaluate_tstops_loss(phi, f, autodiff, tstops, p, batch, param_estim)
tstops_loss = tstops_loss_func(θ, phi)
if strategy isa GridTraining
num_original_points = length(tspan[1]:(strategy.dx):tspan[2])
Expand Down
30 changes: 30 additions & 0 deletions test/NNODE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Random, NeuralPDE
using OrdinaryDiffEq, Statistics
import Lux, OptimizationOptimisers, OptimizationOptimJL
using Flux
using LineSearches

Random.seed!(100)

Expand Down Expand Up @@ -217,6 +218,35 @@ end
end
end

@testset "Parameter Estimation" begin
function lorenz(u, p, t)
return [p[1]*(u[2]-u[1]),
u[1]*(p[2]-u[3])-u[2],
u[1]*u[2]-p[3]*u[3]]
end
prob = ODEProblem(lorenz, [1.0, 0.0, 0.0], (0.0, 1.0), [1.0, 1.0, 1.0])
true_p = [2.0, 3.0, 2.0]
prob2 = remake(prob, p = true_p)
sol = solve(prob2, Tsit5(), saveat = 0.01)
t_ = sol.t
u_ = reduce(hcat, sol.u)
function additional_loss(phi, θ)
return sum(abs2, phi(t_, θ) .- u_)/100
end
n = 8
luxchain = Lux.Chain(
Lux.Dense(1, n, Lux.σ),
Lux.Dense(n, n, Lux.σ),
Lux.Dense(n, n, Lux.σ),
Lux.Dense(n, 3)
)
opt = OptimizationOptimJL.LBFGS(linesearch = BackTracking())
alg = NNODE(luxchain, opt, strategy = GridTraining(0.01), param_estim = true, additional_loss = additional_loss)
sol = solve(prob, alg, verbose = true, abstol = 1e-8, maxiters = 5000, saveat = t_)
@test sol.k.u.ptrue_p atol=1e-2
@test reduce(hcat, sol.u)u_ atol=1e-2
end

@testset "Translating from Flux" begin
linear = (u, p, t) -> cos(2pi * t)
linear_analytic = (u, p, t) -> (1 / (2pi)) * sin(2pi * t)
Expand Down

0 comments on commit 15dd978

Please sign in to comment.