Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion enzyme/Enzyme/CacheUtility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,13 @@ AllocaInst *CacheUtility::createCacheForScope(LimitContext ctx, Type *T,
alloc->setAlignment(Align(align));
}
if (sublimits.size() == 0) {
auto val = getUndefinedValueForType(*newFunc->getParent(), types.back());
// For i1 predicate caches, "not executed" must behave like false.
// Otherwise reverse CFG reconstruction may branch on undef/poison (issue
// #2629).
bool forceZero =
T->isIntegerTy() && cast<IntegerType>(T)->getBitWidth() == 1;
auto val = getUndefinedValueForType(*newFunc->getParent(), types.back(),
/*forceZero*/ forceZero);
if (!isa<UndefValue>(val))
scopeInstructions[alloc].push_back(entryBuilder.CreateStore(val, alloc));
}
Expand Down
103 changes: 39 additions & 64 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7891,7 +7891,12 @@ void GradientUtils::branchToCorrespondingTarget(
}
// llvm::errs() << "</DONE>\n";

if (targetToPreds.size() == 3) {
// NOTE: the 3-target "reuse two branch predicates" optimization constructs
// a synthetic staging block to evaluate the second predicate only under the
// first split. That requires actual control-flow. When replacePHIs != nullptr
// we must not use this fast-path (it would otherwise eagerly evaluate the
// inner predicate unconditionally and reproduce #2629 at -O0/-O1).
if (replacePHIs == nullptr && targetToPreds.size() == 3) {
// Try `block` as a potential first split point.
for (auto block : blocks) {
{
Expand Down Expand Up @@ -8009,73 +8014,43 @@ void GradientUtils::branchToCorrespondingTarget(
// the remainder of foundTargets.
auto cond1 = lookupM(bi1->getCondition(), BuilderM);

// Condition cond2 splits off the two blocks in
// (foundTargets-uniqueTargets) from each other.
auto cond2 = lookupM(bi2->getCondition(), BuilderM);
// Create a staging block so the second predicate is only evaluated
// on the path where the first split is taken (fixes #2629).
BasicBlock *staging =
BasicBlock::Create(oldFunc->getContext(), "staging", newFunc);

// `lookupM` requires reverse-only blocks to have an entry in
// reverseBlockToPrimal. Since `staging` is synthetic, map it to the
// same forward/primal block as the current insertion block.
BasicBlock *stagingFwd = BuilderM.GetInsertBlock();
if (!isOriginalBlock(*stagingFwd)) {
auto it = reverseBlockToPrimal.find(stagingFwd);
assert(it != reverseBlockToPrimal.end());
stagingFwd = it->second;
}
reverseBlockToPrimal[staging] = stagingFwd;

if (replacePHIs == nullptr) {
BasicBlock *staging =
BasicBlock::Create(oldFunc->getContext(), "staging", newFunc);
auto stagingIfNeeded = [&](BasicBlock *B) {
auto edge = std::make_pair(block, B);
if (done[edge].size() == 1) {
return *done[edge].begin();
} else {
assert(done[edge].size() == 2);
return staging;
}
};
BuilderM.CreateCondBr(cond1, stagingIfNeeded(bi1->getSuccessor(0)),
stagingIfNeeded(bi1->getSuccessor(1)));
BuilderM.SetInsertPoint(staging);
BuilderM.CreateCondBr(
cond2,
*done[std::make_pair(subblock, bi2->getSuccessor(0))].begin(),
*done[std::make_pair(subblock, bi2->getSuccessor(1))].begin());
} else {
Value *otherBranch = nullptr;
for (unsigned i = 0; i < 2; ++i) {
Value *val = cond1;
if (i == 1)
val = BuilderM.CreateNot(val, "anot1_");
auto edge = std::make_pair(block, bi1->getSuccessor(i));
if (done[edge].size() == 1) {
auto found = replacePHIs->find(*done[edge].begin());
if (found == replacePHIs->end())
continue;
if (&*BuilderM.GetInsertPoint() == found->second) {
if (found->second->getNextNode())
BuilderM.SetInsertPoint(found->second->getNextNode());
else
BuilderM.SetInsertPoint(found->second->getParent());
}
found->second->replaceAllUsesWith(val);
found->second->eraseFromParent();
} else {
otherBranch = val;
}
auto stagingIfNeeded = [&](BasicBlock *B) {
auto edge = std::make_pair(block, B);
if (done[edge].size() == 1) {
return *done[edge].begin();
} else {
assert(done[edge].size() == 2);
return staging;
}
};

for (unsigned i = 0; i < 2; ++i) {
auto edge = std::make_pair(subblock, bi2->getSuccessor(i));
auto found = replacePHIs->find(*done[edge].begin());
if (found == replacePHIs->end())
continue;
BuilderM.CreateCondBr(cond1, stagingIfNeeded(bi1->getSuccessor(0)),
stagingIfNeeded(bi1->getSuccessor(1)));
BuilderM.SetInsertPoint(staging);

Value *val = cond2;
if (i == 1)
val = BuilderM.CreateNot(val, "bnot1_");
val = BuilderM.CreateAnd(val, otherBranch, "andVal" + Twine(i));
if (&*BuilderM.GetInsertPoint() == found->second) {
if (found->second->getNextNode())
BuilderM.SetInsertPoint(found->second->getNextNode());
else
BuilderM.SetInsertPoint(found->second->getParent());
}
found->second->replaceAllUsesWith(val);
found->second->eraseFromParent();
}
}
// IMPORTANT: materialize cond2 *in staging* (so it is not executed
// when the outer guard path wasn't taken).
auto cond2 = lookupM(bi2->getCondition(), BuilderM);
BuilderM.CreateCondBr(
cond2,
*done[std::make_pair(subblock, bi2->getSuccessor(0))].begin(),
*done[std::make_pair(subblock, bi2->getSuccessor(1))].begin());

return;
}
Expand Down
17 changes: 11 additions & 6 deletions enzyme/test/Enzyme/ReverseMode/condtriload.ll
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ entry:

; CHECK: define internal void @diffealldiv(double* %a, double* %"a'", i1 %cmp, i32 %val, double %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %switch.selectcmp = icmp eq i32 %val, 17
; CHECK-NEXT: %switch.select = select i1 %switch.selectcmp, i8 0, i8 1
; CHECK-NEXT: %switch.selectcmp6 = icmp eq i32 %val, 13
; CHECK-NEXT: %switch.select7 = select i1 %switch.selectcmp6, i8 2, i8 %switch.select
; CHECK-NEXT: %_cache.1 = select i1 %cmp, i8 %switch.select7, i8 undef
; CHECK-NEXT: %0 = select {{(fast )?}}i1 %cmp, double %differeturn, double 0.000000e+00
; CHECK-NEXT: br i1 %cmp, label %invertend, label %invertentry

Expand Down Expand Up @@ -103,12 +108,12 @@ entry:
; CHECK-NEXT: %7 = phi {{(fast )?}}double [ %l1_unwrap, %invertend_phirc ], [ %l2_unwrap, %[[invertend_phirc1]] ], [ %l3_unwrap, %[[invertend_phirc2]] ]
; CHECK-NEXT: %[[m0diffep:.+]] = fmul fast double %0, %7
; CHECK-NEXT: %[[i8:.+]] = fadd fast double %[[m0diffep]], %[[m0diffep]]
; CHECK-NEXT: %anot1_ = xor i1 %c1_unwrap, true
; CHECK-NEXT: %bnot1_ = xor i1 %c2_unwrap, true
; CHECK-NEXT: %andVal1 = and i1 %bnot1_, %anot1_
; CHECK-NEXT: %[[a9]] = select {{(fast )?}}i1 %c1_unwrap, double %[[i8]], double 0.000000e+00
; CHECK-NEXT: %[[a10]] = select {{(fast )?}}i1 %andVal1, double %[[i8]], double 0.000000e+00
; CHECK-NEXT: %[[a11]] = select {{(fast )?}}i1 %c2_unwrap, double %[[i8]], double 0.000000e+00
; CHECK-NEXT: %[[icmp0:.+]] = icmp eq i8 0, %_cache.1
; CHECK-NEXT: %[[icmp1:.+]] = icmp eq i8 1, %_cache.1
; CHECK-NEXT: %[[icmp2:.+]] = icmp eq i8 2, %_cache.1
; CHECK-NEXT: %[[a9]] = select {{(fast )?}}i1 %[[icmp2]], double %[[i8]], double 0.000000e+00
; CHECK-NEXT: %[[a10]] = select {{(fast )?}}i1 %[[icmp1]], double %[[i8]], double 0.000000e+00
; CHECK-NEXT: %[[a11]] = select {{(fast )?}}i1 %[[icmp0]], double %[[i8]], double 0.000000e+00
; CHECK-NEXT: br i1 %c1_unwrap, label %invertbdef, label %staging

; CHECK: staging:
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/Enzyme/ReverseMode/insertsort.ll
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ attributes #0 = { noinline norecurse nounwind uwtable }
; CHECK-NEXT: br i1 %[[cmp3_rev:.+]], label %invertwhile.body, label %invertland.rhs

; CHECK: invertwhile.end: ; preds = %entry, %while.body, %land.rhs
; CHECK-NEXT: %[[cmp3_rev]] = phi i1 [ undef, %entry ], [ %cmp3, %while.body ], [ %cmp3, %land.rhs ]
; CHECK-NEXT: %[[cmp3_rev]] = phi i1 [ false, %entry ], [ %cmp3, %while.body ], [ %cmp3, %land.rhs ]
; CHECK-NEXT: %loopLimit_cache.0 = phi i64 [ undef, %entry ], [ %iv, %while.body ], [ %iv, %land.rhs ]
; CHECK-NEXT: br i1 %cmp29, label %invertwhile.end.loopexit, label %invertentry
; CHECK-NEXT: }
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=0 -enzyme -mem2reg -early-cse -simplifycfg -instsimplify -correlated-propagation -simplifycfg -adce -S -enzyme-detect-readthrow=0 | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=0 -passes="enzyme,function(mem2reg,early-cse,%simplifycfg,instsimplify,correlated-propagation,%simplifycfg,adce)" -S -enzyme-detect-readthrow=0 | FileCheck %s

; Regression test for issue #2629:
; An outer branch guarded by a constant (enzyme_const) bool `fan` contains an
; inner branch guarded by an active predicate `cond` (computed from `%a`).
; When fan=false, the inner predicate is never computed, so its tape cache is
; never written. The reverse pass must not branch on an uninitialized/undef
; taped predicate in that case.
;
; The fix ensures that:
; 1. The inner predicate cache (`_cache.0`) is initialized to `false` (not
; undef) for the path where the outer guard is not taken.
; 2. The `cond_unwrap` lookup in the reverse pass is deferred to a `staging`
; block that is only reached when `fan` is true.

declare double @__enzyme_autodiff(i8*, ...)

define double @f(double* %a, i1 %fan) {
entry:
br i1 %fan, label %if.fan, label %merge

if.fan:
%a0 = load double, double* %a, align 8
%cond = fcmp ogt double %a0, 0.000000e+00
br i1 %cond, label %inner, label %merge

inner:
%gp = getelementptr inbounds double, double* %a, i32 1
%a1 = load double, double* %gp, align 8
br label %merge

merge:
%res = phi double [0.000000e+00, %entry], [1.000000e+00, %if.fan], [%a1, %inner]
ret double %res
}

define void @caller(double* %a, double* %da, i1 %fan) {
entry:
%call = call double (i8*, ...) @__enzyme_autodiff(
i8* bitcast (double (double*, i1)* @f to i8*),
double* nonnull %a, double* nonnull %da,
i1 %fan)
ret void
}

; CHECK: define internal void @diffef(double* %a, double* %"a'", i1 %fan, double %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: br i1 %fan, label %if.fan, label %invertmerge

; CHECK: if.fan:
; CHECK-NEXT: %a0 = load double, double* %a
; CHECK-NEXT: %cond = fcmp ogt double %a0, 0.000000e+00
; CHECK-NEXT: br i1 %cond, label %inner, label %invertmerge

; CHECK: inner:
; CHECK-NEXT: br label %invertmerge

; CHECK: invertentry:
; CHECK-NEXT: ret void

; CHECK: invertinner:
; CHECK-NEXT: %"gp'ipg_unwrap" = getelementptr inbounds double, double* %"a'", i32 1
; CHECK-NEXT: %[[a0:.+]] = load double, double* %"gp'ipg_unwrap"
; CHECK-NEXT: %[[a1:.+]] = fadd fast double %[[a0]], %[[sel:.+]]
; CHECK-NEXT: store double %[[a1]], double* %"gp'ipg_unwrap"
; CHECK-NEXT: br label %invertentry

; CHECK: invertmerge:
; NOTE: the `false` entry for %entry predecessor is critical - without the fix
; this would be `undef`, causing the reverse to take the inner path spuriously.
; CHECK-NEXT: %_cache.0 = phi i1 [ true, %inner ], [ false, %if.fan ], [ false, %entry ]
; CHECK-NEXT: %[[sel]] = select {{(fast )?}}i1 %_cache.0, double %differeturn, double 0.000000e+00
; NOTE: cond2 is only evaluated inside %staging, which is only reachable when
; %fan is true, preventing the uninitialized-predicate crash from issue #2629.
; CHECK-NEXT: br i1 %fan, label %staging, label %invertentry

; CHECK: staging:
; CHECK-NEXT: %a0_unwrap = load double, double* %a
; CHECK-NEXT: %cond_unwrap = fcmp ogt double %a0_unwrap, 0.000000e+00
; CHECK-NEXT: br i1 %cond_unwrap, label %invertinner, label %invertentry
; CHECK-NEXT: }
25 changes: 13 additions & 12 deletions enzyme/test/Enzyme/ReverseMode/scase.ll
Original file line number Diff line number Diff line change
Expand Up @@ -100,33 +100,34 @@ attributes #8 = { noreturn nounwind }
; CHECK: define internal { double } @diffetaylorlog(double %x, i32 %SINCOSN, double %differeturn)
; CHECK-NEXT: entry:
; CHECK-NEXT: %cmp8 = icmp eq i32 %SINCOSN, 0
; CHECK-NEXT: %lcmp.mod_unwrap = icmp ne i32 %SINCOSN, 1
; CHECK-NEXT: %anot1_ = xor i1 %cmp8, true
; CHECK-NEXT: %[[andVal:.+]] = and i1 %lcmp.mod_unwrap, %anot1_
; CHECK-NEXT: %bnot1_ = xor i1 %lcmp.mod_unwrap, true
; CHECK-NEXT: %0 = select{{( fast)?}} i1 %bnot1_, double %differeturn, double 0.000000e+00
; CHECK-NEXT: %1 = select{{( fast)?}} i1 %[[andVal]], double %differeturn, double 0.000000e+00
; CHECK-NEXT: %lcmp.mod = icmp ne i32 %SINCOSN, 1
; CHECK-NEXT: %spec.select = select i1 %lcmp.mod, i8 2, i8 1
; CHECK-NEXT: %_cache.0 = select i1 %cmp8, i8 0, i8 %spec.select
; CHECK-NEXT: %[[i0:.+]] = icmp eq i8 1, %_cache.0
; CHECK-NEXT: %[[i1:.+]] = icmp eq i8 2, %_cache.0
; CHECK-NEXT: %[[sel0:.+]] = select{{( fast)?}} i1 %[[i0]], double %differeturn, double 0.000000e+00
; CHECK-NEXT: %[[sel1:.+]] = select{{( fast)?}} i1 %[[i1]], double %differeturn, double 0.000000e+00
; CHECK-NEXT: br i1 %cmp8, label %invertentry, label %staging

; CHECK: invertentry: ; preds = %invertfor.body, %entry
; CHECK-NEXT: %"x'de.0" = phi double [ %0, %entry ], [ %[[i7:.+]], %invertfor.body ]
; CHECK-NEXT: %2 = insertvalue { double } undef, double %"x'de.0", 0
; CHECK-NEXT: ret { double } %2
; CHECK-NEXT: %"x'de.0" = phi double [ %[[sel0]], %entry ], [ %[[i7:.+]], %invertfor.body ]
; CHECK-NEXT: %[[iv:.+]] = insertvalue { double } undef, double %"x'de.0", 0
; CHECK-NEXT: ret { double } %[[iv]]

; CHECK: invertfor.body: ; preds = %staging, %incinvertfor.body
; CHECK-NEXT: %"x'de.1" = phi double [ %0, %staging ], [ %[[i7]], %incinvertfor.body ]
; CHECK-NEXT: %"x'de.1" = phi double [ %[[sel0]], %staging ], [ %[[i7]], %incinvertfor.body ]
; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %[[_unwrap2:.+]], %staging ], [ %[[i10:.+]], %incinvertfor.body ]
; CHECK-NEXT: %iv.next_unwrap = add nuw nsw i64 %"iv'ac.0", 1
; CHECK-NEXT: %_unwrap = trunc i64 %iv.next_unwrap to i32
; CHECK-NEXT: %conv_unwrap = sitofp i32 %_unwrap to double
; CHECK-NEXT: %[[d0diffez:.+]] = fdiv fast double %1, %conv_unwrap
; CHECK-NEXT: %[[d0diffez:.+]] = fdiv fast double %[[sel1]], %conv_unwrap
; CHECK-NEXT: %[[i3:.+]] = fsub fast double %conv_unwrap, 1.000000e+00
; CHECK-NEXT: %[[i4:.+]] = call fast double @llvm.pow.f64(double %x, double %[[i3]])
; CHECK-NEXT: %[[i5:.+]] = fmul fast double %conv_unwrap, %[[i4]]
; CHECK-NEXT: %[[i6:.+]] = fmul fast double %[[d0diffez]], %[[i5]]
; CHECK-NEXT: %[[i7]] = fadd fast double %"x'de.1", %[[i6]]
; CHECK-NEXT: %[[i8:.+]] = icmp eq i64 %"iv'ac.0", 0
; CHECK-NEXT: %[[i9:.+]] = select{{( fast)?}} i1 %[[i8]], double 0.000000e+00, double %1
; CHECK-NEXT: %[[i9:.+]] = select{{( fast)?}} i1 %[[i8]], double 0.000000e+00, double %[[sel1]]
; CHECK-NEXT: br i1 %[[i8]], label %invertentry, label %incinvertfor.body

; CHECK: incinvertfor.body: ; preds = %invertfor.body
Expand Down