Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ for f in (:eig, :eigh)
_warn_pullback_truncerror(dϵ)

# compute pullbacks
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
$f_pullback!(dA, Ac, DV, dDVtrunc, ind)
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required

# restore state
Expand Down Expand Up @@ -351,8 +351,8 @@ for f in (:eig, :eigh)
dDVtrunc = last.(arrayify.(DVtrunc, Mooncake.tangent(DVtrunc_dDVtrunc)))
function $f_adjoint!(::NoRData)
# compute pullbacks
$f_pullback!(dA, Ac, DVc, dDVtrunc, ind)
zero!.(dDVtrunc) # since this is allocated in this function this is probably not required
$f_pullback!(dA, Ac, DV, dDVtrunc, ind)
zero!.(dDV)

# restore state
copy!(A, Ac)
Expand Down Expand Up @@ -425,7 +425,7 @@ for (f!, f) in (
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
USVᴴc = copy.(USVᴴ)
output = $f!(A, Mooncake.primal(alg_dalg))
output = $f!(A, USVᴴ, Mooncake.primal(alg_dalg))
function svd_adjoint(::NoRData)
copy!(A, Ac)
if $(f! == svd_compact!)
Expand Down Expand Up @@ -590,7 +590,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUS
_warn_pullback_truncerror(dϵ)

# compute pullbacks
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind)
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
zero!.(dUSVᴴ)

Expand Down Expand Up @@ -717,8 +717,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error!)}, A_dA::CoDual, U
dUSVᴴtrunc = last.(arrayify.(USVᴴtrunc, Mooncake.tangent(USVᴴtrunc_dUSVᴴtrunc)))
function svd_trunc_adjoint(::NoRData)
# compute pullbacks
svd_pullback!(dA, Ac, USVᴴc, dUSVᴴtrunc, ind)
zero!.(dUSVᴴtrunc) # since this is allocated in this function this is probably not required
svd_pullback!(dA, Ac, USVᴴ, dUSVᴴtrunc, ind)
zero!.(dUSVᴴ)

# restore state
Expand Down
22 changes: 15 additions & 7 deletions src/pullbacks/qr.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Real = default_pullback_gauge_atol(ΔQ))
qr_rank(R; rank_atol = default_pullback_rank_atol(R)) =
@something findlast(>=(rank_atol) ∘ abs, diagview(R)) 0

