Skip to content

Commit 08226a4

Browse files
define mapobs behavior for vector of indexes (#147)
* define mapobs behavior for vector of indexes * add batched keyword * type * test colon
1 parent 134ee91 commit 08226a4

File tree

2 files changed

+109
-14
lines changed

2 files changed

+109
-14
lines changed

src/obstransform.jl

+47-14
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,66 @@
11

22
# mapobs
33

4-
struct MappedData{F,D} <: AbstractDataContainer
4+
struct MappedData{batched, F, D} <: AbstractDataContainer
55
f::F
66
data::D
77
end
88

9-
Base.show(io::IO, data::MappedData) = print(io, "mapobs($(data.f), $(summary(data.data)))")
10-
Base.show(io::IO, data::MappedData{F,<:AbstractArray}) where {F} =
11-
print(io, "mapobs($(data.f), $(ShowLimit(data.data, limit=80)))")
9+
function Base.show(io::IO, data::MappedData{batched}) where {batched}
10+
print(io, "mapobs(")
11+
print(IOContext(io, :compact=>true), data.f)
12+
print(io, ", ")
13+
print(IOContext(io, :compact=>true), data.data)
14+
print(io, "; batched=:$(batched))")
15+
end
16+
1217
Base.length(data::MappedData) = numobs(data.data)
13-
Base.getindex(data::MappedData, idx::Int) = data.f(getobs(data.data, idx))
14-
Base.getindex(data::MappedData, idxs::AbstractVector) = data.f.(getobs(data.data, idxs))
18+
Base.getindex(data::MappedData, ::Colon) = data[1:length(data)]
19+
20+
Base.getindex(data::MappedData{:auto}, idx::Int) = data.f(getobs(data.data, idx))
21+
Base.getindex(data::MappedData{:auto}, idxs::AbstractVector) = data.f(getobs(data.data, idxs))
22+
23+
Base.getindex(data::MappedData{:never}, idx::Int) = data.f(getobs(data.data, idx))
24+
Base.getindex(data::MappedData{:never}, idxs::AbstractVector) = [data.f(getobs(data.data, idx)) for idx in idxs]
25+
26+
Base.getindex(data::MappedData{:always}, idx::Int) = getobs(data.f(getobs(data.data, [idx])), 1)
27+
Base.getindex(data::MappedData{:always}, idxs::AbstractVector) = data.f(getobs(data.data, idxs))
1528

1629

1730
"""
18-
mapobs(f, data)
31+
mapobs(f, data; batched=:auto)
1932
2033
Lazily map `f` over the observations in a data container `data`.
34+
Returns a new data container `mdata` that can be indexed and has a length.
35+
Indexing triggers the transformation `f`.
36+
37+
The batched keyword argument controls the behavior of `mdata[idx]` and `mdata[idxs]`
38+
where `idx` is an integer and `idxs` is a vector of integers:
39+
- `batched=:auto` (default). Let `f` handle the two cases.
40+
Call `f(getobs(data, idx))` and `f(getobs(data, idxs))`.
41+
- `batched=:never`. `f` is always called on a single observation.
42+
Call `f(getobs(data, idx))` and `[f(getobs(data, idx)) for idx in idxs]`.
43+
- `batched=:always`. `f` is always called on a batch of observations.
44+
Call `getobs(f(getobs(data, [idx])), 1)` and `f(getobs(data, idxs))`.
45+
46+
# Examples
47+
2148
```julia
22-
data = 1:10
23-
getobs(data, 8) == 8
24-
mdata = mapobs(-, data)
25-
getobs(mdata, 8) == -8
49+
julia> data = (a=[1,2,3], b=[1,2,3]);
50+
51+
julia> mdata = mapobs(data) do x
52+
(c = x.a .+ x.b, d = x.a .- x.b)
53+
end
54+
mapobs(#25, (a = [1, 2, 3], b = [1, 2, 3]); batched=:auto))
55+
56+
julia> mdata[1]
57+
(c = 2, d = 0)
58+
59+
julia> mdata[1:2]
60+
(c = [2, 4], d = [0, 0])
2661
```
2762
"""
28-
mapobs(f, data) = MappedData(f, data)
29-
mapobs(f::typeof(identity), data) = data
30-
63+
mapobs(f::F, data::D; batched=:auto) where {F,D} = MappedData{batched, F, D}(f, data)
3164

3265
"""
3366
mapobs(fs, data)

test/obstransform.jl

+62
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,74 @@
33
mdata = mapobs(-, data)
44
@test getobs(mdata, 8) == -8
55

6+
@test length(mdata) == 10
7+
@test numobs(mdata) == 10
8+
69
mdata2 = mapobs((-, x -> 2x), data)
710
@test getobs(mdata2, 8) == (-8, 16)
811

912
nameddata = mapobs((x = sqrt, y = log), data)
1013
@test getobs(nameddata, 10) == (x = sqrt(10), y = log(10))
1114
@test getobs(nameddata.x, 10) == sqrt(10)
15+
16+
# colon
17+
@test mapobs(x -> 2x, [1:10;])[:] == [2:2:20;]
18+
19+
@testset "batched = :auto" begin
20+
data = (a = [1:10;],)
21+
22+
m = mapobs(data; batched=:auto) do x
23+
@test x.a isa Int
24+
return (; c = 2 .* x.a)
25+
end[1]
26+
@test m == (; c = 2)
27+
m = mapobs(data) do x
28+
@test x.a isa Vector{Int}
29+
return (; c = 2 .* x.a)
30+
end[1:2]
31+
@test m == (; c = [2, 4])
32+
33+
# check that :auto is the default
34+
m = mapobs(data) do x
35+
@test x.a isa Int
36+
return (; c = 2 .* x.a)
37+
end[1]
38+
@test m == (; c = 2)
39+
m = mapobs(data) do x
40+
@test x.a isa Vector{Int}
41+
return (; c = 2 .* x.a)
42+
end[1:2]
43+
@test m == (; c = [2, 4])
44+
end
45+
46+
@testset "batched = :always" begin
47+
data = (; a = [1:10;],)
48+
49+
m = mapobs(data; batched=:always) do x
50+
@test x.a isa Vector{Int}
51+
return (; c = 2 .* x.a)
52+
end[1]
53+
@test m == (; c = 2)
54+
m = mapobs(data; batched=:always) do x
55+
@test x.a isa Vector{Int}
56+
return (; c = 2 .* x.a)
57+
end[1:2]
58+
@test m == (; c = [2, 4])
59+
end
60+
61+
@testset "batched = :never" begin
62+
data = (; a = [1:10;],)
63+
m = mapobs(data; batched=:never) do x
64+
@test x.a isa Int
65+
return (; c = 2 .* x.a)
66+
end[1]
67+
@test m == (; c = 2)
68+
m = mapobs(data; batched=:never) do x
69+
@test x.a isa Int
70+
return (; c = 2 .* x.a)
71+
end[1:2]
72+
@test m == [(; c = 2), (; c = 4)]
73+
end
1274
end
1375

1476
@testset "filterobs" begin

0 commit comments

Comments
 (0)