diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp
index 7dc7586eaf9a..feb03b3e4a27 100644
--- a/enzyme/Enzyme/FunctionUtils.cpp
+++ b/enzyme/Enzyme/FunctionUtils.cpp
@@ -28,6 +28,7 @@
#include "EnzymeLogic.h"
#include "GradientUtils.h"
#include "LibraryFuncs.h"
+#include "PreserveNVVM.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/BasicBlock.h"
@@ -3354,6 +3355,53 @@ void SelectOptimization(Function *F) {
}
void ReplaceFunctionImplementation(Module &M) {
+ // For NVPTX targets, ensure __nv_* implementations are declared for any
+ // LLVM math intrinsics used in derivative code. Enzyme may generate calls
+ // to LLVM intrinsics like llvm.log.f64 or llvm.cos.f64 as part of
+ // derivatives, but on NVPTX these intrinsics must be lowered to __nv_log
+ // / __nv_cos etc. If the __nv_* function is not yet declared, declare it
+ // here so the replacement loop below can do the substitution correctly.
+ if (isTargetNVPTX(M)) {
+ // Pairs of {intrinsic base name, f32 suffix flag} -> NVPTX function name.
+ // Double variant: llvm..f64 -> __nv_
+ // Float variant: llvm..f32 -> __nv_f
+ static const struct {
+ const char *base;
+ } nvptxMathFuncs[] = {
+ {"log"}, {"log2"}, {"log10"}, {"exp"}, {"exp2"},
+ {"sqrt"}, {"sin"}, {"cos"}, {"tanh"}, {"sinh"},
+ {"cosh"}, {"pow"}, {"fabs"}, {"floor"}, {"ceil"},
+ {"round"}, {"trunc"},
+ };
+ for (auto &entry : nvptxMathFuncs) {
+ StringRef base = entry.base;
+ // Process double (f64) and float (f32) variants
+ for (bool isFloat : {false, true}) {
+ std::string llvmIntrName =
+ (Twine("llvm.") + base + (isFloat ? ".f32" : ".f64")).str();
+ std::string nvFuncName =
+ (Twine("__nv_") + base + (isFloat ? "f" : "")).str();
+
+ // Only act if the LLVM intrinsic is referenced in the module but
+ // the __nv_* implementing function is not yet declared.
+ Function *intrFunc = M.getFunction(llvmIntrName);
+ if (!intrFunc)
+ continue;
+ if (M.getFunction(nvFuncName))
+ continue;
+
+ // Declare the __nv_* function with the same type as the intrinsic.
+ Function *nvFunc =
+ Function::Create(intrFunc->getFunctionType(),
+ Function::ExternalLinkage, nvFuncName, M);
+ nvFunc->addFnAttr("implements", llvmIntrName);
+ nvFunc->addFnAttr("implements2", (Twine(base) + (isFloat ? "f" : "")).str());
+ nvFunc->addFnAttr("enzyme_math",
+ (Twine(base) + (isFloat ? "f" : "")).str());
+ }
+ }
+ }
+
for (Function &Impl : M) {
for (auto attr : {"implements", "implements2"}) {
if (!Impl.hasFnAttribute(attr))
diff --git a/enzyme/Enzyme/PreserveNVVM.h b/enzyme/Enzyme/PreserveNVVM.h
index 0e037a5199ec..ad18348e2f67 100644
--- a/enzyme/Enzyme/PreserveNVVM.h
+++ b/enzyme/Enzyme/PreserveNVVM.h
@@ -37,6 +37,8 @@ class FunctionPass;
llvm::ModulePass *createPreserveNVVMPass(bool Begin);
llvm::FunctionPass *createPreserveNVVMFnPass(bool Begin);
+bool isTargetNVPTX(llvm::Module &M);
+
class PreserveNVVMNewPM final
: public llvm::AnalysisInfoMixin {
friend struct llvm::AnalysisInfoMixin;
diff --git a/enzyme/test/Enzyme/ReverseMode/nvvm_tanh.ll b/enzyme/test/Enzyme/ReverseMode/nvvm_tanh.ll
new file mode 100644
index 000000000000..9a6604c65d45
--- /dev/null
+++ b/enzyme/test/Enzyme/ReverseMode/nvvm_tanh.ll
@@ -0,0 +1,37 @@
+; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme" -S | FileCheck %s
+
+; Test that on NVPTX targets, the derivative of __nv_tanh uses __nv_cosh (not
+; plain "cosh" which doesn't exist on the CUDA device).
+
+target triple = "nvptx64-nvidia-cuda"
+target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
+
+; Declare __nv_tanh as the CUDA device implementation of tanh
+declare double @__nv_tanh(double) #1
+attributes #1 = { nounwind "enzyme_math"="tanh" "implements"="llvm.tanh.f64" "implements2"="tanh" }
+
+; Declare __nv_cosh needed for the derivative of tanh
+declare double @__nv_cosh(double) #2
+attributes #2 = { nounwind "enzyme_math"="cosh" "implements"="llvm.cosh.f64" "implements2"="cosh" }
+
+define void @foo(double* %x_in, double* %x_out) {
+entry:
+ %x = load double, double* %x_in
+ %r = call double @__nv_tanh(double %x)
+ store double %r, double* %x_out
+ ret void
+}
+
+declare void @__enzyme_autodiff(...)
+
+define void @test(double* %x, double* %d_x, double* %y, double* %d_y) {
+entry:
+ call void (...) @__enzyme_autodiff(void (double*, double*)* @foo,
+ double* %x, double* %d_x,
+ double* %y, double* %d_y)
+ ret void
+}
+
+; Derivative of tanh(x) is 1/cosh(x)^2; on NVPTX must use __nv_cosh not cosh
+; CHECK: define internal void @diffefoo(
+; CHECK: call double @__nv_cosh(
diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
index a7ace51e2bed..20cea0899f60 100644
--- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
+++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
@@ -89,12 +89,38 @@ void getFunction(const Twine &curIndent, raw_ostream &os, StringRef callval,
if (opName == "SameTypesFunc" || Def->isSubClassOf("SameTypesFunc")) {
os << curIndent << "auto " << FT << " = cast(&" << origName
<< ")->getFunctionType();\n";
- os << curIndent << "auto " << callval
- << " = gutils->oldFunc->getParent()->getOrInsertFunction(";
- os << Def->getValueInit("name")->getAsString();
- os << ", " << FT
+ // When differentiating an NVPTX device function (has "implements2" attr
+ // set by PreserveNVVM), prefer the __nv_* version of the derivative
+ // helper function so the generated derivative code works on NVPTX device.
+ os << curIndent << "llvm::Value *" << callval
+ << " = [&]() -> llvm::Value* {\n";
+ os << curIndent << " std::string _fn = "
+ << Def->getValueInit("name")->getAsString() << ";\n";
+ os << curIndent
+ << " if (called && called->hasFnAttribute(\"implements2\")) {\n";
+ // Look for a function already in the module with
+ // implements2 = targetFuncName (e.g., __nv_cosh implements "cosh")
+ os << curIndent
+ << " for (auto &_F : *gutils->oldFunc->getParent()) {\n";
+ os << curIndent
+ << " if (_F.hasFnAttribute(\"implements2\") && "
+ "_F.getFnAttribute(\"implements2\").getValueAsString() == "
+ << Def->getValueInit("name")->getAsString() << ")\n";
+ os << curIndent << " return &_F;\n";
+ os << curIndent << " }\n";
+ // Not found in module: for NVPTX use __nv_ naming convention.
+ // (On NVPTX, device math functions follow the __nv_* naming scheme.)
+ os << curIndent
+ << " if (gutils->oldFunc->getParent()->getTargetTriple().find("
+ "\"nvptx\") != std::string::npos)\n";
+ os << curIndent << " _fn = \"__nv_\" + _fn;\n";
+ os << curIndent << " }\n";
+ os << curIndent
+ << " return gutils->oldFunc->getParent()->getOrInsertFunction(_fn, "
+ << FT
<< ", called->getAttributes().removeFnAttribute(called->getContext(), "
"\"enzymejl_needs_restoration\")).getCallee();\n";
+ os << curIndent << "}();\n";
os << curIndent << "auto " << cconv << " = cast(&" << origName
<< ")->getCallingConv();\n";
return;