Skip to content

Commit 2e31d7a

Browse files
add tgrad to ODESystem
1 parent 15e4f80 commit 2e31d7a

File tree

2 files changed

+52
-3
lines changed

2 files changed

+52
-3
lines changed

src/systems/diffeqs/diffeqsystem.jl

+31-2
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ struct ODESystem <: AbstractSystem
6565
"""Parameter variables."""
6666
ps::Vector{Variable}
6767
"""
68+
Time-derivative matrix. Note: this field will not be defined until
69+
[`calculate_tgrad`](@ref) is called on the system.
70+
"""
71+
tgrad::RefValue{Vector{Expression}}
72+
"""
6873
Jacobian matrix. Note: this field will not be defined until
6974
[`calculate_jacobian`](@ref) is called on the system.
7075
"""
@@ -99,10 +104,11 @@ function ODESystem(eqs)
99104
end
100105

101106
function ODESystem(deqs::AbstractVector{DiffEq}, iv, dvs, ps)
107+
tgrad = RefValue(Vector{Expression}(undef, 0))
102108
jac = RefValue(Matrix{Expression}(undef, 0, 0))
103109
Wfact = RefValue(Matrix{Expression}(undef, 0, 0))
104110
Wfact_t = RefValue(Matrix{Expression}(undef, 0, 0))
105-
ODESystem(deqs, iv, dvs, ps, jac, Wfact, Wfact_t)
111+
ODESystem(deqs, iv, dvs, ps, tgrad, jac, Wfact, Wfact_t)
106112
end
107113

108114
function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps)
@@ -133,6 +139,15 @@ independent_variables(sys::ODESystem) = Set{Variable}([sys.iv])
133139
dependent_variables(sys::ODESystem) = Set{Variable}(sys.dvs)
134140
parameters(sys::ODESystem) = Set{Variable}(sys.ps)
135141

142+
function calculate_tgrad(sys::ODESystem)
143+
isempty(sys.tgrad[]) || return sys.tgrad[] # use cached tgrad, if possible
144+
rhs = [detime_dvs(eq.rhs) for eq sys.eqs]
145+
iv = sys.iv()
146+
notime_tgrad = [expand_derivatives(ModelingToolkit.Differential(iv)(r)) for r in rhs]
147+
tgrad = retime_dvs.(notime_tgrad,(sys.dvs,),iv)
148+
sys.tgrad[] = tgrad
149+
return tgrad
150+
end
136151

137152
function calculate_jacobian(sys::ODESystem)
138153
isempty(sys.jac[]) || return sys.jac[] # use cached Jacobian, if possible
@@ -160,6 +175,11 @@ function (f::ODEToExpr)(O::Operation)
160175
end
161176
(f::ODEToExpr)(x) = convert(Expr, x)
162177

178+
function generate_tgrad(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)
179+
tgrad = calculate_tgrad(sys)
180+
return build_function(tgrad, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
181+
end
182+
163183
function generate_jacobian(sys::ODESystem, dvs = sys.dvs, ps = sys.ps, expression = Val{true}; kwargs...)
164184
jac = calculate_jacobian(sys)
165185
return build_function(jac, dvs, ps, (sys.iv.name,), ODEToExpr(sys), expression; kwargs...)
@@ -218,13 +238,21 @@ are used to set the order of the dependent variable and parameter vectors,
218238
respectively.
219239
"""
220240
function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs = sys.dvs, ps = sys.ps;
221-
version = nothing,
241+
version = nothing, tgrad=false,
222242
jac = false, Wfact = false) where {iip}
223243
f_oop,f_iip = generate_function(sys, dvs, ps, Val{false})
224244

225245
f(u,p,t) = f_oop(u,p,t)
226246
f(du,u,p,t) = f_iip(du,u,p,t)
227247

248+
if tgrad
249+
tgrad_oop,tgrad_iip = generate_tgrad(sys, dvs, ps, Val{false})
250+
_tgrad(u,p,t) = tgrad_oop(u,p,t)
251+
_tgrad(J,u,p,t) = tgrad_iip(J,u,p,t)
252+
else
253+
_tgrad = nothing
254+
end
255+
228256
if jac
229257
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps, Val{false})
230258
_jac(u,p,t) = jac_oop(u,p,t)
@@ -246,6 +274,7 @@ function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs = sys.dvs, ps = sys.ps;
246274
end
247275

248276
ODEFunction{iip}(f,jac=_jac,
277+
tgrad = tgrad,
249278
Wfact = _Wfact,
250279
Wfact_t = _Wfact_t,
251280
syms = string.(sys.dvs))

src/utils.jl

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using MacroTools
22

3-
43
function Base.convert(::Type{Expression}, ex::Expr)
54
ex.head === :if && (ex = Expr(:call, ifelse, ex.args...))
65
ex.head === :call || throw(ArgumentError("internal representation does not support non-call Expr"))
@@ -13,6 +12,9 @@ end
1312
Base.convert(::Type{Expression}, x::Expression) = x
1413
Base.convert(::Type{Expression}, x::Number) = Constant(x)
1514
Base.convert(::Type{Expression}, x::Bool) = Constant(x)
15+
Base.convert(::Type{Expression}, x::Variable) = convert(Operation,x)
16+
Base.convert(::Type{Expression}, x::Operation) = x
17+
Base.convert(::Type{Expression}, x::Symbol) = Base.convert(Expression,eval(x))
1618
Expression(x::Bool) = Constant(x)
1719

1820
function build_expr(head::Symbol, args)
@@ -33,6 +35,24 @@ function flatten_expr!(x)
3335
x
3436
end
3537

38+
function detime_dvs(op::Operation)
39+
if op.op isa Variable
40+
Operation(Variable(op.op.name),Expression[])
41+
else
42+
Operation(op.op,detime_dvs.(op.args))
43+
end
44+
end
45+
detime_dvs(op::Constant) = op
46+
47+
function retime_dvs(op::Operation,dvs,iv)
48+
if op.op isa Variable && op.op dvs
49+
Operation(Variable(op.op.name),Expression[iv])
50+
else
51+
Operation(op.op,retime_dvs.(op.args,(dvs,),iv))
52+
end
53+
end
54+
retime_dvs(op::Constant,dvs,iv) = op
55+
3656
function build_function(rhss, vs, ps = (), args = (), conv = simplified_expr, expression = Val{true};
3757
checkbounds = false, constructor=nothing, linenumbers = true)
3858
_vs = map(x-> x isa Operation ? x.op : x, vs)

0 commit comments

Comments
 (0)