Skip to content

Commit d352cc0

Browse files
Merge pull request #3213 from AayushSabharwal/as/nonlinear-scc
feat: initial implementation of `SCCNonlinearProblem` codegen
2 parents 36bd3fd + 0537715 commit d352cc0

17 files changed

+745
-232
lines changed

Project.toml

+7-3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
4444
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
4545
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
4646
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
47+
SCCNonlinearSolve = "9dfe8606-65a1-4bb3-9748-cb89d1561431"
4748
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
4849
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
4950
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
@@ -120,13 +121,15 @@ NonlinearSolve = "3.14, 4"
120121
OffsetArrays = "1"
121122
OrderedCollections = "1"
122123
OrdinaryDiffEq = "6.82.0"
123-
OrdinaryDiffEqCore = "1.7.0"
124+
OrdinaryDiffEqCore = "1.13.0"
125+
OrdinaryDiffEqNonlinearSolve = "1.3.0"
124126
PrecompileTools = "1"
125127
REPL = "1"
126128
RecursiveArrayTools = "3.26"
127129
Reexport = "0.2, 1"
128130
RuntimeGeneratedFunctions = "0.5.9"
129-
SciMLBase = "2.64"
131+
SCCNonlinearSolve = "1.0.0"
132+
SciMLBase = "2.66"
130133
SciMLStructures = "1.0"
131134
Serialization = "1"
132135
Setfield = "0.7, 0.8, 1"
@@ -160,6 +163,7 @@ OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
160163
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
161164
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
162165
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
166+
OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8"
163167
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
164168
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
165169
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -174,4 +178,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
174178
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
175179

176180
[targets]
177-
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]
181+
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"]

src/ModelingToolkit.jl

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ using Distributed
5050
import JuliaFormatter
5151
using MLStyle
5252
using NonlinearSolve
53+
import SCCNonlinearSolve
5354
using Reexport
5455
using RecursiveArrayTools
5556
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix

src/structural_transformation/bipartite_tearing/modia_tearing.jl

+10-4
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,15 @@ function tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, eqs, vars
6262
return nothing
6363
end
6464

65+
function build_var_eq_matching(structure::SystemStructure, ::Type{U} = Unassigned;
66+
varfilter::F2 = v -> true, eqfilter::F3 = eq -> true) where {U, F2, F3}
67+
@unpack graph, solvable_graph = structure
68+
var_eq_matching = maximal_matching(graph, eqfilter, varfilter, U)
69+
matching_len = max(length(var_eq_matching),
70+
maximum(x -> x isa Int ? x : 0, var_eq_matching, init = 0))
71+
return complete(var_eq_matching, matching_len), matching_len
72+
end
73+
6574
function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
6675
::Type{U} = Unassigned;
6776
varfilter::F2 = v -> true,
@@ -78,10 +87,7 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
7887
# find them here [TODO: It would be good to have an explicit example of this.]
7988

8089
@unpack graph, solvable_graph = structure
81-
var_eq_matching = maximal_matching(graph, eqfilter, varfilter, U)
82-
matching_len = max(length(var_eq_matching),
83-
maximum(x -> x isa Int ? x : 0, var_eq_matching, init = 0))
84-
var_eq_matching = complete(var_eq_matching, matching_len)
90+
var_eq_matching, matching_len = build_var_eq_matching(structure, U; varfilter, eqfilter)
8591
full_var_eq_matching = copy(var_eq_matching)
8692
var_sccs = find_var_sccs(graph, var_eq_matching)
8793
vargraph = DiCMOBiGraph{true}(graph, 0, Matching(matching_len))

src/systems/abstractsystem.jl

