Skip to content

Commit 385c034

Browse files
Merge pull request #2928 from AayushSabharwal/as/param-init
feat: allow parameters to be unknowns in the initialization system
2 parents 8ce64bf + 84a1f2e commit 385c034

15 files changed

+803
-45
lines changed

Project.toml

+5-3
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ ConstructionBase = "1"
8181
DataInterpolations = "6.4"
8282
DataStructures = "0.17, 0.18"
8383
DeepDiffs = "1"
84-
DiffEqBase = "6.103.0"
84+
DiffEqBase = "6.157"
8585
DiffEqCallbacks = "2.16, 3, 4"
8686
DiffEqNoiseProcess = "5"
8787
DiffRules = "0.1, 1.0"
@@ -110,12 +110,13 @@ NonlinearSolve = "3.14"
110110
OffsetArrays = "1"
111111
OrderedCollections = "1"
112112
OrdinaryDiffEq = "6.82.0"
113+
OrdinaryDiffEqCore = "1.7.0"
113114
PrecompileTools = "1"
114115
REPL = "1"
115116
RecursiveArrayTools = "3.26"
116117
Reexport = "0.2, 1"
117118
RuntimeGeneratedFunctions = "0.5.9"
118-
SciMLBase = "2.55"
119+
SciMLBase = "2.56.1"
119120
SciMLStructures = "1.0"
120121
Serialization = "1"
121122
Setfield = "0.7, 0.8, 1"
@@ -148,6 +149,7 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
148149
OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
149150
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
150151
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
152+
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
151153
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
152154
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
153155
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -162,4 +164,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
162164
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
163165

164166
[targets]
165-
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]
167+
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]

docs/src/tutorials/initialization.md

+134
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,87 @@ long enough you will see that `λ = 0` is required for this equation, but since
201201
problem constructor. Additionally, any warning about not being fully determined can
202202
be suppressed via passing `warn_initialize_determined = false`.
203203

204+
## Initialization of parameters
205+
206+
Parameters may also be treated as unknowns in the initialization system. Doing so works
207+
almost identically to the standard case. For a parameter to be an initialization unknown
208+
(henceforth referred to as "solved parameter") it must represent a floating point number
209+
(have a `symtype` of `Real` or `<:AbstractFloat`) or an array of such numbers. Additionally,
210+
it must have a guess and one of the following conditions must be satisfied:
211+
212+
1. The value of the parameter as passed to `ODEProblem` is an expression involving other
213+
variables/parameters. For example, if `[p => 2q + x]` is passed to `ODEProblem`. In
214+
this case, `p ~ 2q + x` is used as an equation during initialization.
215+
2. The parameter has a default (and no value for it is given to `ODEProblem`, since
216+
that is condition 1). The default will be used as an equation during initialization.
217+
3. The parameter has a default of `missing`. If `ODEProblem` is given a value for this
218+
parameter, it is used as an equation during initialization (whether the value is an
219+
expression or not).
220+
4. `ODEProblem` is given a value of `missing` for the parameter. If the parameter has a
221+
default, it will be used as an equation during initialization.
222+
223+
All parameter dependencies (where the dependent parameter is a floating point number or
224+
array thereof) also become equations during initialization, and the dependent parameters
225+
become unknowns.
226+
227+
`remake` will reconstruct the initialization system and problem, given the new
228+
constraints provided to it. The new values will be combined with the original
229+
variable-value mapping provided to `ODEProblem` and used to construct the initialization
230+
problem.
231+
232+
### Parameter initialization by example
233+
234+
Consider the following system, where the sum of two unknowns is a constant parameter
235+
`total`.
236+
237+
```@example paraminit
238+
using ModelingToolkit, OrdinaryDiffEq # hidden
239+
using ModelingToolkit: t_nounits as t, D_nounits as D # hidden
240+
241+
@variables x(t) y(t)
242+
@parameters total
243+
@mtkbuild sys = ODESystem([D(x) ~ -x, total ~ x + y], t;
244+
defaults = [total => missing], guesses = [total => 1.0])
245+
```
246+
247+
Given any two of `x`, `y` and `total` we can determine the remaining variable.
248+
249+
```@example paraminit
250+
prob = ODEProblem(sys, [x => 1.0, y => 2.0], (0.0, 1.0))
251+
integ = init(prob, Tsit5())
252+
@assert integ.ps[total] ≈ 3.0 # hide
253+
integ.ps[total]
254+
```
255+
256+
Suppose we want to re-create this problem, but now solve for `x` given `total` and `y`:
257+
258+
```@example paraminit
259+
prob2 = remake(prob; u0 = [y => 1.0], p = [total => 4.0])
260+
initsys = prob2.f.initializeprob.f.sys
261+
```
262+
263+
The system is now overdetermined. In fact:
264+
265+
```@example paraminit
266+
[equations(initsys); observed(initsys)]
267+
```
268+
269+
The system can never be satisfied and will always lead to an `InitialFailure`. This is
270+
due to the aforementioned behavior of retaining the original variable-value mapping
271+
provided to `ODEProblem`. To fix this, we pass `x => nothing` to `remake` to remove its
272+
retained value.
273+
274+
```@example paraminit
275+
prob2 = remake(prob; u0 = [y => 1.0, x => nothing], p = [total => 4.0])
276+
initsys = prob2.f.initializeprob.f.sys
277+
```
278+
279+
The system is fully determined, and the equations are solvable.
280+
281+
```@example
282+
[equations(initsys); observed(initsys)]
283+
```
284+
204285
## Diving Deeper: Constructing the Initialization System
205286

