@@ -206,24 +206,24 @@ function bayes_unpack_data(prob, p::AbstractVector{<:Pair})
206
206
(pdist, IndexKeyMap (prob, pkeys))
207
207
end
208
208
209
- Turing. @model function bayesianODE (prob, t, pdist, pkeys, data, noise_prior)
209
+ Turing. @model function bayesianODE (prob, alg, t, pdist, pkeys, data, datamap , noise_prior)
210
210
σ ~ noise_prior
211
211
212
212
pprior ~ product_distribution (pdist)
213
213
214
214
prob = _remake (prob, (prob. tspan[1 ], t[end ]), pkeys, pprior)
215
- sol = solve (prob, saveat = t)
215
+ sol = solve (prob, alg, saveat = t)
216
216
if ! SciMLBase. successful_retcode (sol)
217
217
Turing. DynamicPPL. acclogp!! (__varinfo__, - Inf )
218
218
return nothing
219
219
end
220
220
for i in eachindex (data)
221
- data[i]. second ~ MvNormal (sol[data[i] . first] , σ^ 2 * I)
221
+ data[i] ~ MvNormal (datamap ( sol) , σ^ 2 * I)
222
222
end
223
223
return nothing
224
224
end
225
225
226
- Turing. @model function bayesianODE (prob,
226
+ Turing. @model function bayesianODE (prob, alg,
227
227
pdist,
228
228
pkeys,
229
229
ts,
@@ -236,7 +236,7 @@ Turing.@model function bayesianODE(prob,
236
236
pprior ~ product_distribution (pdist)
237
237
238
238
prob = _remake (prob, (prob. tspan[1 ], lastt), pkeys, pprior)
239
- sol = solve (prob)
239
+ sol = solve (prob, alg )
240
240
if ! SciMLBase. successful_retcode (sol)
241
241
Turing. DynamicPPL. acclogp!! (__varinfo__, - Inf )
242
242
return nothing
@@ -264,18 +264,19 @@ end
264
264
Base. length (ws:: WeightedSol ) = length (first (ws. sols))
265
265
Base. size (ws:: WeightedSol ) = (length (first (ws. sols)),)
266
266
function Base. getindex (ws:: WeightedSol{T} , i:: Int ) where {T}
267
- s = zero (T)
268
- w = zero (T)
269
- for j in eachindex (ws. weights)
267
+ s:: T = zero (T)
268
+ w:: T = zero (T)
269
+ @inbounds for j in eachindex (ws. weights)
270
270
w += ws. weights[j]
271
271
s += ws. weights[j] * ws. sols[j][i]
272
272
end
273
273
return s + (one (T) - w) * ws. sols[end ][i]
274
274
end
275
- function WeightedSol (sols, select, weights)
276
- T = eltype (weights)
277
- s = map (Base. Fix2 (getindex, select), sols)
278
- WeightedSol {T} (s, weights)
275
+ function WeightedSol (sols, select, i:: Int , weights)
276
+ s = map (sols, select) do sol, sel
277
+ @view (sol[sel. indices[i], :])
278
+ end
279
+ WeightedSol {eltype(weights)} (s, weights)
279
280
end
280
281
function bayes_unpack_data (probs, p:: Tuple{Vararg{<:AbstractVector{<:Pair}}} , data)
281
282
pdist, pkeys = bayes_unpack_data (probs, p)
@@ -305,43 +306,46 @@ function flatten(x::Tuple)
305
306
reduce (vcat, x), Grouper (map (length, x))
306
307
end
307
308
308
- function getsols (probs, probspkeys, ppriors, t:: AbstractArray )
309
- map (probs, probspkeys, ppriors) do prob, pkeys, pprior
309
+ function getsols (probs, algs, probspkeys, ppriors, t:: AbstractArray )
310
+ map (probs, algs, probspkeys, ppriors) do prob, alg , pkeys, pprior
310
311
newprob = _remake (prob, (prob. tspan[1 ], t[end ]), pkeys, pprior)
311
- solve (newprob, saveat = t)
312
+ solve (newprob, alg, saveat = t)
312
313
end
313
314
end
314
- function getsols (probs, probspkeys, ppriors, lastt:: Number )
315
- map (probs, probspkeys, ppriors) do prob, pkeys, pprior
315
+ function getsols (probs, algs, probspkeys, ppriors, lastt:: Number )
316
+ map (probs, algs, probspkeys, ppriors) do prob, alg , pkeys, pprior
316
317
newprob = _remake (prob, (prob. tspan[1 ], lastt), pkeys, pprior)
317
- solve (newprob)
318
+ solve (newprob, alg )
318
319
end
319
320
end
320
321
321
322
Turing. @model function ensemblebayesianODE (probs:: Union{Tuple, AbstractVector} ,
323
+ algs,
322
324
t,
323
325
pdist,
324
326
grouppriorsfunc,
325
327
probspkeys,
326
328
data,
329
+ datamaps,
327
330
noise_prior)
328
331
σ ~ noise_prior
329
332
ppriors ~ product_distribution (pdist)
330
333
331
334
Nprobs = length (probs)
332
335
Nprobs⁻¹ = inv (Nprobs)
333
336
weights ~ MvNormal (Distributions. Fill (Nprobs⁻¹, Nprobs - 1 ), Nprobs⁻¹)
334
- sols = getsols (probs, probspkeys, grouppriorsfunc (ppriors), t)
337
+ sols = getsols (probs, algs, probspkeys, grouppriorsfunc (ppriors), t)
335
338
if ! all (SciMLBase. successful_retcode, sols)
336
339
Turing. DynamicPPL. acclogp!! (__varinfo__, - Inf )
337
340
return nothing
338
341
end
339
342
for i in eachindex (data)
340
- data[i]. second ~ MvNormal (WeightedSol (sols, data[i] . first , weights), σ^ 2 * I)
343
+ data[i] ~ MvNormal (WeightedSol (sols, datamaps, i , weights), σ^ 2 * I)
341
344
end
342
345
return nothing
343
346
end
344
347
Turing. @model function ensemblebayesianODE (probs:: Union{Tuple, AbstractVector} ,
348
+ algs,
345
349
pdist,
346
350
grouppriorsfunc,
347
351
probspkeys,
@@ -353,7 +357,7 @@ Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector},
353
357
σ ~ noise_prior
354
358
ppriors ~ product_distribution (pdist)
355
359
356
- sols = getsols (probs, probspkeys, grouppriorsfunc (ppriors), lastt)
360
+ sols = getsols (probs, algs, probspkeys, grouppriorsfunc (ppriors), lastt)
357
361
358
362
Nprobs = length (probs)
359
363
Nprobs⁻¹ = inv (Nprobs)
@@ -411,7 +415,14 @@ function bayesian_datafit(prob,
411
415
nchains = 4 ,
412
416
niter = 1000 )
413
417
(pdist, pkeys) = bayes_unpack_data (prob, p)
414
- model = bayesianODE (prob, t, pdist, pkeys, data, noise_prior)
418
+ model = bayesianODE (prob,
419
+ first (default_algorithm (prob)),
420
+ t,
421
+ pdist,
422
+ pkeys,
423
+ last .(data),
424
+ IndexKeyMap (prob, data),
425
+ noise_prior)
415
426
chain = Turing. sample (model,
416
427
Turing. NUTS (0.65 ),
417
428
mcmcensemble,
@@ -430,7 +441,15 @@ function bayesian_datafit(prob,
430
441
nchains = 4 ,
431
442
niter = 1_000 )
432
443
pdist, pkeys, ts, lastt, timeseries, datakeys = bayes_unpack_data (prob, p, data)
433
- model = bayesianODE (prob, pdist, pkeys, ts, lastt, timeseries, datakeys, noise_prior)
444
+ model = bayesianODE (prob,
445
+ first (default_algorithm (prob)),
446
+ pdist,
447
+ pkeys,
448
+ ts,
449
+ lastt,
450
+ timeseries,
451
+ datakeys,
452
+ noise_prior)
434
453
chain = Turing. sample (model,
435
454
Turing. NUTS (0.65 ),
436
455
mcmcensemble,
@@ -451,7 +470,10 @@ function bayesian_datafit(probs::Union{Tuple, AbstractVector},
451
470
(pdist_, pkeys) = bayes_unpack_data (p)
452
471
pdist, grouppriorsfunc = flatten (pdist_)
453
472
454
- model = ensemblebayesianODE (probs, t, pdist, grouppriorsfunc, pkeys, data, noise_prior)
473
+ model = ensemblebayesianODE (probs,
474
+ map (first ∘ default_algorithm, probs),
475
+ t, pdist, grouppriorsfunc, pkeys, last .(data),
476
+ map (Base. Fix2 (IndexKeyMap, data), probs), noise_prior)
455
477
chain = Turing. sample (model,
456
478
Turing. NUTS (0.65 ),
457
479
mcmcensemble,
@@ -472,6 +494,7 @@ function bayesian_datafit(probs::Union{Tuple, AbstractVector},
472
494
pdist_, pkeys, ts, lastt, timeseries, datakeys = bayes_unpack_data (p, data)
473
495
pdist, grouppriorsfunc = flatten (pdist_)
474
496
model = ensemblebayesianODE (probs,
497
+ map (first ∘ default_algorithm, probs),
475
498
pdist,
476
499
grouppriorsfunc,
477
500
pkeys,
0 commit comments