Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 68 additions & 8 deletions enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ static const StringSet<> KnownInactiveFunctionInsts = {

/// Is the use of value val as an argument of call CI known to be inactive
/// This tool can only be used when in DOWN mode
bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) {
bool ActivityAnalyzer::isFunctionArgumentConstant(CallBase *CI, Value *val) {
assert(directions & DOWN);
if (isInactiveCall(*CI))
return true;
Expand Down Expand Up @@ -645,6 +645,12 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) {
if (Name == "MPI_Waitall" || Name == "PMPI_Waitall")
return val != CI->getOperand(1);

if (Name == "__kmpc_reduce_nowait")
return val != CI->getOperand(4);

if (Name == "__kmpc_end_reduce_nowait")
return true;

// TODO interprocedural detection
// Before potential introprocedural detection, any function without definition
// may to be assumed to have an active use
Expand Down Expand Up @@ -676,6 +682,15 @@ static inline void propagateArgumentInformation(
return;
}

if (Name == "__kmpc_reduce_nowait") {
propagateFromOperand(CI.getArgOperand(4));
return;
}

if (Name == "__kmpc_end_reduce_nowait") {
return;
}

if (Name == "julia.call" || Name == "julia.call2") {
for (size_t i = 1; i < CI.arg_size(); i++) {
propagateFromOperand(CI.getOperand(i));
Expand Down Expand Up @@ -1109,6 +1124,40 @@ bool isValuePotentiallyUsedAsPointer(llvm::Value *val) {
return false;
}

bool ActivityAnalyzer::hasActiveArgumentOtherThan(Instruction *I, Value *Val,
TypeResults const &TR) {
bool has_other_active_operand = false;
if (auto CB = dyn_cast<CallBase>(I)) {
if (EnzymeGlobalActivity)
return true;
#if LLVM_VERSION_MAJOR >= 14
for (auto &arg : CB->args())
#else
for (auto &arg : CB->arg_operands())
#endif
{
if (arg == Val)
continue;
if (isFunctionArgumentConstant(CB, arg))
continue;
if (!DeducingPointers.count(arg) && isConstantValue(TR, arg)) {
continue;
}
has_other_active_operand = true;
}
} else {
for (auto &arg : I->operands()) {
if (arg == Val)
continue;
if (!DeducingPointers.count(arg) && isConstantValue(TR, arg)) {
continue;
}
has_other_active_operand = true;
}
}
return has_other_active_operand;
}

bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
// This analysis may only be called by instructions corresponding to
// the function analyzed by TypeInfo -- however if the Value
Expand Down Expand Up @@ -2188,8 +2237,10 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
// Notably need both to check the result and instruction since
// A load that has as result an active pointer is not an active
// instruction, but does have an active value
if (!Hypothesis->isConstantInstruction(TR, I) ||
(I != Val && !Hypothesis->isConstantValue(TR, I))) {
if ((!Hypothesis->isConstantInstruction(TR, I) ||
(I != Val && !Hypothesis->isConstantValue(TR, I))) &&
!(Hypothesis->isConstantValue(TR, I) &&
!Hypothesis->hasActiveArgumentOtherThan(I, Val, TR))) {
potentiallyActiveLoad = I;
// If this a potential pointer of pointer AND
// double** Val;
Expand All @@ -2208,10 +2259,14 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
// double* I = *Val;
// I[0] = active;
//
if ((I->mayWriteToMemory() &&
!Hypothesis->isConstantInstruction(TR, I)) ||
(!Hypothesis->DeducingPointers.count(I) &&
!Hypothesis->isConstantValue(TR, I) && TR.anyPointer(I))) {
// We have an exception for instructions with a single active
// argument as the only way for data to flow into it is through
// itself.
if (((I->mayWriteToMemory() &&
!Hypothesis->isConstantInstruction(TR, I)) ||
(!Hypothesis->DeducingPointers.count(I) &&
!Hypothesis->isConstantValue(TR, I) && TR.anyPointer(I))) &&
Hypothesis->hasActiveArgumentOtherThan(I, Val, TR)) {
if (EnzymePrintActivity)
llvm::errs() << "potential active store via pointer in "
"unknown inst: "
Expand Down Expand Up @@ -2729,6 +2784,10 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults const &TR,
<< "\n";
return true;
}
// GEPs may be returning a pointer to an active constant (like a vtable),
// but we can still prove they are inactive if their USERS are inactive.
// However, since this is UP search, finding non-constant args means we
// cannot inductively prove it's inactive purely from origin.
return false;
} else if (auto ci = dyn_cast<CallInst>(inst)) {
bool seenuse = false;
Expand Down Expand Up @@ -3586,7 +3645,8 @@ bool ActivityAnalyzer::isValueActivelyStoredOrReturned(TypeResults const &TR,
if (!inst->mayWriteToMemory() ||
(isa<CallInst>(inst) &&
(AA.onlyReadsMemory(cast<CallInst>(inst)) ||
isLocalReadOnlyOrThrow(cast<CallInst>(inst))))) {
isLocalReadOnlyOrThrow(cast<CallInst>(inst)))) ||
!hasActiveArgumentOtherThan(inst, val, TR)) {
// if not written to memory and returning a known constant, this
// cannot be actively returned/stored
if (inst->getParent()->getParent() == TR.getFunction() &&
Expand Down
5 changes: 4 additions & 1 deletion enzyme/Enzyme/ActivityAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ class ActivityAnalyzer {
InsertConstValueRecursionHandler = nullptr;
}

bool hasActiveArgumentOtherThan(llvm::Instruction *I, llvm::Value *Val,
TypeResults const &TR);

/// Import known constants from an existing analyzer
void insertConstantsFrom(TypeResults const &TR,
ActivityAnalyzer &Hypothesis) {
Expand Down Expand Up @@ -221,7 +224,7 @@ class ActivityAnalyzer {
}

/// Is the use of value val as an argument of call CI known to be inactive
bool isFunctionArgumentConstant(llvm::CallInst *CI, llvm::Value *val);
bool isFunctionArgumentConstant(llvm::CallBase *CI, llvm::Value *val);

/// Is the instruction guaranteed to be inactive because of its operands.
/// \p considerValue specifies that we ask whether the returned value, rather
Expand Down
54 changes: 54 additions & 0 deletions enzyme/Enzyme/CallDerivatives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2344,6 +2344,60 @@ bool AdjointGenerator::handleKnownCallDerivatives(
return true;
}

if (funcName == "__kmpc_reduce_nowait") {
if (gutils->isConstantInstruction(&call)) {
return false;
}
}

if (funcName == "__kmpc_end_reduce_nowait") {
assert(gutils->isConstantInstruction(&call));
auto GV = call.getArgOperand(2);
bool active = false;

SmallVector<Value *, 4> worklist;
worklist.push_back(GV);
SmallPtrSet<Value *, 4> seen;
while (!worklist.empty()) {
auto V = worklist.pop_back_val();
if (!seen.insert(V).second)
continue;
for (auto u : V->users()) {
if (u == &call)
continue;
if (auto *CE = dyn_cast<ConstantExpr>(u)) {
if (CE->isCast()) {
worklist.push_back(CE);
continue;
}
} else if (auto *BC = dyn_cast<CastInst>(u)) {
worklist.push_back(BC);
continue;
}
auto CB = dyn_cast<CallBase>(u);
if (!CB) {
EmitFailure("UnknownKmpcUser", call.getDebugLoc(), &call,
"Unknown user ", *u, " of kmpc global\n");
return false;
}
if (CB->getParent()->getParent() != gutils->oldFunc) {
continue;
}
auto F = CB->getCalledFunction();
if (!F || F->getName() != "__kmpc_reduce_nowait") {
EmitFailure("UnknownKmpcUser", call.getDebugLoc(), &call,
"Unknown CallBase user ", *CB, " of kmpc global\n");
return false;
}
if (!gutils->isConstantInstruction(CB)) {
active = true;
}
}
}
if (!active)
return false;
}

if (startsWith(funcName, "__kmpc") &&
funcName != "__kmpc_global_thread_num") {
std::string s;
Expand Down
44 changes: 44 additions & 0 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2235,6 +2235,50 @@ bool DetectReadonlyOrThrow(Module &M) {

inverse_todo_map.erase(found);
}

for (auto &F : M) {
for (auto &B : F) {
for (auto &I : B) {
if (auto CB = dyn_cast<CallBase>(&I)) {
auto CF = CB->getCalledFunction();
if (!CF)
continue;
if (CF->getName() != "__kmpc_reduce_nowait")
continue;
auto ReduceF = dyn_cast<Function>(getBaseObject(CB->getOperand(5)));
if (!ReduceF)
continue;
if (!isNoCapture(CB, 4) && isNoCapture(ReduceF, 0) &&
isNoCapture(ReduceF, 1)) {
addCallSiteNoCapture(CB, 4);
changed = true;
}
if (!CB->hasFnAttr(Attribute::NoFree) &&
ReduceF->hasFnAttribute(Attribute::NoFree)) {
CB->addFnAttr(Attribute::NoFree);
changed = true;
}
if (isReadNone(ReduceF, 0) && isReadNone(ReduceF, 1) &&
!isReadNone(CB, 4)) {
CB->removeParamAttr(4, Attribute::ReadOnly);
CB->removeParamAttr(4, Attribute::WriteOnly);
CB->addParamAttr(4, Attribute::ReadNone);
changed = true;
}
if (isReadOnly(ReduceF, 0) && isReadOnly(ReduceF, 1) &&
!isReadOnly(CB, 4) && !isWriteOnly(CB, 4)) {
CB->addParamAttr(4, Attribute::ReadOnly);
changed = true;
}
if (isWriteOnly(ReduceF, 0) && isWriteOnly(ReduceF, 1) &&
!isReadOnly(CB, 4) && !isWriteOnly(CB, 4)) {
CB->addParamAttr(4, Attribute::WriteOnly);
changed = true;
}
}
}
}
}
return changed;
}

Expand Down
3 changes: 2 additions & 1 deletion enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4771,7 +4771,8 @@ void TypeAnalyzer::visitCallBase(CallBase &call) {
}
if (funcName == "omp_get_max_threads" || funcName == "omp_get_thread_num" ||
funcName == "omp_get_num_threads" ||
funcName == "__kmpc_global_thread_num") {
funcName == "__kmpc_global_thread_num" ||
funcName == "__kmpc_reduce_nowait") {
updateAnalysis(&call, TypeTree(BaseType::Integer).Only(-1, &call), &call);
return;
}
Expand Down
3 changes: 3 additions & 0 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ bool attributeKnownFunctions(llvm::Function &F) {
}
}
}
if (F.getName() == "__kmpc_end_reduce_nowait") {
F.addFnAttr(Attribute::NoFree);
}
if (F.getName() == "memcmp") {
changed = true;
#if LLVM_VERSION_MAJOR >= 16
Expand Down
15 changes: 10 additions & 5 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,13 @@ static inline bool isReadNone(const llvm::Function *F, ssize_t arg = -1) {
return isReadOnly(F, arg) && isWriteOnly(F, arg);
}

static inline bool isNoCapture(const llvm::Function *F, size_t idx) {
if (idx < F->arg_size() && F->getArg(idx)->hasNoCaptureAttr())
return true;
// if (F->getAttributes().hasParamAttribute(idx, "enzyme_NoCapture"))
// return true;
return false;
}
static inline bool isNoCapture(const llvm::CallBase *call, size_t idx) {
if (call->doesNotCapture(idx))
return true;
Expand All @@ -1851,11 +1858,9 @@ static inline bool isNoCapture(const llvm::CallBase *call, size_t idx) {
// call wrapping args into an array. This is because the wrapped array
// may be nocapure/readonly, but the actual arg (which will be put in the
// array) may not be.
if (F->getCallingConv() == call->getCallingConv())
if (idx < F->arg_size() && F->getArg(idx)->hasNoCaptureAttr())
return true;
// if (F->getAttributes().hasParamAttribute(idx, "enzyme_NoCapture"))
// return true;
if (F->getCallingConv() == call->getCallingConv()) {
return isNoCapture(F, idx);
}
}
return false;
}
Expand Down
12 changes: 7 additions & 5 deletions enzyme/test/Enzyme/ReverseMode/constglobal.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
; RUN: if [ %llvmver -eq 15 ]; then %opt < %s %loadEnzyme -enzyme -opaque-pointers=1 -S | FileCheck %s; fi
; RUN: if [ %llvmver -ge 15 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -opaque-pointers=1 -S | FileCheck %s; fi
; RUN: if [ %llvmver -eq 15 ]; then %opt < %s %loadEnzyme -enzyme -opaque-pointers=1 -S | FileCheck %s --check-prefixes=CHECK,CHECK; fi
; RUN: if [ %llvmver -ge 16 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -opaque-pointers=1 -S | FileCheck %s --check-prefixes=CHECK,CHECK; fi
; RUN: if [ %llvmver -eq 15 ]; then %opt < %s %newLoadEnzyme -passes="enzyme" -opaque-pointers=1 -S | FileCheck %s --check-prefixes=CHECK,CHECK; fi

%"class.std::ios_base::Init" = type { i8 }
%class.Test = type { ptr }
Expand Down Expand Up @@ -52,7 +53,7 @@ entry:
%this.addr = alloca ptr, align 8
store ptr %this, ptr %this.addr, align 8
%this1 = load ptr, ptr %this.addr, align 8
store ptr getelementptr inbounds ({ [3 x ptr] }, ptr @_ZTV4Test, i32 0, inrange i32 0, i32 2), ptr %this1, align 8
store ptr getelementptr inbounds ({ [3 x ptr] }, ptr @_ZTV4Test, i32 0, i32 0, i32 2), ptr %this1, align 8
ret void
}

Expand Down Expand Up @@ -103,6 +104,7 @@ attributes #6 = { mustprogress noinline norecurse optnone uwtable "frame-pointer
; CHECK-NEXT: %"sys'ipa" = alloca %class.Test, align 8
; CHECK-NEXT: store %class.Test zeroinitializer, ptr %"sys'ipa", align 8
; CHECK-NEXT: %sys = alloca %class.Test, align 8
; CHECK-NEXT: call void @nofree__ZN4TestC2Ev(ptr noundef nonnull align 8 dereferenceable(8) %sys)
; CHECK-NEXT: br label %invertentry

; CHECK: invertentry: ; preds = %entry
Expand All @@ -112,8 +114,8 @@ attributes #6 = { mustprogress noinline norecurse optnone uwtable "frame-pointer

; CHECK: define internal void @diffe_ZN4TestC2Ev(ptr noundef nonnull align 8 dereferenceable(8) %this, ptr align 8 %"this'")
; CHECK-NEXT: entry:
; CHECK-NEXT: store ptr getelementptr inbounds ({ [3 x ptr] }, ptr @_ZTV4Test_shadow, i32 0, inrange i32 0, i32 2), ptr %"this'", align 8, !alias.scope !7, !noalias !10
; CHECK-NEXT: store ptr getelementptr inbounds ({ [3 x ptr] }, ptr @_ZTV4Test, i32 0, inrange i32 0, i32 2), ptr %this, align 8, !alias.scope !10, !noalias !7
; CHECK-NEXT: store ptr getelementptr inbounds ({ [3 x ptr] }, ptr @_ZTV4Test_shadow, i32 0, i32 0, i32 2), ptr %"this'", align 8, !alias.scope !7, !noalias !10
; CHECK-NEXT: store ptr getelementptr inbounds ({ [3 x ptr] }, ptr @_ZTV4Test, i32 0, i32 0, i32 2), ptr %this, align 8, !alias.scope !10, !noalias !7
; CHECK-NEXT: br label %invertentry

; CHECK: invertentry: ; preds = %entry
Expand Down
5 changes: 3 additions & 2 deletions enzyme/test/Enzyme/ReverseMode/invcglobal.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme-detect-readthrow=0 -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -enzyme-detect-readthrow=0 -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme-detect-readthrow=0 -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s --check-prefixes=CHECK,CHECK; fi
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %newLoadEnzyme -enzyme-preopt=false -enzyme-detect-readthrow=0 -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s --check-prefixes=CHECK,CHECK; else %opt < %s %newLoadEnzyme -enzyme-preopt=false -enzyme-detect-readthrow=0 -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s --check-prefixes=CHECK,CHECK; fi

@g = constant i8* null, align 8

Expand Down Expand Up @@ -28,6 +28,7 @@ declare double @__enzyme_autodiff(...)
; CHECK-NEXT: store void ()* null, void ()** %"ai2'ipa", align 8
; CHECK-NEXT: %ai2 = alloca void ()*, align 8
; CHECK-NEXT: call void @diffe_Z3barv(void ()** %ai2, void ()** %"ai2'ipa")
; CHECK-NEXT: call void @nofree__Z3barv(void ()** %ai2)
; CHECK-NEXT: ret void
; CHECK-NEXT: }

Expand Down
Loading
Loading