Skip to content

Commit 8eec388

Browse files
Merge pull request #254 from JuliaDiffEq/tgrad
add tgrad to ODESystem
2 parents 15e4f80 + 7ba257f commit 8eec388

File tree

3 files changed

+66
-3
lines changed

3 files changed

+66
-3
lines changed

src/systems/diffeqs/diffeqsystem.jl

+33-2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ struct ODESystem <: AbstractSystem
6565
"""Parameter variables."""
6666
ps::Vector{Variable}
6767
"""
68+
Time-derivative matrix. Note: this field will not be defined until
69+
[`calculate_tgrad`](@ref) is called on the system.
70+
"""
71+
tgrad::RefValue{Vector{Expression}}
72+
"""
6873
Jacobian matrix. Note: this field will not be defined until
6974
[`calculate_jacobian`](@ref) is called on the system.
7075
"""
@@ -99,10 +104,11 @@ function ODESystem(eqs)
99104
end
100105

101106
function ODESystem(deqs::AbstractVector{DiffEq}, iv, dvs, ps)
107+
tgrad = RefValue(Vector{Expression}(undef, 0))
102108
jac = RefValue(Matrix{Expression}(undef, 0, 0))
103109
Wfact = RefValue(Matrix{Expression}(undef, 0, 0))
104110
Wfact_t = RefValue(Matrix{Expression}(undef, 0, 0))
105-
ODESystem(deqs, iv, dvs, ps, jac, Wfact, Wfact_t)
111+
ODESystem(deqs, iv, dvs, ps, tgrad, jac, Wfact, Wfact_t)
106112
end
107113

108114
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps)
@@ -133,6 +139,17 @@ independent_variables(sys::ODESystem) = Set{Variable}([sys.iv])
133139
dependent_variables(sys::ODESystem) = Set{Variable}(sys.dvs)
134140
parameters(sys::ODESystem) = Set{Variable}(sys.ps)
135141

142+
function calculate_tgrad(sys::ODESystem)
143+
isempty(sys.tgrad[]) || return sys.tgrad[] # use cached tgrad, if possible
144+
rhs = [detime_dvs(eq.rhs) for eq sys.eqs]
145+
iv = sys.iv()
146+
notime_tgrad = [expand_derivatives(ModelingToolkit.Differential(iv)(r)) for r in rhs]
147+
@show notime_tgrad
148+
tgrad = retime_dvs.(notime_tgrad,(sys.dvs,),iv)
149+
@show tgrad
150+
sys.tgrad[] = tgrad
151+
return tgrad
152+
end
136153

137154
function calculate_jacobian(sys::ODESystem)
138155
isempty(sys.jac[]) || return sys.jac[] # use cached Jacobian, if possible
@@ -160,6 +177,11 @@ function (f::ODEToExpr)(O::Operation)
160177
end
161178
(f::ODEToExpr)(x) = convert(Expr, x)
162179

180+
function generate_tgrad(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)
181+
tgrad = calculate_tgrad(sys)
182+
return build_function(tgrad, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
183+
end
184+
163185
function generate_jacobian(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)
164186
jac = calculate_jacobian(sys)
165187
return build_function(jac, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
@@ -218,13 +240,21 @@ are used to set the order of the dependent variable and parameter vectors,
218240
respectively.
219241
"""
220242
function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs = sys.dvs, ps = sys.ps;
221-
version = nothing,
243+
version = nothing, tgrad=false,
222244
jac = false, Wfact = false) where {iip}
223245
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false})
224246

225247
f(u,p,t) = f_oop(u,p,t)
226248
f(du,u,p,t) = f_iip(du,u,p,t)
227249

250+
if tgrad
251+
tgrad_oop,tgrad_iip = generate_tgrad(sys, dvs, ps, Val{false})
252+
_tgrad(u,p,t) = tgrad_oop(u,p,t)
253+
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
254+
else
255+
_tgrad = nothing
256+
end
257+
228258
if jac
229259
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false})
230260
_jac(u,p,t) = jac_oop(u,p,t)
@@ -246,6 +276,7 @@ function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs = sys.dvs, ps = sys.ps;
246276
end
247277

248278
ODEFunction{iip}(f,jac=_jac,
279+
tgrad = tgrad,
249280
Wfact = _Wfact,
250281
Wfact_t = _Wfact_t,
251282
syms = string.(sys.dvs))

src/utils.jl

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using MacroTools
22

3-
43
function Base.convert(::Type{Expression}, ex::Expr)
54
ex.head === :if && (ex = Expr(:call, ifelse, ex.args...))
65
ex.head === :call || throw(ArgumentError("internal representation does not support non-call Expr"))
@@ -13,6 +12,9 @@ end
1312
Base.convert(::Type{Expression}, x::Expression) = x
1413
Base.convert(::Type{Expression}, x::Number) = Constant(x)
1514
Base.convert(::Type{Expression}, x::Bool) = Constant(x)
15+
Base.convert(::Type{Expression}, x::Variable) = convert(Operation,x)
16+
Base.convert(::Type{Expression}, x::Operation) = x
17+
Base.convert(::Type{Expression}, x::Symbol) = Base.convert(Expression,eval(x))
1618
Expression(x::Bool) = Constant(x)
1719

1820
function build_expr(head::Symbol, args)
@@ -33,6 +35,24 @@ function flatten_expr!(x)
3335
x
3436
end
3537

38+
function detime_dvs(op::Operation)
39+
if op.op isa Variable
40+
Operation(Variable(op.op.name,known=op.op.known),Expression[])
41+
else
42+
Operation(op.op,detime_dvs.(op.args))
43+
end
44+
end
45+
detime_dvs(op::Constant) = op
46+
47+
function retime_dvs(op::Operation,dvs,iv)
48+
if op.op isa Variable && op.op dvs
49+
Operation(Variable(op.op.name),Expression[iv])
50+
else
51+
Operation(op.op,retime_dvs.(op.args,(dvs,),iv))
52+
end
53+
end
54+
retime_dvs(op::Constant,dvs,iv) = op
55+
3656
function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, expression = Val{true};
3757
checkbounds = false, constructor=nothing, linenumbers = true)
3858
_vs = map(x-> x isa Operation ? x.op : x, vs)

test/system_construction.jl

+12
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,18 @@ sol = Sfw_t \ @SArray ones(3)
7474
@test sol isa SArray
7575
@test sol -(I/0.2 - J)\ones(3)
7676

77+
eqs = [D(x) ~ σ*(y-x),
78+
D(y) ~ x*-z)-y*t,
79+
D(z) ~ x*y - β*z]
80+
de = ODESystem(eqs)
81+
ModelingToolkit.calculate_tgrad(de)
82+
83+
tgrad_oop, tgrad_iip = eval.(ModelingToolkit.generate_tgrad(de))
84+
@test tgrad_oop(u,p,t) == [0.0,-u[2],0.0]
85+
du = zeros(3)
86+
tgrad_iip(du,u,p,t)
87+
@test du == [0.0,-u[2],0.0]
88+
7789
@testset "time-varying parameters" begin
7890
@parameters σ′(t-1)
7991
eqs = [D(x) ~ σ′*(y-x),

0 commit comments

Comments
 (0)