From b43c50ada0f9957df0564ad1d5d27b0930030115 Mon Sep 17 00:00:00 2001 From: WT Date: Sun, 15 Aug 2021 21:39:33 +0100 Subject: [PATCH] wip rrule dot test --- src/testers.jl | 71 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 49 insertions(+), 22 deletions(-) diff --git a/src/testers.jl b/src/testers.jl index 91f21b3..eda735f 100644 --- a/src/testers.jl +++ b/src/testers.jl @@ -172,7 +172,7 @@ function test_rrule( config::RuleConfig, f, args...; - output_tangent=Auto(), + output_cotangent=Auto(), check_thunked_output_tangent=true, fdm=_fdm, rrule_f=ChainRulesCore.rrule, @@ -188,6 +188,8 @@ function test_rrule( # and define helper closure over fkwargs call(f, xs...) = f(xs...; fkwargs...) + call_on_copy(f, xs...) = deepcopy(f)(deepcopy(xs)...; deepcopy(fkwargs)...) + @testset "test_rrule: $f on $(_string_typeof(args))" begin # Check correctness of evaluation. @@ -219,30 +221,55 @@ function test_rrule( # Correctness testing via finite differencing. is_ignored = isa.(accum_cotangents, NoTangent) - fd_cotangents = _make_j′vp_call(fdm, call, ȳ, primals, is_ignored) + fd_output_tangent = _make_jvp_call( + fdm, call_on_copy, y, primals, tangents, is_ignored, + ) - for (accum_cotangent, ad_cotangent, fd_cotangent) in zip( - accum_cotangents, ad_cotangents, fd_cotangents + # Current implementation assumes that is_ignored is always false. Easy fix though. + # More consistent names for variables in this context. + inputs = primals + inputs_tangents = accum_cotangents + inputs_cotangents = ad_cotangents + output = y + output_tangent = fd_output_tangent + output_cotangent = ȳ + @test isapprox( + dot(output_cotangent, output_tangent), + dot(inputs_cotangents, inputs_tangents), ) - if accum_cotangent isa NoTangent # then we marked this argument as not differentiable - @assert fd_cotangent === NoTangent() - ad_cotangent isa ZeroTangent && error( - "The pullback in the rrule should use NoTangent()" * - " rather than ZeroTangent() for non-perturbable arguments.", - ) - @test ad_cotangent isa NoTangent # we said it wasn't differentiable. - else - ad_cotangent isa AbstractThunk && check_inferred && _test_inferred(unthunk, ad_cotangent) - - # The main test of the actual derivative being correct: - test_approx(ad_cotangent, fd_cotangent; isapprox_kwargs...) - _test_add!!_behaviour(accum_cotangent, ad_cotangent; isapprox_kwargs...) - end - end - if check_thunked_output_tangent - test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk:") - end + # Alternatively: + # x = primals + # ẋ = accum_cotangents + # x̄ = ad_cotangents + # y = y + # ẏ = fd_output_tangent + # ȳ = ȳ + # @test dot(ȳ, ẏ) ≈ dot(x̄, ẋ) + + + # for (accum_cotangent, ad_cotangent, fd_cotangent) in zip( + # accum_cotangents, ad_cotangents, fd_cotangents + # ) + # if accum_cotangent isa NoTangent # then we marked this argument as not differentiable + # @assert fd_cotangent === NoTangent() + # ad_cotangent isa ZeroTangent && error( + # "The pullback in the rrule should use NoTangent()" * + # " rather than ZeroTangent() for non-perturbable arguments.", + # ) + # @test ad_cotangent isa NoTangent # we said it wasn't differentiable. + # else + # ad_cotangent isa AbstractThunk && check_inferred && _test_inferred(unthunk, ad_cotangent) + + # # The main test of the actual derivative being correct: + # test_approx(ad_cotangent, fd_cotangent; isapprox_kwargs...) + # _test_add!!_behaviour(accum_cotangent, ad_cotangent; isapprox_kwargs...) + # end + # end + + # if check_thunked_output_tangent + # test_approx(ad_cotangents, pullback(@thunk(ȳ)), "pulling back a thunk:") + # end end # top-level testset end