Skip to content

Commit

Permalink
Merge pull request SciML#789 from sathvikbhagavan/sb/remove_flux
Browse files Browse the repository at this point in the history
Remove Flux Support
  • Loading branch information
ChrisRackauckas authored Feb 1, 2024
2 parents a0589ae + c303e13 commit e45047d
Show file tree
Hide file tree
Showing 29 changed files with 2,018 additions and 3,093 deletions.
11 changes: 5 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Integrals = "de52edbc-65ea-441a-8357-d3a637375a31"
Expand All @@ -26,8 +25,8 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -53,7 +52,7 @@ DiffEqNoiseProcess = "5.1"
Distributions = "0.23, 0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
DomainSets = "0.6, 0.7"
Flux = "0.13, 0.14"
Flux = "0.14"
ForwardDiff = "0.10"
Functors = "0.4"
Integrals = "4"
Expand All @@ -64,8 +63,8 @@ MCMCChains = "6"
ModelingToolkit = "8"
MonteCarloMeasurements = "1"
Optim = "1.7.8"
Optimisers = "0.2, 0.3"
Optimization = "3"
OptimizationOptimisers = "0.1"
QuasiMonteCarlo = "0.3.2"
Reexport = "1.0"
RuntimeGeneratedFunctions = "0.5"
Expand All @@ -80,15 +79,15 @@ julia = "1.6"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["Test", "CUDA", "SafeTestsets", "OptimizationOptimisers", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "cuDNN", "LuxCUDA"]
test = ["Test", "CUDA", "SafeTestsets", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "cuDNN", "LuxCUDA", "Flux"]
47 changes: 20 additions & 27 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# HIGH level API for BPINN ODE solver

"""
```julia
BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCargs = (n_leapfrog=30), nchains = 1, init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8, Metric = DiagEuclideanMetric),
Integratorkwargs = (Integrator = Leapfrog,), autodiff = false,
progress = false, verbose = false)
```
BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCargs = (n_leapfrog=30), nchains = 1, init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8, Metric = DiagEuclideanMetric),
Integratorkwargs = (Integrator = Leapfrog,), autodiff = false,
progress = false, verbose = false)
Algorithm for solving ordinary differential equations using a Bayesian neural network. This is a specialization
of the physics-informed neural network which is used as a solver for a standard `ODEProblem`.
Expand All @@ -22,7 +20,7 @@ of the physics-informed neural network which is used as a solver for a standard
## Positional Arguments
* `chain`: A neural network architecture, defined as either a `Flux.Chain` or a `Lux.AbstractExplicitLayer`.
* `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer`.
* `Kernel`: Choice of MCMC Sampling Algorithm. Defaults to `AdvancedHMC.HMC`
## Keyword Arguments
Expand All @@ -46,18 +44,18 @@ dataset = [x̂, time]
chainlux = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh), Lux.Dense(6, 1))
alg = NeuralPDE.BNNODE(chainlux, draw_samples = 2000,
alg = BNNODE(chainlux, draw_samples = 2000,
l2std = [0.05], phystd = [0.05],
priorsNNw = (0.0, 3.0), progress = true)
sol_lux = solve(prob, alg)
# with parameter estimation
alg = NeuralPDE.BNNODE(chainlux,dataset = dataset,
draw_samples = 2000,l2std = [0.05],
phystd = [0.05],priorsNNw = (0.0, 10.0),
param = [Normal(6.5, 0.5), Normal(-3, 0.5)],
progress = true)
alg = BNNODE(chainlux,dataset = dataset,
draw_samples = 2000,l2std = [0.05],
phystd = [0.05],priorsNNw = (0.0, 10.0),
param = [Normal(6.5, 0.5), Normal(-3, 0.5)],
progress = true)
sol_lux_pestim = solve(prob, alg)
```
Expand All @@ -74,11 +72,10 @@ is an accurate interpolation (up to the neural network training result). In addi
## References
Liu Yanga, Xuhui Menga, George Em Karniadakis. "B-PINNs: Bayesian Physics-Informed Neural Networks for
Forward and Inverse PDE Problems with Noisy Data"
Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, Ellen Kuhl.
"Bayesian Physics Informed Neural Networks for real-world nonlinear dynamical systems"
Forward and Inverse PDE Problems with Noisy Data".
Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, Ellen Kuhl
"Bayesian Physics Informed Neural Networks for real-world nonlinear dynamical systems".
"""
struct BNNODE{C, K, IT <: NamedTuple,
A <: NamedTuple, H <: NamedTuple,
Expand Down Expand Up @@ -116,6 +113,7 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
autodiff = false, progress = false, verbose = false)
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
BNNODE(chain, Kernel, strategy,
draw_samples, priorsNNw, param, l2std,
phystd, dataset, physdt, MCMCkwargs,
Expand Down Expand Up @@ -222,13 +220,8 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
luxar = [chain(t', θ[i], st)[1] for i in 1:numensemble]
# only need for size
θinit = collect(ComponentArrays.ComponentArray(θinit))
elseif chain isa Flux.Chain
θinit, re1 = Flux.destructure(chain)
out = re1.([samples[i][1:(end - ninv)]
for i in (draw_samples - numensemble):draw_samples])
luxar = collect(out[i](t') for i in eachindex(out))
else
throw(error("Only Lux.AbstractExplicitLayer and Flux.Chain neural networks are supported"))
throw(error("Only Lux.AbstractExplicitLayer neural networks are supported"))
end

# contructing ensemble predictions
Expand Down Expand Up @@ -272,4 +265,4 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
end

BPINNsolution(fullsolution, ensemblecurves, estimnnparams, estimated_params, t)
end
end
7 changes: 3 additions & 4 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using Reexport, Statistics
using Zygote, ForwardDiff, Random, Distributions
using Adapt, DiffEqNoiseProcess, StochasticDiffEq
using Optimization
using OptimizationOptimisers
using Integrals, Cubature
using QuasiMonteCarlo
using RuntimeGeneratedFunctions
Expand All @@ -24,14 +25,12 @@ using Symbolics: wrap, unwrap, arguments, operation
using SymbolicUtils
using AdvancedHMC, LogDensityProblems, LinearAlgebra, Functors, MCMCChains
using MonteCarloMeasurements

import ModelingToolkit: value, nameof, toexpr, build_expr, expand_derivatives
import DomainSets: Domain, ClosedInterval
import ModelingToolkit: Interval, infimum, supremum #,Ball
import SciMLBase: @add_kwonly, parameterless_type
import Optimisers
import UnPack: @unpack
import ChainRulesCore, Flux, Lux, ComponentArrays
import ChainRulesCore, Lux, ComponentArrays
import ChainRulesCore: @non_differentiable

RuntimeGeneratedFunctions.init(@__MODULE__)
Expand All @@ -45,7 +44,7 @@ include("symbolic_utilities.jl")
include("training_strategies.jl")
include("adaptive_losses.jl")
include("ode_solve.jl")
include("rode_solve.jl")
# include("rode_solve.jl")
include("dae_solve.jl")
include("transform_inf_integral.jl")
include("discretize.jl")
Expand Down
Loading

0 comments on commit e45047d

Please sign in to comment.