Skip to content

Commit cf6a8d0

Browse files
fix: allow specifying type of buffers inside MTKParameters
1 parent 30bf372 commit cf6a8d0

File tree

4 files changed

+82
-15
lines changed

4 files changed

+82
-15
lines changed

src/systems/parameter_buffer.jl

+44-12
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@ the default behavior).
2828
"""
2929
function MTKParameters(
3030
sys::AbstractSystem, p, u0 = Dict(); tofloat = false,
31-
t0 = nothing, substitution_limit = 1000, floatT = nothing)
31+
t0 = nothing, substitution_limit = 1000, floatT = nothing,
32+
container_type = Vector)
33+
if !(container_type <: AbstractArray)
34+
container_type = Array
35+
end
3236
ic = if has_index_cache(sys) && get_index_cache(sys) !== nothing
3337
get_index_cache(sys)
3438
else
@@ -133,18 +137,23 @@ function MTKParameters(
133137
end
134138
end
135139
end
136-
tunable_buffer = narrow_buffer_type(tunable_buffer)
140+
tunable_buffer = narrow_buffer_type(tunable_buffer; container_type)
137141
if isempty(tunable_buffer)
138142
tunable_buffer = SizedVector{0, Float64}()
139143
end
140-
initials_buffer = narrow_buffer_type(initials_buffer)
144+
initials_buffer = narrow_buffer_type(initials_buffer; container_type)
141145
if isempty(initials_buffer)
142146
initials_buffer = SizedVector{0, Float64}()
143147
end
144-
disc_buffer = narrow_buffer_type.(disc_buffer)
145-
const_buffer = narrow_buffer_type.(const_buffer)
148+
disc_buffer = narrow_buffer_type.(disc_buffer; container_type)
149+
const_buffer = narrow_buffer_type.(const_buffer; container_type)
146150
# Don't narrow nonnumeric types
147-
nonnumeric_buffer = nonnumeric_buffer
151+
if !isempty(nonnumeric_buffer)
152+
nonnumeric_buffer = map(nonnumeric_buffer) do buf
153+
SymbolicUtils.Code.create_array(
154+
container_type, nothing, Val(1), Val(length(buf)), buf...)
155+
end
156+
end
148157

149158
mtkps = MTKParameters{
150159
typeof(tunable_buffer), typeof(initials_buffer), typeof(disc_buffer),
@@ -160,21 +169,44 @@ function rebuild_with_caches(p::MTKParameters, cache_templates::BufferTemplate..
160169
@set p.caches = buffers
161170
end
162171

163-
function narrow_buffer_type(buffer::AbstractArray)
172+
function narrow_buffer_type(buffer::AbstractArray; container_type = typeof(buffer))
164173
type = Union{}
165174
for x in buffer
166175
type = promote_type(type, typeof(x))
167176
end
168-
return convert.(type, buffer)
177+
return SymbolicUtils.Code.create_array(
178+
container_type, type, Val(ndims(buffer)), Val(length(buffer)), buffer...)
169179
end
170180

171-
function narrow_buffer_type(buffer::AbstractArray{<:AbstractArray})
172-
buffer = narrow_buffer_type.(buffer)
181+
function narrow_buffer_type(
182+
buffer::AbstractArray{<:AbstractArray}; container_type = typeof(buffer))
183+
type = Union{}
184+
for arr in buffer
185+
for x in arr
186+
type = promote_type(type, typeof(x))
187+
end
188+
end
189+
buffer = map(buffer) do buf
190+
SymbolicUtils.Code.create_array(
191+
container_type, type, Val(ndims(buf)), Val(size(buf)), buf...)
192+
end
193+
return SymbolicUtils.Code.create_array(
194+
container_type, nothing, Val(ndims(buffer)), Val(size(buffer)), buffer...)
195+
end
196+
197+
function narrow_buffer_type(buffer::BlockedArray; container_type = typeof(parent(buffer)))
173198
type = Union{}
174199
for x in buffer
175-
type = promote_type(type, eltype(x))
200+
type = promote_type(type, typeof(x))
201+
end
202+
tmp = SymbolicUtils.Code.create_array(
203+
container_type, type, Val(ndims(buffer)), Val(size(buffer)), buffer...)
204+
blocks = ntuple(Val(ndims(buffer))) do i
205+
bsizes = blocksizes(buffer, i)
206+
SymbolicUtils.Code.create_array(
207+
container_type, Int, Val(1), Val(length(bsizes)), bsizes...)
176208
end
177-
return broadcast.(convert, type, buffer)
209+
return BlockedArray(tmp, blocks...)
178210
end
179211

180212
function buffer_to_arraypartition(buf)

src/systems/problem_utils.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1095,7 +1095,7 @@ function process_SciMLProblem(
10951095
end
10961096
evaluate_varmap!(op, ps; limit = substitution_limit)
10971097
if is_split(sys)
1098-
p = MTKParameters(sys, op; floatT = floatT)
1098+
p = MTKParameters(sys, op; floatT = floatT, container_type = pType)
10991099
else
11001100
p = better_varmap_to_vars(op, ps; tofloat, container_type = pType)
11011101
end

test/initial_values.jl

+26
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using ModelingToolkit
22
using ModelingToolkit: t_nounits as t, D_nounits as D, get_u0
33
using OrdinaryDiffEq
44
using DataInterpolations
5+
using StaticArrays
56
using SymbolicIndexingInterface: getu
67

78
@variables x(t)[1:3]=[1.0, 2.0, 3.0] y(t) z(t)[1:2]
@@ -281,3 +282,28 @@ end
281282
@test prob.p isa Vector{Float64}
282283
@test length(prob.p) == 5
283284
end
285+
286+
@testset "MTKParameters uses given `pType` for inner buffers" begin
287+
@parameters σ ρ β
288+
@variables x(t) y(t) z(t)
289+
290+
eqs = [D(D(x)) ~ σ * (y - x),
291+
D(y) ~ x *- z) - y,
292+
D(z) ~ x * y - β * z]
293+
294+
@mtkbuild sys = ODESystem(eqs, t)
295+
296+
u0 = SA[D(x) => 2.0f0,
297+
x => 1.0f0,
298+
y => 0.0f0,
299+
z => 0.0f0]
300+
301+
p = SA[σ => 28.0f0,
302+
ρ => 10.0f0,
303+
β => 8.0f0 / 3]
304+
305+
tspan = (0.0f0, 100.0f0)
306+
prob = ODEProblem(sys, u0, tspan, p)
307+
@test prob.p.tunable isa SVector
308+
@test prob.p.initials isa SVector
309+
end

test/mtkparameters.jl

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
using ModelingToolkit
22
using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters
3-
using SymbolicIndexingInterface
3+
using SymbolicIndexingInterface, StaticArrays
44
using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants
5-
using BlockArrays: BlockedArray, Block
5+
using BlockArrays: BlockedArray, BlockedVector, Block
66
using OrdinaryDiffEq
77
using ForwardDiff
88
using JET
@@ -27,6 +27,15 @@ end
2727
@test getp(sys, a)(ps) == getp(sys, b)(ps) == getp(sys, c)(ps) == 0.0
2828
@test getp(sys, d)(ps) isa Int
2929

30+
@testset "`container_type`" begin
31+
ps2 = MTKParameters(sys, ivs; container_type = SVector)
32+
@test ps2.tunable isa SVector
33+
@test ps2.initials isa SVector
34+
@test ps2.discrete isa Tuple{<:BlockedVector{Float64, <:SVector}}
35+
@test ps2.constant isa Tuple{<:SVector, <:SVector, <:SVector{1, <:SMatrix}}
36+
@test ps2.nonnumeric isa Tuple{<:SVector}
37+
end
38+
3039
ivs[a] = 1.0
3140
ps = MTKParameters(sys, ivs)
3241
for (p, val) in ivs

0 commit comments

Comments
 (0)