Applying Universal Differential Equations for Recovering Unknown Mechanism

Binder

Sample code is modified from ChrisRackauckas/universsal_differential_equations. This is the part of the work of

Rackauckas, Christopher, et al. “Universal differential equations for scientific machine learning.” arXiv preprint arXiv:2001.04385 (2020).

Activate the environment

## Environment and packages
cd(@__DIR__)
using Pkg
Pkg.activate("lotka")
  Activating project at `~/Documents/GitHub/Julia-for-SciML/hands-on/lotka`
@info "Instantiate"
Pkg.instantiate() # This step will take a while for numerous packages
┌ Info: Instantiate
└ @ Main In[2]:1
┌ Warning: The active manifest file is an older format with no julia version entry. Dependencies may have been resolved with a different julia version.
└ @ nothing /Users/stevenchiu/Documents/GitHub/Julia-for-SciML/hands-on/lotka/Manifest.toml:0
] st
Status `~/Documents/GitHub/Julia-for-SciML/hands-on/lotka/Project.toml`
  [c3fe647b] AbstractAlgebra v0.27.5
  [621f4979] AbstractFFTs v1.2.1
  [1520ce14] AbstractTrees v0.4.3
  [7d9f7c33] Accessors v0.1.20
  [79e6a3ab] Adapt v3.4.0
  [dce04be8] ArgCheck v2.3.0
⌅ [ec485272] ArnoldiMethod v0.1.0
  [4fba245c] ArrayInterface v6.0.23
  [30b0a656] ArrayInterfaceCore v0.1.22
  [6ba088a2] ArrayInterfaceGPUArrays v0.2.2
  [015c0d05] ArrayInterfaceOffsetArrays v0.1.6
  [b0d46f97] ArrayInterfaceStaticArrays v0.1.4
  [dd5226c6] ArrayInterfaceStaticArraysCore v0.1.3
  [a2b0951a] ArrayInterfaceTracker v0.1.1
  [4c555306] ArrayLayouts v0.8.12
  [15f4f7f2] AutoHashEquals v0.2.0
  [13072b0f] AxisAlgorithms v1.0.1
⌅ [ab4f0b2a] BFloat16s v0.2.0
  [aae01518] BandedMatrices v0.17.7
  [198e06fe] BangBang v0.3.37
  [9718e550] Baselet v0.1.1
  [e2ed5e7c] Bijections v0.1.4
  [62783981] BitTwiddlingConvenienceFunctions v0.1.4
  [8e7c35d0] BlockArrays v0.16.21
  [ffab5731] BlockBandedMatrices v0.11.9
  [fa961155] CEnum v0.4.2
  [2a0fbf3d] CPUSummary v0.1.27
  [00ebfdb7] CSTParser v3.3.6
  [052768ef] CUDA v3.12.0
  [49dc2e85] Calculus v0.5.1
  [7057c7e9] Cassette v0.3.10
  [082447d4] ChainRules v1.44.7
  [d360d2e6] ChainRulesCore v1.15.6
  [9e997f8a] ChangesOfVariables v0.1.4
  [fb6a15b2] CloseOpenIntervals v0.1.10
  [944b1d66] CodecZlib v0.7.0
  [35d6a980] ColorSchemes v3.19.0
  [3da002f7] ColorTypes v0.11.4
  [c3611d14] ColorVectorSpace v0.9.9
  [5ae59095] Colors v0.12.8
  [861a8166] Combinatorics v1.0.2
  [a80b9123] CommonMark v0.8.6
  [38540f10] CommonSolve v0.2.1
  [bbf7d656] CommonSubexpressions v0.3.0
  [34da2185] Compat v4.3.0
  [b0b7db55] ComponentArrays v0.13.4
  [b152e2b5] CompositeTypes v0.1.2
  [a33af91c] CompositionsBase v0.1.1
  [88cd18e8] ConsoleProgressMonitor v0.1.2
  [187b0558] ConstructionBase v1.4.1
  [6add18c4] ContextVariablesX v0.1.3
  [d38c429a] Contour v0.6.2
  [adafc99b] CpuId v0.3.1
  [a8cc5b0e] Crayons v4.1.1
  [9a962f9c] DataAPI v1.12.0
  [2445eb08] DataDrivenDiffEq v0.8.5
  [82cc6244] DataInterpolations v3.10.1
  [864edb3b] DataStructures v0.18.13
  [e2d170a0] DataValueInterfaces v1.0.0
  [244e2a9f] DefineSingletons v0.1.2
  [b429d917] DensityInterface v0.4.0
  [2b5f629d] DiffEqBase v6.105.1
  [459566f4] DiffEqCallbacks v2.24.2
  [c894b116] DiffEqJump v8.6.3
  [77a26b50] DiffEqNoiseProcess v5.13.1
  [9fdde737] DiffEqOperators v4.43.1
  [41bf760c] DiffEqSensitivity v6.79.0
  [163ba53b] DiffResults v1.1.0
  [b552c78f] DiffRules v1.12.0
  [b4f34e82] Distances v0.10.7
  [31c24e10] Distributions v0.25.76
  [ced4e74d] DistributionsAD v0.6.43
