@@ -28,7 +28,11 @@ the default behavior).
28
28
"""
29
29
function MTKParameters (
30
30
sys:: AbstractSystem , p, u0 = Dict (); tofloat = false ,
31
- t0 = nothing , substitution_limit = 1000 , floatT = nothing )
31
+ t0 = nothing , substitution_limit = 1000 , floatT = nothing ,
32
+ container_type = Vector)
33
+ if ! (container_type <: AbstractArray )
34
+ container_type = Array
35
+ end
32
36
ic = if has_index_cache (sys) && get_index_cache (sys) != = nothing
33
37
get_index_cache (sys)
34
38
else
@@ -133,18 +137,23 @@ function MTKParameters(
133
137
end
134
138
end
135
139
end
136
- tunable_buffer = narrow_buffer_type (tunable_buffer)
140
+ tunable_buffer = narrow_buffer_type (tunable_buffer; container_type )
137
141
if isempty (tunable_buffer)
138
142
tunable_buffer = SizedVector {0, Float64} ()
139
143
end
140
- initials_buffer = narrow_buffer_type (initials_buffer)
144
+ initials_buffer = narrow_buffer_type (initials_buffer; container_type )
141
145
if isempty (initials_buffer)
142
146
initials_buffer = SizedVector {0, Float64} ()
143
147
end
144
- disc_buffer = narrow_buffer_type .(disc_buffer)
145
- const_buffer = narrow_buffer_type .(const_buffer)
148
+ disc_buffer = narrow_buffer_type .(disc_buffer; container_type )
149
+ const_buffer = narrow_buffer_type .(const_buffer; container_type )
146
150
# Don't narrow nonnumeric types
147
- nonnumeric_buffer = nonnumeric_buffer
151
+ if ! isempty (nonnumeric_buffer)
152
+ nonnumeric_buffer = map (nonnumeric_buffer) do buf
153
+ SymbolicUtils. Code. create_array (
154
+ container_type, nothing , Val (1 ), Val (length (buf)), buf... )
155
+ end
156
+ end
148
157
149
158
mtkps = MTKParameters{
150
159
typeof (tunable_buffer), typeof (initials_buffer), typeof (disc_buffer),
@@ -160,21 +169,44 @@ function rebuild_with_caches(p::MTKParameters, cache_templates::BufferTemplate..
160
169
@set p. caches = buffers
161
170
end
162
171
163
- function narrow_buffer_type (buffer:: AbstractArray )
172
+ function narrow_buffer_type (buffer:: AbstractArray ; container_type = typeof (buffer) )
164
173
type = Union{}
165
174
for x in buffer
166
175
type = promote_type (type, typeof (x))
167
176
end
168
- return convert .(type, buffer)
177
+ return SymbolicUtils. Code. create_array (
178
+ container_type, type, Val (ndims (buffer)), Val (length (buffer)), buffer... )
169
179
end
170
180
171
- function narrow_buffer_type (buffer:: AbstractArray{<:AbstractArray} )
172
- buffer = narrow_buffer_type .(buffer)
181
+ function narrow_buffer_type (
182
+ buffer:: AbstractArray{<:AbstractArray} ; container_type = typeof (buffer))
183
+ type = Union{}
184
+ for arr in buffer
185
+ for x in arr
186
+ type = promote_type (type, typeof (x))
187
+ end
188
+ end
189
+ buffer = map (buffer) do buf
190
+ SymbolicUtils. Code. create_array (
191
+ container_type, type, Val (ndims (buf)), Val (size (buf)), buf... )
192
+ end
193
+ return SymbolicUtils. Code. create_array (
194
+ container_type, nothing , Val (ndims (buffer)), Val (size (buffer)), buffer... )
195
+ end
196
+
197
+ function narrow_buffer_type (buffer:: BlockedArray ; container_type = typeof (parent (buffer)))
173
198
type = Union{}
174
199
for x in buffer
175
- type = promote_type (type, eltype (x))
200
+ type = promote_type (type, typeof (x))
201
+ end
202
+ tmp = SymbolicUtils. Code. create_array (
203
+ container_type, type, Val (ndims (buffer)), Val (size (buffer)), buffer... )
204
+ blocks = ntuple (Val (ndims (buffer))) do i
205
+ bsizes = blocksizes (buffer, i)
206
+ SymbolicUtils. Code. create_array (
207
+ container_type, Int, Val (1 ), Val (length (bsizes)), bsizes... )
176
208
end
177
- return broadcast .(convert, type, buffer )
209
+ return BlockedArray (tmp, blocks ... )
178
210
end
179
211
180
212
function buffer_to_arraypartition (buf)
0 commit comments