diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 7dc7586eaf9a..feb03b3e4a27 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -28,6 +28,7 @@ #include "EnzymeLogic.h" #include "GradientUtils.h" #include "LibraryFuncs.h" +#include "PreserveNVVM.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -3354,6 +3355,53 @@ void SelectOptimization(Function *F) { } void ReplaceFunctionImplementation(Module &M) { + // For NVPTX targets, ensure __nv_* implementations are declared for any + // LLVM math intrinsics used in derivative code. Enzyme may generate calls + // to LLVM intrinsics like llvm.log.f64 or llvm.cos.f64 as part of + // derivatives, but on NVPTX these intrinsics must be lowered to __nv_log + // / __nv_cos etc. If the __nv_* function is not yet declared, declare it + // here so the replacement loop below can do the substitution correctly. + if (isTargetNVPTX(M)) { + // Pairs of {intrinsic base name, f32 suffix flag} -> NVPTX function name. + // Double variant: llvm..f64 -> __nv_ + // Float variant: llvm..f32 -> __nv_f + static const struct { + const char *base; + } nvptxMathFuncs[] = { + {"log"}, {"log2"}, {"log10"}, {"exp"}, {"exp2"}, + {"sqrt"}, {"sin"}, {"cos"}, {"tanh"}, {"sinh"}, + {"cosh"}, {"pow"}, {"fabs"}, {"floor"}, {"ceil"}, + {"round"}, {"trunc"}, + }; + for (auto &entry : nvptxMathFuncs) { + StringRef base = entry.base; + // Process double (f64) and float (f32) variants + for (bool isFloat : {false, true}) { + std::string llvmIntrName = + (Twine("llvm.") + base + (isFloat ? ".f32" : ".f64")).str(); + std::string nvFuncName = + (Twine("__nv_") + base + (isFloat ? "f" : "")).str(); + + // Only act if the LLVM intrinsic is referenced in the module but + // the __nv_* implementing function is not yet declared. + Function *intrFunc = M.getFunction(llvmIntrName); + if (!intrFunc) + continue; + if (M.getFunction(nvFuncName)) + continue; + + // Declare the __nv_* function with the same type as the intrinsic. + Function *nvFunc = + Function::Create(intrFunc->getFunctionType(), + Function::ExternalLinkage, nvFuncName, M); + nvFunc->addFnAttr("implements", llvmIntrName); + nvFunc->addFnAttr("implements2", (Twine(base) + (isFloat ? "f" : "")).str()); + nvFunc->addFnAttr("enzyme_math", + (Twine(base) + (isFloat ? "f" : "")).str()); + } + } + } + for (Function &Impl : M) { for (auto attr : {"implements", "implements2"}) { if (!Impl.hasFnAttribute(attr)) diff --git a/enzyme/Enzyme/PreserveNVVM.h b/enzyme/Enzyme/PreserveNVVM.h index 0e037a5199ec..ad18348e2f67 100644 --- a/enzyme/Enzyme/PreserveNVVM.h +++ b/enzyme/Enzyme/PreserveNVVM.h @@ -37,6 +37,8 @@ class FunctionPass; llvm::ModulePass *createPreserveNVVMPass(bool Begin); llvm::FunctionPass *createPreserveNVVMFnPass(bool Begin); +bool isTargetNVPTX(llvm::Module &M); + class PreserveNVVMNewPM final : public llvm::AnalysisInfoMixin { friend struct llvm::AnalysisInfoMixin; diff --git a/enzyme/test/Enzyme/ReverseMode/nvvm_tanh.ll b/enzyme/test/Enzyme/ReverseMode/nvvm_tanh.ll new file mode 100644 index 000000000000..9a6604c65d45 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/nvvm_tanh.ll @@ -0,0 +1,37 @@ +; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme" -S | FileCheck %s + +; Test that on NVPTX targets, the derivative of __nv_tanh uses __nv_cosh (not +; plain "cosh" which doesn't exist on the CUDA device). + +target triple = "nvptx64-nvidia-cuda" +target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64" + +; Declare __nv_tanh as the CUDA device implementation of tanh +declare double @__nv_tanh(double) #1 +attributes #1 = { nounwind "enzyme_math"="tanh" "implements"="llvm.tanh.f64" "implements2"="tanh" } + +; Declare __nv_cosh needed for the derivative of tanh +declare double @__nv_cosh(double) #2 +attributes #2 = { nounwind "enzyme_math"="cosh" "implements"="llvm.cosh.f64" "implements2"="cosh" } + +define void @foo(double* %x_in, double* %x_out) { +entry: + %x = load double, double* %x_in + %r = call double @__nv_tanh(double %x) + store double %r, double* %x_out + ret void +} + +declare void @__enzyme_autodiff(...) + +define void @test(double* %x, double* %d_x, double* %y, double* %d_y) { +entry: + call void (...) @__enzyme_autodiff(void (double*, double*)* @foo, + double* %x, double* %d_x, + double* %y, double* %d_y) + ret void +} + +; Derivative of tanh(x) is 1/cosh(x)^2; on NVPTX must use __nv_cosh not cosh +; CHECK: define internal void @diffefoo( +; CHECK: call double @__nv_cosh( diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index a7ace51e2bed..20cea0899f60 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -89,12 +89,38 @@ void getFunction(const Twine &curIndent, raw_ostream &os, StringRef callval, if (opName == "SameTypesFunc" || Def->isSubClassOf("SameTypesFunc")) { os << curIndent << "auto " << FT << " = cast(&" << origName << ")->getFunctionType();\n"; - os << curIndent << "auto " << callval - << " = gutils->oldFunc->getParent()->getOrInsertFunction("; - os << Def->getValueInit("name")->getAsString(); - os << ", " << FT + // When differentiating an NVPTX device function (has "implements2" attr + // set by PreserveNVVM), prefer the __nv_* version of the derivative + // helper function so the generated derivative code works on NVPTX device. + os << curIndent << "llvm::Value *" << callval + << " = [&]() -> llvm::Value* {\n"; + os << curIndent << " std::string _fn = " + << Def->getValueInit("name")->getAsString() << ";\n"; + os << curIndent + << " if (called && called->hasFnAttribute(\"implements2\")) {\n"; + // Look for a function already in the module with + // implements2 = targetFuncName (e.g., __nv_cosh implements "cosh") + os << curIndent + << " for (auto &_F : *gutils->oldFunc->getParent()) {\n"; + os << curIndent + << " if (_F.hasFnAttribute(\"implements2\") && " + "_F.getFnAttribute(\"implements2\").getValueAsString() == " + << Def->getValueInit("name")->getAsString() << ")\n"; + os << curIndent << " return &_F;\n"; + os << curIndent << " }\n"; + // Not found in module: for NVPTX use __nv_ naming convention. + // (On NVPTX, device math functions follow the __nv_* naming scheme.) + os << curIndent + << " if (gutils->oldFunc->getParent()->getTargetTriple().find(" + "\"nvptx\") != std::string::npos)\n"; + os << curIndent << " _fn = \"__nv_\" + _fn;\n"; + os << curIndent << " }\n"; + os << curIndent + << " return gutils->oldFunc->getParent()->getOrInsertFunction(_fn, " + << FT << ", called->getAttributes().removeFnAttribute(called->getContext(), " "\"enzymejl_needs_restoration\")).getCallee();\n"; + os << curIndent << "}();\n"; os << curIndent << "auto " << cconv << " = cast(&" << origName << ")->getCallingConv();\n"; return;