⌅ [ffbed154] DocStringExtensions v0.8.6
  [5b8099bc] DomainSets v0.5.14
  [fa6b7ba4] DualNumbers v0.6.8
  [7c1d4256] DynamicPolynomials v0.4.5
  [da5c29d0] EllipsisNotation v1.6.0
  [7da242da] Enzyme v0.10.12
  [d4d017d3] ExponentialUtilities v1.19.0
  [e2ba6199] ExprTools v0.1.8
  [c87230d0] FFMPEG v0.4.1
  [7a1cc6ca] FFTW v1.5.0
  [cc61a311] FLoops v0.2.1
  [b9860ae5] FLoopsBase v0.1.1
  [7034ab61] FastBroadcast v0.2.1
  [9aa1b823] FastClosures v0.3.2
  [29a986be] FastLapackInterface v1.2.7
  [5789e2e9] FileIO v1.16.0
  [1a297f60] FillArrays v0.13.5
  [6a86dc24] FiniteDiff v2.15.0
  [53c48c17] FixedPointNumbers v0.8.4
  [587475ba] Flux v0.13.6
  [9c68100b] FoldsThreads v0.1.1
  [59287772] Formatting v0.4.2
  [f6369f11] ForwardDiff v0.10.32
  [069b7b12] FunctionWrappers v1.1.3
  [d9f16b24] Functors v0.3.0
  [0c68f7d7] GPUArrays v8.5.0
  [46192b85] GPUArraysCore v0.1.2
  [61eb1bfa] GPUCompiler v0.16.4
  [28b8d3ca] GR v0.69.5
  [a75be94c] GalacticOptim v3.4.0
  [c145ed77] GenericSchur v0.5.3
  [5c1252a2] GeometryBasics v0.4.4
  [af5da776] GlobalSensitivity v2.1.2
  [86223c79] Graphs v1.7.4
  [42e2da0e] Grisu v1.0.2
  [0b43b601] Groebner v0.2.10
  [d5909c97] GroupsCore v0.4.0
  [cd3eb016] HTTP v1.5.0
  [3e5b6fbb] HostCPUFeatures v0.1.8
  [0e44f5e4] Hwloc v2.2.0
  [34004b35] HypergeometricFunctions v0.3.11
  [b5f81e59] IOCapture v0.2.2
  [7869d1d1] IRTools v0.4.7
  [615f187c] IfElse v0.1.1
  [d25df0c9] Inflate v0.1.3
  [83e8ac13] IniFile v0.5.1
  [22cec73e] InitialValues v0.3.1
  [18e54dd8] IntegerMathUtils v0.1.0
  [a98d9a8b] Interpolations v0.14.6
  [8197267c] IntervalSets v0.7.3
  [3587e190] InverseFunctions v0.1.8
  [92d709cd] IrrationalConstants v0.1.1
  [c8e1da08] IterTools v1.4.0
  [42fd0dbc] IterativeSolvers v0.9.2
  [82899510] IteratorInterfaceExtensions v1.0.0
  [033835bb] JLD2 v0.4.25
  [692b3bcd] JLLWrappers v1.4.1
  [682c06a0] JSON v0.21.3
  [98e50ef6] JuliaFormatter v1.0.13
  [b14d175d] JuliaVariables v0.2.4
  [ccbc3e58] JumpProcesses v9.2.0
  [e5e0dc1b] Juno v0.8.4
⌅ [ef3ab10e] KLU v0.3.0
  [5ab0869b] KernelDensity v0.6.5
  [ba0b0d4f] Krylov v0.8.4
  [0b1a1467] KrylovKit v0.5.4
  [929cbde3] LLVM v4.14.0
  [b964fa9f] LaTeXStrings v1.3.0
  [2ee39098] LabelledArrays v1.12.3
  [23fbe1c1] Latexify v0.15.17
  [a5e1c1ea] LatinHypercubeSampling v1.8.0
  [73f95e8e] LatticeRules v0.0.1
  [10f19ff3] LayoutPointers v0.1.11
  [50d2b5c4] Lazy v0.15.1
  [5078a376] LazyArrays v0.22.12