206287
To get a better sense of the initialization system and to help debug it, you can construct
@@ -383,3 +464,56 @@ sol[α * x - β * x * y]
383464
```@example init
384465
plot(sol)
385466
```
467+
468+
## Solving for parameters during initialization
469+
470+
Sometimes, it is necessary to solve for a parameter during initialization. For example,
471+
given a spring-mass system we want to find the un-stretched length of the spring given
472+
that the initial condition of the system is its steady state.
473+
474+
```@example init
475+
using ModelingToolkitStandardLibrary.Mechanical.TranslationalModelica: Fixed, Mass, Spring,
476+
Force, Damper
477+
using ModelingToolkitStandardLibrary.Blocks: Constant
478+
479+
@named mass = Mass(; m = 1.0, s = 1.0, v = 0.0, a = 0.0)
480+
@named fixed = Fixed(; s0 = 0.0)
481+
@named spring = Spring(; c = 2.0, s_rel0 = missing)
482+
@named gravity = Force()
483+
@named constant = Constant(; k = 9.81)
484+
@named damper = Damper(; d = 0.1)
485+
@mtkbuild sys = ODESystem(
486+
[connect(fixed.flange, spring.flange_a), connect(spring.flange_b, mass.flange_a),
487+
connect(mass.flange_a, gravity.flange), connect(constant.output, gravity.f),
488+
connect(fixed.flange, damper.flange_a), connect(damper.flange_b, mass.flange_a)],
489+
t;
490+
systems = [fixed, spring, mass, gravity, constant, damper],
491+
guesses = [spring.s_rel0 => 1.0])
492+
```
493+
494+
Note that we explicitly provide `s_rel0 = missing` to the spring. Parameters are only
495+
solved for during initialization if their value (either default, or explicitly passed
496+
to the `ODEProblem` constructor) is `missing`. We also need to provide a guess for the
497+
parameter.
498+
499+
If a parameter is not given a value of `missing`, and does not have a default or initial
500+
value, the `ODEProblem` constructor will throw an error. If the parameter _does_ have a
501+
value of `missing`, it must be given a guess.
502+
503+
```@example init
504+
prob = ODEProblem(sys, [], (0.0, 1.0))
505+
prob.ps[spring.s_rel0]
506+
```
507+
508+
Note that the value of the parameter in the problem is zero, similar to unknowns that
509+
are solved for during initialization.
510+
511+
```@example init
512+
integ = init(prob)
513+
integ.ps[spring.s_rel0]
514+
```
515+
516+
The un-stretched length of the spring is now correctly calculated. The same result can be
517+
achieved if `s_rel0 = missing` is omitted when constructing `spring`, and instead
518+
`spring.s_rel0 => missing` is passed to the `ODEProblem` constructor along with values
519+
of other parameters.

src/systems/abstractsystem.jl

+22-1
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,15 @@ function has_observed_with_lhs(sys, sym)
736736
end
737737
end
738738

739+
function has_parameter_dependency_with_lhs(sys, sym)
740+
has_parameter_dependencies(sys) || return false
741+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
742+
return any(isequal(sym), ic.dependent_pars)
743+
else
744+
return any(isequal(sym), [eq.lhs for eq in parameter_dependencies(sys)])
745+
end
746+
end
747+
739748
function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym)
740749
if is_variable(sys, sym) || is_independent_variable(sys, sym)
741750
push!(ts_idxs, ContinuousTimeseries())
@@ -1344,9 +1353,21 @@ function namespace_assignment(eq::Assignment, sys)
13441353
Assignment(_lhs, _rhs)
13451354
end
13461355

1356+
function is_array_of_symbolics(x)
1357+
symbolic_type(x) == ArraySymbolic() && return true
1358+
symbolic_type(x) == ScalarSymbolic() && return false
1359+
x isa AbstractArray &&
1360+
any(y -> symbolic_type(y) != NotSymbolic() || is_array_of_symbolics(y), x)
1361+
end
1362+
13471363
function namespace_expr(
13481364
O, sys, n = nameof(sys); ivs = independent_variables(sys))
13491365
O = unwrap(O)
1366+
# Exceptions for arrays of symbolic and Ref of a symbolic, the latter
1367+
# of which shows up in broadcasts
1368+
if symbolic_type(O) == NotSymbolic() && !(O isa AbstractArray) && !(O isa Ref)
1369+
return O
1370+
end
13501371
if any(isequal(O), ivs)
13511372
return O
13521373
elseif iscall(O)
@@ -1368,7 +1389,7 @@ function namespace_expr(
13681389
end
13691390
elseif isvariable(O)
13701391
renamespace(n, O)
1371-
elseif O isa Array
1392+
elseif O isa AbstractArray && is_array_of_symbolics(O)
13721393
let sys = sys, n = n
13731394
map(o -> namespace_expr(o, sys, n; ivs), O)
13741395
end

0 commit comments

Comments
 (0)