Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,10 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {

auto found = gutils->invertedPointers.find(&I);
if (gutils->isConstantValue(&I)) {
assert(found == gutils->invertedPointers.end());
// invertPointerM may have already computed a shadow for this CONST
// instruction (e.g., when resolving free arguments through CONST
// PHI/Load chains in forward-over-reverse mode where activity analysis
// is inconsistent between alloc and free). Just leave it.
return;
}

Expand Down
64 changes: 35 additions & 29 deletions enzyme/Enzyme/CallDerivatives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4093,42 +4093,44 @@ bool AdjointGenerator::handleKnownCallDerivatives(

if (Mode == DerivativeMode::ForwardMode ||
Mode == DerivativeMode::ForwardModeError) {
if (!gutils->isConstantValue(call.getArgOperand(0))) {
IRBuilder<> Builder2(&call);
getForwardBuilder(Builder2);
auto origfree = call.getArgOperand(0);
auto newfree = gutils->getNewFromOriginal(call.getArgOperand(0));
auto tofree = gutils->invertPointerM(origfree, Builder2);
// Always emit checked_free for both primal and shadow, regardless
// of isConstantValue. In forward-over-reverse mode, the activity
// analysis can be inconsistent between allocation (NON-CONST, shadow
// created) and deallocation (CONST, no shadow free). The checked_free
// function safely handles primal == shadow by not freeing.
IRBuilder<> Builder2(&call);
getForwardBuilder(Builder2);
auto origfree = call.getArgOperand(0);
auto newfree = gutils->getNewFromOriginal(call.getArgOperand(0));
auto tofree = gutils->invertPointerM(origfree, Builder2);

Function *free = getOrInsertCheckedFree(
*call.getModule(), &call, newfree->getType(), gutils->getWidth());
Function *free = getOrInsertCheckedFree(
*call.getModule(), &call, newfree->getType(), gutils->getWidth());

bool used = true;
if (auto instArg = dyn_cast<Instruction>(call.getArgOperand(0)))
used = unnecessaryInstructions.find(instArg) ==
unnecessaryInstructions.end();
bool used = true;
if (auto instArg = dyn_cast<Instruction>(call.getArgOperand(0)))
used = unnecessaryInstructions.find(instArg) ==
unnecessaryInstructions.end();

SmallVector<Value *, 3> args;
if (used)
args.push_back(newfree);
else
args.push_back(
Constant::getNullValue(call.getArgOperand(0)->getType()));
SmallVector<Value *, 3> args;
if (used)
args.push_back(newfree);
else
args.push_back(
Constant::getNullValue(call.getArgOperand(0)->getType()));

auto rule = [&args](Value *tofree) { args.push_back(tofree); };
applyChainRule(Builder2, rule, tofree);
auto rule = [&args](Value *tofree) { args.push_back(tofree); };
applyChainRule(Builder2, rule, tofree);

for (size_t i = 1; i < call.arg_size(); i++) {
args.push_back(gutils->getNewFromOriginal(call.getArgOperand(i)));
}
for (size_t i = 1; i < call.arg_size(); i++) {
args.push_back(gutils->getNewFromOriginal(call.getArgOperand(i)));
}

auto frees = Builder2.CreateCall(free->getFunctionType(), free, args);
frees->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc()));
auto frees = Builder2.CreateCall(free->getFunctionType(), free, args);
frees->setDebugLoc(gutils->getNewFromOriginal(call.getDebugLoc()));

eraseIfUnused(call);
return true;
}
eraseIfUnused(call);
return true;
}
auto callval = call.getCalledOperand();

Expand Down Expand Up @@ -4216,7 +4218,11 @@ bool AdjointGenerator::handleKnownCallDerivatives(
}

// TODO HANDLE FREE
llvm::errs() << "freeing without malloc " << *val << "\n";
// Suppress warning for common safe cases (phi nodes, loads) that are
// conservatively handled but can't be statically matched to allocations
if (!isa<PHINode>(val) && !isa<LoadInst>(val)) {
llvm::errs() << "freeing without malloc " << *val << "\n";
}
eraseIfUnused(call, /*erase*/ true, /*check*/ false);
return true;
}
Expand Down
23 changes: 15 additions & 8 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5386,10 +5386,24 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
return applyChainRule(oval->getType(), BuilderM, rule);
}

// Check invertedPointers before isConstantValue. If an explicit shadow was
// registered (e.g. by the forward-mode allocation handler), use it
// regardless of the activity analysis result. This prevents leaks when
// activity analysis is inconsistent between alloc (NON-CONST) and free
// (CONST) in forward-over-reverse mode.
{
auto ifound = invertedPointers.find(oval);
if (ifound != invertedPointers.end()) {
return &*ifound->second;
}
}

bool shouldNullShadow = isConstantValue(oval);
if (shouldNullShadow) {
if (isa<InsertValueInst>(oval) || isa<ExtractValueInst>(oval) ||
isa<InsertElementInst>(oval) || isa<ExtractElementInst>(oval)) {
isa<InsertElementInst>(oval) || isa<ExtractElementInst>(oval) ||
(isa<PHINode>(oval) && oval->getType()->isPointerTy()) ||
(isa<LoadInst>(oval) && oval->getType()->isPointerTy())) {
shouldNullShadow = false;
auto orig = cast<Instruction>(oval);
if (knownRecomputeHeuristic.count(orig)) {
Expand Down Expand Up @@ -5495,13 +5509,6 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
auto M = oldFunc->getParent();
assert(oval);

{
auto ifound = invertedPointers.find(oval);
if (ifound != invertedPointers.end()) {
return &*ifound->second;
}
}

if (mode != DerivativeMode::ForwardMode &&
mode != DerivativeMode::ForwardModeError &&
mode != DerivativeMode::ForwardModeSplit && nullShadow) {
Expand Down