⌅ [d7e5e226] LazyBandedMatrices v0.7.17
  [0fc2ff8b] LeastSquaresOptim v0.8.3
  [1d6d02ad] LeftChildRightSiblingTrees v0.2.0
  [2d8b4e74] LevyArea v1.0.0
  [093fc24a] LightGraphs v1.3.5
  [d3d80556] LineSearches v7.2.0
  [7ed4a6bd] LinearSolve v1.27.0
  [98b081ad] Literate v2.14.0
  [2ab3a3ac] LogExpFunctions v0.3.18
  [e6f89c97] LoggingExtras v0.4.9
  [bdcacae8] LoopVectorization v0.12.136
  [b2108857] Lux v0.4.29
  [d8e11817] MLStyle v0.4.14
  [f1d291b0] MLUtils v0.2.11
  [1914dd2f] MacroTools v0.5.10
  [d125e4d3] ManualMemory v0.1.8
  [a3b82374] MatrixFactorizations v0.9.3
  [739be429] MbedTLS v1.1.6
  [eff96d63] Measurements v2.8.0
  [442fdcdd] Measures v0.3.1
  [e89f7d12] Media v0.5.0
  [c03570c3] Memoize v0.4.4
  [e9d8d322] Metatheory v1.3.5
  [128add7d] MicroCollections v0.1.3
  [e1d29d7a] Missings v1.0.2
  [961ee093] ModelingToolkit v8.29.1
  [46d2c3a1] MuladdMacro v0.2.2
  [102ac46a] MultivariatePolynomials v0.4.6
  [d8a4904e] MutableArithmetics v1.0.5
  [d41bc354] NLSolversBase v7.8.2
  [2774e3e8] NLsolve v4.5.1
  [872c559c] NNlib v0.8.9
  [a00861dc] NNlibCUDA v0.2.4
  [77ba4419] NaNMath v1.0.1
  [71a1bf82] NameResolution v0.1.5
  [8913a72c] NonlinearSolve v0.3.22
  [d8793406] ObjectFile v0.3.7
  [6fe1bfb0] OffsetArrays v1.12.8
  [429524aa] Optim v1.7.3
  [3bd65402] Optimisers v0.2.10
  [7f7a1694] Optimization v3.9.2
  [36348300] OptimizationOptimJL v0.1.3
  [42dfb2eb] OptimizationOptimisers v0.1.0
  [bac558e1] OrderedCollections v1.4.1
  [1dea7af3] OrdinaryDiffEq v6.29.3
  [90014a1f] PDMats v0.11.16
  [d96e819e] Parameters v0.12.3
  [69de0a69] Parsers v2.4.2
  [ccf2f8ad] PlotThemes v3.1.0
  [995b91a9] PlotUtils v1.3.1
  [91a5bcdd] Plots v1.35.4
  [e409e4f3] PoissonRandom v0.4.1
  [f517fe37] Polyester v0.6.16
  [1d0040c9] PolyesterWeave v0.1.10
  [85a6dd25] PositiveFactorizations v0.2.4
  [d236fae5] PreallocationTools v0.4.4
  [21216c6a] Preferences v1.3.0
  [8162dcfd] PrettyPrint v0.2.0
  [27ebfcd6] Primes v0.5.3
  [33c8b6b6] ProgressLogging v0.1.4
  [92933f4c] ProgressMeter v1.7.2
  [1fd47b50] QuadGK v2.5.0
  [8a4e6c94] QuasiMonteCarlo v0.2.14
  [74087812] Random123 v1.6.0
  [fb686558] RandomExtensions v0.4.3
  [e6cf234a] RandomNumbers v1.5.3
  [c84ed2f1] Ratios v0.4.3
  [c1ae055f] RealDot v0.1.0
  [3cdcf5f2] RecipesBase v1.3.1
  [01d81517] RecipesPipeline v0.6.7
  [731186ca] RecursiveArrayTools v2.32.0
  [f2c3362d] RecursiveFactorization v0.2.12
  [189a3867] Reexport v1.2.2
  [42d2dcc6] Referenceables v0.1.2
  [29dad682] RegularizationTools v0.6.0
  [05181044] RelocatableFolders v1.0.0
  [ae029012] Requires v1.3.0
  [ae5879a3] ResettableStacks v1.1.1
  [37e2e3b7] ReverseDiff v1.14.4
  [79098fc4] Rmath v0.7.0
  [7e49a35a] RuntimeGeneratedFunctions v0.5.3
  [3cdde19b] SIMDDualNumbers v0.1.1
  [94e857df] SIMDTypes v0.1.0
  [476501e8] SLEEFPirates v0.6.36
  [1bc83da4] SafeTestsets v0.0.1
  [0bca4576] SciMLBase v1.63.0
  [6c6a2e73] Scratch v1.1.1
  [efcf1570] Setfield v1.1.1
  [605ecd9f] ShowCases v0.1.0
  [992d4aef] Showoff v1.0.3
  [777ac1f9] SimpleBufferStream v1.1.0
  [699a6c99] SimpleTraits v0.9.4
  [ed01d8cd] Sobol v1.5.0
  [a2af1166] SortingAlgorithms v1.0.1
  [47a9eef4] SparseDiffTools v1.27.0
  [276daf66] SpecialFunctions v2.1.7
  [171d559e] SplittablesBase v0.1.15
  [860ef19b] StableRNGs v1.0.0
  [aedffcd0] Static v0.7.7
  [90137ffa] StaticArrays v1.5.9
  [1e83bf80] StaticArraysCore v1.4.0
  [82ae8749] StatsAPI v1.5.0
  [2913bbd2] StatsBase v0.33.21
  [4c63d2b9] StatsFuns v1.0.1
  [789caeaf] StochasticDiffEq v6.54.0
  [7792a7ef] StrideArraysCore v0.3.15
  [09ab397b] StructArrays v0.6.13
  [53d494c1] StructIO v0.3.0
  [d1185830] SymbolicUtils v0.19.11
  [0c5d862f] Symbolics v4.13.0
  [3783bdb8] TableTraits v1.0.1
  [bd369af6] Tables v1.10.0
  [62fd8b95] TensorCore v0.1.1
