Skip to content

Commit bc22ad6

Browse files
authored
Merge pull request #62 from michel2323/ms/fix_tests
Fix tests
2 parents 372f6a0 + fb84c69 commit bc22ad6

File tree

3 files changed

+40
-24
lines changed

3 files changed

+40
-24
lines changed

Project.toml

-10
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,3 @@ StaticArrays = "1"
2020
StatsBase = "0.33"
2121
StructArrays = "0.6"
2222
julia = "1.7"
23-
24-
[extras]
25-
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
26-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
27-
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
28-
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
29-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
30-
31-
[targets]
32-
test = ["Test", "ForwardDiff", "LinearAlgebra", "Random", "Symbolics"]

test/Project.toml

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
[deps]
2+
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
4+
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
5+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
6+
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
7+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
10+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
11+
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
12+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
13+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
14+
15+
[compat]
16+
ChainRules = "1.5"
17+
ChainRulesCore = "1.2"
18+
Combinatorics = "1"
19+
StaticArrays = "1"
20+
StatsBase = "0.33"
21+
StructArrays = "0.6"
22+
julia = "1.7"

test/runtests.jl

+18-14
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,21 @@ ChainRules.rrule(::typeof(my_tuple), args...) = args, Δ->Core.tuple(NoTangent()
2525
# Check characteristic of exp rule
2626
@variables ω α β γ δ ϵ ζ η
2727
(x1, c1) = ∂⃖{3}()(exp, ω)
28-
@test simplify(x1 == exp(ω)).val
28+
@test isequal(simplify(x1), simplify(exp(ω)))
2929
((_, x2), c2) = c1(α)
30-
@test simplify(x2 == α*exp(ω)).val
30+
@test isequal(simplify(x2), simplify(α*exp(ω)))
3131
(x3, c3) = c2(ZeroTangent(), β)
32-
@test simplify(x3 == β*exp(ω)).val
32+
@test isequal(simplify(x3), simplify(β*exp(ω)))
3333
((_, x4), c4) = c3(γ)
34-
@test simplify(x4 == exp(ω)*+*β))).val
34+
@test isequal(simplify(x4), simplify(exp(ω)*+*β))))
3535
(x5, c5) = c4(ZeroTangent(), δ)
36-
@test simplify(x5 == δ*exp(ω)).val
36+
@test isequal(simplify(x5), simplify(δ*exp(ω)))
3737
((_, x6), c6) = c5(ϵ)
38-
@test simplify(x6 == ϵ*exp(ω) + α*δ*exp(ω)).val
38+
@test isequal(simplify(x6), simplify(ϵ*exp(ω) + α*δ*exp(ω)))
3939
(x7, c7) = c6(ZeroTangent(), ζ)
40-
@test simplify(x7 == ζ*exp(ω) + β*δ*exp(ω)).val
40+
@test isequal(simplify(x7), simplify(ζ*exp(ω) + β*δ*exp(ω)))
4141
(_, x8) = c7(η)
42-
@test simplify(x8 == (η +*ζ) +*ϵ) +*+*β))))*exp(ω)).val
42+
@test isequal(simplify(x8), simplify((η +*ζ) +*ϵ) +*+*β))))*exp(ω)))
4343

4444
# Minimal 2-nd order forward smoke test
4545
@test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin),
@@ -123,10 +123,12 @@ let var"'" = Diffractor.PrimeDerivativeFwd
123123
# Integration tests
124124
@test recursive_sin'(1.0) == cos(1.0)
125125
@test recursive_sin''(1.0) == -sin(1.0)
126-
@test recursive_sin'''(1.0) == -cos(1.0)
127-
@test recursive_sin''''(1.0) == sin(1.0)
128-
@test recursive_sin'''''(1.0) == cos(1.0)
129-
@test recursive_sin''''''(1.0) == -sin(1.0)
126+
# Error: ArgumentError: Tangent for the primal Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}
127+
# should be backed by a NamedTuple type, not by Tuple{Tangent{Tuple{Float64, Float64}, Tuple{Float64, Float64}}}.
128+
@test_broken recursive_sin'''(1.0) == -cos(1.0)
129+
@test_broken recursive_sin''''(1.0) == sin(1.0)
130+
@test_broken recursive_sin'''''(1.0) == cos(1.0)
131+
@test_broken recursive_sin''''''(1.0) == -sin(1.0)
130132

131133
# Test the special rules for sin/cos/exp
132134
@test sin''''''(1.0) == -sin(1.0)
@@ -148,7 +150,7 @@ end
148150
@test gradient(x -> sum(abs2, x .+ 1.0), zeros(3))[1] == [2.0, 2.0, 2.0]
149151

150152
const fwd = Diffractor.PrimeDerivativeFwd
151-
const bwd = Diffractor.PrimeDerivativeFwd
153+
const bwd = Diffractor.PrimeDerivativeBack
152154

153155
function f_broadcast(a)
154156
l = a / 2.0 * [[0. 1. 1.]; [1. 0. 1.]; [1. 1. 0.]]
@@ -186,7 +188,9 @@ end
186188
# Issue #27 - Mixup in lifting of getfield
187189
let var"'" = bwd
188190
@test (x->x^5)''(1.0) == 20.
189-
@test (x->x^5)'''(1.0) == 60.
191+
@test (x->(x*x)*(x*x)*x)''' == 60.
192+
# Higher order control flow not yet supported (https://github.com/JuliaDiff/Diffractor.jl/issues/24)
193+
@test_broken (x->x^5)'''(1.0) == 60.
190194
end
191195

192196
# Issue #38 - Splatting arrays

0 commit comments

Comments
 (0)