Skip to content

Commit eb5934a

Browse files
authored
More unthunking in ∇chunk (#180)
* Update utils.jl * version 0.4.5
1 parent 204b958 commit eb5934a

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLUtils"
22
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
33
authors = ["Carlo Lucibello <[email protected]> and contributors"]
4-
version = "0.4.4"
4+
version = "0.4.5"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/utils.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,8 @@ end
237237
@non_differentiable _partition_idxs(::Any...)
238238

239239
# Similar to ∇eachslice https://github.com/JuliaDiff/ChainRules.jl/blob/8108a77a96af5d4b0c460aac393e44f8943f3c5e/src/rulesets/Base/indexing.jl#L77
240-
function ∇chunk(dys, x, idxs, vd::Val{dim}) where {dim}
240+
function ∇chunk(dys_raw, x, idxs, vd::Val{dim}) where {dim}
241+
dys = unthunk.(unthunk(dys_raw)) # https://github.com/FluxML/Zygote.jl/pull/966#issuecomment-2569227272
241242
i1 = findfirst(dy -> !(dy isa AbstractZero), dys)
242243
if i1 === nothing # all slices are Zero!
243244
return _zero_fill!(similar(x, float(eltype(x))))

0 commit comments

Comments
 (0)