Skip to content

Commit 880b2f2

Browse files
committed
minor fix
1 parent 938522d commit 880b2f2

26 files changed

+108
-98
lines changed

Dockerfile

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
FROM julia:1.1
2+
3+
ADD . /RLIntro
4+
WORKDIR /RLIntro
5+
RUN ["julia", "-e", "using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.instantiate(); pkg\"precompile\""]
6+
CMD ["julia"]

Project.toml

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@ version = "0.1.0"
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
88
GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
99
Ju = "449ae9ca-b987-11e8-3919-0764a06dfe61"
10-
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
1110
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1211
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
1312
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1413
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1514
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
16-
StatPlots = "60ddc479-9b66-56df-82fc-76a74619b69c"
1715
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1816
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
17+
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
1918

2019
[extras]
2120
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

README.md

-2
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ julia> @show [f for f in names(RLIntro) if startswith(string(f), "fig")]; # lis
3333
julia> fig_2_2() # reproduce figure_2_2
3434
```
3535

36-
**Notice** that for some figures you may need to install *pdflatex*.
37-
3836
## Develop
3937

4038
If you would like to make some improvements, I'd suggest the following workflow:

src/RLIntro.jl

+10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module RLIntro
22

3+
export plot_all
4+
35
include("environments/environments.jl")
46

57
using Reexport
@@ -17,4 +19,12 @@ include("chapter11/chapter11.jl")
1719
include("chapter12/chapter12.jl")
1820
include("chapter13/chapter13.jl")
1921

22+
function plot_all(fig_dir=".")
23+
for f in names(RLIntro)
24+
if startswith(string(f), "fig")
25+
@eval $f()
26+
end
27+
end
28+
end
29+
2030
end # module

src/chapter02/ten_armed_testbed.jl

+22-23
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
using Ju
22
using ..MultiArmBandits
33
using Statistics
4-
using LaTeXStrings
54
using Plots
65
gr()
76

8-
figpath(f) = "docs/src/assets/figures/figure_$f.png"
7+
98

109
function collect_best_actions()
1110
isbest = Vector{Bool}()
@@ -26,55 +25,55 @@ end
2625

2726
##############################
2827

29-
function fig_2_1()
30-
env = MultiArmBanditsEnv()
31-
f = render(env)
32-
savefig(f, figpath("2_1"))
33-
f
34-
end
28+
# function fig_2_1()
29+
# env = MultiArmBanditsEnv()
30+
# f = render(env)
31+
# savefig(f, "figure_2_1.png")
32+
# f
33+
# end
3534

3635

3736
function fig_2_2()
3837
learner(ϵ) = QLearner(TabularQ(1, 10), EpsilonGreedySelector(ϵ), 0., cached_inverse_decay())
3938
p = plot(layout=(2, 1), dpi=200)
4039
for ϵ in [0.1, 0.01, 0.0]
4140
stats = [bandit_testbed(learner(ϵ)) for _ in 1:2000]
42-
plot!(p, mean(x[1] for x in stats), subplot=1, legend=:bottomright, label=latexstring("\\epsilon="))
43-
plot!(p, mean(x[2] for x in stats), subplot=2, legend=:bottomright, label=latexstring("\\epsilon="))
41+
plot!(p, mean(x[1] for x in stats), subplot=1, legend=:bottomright, label="epsilon=")
42+
plot!(p, mean(x[2] for x in stats), subplot=2, legend=:bottomright, label="epsilon=")
4443
end
45-
savefig(p, figpath("2_2"))
44+
savefig(p, "figure_2_2.png")
4645
p
4746
end
4847

4948
function fig_2_3()
5049
learner1() = QLearner(TabularQ(1, 10, 5.), EpsilonGreedySelector(0.0), 0., 0.1)
5150
learner2() = QLearner(TabularQ(1, 10), EpsilonGreedySelector(0.1), 0., 0.1)
5251
p = plot(legend=:bottomright, dpi=200)
53-
plot!(p, mean(bandit_testbed(learner1())[2] for _ in 1:2000), label=latexstring("Q_1=5, \\epsilon=0."))
54-
plot!(p, mean(bandit_testbed(learner2())[2] for _ in 1:2000), label=latexstring("Q_1=0, \\epsilon=0.1"))
55-
savefig(p, figpath("2_3"))
52+
plot!(p, mean(bandit_testbed(learner1())[2] for _ in 1:2000), label="Q_1=5, epsilon=0.")
53+
plot!(p, mean(bandit_testbed(learner2())[2] for _ in 1:2000), label="Q_1=0, epsilon=0.1")
54+
savefig(p, "figure_2_3.png")
5655
p
5756
end
5857

5958
function fig_2_4()
6059
learner1() = QLearner(TabularQ(1, 10), UpperConfidenceBound(10), 0., 0.1)
6160
learner2() = QLearner(TabularQ(1, 10), EpsilonGreedySelector(0.1), 0., 0.1)
6261
p = plot(legend=:bottomright, dpi=200)
63-
plot!(p, mean(bandit_testbed(learner1())[1] for _ in 1:2000), label=latexstring("UpperConfidenceBound, c=2"))
64-
plot!(p, mean(bandit_testbed(learner2())[1] for _ in 1:2000), label=latexstring("\\epsilon-greedy, \\epsilon=0.1"))
65-
savefig(p, figpath("2_4"))
62+
plot!(p, mean(bandit_testbed(learner1())[1] for _ in 1:2000), label="UpperConfidenceBound, c=2")
63+
plot!(p, mean(bandit_testbed(learner2())[1] for _ in 1:2000), label="epsilon-greedy, epsilon=0.1")
64+
savefig(p, "figure_2_4.png")
6665
p
6766
end
6867

6968
function fig_2_5()
7069
learner(alpha, baseline) = GradientBanditLearner(TabularQ(1, 10), WeightedSample(), alpha, baseline)
7170
truevalue = 4.0
7271
p = plot(legend=:bottomright, dpi=200)
73-
plot!(p, mean(bandit_testbed(learner(0.1, sample_avg()), truevalue)[2] for _ in 1:2000), label=latexstring("\\alpha = 0.1, with baseline"))
74-
plot!(p, mean(bandit_testbed(learner(0.4, sample_avg()), truevalue)[2] for _ in 1:2000), label=latexstring("\\alpha = 0.4, with baseline"))
75-
plot!(p, mean(bandit_testbed(learner(0.1, 0.), truevalue)[2] for _ in 1:2000), label=latexstring("\\alpha = 0.1, without baseline"))
76-
plot!(p, mean(bandit_testbed(learner(0.4, 0.), truevalue)[2] for _ in 1:2000), label=latexstring("\\alpha = 0.4, without baseline"))
77-
savefig(p, figpath("2_5"))
72+
plot!(p, mean(bandit_testbed(learner(0.1, sample_avg()), truevalue)[2] for _ in 1:2000), label="alpha = 0.1, with baseline")
73+
plot!(p, mean(bandit_testbed(learner(0.4, sample_avg()), truevalue)[2] for _ in 1:2000), label="alpha = 0.4, with baseline")
74+
plot!(p, mean(bandit_testbed(learner(0.1, 0.), truevalue)[2] for _ in 1:2000), label="alpha = 0.1, without baseline")
75+
plot!(p, mean(bandit_testbed(learner(0.4, 0.), truevalue)[2] for _ in 1:2000), label="alpha = 0.4, without baseline")
76+
savefig(p, "figure_2_5.png")
7877
p
7978
end
8079

@@ -89,6 +88,6 @@ function fig_2_6()
8988
plot!(p, -5:1, [mean(mean(bandit_testbed(gradient_learner(2.0^i))[1] for _ in 1:2000)) for i in -5:1], label="gradient")
9089
plot!(p, -4:2, [mean(mean(bandit_testbed(UpperConfidenceBound_learner(2.0^i))[1] for _ in 1:2000)) for i in -4:2], label="UCB")
9190
plot!(p, -2:2, [mean(mean(bandit_testbed(greedy_with_init_learner(2.0^i))[1] for _ in 1:2000)) for i in -2:2], label="greedy with initialization")
92-
savefig(p, figpath("2_6"))
91+
savefig(p, "figure_2_6.png")
9392
p
9493
end

src/chapter03/grid_world.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Ju
22
using Plots
33
gr()
44

5-
figpath(f) = "docs/src/assets/figures/figure_$f.png"
5+
66

77
const GridWorldLinearIndices = LinearIndices((5,5))
88
const GridWorldCartesianIndices = CartesianIndices((5,5))
@@ -34,14 +34,14 @@ function fig_3_2()
3434
V, π = TabularV(25), RandomPolicy(fill(0.25, 25, 4))
3535
policy_evaluation!(V, π, GridWorldEnvModel)
3636
p = heatmap(1:5, 1:5, reshape(V.table, 5,5), yflip=true)
37-
savefig(p, figpath("3_2"))
37+
savefig(p, "figure_3_2.png")
3838
p
3939
end
4040

4141
function fig_3_5()
4242
V, π = TabularV(25), DeterministicPolicy(rand(1:4, 25), 4)
4343
policy_iteration!(V, π, GridWorldEnvModel)
4444
p = heatmap(1:5, 1:5, reshape(V.table, 5,5), yflip=true)
45-
savefig(p, figpath("3_5"))
45+
savefig(p, "figure_3_5.png")
4646
p
4747
end

src/chapter04/car_rental.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Distributions
33
using Plots
44
gr()
55

6-
figpath(f) = "docs/src/assets/figures/figure_$f.png"
6+
77

88
const PoissonUpperBound = 10
99
const MaxCars= 20
@@ -52,8 +52,8 @@ function fig_4_2(max_iter=100)
5252
V, π = TabularV((1+MaxCars)^2), DeterministicPolicy(zeros(Int,21^2), length(Actions))
5353
policy_iteration!(V, π, CarRentalEnvModel; γ=0.9, max_iter=max_iter)
5454
p1 = heatmap(0:MaxCars, 0:MaxCars, reshape([decode_action(x) for x in π.table], 1+MaxCars,1+MaxCars))
55-
savefig(p1, figpath("4_2_policy"))
55+
savefig(p1, "figure_4_2_policy.png")
5656
p2 = heatmap(0:MaxCars, 0:MaxCars, reshape(V.table, 1+MaxCars,1+MaxCars))
57-
savefig(p2, figpath("4_2_value"))
57+
savefig(p2, "figure_4_2_value.png")
5858
p1, p2
5959
end

src/chapter04/gambler_problem.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Distributions
33
using Plots
44
gr()
55

6-
figpath(f) = "docs/src/assets/figures/figure_$f.png"
6+
77

88
const pₕ = 0.4
99
const WinCapital = 100
@@ -30,6 +30,6 @@ function fig_4_3(max_iter=typemax(Int))
3030
V = TabularV(1+WinCapital)
3131
value_iteration!(V, GamblerProblemEnvModel; γ=1.0, max_iter=max_iter)
3232
p = plot(V.table[2:end-1])
33-
savefig(p, figpath("4_3"))
33+
savefig(p, "figure_4_3.png")
3434
p
3535
end

src/chapter04/grid_world.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Ju
22
using Plots
33
gr()
44

5-
figpath(f) = "docs/src/assets/figures/figure_$f.png"
5+
66

77
const GridWorldLinearIndices = LinearIndices((4,4))
88
const GridWorldCartesianIndices = CartesianIndices((4,4))
@@ -29,6 +29,6 @@ function fig_4_1()
2929
V, π = TabularV(16), RandomPolicy(fill(0.25, 16, 4))
3030
policy_evaluation!(V, π, GridWorldEnvModel; γ=1.0)
3131
p = heatmap(1:4, 1:4, reshape(V.table, 4,4), yflip=true)
32-
savefig(p, figpath("4_1"))
32+
savefig(p, "figure_4_1.png")
3333
p
3434
end

src/chapter05/blackjack.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using StatsBase:mean
44
using ..BlackJack
55
gr()
66

7-
figpath(f) = "docs/src/assets/figures/figure_$f.png"
7+
88

99
const Indices = LinearIndices(size(observationspace(BlackJackEnv)))
1010

@@ -29,8 +29,8 @@ function fig_5_1(n=10000)
2929
for dealer_card in 2:11, player_sum in 11:21]
3030
p1 = heatmap(usable_ace_values)
3131
p2 = heatmap(no_usable_ace_values)
32-
savefig(p1, figpath("5_1_usable_ace_n_$n"))
33-
savefig(p2, figpath("5_1_no_usable_ace_n_$n"))
32+
savefig(p1, "figure_5_1_usable_ace_n_$n.png")
33+
savefig(p2, "figure_5_1_no_usable_ace_n_$n.png")
3434
p1, p2
3535
end
3636

@@ -56,10 +56,10 @@ function fig_5_2(n=1000000)
5656
p2 = heatmap(no_usable_ace_values)
5757
p3 = heatmap(usable_ace_policy)
5858
p4 = heatmap(no_usable_ace_policy)
59-
savefig(p1, figpath("5_2_usable_ace_n_$n"))
60-
savefig(p2, figpath("5_2_no_usable_ace_n_$n"))
61-
savefig(p3, figpath("5_2_usable_ace_policy_n_$n"))
62-
savefig(p4, figpath("5_2_no_usable_ace_policy_n_$n"))
59+
savefig(p1, "figure_5_2_usable_ace_n_$n.png")
60+
savefig(p2, "figure_5_2_no_usable_ace_n_$n.png")
61+
savefig(p3, "figure_5_2_usable_ace_policy_n_$n.png")
62+
savefig(p4, "figure_5_2_no_usable_ace_policy_n_$n.png")
6363
p1, p2, p3, p4
6464
end
6565

@@ -93,6 +93,6 @@ function fig_5_3(n=10000)
9393
end
9494
p = plot(mean((run() .- (-0.27726)).^2 for _ in 1:100), label="Weighted Importance Sampling")
9595
p = plot!(p, mean((run(:OrdinaryImportanceSampling) .- (-0.27726)).^2 for _ in 1:100), xscale=:log10, label="Ordinary Importance Sampling")
96-
savefig(p, figpath("5_3"))
96+
savefig(p, "figure_5_3.png")
9797
p
9898
end

src/chapter05/leftright.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ..LeftRight
44
using Plots
55
gr()
66

7-
figpath(f) = "docs/src/assets/figures/figure_$f.png"
7+
88

99
function fig_5_4()
1010
function value_collect()
@@ -31,6 +31,6 @@ function fig_5_4()
3131
train!(LeftRightEnv(), agent; callbacks = callbacks)
3232
plot!(p, callbacks[2](), xscale = :log10)
3333
end
34-
savefig(p, figpath("5_4"))
34+
savefig(p, "figure_5_4.png")
3535
p
3636
end

src/chapter06/cliff_walking.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ..CliffWalking
44
using Plots
55
gr()
66

7-
figpath(f) = "docs/src/assets/figures/figure_$f.png"
7+
88

99
function rewards_of_each_episode()
1010
rewards = []
@@ -61,7 +61,7 @@ function fig_6_3_a()
6161
p = plot(legend=:bottomright, dpi=200)
6262
plot!(p, mean(rewards(gen_env_Qagent()...) for _ in 1:100), label="QLearning")
6363
plot!(p, mean(rewards(gen_env_SARSAagent()...) for _ in 1:100), label="SARSA")
64-
savefig(p, figpath("6_3_a"))
64+
savefig(p, "figure_6_3_a.png")
6565
p
6666
end
6767

@@ -82,6 +82,6 @@ function fig_6_3_b()
8282
plot!(p, A, [mean(avg_reward_per_episode(1000, gen_env_Qagent(α)...) for _ in 1:10) for α in A], label="Asymptotic interim Q")
8383
plot!(p, A, [mean(avg_reward_per_episode(1000, gen_env_SARSAagent(α)...) for _ in 1:10) for α in A], label="Asymptotic SARSA")
8484
plot!(p, A, [mean(avg_reward_per_episode(1000, gen_env_ExpectedSARSAagent(α)...) for _ in 1:10) for α in A], label="Asymptotic ExpectedSARSA")
85-
savefig(p, figpath("6_3_b"))
85+
savefig(p, "figure_6_3_b.png")
8686
p
8787
end

src/chapter06/maximization_bias.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using StatsBase:mean
44
using Plots
55
gr()
66

7-
figpath(f) = "docs/src/assets/figures/figure_$f.png"
7+
88

99
function count_left_actions_from_A()
1010
counts_per_episode = []
@@ -54,6 +54,6 @@ function fig_6_5()
5454
p = plot(legend=:topright, dpi=200)
5555
plot!(p, mean(run_once(gen_env_DQagent()...) for _ in 1:10000), label="Double-Q")
5656
plot!(p, mean(run_once(gen_env_Qagent()...) for _ in 1:10000), label="Q")
57-
savefig(p, figpath("6_5"))
57+
savefig(p, "figure_6_5.png")
5858
p
5959
end

src/chapter06/randomwalk.jl

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
using Ju
22
using Statistics
3-
using LaTeXStrings
43
using ..RandomWalk
54
using Plots
65
gr()
76

8-
figpath(f) = "docs/src/assets/figures/figure_$f.png"
7+
98

109
const true_values = [i/6 for i in 1:5]
1110

@@ -54,7 +53,7 @@ function fig_6_2_a()
5453
train!(env, agent; callbacks = (stop_at_episode(i),))
5554
plot!(p, agent.learner.approximator.table[2:end - 1])
5655
end
57-
savefig(p, figpath("6_2_a"))
56+
savefig(p, "figure_6_2_a.png")
5857
p
5958
end
6059

@@ -63,15 +62,15 @@ function fig_6_2_b()
6362
for α in [0.05, 0.1, 0.15]
6463
callbacks = (stop_at_episode(100), record_rms())
6564
train!(gen_env_TDagent(α)...;callbacks = callbacks)
66-
plot!(p, callbacks[2](), label = latexstring("TD \\alpha="))
65+
plot!(p, callbacks[2](), label ="TD alpha=")
6766
end
6867

6968
for α in [0.01, 0.02, 0.03, 0.04]
7069
callbacks = (stop_at_episode(100), record_rms())
7170
train!(gen_env_MCagent(α)...;callbacks = callbacks)
72-
plot!(p, callbacks[2](), label = latexstring("MC \\alpha="))
71+
plot!(p, callbacks[2](), label ="MC alpha=")
7372
end
74-
savefig(p, figpath("6_2_b"))
73+
savefig(p, "figure_6_2_b.png")
7574
p
7675
end
7776

@@ -93,6 +92,6 @@ function fig_6_2_c()
9392
end
9493
plot!(mean(avg_rms), color=:red, label="MC")
9594

96-
savefig(p, figpath("6_2_c"))
95+
savefig(p, "figure_6_2_c.png")
9796
p
9897
end

0 commit comments

Comments
 (0)