function check_qr_cotangents(
Q, R, ΔQ, ΔR, p::Int;
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
)
minmn = min(size(Q, 1), size(R, 2))
if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
Expand All @@ -7,11 +14,13 @@ function check_qr_cotangents(Q, R, ΔQ, ΔR, minmn::Int, p::Int; gauge_atol::Rea
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, :, (p + 1):size(Q, 2))
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
Δgauge_Q = norm(ΔQ2, Inf)
Δgauge = max(Δgauge, Δgauge_Q)
end
if !iszerotangent(ΔR)
ΔR22 = view(ΔR, (p + 1):minmn, (p + 1):size(R, 2))
Δgauge = max(Δgauge, norm(ΔR22, Inf))
Δgauge_R = norm(ΔR22, Inf)
Δgauge = max(Δgauge, Δgauge_R)
end
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
Expand All @@ -29,7 +38,7 @@ function check_qr_full_cotangents(Q1, ΔQ2, Q1dΔQ2; gauge_atol::Real = default_
# Q2' * ΔQ2 as a gauge dependent quantity.
Δgauge = norm(mul!(copy(ΔQ2), Q1, Q1dΔQ2, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`qr` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
@warn "`qr` full cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end

Expand Down Expand Up @@ -60,9 +69,8 @@ function qr_pullback!(
Q, R = QR
m = size(Q, 1)
n = size(R, 2)
minmn = min(m, n)
Rd = diagview(R)
p = @something findlast(>=(rank_atol) ∘ abs, Rd) 0
p = qr_rank(R)

ΔQ, ΔR = ΔQR

Expand All @@ -72,7 +80,7 @@ function qr_pullback!(
ΔA1 = view(ΔA, :, 1:p)
ΔA2 = view(ΔA, :, (p + 1):n)

check_qr_cotangents(Q, R, ΔQ, ΔR, minmn, p; gauge_atol)
check_qr_cotangents(Q, R, ΔQ, ΔR, p; gauge_atol)

ΔQ̃ = zero!(similar(Q, (m, p)))
if !iszerotangent(ΔQ)
Expand Down
33 changes: 24 additions & 9 deletions test/testsuite/TestSuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using MatrixAlgebraKit
using MatrixAlgebraKit: diagview
using LinearAlgebra: Diagonal, norm, istriu, istril, I
using Random, StableRNGs
using Mooncake
using AMDGPU, CUDA

const tests = Dict()
Expand Down Expand Up @@ -86,16 +87,30 @@ instantiate_unitary(::Type{<:Diagonal}, A, sz) = Diagonal(fill!(similar(parent(A

include("ad_utils.jl")

include("qr.jl")
include("lq.jl")
include("polar.jl")
include("projections.jl")
include("schur.jl")
include("eig.jl")
include("eigh.jl")
include("orthnull.jl")
include("svd.jl")
include("mooncake.jl")

# Decompositions
# --------------
include("decompositions/qr.jl")
include("decompositions/lq.jl")
include("decompositions/polar.jl")
include("decompositions/schur.jl")
include("decompositions/eig.jl")
include("decompositions/eigh.jl")
include("decompositions/orthnull.jl")
include("decompositions/svd.jl")

# Mooncake
# --------
include("mooncake/mooncake.jl")
include("mooncake/qr.jl")
include("mooncake/lq.jl")
include("mooncake/eig.jl")
include("mooncake/eigh.jl")
include("mooncake/svd.jl")
include("mooncake/polar.jl")
include("mooncake/orthnull.jl")

include("enzyme.jl")
include("chainrules.jl")

Expand Down
13 changes: 2 additions & 11 deletions test/testsuite/ad_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,6 @@ function remove_svdgauge_dependence!(
mul!(ΔU, U, gaugepart, -1, 1)
return ΔU, ΔVᴴ
end
function remove_eiggauge_dependence!(
ΔV, D, V;
degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D)
)
gaugepart = V' * ΔV
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
return ΔV
end
function remove_eighgauge_dependence!(
ΔV, D, V;
degeneracy_atol = MatrixAlgebraKit.default_pullback_gauge_atol(D)
Expand Down Expand Up @@ -204,7 +195,7 @@ function ad_eig_full_setup(A)
D, V = DV
Ddiag = diagview(D)
ΔV = randn!(similar(A, complex(T), m, m))
ΔV = remove_eiggauge_dependence!(ΔV, D, V)
ΔV = remove_eig_gauge_dependence!(ΔV, D, V)
ΔD = randn!(similar(A, complex(T), m, m))
ΔD2 = Diagonal(randn!(similar(A, complex(T), m)))
return DV, (ΔD, ΔV), (ΔD2, ΔV)
Expand All @@ -216,7 +207,7 @@ function ad_eig_full_setup(A::Diagonal)
DV = eig_full(A)
D, V = DV
ΔV = randn!(similar(A.diag, T, m, m))
ΔV = remove_eiggauge_dependence!(ΔV, D, V)
ΔV = remove_eig_gauge_dependence!(ΔV, D, V)
ΔD = Diagonal(randn!(similar(A.diag, T, m)))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ΔD = Diagonal(randn!(similar(A.diag, T, m)))
ΔD = Diagonal(randn!(similar(D.diag)))

or also using diagview(D) if we don't want to access .diag.

ΔD2 = Diagonal(randn!(similar(A.diag, T, m)))
return DV, (ΔD, ΔV), (ΔD2, ΔV)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using TestExtras
using LinearAlgebra

include("../linearmap.jl")
include("../../linearmap.jl")

_left_orth_svd(x; kwargs...) = left_orth(x; alg = :svd, kwargs...)
_left_orth_svd!(x, VC; kwargs...) = left_orth!(x, VC; alg = :svd, kwargs...)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading