Skip to content

Commit 30bf372

Browse files
Merge pull request #3547 from AayushSabharwal/as/lex-sort-eqs
feat: lexicographically sort equations in `structural_simplify`
2 parents 21dd529 + 8bf5d9e commit 30bf372

File tree

13 files changed

+54
-33
lines changed

13 files changed

+54
-33
lines changed

src/structural_transformation/tearing.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ function algebraic_variables_scc(state::TearingState)
6363
all(v -> !isdervar(state.structure, v),
6464
𝑠neighbors(graph, eq))
6565
end))
66-
var_eq_matching = complete(maximal_matching(graph, e -> e in algeqs, v -> v in algvars))
66+
var_eq_matching = complete(
67+
maximal_matching(graph, e -> e in algeqs, v -> v in algvars), ndsts(graph))
6768
var_sccs = find_var_sccs(complete(graph), var_eq_matching)
6869

6970
return var_eq_matching, var_sccs

src/systems/systems.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ topological sort of the observed equations in `sys`.
2424
+ When `simplify=true`, the `simplify` function will be applied during the tearing process.
2525
+ `allow_symbolic=false`, `allow_parameter=true`, and `conservative=false` limit the coefficient types during tearing. In particular, `conservative=true` limits tearing to only solve for trivial linear systems where the coefficient has the absolute value of ``1``.
2626
+ `fully_determined=true` controls whether or not an error will be thrown if the number of equations don't match the number of inputs, outputs, and equations.
27+
+ `sort_eqs=true` controls whether equations are sorted lexicographically before simplification or not.
2728
"""
2829
function structural_simplify(
2930
sys::AbstractSystem, io = nothing; additional_passes = [], simplify = false, split = true,
@@ -69,10 +70,11 @@ function __structural_simplify(sys::SDESystem, args...; kwargs...)
6970
return __structural_simplify(ODESystem(sys), args...; kwargs...)
7071
end
7172

72-
function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = false,
73+
function __structural_simplify(
74+
sys::AbstractSystem, io = nothing; simplify = false, sort_eqs = true,
7375
kwargs...)
7476
sys = expand_connections(sys)
75-
state = TearingState(sys)
77+
state = TearingState(sys; sort_eqs)
7678

7779
@unpack structure, fullvars = state
7880
@unpack graph, var_to_diff, var_types = structure

src/systems/systemstructure.jl

+9-1
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ function is_time_dependent_parameter(p, iv)
260260
(args = arguments(p); length(args)) == 1 && isequal(only(args), iv))
261261
end
262262

263-
function TearingState(sys; quick_cancel = false, check = true)
263+
function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
264264
sys = flatten(sys)
265265
ivs = independent_variables(sys)
266266
iv = length(ivs) == 1 ? ivs[1] : nothing
@@ -381,6 +381,14 @@ function TearingState(sys; quick_cancel = false, check = true)
381381
neqs = length(eqs)
382382
symbolic_incidence = symbolic_incidence[eqs_to_retain]
383383

384+
if sort_eqs
385+
# sort equations lexicographically to reduce simplification issues
386+
# depending on order due to NP-completeness of tearing.
387+
sortidxs = Base.sortperm(eqs, by = string)
388+
eqs = eqs[sortidxs]
389+
symbolic_incidence = symbolic_incidence[sortidxs]
390+
end
391+
384392
### Handle discrete variables
385393
lowest_shift = Dict()
386394
for var in fullvars

test/clock.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,9 @@ sss, = ModelingToolkit._structural_simplify!(deepcopy(tss[1]), (inputs[1], ()))
7272
@test isempty(equations(sss))
7373
d = Clock(dt)
7474
k = ShiftIndex(d)
75-
@test observed(sss) == [yd(k + 1) ~ Sample(dt)(y); r(k + 1) ~ 1.0;
76-
ud(k + 1) ~ kp * (r(k + 1) - yd(k + 1))]
75+
@test issetequal(observed(sss),
76+
[yd(k + 1) ~ Sample(dt)(y); r(k + 1) ~ 1.0;
77+
ud(k + 1) ~ kp * (r(k + 1) - yd(k + 1))])
7778

7879
d = Clock(dt)
7980
# Note that TearingState reorders the equations

test/discrete_system.jl

+13-11
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
- https://github.com/epirecipes/sir-julia/blob/master/markdown/function_map/function_map.md
44
- https://en.wikipedia.org/wiki/Compartmental_models_in_epidemiology#Deterministic_versus_stochastic_epidemic_models
55
=#
6-
using ModelingToolkit, Test
6+
using ModelingToolkit, SymbolicIndexingInterface, Test
77
using ModelingToolkit: t_nounits as t
88
using ModelingToolkit: get_metadata, MTKParameters
99

@@ -37,13 +37,15 @@ syss = structural_simplify(sys)
3737
df = DiscreteFunction(syss)
3838
# iip
3939
du = zeros(3)
40-
u = collect(1:3)
40+
u = ModelingToolkit.better_varmap_to_vars(Dict([S => 1, I => 2, R => 3]), unknowns(syss))
4141
p = MTKParameters(syss, [c, nsteps, δt, β, γ] .=> collect(1:5))
4242
df.f(du, u, p, 0)
43-
@test du [0.01831563888873422, 0.9816849729159067, 4.999999388195359]
43+
reorderer = getu(syss, [S, I, R])
44+
@test reorderer(du) [0.01831563888873422, 0.9816849729159067, 4.999999388195359]
4445

4546
# oop
46-
@test df.f(u, p, 0) [0.01831563888873422, 0.9816849729159067, 4.999999388195359]
47+
@test reorderer(df.f(u, p, 0))
48+
[0.01831563888873422, 0.9816849729159067, 4.999999388195359]
4749

4850
# Problem
4951
u0 = [S(k - 1) => 990.0, I(k - 1) => 10.0, R(k - 1) => 0.0]
@@ -98,12 +100,12 @@ function sir_map!(u_diff, u, p, t)
98100
end
99101
nothing
100102
end;
101-
u0 = prob_map2.u0;
103+
u0 = prob_map2[[S, I, R]];
102104
p = [0.05, 10.0, 0.25, 0.1];
103105
prob_map = DiscreteProblem(sir_map!, u0, tspan, p);
104106
sol_map2 = solve(prob_map, FunctionMap());
105107

106-
@test Array(sol_map) Array(sol_map2)
108+
@test reduce(hcat, sol_map[[S, I, R]]) Array(sol_map2)
107109

108110
# Delayed difference equation
109111
# @variables x(..) y(..) z(t)
@@ -317,9 +319,9 @@ end
317319

318320
import ModelingToolkit: shift2term
319321
# unknowns(de) = xₜ₋₁, x, zₜ₋₁, xₜ₋₂, z
320-
vars = ModelingToolkit.value.(unknowns(de))
321-
@test isequal(shift2term(Shift(t, 1)(vars[1])), vars[2])
322-
@test isequal(shift2term(Shift(t, 1)(vars[4])), vars[1])
323-
@test isequal(shift2term(Shift(t, -1)(vars[5])), vars[3])
324-
@test isequal(shift2term(Shift(t, -2)(vars[2])), vars[4])
322+
vars = sort(ModelingToolkit.value.(unknowns(de)); by = string)
323+
@test isequal(shift2term(Shift(t, 1)(vars[2])), vars[1])
324+
@test isequal(shift2term(Shift(t, 1)(vars[3])), vars[2])
325+
@test isequal(shift2term(Shift(t, -1)(vars[4])), vars[5])
326+
@test isequal(shift2term(Shift(t, -2)(vars[1])), vars[3])
325327
end

test/downstream/linearize.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,9 @@ lsys = ModelingToolkit.reorder_unknowns(lsys0, unknowns(ssys), desired_order)
125125
@test lsys.C == [400 -4000]
126126
@test lsys.D == [4400 -4400]
127127

128-
lsyss, _ = ModelingToolkit.linearize_symbolic(pid, [reference.u, measurement.u],
128+
lsyss0, ssys2 = ModelingToolkit.linearize_symbolic(pid, [reference.u, measurement.u],
129129
[ctr_output.u])
130+
lsyss = ModelingToolkit.reorder_unknowns(lsyss0, unknowns(ssys2), desired_order)
130131

131132
@test ModelingToolkit.fixpoint_sub(
132133
lsyss.A, ModelingToolkit.defaults_and_guesses(pid)) == lsys.A
@@ -138,7 +139,7 @@ lsyss, _ = ModelingToolkit.linearize_symbolic(pid, [reference.u, measurement.u],
138139
lsyss.D, ModelingToolkit.defaults_and_guesses(pid)) == lsys.D
139140

140141
# Test with the reverse desired unknown order as well to verify that similarity transform and reoreder_unknowns really works
141-
lsys = ModelingToolkit.reorder_unknowns(lsys, unknowns(ssys), reverse(desired_order))
142+
lsys = ModelingToolkit.reorder_unknowns(lsys, desired_order, reverse(desired_order))
142143

143144
@test lsys.A == [-10 0; 0 0]
144145
@test lsys.B == [10 -10; 2 -2]

test/initializationsystem.jl

+11-6
Original file line numberDiff line numberDiff line change
@@ -801,12 +801,17 @@ end
801801
end
802802

803803
@parameters p=2.0 q=missing [guess = 1.0] c=1.0
804-
@variables x=1.0 y=2.0 z=3.0
805-
806-
eqs = [0 ~ p * (y - x),
807-
0 ~ x * (q - z) - y,
808-
0 ~ x * y - c * z]
809-
@mtkbuild sys = NonlinearSystem(eqs; initialization_eqs = [p^2 + q^2 + 2p * q ~ 0])
804+
@variables x=1.0 z=3.0
805+
806+
# eqs = [0 ~ p * (y - x),
807+
# 0 ~ x * (q - z) - y,
808+
# 0 ~ x * y - c * z]
809+
# specifically written this way due to
810+
# https://github.com/SciML/NonlinearSolve.jl/issues/586
811+
eqs = [0 ~ -c * z + (q - z) * (x^2)
812+
0 ~ p * (-x + (q - z) * x)]
813+
@named sys = NonlinearSystem(eqs; initialization_eqs = [p^2 + q^2 + 2p * q ~ 0])
814+
sys = complete(sys)
810815
# @mtkbuild sys = NonlinearSystem(
811816
# [p * x^2 + q * y^3 ~ 0, x - q ~ 0]; defaults = [q => missing],
812817
# guesses = [q => 1.0], initialization_eqs = [p^2 + q^2 + 2p * q ~ 0])

test/input_output_handling.jl

+2
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,8 @@ matrices, ssys = linearize(augmented_sys,
414414
augmented_sys.d
415415
], outs;
416416
op = [augmented_sys.u => 0.0, augmented_sys.input.u[2] => 0.0, augmented_sys.d => 0.0])
417+
matrices = ModelingToolkit.reorder_unknowns(
418+
matrices, unknowns(ssys), [ssys.x[2], ssys.integrator.x[1], ssys.x[1]])
417419
@test matrices.A [A [1; 0]; zeros(1, 2) -0.001]
418420
@test matrices.B == I
419421
@test matrices.C == [C zeros(2)]

test/lowering_solving.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ u0 = [lorenz1.x => 1.0,
5959
lorenz1.z => 0.0,
6060
lorenz2.x => 0.0,
6161
lorenz2.y => 1.0,
62-
lorenz2.z => 0.0,
63-
α => 2.0]
62+
lorenz2.z => 0.0]
6463

6564
p = [lorenz1.σ => 10.0,
6665
lorenz1.ρ => 28.0,
@@ -73,5 +72,5 @@ p = [lorenz1.σ => 10.0,
7372
tspan = (0.0, 100.0)
7473
prob = ODEProblem(connected, u0, tspan, p)
7574
sol = solve(prob, Rodas5())
76-
@test maximum(sol[2, :] + sol[6, :] + 2sol[1, :]) < 1e-12
75+
@test maximum(sol[lorenz1.x] + sol[lorenz2.y] + 2sol[α]) < 1e-12
7776
#using Plots; plot(sol,idxs=(:α,Symbol(lorenz1.x),Symbol(lorenz2.y)))

test/odesystem.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -927,7 +927,7 @@ let
927927
sys_simp = structural_simplify(sys_con)
928928
true_eqs = [D(sys.x) ~ sys.v
929929
D(sys.v) ~ ctrl.kv * sys.v + ctrl.kx * sys.x]
930-
@test isequal(full_equations(sys_simp), true_eqs)
930+
@test issetequal(full_equations(sys_simp), true_eqs)
931931
end
932932

933933
let

test/structural_transformation/index_reduction.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ state = TearingState(pendulum)
3737
@test StructuralTransformations.maximal_matching(graph, eq -> true,
3838
v -> var_to_diff[v] === nothing) ==
3939
map(x -> x == 0 ? StructuralTransformations.unassigned : x,
40-
[1, 2, 3, 4, 0, 0, 0, 0, 0])
40+
[3, 4, 2, 5, 0, 0, 0, 0, 0])
4141

4242
using ModelingToolkit
4343
@parameters L g

test/structural_transformation/utils.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ state = TearingState(pendulum)
2323
StructuralTransformations.find_solvables!(state)
2424
sss = state.structure
2525
@unpack graph, solvable_graph, var_to_diff = sss
26-
@test graph.fadjlist == [[1, 7], [2, 8], [3, 5, 9], [4, 6, 9], [5, 6]]
26+
@test sort(graph.fadjlist) == [[1, 7], [2, 8], [3, 5, 9], [4, 6, 9], [5, 6]]
2727
@test length(graph.badjlist) == 9
2828
@test ne(graph) == nnz(incidence_matrix(graph)) == 12
2929
@test nv(solvable_graph) == 9 + 5

test/symbolic_events.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -829,8 +829,8 @@ let
829829
sol = solve(prob, Tsit5(), saveat = 0.1)
830830

831831
@test typeof(oneosc_ce_simpl) == ODESystem
832-
@test sol[1, 6] < 1.0 # test whether x(t) decreases over time
833-
@test sol[1, 18] > 0.5 # test whether event happened
832+
@test sol[oscce.x, 6] < 1.0 # test whether x(t) decreases over time
833+
@test sol[oscce.x, 18] > 0.5 # test whether event happened
834834
end
835835

836836
@testset "Additional SymbolicContinuousCallback options" begin

0 commit comments

Comments
 (0)