## Environment and packages
using Pkg
Activating project at `~/Documents/GitHub/Julia-for-SciML/hands-on/lotka`
Sample code is modified from ChrisRackauckas/universsal_differential_equations. This is the part of the work of
Rackauckas, Christopher, et al. “Universal differential equations for scientific machine learning.” arXiv preprint arXiv:2001.04385 (2020).
Info Packages marked with ⌅ have new versions available but compatibility constraints restrict them from upgrading.
@info "Precompile"
using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra, ComponentArrays
using Optimization, OptimizationOptimisers, OptimizationOptimJL #OptimizationFlux for ADAM and OptimizationOptimJL for BFGS
using DiffEqSensitivity
using Lux
using Plots
using Statistics
# Set a random seed for reproduceable behaviour
using Random
rng = Random.default_rng()
@info "Complete Precompilation"
┌ Info: Precompile
└ @ Main In[4]:1
┌ Info: Precompiling ModelingToolkit [961ee093-0014-501f-94e3-6117800e7a78]
└ @ Base loading.jl:1664
┌ Info: Precompiling DataDrivenDiffEq [2445eb08-9709-466a-b3fc-47e12bd697a2]
└ @ Base loading.jl:1664
┌ Warning: The variable syntax (u[1:n])(t) is deprecated. Use (u(t))[1:n] instead.
│ The former creates an array of functions, while the latter creates an array valued function.
│ The deprecated syntax will cause an error in the next major release of Symbolics.
│ This change will facilitate better implementation of various features of Symbolics.
└ @ Symbolics ~/.julia/packages/Symbolics/FGTCH/src/variable.jl:129
┌ Warning: Type annotations on keyword arguments not currently supported in recipes. Type information has been discarded
└ @ RecipesBase ~/.julia/packages/RecipesBase/6AijY/src/RecipesBase.jl:117
┌ Info: Precompiling Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80]
└ @ Base loading.jl:1664
┌ Info: Complete Precompilation
└ @ Main In[4]:17
For simplicity, we use Lotka-Volterra system as an example
\[\begin{align} \dot{x} &= \alpha x - \beta xy\\ \dot{y} &= \gamma xy- \delta y \end{align}\]
where \(\alpha, \beta, \gamma\), and \(\delta\) are positive real parameters
## Data generation
function lotka!(du, u, p, t)
α, β, γ, δ = p
du[1] = α*u[1] - β*u[2]*u[1]
du[2] = γ*u[1]*u[2] - δ*u[2]
lotka! (generic function with 1 method)
# Define the experimental parameter
tspan = (0.0,3.0)
u0 = [0.44249296,4.6280594]
p_ = [1.3, 0.9, 0.8, 1.8]
4-element Vector{Float64}:
Vern7() is used for non-stiff problems. Numerous solvers can be found on DifferentialEquations.jl's doc
prob = ODEProblem(lotka!, u0,tspan, p_)
solution = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 0.1)
retcode: Success
Interpolation: 1st order linear
t: 31-element Vector{Float64}:
u: 31-element Vector{Vector{Float64}}:
[0.44249296, 4.6280594]
[0.34212452862086234, 3.98764547181634]
[0.2793966078254349, 3.4139529441083147]
[0.2394952228707143, 2.9110318130603883]
[0.21413620714095402, 2.4758280205419836]
[0.19854852659179129, 2.1022922430734137]
[0.18991187927524103, 1.7834096349202704]
[0.18652973211225643, 1.5121821427640152]
[0.18737918127509637, 1.2820806846455604]
[0.1918587411736629, 1.087227597605956]
[0.1996432344128222, 0.9224424008592909]
[0.2105985019620811, 0.7832199752377471]
[0.22473063540355143, 0.6656774980182895]
[0.4333056937367298, 0.22471175932636067]
[0.48425346211989406, 0.1947029152564331]
[0.5425361548950363, 0.16943926722620506]
[0.6091040110729008, 0.14819092695665834]
[0.6850407509453579, 0.13034710141497852]
[0.7715795653361799, 0.11539841080610512]
[0.8701212001899306, 0.10292221843899205]
[0.9822541897624152, 0.09257065810821445]
[1.1097772412678872, 0.08406114122763123]
[1.2547236687788759, 0.07716924328704818]
[1.419387582491876, 0.07172402271816045]
[1.606351205697802, 0.06760604257226555]
# Ideal data
X = Array(solution)
t = solution.t
DX = Array(solution(solution.t, Val{1}))
full_problem = DataDrivenProblem(X, t = t, DX = DX)
# Add noise in terms of the mean
x̄ = mean(X, dims = 2)
noise_magnitude = 5e-3
Xₙ = X .+ (noise_magnitude*x̄) .* randn(eltype(X), size(X))
2×31 Matrix{Float64}:
0.444955 0.344412 0.277873 0.246362 … 1.2542 1.41631 1.60945
4.6231 3.98748 3.40663 2.91876 0.0810739 0.0772539 0.0672219
Suppose we only know part of the Lotka-Voltera model, and use CNN to surrogate the unknown part
\[\begin{align} \dot{x} &= \theta_1 x + U_1(\theta_3, x, y)\\ \dot{y} &= -\theta_2 y + U_2(\theta_3..., x, y) \end{align}\]
## Define the network
# Gaussian RBF as activation
rbf(x) = exp.(-(x.^2))
# Multilayer FeedForward
U = Lux.Chain(
Lux.Dense(2,5,rbf), Lux.Dense(5,5, rbf), Lux.Dense(5,5, rbf), Lux.Dense(5,2)
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)
# Define the hybrid model
function ude_dynamics!(du,u, p, t, p_true)
û = U(u, p, st)[1] # Network prediction
du[1] = p_true[1]*u[1] + û[1]
du[2] = -p_true[4]*u[2] + û[2]
ude_dynamics! (generic function with 1 method)
# Closure with the known parameter
nn_dynamics!(du,u,p,t) = ude_dynamics!(du,u,p,t,p_)
# Define the problem (Fix: https://discourse.julialang.org/t/issue-with-ude-repository-lv-scenario-1/88618/5)
prob_nn = ODEProblem{true, SciMLBase.FullSpecialize}(nn_dynamics!,Xₙ[:, 1], tspan, p)
#prob_nn = ODEProblem(nn_dynamics!,Xₙ[:, 1], tspan, p)
ODEProblem with uType Vector{Float64} and tType Float64. In-place: true timespan: (0.0, 3.0) u0: 2-element Vector{Float64}: 0.44495468189157616 4.623098367786485
## Function to train the network
# Define a predictor
function predict(θ, X = Xₙ[:,1], T = t)
_prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
Array(solve(_prob, Vern7(), saveat = T,
abstol=1e-6, reltol=1e-6,
sensealg = ForwardDiffSensitivity()
# Simple L2 loss
function loss(θ)
X̂ = predict(θ)
sum(abs2, Xₙ .- X̂)
# Container to track the losses
losses = Float64[]
callback = function (p, l)
push!(losses, l)
if length(losses)%50==0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
return false
#1 (generic function with 1 method)
The training is splitted to two steps: 1. ADAM: for better convergence 2. BFGS: get better position
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))
res1 = Optimization.solve(optprob, ADAM(0.1), callback=callback, maxiters = 200)
@info "Training loss after $(length(losses)) iterations: $(losses[end])"
# Train with BFGS
@time optprob2 = Optimization.OptimizationProblem(optf, res1.minimizer)
@time res2 = Optimization.solve(optprob2, Optim.BFGS(initial_stepnorm=0.01), callback=callback, maxiters = 3000)
@info "Final training loss after $(length(losses)) iterations: $(losses[end])"
Current loss after 50 iterations: 3.2919945635723753
Current loss after 100 iterations: 1.7058558916650455
Current loss after 150 iterations: 1.6697049368588788
Current loss after 200 iterations: 1.6423084245679131
0.002839 seconds (946 allocations: 59.323 KiB, 97.18% compilation time)
┌ Info: Training loss after 201 iterations: 1.6423084245679131
└ @ Main In[13]:6
Current loss after 250 iterations: 0.023661287470332845
Current loss after 300 iterations: 0.013746817618177573
Current loss after 350 iterations: 0.0032781193503113072
Current loss after 400 iterations: 0.0017736994450182483
Current loss after 450 iterations: 0.0016452061747373918
Current loss after 500 iterations: 0.001415755561505311
Current loss after 550 iterations: 0.001226557886040906
Current loss after 600 iterations: 0.001107089612914693
Current loss after 650 iterations: 0.0009982934123468784
Current loss after 700 iterations: 0.0009875749757684356
Current loss after 750 iterations: 0.0009812594981828098
Current loss after 800 iterations: 0.0009796658875397177
Current loss after 850 iterations: 0.0009793537523280391
Current loss after 900 iterations: 0.0009772977493860333
Current loss after 950 iterations: 0.0009743301969273224
Current loss after 1000 iterations: 0.0009727084653298941
Current loss after 1050 iterations: 0.000972080483835656
Current loss after 1100 iterations: 0.0009717534715880698
Current loss after 1150 iterations: 0.0009656859249757323
Current loss after 1200 iterations: 0.000962561381268262
Current loss after 1250 iterations: 0.0009612859817110582
Current loss after 1300 iterations: 0.0009591475053966475
Current loss after 1350 iterations: 0.0009582234874165839
Current loss after 1400 iterations: 0.0009574676566527876
Current loss after 1450 iterations: 0.0009569122026897299
Current loss after 1500 iterations: 0.0009567527113130108
Current loss after 1550 iterations: 0.0009561199657093553
Current loss after 1600 iterations: 0.0009549884254749657
Current loss after 1650 iterations: 0.0009541600089060593
Current loss after 1700 iterations: 0.0009537452602046372
Current loss after 1750 iterations: 0.0009536569492935126
Current loss after 1800 iterations: 0.0009534426125349452
Current loss after 1850 iterations: 0.0009532818826279794
Current loss after 1900 iterations: 0.0009530863107623305
Current loss after 1950 iterations: 0.0009529698205412621
Current loss after 2000 iterations: 0.0009528567762890976
Current loss after 2050 iterations: 0.0009525373206446654
Current loss after 2100 iterations: 0.0009521593217738839
Current loss after 2150 iterations: 0.0009520958641338131
Current loss after 2200 iterations: 0.0009520568498258301
Current loss after 2250 iterations: 0.0009517387568709354
Current loss after 2300 iterations: 0.0009516121690135998
Current loss after 2350 iterations: 0.0009514360660860543
Current loss after 2400 iterations: 0.0009512526383059336
Current loss after 2450 iterations: 0.0009511495748527354
Current loss after 2500 iterations: 0.000951106817478341
Current loss after 2550 iterations: 0.0009510985159145336
Current loss after 2600 iterations: 0.0009509270459168483
Current loss after 2650 iterations: 0.000950863437522074
Current loss after 2700 iterations: 0.0009507791033767345
Current loss after 2750 iterations: 0.000950657136744276
Current loss after 2800 iterations: 0.0009506293508157388
Current loss after 2850 iterations: 0.0009506101911525282
Current loss after 2900 iterations: 0.0009504534699507922
Current loss after 2950 iterations: 0.0009504069207206256
Current loss after 3000 iterations: 0.0009503383212540466
Current loss after 3050 iterations: 0.0009503196958639988
Current loss after 3100 iterations: 0.0009502399212363976
Current loss after 3150 iterations: 0.0009501939332356801
Current loss after 3200 iterations: 0.0009501728013450467
980.071985 seconds (629.24 M allocations: 116.047 GiB, 1.16% gc time, 0.19% compilation time)
┌ Info: Final training loss after 3202 iterations: 0.0009501727548593611
└ @ Main In[13]:10
# Plot the losses
pl_losses = plot(1:200, losses[1:200], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "ADAM", color = :blue)
plot!(201:length(losses), losses[201:end], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "BFGS", color = :red)
## Analysis of the trained network
# Plot the data and the approximation
ts = first(solution.t):mean(diff(solution.t))/2:last(solution.t)
X̂ = predict(p_trained, Xₙ[:,1], ts)
# Trained on noisy data vs real solution
pl_trajectory = plot(ts, transpose(X̂), xlabel = "t", ylabel ="x(t), y(t)", color = :red, label = ["UDE Approximation" nothing])
scatter!(solution.t, transpose(Xₙ), color = :black, label = ["Measurements" nothing])
# Ideal unknown interactions of the predictor
Ȳ = [-p_[2]*(X̂[1,:].*X̂[2,:])';p_[3]*(X̂[1,:].*X̂[2,:])']
# Neural network guess
Ŷ = U(X̂,p_trained,st)[1]
pl_reconstruction = plot(ts, transpose(Ŷ), xlabel = "t", ylabel ="U(x,y)", color = :red, label = ["UDE Approximation" nothing])
plot!(ts, transpose(Ȳ), color = :black, label = ["True Interaction" nothing])
# Plot the error
pl_reconstruction_error = plot(ts, norm.(eachcol(Ȳ-Ŷ)), yaxis = :log, xlabel = "t", ylabel = "L2-Error", label = nothing, color = :red)
pl_missing = plot(pl_reconstruction, pl_reconstruction_error, layout = (2,1))
pl_overall = plot(pl_trajectory, pl_missing)
## Symbolic regression via sparse regression ( SINDy based )
# Create a Basis
@variables u[1:2]
# Generate the basis functions, multivariate polynomials up to deg 5
# and sine
b = [polynomial_basis(u, 5); sin.(u)]
basis = Basis(b,u);
# Create the thresholds which should be used in the search process
λ = exp10.(-3:0.01:5)
# Create an optimizer for the SINDy problem
opt = STLSQ(λ)
# Define different problems for the recovery
ideal_problem = DirectDataDrivenProblem(X̂, Ȳ)
nn_problem = DirectDataDrivenProblem(X̂, Ŷ)
# Test on ideal derivative data for unknown function ( not available )
println("Sparse regression")
full_res = solve(full_problem, basis, opt, maxiter = 10000, progress = true)
ideal_res = solve(ideal_problem, basis, opt, maxiter = 10000, progress = true)
nn_res = solve(nn_problem, basis, opt, maxiter = 10000, progress = true, sampler = DataSampler(Batcher(n = 4, shuffle = true)))
# Store the results
results = [full_res; ideal_res; nn_res]
# Show the results
map(println, results)
# Show the results
map(println ∘ result, results)
# Show the identified parameters
map(println ∘ parameter_map, results)
# Define the recovered, hyrid model
function recovered_dynamics!(du,u, p, t)
û = nn_res(u, p) # Network prediction
du[1] = p_[1]*u[1] + û[1]
du[2] = -p_[4]*u[2] + û[2]
estimation_prob = ODEProblem(recovered_dynamics!, u0, tspan, parameters(nn_res))
estimate = solve(estimation_prob, Tsit5(), saveat = solution.t)
Sparse regression
STLSQ 0%|▏ | ETA: 0:04:27
Threshold: 0.0010232929922807535
Best Objective: 0.0
Best Sparsity: 23.0
Current Objective: 0.0
STLSQ 0%|▏ | ETA: 0:02:41
Threshold: 0.0010232929922807535
Best Objective: 0.0
Best Sparsity: 23.0
Current Objective: 0.0
Current Sparsity: 23.0
Linear Solution with 2 equations and 20 parameters.
Returncode: solved
L₂ Norm error : [31.995291148539735, 1.2046710278865183]
AIC : [147.43325095127574, 45.77240223675897]
R² : [-1.2429510122420595, 0.990118386736816]
Linear Solution with 2 equations and 2 parameters.
Returncode: solved
L₂ Norm error : [8.108165065870263e-32, 1.3731880503480992e-31]
AIC : [-4362.980934730921, -4330.843170940274]
R² : [1.0, 1.0]
Linear Solution with 2 equations and 3 parameters.
Returncode: solved
L₂ Norm error : [1.6430538154747572, 8.039350195270902]
AIC : [36.28995216787681, 133.14524376489635]
R² : [0.7906825418979999, -0.31134160491819185]
Model ##Basis#629 with 2 equations
States : u[1] u[2]
Parameters : 20
Independent variable: t
Differential(t)(u[1]) = p₁ + p₁₀*(u[2]^2) + p₃*(u[1]^2) + p₁₇*sin(u[1]) + p₂*u[1] + p₄*(u[1]^3) + p₅*u[2] + p₁₂*(u[1]^2)*(u[2]^2) + p₁₅*(u[1]^2)*(u[2]^3) + p₁₃*(u[1]^3)*(u[2]^2) + p₁₁*(u[2]^2)*u[1] + p₁₄*(u[2]^3)*u[1] + p₇*(u[1]^2)*u[2] + p₈*(u[1]^3)*u[2] + p₉*(u[1]^4)*u[2] + p₁₆*(u[2]^4)*u[1] + p₆*u[1]*u[2]
Differential(t)(u[2]) = p₁₉*(u[1]^2)*u[2] + p₂₀*(u[1]^3)*u[2] + p₁₈*u[1]*u[2]
Model ##Basis#632 with 2 equations
States : u[1] u[2]
Parameters : p₁ p₂
Independent variable: t
φ₁ = p₁*u[1]*u[2]
φ₂ = p₂*u[1]*u[2]
Model ##Basis#635 with 2 equations
States : u[1] u[2]
Parameters : p₁ p₂ p₃
Independent variable: t
φ₁ = p₁*(u[1]^2)*u[2]
φ₂ = p₃*sin(u[1]) + p₂*u[1]
Pair{Sym{Real, Base.ImmutableDict{DataType, Any}}, Float64}[p₁ => 88.0, p₂ => 90.1, p₃ => 45.4, p₄ => 27.6, p₅ => 73.5, p₆ => -1107.4, p₇ => -2835.6, p₈ => 27.9, p₉ => 25.07, p₁₀ => 16.9, p₁₁ => -472.5, p₁₂ => 6115.8, p₁₃ => -117.4, p₁₄ => 22.016, p₁₅ => -659.7, p₁₆ => -25.4, p₁₇ => 62.5, p₁₈ => -13.9, p₁₉ => 31.25, p₂₀ => -15.5]
Pair{Sym{Real, Base.ImmutableDict{DataType, Any}}, Float64}[p₁ => -0.9, p₂ => 0.8]
Pair{Sym{Real, Base.ImmutableDict{DataType, Any}}, Float64}[p₁ => -2.012, p₂ => -2.25, p₃ => 3.3]
retcode: Success
Interpolation: 1st order linear
t: 31-element Vector{Float64}:
u: 31-element Vector{Vector{Float64}}:
[0.44249296, 4.6280594]
[0.3589998389049996, 3.90043838799414]
[0.32039236584298314, 3.288366022606705]
[0.30209737580427426, 2.7749501847648586]
[0.29519534196248753, 2.345125591932463]
[0.2957319262240388, 1.9858702305186864]
[0.30173746136416507, 1.6860822316973967]
[0.3121621269423411, 1.4363640495034167]
[0.3264408593862669, 1.228791374981506]
[0.3442811760621213, 1.05671475218947]
[0.3655278350988628, 0.9145062682055913]
[0.3901432875306497, 0.7974760103603659]
[0.41813107826611684, 0.7016575070929881]
[0.7074222345154149, 0.3968840102725108]
[0.7606094385047732, 0.3826360576519005]
[0.8159667970667692, 0.3714306439202008]
[0.8731544682099793, 0.3621229881052103]
[0.931833648528594, 0.35361775010135293]
[0.9918281148809143, 0.34494926059255426]
[1.053178574026386, 0.33524417437964893]
[1.1161906808048632, 0.3236127866772301]
[1.1815090141176419, 0.3091482705104759]
[1.250273983171013, 0.2909351024621833]
[1.3243227902154364, 0.26776728987299037]
[1.4068247104939124, 0.23796649830740021]
# Plot
## Simulation
# Look at long term prediction
t_long = (0.0, 50.0)
estimation_prob = ODEProblem(recovered_dynamics!, u0, t_long, parameters(nn_res))
estimate_long = solve(estimation_prob, Tsit5()) # Using higher tolerances here results in exit of julia
true_prob = ODEProblem(lotka!, u0, t_long, p_)
true_solution_long = solve(true_prob, Tsit5(), saveat = estimate_long.t)
## Post Processing and Plots
c1 = 3 # RGBA(174/255,192/255,201/255,1) # Maroon
c2 = :orange # RGBA(132/255,159/255,173/255,1) # Red
c3 = :blue # RGBA(255/255,90/255,0,1) # Orange
c4 = :purple # RGBA(153/255,50/255,204/255,1) # Purple
p1 = plot(t,abs.(Array(solution) .- estimate)' .+ eps(Float32),
lw = 3, yaxis = :log, title = "Timeseries of UODE Error",
color = [3 :orange], xlabel = "t",
label = ["x(t)" "y(t)"],
titlefont = "Helvetica", legendfont = "Helvetica",
legend = :topright)
# Plot L₂
p2 = plot3d(X̂[1,:], X̂[2,:], Ŷ[2,:], lw = 3,
title = "Neural Network Fit of U2(t)", color = c1,
label = "Neural Network", xaxis = "x", yaxis="y",
titlefont = "Helvetica", legendfont = "Helvetica",
legend = :bottomright)
plot!(X̂[1,:], X̂[2,:], Ȳ[2,:], lw = 3, label = "True Missing Term", color=c2)
p3 = scatter(solution, color = [c1 c2], label = ["x data" "y data"],
title = "Extrapolated Fit From Short Training Data",
titlefont = "Helvetica", legendfont = "Helvetica",
markersize = 5)
plot!(p3,true_solution_long, color = [c1 c2], linestyle = :dot, lw=5, label = ["True x(t)" "True y(t)"])
plot!(p3,estimate_long, color = [c3 c4], lw=1, label = ["Estimated x(t)" "Estimated y(t)"])
plot!(p3,[2.99,3.01],[0.0,10.0],lw=1,color=:black, label = nothing)
annotate!([(1.5,13,text("Training \nData", 10, :center, :top, :black, "Helvetica"))])
l = @layout [grid(1,2)
plot(p1,p2,p3,layout = l)
┌ Warning: dt(7.105427357601002e-15) <= dtmin(7.105427357601002e-15) at t=3.6948134503799572. Aborting. There is either an error in your model specification or the true solution is unstable.
└ @ SciMLBase /Users/stevenchiu/.julia/packages/SciMLBase/kTnku/src/integrator_interface.jl:516