In [1]:
# import Pkg; Pkg.add("Flux")
using OrdinaryDiffEq, RecursiveArrayTools, LinearAlgebra, Test, SparseArrays, SparseDiffTools, Sundials
using Plots; pyplot()
using BenchmarkTools
using DifferentialEquations
using Flux, DiffEqFlux, Optim

Optimizing parameters of an ODE for an Optimal Control problem

In [21]:
function lotka_volterra(du,u,p,t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end
u0 = [1.0,1.0]
tspan = (0.0,10.0)
p = [1.5,1.0,3.0,1.0]
prob = ODEProblem(lotka_volterra,u0,tspan,p)
sol = solve(prob)
t = 0:0.1:10.0
Out[21]:
0.0:0.1:10.0
In [22]:
u0_f(p,t0) = [p[2],p[4]]
tspan_f(p) = (0.0,10*p[4])
tspan_f(p) = (0.0,10)
p = [1.5,1.0,3.0,1.0]
prob = ODEProblem(lotka_volterra,u0_f,tspan_f,p)
sol = solve(prob,Tsit5(),saveat=0.1);
A = sol[1,:]; # length 101 vector
plot(sol)
scatter!(t,A)
Out[22]:
In [23]:
p = [2.2, 1.0, 2.0, 0.4]# Initial Parameter Vecto
prob = ODEProblem(lotka_volterra,u0,tspan,p)
sol = solve(prob,Tsit5(),saveat=0.1);
plot(sol)
scatter!(t,A)
Out[23]:
In [24]:
params = Flux.params(p)

function predict_rd() # Our 1-layer neural network
  concrete_solve(prob,Tsit5(),u0,p,saveat=0.1)[1,:]
end

# loss_rd() = sum(abs2,x-1 for x in predict_rd()) # loss function
loss_rd1() = sum(abs2,A .- predict_rd())
Out[24]:
loss_rd1 (generic function with 1 method)
In [25]:
data = Iterators.repeated((), 100)
opt = ADAM(0.1)
anim=Animation()
cb = function () #callback function to observe training
#     display(loss_rd1)
    # using `remake` to re-create our `prob` with current parameters `p`
    s=plot(solve(remake(prob,p=p),Tsit5(),saveat=0.1),ylim=(0,6))
    
#     display(plot(s,ylim=(0,10)))
#     display(scatter!(t,A))
    plot(s,ylim=(0,10))
    scatter!(t,A)
    frame(anim)
    
end

# Display the ODE with the initial parameter values.
cb()

Flux.train!(loss_rd1, params, data, opt, cb = cb)
In [27]:
gif(anim,"lotka_volterra.gif",fps=10)
┌ Info: Saved animation to 
│   fn = C:\Users\amira\Dropbox (MIT)\CBA\_Classes\4_Spring 2020\NMM\01_Code\project\lotka_volterra.gif
â”” @ Plots C:\Users\amira\.julia\packages\Plots\cc8wh\src\animation.jl:98
Out[27]:
In [ ]: