diff --git a/enzyme/Enzyme/Clang/EnzymeClang.cpp b/enzyme/Enzyme/Clang/EnzymeClang.cpp index bfadc6bca08b..a5b5b964d00e 100644 --- a/enzyme/Enzyme/Clang/EnzymeClang.cpp +++ b/enzyme/Enzyme/Clang/EnzymeClang.cpp @@ -505,6 +505,49 @@ struct EnzymeInactiveAttrInfo : public ParsedAttrInfo { static ParsedAttrInfoRegistry::Add X4("enzyme_inactive", ""); +struct EnzymeElementwiseReadAttrInfo : public ParsedAttrInfo { + EnzymeElementwiseReadAttrInfo() { + OptArgs = 1; + static constexpr Spelling S[] = { + {ParsedAttr::AS_GNU, "enzyme_elementwise_read"}, +#if LLVM_VERSION_MAJOR > 17 + {ParsedAttr::AS_C23, "enzyme_elementwise_read"}, +#else + {ParsedAttr::AS_C2x, "enzyme_elementwise_read"}, +#endif + {ParsedAttr::AS_CXX11, "enzyme_elementwise_read"}, + {ParsedAttr::AS_CXX11, "enzyme::elementwise_read"} + }; + Spellings = S; + } + + bool diagAppertainsToDecl(Sema &S, const ParsedAttr &Attr, + const Decl *D) const override { + if (isa(D)) + return true; + S.Diag(Attr.getLoc(), diag::warn_attribute_wrong_decl_type_str) + << Attr << "functions"; + return false; + } + + AttrHandling handleDeclAttribute(Sema &S, Decl *D, + const ParsedAttr &Attr) const override { + if (Attr.getNumArgs() != 0) { + unsigned ID = S.getDiagnostics().getCustomDiagID( + DiagnosticsEngine::Error, + "'enzyme_elementwise_read' attribute requires zero arguments"); + S.Diag(Attr.getLoc(), ID); + return AttributeNotApplied; + } + D->addAttr(AnnotateAttr::Create(S.Context, "enzyme_elementwise_read", + nullptr, 0, Attr.getRange())); + return AttributeApplied; + } +}; + +static ParsedAttrInfoRegistry::Add + XElemRead("enzyme_elementwise_read", ""); + struct EnzymeNoFreeAttrInfo : public ParsedAttrInfo { EnzymeNoFreeAttrInfo() { OptArgs = 1; diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index 19e3c947b04b..dd75209aab7f 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -51,6 +51,26 @@ using namespace llvm; +namespace { +bool elementwiseReadForContext(const Instruction *orig, const Value *origptr) { + if (orig) { + if (const Function *F = orig->getFunction()) { + if (F->hasFnAttribute("enzyme_elementwise_read")) { + return true; + } + } + } + const Value *base = getBaseObject(origptr); + if (auto *arg = dyn_cast(base)) { + if (const Function *F = arg->getParent()) { + return F->getAttributes().hasParamAttr(arg->getArgNo(), + "enzyme_elementwise_read"); + } + } + return false; +} +} // namespace + DiffeGradientUtils::DiffeGradientUtils( EnzymeLogic &Logic, Function *newFunc_, Function *oldFunc_, TargetLibraryInfo &TLI, TypeAnalysis &TA, TypeResults TR, @@ -931,6 +951,8 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(Instruction *orig, // all additional parallelism in this function is outlined. if (backwardsOnlyShadows.find(TmpOrig) != backwardsOnlyShadows.end()) Atomic = false; + if (Atomic && elementwiseReadForContext(orig, origptr)) + Atomic = false; if (Atomic) { // For amdgcn constant AS is 4 and if the primal is in it we need to cast diff --git a/enzyme/Enzyme/PreserveNVVM.cpp b/enzyme/Enzyme/PreserveNVVM.cpp index 270570eddcba..8c98c0de0257 100644 --- a/enzyme/Enzyme/PreserveNVVM.cpp +++ b/enzyme/Enzyme/PreserveNVVM.cpp @@ -387,6 +387,15 @@ bool preserveNVVM(bool Begin, Module &M) { continue; } + if (AS == "enzyme_elementwise_read" && Func) { + Func->addAttribute(AttributeList::FunctionIndex, + Attribute::get(Func->getContext(), + "enzyme_elementwise_read")); + changed = true; + replacements.push_back(Constant::getNullValue(CAOp->getType())); + continue; + } + if (AS == "enzyme_shouldrecompute" && Func) { Func->addAttribute( AttributeList::FunctionIndex, diff --git a/enzyme/test/Enzyme/ReverseMode/elementwise-read.ll b/enzyme/test/Enzyme/ReverseMode/elementwise-read.ll new file mode 100644 index 000000000000..fec23962873b --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/elementwise-read.ll @@ -0,0 +1,38 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -preserve-nvvm -enzyme -enzyme-detect-readthrow=0 -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -enzyme-detect-readthrow=0 -passes="preserve-nvvm,enzyme" -S | FileCheck %s + +; ModuleID = 'elementwise-read.ll' +source_filename = "elementwise-read.ll" +target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64-ni:10:11:12:13" +target triple = "nvptx64-nvidia-cuda" + +@.str.enzyme_elementwise_read = private unnamed_addr constant [24 x i8] c"enzyme_elementwise_read\00", section "llvm.metadata" +@.str.file = private unnamed_addr constant [17 x i8] c"elementwise-read\00", section "llvm.metadata" +@llvm.global.annotations = appending global [1 x { i8*, i8*, i8*, i32, i8* }] [{ i8*, i8*, i8*, i32, i8* } { i8* bitcast (float (float addrspace(1)*)* @vmul to i8*), i8* getelementptr inbounds ([24 x i8], [24 x i8]* @.str.enzyme_elementwise_read, i32 0, i32 0), i8* getelementptr inbounds ([17 x i8], [17 x i8]* @.str.file, i32 0, i32 0), i32 1, i8* null }], section "llvm.metadata" + +declare float @llvm.nvvm.ldg.global.f.f32.p1f32(float addrspace(1)* nocapture, i32) + +define float @vmul(float addrspace(1)* %inp) { +top: + %ld = call float @llvm.nvvm.ldg.global.f.f32.p1f32(float addrspace(1)* %inp, i32 4) + ret float %ld +} + +define float @test_derivative(float addrspace(1)* %inp, float addrspace(1)* %dinp) { +entry: + %0 = tail call float (float (float addrspace(1)*)*, ...) @__enzyme_autodiff(float (float addrspace(1)*)* nonnull @vmul, float addrspace(1)* %inp, float addrspace(1)* %dinp) + ret float %0 +} + +declare float @__enzyme_autodiff(float (float addrspace(1)*)*, ...) + +; CHECK-LABEL: define float @vmul(float addrspace(1)* %inp) #[[VMULATTR:[0-9]+]] +; CHECK: attributes #[[VMULATTR]] = { {{[^}]*}}"enzyme_elementwise_read"{{[^}]*}} } + +; CHECK-LABEL: define internal void @diffevmul(float addrspace(1)* %inp, float addrspace(1)* %"inp'", float %differeturn) +; CHECK-NOT: atomicrmw +; CHECK-NEXT: top: +; CHECK-NEXT: %[[OLD:[^ ]+]] = load float, float addrspace(1)* %"inp'" +; CHECK-NEXT: %[[NEW:[^ ]+]] = fadd fast float %[[OLD]], %differeturn +; CHECK-NEXT: store float %[[NEW]], float addrspace(1)* %"inp'" +; CHECK-NEXT: ret void