⌅ [8ea1fca8] TermInterface v0.2.3
  [5d786b92] TerminalLoggers v0.1.6
  [8290d209] ThreadingUtilities v0.5.0
  [ac1d9e8a] ThreadsX v0.1.11
  [a759f4b9] TimerOutputs v0.5.21
  [0796e94c] Tokenize v0.5.24
  [9f7883ad] Tracker v0.2.22
  [3bb67fe8] TranscodingStreams v0.9.9
  [28d57a85] Transducers v0.4.74
  [592b5752] Trapz v2.0.3
  [a2a6695c] TreeViews v0.3.0
  [d5829a12] TriangularSolve v0.1.14
  [5c2747f8] URIs v1.4.0
  [3a884ed6] UnPack v1.0.2
  [d9a01c3f] Underscores v3.0.0
  [1cfade01] UnicodeFun v0.4.1
  [1986cc42] Unitful v1.12.0
  [41fe7b60] Unzip v0.2.0
  [3d5dd08c] VectorizationBase v0.21.54
  [19fa3120] VertexSafeGraphs v0.2.0
  [efce3f68] WoodburyMatrices v0.5.5
  [a5390f91] ZipFile v0.10.0
  [e88e6eb3] Zygote v0.6.49
  [700de1a5] ZygoteRules v0.2.2
  [6e34b625] Bzip2_jll v1.0.8+0
  [83423d85] Cairo_jll v1.16.1+1
  [5ae413db] EarCut_jll v2.2.4+0
  [7cc45869] Enzyme_jll v0.0.43+0
  [2e619515] Expat_jll v2.4.8+0
  [b22a6f82] FFMPEG_jll v4.4.2+2
  [f5851436] FFTW_jll v3.3.10+0
  [a3f928ae] Fontconfig_jll v2.13.93+0
  [d7e528f0] FreeType2_jll v2.10.4+0
  [559328eb] FriBidi_jll v1.0.10+0
  [0656b61e] GLFW_jll v3.3.8+0
  [d2c73de3] GR_jll v0.69.1+0
  [78b55507] Gettext_jll v0.21.0+0
  [7746bdde] Glib_jll v2.74.0+1
  [3b182d85] Graphite2_jll v1.3.14+0
  [2e76f6c2] HarfBuzz_jll v2.8.1+1
  [e33a78d0] Hwloc_jll v2.8.0+1
  [1d5cc7b8] IntelOpenMP_jll v2018.0.3+2
  [aacddb02] JpegTurbo_jll v2.1.2+0
  [c1c5ebd0] LAME_jll v3.100.1+0
  [88015f11] LERC_jll v3.0.0+1
  [dad2f222] LLVMExtra_jll v0.0.16+0
  [dd4b983a] LZO_jll v2.10.1+0
  [dd192d2f] LibVPX_jll v1.10.0+0
  [e9f186c6] Libffi_jll v3.2.2+1
  [d4300ac3] Libgcrypt_jll v1.8.7+0
  [7e76a0d4] Libglvnd_jll v1.3.0+3
  [7add5ba3] Libgpg_error_jll v1.42.0+0
  [94ce4f54] Libiconv_jll v1.16.1+1
  [4b2f31a3] Libmount_jll v2.35.0+0
  [89763e89] Libtiff_jll v4.4.0+0
  [38a345b3] Libuuid_jll v2.36.0+0
  [856f044c] MKL_jll v2022.2.0+0
  [e7412a2a] Ogg_jll v1.3.5+1
  [458c3c95] OpenSSL_jll v1.1.17+0
  [efe28fd5] OpenSpecFun_jll v0.5.5+0
  [91d4177d] Opus_jll v1.3.2+0
  [2f80f16e] PCRE_jll v8.44.0+0
  [30392449] Pixman_jll v0.40.1+0
  [ea2cea3b] Qt5Base_jll v5.15.3+1
  [f50d1b31] Rmath_jll v0.3.0+0
  [a2964d1f] Wayland_jll v1.19.0+0
  [2381bf8a] Wayland_protocols_jll v1.25.0+0
  [02c8fc9c] XML2_jll v2.9.14+0
  [aed1982a] XSLT_jll v1.1.34+0
  [4f6342f7] Xorg_libX11_jll v1.6.9+4
  [0c0b7dd1] Xorg_libXau_jll v1.0.9+4
  [935fb764] Xorg_libXcursor_jll v1.2.0+4
  [a3789734] Xorg_libXdmcp_jll v1.1.3+4
  [1082639a] Xorg_libXext_jll v1.3.4+4
  [d091e8ba] Xorg_libXfixes_jll v5.0.3+4
  [a51aa0fd] Xorg_libXi_jll v1.7.10+4
  [d1454406] Xorg_libXinerama_jll v1.1.4+4
  [ec84b674] Xorg_libXrandr_jll v1.5.2+4
  [ea2f1a96] Xorg_libXrender_jll v0.9.10+4
  [14d82f49] Xorg_libpthread_stubs_jll v0.1.0+3
  [c7cfdc94] Xorg_libxcb_jll v1.13.0+3
  [cc61e674] Xorg_libxkbfile_jll v1.1.0+4
  [12413925] Xorg_xcb_util_image_jll v0.4.0+1
  [2def613f] Xorg_xcb_util_jll v0.4.0+1
  [975044d2] Xorg_xcb_util_keysyms_jll v0.4.0+1
  [0d47668e] Xorg_xcb_util_renderutil_jll v0.3.9+1
  [c22f9ab0] Xorg_xcb_util_wm_jll v0.4.1+1
  [35661453] Xorg_xkbcomp_jll v1.4.2+4
  [33bec58e] Xorg_xkeyboard_config_jll v2.27.0+4
  [c5fb5394] Xorg_xtrans_jll v1.4.0+3
  [3161d3a3] Zstd_jll v1.5.2+0
  [a4ae2306] libaom_jll v3.4.0+0
  [0ac62f75] libass_jll v0.15.1+0
  [f638f0a6] libfdk_aac_jll v2.0.2+0
  [b53b4c65] libpng_jll v1.6.38+0
  [f27f6e37] libvorbis_jll v1.3.7+1
  [1270edf5] x264_jll v2021.5.5+0
  [dfaa095f] x265_jll v3.5.0+0
  [d8fb68d0] xkbcommon_jll v1.4.1+0
  [0dad84c5] ArgTools v1.1.1
  [56f22d72] Artifacts
  [2a0f44e3] Base64
  [ade2ca70] Dates
  [8bb1440f] DelimitedFiles
  [8ba89e20] Distributed
  [f43a241f] Downloads v1.6.0
  [7b1f6079] FileWatching
  [9fa8497b] Future
  [b77e0a4c] InteractiveUtils
  [4af54fe1] LazyArtifacts
  [b27032c2] LibCURL v0.6.3
  [76f85450] LibGit2
  [8f399da3] Libdl
  [37e2e46d] LinearAlgebra
  [56ddb016] Logging
  [d6f4376e] Markdown
  [a63ad114] Mmap
  [ca575930] NetworkOptions v1.2.0
  [44cfe95a] Pkg v1.8.0
  [de0858da] Printf
  [9abbd945] Profile
  [3fa0cd96] REPL
  [9a3f8284] Random
  [ea8e919c] SHA v0.7.0
  [9e88b42a] Serialization
  [1a1011a3] SharedArrays
  [6462fe0b] Sockets
  [2f01184e] SparseArrays
  [10745b16] Statistics
  [4607b0f0] SuiteSparse
  [fa267f1f] TOML v1.0.0
  [a4e569a6] Tar v1.10.1
  [8dfed614] Test
  [cf7118a7] UUIDs
  [4ec0a83e] Unicode
  [e66e0078] CompilerSupportLibraries_jll v0.5.2+0
  [deac9b47] LibCURL_jll v7.84.0+0
  [29816b5a] LibSSH2_jll v1.10.2+0
  [c8ffd9c3] MbedTLS_jll v2.28.0+0
  [14a3606d] MozillaCACerts_jll v2022.2.1
  [4536629a] OpenBLAS_jll v0.3.20+0
  [05823500] OpenLibm_jll v0.8.1+0
  [bea87d4a] SuiteSparse_jll v5.10.1+0
  [83775a58] Zlib_jll v1.2.12+3
  [8e850b90] libblastrampoline_jll v5.1.1+0
  [8e850ede] nghttp2_jll v1.48.0+0
  [3f19e933] p7zip_jll v17.4.0+0
