Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions enzyme/Enzyme/FunctionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "EnzymeLogic.h"
#include "GradientUtils.h"
#include "LibraryFuncs.h"
#include "PreserveNVVM.h"

#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
Expand Down Expand Up @@ -3354,6 +3355,53 @@ void SelectOptimization(Function *F) {
}

void ReplaceFunctionImplementation(Module &M) {
// For NVPTX targets, ensure __nv_* implementations are declared for any
// LLVM math intrinsics used in derivative code. Enzyme may generate calls
// to LLVM intrinsics like llvm.log.f64 or llvm.cos.f64 as part of
// derivatives, but on NVPTX these intrinsics must be lowered to __nv_log
// / __nv_cos etc. If the __nv_* function is not yet declared, declare it
// here so the replacement loop below can do the substitution correctly.
if (isTargetNVPTX(M)) {
// Pairs of {intrinsic base name, f32 suffix flag} -> NVPTX function name.
// Double variant: llvm.<base>.f64 -> __nv_<base>
// Float variant: llvm.<base>.f32 -> __nv_<base>f
static const struct {
const char *base;
} nvptxMathFuncs[] = {
{"log"}, {"log2"}, {"log10"}, {"exp"}, {"exp2"},
{"sqrt"}, {"sin"}, {"cos"}, {"tanh"}, {"sinh"},
{"cosh"}, {"pow"}, {"fabs"}, {"floor"}, {"ceil"},
{"round"}, {"trunc"},
};
for (auto &entry : nvptxMathFuncs) {
StringRef base = entry.base;
// Process double (f64) and float (f32) variants
for (bool isFloat : {false, true}) {
std::string llvmIntrName =
(Twine("llvm.") + base + (isFloat ? ".f32" : ".f64")).str();
std::string nvFuncName =
(Twine("__nv_") + base + (isFloat ? "f" : "")).str();

// Only act if the LLVM intrinsic is referenced in the module but
// the __nv_* implementing function is not yet declared.
Function *intrFunc = M.getFunction(llvmIntrName);
if (!intrFunc)
continue;
if (M.getFunction(nvFuncName))
continue;

// Declare the __nv_* function with the same type as the intrinsic.
Function *nvFunc =
Function::Create(intrFunc->getFunctionType(),
Function::ExternalLinkage, nvFuncName, M);
nvFunc->addFnAttr("implements", llvmIntrName);
nvFunc->addFnAttr("implements2", (Twine(base) + (isFloat ? "f" : "")).str());
nvFunc->addFnAttr("enzyme_math",
(Twine(base) + (isFloat ? "f" : "")).str());
}
}
}

for (Function &Impl : M) {
for (auto attr : {"implements", "implements2"}) {
if (!Impl.hasFnAttribute(attr))
Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/PreserveNVVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class FunctionPass;
llvm::ModulePass *createPreserveNVVMPass(bool Begin);
llvm::FunctionPass *createPreserveNVVMFnPass(bool Begin);

bool isTargetNVPTX(llvm::Module &M);

class PreserveNVVMNewPM final
: public llvm::AnalysisInfoMixin<PreserveNVVMNewPM> {
friend struct llvm::AnalysisInfoMixin<PreserveNVVMNewPM>;
Expand Down
37 changes: 37 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/nvvm_tanh.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme" -S | FileCheck %s

; Test that on NVPTX targets, the derivative of __nv_tanh uses __nv_cosh (not
; plain "cosh" which doesn't exist on the CUDA device).

target triple = "nvptx64-nvidia-cuda"
target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"

; Declare __nv_tanh as the CUDA device implementation of tanh
declare double @__nv_tanh(double) #1
attributes #1 = { nounwind "enzyme_math"="tanh" "implements"="llvm.tanh.f64" "implements2"="tanh" }

; Declare __nv_cosh needed for the derivative of tanh
declare double @__nv_cosh(double) #2
attributes #2 = { nounwind "enzyme_math"="cosh" "implements"="llvm.cosh.f64" "implements2"="cosh" }

define void @foo(double* %x_in, double* %x_out) {
entry:
%x = load double, double* %x_in
%r = call double @__nv_tanh(double %x)
store double %r, double* %x_out
ret void
}

declare void @__enzyme_autodiff(...)

define void @test(double* %x, double* %d_x, double* %y, double* %d_y) {
entry:
call void (...) @__enzyme_autodiff(void (double*, double*)* @foo,
double* %x, double* %d_x,
double* %y, double* %d_y)
ret void
}

; Derivative of tanh(x) is 1/cosh(x)^2; on NVPTX must use __nv_cosh not cosh
; CHECK: define internal void @diffefoo(
; CHECK: call double @__nv_cosh(
34 changes: 30 additions & 4 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,38 @@ void getFunction(const Twine &curIndent, raw_ostream &os, StringRef callval,
if (opName == "SameTypesFunc" || Def->isSubClassOf("SameTypesFunc")) {
os << curIndent << "auto " << FT << " = cast<CallInst>(&" << origName
<< ")->getFunctionType();\n";
os << curIndent << "auto " << callval
<< " = gutils->oldFunc->getParent()->getOrInsertFunction(";
os << Def->getValueInit("name")->getAsString();
os << ", " << FT
// When differentiating an NVPTX device function (has "implements2" attr
// set by PreserveNVVM), prefer the __nv_* version of the derivative
// helper function so the generated derivative code works on NVPTX device.
os << curIndent << "llvm::Value *" << callval
<< " = [&]() -> llvm::Value* {\n";
os << curIndent << " std::string _fn = "
<< Def->getValueInit("name")->getAsString() << ";\n";
os << curIndent
<< " if (called && called->hasFnAttribute(\"implements2\")) {\n";
// Look for a function already in the module with
// implements2 = targetFuncName (e.g., __nv_cosh implements "cosh")
os << curIndent
<< " for (auto &_F : *gutils->oldFunc->getParent()) {\n";
os << curIndent
<< " if (_F.hasFnAttribute(\"implements2\") && "
"_F.getFnAttribute(\"implements2\").getValueAsString() == "
<< Def->getValueInit("name")->getAsString() << ")\n";
os << curIndent << " return &_F;\n";
os << curIndent << " }\n";
// Not found in module: for NVPTX use __nv_<funcname> naming convention.
// (On NVPTX, device math functions follow the __nv_* naming scheme.)
os << curIndent
<< " if (gutils->oldFunc->getParent()->getTargetTriple().find("
"\"nvptx\") != std::string::npos)\n";
os << curIndent << " _fn = \"__nv_\" + _fn;\n";
os << curIndent << " }\n";
os << curIndent
<< " return gutils->oldFunc->getParent()->getOrInsertFunction(_fn, "
<< FT
<< ", called->getAttributes().removeFnAttribute(called->getContext(), "
"\"enzymejl_needs_restoration\")).getCallee();\n";
os << curIndent << "}();\n";
os << curIndent << "auto " << cconv << " = cast<CallInst>(&" << origName
<< ")->getCallingConv();\n";
return;
Expand Down