Skip to content

Commit 82339c5

Browse files
Merge pull request #3605 from AayushSabharwal/as/copy-guesses
feat: use guesses as temporary values for variables solved by initialization
2 parents 5a80fd0 + b315138 commit 82339c5

File tree

3 files changed

+33
-6
lines changed

3 files changed

+33
-6
lines changed

src/systems/problem_utils.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -938,8 +938,7 @@ function maybe_build_initialization_problem(
938938
is_parameter_solvable(p, pmap, defs, guesses) || continue
939939
get(op, p, missing) === missing || continue
940940
p = unwrap(p)
941-
stype = symtype(p)
942-
op[p] = get_temporary_value(p, floatT)
941+
op[p] = getu(initializeprob, p)(initializeprob)
943942
if iscall(p) && operation(p) === getindex
944943
arrp = arguments(p)[1]
945944
op[arrp] = collect(arrp)
@@ -948,7 +947,7 @@ function maybe_build_initialization_problem(
948947

949948
if is_time_dependent(sys)
950949
for v in missing_unknowns
951-
op[v] = get_temporary_value(v, floatT)
950+
op[v] = getu(initializeprob, v)(initializeprob)
952951
end
953952
empty!(missing_unknowns)
954953
end

test/initial_values.jl

+28
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,31 @@ end
281281
@test prob.p isa Vector{Float64}
282282
@test length(prob.p) == 5
283283
end
284+
285+
@testset "Temporary values for solved variables are guesses" begin
286+
@parameters σ ρ β=missing [guess = 8 / 3]
287+
@variables x(t) y(t) z(t) w(t) w2(t)
288+
289+
eqs = [D(D(x)) ~ σ * (y - x),
290+
D(y) ~ x *- z) - y,
291+
D(z) ~ x * y - β * z,
292+
w ~ x + y + z + 2 * β,
293+
0 ~ x^2 + y^2 - w2^2
294+
]
295+
296+
@mtkbuild sys = ODESystem(eqs, t)
297+
298+
u0 = [D(x) => 2.0,
299+
x => 1.0,
300+
y => 0.0,
301+
z => 0.0]
302+
303+
p ==> 28.0,
304+
ρ => 10.0]
305+
306+
tspan = (0.0, 100.0)
307+
prob = ODEProblem(sys, u0, tspan, p, jac = true, guesses = [w2 => -1.0],
308+
warn_initialize_determined = false)
309+
@test prob[w2] -1.0
310+
@test prob.ps[β] 8 / 3
311+
end

test/initializationsystem.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -1186,10 +1186,10 @@ end
11861186
@mtkbuild sys = ODESystem([D(x) ~ x * p + q, x^3 + y^3 ~ 3], t)
11871187
prob = ODEProblem(
11881188
sys, [], (0.0, 1.0), [p => 1.0]; guesses = [x => 1.0, y => 1.0, q => 1.0])
1189-
@test prob[x] == 0.0
1190-
@test prob[y] == 0.0
1189+
@test prob[x] == 1.0
1190+
@test prob[y] == 2.0
11911191
@test prob.ps[p] == 1.0
1192-
@test prob.ps[q] == 0.0
1192+
@test prob.ps[q] == 3.0
11931193
integ = init(prob)
11941194
@test integ[x] 1 / cbrt(3)
11951195
@test integ[y] 2 / cbrt(3)

0 commit comments

Comments
 (0)