+54-137
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,12 @@ object.
162162
"""
163163
function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys),
164164
ps = parameters(sys); wrap_code = nothing, postprocess_fbody = nothing, states = nothing,
165-
expression = Val{true}, eval_expression = false, eval_module = @__MODULE__, kwargs...)
165+
expression = Val{true}, eval_expression = false, eval_module = @__MODULE__,
166+
cachesyms::Tuple = (), kwargs...)
166167
if !iscomplete(sys)
167168
error("A completed system is required. Call `complete` or `structural_simplify` on the system.")
168169
end
169-
p = reorder_parameters(sys, unwrap.(ps))
170+
p = (reorder_parameters(sys, unwrap.(ps))..., cachesyms...)
170171
isscalar = !(exprs isa AbstractArray)
171172
if wrap_code === nothing
172173
wrap_code = isscalar ? identity : (identity, identity)
@@ -187,7 +188,7 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
187188
postprocess_fbody,
188189
states,
189190
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
190-
wrap_array_vars(sys, exprs; dvs) .∘
191+
wrap_array_vars(sys, exprs; dvs, cachesyms) .∘
191192
wrap_parameter_dependencies(sys, isscalar),
192193
expression = Val{true}
193194
)
@@ -199,7 +200,7 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
199200
postprocess_fbody,
200201
states,
201202
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘
202-
wrap_array_vars(sys, exprs; dvs) .∘
203+
wrap_array_vars(sys, exprs; dvs, cachesyms) .∘
203204
wrap_parameter_dependencies(sys, isscalar),
204205
expression = Val{true}
205206
)
@@ -231,133 +232,59 @@ end
231232

232233
function wrap_array_vars(
233234
sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys),
234-
inputs = nothing, history = false)
235+
inputs = nothing, history = false, cachesyms::Tuple = ())
235236
isscalar = !(exprs isa AbstractArray)
236-
array_vars = Dict{Any, AbstractArray{Int}}()
237-
if dvs !== nothing
238-
for (j, x) in enumerate(dvs)
239-
if iscall(x) && operation(x) == getindex
240-
arg = arguments(x)[1]
241-
inds = get!(() -> Int[], array_vars, arg)
242-
push!(inds, j)
243-
end
244-
end
245-
for (k, inds) in array_vars
246-
if inds == (inds′ = inds[1]:inds[end])
247-
array_vars[k] = inds′
248-
end
249-
end
237+
var_to_arridxs = Dict()
250238

251-
uind = 1
252-
else
239+
if dvs === nothing
253240
uind = 0
254-
end
255-
# values are (indexes, index of buffer, size of parameter)
256-
array_parameters = Dict{Any, Tuple{AbstractArray{Int}, Int, Tuple{Vararg{Int}}}}()
257-
# If for some reason different elements of an array parameter are in different buffers
258-
other_array_parameters = Dict{Any, Any}()
259-
260-
hasinputs = inputs !== nothing
261-
input_vars = Dict{Any, AbstractArray{Int}}()
262-
if hasinputs
263-
for (j, x) in enumerate(inputs)
264-
if iscall(x) && operation(x) == getindex
265-
arg = arguments(x)[1]
266-
inds = get!(() -> Int[], input_vars, arg)
267-
push!(inds, j)
268-
end
269-
end
270-
for (k, inds) in input_vars
271-
if inds == (inds′ = inds[1]:inds[end])
272-
input_vars[k] = inds′
273-
end
274-
end
275-
end
276-
if has_index_cache(sys)
277-
ic = get_index_cache(sys)
278241
else
279-
ic = nothing
280-
end
281-
if ps isa Tuple && eltype(ps) <: AbstractArray
282-
ps = Iterators.flatten(ps)
283-
end
284-
for p in ps
285-
p = unwrap(p)
286-
if iscall(p) && operation(p) == getindex
287-
p = arguments(p)[1]
288-
end
289-
symtype(p) <: AbstractArray && Symbolics.shape(p) != Symbolics.Unknown() || continue
290-
scal = collect(p)
291-
# all scalarized variables are in `ps`
292-
any(isequal(p), ps) || all(x -> any(isequal(x), ps), scal) || continue
293-
(haskey(array_parameters, p) || haskey(other_array_parameters, p)) && continue
294-
295-
idx = parameter_index(sys, p)
296-
idx isa Int && continue
297-
if idx isa ParameterIndex
298-
if idx.portion != SciMLStructures.Tunable()
299-
continue
300-
end
301-
array_parameters[p] = (vec(idx.idx), 1, size(idx.idx))
242+
uind = 1
243+
for (i, x) in enumerate(dvs)
244+
iscall(x) && operation(x) == getindex || continue
245+
arg = arguments(x)[1]
246+
inds = get!(() -> [], var_to_arridxs, arg)
247+
push!(inds, (uind, i))
248+
end
249+
end
250+
p_start = uind + 1 + history
251+
rps = (reorder_parameters(sys, ps)..., cachesyms...)
252+
if inputs !== nothing
253+
rps = (inputs, rps...)
254+
end
255+
for sym in reduce(vcat, rps; init = [])
256+
iscall(sym) && operation(sym) == getindex || continue
257+
arg = arguments(sym)[1]
258+
259+
bufferidx = findfirst(buf -> any(isequal(sym), buf), rps)
260+
idxinbuffer = findfirst(isequal(sym), rps[bufferidx])
261+
inds = get!(() -> [], var_to_arridxs, arg)
262+
push!(inds, (p_start + bufferidx - 1, idxinbuffer))
263+
end
264+
265+
viewsyms = Dict()
266+
splitsyms = Dict()
267+
for (arrsym, idxs) in var_to_arridxs
268+
length(idxs) == length(arrsym) || continue
269+
# allequal(first, idxs) is a 1.11 feature
270+
if allequal(Iterators.map(first, idxs))
271+
viewsyms[arrsym] = (first(first(idxs)), reshape(last.(idxs), size(arrsym)))
302272
else
303-
# idx === nothing
304-
idxs = map(Base.Fix1(parameter_index, sys), scal)
305-
if first(idxs) isa ParameterIndex
306-
buffer_idxs = map(Base.Fix1(iterated_buffer_index, ic), idxs)
307-
if allequal(buffer_idxs)
308-
buffer_idx = first(buffer_idxs)
309-
if first(idxs).portion == SciMLStructures.Tunable()
310-
idxs = map(x -> x.idx, idxs)
311-
else
312-
idxs = map(x -> x.idx[end], idxs)
313-
end
314-
else
315-
other_array_parameters[p] = scal
316-
continue
317-
end
318-
else
319-
buffer_idx = 1
320-
end
321-
322-
sz = size(idxs)
323-
if vec(idxs) == idxs[begin]:idxs[end]
324-
idxs = idxs[begin]:idxs[end]
325-
elseif vec(idxs) == idxs[begin]:-1:idxs[end]
326-
idxs = idxs[begin]:-1:idxs[end]
327-
end
328-
idxs = vec(idxs)
329-
array_parameters[p] = (idxs, buffer_idx, sz)
273+
splitsyms[arrsym] = reshape(idxs, size(arrsym))
330274
end
331275
end
332-
333-
inputind = if history
334-
uind + 2
335-
else
336-
uind + 1
337-
end
338-
params_offset = if history && hasinputs
339-
uind + 2
340-
elseif history || hasinputs
341-
uind + 1
342-
else
343-
uind
344-
end
345276
if isscalar
346277
function (expr)
347278
Func(
348279
expr.args,
349280
[],
350281
Let(
351282
vcat(
352-
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
353-
[k :(view($(expr.args[inputind].name), $v))
354-
for (k, v) in input_vars],
355-
[k :(reshape(
356-
view($(expr.args[params_offset + buffer_idx].name), $idxs),
357-
$sz))
358-
for (k, (idxs, buffer_idx, sz)) in array_parameters],
359-
[k Code.MakeArray(v, symtype(k))
360-
for (k, v) in other_array_parameters]
283+
[sym :(view($(expr.args[i].name), $idxs))
284+
for (sym, (i, idxs)) in viewsyms],
285+
[sym
286+
MakeArray([expr.args[bufi].elems[vali] for (bufi, vali) in idxs],
287+
expr.args[idxs[1][1]]) for (sym, idxs) in splitsyms]
361288
),
362289
expr.body,
363290
false
@@ -371,15 +298,11 @@ function wrap_array_vars(
371298
[],
372299
Let(
373300
vcat(
374-
[k :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars],
375-
[k :(view($(expr.args[inputind].name), $v))
376-
for (k, v) in input_vars],
377-
[k :(reshape(
378-
view($(expr.args[params_offset + buffer_idx].name), $idxs),
379-
$sz))
380-
for (k, (idxs, buffer_idx, sz)) in array_parameters],
381-
[k Code.MakeArray(v, symtype(k))
382-
for (k, v) in other_array_parameters]
301+
[sym :(view($(expr.args[i].name), $idxs))
302+
for (sym, (i, idxs)) in viewsyms],
303+
[sym
304+
MakeArray([expr.args[bufi].elems[vali] for (bufi, vali) in idxs],
305+
expr.args[idxs[1][1]]) for (sym, idxs) in splitsyms]
383306
),
384307
expr.body,
385308
false
@@ -392,17 +315,11 @@ function wrap_array_vars(
392315
[],
393316
Let(
394317
vcat(
395-
[k :(view($(expr.args[uind + 1].name), $v))
396-
for (k, v) in array_vars],
397-
[k :(view($(expr.args[inputind + 1].name), $v))
398-
for (k, v) in input_vars],
399-
[k :(reshape(
400-
view($(expr.args[params_offset + buffer_idx + 1].name),
401-
$idxs),
402-
$sz))
403-
for (k, (idxs, buffer_idx, sz)) in array_parameters],
404-
[k Code.MakeArray(v, symtype(k))
405-
for (k, v) in other_array_parameters]
318+
[sym :(view($(expr.args[i + 1].name), $idxs))
319+
for (sym, (i, idxs)) in viewsyms],
320+
[sym MakeArray(
321+
[expr.args[bufi + 1].elems[vali] for (bufi, vali) in idxs],
322+
expr.args[idxs[1][1] + 1]) for (sym, idxs) in splitsyms]
406323
),
407324
expr.body,
408325
false

0 commit comments

Comments
 (0)