Info Packages marked with ⌅ have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated`

Import packages

  • When importing multiple packages, this requires precompilation that takes some time.
@info "Precompile"
using OrdinaryDiffEq
using ModelingToolkit
using DataDrivenDiffEq
using LinearAlgebra, ComponentArrays
using Optimization, OptimizationOptimisers, OptimizationOptimJL #OptimizationFlux for ADAM and OptimizationOptimJL for BFGS
using DiffEqSensitivity
using Lux
using Plots
gr()
using Statistics

# Set a random seed for reproduceable behaviour
using Random
rng = Random.default_rng()
Random.seed!(1234);
@info "Complete Precompilation"
┌ Info: Precompile
└ @ Main In[4]:1
┌ Info: Precompiling ModelingToolkit [961ee093-0014-501f-94e3-6117800e7a78]
└ @ Base loading.jl:1664
┌ Info: Precompiling DataDrivenDiffEq [2445eb08-9709-466a-b3fc-47e12bd697a2]
└ @ Base loading.jl:1664
┌ Warning: The variable syntax (u[1:n])(t) is deprecated. Use (u(t))[1:n] instead.
│                   The former creates an array of functions, while the latter creates an array valued function.
│                   The deprecated syntax will cause an error in the next major release of Symbolics.
│                   This change will facilitate better implementation of various features of Symbolics.
└ @ Symbolics ~/.julia/packages/Symbolics/FGTCH/src/variable.jl:129
┌ Warning: Type annotations on keyword arguments not currently supported in recipes. Type information has been discarded
└ @ RecipesBase ~/.julia/packages/RecipesBase/6AijY/src/RecipesBase.jl:117
┌ Info: Precompiling Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80]
└ @ Base loading.jl:1664
┌ Info: Complete Precompilation
└ @ Main In[4]:17

Building ODE model: Ground truth model

For simplicity, we use Lotka-Volterra system as an example

\[\begin{align} \dot{x} &= \alpha x - \beta xy\\ \dot{y} &= \gamma xy- \delta y \end{align}\]

where \(\alpha, \beta, \gamma\), and \(\delta\) are positive real parameters

## Data generation
function lotka!(du, u, p, t)
    α, β, γ, δ = p
    du[1] = α*u[1] - β*u[2]*u[1]
    du[2] = γ*u[1]*u[2]  - δ*u[2]
end
lotka! (generic function with 1 method)
# Define the experimental parameter
tspan = (0.0,3.0)
u0 = [0.44249296,4.6280594]
p_ = [1.3, 0.9, 0.8, 1.8]
4-element Vector{Float64}:
 1.3
 0.9
 0.8
 1.8

Solving ODE

# Solve
prob = ODEProblem(lotka!, u0,tspan, p_)
solution = solve(prob, Vern7(), abstol=1e-12, reltol=1e-12, saveat = 0.1)
retcode: Success
Interpolation: 1st order linear
t: 31-element Vector{Float64}:
 0.0
 0.1
 0.2
 0.3
 0.4
 0.5
 0.6
 0.7
 0.8
 0.9
 1.0
 1.1
 1.2
 ⋮
 1.9
 2.0
 2.1
 2.2
 2.3
 2.4
 2.5
 2.6
 2.7
 2.8
 2.9
 3.0
u: 31-element Vector{Vector{Float64}}:
 [0.44249296, 4.6280594]
 [0.34212452862086234, 3.98764547181634]
 [0.2793966078254349, 3.4139529441083147]
 [0.2394952228707143, 2.9110318130603883]
 [0.21413620714095402, 2.4758280205419836]
 [0.19854852659179129, 2.1022922430734137]
 [0.18991187927524103, 1.7834096349202704]
 [0.18652973211225643, 1.5121821427640152]
 [0.18737918127509637, 1.2820806846455604]
 [0.1918587411736629, 1.087227597605956]
 [0.1996432344128222, 0.9224424008592909]
 [0.2105985019620811, 0.7832199752377471]
 [0.22473063540355143, 0.6656774980182895]
 ⋮
 [0.4333056937367298, 0.22471175932636067]
 [0.48425346211989406, 0.1947029152564331]
 [0.5425361548950363, 0.16943926722620506]
 [0.6091040110729008, 0.14819092695665834]
 [0.6850407509453579, 0.13034710141497852]
 [0.7715795653361799, 0.11539841080610512]
 [0.8701212001899306, 0.10292221843899205]
 [0.9822541897624152, 0.09257065810821445]
 [1.1097772412678872, 0.08406114122763123]
 [1.2547236687788759, 0.07716924328704818]
 [1.419387582491876, 0.07172402271816045]
 [1.606351205697802, 0.06760604257226555]

Data processing

# Ideal data
X = Array(solution)
t = solution.t
DX = Array(solution(solution.t, Val{1}))

full_problem = DataDrivenProblem(X, t = t, DX = DX)

# Add noise in terms of the mean
= mean(X, dims = 2)
noise_magnitude = 5e-3
Xₙ = X .+ (noise_magnitude*x̄) .* randn(eltype(X), size(X))
2×31 Matrix{Float64}:
 0.444955  0.344412  0.277873  0.246362  …  1.2542     1.41631    1.60945
 4.6231    3.98748   3.40663   2.91876      0.0810739  0.0772539  0.0672219

Data Visualization

plot(solution, alpha = 0.75, color = :black, label = ["True Data" nothing])
scatter!(t, transpose(Xₙ), color = :red, label = ["Noisy Data" nothing])

Build surrogate model

Suppose we only know part of the Lotka-Voltera model, and use CNN to surrogate the unknown part

\[\begin{align} \dot{x} &= \theta_1 x + U_1(\theta_3, x, y)\\ \dot{y} &= -\theta_2 y + U_2(\theta_3..., x, y) \end{align}\]

## Define the network
# Gaussian RBF as activation
rbf(x) = exp.(-(x.^2))

# Multilayer FeedForward
U = Lux.Chain(
    Lux.Dense(2,5,rbf), Lux.Dense(5,5, rbf), Lux.Dense(5,5, rbf), Lux.Dense(5,2)
)
# Get the initial parameters and state variables of the model
p, st = Lux.setup(rng, U)

# Define the hybrid model
function ude_dynamics!(du,u, p, t, p_true)
= U(u, p, st)[1] # Network prediction
    du[1] = p_true[1]*u[1] + û[1]
    du[2] = -p_true[4]*u[2] + û[2]
end
ude_dynamics! (generic function with 1 method)
# Closure with the known parameter
nn_dynamics!(du,u,p,t) = ude_dynamics!(du,u,p,t,p_)
# Define the problem (Fix: https://discourse.julialang.org/t/issue-with-ude-repository-lv-scenario-1/88618/5)
prob_nn = ODEProblem{true, SciMLBase.FullSpecialize}(nn_dynamics!,Xₙ[:, 1], tspan, p)
#prob_nn = ODEProblem(nn_dynamics!,Xₙ[:, 1], tspan, p)
ODEProblem with uType Vector{Float64} and tType Float64. In-place: true
timespan: (0.0, 3.0)
u0: 2-element Vector{Float64}:
 0.44495468189157616
 4.623098367786485

Training Setting

## Function to train the network
# Define a predictor
function predict(θ, X = Xₙ[:,1], T = t)
    _prob = remake(prob_nn, u0 = X, tspan = (T[1], T[end]), p = θ)
    Array(solve(_prob, Vern7(), saveat = T,
                abstol=1e-6, reltol=1e-6,
                sensealg = ForwardDiffSensitivity()
                ))
end

# Simple L2 loss
function loss(θ)
= predict(θ)
    sum(abs2, Xₙ .- X̂)
end

# Container to track the losses
losses = Float64[]

callback = function (p, l)
  push!(losses, l)
  if length(losses)%50==0
      println("Current loss after $(length(losses)) iterations: $(losses[end])")
  end
  return false
end
#1 (generic function with 1 method)

Training

The training is splitted to two steps: 1. ADAM: for better convergence 2. BFGS: get better position

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(p))
res1 = Optimization.solve(optprob, ADAM(0.1), callback=callback, maxiters = 200)

@info "Training loss after $(length(losses)) iterations: $(losses[end])"
# Train with BFGS
@time optprob2 = Optimization.OptimizationProblem(optf, res1.minimizer)
@time res2 = Optimization.solve(optprob2, Optim.BFGS(initial_stepnorm=0.01), callback=callback, maxiters = 3000)
@info "Final training loss after $(length(losses)) iterations: $(losses[end])"
Current loss after 50 iterations: 3.2919945635723753
Current loss after 100 iterations: 1.7058558916650455
Current loss after 150 iterations: 1.6697049368588788
Current loss after 200 iterations: 1.6423084245679131
  0.002839 seconds (946 allocations: 59.323 KiB, 97.18% compilation time)
┌ Info: Training loss after 201 iterations: 1.6423084245679131
└ @ Main In[13]:6
Current loss after 250 iterations: 0.023661287470332845
Current loss after 300 iterations: 0.013746817618177573
Current loss after 350 iterations: 0.0032781193503113072
Current loss after 400 iterations: 0.0017736994450182483
Current loss after 450 iterations: 0.0016452061747373918
Current loss after 500 iterations: 0.001415755561505311
Current loss after 550 iterations: 0.001226557886040906
Current loss after 600 iterations: 0.001107089612914693
Current loss after 650 iterations: 0.0009982934123468784
Current loss after 700 iterations: 0.0009875749757684356
Current loss after 750 iterations: 0.0009812594981828098
Current loss after 800 iterations: 0.0009796658875397177
Current loss after 850 iterations: 0.0009793537523280391
Current loss after 900 iterations: 0.0009772977493860333
Current loss after 950 iterations: 0.0009743301969273224
Current loss after 1000 iterations: 0.0009727084653298941
Current loss after 1050 iterations: 0.000972080483835656
Current loss after 1100 iterations: 0.0009717534715880698
Current loss after 1150 iterations: 0.0009656859249757323
Current loss after 1200 iterations: 0.000962561381268262
Current loss after 1250 iterations: 0.0009612859817110582
Current loss after 1300 iterations: 0.0009591475053966475
Current loss after 1350 iterations: 0.0009582234874165839
Current loss after 1400 iterations: 0.0009574676566527876
Current loss after 1450 iterations: 0.0009569122026897299
Current loss after 1500 iterations: 0.0009567527113130108
Current loss after 1550 iterations: 0.0009561199657093553
Current loss after 1600 iterations: 0.0009549884254749657
Current loss after 1650 iterations: 0.0009541600089060593
Current loss after 1700 iterations: 0.0009537452602046372
Current loss after 1750 iterations: 0.0009536569492935126
Current loss after 1800 iterations: 0.0009534426125349452
Current loss after 1850 iterations: 0.0009532818826279794
Current loss after 1900 iterations: 0.0009530863107623305
Current loss after 1950 iterations: 0.0009529698205412621
Current loss after 2000 iterations: 0.0009528567762890976
Current loss after 2050 iterations: 0.0009525373206446654
Current loss after 2100 iterations: 0.0009521593217738839
Current loss after 2150 iterations: 0.0009520958641338131
Current loss after 2200 iterations: 0.0009520568498258301
Current loss after 2250 iterations: 0.0009517387568709354
Current loss after 2300 iterations: 0.0009516121690135998
Current loss after 2350 iterations: 0.0009514360660860543
Current loss after 2400 iterations: 0.0009512526383059336
Current loss after 2450 iterations: 0.0009511495748527354
Current loss after 2500 iterations: 0.000951106817478341
Current loss after 2550 iterations: 0.0009510985159145336
Current loss after 2600 iterations: 0.0009509270459168483
Current loss after 2650 iterations: 0.000950863437522074
Current loss after 2700 iterations: 0.0009507791033767345
Current loss after 2750 iterations: 0.000950657136744276
Current loss after 2800 iterations: 0.0009506293508157388
Current loss after 2850 iterations: 0.0009506101911525282
Current loss after 2900 iterations: 0.0009504534699507922
Current loss after 2950 iterations: 0.0009504069207206256
Current loss after 3000 iterations: 0.0009503383212540466
Current loss after 3050 iterations: 0.0009503196958639988
Current loss after 3100 iterations: 0.0009502399212363976
Current loss after 3150 iterations: 0.0009501939332356801
Current loss after 3200 iterations: 0.0009501728013450467
980.071985 seconds (629.24 M allocations: 116.047 GiB, 1.16% gc time, 0.19% compilation time)
┌ Info: Final training loss after 3202 iterations: 0.0009501727548593611
└ @ Main In[13]:10

Visualize loss

# Plot the losses
pl_losses = plot(1:200, losses[1:200], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "ADAM", color = :blue)
plot!(201:length(losses), losses[201:end], yaxis = :log10, xaxis = :log10, xlabel = "Iterations", ylabel = "Loss", label = "BFGS", color = :red)

# Rename the best candidate
p_trained = res2.minimizer;
## Analysis of the trained network
# Plot the data and the approximation
ts = first(solution.t):mean(diff(solution.t))/2:last(solution.t)
= predict(p_trained, Xₙ[:,1], ts)
# Trained on noisy data vs real solution
pl_trajectory = plot(ts, transpose(X̂), xlabel = "t", ylabel ="x(t), y(t)", color = :red, label = ["UDE Approximation" nothing])
scatter!(solution.t, transpose(Xₙ), color = :black, label = ["Measurements" nothing])

# Ideal unknown interactions of the predictor
= [-p_[2]*(X̂[1,:].*X̂[2,:])';p_[3]*(X̂[1,:].*X̂[2,:])']
# Neural network guess
= U(X̂,p_trained,st)[1]

pl_reconstruction = plot(ts, transpose(Ŷ), xlabel = "t", ylabel ="U(x,y)", color = :red, label = ["UDE Approximation" nothing])
plot!(ts, transpose(Ȳ), color = :black, label = ["True Interaction" nothing])

# Plot the error
pl_reconstruction_error = plot(ts, norm.(eachcol(Ȳ-Ŷ)), yaxis = :log, xlabel = "t", ylabel = "L2-Error", label = nothing, color = :red)
pl_missing = plot(pl_reconstruction, pl_reconstruction_error, layout = (2,1))

pl_overall = plot(pl_trajectory, pl_missing)
## Symbolic regression via sparse regression ( SINDy based )

Symbolic Regression

# Create a Basis
@variables u[1:2]
# Generate the basis functions, multivariate polynomials up to deg 5
# and sine
b = [polynomial_basis(u, 5); sin.(u)]
basis = Basis(b,u);

# Create the thresholds which should be used in the search process
λ = exp10.(-3:0.01:5)
# Create an optimizer for the SINDy problem
opt = STLSQ(λ)

# Define different problems for the recovery
ideal_problem = DirectDataDrivenProblem(X̂, Ȳ)
nn_problem = DirectDataDrivenProblem(X̂, Ŷ)

# Test on ideal derivative data for unknown function ( not available )
println("Sparse regression")
full_res = solve(full_problem, basis, opt, maxiter = 10000, progress = true)
ideal_res = solve(ideal_problem, basis, opt, maxiter = 10000, progress = true)
nn_res = solve(nn_problem, basis, opt, maxiter = 10000, progress = true, sampler = DataSampler(Batcher(n = 4, shuffle = true)))


# Store the results
results = [full_res; ideal_res; nn_res]
# Show the results
map(println, results)
# Show the results
map(println  result, results)
# Show the identified parameters
map(println  parameter_map, results)

# Define the recovered, hyrid model
function recovered_dynamics!(du,u, p, t)
= nn_res(u, p) # Network prediction
    du[1] = p_[1]*u[1] + û[1]
    du[2] = -p_[4]*u[2] + û[2]
end


estimation_prob = ODEProblem(recovered_dynamics!, u0, tspan, parameters(nn_res))
estimate = solve(estimation_prob, Tsit5(), saveat = solution.t)
Sparse regression
STLSQ   0%|▏                                             |  ETA: 0:04:27
  Threshold:          0.0010232929922807535
  Best Objective:     0.0
  Best Sparsity:      23.0
  Current Objective:  0.0
STLSQ   0%|▏                                             |  ETA: 0:02:41
  Threshold:          0.0010232929922807535
  Best Objective:     0.0
  Best Sparsity:      23.0
  Current Objective:  0.0
  Current Sparsity:   23.0
Linear Solution with 2 equations and 20 parameters.
Returncode: solved
L₂ Norm error : [31.995291148539735, 1.2046710278865183]
AIC : [147.43325095127574, 45.77240223675897]
R² : [-1.2429510122420595, 0.990118386736816]

Linear Solution with 2 equations and 2 parameters.
Returncode: solved
L₂ Norm error : [8.108165065870263e-32, 1.3731880503480992e-31]
AIC : [-4362.980934730921, -4330.843170940274]
R² : [1.0, 1.0]

Linear Solution with 2 equations and 3 parameters.
Returncode: solved
L₂ Norm error : [1.6430538154747572, 8.039350195270902]
AIC : [36.28995216787681, 133.14524376489635]
R² : [0.7906825418979999, -0.31134160491819185]

Model ##Basis#629 with 2 equations
States : u[1] u[2]
Parameters : 20
Independent variable: t
Equations
Differential(t)(u[1]) = p₁ + p₁₀*(u[2]^2) + p₃*(u[1]^2) + p₁₇*sin(u[1]) + p₂*u[1] + p₄*(u[1]^3) + p₅*u[2] + p₁₂*(u[1]^2)*(u[2]^2) + p₁₅*(u[1]^2)*(u[2]^3) + p₁₃*(u[1]^3)*(u[2]^2) + p₁₁*(u[2]^2)*u[1] + p₁₄*(u[2]^3)*u[1] + p₇*(u[1]^2)*u[2] + p₈*(u[1]^3)*u[2] + p₉*(u[1]^4)*u[2] + p₁₆*(u[2]^4)*u[1] + p₆*u[1]*u[2]
Differential(t)(u[2]) = p₁₉*(u[1]^2)*u[2] + p₂₀*(u[1]^3)*u[2] + p₁₈*u[1]*u[2]
Model ##Basis#632 with 2 equations
States : u[1] u[2]
Parameters : p₁ p₂
Independent variable: t
Equations
φ₁ = p₁*u[1]*u[2]
φ₂ = p₂*u[1]*u[2]
Model ##Basis#635 with 2 equations
States : u[1] u[2]
Parameters : p₁ p₂ p₃
Independent variable: t
Equations
φ₁ = p₁*(u[1]^2)*u[2]
φ₂ = p₃*sin(u[1]) + p₂*u[1]
Pair{Sym{Real, Base.ImmutableDict{DataType, Any}}, Float64}[p₁ => 88.0, p₂ => 90.1, p₃ => 45.4, p₄ => 27.6, p₅ => 73.5, p₆ => -1107.4, p₇ => -2835.6, p₈ => 27.9, p₉ => 25.07, p₁₀ => 16.9, p₁₁ => -472.5, p₁₂ => 6115.8, p₁₃ => -117.4, p₁₄ => 22.016, p₁₅ => -659.7, p₁₆ => -25.4, p₁₇ => 62.5, p₁₈ => -13.9, p₁₉ => 31.25, p₂₀ => -15.5]
Pair{Sym{Real, Base.ImmutableDict{DataType, Any}}, Float64}[p₁ => -0.9, p₂ => 0.8]
Pair{Sym{Real, Base.ImmutableDict{DataType, Any}}, Float64}[p₁ => -2.012, p₂ => -2.25, p₃ => 3.3]
retcode: Success
Interpolation: 1st order linear
t: 31-element Vector{Float64}:
 0.0
 0.1
 0.2
 0.3
 0.4
 0.5
 0.6
 0.7
 0.8
 0.9
 1.0
 1.1
 1.2
 ⋮
 1.9
 2.0
 2.1
 2.2
 2.3
 2.4
 2.5
 2.6
 2.7
 2.8
 2.9
 3.0
u: 31-element Vector{Vector{Float64}}:
 [0.44249296, 4.6280594]
 [0.3589998389049996, 3.90043838799414]
 [0.32039236584298314, 3.288366022606705]
 [0.30209737580427426, 2.7749501847648586]
 [0.29519534196248753, 2.345125591932463]
 [0.2957319262240388, 1.9858702305186864]
 [0.30173746136416507, 1.6860822316973967]
 [0.3121621269423411, 1.4363640495034167]
 [0.3264408593862669, 1.228791374981506]
 [0.3442811760621213, 1.05671475218947]
 [0.3655278350988628, 0.9145062682055913]
 [0.3901432875306497, 0.7974760103603659]
 [0.41813107826611684, 0.7016575070929881]
 ⋮
 [0.7074222345154149, 0.3968840102725108]
 [0.7606094385047732, 0.3826360576519005]
 [0.8159667970667692, 0.3714306439202008]
 [0.8731544682099793, 0.3621229881052103]
 [0.931833648528594, 0.35361775010135293]
 [0.9918281148809143, 0.34494926059255426]
 [1.053178574026386, 0.33524417437964893]
 [1.1161906808048632, 0.3236127866772301]
 [1.1815090141176419, 0.3091482705104759]
 [1.250273983171013, 0.2909351024621833]
 [1.3243227902154364, 0.26776728987299037]
 [1.4068247104939124, 0.23796649830740021]

Visualization

# Plot
plot(solution)
plot!(estimate)

## Simulation

# Look at long term prediction
t_long = (0.0, 50.0)
estimation_prob = ODEProblem(recovered_dynamics!, u0, t_long, parameters(nn_res))
estimate_long = solve(estimation_prob, Tsit5()) # Using higher tolerances here results in exit of julia
plot(estimate_long)

true_prob = ODEProblem(lotka!, u0, t_long, p_)
true_solution_long = solve(true_prob, Tsit5(), saveat = estimate_long.t)
plot!(true_solution_long)



## Post Processing and Plots

c1 = 3 # RGBA(174/255,192/255,201/255,1) # Maroon
c2 = :orange # RGBA(132/255,159/255,173/255,1) # Red
c3 = :blue # RGBA(255/255,90/255,0,1) # Orange
c4 = :purple # RGBA(153/255,50/255,204/255,1) # Purple

p1 = plot(t,abs.(Array(solution) .- estimate)' .+ eps(Float32),
          lw = 3, yaxis = :log, title = "Timeseries of UODE Error",
          color = [3 :orange], xlabel = "t",
          label = ["x(t)" "y(t)"],
          titlefont = "Helvetica", legendfont = "Helvetica",
          legend = :topright)

# Plot L₂
p2 = plot3d(X̂[1,:], X̂[2,:], Ŷ[2,:], lw = 3,
     title = "Neural Network Fit of U2(t)", color = c1,
     label = "Neural Network", xaxis = "x", yaxis="y",
     titlefont = "Helvetica", legendfont = "Helvetica",
     legend = :bottomright)
plot!(X̂[1,:], X̂[2,:], Ȳ[2,:], lw = 3, label = "True Missing Term", color=c2)

p3 = scatter(solution, color = [c1 c2], label = ["x data" "y data"],
             title = "Extrapolated Fit From Short Training Data",
             titlefont = "Helvetica", legendfont = "Helvetica",
             markersize = 5)

plot!(p3,true_solution_long, color = [c1 c2], linestyle = :dot, lw=5, label = ["True x(t)" "True y(t)"])
plot!(p3,estimate_long, color = [c3 c4], lw=1, label = ["Estimated x(t)" "Estimated y(t)"])
plot!(p3,[2.99,3.01],[0.0,10.0],lw=1,color=:black, label = nothing)
annotate!([(1.5,13,text("Training \nData", 10, :center, :top, :black, "Helvetica"))])
l = @layout [grid(1,2)
             grid(1,1)]
plot(p1,p2,p3,layout = l)
┌ Warning: dt(7.105427357601002e-15) <= dtmin(7.105427357601002e-15) at t=3.6948134503799572. Aborting. There is either an error in your model specification or the true solution is unstable.
└ @ SciMLBase /Users/stevenchiu/.julia/packages/SciMLBase/kTnku/src/integrator_interface.jl:516