@@ -162,11 +162,12 @@ object.
162
162
"""
163
163
function generate_custom_function (sys:: AbstractSystem , exprs, dvs = unknowns (sys),
164
164
ps = parameters (sys); wrap_code = nothing , postprocess_fbody = nothing , states = nothing ,
165
- expression = Val{true }, eval_expression = false , eval_module = @__MODULE__ , kwargs... )
165
+ expression = Val{true }, eval_expression = false , eval_module = @__MODULE__ ,
166
+ cachesyms:: Tuple = (), kwargs... )
166
167
if ! iscomplete (sys)
167
168
error (" A completed system is required. Call `complete` or `structural_simplify` on the system." )
168
169
end
169
- p = reorder_parameters (sys, unwrap .(ps))
170
+ p = ( reorder_parameters (sys, unwrap .(ps)) ... , cachesyms ... )
170
171
isscalar = ! (exprs isa AbstractArray)
171
172
if wrap_code === nothing
172
173
wrap_code = isscalar ? identity : (identity, identity)
@@ -187,7 +188,7 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
187
188
postprocess_fbody,
188
189
states,
189
190
wrap_code = wrap_code .∘ wrap_mtkparameters (sys, isscalar) .∘
190
- wrap_array_vars (sys, exprs; dvs) .∘
191
+ wrap_array_vars (sys, exprs; dvs, cachesyms ) .∘
191
192
wrap_parameter_dependencies (sys, isscalar),
192
193
expression = Val{true }
193
194
)
@@ -199,7 +200,7 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
199
200
postprocess_fbody,
200
201
states,
201
202
wrap_code = wrap_code .∘ wrap_mtkparameters (sys, isscalar) .∘
202
- wrap_array_vars (sys, exprs; dvs) .∘
203
+ wrap_array_vars (sys, exprs; dvs, cachesyms ) .∘
203
204
wrap_parameter_dependencies (sys, isscalar),
204
205
expression = Val{true }
205
206
)
@@ -231,133 +232,59 @@ end
231
232
232
233
function wrap_array_vars (
233
234
sys:: AbstractSystem , exprs; dvs = unknowns (sys), ps = parameters (sys),
234
- inputs = nothing , history = false )
235
+ inputs = nothing , history = false , cachesyms :: Tuple = () )
235
236
isscalar = ! (exprs isa AbstractArray)
236
- array_vars = Dict {Any, AbstractArray{Int}} ()
237
- if dvs != = nothing
238
- for (j, x) in enumerate (dvs)
239
- if iscall (x) && operation (x) == getindex
240
- arg = arguments (x)[1 ]
241
- inds = get! (() -> Int[], array_vars, arg)
242
- push! (inds, j)
243
- end
244
- end
245
- for (k, inds) in array_vars
246
- if inds == (inds′ = inds[1 ]: inds[end ])
247
- array_vars[k] = inds′
248
- end
249
- end
237
+ var_to_arridxs = Dict ()
250
238
251
- uind = 1
252
- else
239
+ if dvs === nothing
253
240
uind = 0
254
- end
255
- # values are (indexes, index of buffer, size of parameter)
256
- array_parameters = Dict{Any, Tuple{AbstractArray{Int}, Int, Tuple{Vararg{Int}}}}()
257
- # If for some reason different elements of an array parameter are in different buffers
258
- other_array_parameters = Dict {Any, Any} ()
259
-
260
- hasinputs = inputs != = nothing
261
- input_vars = Dict {Any, AbstractArray{Int}} ()
262
- if hasinputs
263
- for (j, x) in enumerate (inputs)
264
- if iscall (x) && operation (x) == getindex
265
- arg = arguments (x)[1 ]
266
- inds = get! (() -> Int[], input_vars, arg)
267
- push! (inds, j)
268
- end
269
- end
270
- for (k, inds) in input_vars
271
- if inds == (inds′ = inds[1 ]: inds[end ])
272
- input_vars[k] = inds′
273
- end
274
- end
275
- end
276
- if has_index_cache (sys)
277
- ic = get_index_cache (sys)
278
241
else
279
- ic = nothing
280
- end
281
- if ps isa Tuple && eltype (ps) <: AbstractArray
282
- ps = Iterators. flatten (ps)
283
- end
284
- for p in ps
285
- p = unwrap (p)
286
- if iscall (p) && operation (p) == getindex
287
- p = arguments (p)[1 ]
288
- end
289
- symtype (p) <: AbstractArray && Symbolics. shape (p) != Symbolics. Unknown () || continue
290
- scal = collect (p)
291
- # all scalarized variables are in `ps`
292
- any (isequal (p), ps) || all (x -> any (isequal (x), ps), scal) || continue
293
- (haskey (array_parameters, p) || haskey (other_array_parameters, p)) && continue
294
-
295
- idx = parameter_index (sys, p)
296
- idx isa Int && continue
297
- if idx isa ParameterIndex
298
- if idx. portion != SciMLStructures. Tunable ()
299
- continue
300
- end
301
- array_parameters[p] = (vec (idx. idx), 1 , size (idx. idx))
242
+ uind = 1
243
+ for (i, x) in enumerate (dvs)
244
+ iscall (x) && operation (x) == getindex || continue
245
+ arg = arguments (x)[1 ]
246
+ inds = get! (() -> [], var_to_arridxs, arg)
247
+ push! (inds, (uind, i))
248
+ end
249
+ end
250
+ p_start = uind + 1 + history
251
+ rps = (reorder_parameters (sys, ps)... , cachesyms... )
252
+ if inputs != = nothing
253
+ rps = (inputs, rps... )
254
+ end
255
+ for sym in reduce (vcat, rps; init = [])
256
+ iscall (sym) && operation (sym) == getindex || continue
257
+ arg = arguments (sym)[1 ]
258
+
259
+ bufferidx = findfirst (buf -> any (isequal (sym), buf), rps)
260
+ idxinbuffer = findfirst (isequal (sym), rps[bufferidx])
261
+ inds = get! (() -> [], var_to_arridxs, arg)
262
+ push! (inds, (p_start + bufferidx - 1 , idxinbuffer))
263
+ end
264
+
265
+ viewsyms = Dict ()
266
+ splitsyms = Dict ()
267
+ for (arrsym, idxs) in var_to_arridxs
268
+ length (idxs) == length (arrsym) || continue
269
+ # allequal(first, idxs) is a 1.11 feature
270
+ if allequal (Iterators. map (first, idxs))
271
+ viewsyms[arrsym] = (first (first (idxs)), reshape (last .(idxs), size (arrsym)))
302
272
else
303
- # idx === nothing
304
- idxs = map (Base. Fix1 (parameter_index, sys), scal)
305
- if first (idxs) isa ParameterIndex
306
- buffer_idxs = map (Base. Fix1 (iterated_buffer_index, ic), idxs)
307
- if allequal (buffer_idxs)
308
- buffer_idx = first (buffer_idxs)
309
- if first (idxs). portion == SciMLStructures. Tunable ()
310
- idxs = map (x -> x. idx, idxs)
311
- else
312
- idxs = map (x -> x. idx[end ], idxs)
313
- end
314
- else
315
- other_array_parameters[p] = scal
316
- continue
317
- end
318
- else
319
- buffer_idx = 1
320
- end
321
-
322
- sz = size (idxs)
323
- if vec (idxs) == idxs[begin ]: idxs[end ]
324
- idxs = idxs[begin ]: idxs[end ]
325
- elseif vec (idxs) == idxs[begin ]: - 1 : idxs[end ]
326
- idxs = idxs[begin ]: - 1 : idxs[end ]
327
- end
328
- idxs = vec (idxs)
329
- array_parameters[p] = (idxs, buffer_idx, sz)
273
+ splitsyms[arrsym] = reshape (idxs, size (arrsym))
330
274
end
331
275
end
332
-
333
- inputind = if history
334
- uind + 2
335
- else
336
- uind + 1
337
- end
338
- params_offset = if history && hasinputs
339
- uind + 2
340
- elseif history || hasinputs
341
- uind + 1
342
- else
343
- uind
344
- end
345
276
if isscalar
346
277
function (expr)
347
278
Func (
348
279
expr. args,
349
280
[],
350
281
Let (
351
282
vcat (
352
- [k ← :(view ($ (expr. args[uind]. name), $ v)) for (k, v) in array_vars],
353
- [k ← :(view ($ (expr. args[inputind]. name), $ v))
354
- for (k, v) in input_vars],
355
- [k ← :(reshape (
356
- view ($ (expr. args[params_offset + buffer_idx]. name), $ idxs),
357
- $ sz))
358
- for (k, (idxs, buffer_idx, sz)) in array_parameters],
359
- [k ← Code. MakeArray (v, symtype (k))
360
- for (k, v) in other_array_parameters]
283
+ [sym ← :(view ($ (expr. args[i]. name), $ idxs))
284
+ for (sym, (i, idxs)) in viewsyms],
285
+ [sym ←
286
+ MakeArray ([expr. args[bufi]. elems[vali] for (bufi, vali) in idxs],
287
+ expr. args[idxs[1 ][1 ]]) for (sym, idxs) in splitsyms]
361
288
),
362
289
expr. body,
363
290
false
@@ -371,15 +298,11 @@ function wrap_array_vars(
371
298
[],
372
299
Let (
373
300
vcat (
374
- [k ← :(view ($ (expr. args[uind]. name), $ v)) for (k, v) in array_vars],
375
- [k ← :(view ($ (expr. args[inputind]. name), $ v))
376
- for (k, v) in input_vars],
377
- [k ← :(reshape (
378
- view ($ (expr. args[params_offset + buffer_idx]. name), $ idxs),
379
- $ sz))
380
- for (k, (idxs, buffer_idx, sz)) in array_parameters],
381
- [k ← Code. MakeArray (v, symtype (k))
382
- for (k, v) in other_array_parameters]
301
+ [sym ← :(view ($ (expr. args[i]. name), $ idxs))
302
+ for (sym, (i, idxs)) in viewsyms],
303
+ [sym ←
304
+ MakeArray ([expr. args[bufi]. elems[vali] for (bufi, vali) in idxs],
305
+ expr. args[idxs[1 ][1 ]]) for (sym, idxs) in splitsyms]
383
306
),
384
307
expr. body,
385
308
false
@@ -392,17 +315,11 @@ function wrap_array_vars(
392
315
[],
393
316
Let (
394
317
vcat (
395
- [k ← :(view ($ (expr. args[uind + 1 ]. name), $ v))
396
- for (k, v) in array_vars],
397
- [k ← :(view ($ (expr. args[inputind + 1 ]. name), $ v))
398
- for (k, v) in input_vars],
399
- [k ← :(reshape (
400
- view ($ (expr. args[params_offset + buffer_idx + 1 ]. name),
401
- $ idxs),
402
- $ sz))
403
- for (k, (idxs, buffer_idx, sz)) in array_parameters],
404
- [k ← Code. MakeArray (v, symtype (k))
405
- for (k, v) in other_array_parameters]
318
+ [sym ← :(view ($ (expr. args[i + 1 ]. name), $ idxs))
319
+ for (sym, (i, idxs)) in viewsyms],
320
+ [sym ← MakeArray (
321
+ [expr. args[bufi + 1 ]. elems[vali] for (bufi, vali) in idxs],
322
+ expr. args[idxs[1 ][1 ] + 1 ]) for (sym, idxs) in splitsyms]
406
323
),
407
324
expr. body,
408
325
false
0 commit comments