Skip to content

Commit 0a70754

Browse files
chriselrodChrisRackauckas
authored andcommitted
Type stability
1 parent 7f3f58d commit 0a70754

File tree

2 files changed

+63
-26
lines changed

2 files changed

+63
-26
lines changed

src/datafit.jl

+47-24
Original file line numberDiff line numberDiff line change
@@ -206,24 +206,24 @@ function bayes_unpack_data(prob, p::AbstractVector{<:Pair})
206206
(pdist, IndexKeyMap(prob, pkeys))
207207
end
208208

209-
Turing.@model function bayesianODE(prob, t, pdist, pkeys, data, noise_prior)
209+
Turing.@model function bayesianODE(prob, alg, t, pdist, pkeys, data, datamap, noise_prior)
210210
σ ~ noise_prior
211211

212212
pprior ~ product_distribution(pdist)
213213

214214
prob = _remake(prob, (prob.tspan[1], t[end]), pkeys, pprior)
215-
sol = solve(prob, saveat = t)
215+
sol = solve(prob, alg, saveat = t)
216216
if !SciMLBase.successful_retcode(sol)
217217
Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf)
218218
return nothing
219219
end
220220
for i in eachindex(data)
221-
data[i].second ~ MvNormal(sol[data[i].first], σ^2 * I)
221+
data[i] ~ MvNormal(datamap(sol), σ^2 * I)
222222
end
223223
return nothing
224224
end
225225

226-
Turing.@model function bayesianODE(prob,
226+
Turing.@model function bayesianODE(prob, alg,
227227
pdist,
228228
pkeys,
229229
ts,
@@ -236,7 +236,7 @@ Turing.@model function bayesianODE(prob,
236236
pprior ~ product_distribution(pdist)
237237

238238
prob = _remake(prob, (prob.tspan[1], lastt), pkeys, pprior)
239-
sol = solve(prob)
239+
sol = solve(prob, alg)
240240
if !SciMLBase.successful_retcode(sol)
241241
Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf)
242242
return nothing
@@ -264,18 +264,19 @@ end
264264
Base.length(ws::WeightedSol) = length(first(ws.sols))
265265
Base.size(ws::WeightedSol) = (length(first(ws.sols)),)
266266
function Base.getindex(ws::WeightedSol{T}, i::Int) where {T}
267-
s = zero(T)
268-
w = zero(T)
269-
for j in eachindex(ws.weights)
267+
s::T = zero(T)
268+
w::T = zero(T)
269+
@inbounds for j in eachindex(ws.weights)
270270
w += ws.weights[j]
271271
s += ws.weights[j] * ws.sols[j][i]
272272
end
273273
return s + (one(T) - w) * ws.sols[end][i]
274274
end
275-
function WeightedSol(sols, select, weights)
276-
T = eltype(weights)
277-
s = map(Base.Fix2(getindex, select), sols)
278-
WeightedSol{T}(s, weights)
275+
function WeightedSol(sols, select, i::Int, weights)
276+
s = map(sols, select) do sol, sel
277+
@view(sol[sel.indices[i], :])
278+
end
279+
WeightedSol{eltype(weights)}(s, weights)
279280
end
280281
function bayes_unpack_data(probs, p::Tuple{Vararg{<:AbstractVector{<:Pair}}}, data)
281282
pdist, pkeys = bayes_unpack_data(probs, p)
@@ -305,43 +306,46 @@ function flatten(x::Tuple)
305306
reduce(vcat, x), Grouper(map(length, x))
306307
end
307308

308-
function getsols(probs, probspkeys, ppriors, t::AbstractArray)
309-
map(probs, probspkeys, ppriors) do prob, pkeys, pprior
309+
function getsols(probs, algs, probspkeys, ppriors, t::AbstractArray)
310+
map(probs, algs, probspkeys, ppriors) do prob, alg, pkeys, pprior
310311
newprob = _remake(prob, (prob.tspan[1], t[end]), pkeys, pprior)
311-
solve(newprob, saveat = t)
312+
solve(newprob, alg, saveat = t)
312313
end
313314
end
314-
function getsols(probs, probspkeys, ppriors, lastt::Number)
315-
map(probs, probspkeys, ppriors) do prob, pkeys, pprior
315+
function getsols(probs, algs, probspkeys, ppriors, lastt::Number)
316+
map(probs, algs, probspkeys, ppriors) do prob, alg, pkeys, pprior
316317
newprob = _remake(prob, (prob.tspan[1], lastt), pkeys, pprior)
317-
solve(newprob)
318+
solve(newprob, alg)
318319
end
319320
end
320321

321322
Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector},
323+
algs,
322324
t,
323325
pdist,
324326
grouppriorsfunc,
325327
probspkeys,
326328
data,
329+
datamaps,
327330
noise_prior)
328331
σ ~ noise_prior
329332
ppriors ~ product_distribution(pdist)
330333

