Credit (blame) to Billy
https://fwd.gymni.ch/PaQUnB
using Enzyme
using Base.Experimental: @aliasscope, Const
using BenchmarkTools
using StaticArrays
Enzyme.API.inlineall!(true)
Enzyme.API.printall!(true)
function f(pos)
e = 0.
n = length(pos)
@inbounds for i in 1:3:n
vx, vy, vz = pos[ i ], pos[i+1 ], pos[i+2 ]
wx, wy, wz = pos[ (i+3 - 1) % n + 1 ], pos[(i+4 - 1) % n + 1 ], pos[ (i+5 - 1) % n + 1]
e += (wx-vx)^2 + (wy-vy)^2 + (wz-vz)^2
end
return e
end
function gradient_ip(in_t, out)
@aliasscope let in = Const(in)
Enzyme.autodiff_deferred(Reverse, f, Duplicated(in, out))
end
return
end
function hess_row_ip(in, idx)
onehot = similar(in)
onehot .= 0.0
onehot[idx] = 1.0
undef = similar(in)
undef .= 0.0
# can reuse
out = similar(in)
out .= 0.0
Enzyme.autodiff(Forward, gradient_ip, Const, Duplicated(in, onehot), DuplicatedNoNeed(out, undef))
return undef
end
function test(Type, n)
in = Type([p for i in 1:n for p in [cos(2π*i/n), sin(2π*i/n), 0]])
f(in)
# outs = []
# for i in 1:length(in)
# push!(outs, hess_row_ip(in, i))
# end
identity = []
outputs = []
for idx in 1:length(in)
onehot = similar(in)
onehot .= 0.0
onehot[idx] = 1.0
push!(identity, onehot)
undef = similar(in)
undef .= 0.0
push!(outputs, undef)
end
bd_in = BatchDuplicated(in, (identity..., ))
out_tmp = similar(in)
out_tmp .= 0.0
bd_out = BatchDuplicatedNoNeed(out_tmp, (outputs...,))
Enzyme.autodiff(
Forward,
gradient_ip, Const, bd_in, bd_out)
@show Type, n
@btime Enzyme.autodiff(
Forward,
gradient_ip, Const, $bd_in, $bd_out)
end
for n in (3, 30, 300)
# test(MVector{3*n, Float64}, n)
test(Vector{Float64}, n)
end
Credit (blame) to Billy
https://fwd.gymni.ch/PaQUnB