@@ -65,6 +65,11 @@ struct ODESystem <: AbstractSystem
65
65
""" Parameter variables."""
66
66
ps:: Vector{Variable}
67
67
"""
68
+ Time-derivative matrix. Note: this field will not be defined until
69
+ [`calculate_tgrad`](@ref) is called on the system.
70
+ """
71
+ tgrad:: RefValue{Vector{Expression}}
72
+ """
68
73
Jacobian matrix. Note: this field will not be defined until
69
74
[`calculate_jacobian`](@ref) is called on the system.
70
75
"""
@@ -99,10 +104,11 @@ function ODESystem(eqs)
99
104
end
100
105
101
106
function ODESystem (deqs:: AbstractVector{DiffEq} , iv, dvs, ps)
107
+ tgrad = RefValue (Vector {Expression} (undef, 0 ))
102
108
jac = RefValue (Matrix {Expression} (undef, 0 , 0 ))
103
109
Wfact = RefValue (Matrix {Expression} (undef, 0 , 0 ))
104
110
Wfact_t = RefValue (Matrix {Expression} (undef, 0 , 0 ))
105
- ODESystem (deqs, iv, dvs, ps, jac, Wfact, Wfact_t)
111
+ ODESystem (deqs, iv, dvs, ps, tgrad, jac, Wfact, Wfact_t)
106
112
end
107
113
108
114
function ODESystem (deqs:: AbstractVector{<:Equation} , iv, dvs, ps)
@@ -133,6 +139,15 @@ independent_variables(sys::ODESystem) = Set{Variable}([sys.iv])
133
139
dependent_variables (sys:: ODESystem ) = Set {Variable} (sys. dvs)
134
140
parameters (sys:: ODESystem ) = Set {Variable} (sys. ps)
135
141
142
+ function calculate_tgrad (sys:: ODESystem )
143
+ isempty (sys. tgrad[]) || return sys. tgrad[] # use cached tgrad, if possible
144
+ rhs = [detime_dvs (eq. rhs) for eq ∈ sys. eqs]
145
+ iv = sys. iv ()
146
+ notime_tgrad = [expand_derivatives (ModelingToolkit. Differential (iv)(r)) for r in rhs]
147
+ tgrad = retime_dvs .(notime_tgrad,(sys. dvs,),iv)
148
+ sys. tgrad[] = tgrad
149
+ return tgrad
150
+ end
136
151
137
152
function calculate_jacobian (sys:: ODESystem )
138
153
isempty (sys. jac[]) || return sys. jac[] # use cached Jacobian, if possible
@@ -160,6 +175,11 @@ function (f::ODEToExpr)(O::Operation)
160
175
end
161
176
(f:: ODEToExpr )(x) = convert (Expr, x)
162
177
178
+ function generate_tgrad (sys:: ODESystem , dvs = sys. dvs, ps = sys. ps, expression = Val{true }; kwargs... )
179
+ tgrad = calculate_tgrad (sys)
180
+ return build_function (tgrad, dvs, ps, (sys. iv. name,), ODEToExpr (sys), expression; kwargs... )
181
+ end
182
+
163
183
function generate_jacobian (sys:: ODESystem , dvs = sys. dvs, ps = sys. ps, expression = Val{true }; kwargs... )
164
184
jac = calculate_jacobian (sys)
165
185
return build_function (jac, dvs, ps, (sys. iv. name,), ODEToExpr (sys), expression; kwargs... )
@@ -218,13 +238,21 @@ are used to set the order of the dependent variable and parameter vectors,
218
238
respectively.
219
239
"""
220
240
function DiffEqBase. ODEFunction {iip} (sys:: ODESystem , dvs = sys. dvs, ps = sys. ps;
221
- version = nothing ,
241
+ version = nothing , tgrad = false ,
222
242
jac = false , Wfact = false ) where {iip}
223
243
f_oop,f_iip = generate_function (sys, dvs, ps, Val{false })
224
244
225
245
f (u,p,t) = f_oop (u,p,t)
226
246
f (du,u,p,t) = f_iip (du,u,p,t)
227
247
248
+ if tgrad
249
+ tgrad_oop,tgrad_iip = generate_tgrad (sys, dvs, ps, Val{false })
250
+ _tgrad (u,p,t) = tgrad_oop (u,p,t)
251
+ _tgrad (J,u,p,t) = tgrad_iip (J,u,p,t)
252
+ else
253
+ _tgrad = nothing
254
+ end
255
+
228
256
if jac
229
257
jac_oop,jac_iip = generate_jacobian (sys, dvs, ps, Val{false })
230
258
_jac (u,p,t) = jac_oop (u,p,t)
@@ -246,6 +274,7 @@ function DiffEqBase.ODEFunction{iip}(sys::ODESystem, dvs = sys.dvs, ps = sys.ps;
246
274
end
247
275
248
276
ODEFunction {iip} (f,jac= _jac,
277
+ tgrad = tgrad,
249
278
Wfact = _Wfact,
250
279
Wfact_t = _Wfact_t,
251
280
syms = string .(sys. dvs))
0 commit comments