331334
Nprobs = length(probs)
332335
Nprobs⁻¹ = inv(Nprobs)
333336
weights ~ MvNormal(Distributions.Fill(Nprobs⁻¹, Nprobs - 1), Nprobs⁻¹)
334-
sols = getsols(probs, probspkeys, grouppriorsfunc(ppriors), t)
337+
sols = getsols(probs, algs, probspkeys, grouppriorsfunc(ppriors), t)
335338
if !all(SciMLBase.successful_retcode, sols)
336339
Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf)
337340
return nothing
338341
end
339342
for i in eachindex(data)
340-
data[i].second ~ MvNormal(WeightedSol(sols, data[i].first, weights), σ^2 * I)
343+
data[i] ~ MvNormal(WeightedSol(sols, datamaps, i, weights), σ^2 * I)
341344
end
342345
return nothing
343346
end
344347
Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector},
348+
algs,
345349
pdist,
346350
grouppriorsfunc,
347351
probspkeys,
@@ -353,7 +357,7 @@ Turing.@model function ensemblebayesianODE(probs::Union{Tuple, AbstractVector},
353357
σ ~ noise_prior
354358
ppriors ~ product_distribution(pdist)
355359

356-
sols = getsols(probs, probspkeys, grouppriorsfunc(ppriors), lastt)
360+
sols = getsols(probs, algs, probspkeys, grouppriorsfunc(ppriors), lastt)
357361

358362
Nprobs = length(probs)
359363
Nprobs⁻¹ = inv(Nprobs)
@@ -411,7 +415,14 @@ function bayesian_datafit(prob,
411415
nchains = 4,
412416
niter = 1000)
413417
(pdist, pkeys) = bayes_unpack_data(prob, p)
414-
model = bayesianODE(prob, t, pdist, pkeys, data, noise_prior)
418+
model = bayesianODE(prob,
419+
first(default_algorithm(prob)),
420+
t,
421+
pdist,
422+
pkeys,
423+
last.(data),
424+
IndexKeyMap(prob, data),
425+
noise_prior)
415426
chain = Turing.sample(model,
416427
Turing.NUTS(0.65),
417428
mcmcensemble,
@@ -430,7 +441,15 @@ function bayesian_datafit(prob,
430441
nchains = 4,
431442
niter = 1_000)
432443
pdist, pkeys, ts, lastt, timeseries, datakeys = bayes_unpack_data(prob, p, data)
433-
model = bayesianODE(prob, pdist, pkeys, ts, lastt, timeseries, datakeys, noise_prior)
444+
model = bayesianODE(prob,
445+
first(default_algorithm(prob)),
446+
pdist,
447+
pkeys,
448+
ts,
449+
lastt,
450+
timeseries,
451+
datakeys,
452+
noise_prior)
434453
chain = Turing.sample(model,
435454
Turing.NUTS(0.65),
436455
mcmcensemble,
@@ -451,7 +470,10 @@ function bayesian_datafit(probs::Union{Tuple, AbstractVector},
451470
(pdist_, pkeys) = bayes_unpack_data(p)
452471
pdist, grouppriorsfunc = flatten(pdist_)
453472

454-
model = ensemblebayesianODE(probs, t, pdist, grouppriorsfunc, pkeys, data, noise_prior)
473+
model = ensemblebayesianODE(probs,
474+
map(first default_algorithm, probs),
475+
t, pdist, grouppriorsfunc, pkeys, last.(data),
476+
map(Base.Fix2(IndexKeyMap, data), probs), noise_prior)
455477
chain = Turing.sample(model,
456478
Turing.NUTS(0.65),
457479
mcmcensemble,
@@ -472,6 +494,7 @@ function bayesian_datafit(probs::Union{Tuple, AbstractVector},
472494
pdist_, pkeys, ts, lastt, timeseries, datakeys = bayes_unpack_data(p, data)
473495
pdist, grouppriorsfunc = flatten(pdist_)
474496
model = ensemblebayesianODE(probs,
497+
map(first default_algorithm, probs),
475498
pdist,
476499
grouppriorsfunc,
477500
pkeys,

src/keyindexmap.jl

+16-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ struct IndexKeyMap
33
indices::Vector{Int}
44
end
55

6+
# probs support
67
function IndexKeyMap(prob, keys)
78
params = ModelingToolkit.parameters(prob.f.sys)
89
indices = Vector{Int}(undef, length(keys))
@@ -12,7 +13,8 @@ function IndexKeyMap(prob, keys)
1213
return IndexKeyMap(indices)
1314
end
1415

15-
Base.@propagate_inbounds function (ikm::IndexKeyMap)(prob, v::AbstractVector)
16+
Base.@propagate_inbounds function (ikm::IndexKeyMap)(prob::SciMLBase.AbstractDEProblem,
17+
v::AbstractVector)
1618
@boundscheck checkbounds(v, length(ikm.indices))
1719
def = prob.p
1820
ret = Vector{Base.promote_eltype(v, def)}(undef, length(def))
@@ -22,8 +24,20 @@ Base.@propagate_inbounds function (ikm::IndexKeyMap)(prob, v::AbstractVector)
2224
end
2325
return ret
2426
end
25-
2627
function _remake(prob, tspan, ikm::IndexKeyMap, pprior)
2728
p = ikm(prob, pprior)
2829
remake(prob; tspan, p)
2930
end
31+
32+
# data support
33+
function IndexKeyMap(prob, data::AbstractVector{<:Pair})
34+
states = ModelingToolkit.states(prob.f.sys)
35+
indices = Vector{Int}(undef, length(data))
36+
for i in eachindex(data)
37+
indices[i] = findfirst(Base.Fix1(isequal, data[i].first), states)
38+
end
39+
return IndexKeyMap(indices)
40+
end
41+
function (ikm::IndexKeyMap)(sol::SciMLBase.AbstractTimeseriesSolution)
42+
(@view(sol[i, :]) for i in ikm.indices)
43+
end

0 commit comments

Comments
 (0)