diff --git a/.github/workflows/poseidon.yml b/.github/workflows/poseidon.yml new file mode 100644 index 000000000000..65809eecf9e9 --- /dev/null +++ b/.github/workflows/poseidon.yml @@ -0,0 +1,56 @@ +name: Poseidon CI + +on: + push: + branches: + - main + pull_request: + merge_group: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + build-linux: + name: Poseidon CI LLVM ${{ matrix.llvm }} ${{ matrix.build }} ${{ matrix.os }} + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + llvm: ["20"] + build: ["Release", "Debug"] + os: [ubuntu-22.04] + + timeout-minutes: 60 + + steps: + - name: Install dependencies + run: | + wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - + sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-${{ matrix.llvm }} main" || true + sudo apt-get update + sudo apt-get install -y cmake gcc g++ llvm-${{ matrix.llvm }}-dev clang-${{ matrix.llvm }} libclang-${{ matrix.llvm }}-dev libzstd-dev libmpfr-dev + sudo python3 -m pip install --upgrade pip lit + + - name: Install Racket + uses: Bogdanp/setup-racket@v1.12 + with: + version: '8.15' + + - uses: actions/checkout@v4 + + - name: Build Enzyme with Poseidon + run: | + mkdir build && cd build + cmake ../enzyme -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DENABLE_POSEIDON=1 -DLLVM_EXTERNAL_LIT=$(which lit) -DLLVM_DIR=/usr/lib/llvm-${{ matrix.llvm }}/lib/cmake/llvm + make -j $(nproc) + + - name: check-poseidon + working-directory: build + run: make -j $(nproc) check-poseidon + + - name: check-poseidon-integration + working-directory: build + run: make -j3 check-poseidon-integration diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 86341a9fadae..450fb1f303f8 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -48,6 +48,10 @@ #include "TraceUtils.h" #include "TypeAnalysis/TBAA.h" +#ifdef ENABLE_POSEIDON +#include "Poseidon/Poseidon.h" +#endif + #define DEBUG_TYPE "enzyme" // Helper instruction visitor that generates adjoints @@ -4610,7 +4614,8 @@ class AdjointGenerator : public llvm::InstVisitor { .forceAnonymousTape = false, .typeInfo = nextTypeInfo, .runtimeActivity = gutils->runtimeActivity, - .strongZero = gutils->strongZero}, + .strongZero = gutils->strongZero, + .profiled = gutils->profiled}, TR.analyzer->interprocedural, subdata, /*omp*/ true); @@ -5989,7 +5994,8 @@ class AdjointGenerator : public llvm::InstVisitor { .forceAnonymousTape = false, .typeInfo = nextTypeInfo, .runtimeActivity = gutils->runtimeActivity, - .strongZero = gutils->strongZero}, + .strongZero = gutils->strongZero, + .profiled = gutils->profiled}, TR.analyzer->interprocedural, subdata); if (!newcalled) return; diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index bd18c0794a1a..8ba08fa7def5 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -717,7 +717,8 @@ LLVMValueRef EnzymeCreatePrimalAndGradient( .forceAnonymousTape = (bool)forceAnonymousTape, .typeInfo = eunwrap(typeInfo, cast(unwrap(todiff))), .runtimeActivity = (bool)runtimeActivity, - .strongZero = (bool)strongZero}, + .strongZero = (bool)strongZero, + .profiled = false}, eunwrap(TA), eunwrap(augmented))); } EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal( diff --git a/enzyme/Enzyme/CMakeLists.txt b/enzyme/Enzyme/CMakeLists.txt index 93be6d6f0de2..f86b247a4ba8 100644 --- a/enzyme/Enzyme/CMakeLists.txt +++ b/enzyme/Enzyme/CMakeLists.txt @@ -50,6 +50,54 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) list(APPEND ENZYME_SRC TypeAnalysis/TypeTree.cpp TypeAnalysis/TypeAnalysis.cpp TypeAnalysis/TypeAnalysisPrinter.cpp TypeAnalysis/RustDebugInfo.cpp) +set(ENABLE_POSEIDON 0 CACHE BOOL "Enable Poseidon") + +if(ENABLE_POSEIDON) + include(ExternalProject) + ExternalProject_Add(herbie + GIT_REPOSITORY https://github.com/sbrantq/herbie + GIT_TAG 193e2a4e6e4902fe5ee7fa8bd5747c19f70ce19b + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND make clean + COMMAND raco pkg install --auto --update-deps fpbench || true + COMMAND raco pkg install --auto --update-deps rival || true + COMMAND cargo build --release --manifest-path=egg-herbie/Cargo.toml + COMMAND raco pkg install ./egg-herbie + COMMAND mkdir -p herbie-compiled/ + COMMAND raco exe -o herbie --orig-exe --embed-dlls --vv src/main.rkt + COMMAND raco distribute herbie-compiled herbie + BUILD_IN_SOURCE true + INSTALL_COMMAND COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_CURRENT_BINARY_DIR}/herbie/install + COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_BINARY_DIR}/herbie-prefix/src/herbie/herbie-compiled ${CMAKE_CURRENT_BINARY_DIR}/herbie/install/herbie + ) + list(APPEND ENZYME_SRC + Poseidon/Poseidon.cpp + Poseidon/PoseidonEvaluators.cpp + Poseidon/PoseidonHerbieUtils.cpp + Poseidon/PoseidonProfUtils.cpp + Poseidon/PoseidonPrecUtils.cpp + Poseidon/PoseidonSolvers.cpp + Poseidon/PoseidonTypes.cpp + Poseidon/PoseidonUtils.cpp + ) + add_compile_definitions(ENABLE_POSEIDON=1) + set_source_files_properties(Poseidon/PoseidonHerbieUtils.cpp PROPERTIES COMPILE_DEFINITIONS HERBIE_BINARY="${CMAKE_CURRENT_BINARY_DIR}/herbie/install/herbie/bin/herbie") + + add_library(EnzymeFPProfile STATIC + Runtimes/FPProfiler/FPProfiler.cpp + ) + target_compile_features(EnzymeFPProfile PRIVATE cxx_std_17) + target_include_directories(EnzymeFPProfile PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/Runtimes/FPProfiler) + + install(TARGETS EnzymeFPProfile + EXPORT EnzymeTargets + ARCHIVE DESTINATION lib + COMPONENT dev + ) +endif() + + if (ENZYME_ENABLE_PLUGINS) # on windows `PLUGIN_TOOL` doesn't link against LLVM.dll if ((WIN32 OR CYGWIN) AND LLVM_LINK_LLVM_DYLIB) @@ -119,6 +167,14 @@ if (ENZYME_ENABLE_PLUGINS) ) target_compile_definitions(LLDEnzyme-${LLVM_VERSION_MAJOR} PUBLIC ENZYME_RUNPASS) endif() + + if(ENABLE_POSEIDON) + target_link_libraries(LLVMEnzyme-${LLVM_VERSION_MAJOR} PRIVATE mpfr) + if (${Clang_FOUND}) + target_link_libraries(ClangEnzyme-${LLVM_VERSION_MAJOR} PRIVATE mpfr) + endif() + target_link_libraries(LLDEnzyme-${LLVM_VERSION_MAJOR} PRIVATE mpfr) + endif() endif() if (${ENZYME_STATIC_LIB}) @@ -129,6 +185,9 @@ if (${ENZYME_STATIC_LIB}) DEPENDS intrinsics_gen ) + if(ENABLE_POSEIDON) + target_link_libraries(EnzymeStatic-${LLVM_VERSION_MAJOR} PRIVATE mpfr) + endif() endif() if (${ENZYME_EXTERNAL_SHARED_LIB}) @@ -154,6 +213,10 @@ if (${ENZYME_EXTERNAL_SHARED_LIB}) target_link_libraries(Enzyme-${LLVM_VERSION_MAJOR} LLVM) endif() + if(ENABLE_POSEIDON) + target_link_libraries(Enzyme-${LLVM_VERSION_MAJOR} PRIVATE mpfr) + endif() + install(TARGETS Enzyme-${LLVM_VERSION_MAJOR} EXPORT EnzymeTargets LIBRARY DESTINATION lib COMPONENT shlib diff --git a/enzyme/Enzyme/DiffeGradientUtils.cpp b/enzyme/Enzyme/DiffeGradientUtils.cpp index 19e3c947b04b..140d492bddf8 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.cpp +++ b/enzyme/Enzyme/DiffeGradientUtils.cpp @@ -89,7 +89,8 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( bool strongZero, unsigned width, Function *todiff, TargetLibraryInfo &TLI, TypeAnalysis &TA, FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, bool shadowReturn, bool diffeReturnArg, ArrayRef constant_args, - bool returnTape, bool returnPrimal, Type *additionalArg, bool omp) { + bool returnTape, bool returnPrimal, Type *additionalArg, bool omp, + bool profiled) { Function *oldFunc = todiff; assert(mode == DerivativeMode::ReverseModeGradient || mode == DerivativeMode::ReverseModeCombined || @@ -108,14 +109,16 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( std::string prefix; switch (mode) { - case DerivativeMode::ForwardModeError: case DerivativeMode::ForwardMode: case DerivativeMode::ForwardModeSplit: prefix = "fwddiffe"; break; + case DerivativeMode::ForwardModeError: + prefix = "fwderr"; + break; case DerivativeMode::ReverseModeCombined: case DerivativeMode::ReverseModeGradient: - prefix = "diffe"; + prefix = profiled ? "instr" : "diffe"; break; case DerivativeMode::ReverseModePrimal: llvm_unreachable("invalid DerivativeMode: ReverseModePrimal\n"); @@ -129,7 +132,7 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( nonconstant_values, returnvals, returnTape, returnPrimal, (mode == DerivativeMode::ReverseModeGradient) ? false : shadowReturn, prefix + oldFunc->getName(), &originalToNew, - /*diffeReturnArg*/ diffeReturnArg, additionalArg); + /*diffeReturnArg*/ diffeReturnArg, additionalArg, profiled); // Convert overwritten args from the input function to the preprocessed // function @@ -165,6 +168,7 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone( Logic, newFunc, oldFunc, TLI, TA, TR, invertedPointers, constant_values, nonconstant_values, retType, shadowReturn, constant_args, originalToNew, mode, runtimeActivity, strongZero, width, omp); + res->profiled = profiled; return res; } diff --git a/enzyme/Enzyme/DiffeGradientUtils.h b/enzyme/Enzyme/DiffeGradientUtils.h index d999b08fdbb6..fc53c2aae4d1 100644 --- a/enzyme/Enzyme/DiffeGradientUtils.h +++ b/enzyme/Enzyme/DiffeGradientUtils.h @@ -84,7 +84,8 @@ class DiffeGradientUtils final : public GradientUtils { FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, bool shadowReturnArg, bool diffeReturnArg, llvm::ArrayRef constant_args, bool returnTape, - bool returnPrimal, llvm::Type *additionalArg, bool omp); + bool returnPrimal, llvm::Type *additionalArg, bool omp, + bool profiled = false); llvm::AllocaInst *getDifferential(llvm::Value *val); diff --git a/enzyme/Enzyme/DifferentialUseAnalysis.h b/enzyme/Enzyme/DifferentialUseAnalysis.h index 841b67379f25..93b1be2c5e92 100644 --- a/enzyme/Enzyme/DifferentialUseAnalysis.h +++ b/enzyme/Enzyme/DifferentialUseAnalysis.h @@ -123,12 +123,13 @@ inline bool is_value_needed_in_reverse( } } } - if (gutils->mode == DerivativeMode::ForwardModeError && + if ((gutils->profiled || + gutils->mode == DerivativeMode::ForwardModeError) && !gutils->isConstantValue(const_cast(inst))) { if (EnzymePrintDiffUse) - llvm::errs() - << " Need: " << to_string(VT) << " of " << *inst - << " in reverse as forward mode error always needs result\n"; + llvm::errs() << " Need: " << to_string(VT) << " of " << *inst + << " in reverse as profiling mode or forward mode " + "error always needs result\n"; return seen[idx] = true; } } diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index bc0e6343ee97..098a03d0ec41 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -68,8 +68,11 @@ #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/AbstractCallSite.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/Path.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -77,6 +80,9 @@ #include "DiffeGradientUtils.h" #include "EnzymeLogic.h" #include "GradientUtils.h" +#ifdef ENABLE_POSEIDON +#include "Poseidon/Poseidon.h" +#endif #include "TraceInterface.h" #include "TraceUtils.h" #include "Utils.h" @@ -382,6 +388,7 @@ static bool ReplaceOriginalCall(IRBuilder<> &Builder, Value *ret, class EnzymeBase { public: EnzymeLogic Logic; + bool poseidonProfiling = false; EnzymeBase(bool PostOpt) : Logic(EnzymePostOpt.getNumOccurrences() ? EnzymePostOpt : PostOpt) { // initializeLowerAutodiffIntrinsicPass(*PassRegistry::getPassRegistry()); @@ -487,6 +494,7 @@ class EnzymeBase { bool runtimeActivity; bool strongZero; bool subsequent_calls_may_write; + double errTol; }; #if LLVM_VERSION_MAJOR > 16 @@ -527,6 +535,7 @@ class EnzymeBase { mode != DerivativeMode::ForwardModeError && mode != DerivativeMode::ReverseModeCombined; StringSet<> ActiveRandomVariables; + double errTol = 0.0; DIFFE_TYPE retType = whatType(fn->getReturnType(), mode); @@ -849,6 +858,21 @@ class EnzymeBase { } skipArg = true; break; + } else if (*metaString == "enzyme_err_tol") { + ++i; + Value *relErrorArg = CI->getArgOperand(i); + if (auto *CFP = dyn_cast(relErrorArg)) { + errTol = CFP->getValueAPF().convertToDouble(); + } else { + EmitFailure( + "InvalidErrorTolerance", CI->getDebugLoc(), CI, + "Relative error tolerance must be a constant floating-point " + "value, got ", + *relErrorArg); + return {}; + } + skipArg = true; + break; } else { EmitFailure("IllegalDiffeType", CI->getDebugLoc(), CI, "illegal enzyme metadata classification ", *CI, @@ -1129,7 +1153,8 @@ class EnzymeBase { overwritten_args, runtimeActivity, strongZero, - subsequent_calls_may_write}); + subsequent_calls_may_write, + errTol}); } static FnTypeInfo populate_type_args(TypeAnalysis &TA, llvm::Function *fn, @@ -1543,8 +1568,10 @@ class EnzymeBase { .forceAnonymousTape = false, .typeInfo = type_args, .runtimeActivity = options.runtimeActivity, - .strongZero = options.strongZero}, + .strongZero = options.strongZero, + .profiled = poseidonProfiling}, TA, /*augmented*/ nullptr); + poseidonProfiling = false; break; case DerivativeMode::ReverseModePrimal: case DerivativeMode::ReverseModeGradient: { @@ -1615,7 +1642,8 @@ class EnzymeBase { .forceAnonymousTape = forceAnonymousTape, .typeInfo = type_args, .runtimeActivity = options.runtimeActivity, - .strongZero = options.strongZero}, + .strongZero = options.strongZero, + .profiled = false}, TA, aug); } } @@ -1947,6 +1975,138 @@ class EnzymeBase { return status; } +#ifdef ENABLE_POSEIDON + bool HandlePoseidonProf(CallInst *CI, SmallVectorImpl &calls) { + assert(FPProfileGenerate); + + IRBuilder<> Builder(CI); + Function *F = parseFunctionParameter(CI); + if (!F) + return false; + + assert(F); + + std::map byVal; + std::vector constants; + SmallVector args; + + auto mode = DerivativeMode::ReverseModeCombined; + poseidonProfiling = true; + + auto options = + handleArguments(Builder, CI, F, mode, false, constants, args, byVal); + if (!options) { + poseidonProfiling = false; + return false; + } + + Value *ret = CI; + Type *retElemType = nullptr; + if (CI->hasStructRetAttr()) { + ret = CI->getArgOperand(0); + retElemType = + CI->getAttribute(AttributeList::FirstArgIndex, Attribute::StructRet) + .getValueAsType(); + } + + return HandleAutoDiff(CI, CI->getCallingConv(), ret, retElemType, args, + byVal, constants, F, mode, *options, false, calls); + } + + bool HandlePoseidonOpt(CallInst *CI, SmallVectorImpl &calls) { + IRBuilder<> Builder(CI); + Function *F = parseFunctionParameter(CI); + if (!F) + return false; + + assert(F); + + auto mode = DerivativeMode::ReverseModeCombined; + poseidonProfiling = true; + + std::map byVal; + std::vector constants; + SmallVector args; + + auto options = + handleArguments(Builder, CI, F, mode, false, constants, args, byVal); + if (!options) { + poseidonProfiling = false; + return false; + } + + SmallVector primalArgs; + size_t argIdx = 0; + for (size_t i = 0; i < constants.size(); ++i) { + if (argIdx >= args.size()) + break; + + if (i == 0 && F->hasParamAttribute(0, Attribute::StructRet)) { + argIdx++; + if (constants[i] == DIFFE_TYPE::DUP_ARG || + constants[i] == DIFFE_TYPE::DUP_NONEED) { + argIdx++; + } + continue; + } + + primalArgs.push_back(args[argIdx]); + argIdx++; + + if (constants[i] == DIFFE_TYPE::DUP_ARG || + constants[i] == DIFFE_TYPE::DUP_NONEED) { + argIdx++; + } + } + + if (!FPProfileUse.getNumOccurrences() || FPProfileUse.empty()) { + EmitWarning("MissingProfileMode", *CI, + "__enzyme_fp_optimize called without -fpprofile-generate or " + "-fpprofile-use=. " + "Emitting unoptimized function call. Use -fpprofile-generate " + "to create a profile or -fpprofile-use= to use an " + "existing profile."); + } else { + F = Logic.PPC.preprocessForClone(F, mode, /*profiled=*/true); + setPoseidonMetadata(*F); + + SmallString<128> profilePath(FPProfileUse); + llvm::sys::path::append(profilePath, F->getName() + ".fpprofile"); + if (!llvm::sys::fs::exists(profilePath.str())) { + EmitFailure("NoProfile", CI->getDebugLoc(), CI, "No profile found at ", + profilePath, " (FPProfileUse: ", FPProfileUse, ")"); + return false; + } + + auto &TTI = Logic.PPC.FAM.getResult(*CI->getFunction()); + + if (FPOptPrint) { + llvm::errs() << "FPOpt: Optimizing " << F->getName() + << " with relative error tolerance: " << options->errTol + << "\n"; + } + + bool optimized = fpOptimize(*F, TTI, options->errTol); + + if (!optimized) { + EmitWarning("NoChange", *CI, "Poseidon returned false (no change) for ", + F->getName()); + } + } + + CallInst *optCall = Builder.CreateCall(F->getFunctionType(), F, primalArgs); + optCall->setCallingConv(CI->getCallingConv()); + optCall->setDebugLoc(CI->getDebugLoc()); + + CI->replaceAllUsesWith(optCall); + CI->eraseFromParent(); + + calls.push_back(optCall); + + return true; + } +#endif + bool handleFullModuleTrunc(Function &F) { if (startsWith(F.getName(), EnzymeFPRTPrefix)) return false; @@ -2096,6 +2256,7 @@ class EnzymeBase { SmallVector toTruncateFuncOp; SmallVector toTruncateValue; SmallVector toExpandValue; + SmallVector toFPOpt; MapVector toProbProg; SetVector InactiveCalls; SetVector IterCalls; @@ -2413,6 +2574,7 @@ class EnzymeBase { bool truncateFuncMem = false; bool truncateValue = false; bool expandValue = false; + bool fpOpt = false; bool probProg = false; DerivativeMode derivativeMode; ProbProgMode probProgMode; @@ -2469,6 +2631,9 @@ class EnzymeBase { enableEnzyme = true; probProgMode = ProbProgMode::Condition; probProg = true; + } else if (Fn->getName().contains("__enzyme_fp_optimize")) { + enableEnzyme = true; + fpOpt = true; } if (enableEnzyme) { @@ -2522,9 +2687,11 @@ class EnzymeBase { toTruncateValue.push_back(CI); else if (expandValue) toExpandValue.push_back(CI); - else if (probProg) { + else if (probProg) toProbProg[CI] = probProgMode; - } else + else if (fpOpt) + toFPOpt.push_back(CI); + else toLower[CI] = derivativeMode; if (auto dc = dyn_cast(fn)) { @@ -2628,6 +2795,18 @@ class EnzymeBase { HandleProbProg(call, mode, calls); } +#ifdef ENABLE_POSEIDON + for (auto CI : toFPOpt) { + Changed |= FPProfileGenerate ? HandlePoseidonProf(CI, calls) + : HandlePoseidonOpt(CI, calls); + } +#else + if (!toFPOpt.empty()) { + llvm_unreachable("Poseidon is not enabled. Please specify " + "-DENABLE_POSEIDON=1 when building Enzyme."); + } +#endif + if (Logic.PostOpt) { auto Params = llvm::getInlineParams(); @@ -2977,8 +3156,13 @@ class EnzymeBase { call->eraseFromParent(); } - for (const auto &pair : Logic.PPC.cache) - pair.second->eraseFromParent(); + // Poseidon opt uses preprocessed functions + for (const auto &pair : Logic.PPC.cache) { + if (pair.second->use_empty()) { + pair.second->eraseFromParent(); + } + } + Logic.clear(); if (changed && Logic.PostOpt) { @@ -3492,6 +3676,12 @@ extern "C" void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB, if (registerFixupJuliaPass(Name, MPM)) { return true; } +#ifdef ENABLE_POSEIDON + if (Name == "fp-opt") { + MPM.addPass(FPOptNewPM()); + return true; + } +#endif if (Name == "preserve-nvvm") { MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true)); return true; diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index ae7839d6dcd8..6efcdd5833db 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -3845,7 +3845,8 @@ Function *EnzymeLogic::CreatePrimalAndGradient( .forceAnonymousTape = key.forceAnonymousTape, .typeInfo = key.typeInfo, .runtimeActivity = key.runtimeActivity, - .strongZero = key.strongZero}, + .strongZero = key.strongZero, + .profiled = key.profiled}, TA, &aug, omp); SmallVector revargs; @@ -3944,7 +3945,8 @@ Function *EnzymeLogic::CreatePrimalAndGradient( .forceAnonymousTape = key.forceAnonymousTape, .typeInfo = key.typeInfo, .runtimeActivity = key.runtimeActivity, - .strongZero = key.strongZero}, + .strongZero = key.strongZero, + .profiled = key.profiled}, TA, augmenteddata, omp); { @@ -4228,6 +4230,26 @@ Function *EnzymeLogic::CreatePrimalAndGradient( assert(augmenteddata->constant_args == key.constant_args); } + if (key.profiled) { + Module *M = key.todiff->getParent(); + LLVMContext &Ctx = M->getContext(); + + Type *DoubleTy = Type::getDoubleTy(Ctx); + Type *PtrTy = PointerType::getUnqual(Ctx); + Type *Int32Ty = Type::getInt32Ty(Ctx); + Type *SizeTy = Type::getInt64Ty(Ctx); + + M->getOrInsertGlobal("ENZYME_FPPROFILE_RUNTIME_VAR", Int32Ty); + + FunctionType *LogValueFT = FunctionType::get( + Type::getVoidTy(Ctx), {PtrTy, SizeTy, DoubleTy, Int32Ty, PtrTy}, false); + M->getOrInsertFunction("enzymeLogValue", LogValueFT); + + FunctionType *LogGradFT = FunctionType::get( + Type::getVoidTy(Ctx), {PtrTy, SizeTy, DoubleTy, DoubleTy}, false); + M->getOrInsertFunction("enzymeLogGrad", LogGradFT); + } + bool diffeReturnArg = key.retType == DIFFE_TYPE::OUT_DIFF; DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone( @@ -4235,7 +4257,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient( key.todiff, TLI, TA, oldTypeInfo, key.retType, augmenteddata ? augmenteddata->shadowReturnUsed : key.shadowReturnUsed, diffeReturnArg, key.constant_args, /*returnTape*/ false, key.returnUsed, - key.additionalType, omp); + key.additionalType, omp, key.profiled); gutils->AtomicAdd = key.AtomicAdd; gutils->FreeMemory = key.freeMemory; diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index fbb54947110d..e32a5603807b 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -174,6 +174,7 @@ struct ReverseCacheKey { const FnTypeInfo typeInfo; bool runtimeActivity; bool strongZero; + bool profiled; ReverseCacheKey replaceTypeInfo(const FnTypeInfo &newTypeInfo) const { return {todiff, @@ -191,7 +192,8 @@ struct ReverseCacheKey { forceAnonymousTape, newTypeInfo, runtimeActivity, - strongZero}; + strongZero, + profiled}; } /* inline bool operator==(const ReverseCacheKey& rhs) const { @@ -298,6 +300,11 @@ struct ReverseCacheKey { if (rhs.strongZero < strongZero) return false; + if (profiled < rhs.profiled) + return true; + if (rhs.profiled < profiled) + return false; + // equal return false; } diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 7dc7586eaf9a..705fb25bd763 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -28,6 +28,7 @@ #include "EnzymeLogic.h" #include "GradientUtils.h" #include "LibraryFuncs.h" +#include "Utils.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" @@ -115,6 +116,10 @@ #include "CacheUtility.h" #include "Utils.h" +#ifdef ENABLE_POSEIDON +#include "Poseidon/Poseidon.h" +#endif + #define addAttribute addAttributeAtIndex #define removeAttribute removeAttributeAtIndex #define getAttribute getAttributeAtIndex @@ -2244,8 +2249,8 @@ bool DetectReadonlyOrThrow(Module &M) { return changed; } -Function *PreProcessCache::preprocessForClone(Function *F, - DerivativeMode mode) { +Function *PreProcessCache::preprocessForClone(Function *F, DerivativeMode mode, + bool profiled) { TimeTraceScope timeScope("preprocessForClone", F->getName()); @@ -2691,7 +2696,8 @@ Function *PreProcessCache::preprocessForClone(Function *F, FAM.invalidate(*NewF, PA); } - if (mode != DerivativeMode::ForwardMode) + if (mode != DerivativeMode::ForwardMode && + mode != DerivativeMode::ForwardModeError) ReplaceReallocs(NewF); { @@ -2727,7 +2733,8 @@ Function *PreProcessCache::preprocessForClone(Function *F, PA.preserve(); } - if (mode != DerivativeMode::ForwardMode) + if (mode != DerivativeMode::ForwardMode && + mode != DerivativeMode::ForwardModeError) ReplaceReallocs(NewF); if (mode == DerivativeMode::ReverseModePrimal || @@ -2739,6 +2746,19 @@ Function *PreProcessCache::preprocessForClone(Function *F, UpgradeAllocasToMallocs(NewF, mode, unreachable); } +#ifdef ENABLE_POSEIDON + if (profiled) { + preprocessForPoseidon(NewF); + + auto PA = InstCombinePass().run(*NewF, FAM); + FAM.invalidate(*NewF, PA); + PA = EarlyCSEPass().run(*NewF, FAM); + FAM.invalidate(*NewF, PA); + PA = GVNPass().run(*NewF, FAM); + FAM.invalidate(*NewF, PA); + } +#endif + CanonicalizeLoops(NewF, FAM); RemoveRedundantPHI(NewF, FAM); @@ -2999,9 +3019,14 @@ Function *PreProcessCache::CloneFunctionWithReturns( SmallPtrSetImpl &returnvals, bool returnTape, bool returnPrimal, bool returnShadow, const Twine &name, llvm::ValueMap *VMapO, - bool diffeReturnArg, llvm::Type *additionalArg) { + bool diffeReturnArg, llvm::Type *additionalArg, bool profiled) { if (!F->empty()) - F = preprocessForClone(F, mode); + F = preprocessForClone(F, mode, profiled); +#ifdef ENABLE_POSEIDON + if (profiled) { + setPoseidonMetadata(*F); + } +#endif llvm::ValueToValueMapTy VMap; llvm::FunctionType *FTy = getFunctionTypeForClone( F->getFunctionType(), mode, width, additionalArg, constant_args, diff --git a/enzyme/Enzyme/FunctionUtils.h b/enzyme/Enzyme/FunctionUtils.h index d3e85c1b93d3..6dd73a50badd 100644 --- a/enzyme/Enzyme/FunctionUtils.h +++ b/enzyme/Enzyme/FunctionUtils.h @@ -95,7 +95,8 @@ class PreProcessCache { std::map, llvm::Function *> cache; std::map CloneOrigin; - llvm::Function *preprocessForClone(llvm::Function *F, DerivativeMode mode); + llvm::Function *preprocessForClone(llvm::Function *F, DerivativeMode mode, + bool profiled = false); llvm::AAResults &getAAResultsFromFunction(llvm::Function *NewF); @@ -108,7 +109,8 @@ class PreProcessCache { llvm::SmallPtrSetImpl &returnvals, bool returnTape, bool returnPrimal, bool returnShadow, const llvm::Twine &name, llvm::ValueMap *VMapO, - bool diffeReturnArg, llvm::Type *additionalArg = nullptr); + bool diffeReturnArg, llvm::Type *additionalArg = nullptr, + bool profiled = false); void ReplaceReallocs(llvm::Function *NewF, bool mem2reg = false); void LowerAllocAddr(llvm::Function *NewF); diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 1cdfa35f7c1d..a97bd026bacb 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -163,8 +163,8 @@ GradientUtils::GradientUtils( llvm::ValueMap &originalToNewFn_, DerivativeMode mode, bool runtimeActivity, bool strongZero, unsigned width, bool omp) - : CacheUtility(TLI_, newFunc_), Logic(Logic), mode(mode), oldFunc(oldFunc_), - invertedPointers(), + : CacheUtility(TLI_, newFunc_), Logic(Logic), mode(mode), profiled(false), + oldFunc(oldFunc_), invertedPointers(), OrigDT(oldFunc_->empty() ? ((DominatorTree *)nullptr) : &Logic.PPC.FAM.getResult( @@ -4908,7 +4908,8 @@ Constant *GradientUtils::GetOrCreateShadowFunction( .forceAnonymousTape = true, .typeInfo = type_args, .runtimeActivity = runtimeActivity, - .strongZero = strongZero}, + .strongZero = strongZero, + .profiled = false}, TA, /*map*/ &augdata); assert(newf); diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index cdeb6d739fa1..bb6effdf5933 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -127,6 +127,7 @@ class GradientUtils : public CacheUtility { EnzymeLogic &Logic; bool AtomicAdd; DerivativeMode mode; + bool profiled; llvm::Function *oldFunc; llvm::ValueMap invertedPointers; llvm::DominatorTree *OrigDT; diff --git a/enzyme/Enzyme/Poseidon/Poseidon.cpp b/enzyme/Enzyme/Poseidon/Poseidon.cpp new file mode 100644 index 000000000000..d5dd8201723b --- /dev/null +++ b/enzyme/Enzyme/Poseidon/Poseidon.cpp @@ -0,0 +1,1348 @@ +#include + +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +#include "llvm/Analysis/TargetTransformInfo.h" + +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Verifier.h" + +#include "llvm/Passes/PassBuilder.h" + +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/InstructionCost.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/raw_ostream.h" + +#include "llvm/Pass.h" + +#include "llvm/Transforms/Utils.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "../Utils.h" +#include "Poseidon.h" +#include "PoseidonHerbieUtils.h" +#include "PoseidonPrecUtils.h" +#include "PoseidonProfUtils.h" +#include "PoseidonSolvers.h" +#include "PoseidonTypes.h" +#include "PoseidonUtils.h" + +using namespace llvm; +#ifdef DEBUG_TYPE +#undef DEBUG_TYPE +#endif +#define DEBUG_TYPE "fp-opt" + +extern "C" { +cl::opt FPProfileGenerate( + "fpprofile-generate", cl::init(false), cl::Hidden, + cl::desc("Generate instrumented program for FP profiling")); +cl::opt FPProfileUse( + "fpprofile-use", cl::Hidden, cl::value_desc("directory"), cl::ValueOptional, + cl::desc("FP profile directory to read from for FP optimization")); +cl::opt FPOptPrint("fpopt-print", cl::init(false), cl::Hidden, + cl::desc("Print FPOpt debug info")); +cl::opt FPOptEnableHerbie( + "fpopt-enable-herbie", cl::init(true), cl::Hidden, + cl::desc("Use Herbie to rewrite floating-point expressions")); +cl::opt FPOptEnablePT( + "fpopt-enable-pt", cl::init(true), cl::Hidden, + cl::desc("Consider precision changes of floating-point expressions")); +cl::opt FPOptCachePath("fpopt-cache-path", cl::init("cache"), + cl::Hidden, + cl::desc("Path to cache Herbie results")); +cl::opt FPOptEnableSolver( + "fpopt-enable-solver", cl::init(true), cl::Hidden, + cl::desc("Use the solver to select desirable rewrite candidates; when " + "disabled, apply all Herbie's first choices")); +cl::opt FPOptMaxExprDepth( + "fpopt-max-expr-depth", cl::init(100), cl::Hidden, + cl::desc( + "The maximum depth of expression construction; abort if exceeded")); +cl::opt FPOptMaxExprLength( + "fpopt-max-expr-length", cl::init(10000), cl::Hidden, + cl::desc("The maximum length of an expression; abort if exceeded")); +cl::opt FPOptReductionEval( + "fpopt-reduction-eval", cl::init("arithmean"), cl::Hidden, + cl::desc("Which reduction result to use in candidate evaluation. " + "Options are 'geomean', 'arithmean', and 'maxabs'")); +cl::opt FPOptMinUsesForSplit( + "fpopt-min-uses-split", cl::init(99), cl::Hidden, + cl::desc("Minimum number of uses of bottleneck node to trigger split")); +cl::opt + FPOptMinOpsForSplit("fpopt-min-ops-split", cl::init(99), cl::Hidden, + cl::desc("Minimum number of upstream operations of " + "bottleneck node to trigger split")); +cl::opt + FPOptAggressiveDCE("fpopt-aggressive-dce", cl::init(false), cl::Hidden, + cl::desc("Aggressively eliminate zero gradient outputs " + "as dead code (non-conditional only)")); +cl::opt FPOptMultiOutputPTOnly( + "fpopt-multi-output-pt-only", cl::init(false), cl::Hidden, + cl::desc("Skip Herbie expression generation for subgraphs with multiple " + "outputs (only apply precision changes)")); +} + +bool Poseidonable(const llvm::Value &V) { + const Instruction *I = dyn_cast(&V); + if (!I) + return false; + + switch (I->getOpcode()) { + case Instruction::FNeg: + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: + case Instruction::FDiv: + case Instruction::FRem: + return I->getType()->isFloatTy() || I->getType()->isDoubleTy(); + case Instruction::Call: { + const CallInst *CI = dyn_cast(I); + if (CI && CI->getCalledFunction() && + (CI->getType()->isFloatTy() || CI->getType()->isDoubleTy())) { + StringRef funcName = CI->getCalledFunction()->getName(); + return + // LLVM intrinsics + startsWith(funcName, "llvm.sin.") || + startsWith(funcName, "llvm.cos.") || + startsWith(funcName, "llvm.tan.") || + startsWith(funcName, "llvm.asin.") || + startsWith(funcName, "llvm.acos.") || + startsWith(funcName, "llvm.atan.") || + startsWith(funcName, "llvm.atan2.") || + startsWith(funcName, "llvm.sinh.") || + startsWith(funcName, "llvm.cosh.") || + startsWith(funcName, "llvm.tanh.") || + startsWith(funcName, "llvm.exp.") || + startsWith(funcName, "llvm.log.") || + startsWith(funcName, "llvm.sqrt.") || + startsWith(funcName, "llvm.pow.") || + startsWith(funcName, "llvm.powi.") || + startsWith(funcName, "llvm.fabs.") || + startsWith(funcName, "llvm.fma.") || + startsWith(funcName, "llvm.fmuladd.") || + startsWith(funcName, "llvm.maxnum.") || + startsWith(funcName, "llvm.minnum.") || + startsWith(funcName, "llvm.ceil.") || + startsWith(funcName, "llvm.floor.") || + startsWith(funcName, "llvm.exp2.") || + startsWith(funcName, "llvm.log10.") || + startsWith(funcName, "llvm.log2.") || + startsWith(funcName, "llvm.rint.") || + startsWith(funcName, "llvm.round.") || + startsWith(funcName, "llvm.trunc.") || + startsWith(funcName, "llvm.copysign.") || + startsWith(funcName, "llvm.fdim.") || + startsWith(funcName, "llvm.fmod.") || + + // libm functions + funcName == "sin" || funcName == "sinf" || funcName == "cos" || + funcName == "cosf" || funcName == "tan" || funcName == "tanf" || + funcName == "asin" || funcName == "asinf" || funcName == "acos" || + funcName == "acosf" || funcName == "atan" || funcName == "atanf" || + funcName == "atan2" || funcName == "atan2f" || funcName == "sinh" || + funcName == "sinhf" || funcName == "cosh" || funcName == "coshf" || + funcName == "tanh" || funcName == "tanhf" || funcName == "asinh" || + funcName == "asinhf" || funcName == "acosh" || funcName == "acoshf" || + funcName == "atanh" || funcName == "atanhf" || funcName == "sqrt" || + funcName == "sqrtf" || funcName == "cbrt" || funcName == "cbrtf" || + funcName == "pow" || funcName == "powf" || funcName == "exp" || + funcName == "expf" || funcName == "log" || funcName == "logf" || + funcName == "fabs" || funcName == "fabsf" || funcName == "fma" || + funcName == "fmaf" || funcName == "hypot" || funcName == "hypotf" || + funcName == "expm1" || funcName == "expm1f" || funcName == "log1p" || + funcName == "log1pf" || funcName == "ceil" || funcName == "ceilf" || + funcName == "floor" || funcName == "floorf" || funcName == "erf" || + funcName == "erff" || funcName == "exp2" || funcName == "exp2f" || + funcName == "lgamma" || funcName == "lgammaf" || + funcName == "log10" || funcName == "log10f" || funcName == "log2" || + funcName == "log2f" || funcName == "rint" || funcName == "rintf" || + funcName == "round" || funcName == "roundf" || funcName == "tgamma" || + funcName == "tgammaf" || funcName == "trunc" || + funcName == "truncf" || funcName == "copysign" || + funcName == "copysignf" || funcName == "fdim" || + funcName == "fdimf" || funcName == "fmod" || funcName == "fmodf" || + funcName == "remainder" || funcName == "remainderf"; + } + return false; + } + default: + return false; + } +} + +void setPoseidonMetadata(Function &F) { + for (auto [idx, I] : enumerate(instructions(F))) { + if (Poseidonable(I)) { + I.setMetadata("enzyme_active", MDNode::get(I.getContext(), {})); + I.setMetadata("enzyme_fpprofile_idx", + MDNode::get(I.getContext(), + {ConstantAsMetadata::get(ConstantInt::get( + Type::getInt64Ty(I.getContext()), idx))})); + } + } +} + +void preprocessForPoseidon(Function *F) { + using namespace llvm::PatternMatch; + + // fmul + fadd -> fmuladd + for (auto &BB : *F) { + for (auto &I : make_early_inc_range(BB)) { + Value *X, *Y, *Z; + + if (auto *FAdd = dyn_cast(&I)) { + if (!isa(FAdd) || !FAdd->hasAllowReassoc() || + !FAdd->hasAllowContract()) + continue; + + // fadd (fmul X, Y), Z + if (match(FAdd, m_FAdd(m_OneUse(m_FMul(m_Value(X), m_Value(Y))), + m_Value(Z)))) { + IRBuilder<> B(FAdd); + B.setFastMathFlags(FAdd->getFastMathFlags()); + + Value *FMulAdd = + B.CreateIntrinsic(Intrinsic::fmuladd, FAdd->getType(), {X, Y, Z}); + FAdd->replaceAllUsesWith(FMulAdd); + FAdd->eraseFromParent(); + } + // fadd Z, (fmul X, Y) + else if (match(FAdd, + m_FAdd(m_Value(Z), + m_OneUse(m_FMul(m_Value(X), m_Value(Y)))))) { + IRBuilder<> B(FAdd); + B.setFastMathFlags(FAdd->getFastMathFlags()); + + Value *FMulAdd = + B.CreateIntrinsic(Intrinsic::fmuladd, FAdd->getType(), {X, Y, Z}); + FAdd->replaceAllUsesWith(FMulAdd); + + FAdd->eraseFromParent(); + } + } + } + } + + for (auto &BB : *F) { + for (auto &I : make_early_inc_range(BB)) { + Value *X, *Y, *Z; + + if (auto *FSub = dyn_cast(&I)) { + if (!isa(FSub) || !FSub->hasAllowReassoc() || + !FSub->hasAllowContract()) + continue; + + // Pattern: fsub (fmul X, Y), Z -> fmuladd(X, Y, -Z) + if (match(FSub, m_FSub(m_OneUse(m_FMul(m_Value(X), m_Value(Y))), + m_Value(Z)))) { + IRBuilder<> B(FSub); + B.setFastMathFlags(FSub->getFastMathFlags()); + + Value *NegZ = B.CreateFNeg(Z); + Value *FMulAdd = B.CreateIntrinsic(Intrinsic::fmuladd, + FSub->getType(), {X, Y, NegZ}); + FSub->replaceAllUsesWith(FMulAdd); + FSub->eraseFromParent(); + } + // Pattern: fsub Z, (fmul X, Y) -> fmuladd(-X, Y, Z) + else if (match(FSub, + m_FSub(m_Value(Z), + m_OneUse(m_FMul(m_Value(X), m_Value(Y)))))) { + IRBuilder<> B(FSub); + B.setFastMathFlags(FSub->getFastMathFlags()); + + Value *NegX = B.CreateFNeg(X); + Value *FMulAdd = B.CreateIntrinsic(Intrinsic::fmuladd, + FSub->getType(), {NegX, Y, Z}); + FSub->replaceAllUsesWith(FMulAdd); + FSub->eraseFromParent(); + } + } + } + } + + // fcmp + select -> fmax/fmin + for (auto &BB : *F) { + for (auto &I : make_early_inc_range(BB)) { + if (auto *Select = dyn_cast(&I)) { + Value *Cond = Select->getCondition(); + Value *TrueVal = Select->getTrueValue(); + Value *FalseVal = Select->getFalseValue(); + + if (!Select->getType()->isFloatingPointTy()) + continue; + + CmpPredicate Pred; + Value *CmpLHS, *CmpRHS; + + if (match(Cond, m_FCmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) { + IRBuilder<> B(Select); + Value *Result = nullptr; + + // select (fcmp ogt X, 0.0), X, 0.0 -> maxnum(X, 0.0) + if (Pred == FCmpInst::FCMP_OGT && match(CmpRHS, m_AnyZeroFP()) && + CmpLHS == TrueVal && match(FalseVal, m_AnyZeroFP())) { + Result = B.CreateIntrinsic( + Intrinsic::maxnum, CmpLHS->getType(), + {CmpLHS, ConstantFP::get(CmpLHS->getType(), 0.0)}); + } + // select (fcmp olt X, 0.0), 0.0, X -> maxnum(X, 0.0) + else if (Pred == FCmpInst::FCMP_OLT && match(CmpRHS, m_AnyZeroFP()) && + CmpLHS == FalseVal && match(TrueVal, m_AnyZeroFP())) { + Result = B.CreateIntrinsic( + Intrinsic::maxnum, CmpLHS->getType(), + {CmpLHS, ConstantFP::get(CmpLHS->getType(), 0.0)}); + } + // select (fcmp ogt X, Y), X, Y -> maxnum(X, Y) + else if (Pred == FCmpInst::FCMP_OGT && CmpLHS == TrueVal && + CmpRHS == FalseVal) { + Result = B.CreateIntrinsic(Intrinsic::maxnum, CmpLHS->getType(), + {CmpLHS, CmpRHS}); + } + // select (fcmp olt X, Y), X, Y -> minnum(X, Y) + else if (Pred == FCmpInst::FCMP_OLT && CmpLHS == TrueVal && + CmpRHS == FalseVal) { + Result = B.CreateIntrinsic(Intrinsic::minnum, CmpLHS->getType(), + {CmpLHS, CmpRHS}); + } + + if (Result) { + Select->replaceAllUsesWith(Result); + Select->eraseFromParent(); + + if (auto *FCmp = dyn_cast(Cond)) { + if (FCmp->use_empty()) { + FCmp->eraseFromParent(); + } + } + } + } + } + } + } +} + +// Run (our choice of) floating point optimizations on function `F`. +// Return whether or not we change the function. +bool fpOptimize(Function &F, const TargetTransformInfo &TTI, double errorTol) { + bool changed = false; + + // llvm::errs() << "FPOpt: Starting optimization for " << F.getName() << "\n"; + // F.print(llvm::errs()); + + const std::string functionName = F.getName().str(); + assert(!FPProfileUse.empty()); + SmallString<128> profilePathBuf(FPProfileUse); + llvm::sys::path::append(profilePathBuf, F.getName() + ".fpprofile"); + const std::string profilePath = profilePathBuf.str().str(); + + if (!FPOptCachePath.empty()) { + if (auto EC = llvm::sys::fs::create_directories(FPOptCachePath, true)) + llvm::errs() << "Warning: Could not create cache directory: " + << EC.message() << "\n"; + } + + std::unordered_map profileMap; + if (!profilePath.empty()) { + parseProfileFile(profilePath, profileMap); + if (profileMap.empty()) { + llvm::errs() << "Warning: No profile data found in " << profilePath + << "\n"; + } + } + + int symbolCounter = 0; + auto getNextSymbol = [&symbolCounter]() -> std::string { + return "v" + std::to_string(symbolCounter++); + }; + + // Extract change: + + // E1) create map for all instructions I, map[I] = FPLLValue(I) + // E2) for all instructions, if Poseidonable(I), map[I] = FPNode(operation(I), + // map[operands(I)]) + // E3) floodfill for all starting locations I to find all distinct graphs / + // outputs. + + /* + B1: + x = sin(arg) + + B2: + y = 1 - x * x + + + -> result y = cos(arg)^2 + +B1: + nothing + +B2: + costmp = cos(arg) + y = costmp * costmp + + */ + + std::unordered_map> valueToNodeMap; + std::unordered_map symbolToValueMap; + + llvm::errs() << "FPOpt: Starting Floodfill for " << F.getName() << "\n"; + + for (auto &BB : F) { + for (auto &I : BB) { + if (!Poseidonable(I)) { + valueToNodeMap[&I] = + std::make_shared(&I, "__nh", "__nh"); // Non-Poseidonable + if (FPOptPrint) + llvm::errs() + << "Registered FPLLValue for non-Poseidonable instruction: " << I + << "\n"; + continue; + } + + std::string dtype; + if (I.getType()->isFloatTy()) { + dtype = "f32"; + } else if (I.getType()->isDoubleTy()) { + dtype = "f64"; + } else { + llvm_unreachable("Unexpected floating point type for instruction"); + } + auto node = std::make_shared(&I, getHerbieOperator(I), dtype); + + auto operands = + isa(I) ? cast(I).args() : I.operands(); + for (auto &operand : operands) { + if (!valueToNodeMap.count(operand)) { + if (auto Arg = dyn_cast(operand)) { + std::string dtype; + if (Arg->getType()->isFloatTy()) { + dtype = "f32"; + } else if (Arg->getType()->isDoubleTy()) { + dtype = "f64"; + } else { + llvm_unreachable("Unexpected floating point type for argument"); + } + valueToNodeMap[operand] = + std::make_shared(Arg, "__arg", dtype); + if (FPOptPrint) + llvm::errs() << "Registered FPNode for argument: " << *Arg + << "\n"; + } else if (auto C = dyn_cast(operand)) { + SmallString<10> value; + C->getValueAPF().toString(value); + std::string dtype; + if (C->getType()->isFloatTy()) { + dtype = "f32"; + } else if (C->getType()->isDoubleTy()) { + dtype = "f64"; + } else { + llvm_unreachable("Unexpected floating point type for constant"); + } + valueToNodeMap[operand] = + std::make_shared(value.c_str(), dtype); + if (FPOptPrint) + llvm::errs() << "Registered FPNode for " << dtype + << " constant: " << value << "\n"; + } else if (auto CI = dyn_cast(operand)) { + // e.g., powi intrinsic has a constant int as its exponent + double exponent = static_cast(CI->getSExtValue()); + std::string dtype = "f64"; + std::string doubleStr = std::to_string(exponent); + valueToNodeMap[operand] = + std::make_shared(doubleStr.c_str(), dtype); + if (FPOptPrint) + llvm::errs() << "Registered FPNode for " << dtype + << " constant (casted from integer): " << doubleStr + << "\n"; + } else if (auto GV = dyn_cast(operand)) { + Type *elemType = GV->getValueType(); + + assert(elemType->isFloatingPointTy() && + "Global variable is not floating point type"); + std::string dtype; + if (elemType->isFloatTy()) { + dtype = "f32"; + } else if (elemType->isDoubleTy()) { + dtype = "f64"; + } else { + llvm_unreachable( + "Unexpected floating point type for global variable"); + } + valueToNodeMap[operand] = + std::make_shared(GV, "__gv", dtype); + if (FPOptPrint) + llvm::errs() << "Registered FPNode for global variable: " << *GV + << "\n"; + } else { + assert(0 && "Unknown operand"); + } + } + node->addOperand(valueToNodeMap[operand]); + } + valueToNodeMap[&I] = node; + } + } + + SmallSet processed; + SmallVector subgraphs; + for (auto &BB : F) { + for (auto &I : BB) { + // Not a Poseidonable instruction, doesn't make sense to create graph node + // out of. + if (!Poseidonable(I)) { + if (FPOptPrint) + llvm::errs() << "Skipping non-Poseidonable instruction: " << I + << "\n"; + continue; + } + + // Instruction is already in a set + if (processed.contains(&I)) { + if (FPOptPrint) + llvm::errs() << "Skipping already seen instruction: " << I << "\n"; + continue; + } + + if (FPOptPrint) + llvm::errs() << "Starting floodfill from: " << I << "\n"; + + SmallVector todo; + SetVector input_seen; + SetVector output_seen; + SetVector operation_seen; + todo.push_back(&I); + while (!todo.empty()) { + auto cur = todo.pop_back_val(); + assert(valueToNodeMap.count(cur) && "Node not found in valueToNodeMap"); + + // We now can assume that this is a Poseidonable expression + // Since we can only herbify instructions, let's assert that + assert(isa(cur)); + auto I2 = cast(cur); + + // Don't repeat any instructions we've already seen (to avoid loops + // for phi nodes) + if (operation_seen.contains(I2)) { + if (FPOptPrint) + llvm::errs() << "Skipping already seen instruction: " << *I2 + << "\n"; + continue; + } + + assert(!processed.contains(cur)); + + if (FPOptPrint) + llvm::errs() << "Insert to operation_seen and processed: " << *I2 + << "\n"; + operation_seen.insert(I2); + processed.insert(cur); + + auto operands = + isa(I2) ? cast(I2)->args() : I2->operands(); + + for (const auto &operand : operands) { + if (!Poseidonable(*operand)) { + if (FPOptPrint) + llvm::errs() << "Non-Poseidonable input found: " << *operand + << "\n"; + + // Don't mark constants as input `llvm::Value`s + if (!isa(operand)) + input_seen.insert(operand); + } else { + if (FPOptPrint) + llvm::errs() << "Adding operand to todo list: " << *operand + << "\n"; + todo.push_back(operand); + } + } + + for (auto U : I2->users()) { + if (auto I3 = dyn_cast(U)) { + if (!Poseidonable(*I3)) { + if (FPOptPrint) + llvm::errs() << "Output instruction found: " << *I2 << "\n"; + output_seen.insert(I2); + } else { + if (FPOptPrint) + llvm::errs() << "Adding user to todo list: " << *I3 << "\n"; + todo.push_back(I3); + } + } + } + } + + // Don't bother with graphs without any Poseidonable operations + if (!operation_seen.empty()) { + if (FPOptPrint) { + llvm::errs() << "Found a subgraph with " << operation_seen.size() + << " operations and " << input_seen.size() + << " inputs and " << output_seen.size() << " outputs\n"; + + llvm::errs() << "Inputs:\n"; + + for (auto &input : input_seen) { + llvm::errs() << *input << "\n"; + } + + llvm::errs() << "Outputs:\n"; + for (auto &output : output_seen) { + llvm::errs() << *output << "\n"; + } + + llvm::errs() << "Operations:\n"; + for (auto &operation : operation_seen) { + llvm::errs() << *operation << "\n"; + } + } + + if (operation_seen.size() == 1) { + if (FPOptPrint) + llvm::errs() << "Skipping trivial subgraph\n"; + continue; + } + + subgraphs.emplace_back(input_seen, output_seen, operation_seen); + } + } + } + + if (FPOptPrint) { + llvm::errs() << "FPOpt: Found " << subgraphs.size() + << " initial subgraphs in " << F.getName() << "\n"; + } + + // Profile read must happen before aggressive DCE as it requires gradients + for (auto &subgraph : subgraphs) { + for (auto op : subgraph.operations) { + if (auto MD = op->getMetadata("enzyme_fpprofile_idx")) { + if (auto C = dyn_cast(MD->getOperand(0))) { + size_t idx = cast(C->getValue())->getZExtValue(); + auto it = profileMap.find(idx); + + if (it != profileMap.end()) { + const auto &profileInfo = it->second; + + auto node = valueToNodeMap[op]; + node->sens = profileInfo.sumSens; + node->grad = profileInfo.sumGrad; + node->executions = profileInfo.exec; + node->updateBounds(profileInfo.minRes, profileInfo.maxRes); + + if (FPOptPrint) { + llvm::errs() << "Range of " << *op << " is [" + << node->getLowerBound() << ", " + << node->getUpperBound() << "]\n"; + llvm::errs() << "Sensitivity score of " << *op + << " is: " << node->sens << "\n" + << "Gradient sum of " << *op << " is: " << node->grad + << "\n" + << "Execution count of " << *op + << " is: " << node->executions << "\n"; + } + + auto operands = + isa(op) ? cast(op)->args() : op->operands(); + + for (const auto &operand_ : enumerate(operands)) { + auto &operand = operand_.value(); + auto i = operand_.index(); + + if (i < profileInfo.minOperands.size()) { + auto operandNode = valueToNodeMap[operand]; + operandNode->updateBounds(profileInfo.minOperands[i], + profileInfo.maxOperands[i]); + if (FPOptPrint) { + llvm::errs() << "Range of " << *operand << " is [" + << operandNode->getLowerBound() << ", " + << operandNode->getUpperBound() << "]\n"; + } + } + } + } else { + if (!FPOptLooseCoverage) { + llvm::errs() << "FP Instruction " << *op + << " has no execution logged (idx=" << idx << ")!\n"; + llvm_unreachable("Unexecuted instruction found; set " + "-fpopt-loose-coverage " + "to suppress this error\n"); + } + if (FPOptPrint) + llvm::errs() << "Sensitivity of " << *op + << " not found in the log; using 0 instead\n"; + } + } + } + } + } + + if (FPOptAggressiveDCE) { + SmallSet critical; + for (auto &BB : F) { + for (auto &I : BB) { + if (auto *fcmp = dyn_cast(&I)) { + critical.insert(fcmp->getOperand(0)); + critical.insert(fcmp->getOperand(1)); + } + } + } + + SmallVector worklist(critical.begin(), critical.end()); + while (!worklist.empty()) { + Value *V = worklist.pop_back_val(); + + if (auto *inst = dyn_cast(V)) { + auto operands = isa(inst) ? cast(inst)->args() + : inst->operands(); + for (auto &op : operands) { + if (op->getType()->isFloatingPointTy() && + critical.insert(op).second) { + worklist.push_back(op); + } + } + + if (auto *phi = dyn_cast(inst)) { + for (unsigned i = 0; i < phi->getNumIncomingValues(); ++i) { + Value *incoming = phi->getIncomingValue(i); + if (incoming->getType()->isFloatingPointTy() && + critical.insert(incoming).second) { + worklist.push_back(incoming); + } + } + } + + if (auto *load = dyn_cast(inst)) { + Value *ptr = load->getPointerOperand(); + for (auto *user : ptr->users()) { + if (auto *store = dyn_cast(user)) { + Value *storedVal = store->getValueOperand(); + if (storedVal->getType()->isFloatingPointTy() && + critical.insert(storedVal).second) { + worklist.push_back(storedVal); + } + } + } + } + } + } + + if (FPOptPrint) { + llvm::errs() << "Critical values:\n"; + for (auto *value : critical) { + llvm::errs() << "\t" << *value << "\n"; + } + } + + auto subgraphIt = subgraphs.begin(); + while (subgraphIt != subgraphs.end()) { + Subgraph &subgraph = *subgraphIt; + SmallVector toRemove; + + for (auto *op : subgraph.operations) { + auto node = valueToNodeMap[op]; + if (node->grad == 0. && !critical.count(op)) { + if (FPOptPrint) + llvm::errs() << "Aggressive DCE: eliminating zero-gradient " + << "non-critical instruction: " << *op << "\n"; + toRemove.push_back(op); + } + } + + for (auto *op : toRemove) { + if (!op->use_empty()) { + op->replaceAllUsesWith(UndefValue::get(op->getType())); + } + + valueToNodeMap.erase(op); + subgraph.operations.remove(op); + subgraph.outputs.remove(op); + + op->eraseFromParent(); + } + + if (subgraph.outputs.empty()) { + if (FPOptPrint) + llvm::errs() << "Removing empty subgraph\n"; + subgraphIt = subgraphs.erase(subgraphIt); + } else { + ++subgraphIt; + } + } + } + + if (FPOptPrint && FPOptAggressiveDCE) { + llvm::errs() << "FPOpt: After aggressive DCE, have " << subgraphs.size() + << " subgraphs in " << F.getName() << "\n"; + } + + splitSubgraphs(subgraphs); + + if (FPOptPrint) { + llvm::errs() << "FPOpt: After splitting, have " << subgraphs.size() + << " subgraphs in " << F.getName() << "\n"; + } + + if (FPOptPrint) { + llvm::errs() << "\n=== Function IR after Subgraph Splitting ===\n"; + + std::unordered_map instToSubgraphIdx; + for (size_t idx = 0; idx < subgraphs.size(); ++idx) { + for (auto *inst : subgraphs[idx].operations) { + instToSubgraphIdx[inst] = idx; + } + for (auto *inst : subgraphs[idx].outputs) { + if (instToSubgraphIdx.find(inst) == instToSubgraphIdx.end()) { + instToSubgraphIdx[inst] = idx; + } + } + } + + for (auto &BB : F) { + BB.printAsOperand(llvm::errs(), false); + llvm::errs() << ":\n"; + for (auto &I : BB) { + llvm::errs() << " "; + I.print(llvm::errs()); + + auto it = instToSubgraphIdx.find(&I); + if (it != instToSubgraphIdx.end()) { + llvm::errs() << " ; [SG" << it->second << "]"; + } + llvm::errs() << "\n"; + } + } + llvm::errs() << "=== End of Function IR ===\n\n"; + } + + // 1) Identify subgraphs of the computation which can be entirely represented + // in herbie-style arithmetic + // 2) Make the herbie FP-style expression by + // converting llvm instructions into herbie string (FPNode ....) + if (subgraphs.empty()) { + if (FPOptPrint) + llvm::errs() << "No subgraphs found\n"; + return false; + } + + SmallVector COs; + SmallVector CSs; + + int subgraphCounter = 0; + + for (auto &subgraph : subgraphs) { + assert(subgraph.inputs.size() > 0 && "No inputs found for subgraph"); + + bool skipHerbie = false; + if (FPOptMultiOutputPTOnly && subgraph.outputs.size() > 1) { + skipHerbie = true; + if (FPOptPrint) + llvm::errs() << "Skipping Herbie for subgraph with " + << subgraph.outputs.size() + << " outputs (fpopt-multi-output-pt-only is set)\n"; + } + + if (FPOptEnableHerbie && !skipHerbie) { + for (const auto &input : subgraph.inputs) { + auto node = valueToNodeMap[input]; + if (node->op == "__const") { + // Constants don't need a symbol + continue; + } + + if (!node->hasSymbol()) { + node->symbol = getNextSymbol(); + } + symbolToValueMap[node->symbol] = input; + if (FPOptPrint) + llvm::errs() << "assigning symbol: " << node->symbol << " to " + << *input << "\n"; + } + + std::vector herbieInputs; + std::vector newCOs; + + assert(subgraph.outputs.size() > 0 && "No outputs found for subgraph"); + for (auto &output : subgraph.outputs) { + // 3) run fancy opts + double grad = valueToNodeMap[output]->grad; + unsigned executions = valueToNodeMap[output]->executions; + + if (grad == 0.) { + llvm::errs() << "Skipping zero gradient instruction: " << *output + << "\n"; + continue; + } + + std::string expr = valueToNodeMap[output]->toFullExpression( + valueToNodeMap, subgraph.inputs); + + if (expr.length() > FPOptMaxExprLength) { + llvm::errs() << "WARNING: Skipping Herbie optimization for " + << *output << " since expression length " + << expr.length() << " exceeds limit of " + << FPOptMaxExprLength << "\n"; + continue; + } + + // Skip trivial expressions with only one operation + auto parenCount = std::count(expr.begin(), expr.end(), '('); + assert(parenCount > 0); + if (parenCount == 1) { + if (FPOptPrint) + llvm::errs() << "Skipping Herbie for simple expression: " << expr + << "\n"; + continue; + } + + SmallSet args; + getUniqueArgs(expr, args); + + std::string properties = ":herbie-conversions ([binary64 binary32])"; + if (valueToNodeMap[output]->dtype == "f32") { + properties += " :precision binary32"; + } else if (valueToNodeMap[output]->dtype == "f64") { + properties += " :precision binary64"; + } else { + llvm_unreachable("Unexpected dtype"); + } + + std::string precondition = + getPrecondition(args, valueToNodeMap, symbolToValueMap); + properties += " :pre " + precondition; + + CandidateOutput CO(subgraph, output, expr, grad, executions, TTI); + properties += " :name \"" + std::to_string(newCOs.size()) + "\""; + + std::string argStr; + for (const auto &arg : args) { + if (!argStr.empty()) + argStr += " "; + argStr += arg; + } + + std::string herbieInput = + "(FPCore (" + argStr + ") " + properties + " " + expr + ")"; + if (FPOptPrint) + llvm::errs() << "Herbie input:\n" << herbieInput << "\n"; + + herbieInputs.push_back(herbieInput); + newCOs.push_back(CO); + } + + if (!herbieInputs.empty()) { + if (!improveViaHerbie(herbieInputs, newCOs, F.getParent(), TTI, + valueToNodeMap, symbolToValueMap, + subgraphCounter)) { + if (FPOptPrint) + llvm::errs() << "Failed to optimize expressions using Herbie!\n"; + } + + COs.insert(COs.end(), newCOs.begin(), newCOs.end()); + } + } + + if (FPOptEnablePT) { + // Sort `cs.operations` by the gradient and construct + // `PrecisionChange`s. + CandidateSubgraph CS(subgraph, TTI); + auto *o0 = subgraph.outputs[0]; + CS.executions = valueToNodeMap[o0]->executions; + + const SmallVector precTypes{ + PrecisionChangeType::FP32, + PrecisionChangeType::FP64, + }; + + const auto &PTFuncs = getPTFuncs(); + + // Check if we have a cached DP table + std::string cacheFilePath = FPOptCachePath + "/table.json"; + bool skipEvaluation = FPOptSolverType == "dp" && + !FPOptCachePath.empty() && + llvm::sys::fs::exists(cacheFilePath); + + SetVector operations; + for (auto *I : subgraph.operations) { + assert(isa(valueToNodeMap[I].get()) && + "Corrupted FPNode for original instructions"); + auto node = cast(valueToNodeMap[I].get()); + if (PTFuncs.count(node->op) != 0) { + operations.insert(node); + llvm::errs() << "FPOpt: PT Function identified: " << *I << "\n"; + } + } + + // Prioritize operations with low sensitivity scores + SmallVector sortedOps(operations.begin(), operations.end()); + llvm::sort(sortedOps, [](const auto &a, const auto &b) { + return a->sens < b->sens; + }); + + // Create PrecisionChanges for 0-10%, 0-20%, ..., up to 0-100% + size_t lastNumChanged = 0; + for (int percent = 10; percent <= 100; percent += 10) { + size_t numToChange = sortedOps.size() * percent / 100; + if (numToChange == 0 || numToChange == lastNumChanged) { + continue; + } + + lastNumChanged = numToChange; + + if (FPOptPrint && numToChange > 0) { + llvm::errs() << "Created PrecisionChange for " << percent + << "% of Funcs (" << numToChange << ")\n"; + double minSens = sortedOps[0]->sens; + double maxSens = sortedOps[numToChange - 1]->sens; + llvm::errs() << "Sensitivity score range: [" << minSens << ", " + << maxSens << "]\n"; + } + + for (auto prec : precTypes) { + PrecisionChangeType currentPrec = + getPrecisionChangeType(subgraph.outputs[0]->getType()); + if (prec == currentPrec) { + continue; + } + + std::string precStr = getPrecisionChangeTypeString(prec).str(); + std::string desc = + "Funcs 0% -- " + std::to_string(percent) + "% -> " + precStr; + + SetVector nodesToChange(sortedOps.begin(), + sortedOps.begin() + numToChange); + PrecisionChange change(nodesToChange, currentPrec, prec); + + SmallVector changes{std::move(change)}; + PTCandidate candidate{std::move(changes), desc}; + + if (!skipEvaluation) { + candidate.CompCost = getCompCost(subgraph, TTI, candidate); + } + + CS.candidates.push_back(std::move(candidate)); + } + } + + SetVector allOperations; + for (auto *I : subgraph.operations) { + assert(isa(valueToNodeMap[I].get()) && + "Corrupted FPNode for original instructions"); + auto node = cast(valueToNodeMap[I].get()); + allOperations.insert(node); + } + + // Prioritize operations with low sensitivity scores + SmallVector sortedAllOps(allOperations.begin(), + allOperations.end()); + llvm::sort(sortedAllOps, [](const auto &a, const auto &b) { + return a->sens < b->sens; + }); + + // Create PrecisionChanges for 0-10%, 0-20%, ..., up to 0-100% + lastNumChanged = 0; + for (int percent = 10; percent <= 100; percent += 10) { + size_t numToChange = sortedAllOps.size() * percent / 100; + if (numToChange == 0 || numToChange == lastNumChanged) { + continue; + } + + lastNumChanged = numToChange; + + if (FPOptPrint && numToChange > 0) { + llvm::errs() << "Created PrecisionChange for " << percent + << "% of all operations (" << numToChange << ")\n"; + double minSens = sortedAllOps[0]->sens; + double maxSens = sortedAllOps[numToChange - 1]->sens; + llvm::errs() << "Sensitivity score range: [" << minSens << ", " + << maxSens << "]\n"; + } + + for (auto prec : precTypes) { + PrecisionChangeType currentPrec = + getPrecisionChangeType(subgraph.outputs[0]->getType()); + if (prec == currentPrec) { + continue; + } + + std::string precStr = getPrecisionChangeTypeString(prec).str(); + std::string desc = + "All 0% -- " + std::to_string(percent) + "% -> " + precStr; + + SetVector nodesToChange( + sortedAllOps.begin(), sortedAllOps.begin() + numToChange); + PrecisionChange change(nodesToChange, currentPrec, prec); + + SmallVector changes{std::move(change)}; + PTCandidate candidate{std::move(changes), desc}; + + if (!skipEvaluation) { + candidate.CompCost = getCompCost(subgraph, TTI, candidate); + } + + CS.candidates.push_back(std::move(candidate)); + } + } + + if (!skipEvaluation) { + setUnifiedAccuracyCost(CS, valueToNodeMap, symbolToValueMap); + } + + CSs.push_back(std::move(CS)); + } + llvm::errs() << "##### Finished synthesizing candidates for " + << ++subgraphCounter << " of " << subgraphs.size() + << " subgraphs! #####\n"; + } + + // Perform rewrites + if (FPOptPrint) { + if (FPOptEnableHerbie) { + for (auto &CO : COs) { + llvm::errs() << "\n################################\n"; + llvm::errs() << "Initial AccuracyCost: " << CO.initialAccCost << "\n"; + llvm::errs() << "Initial ComputationCost: " << CO.initialCompCost + << "\n"; + llvm::errs() << "Initial HerbieCost: " << CO.initialHerbieCost << "\n"; + llvm::errs() << "Initial HerbieAccuracy: " << CO.initialHerbieAccuracy + << "\n"; + llvm::errs() << "Initial Expression: " << CO.expr << "\n"; + llvm::errs() << "Grad: " << CO.grad << "\n\n"; + llvm::errs() << "Candidates:\n"; + llvm::errs() << "Δ AccCost\t\tΔ " + "CompCost\t\tHerbieCost\t\tAccuracy\t\tExpression\n"; + llvm::errs() << "--------------------------------\n"; + for (size_t i = 0; i < CO.candidates.size(); ++i) { + auto &candidate = CO.candidates[i]; + llvm::errs() << CO.getAccCostDelta(i) << "\t\t" + << CO.getCompCostDelta(i) << "\t\t" + << candidate.herbieCost << "\t\t" + << candidate.herbieAccuracy << "\t\t" << candidate.expr + << "\n"; + } + llvm::errs() << "################################\n\n"; + } + } + if (FPOptEnablePT) { + for (auto &CS : CSs) { + llvm::errs() << "\n################################\n"; + llvm::errs() << "Initial AccuracyCost: " << CS.initialAccCost << "\n"; + llvm::errs() << "Initial ComputationCost: " << CS.initialCompCost + << "\n"; + llvm::errs() << "Candidates:\n"; + llvm::errs() << "Δ AccCost\t\tΔ CompCost\t\tDescription\n" + << "---------------------------\n"; + for (size_t i = 0; i < CS.candidates.size(); ++i) { + auto &candidate = CS.candidates[i]; + llvm::errs() << CS.getAccCostDelta(i) << "\t\t" + << CS.getCompCostDelta(i) << "\t\t" << candidate.desc + << "\n"; + } + llvm::errs() << "################################\n\n"; + } + } + } + + if (!FPOptApplyRewrites.empty()) { + // User-selected rewrites: parse IDs and apply the specified candidates. + // IDs are R{coIdx}_{candIdx} for rewrites, PT{csIdx}_{candIdx} for PT. + SmallVector ids; + StringRef(FPOptApplyRewrites) + .split(ids, ',', /*MaxSplit=*/-1, + /*KeepEmpty=*/false); + + // Track which CO/CS has been applied to enforce at-most-one constraint + SmallDenseSet appliedCOs, appliedCSs; + + for (auto id : ids) { + id = id.trim(); + if (id.starts_with("R")) { + // Parse R{coIdx}_{candIdx} + auto rest = id.drop_front(1); + auto [coStr, candStr] = rest.split('_'); + size_t coIdx, candIdx; + if (coStr.getAsInteger(10, coIdx) || + candStr.getAsInteger(10, candIdx)) { + llvm::errs() << "FPOpt: Invalid rewrite ID '" << id << "'\n"; + continue; + } + if (coIdx >= COs.size()) { + llvm::errs() << "FPOpt: CO index " << coIdx << " out of range (" + << COs.size() << " COs)\n"; + continue; + } + if (candIdx >= COs[coIdx].candidates.size()) { + llvm::errs() << "FPOpt: Candidate index " << candIdx + << " out of range for CO " << coIdx << " (" + << COs[coIdx].candidates.size() << " candidates)\n"; + continue; + } + if (!appliedCOs.insert(coIdx).second) { + llvm::errs() << "FPOpt: CO " << coIdx + << " already has a rewrite applied. Skipping " << id + << "\n"; + continue; + } + llvm::errs() << "FPOpt: Applying " << id << ": " << COs[coIdx].expr + << " -> " << COs[coIdx].candidates[candIdx].expr << "\n"; + COs[coIdx].apply(candIdx, valueToNodeMap, symbolToValueMap); + changed = true; + + } else if (id.starts_with("PT")) { + auto rest = id.drop_front(2); + auto [csStr, candStr] = rest.split('_'); + size_t csIdx, candIdx; + if (csStr.getAsInteger(10, csIdx) || + candStr.getAsInteger(10, candIdx)) { + llvm::errs() << "FPOpt: Invalid PT ID '" << id << "'\n"; + continue; + } + if (csIdx >= CSs.size()) { + llvm::errs() << "FPOpt: CS index " << csIdx << " out of range (" + << CSs.size() << " CSs)\n"; + continue; + } + if (candIdx >= CSs[csIdx].candidates.size()) { + llvm::errs() << "FPOpt: Candidate index " << candIdx + << " out of range for CS " << csIdx << " (" + << CSs[csIdx].candidates.size() << " candidates)\n"; + continue; + } + if (!appliedCSs.insert(csIdx).second) { + llvm::errs() << "FPOpt: CS " << csIdx + << " already has a PT applied. Skipping " << id << "\n"; + continue; + } + llvm::errs() << "FPOpt: Applying " << id << ": " + << CSs[csIdx].candidates[candIdx].desc << "\n"; + CSs[csIdx].apply(candIdx); + changed = true; + + } else { + llvm::errs() << "FPOpt: Unknown ID prefix in '" << id + << "' (expected R or PT)\n"; + } + } + } else if (!FPOptEnableSolver) { + if (FPOptEnableHerbie) { + for (auto &CO : COs) { + CO.apply(0, valueToNodeMap, symbolToValueMap); + changed = true; + } + } + } else { + if (FPOptSolverType == "greedy") { + changed = + accuracyGreedySolver(COs, CSs, valueToNodeMap, symbolToValueMap); + } else if (FPOptSolverType == "dp") { + changed = accuracyDPSolver(F, TTI, COs, CSs, valueToNodeMap, + symbolToValueMap, errorTol); + } else { + llvm::errs() << "FPOpt: Unknown solver type: " << FPOptSolverType << "\n"; + return false; + } + } + + llvm::errs() << "FPOpt: Finished optimizing " << F.getName() << "\n"; + + // Cleanup + if (changed) { + for (auto &subgraph : subgraphs) { + if (subgraph.outputs_rewritten != subgraph.outputs.size()) { + if (FPOptPrint) + llvm::errs() << "Skip erasing a subgraph: only rewrote " + << subgraph.outputs_rewritten << " of " + << subgraph.outputs.size() << " outputs\n"; + continue; // Intermediate operations cannot be erased safely + } + for (auto *I : subgraph.operations) { + if (FPOptPrint) + llvm::errs() << "Erasing: " << *I << "\n"; + if (!I->use_empty()) { + I->replaceAllUsesWith(UndefValue::get(I->getType())); + } + I->eraseFromParent(); + } + } + + llvm::errs() << "FPOpt: Finished cleaning up " << F.getName() << "\n"; + } + + runPoseidonFunctionSimplify(F, OptimizationLevel::O3); + + if (FPOptPrint) { + llvm::errs() << "FPOpt: Finished Optimization\n"; + F.print(llvm::errs()); + } + + return changed; +} + +namespace {} // namespace + +char FPOpt::ID = 0; + +FPOpt::FPOpt() : FunctionPass(ID) {} + +void FPOpt::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + FunctionPass::getAnalysisUsage(AU); +} + +bool FPOpt::runOnFunction(Function &F) { + auto &TTI = getAnalysis().getTTI(F); + return fpOptimize(F, TTI); +} + +static RegisterPass + X("fp-opt", "Run Enzyme/Poseidon Floating point optimizations"); + +FunctionPass *createFPOptPass() { return new FPOpt(); } + +#include +#include + +#include "llvm/IR/LegacyPassManager.h" + +extern "C" void AddFPOptPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createFPOptPass()); +} + +FPOptNewPM::Result FPOptNewPM::run(llvm::Module &M, + llvm::ModuleAnalysisManager &MAM) { + bool changed = false; + FunctionAnalysisManager &FAM = + MAM.getResult(M).getManager(); + for (auto &F : M) { + if (!F.isDeclaration()) { + const auto &TTI = FAM.getResult(F); + changed |= fpOptimize(F, TTI); + } + } + + return changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} +llvm::AnalysisKey FPOptNewPM::Key; diff --git a/enzyme/Enzyme/Poseidon/Poseidon.h b/enzyme/Enzyme/Poseidon/Poseidon.h new file mode 100644 index 000000000000..f018b3d23797 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/Poseidon.h @@ -0,0 +1,61 @@ +#ifndef ENZYME_POSEIDON_H +#define ENZYME_POSEIDON_H + +#include + +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/Utils/ValueMapper.h" + +using namespace llvm; + +extern "C" { +extern llvm::cl::opt FPProfileGenerate; +extern llvm::cl::opt FPProfileUse; +extern llvm::cl::opt FPOptPrint; +extern llvm::cl::opt FPOptEnableHerbie; +extern llvm::cl::opt FPOptEnablePT; +extern llvm::cl::opt FPOptEnableSolver; +extern llvm::cl::opt FPOptMaxExprDepth; +extern llvm::cl::opt FPOptMaxExprLength; +extern llvm::cl::opt FPOptReductionEval; +extern llvm::cl::opt FPOptCachePath; +extern llvm::cl::opt FPOptMultiOutputPTOnly; +} + +bool Poseidonable(const Value &V); +void setPoseidonMetadata(Function &F); +void preprocessForPoseidon(Function *F); +bool fpOptimize(Function &F, const TargetTransformInfo &TTI, + double errorTol = 0.0); + +class FPOpt final : public FunctionPass { +public: + static char ID; + FPOpt(); + + void getAnalysisUsage(AnalysisUsage &AU) const override; + bool runOnFunction(Function &F) override; +}; + +llvm::FunctionPass *createFPOptPass(); + +class FPOptNewPM final : public llvm::AnalysisInfoMixin { + friend struct llvm::AnalysisInfoMixin; + +private: + static llvm::AnalysisKey Key; + +public: + using Result = llvm::PreservedAnalyses; + FPOptNewPM() {} + + Result run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM); + + static bool isRequired() { return true; } +}; + +#endif // ENZYME_POSEIDON_H diff --git a/enzyme/Enzyme/Poseidon/PoseidonEvaluators.cpp b/enzyme/Enzyme/Poseidon/PoseidonEvaluators.cpp new file mode 100644 index 000000000000..5655aecad196 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonEvaluators.cpp @@ -0,0 +1,974 @@ +//=- PoseidonEvaluators.cpp - Expression evaluators for Poseidon ----------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the evaluator classes for floating-point expressions +// in the Poseidon optimization pass. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/raw_ostream.h" + +#include + +#include "PoseidonEvaluators.h" +#include "PoseidonPrecUtils.h" +#include "PoseidonTypes.h" + +using namespace llvm; + +extern "C" { +cl::opt FPOptStrictMode( + "fpopt-strict-mode", cl::init(false), cl::Hidden, + cl::desc( + "Discard all FPOpt candidates that produce NaN or inf outputs for any " + "input point that originally produced finite outputs")); +cl::opt FPOptGeoMeanEps( + "fpopt-geo-mean-eps", cl::init(0.0), cl::Hidden, + cl::desc("The offset used in the geometric mean " + "calculation; if = 0, zeros are replaced with ULPs")); +} + +FPEvaluator::FPEvaluator(PTCandidate *pt) { + if (pt) { + for (const auto &change : pt->changes) { + for (auto node : change.nodes) { + nodePrecisions[node] = change.newType; + } + } + } +} + +PrecisionChangeType FPEvaluator::getNodePrecision(const FPNode *node) const { + // If the node has a new precision from PT, use it + PrecisionChangeType precType; + + auto it = nodePrecisions.find(node); + if (it != nodePrecisions.end()) { + precType = it->second; + } else { + // Otherwise, use the node's original precision + if (node->dtype == "f32") { + precType = PrecisionChangeType::FP32; + } else if (node->dtype == "f64") { + precType = PrecisionChangeType::FP64; + } else { + llvm_unreachable( + ("Operator " + node->op + " has unexpected dtype: " + node->dtype) + .c_str()); + } + } + + if (precType != PrecisionChangeType::FP32 && + precType != PrecisionChangeType::FP64) { + llvm_unreachable("Unsupported FP precision"); + } + + return precType; +} + +void FPEvaluator::evaluateNode(const FPNode *node, + const MapVector &inputValues) { + if (cache.find(node) != cache.end()) + return; + + if (isa(node)) { + double constVal = node->getLowerBound(); + cache.emplace(node, constVal); + return; + } + + if (isa(node) && inputValues.count(cast(node)->value)) { + double inputValue = inputValues.lookup(cast(node)->value); + cache.emplace(node, inputValue); + return; + } + + if (node->op == "if") { + evaluateNode(node->operands[0].get(), inputValues); + double cond = getResult(node->operands[0].get()); + + if (cond == 1.0) { + evaluateNode(node->operands[1].get(), inputValues); + double then_val = getResult(node->operands[1].get()); + cache.emplace(node, then_val); + } else { + evaluateNode(node->operands[2].get(), inputValues); + double else_val = getResult(node->operands[2].get()); + cache.emplace(node, else_val); + } + return; + } else if (node->op == "and") { + evaluateNode(node->operands[0].get(), inputValues); + double op0 = getResult(node->operands[0].get()); + if (op0 != 1.0) { + cache.emplace(node, 0.0); + return; + } + evaluateNode(node->operands[1].get(), inputValues); + double op1 = getResult(node->operands[1].get()); + if (op1 != 1.0) { + cache.emplace(node, 0.0); + return; + } + cache.emplace(node, 1.0); + return; + } else if (node->op == "or") { + evaluateNode(node->operands[0].get(), inputValues); + double op0 = getResult(node->operands[0].get()); + if (op0 == 1.0) { + cache.emplace(node, 1.0); + return; + } + evaluateNode(node->operands[1].get(), inputValues); + double op1 = getResult(node->operands[1].get()); + if (op1 == 1.0) { + cache.emplace(node, 1.0); + return; + } + cache.emplace(node, 0.0); + return; + } else if (node->op == "not") { + evaluateNode(node->operands[0].get(), inputValues); + double op = getResult(node->operands[0].get()); + cache.emplace(node, (op == 1.0) ? 0.0 : 1.0); + return; + } else if (node->op == "TRUE") { + cache.emplace(node, 1.0); + return; + } else if (node->op == "FALSE") { + cache.emplace(node, 0.0); + return; + } + + PrecisionChangeType nodePrec = getNodePrecision(node); + + for (const auto &operand : node->operands) { + evaluateNode(operand.get(), inputValues); + } + + double res = 0.0; + + auto evalUnary = [&](auto doubleFunc, auto floatFunc) -> double { + double op = getResult(node->operands[0].get()); + if (nodePrec == PrecisionChangeType::FP32) + return floatFunc(static_cast(op)); + else + return doubleFunc(op); + }; + + auto evalBinary = [&](auto doubleFunc, auto floatFunc) -> double { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + if (nodePrec == PrecisionChangeType::FP32) + return floatFunc(static_cast(op0), static_cast(op1)); + else + return doubleFunc(op0, op1); + }; + + auto evalTernary = [&](auto doubleFunc, auto floatFunc) -> double { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + double op2 = getResult(node->operands[2].get()); + if (nodePrec == PrecisionChangeType::FP32) + return floatFunc(static_cast(op0), static_cast(op1), + static_cast(op2)); + else + return doubleFunc(op0, op1, op2); + }; + + if (node->op == "neg") { + double op = getResult(node->operands[0].get()); + res = + (nodePrec == PrecisionChangeType::FP32) ? -static_cast(op) : -op; + } else if (node->op == "+") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) + static_cast(op1) + : op0 + op1; + } else if (node->op == "-") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) - static_cast(op1) + : op0 - op1; + } else if (node->op == "*") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) * static_cast(op1) + : op0 * op1; + } else if (node->op == "/") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + res = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) / static_cast(op1) + : op0 / op1; + } else if (node->op == "sin") { + res = evalUnary(static_cast(std::sin), + static_cast(sinf)); + } else if (node->op == "cos") { + res = evalUnary(static_cast(std::cos), + static_cast(cosf)); + } else if (node->op == "tan") { + res = evalUnary(static_cast(std::tan), + static_cast(tanf)); + } else if (node->op == "exp") { + res = evalUnary(static_cast(std::exp), + static_cast(expf)); + } else if (node->op == "expm1") { + res = evalUnary(static_cast(std::expm1), + static_cast(expm1f)); + } else if (node->op == "log") { + res = evalUnary(static_cast(std::log), + static_cast(logf)); + } else if (node->op == "log1p") { + res = evalUnary(static_cast(std::log1p), + static_cast(log1pf)); + } else if (node->op == "sqrt") { + res = evalUnary(static_cast(std::sqrt), + static_cast(sqrtf)); + } else if (node->op == "cbrt") { + res = evalUnary(static_cast(std::cbrt), + static_cast(cbrtf)); + } else if (node->op == "asin") { + res = evalUnary(static_cast(std::asin), + static_cast(asinf)); + } else if (node->op == "acos") { + res = evalUnary(static_cast(std::acos), + static_cast(acosf)); + } else if (node->op == "atan") { + res = evalUnary(static_cast(std::atan), + static_cast(atanf)); + } else if (node->op == "sinh") { + res = evalUnary(static_cast(std::sinh), + static_cast(sinhf)); + } else if (node->op == "cosh") { + res = evalUnary(static_cast(std::cosh), + static_cast(coshf)); + } else if (node->op == "tanh") { + res = evalUnary(static_cast(std::tanh), + static_cast(tanhf)); + } else if (node->op == "asinh") { + res = evalUnary(static_cast(std::asinh), + static_cast(asinhf)); + } else if (node->op == "acosh") { + res = evalUnary(static_cast(std::acosh), + static_cast(acoshf)); + } else if (node->op == "atanh") { + res = evalUnary(static_cast(std::atanh), + static_cast(atanhf)); + } else if (node->op == "ceil") { + res = evalUnary(static_cast(std::ceil), + static_cast(ceilf)); + } else if (node->op == "floor") { + res = evalUnary(static_cast(std::floor), + static_cast(floorf)); + } else if (node->op == "exp2") { + res = evalUnary(static_cast(std::exp2), + static_cast(exp2f)); + } else if (node->op == "log10") { + res = evalUnary(static_cast(std::log10), + static_cast(log10f)); + } else if (node->op == "log2") { + res = evalUnary(static_cast(std::log2), + static_cast(log2f)); + } else if (node->op == "rint") { + res = evalUnary(static_cast(std::rint), + static_cast(rintf)); + } else if (node->op == "round") { + res = evalUnary(static_cast(std::round), + static_cast(roundf)); + } else if (node->op == "trunc") { + res = evalUnary(static_cast(std::trunc), + static_cast(truncf)); + } else if (node->op == "pow") { + res = evalBinary(static_cast(std::pow), + static_cast(powf)); + } else if (node->op == "fabs") { + res = evalUnary(static_cast(std::fabs), + static_cast(fabsf)); + } else if (node->op == "hypot") { + res = evalBinary(static_cast(std::hypot), + static_cast(hypotf)); + } else if (node->op == "atan2") { + res = evalBinary(static_cast(std::atan2), + static_cast(atan2f)); + } else if (node->op == "copysign") { + res = evalBinary(static_cast(std::copysign), + static_cast(copysignf)); + } else if (node->op == "fmax") { + res = evalBinary(static_cast(std::fmax), + static_cast(fmaxf)); + } else if (node->op == "fmin") { + res = evalBinary(static_cast(std::fmin), + static_cast(fminf)); + } else if (node->op == "fdim") { + res = evalBinary(static_cast(std::fdim), + static_cast(fdimf)); + } else if (node->op == "fmod") { + res = evalBinary(static_cast(std::fmod), + static_cast(fmodf)); + } else if (node->op == "remainder") { + res = evalBinary(static_cast(std::remainder), + static_cast(remainderf)); + } else if (node->op == "fma") { + res = evalTernary(static_cast(std::fma), + static_cast(fmaf)); + } else if (node->op == "lgamma") { + res = evalUnary(static_cast(std::lgamma), + static_cast(lgammaf)); + } else if (node->op == "tgamma") { + res = evalUnary(static_cast(std::tgamma), + static_cast(tgammaf)); + } else if (node->op == "==") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) == static_cast(op1) + : op0 == op1; + res = result ? 1.0 : 0.0; + } else if (node->op == "!=") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) != static_cast(op1) + : op0 != op1; + res = result ? 1.0 : 0.0; + } else if (node->op == "<") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) < static_cast(op1) + : op0 < op1; + res = result ? 1.0 : 0.0; + } else if (node->op == ">") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) > static_cast(op1) + : op0 > op1; + res = result ? 1.0 : 0.0; + } else if (node->op == "<=") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) <= static_cast(op1) + : op0 <= op1; + res = result ? 1.0 : 0.0; + } else if (node->op == ">=") { + double op0 = getResult(node->operands[0].get()); + double op1 = getResult(node->operands[1].get()); + bool result = (nodePrec == PrecisionChangeType::FP32) + ? static_cast(op0) >= static_cast(op1) + : op0 >= op1; + res = result ? 1.0 : 0.0; + } else if (node->op == "PI") { + res = M_PI; + } else if (node->op == "E") { + res = M_E; + } else if (node->op == "INFINITY") { + res = INFINITY; + } else if (node->op == "NAN") { + res = NAN; + } else { + std::string msg = "FPEvaluator: Unexpected operator " + node->op; + llvm_unreachable(msg.c_str()); + } + + cache.emplace(node, res); +} + +double FPEvaluator::getResult(const FPNode *node) const { + auto it = cache.find(node); + assert(it != cache.end() && "Node not evaluated yet"); + return it->second; +} + +MPFREvaluator::CachedValue::CachedValue(unsigned prec) : prec(prec) { + mpfr_init2(value, prec); + mpfr_set_zero(value, 1); +} + +MPFREvaluator::CachedValue::CachedValue(CachedValue &&other) noexcept + : prec(other.prec) { + mpfr_init2(value, other.prec); + mpfr_swap(value, other.value); +} + +MPFREvaluator::CachedValue & +MPFREvaluator::CachedValue::operator=(CachedValue &&other) noexcept { + if (this != &other) { + mpfr_set_prec(value, other.prec); + prec = other.prec; + mpfr_swap(value, other.value); + } + return *this; +} + +MPFREvaluator::CachedValue::~CachedValue() { mpfr_clear(value); } + +MPFREvaluator::MPFREvaluator(unsigned prec, PTCandidate *pt) : prec(prec) { + if (pt) { + for (const auto &change : pt->changes) { + for (auto node : change.nodes) { + nodeToNewPrec[node] = getMPFRPrec(change.newType); + } + } + } +} + +unsigned MPFREvaluator::getNodePrecision(const FPNode *node, + bool groundTruth) const { + // If trying to evaluate the ground truth, use the current MPFR precision + if (groundTruth) + return prec; + + // If the node has a new precision for PT, use it + auto it = nodeToNewPrec.find(node); + if (it != nodeToNewPrec.end()) { + return it->second; + } + + // Otherwise, use the original precision + return node->getMPFRPrec(); +} + +// Compute the expression with MPFR at `prec` precision +// recursively. When operand is a FPConst, use its lower +// bound. When operand is a FPLLValue, get its inputs from +// `inputs`. +void MPFREvaluator::evaluateNode(const FPNode *node, + const MapVector &inputValues, + bool groundTruth) { + if (cache.find(node) != cache.end()) + return; + + if (isa(node)) { + double constVal = node->getLowerBound(); + CachedValue cv(53); + mpfr_set_d(cv.value, constVal, MPFR_RNDN); + cache.emplace(node, CachedValue(std::move(cv))); + return; + } + + if (isa(node) && inputValues.count(cast(node)->value)) { + double inputValue = inputValues.lookup(cast(node)->value); + CachedValue cv(53); + mpfr_set_d(cv.value, inputValue, MPFR_RNDN); + cache.emplace(node, std::move(cv)); + return; + } + + if (node->op == "if") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &cond = getResult(node->operands[0].get()); + if (0 == mpfr_cmp_ui(cond, 1)) { + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &then_val = getResult(node->operands[1].get()); + cache.emplace(node, CachedValue(cache.at(node->operands[1].get()).prec)); + mpfr_set(cache.at(node).value, then_val, MPFR_RNDN); + } else { + evaluateNode(node->operands[2].get(), inputValues, groundTruth); + mpfr_t &else_val = getResult(node->operands[2].get()); + cache.emplace(node, CachedValue(cache.at(node->operands[2].get()).prec)); + mpfr_set(cache.at(node).value, else_val, MPFR_RNDN); + } + return; + } + + unsigned nodePrec = getNodePrecision(node, groundTruth); + cache.emplace(node, CachedValue(nodePrec)); + mpfr_t &res = cache.at(node).value; + + if (node->op == "neg") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_neg(res, op, MPFR_RNDN); + } else if (node->op == "+") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_add(res, op0, op1, MPFR_RNDN); + } else if (node->op == "-") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_sub(res, op0, op1, MPFR_RNDN); + } else if (node->op == "*") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_mul(res, op0, op1, MPFR_RNDN); + } else if (node->op == "/") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_div(res, op0, op1, MPFR_RNDN); + } else if (node->op == "sin") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_sin(res, op, MPFR_RNDN); + } else if (node->op == "cos") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_cos(res, op, MPFR_RNDN); + } else if (node->op == "tan") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_tan(res, op, MPFR_RNDN); + } else if (node->op == "asin") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_asin(res, op, MPFR_RNDN); + } else if (node->op == "acos") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_acos(res, op, MPFR_RNDN); + } else if (node->op == "atan") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_atan(res, op, MPFR_RNDN); + } else if (node->op == "atan2") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_atan2(res, op0, op1, MPFR_RNDN); + } else if (node->op == "exp") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_exp(res, op, MPFR_RNDN); + } else if (node->op == "expm1") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_expm1(res, op, MPFR_RNDN); + } else if (node->op == "log") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_log(res, op, MPFR_RNDN); + } else if (node->op == "log1p") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_log1p(res, op, MPFR_RNDN); + } else if (node->op == "sqrt") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_sqrt(res, op, MPFR_RNDN); + } else if (node->op == "cbrt") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_cbrt(res, op, MPFR_RNDN); + } else if (node->op == "pow") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_pow(res, op0, op1, MPFR_RNDN); + } else if (node->op == "fma") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + evaluateNode(node->operands[2].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_t &op2 = getResult(node->operands[2].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_prec_round(op2, nodePrec, MPFR_RNDN); + mpfr_fma(res, op0, op1, op2, MPFR_RNDN); + } else if (node->op == "fabs") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_abs(res, op, MPFR_RNDN); + } else if (node->op == "hypot") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_hypot(res, op0, op1, MPFR_RNDN); + } else if (node->op == "asinh") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_asinh(res, op, MPFR_RNDN); + } else if (node->op == "acosh") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_acosh(res, op, MPFR_RNDN); + } else if (node->op == "atanh") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_atanh(res, op, MPFR_RNDN); + } else if (node->op == "sinh") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_sinh(res, op, MPFR_RNDN); + } else if (node->op == "cosh") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_cosh(res, op, MPFR_RNDN); + } else if (node->op == "tanh") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_tanh(res, op, MPFR_RNDN); + } else if (node->op == "ceil") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_ceil(res, op); + } else if (node->op == "floor") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_floor(res, op); + } else if (node->op == "erf") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_erf(res, op, MPFR_RNDN); + } else if (node->op == "exp2") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_exp2(res, op, MPFR_RNDN); + } else if (node->op == "log10") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_log10(res, op, MPFR_RNDN); + } else if (node->op == "log2") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_log2(res, op, MPFR_RNDN); + } else if (node->op == "rint") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_rint(res, op, MPFR_RNDN); + } else if (node->op == "round") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_round(res, op); + } else if (node->op == "trunc") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_prec_round(op, nodePrec, MPFR_RNDN); + mpfr_trunc(res, op); + } else if (node->op == "copysign") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_copysign(res, op0, op1, MPFR_RNDN); + } else if (node->op == "fdim") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_dim(res, op0, op1, MPFR_RNDN); + } else if (node->op == "fmod") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_fmod(res, op0, op1, MPFR_RNDN); + } else if (node->op == "remainder") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_remainder(res, op0, op1, MPFR_RNDN); + } else if (node->op == "fmax") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_max(res, op0, op1, MPFR_RNDN); + } else if (node->op == "fmin") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + mpfr_prec_round(op0, nodePrec, MPFR_RNDN); + mpfr_prec_round(op1, nodePrec, MPFR_RNDN); + mpfr_min(res, op0, op1, MPFR_RNDN); + } else if (node->op == "==") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + if (0 == mpfr_cmp(op0, op1)) + mpfr_set_ui(res, 1, MPFR_RNDN); + else + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == "!=") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + if (0 != mpfr_cmp(op0, op1)) + mpfr_set_ui(res, 1, MPFR_RNDN); + else + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == "<") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + if (0 > mpfr_cmp(op0, op1)) + mpfr_set_ui(res, 1, MPFR_RNDN); + else + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == ">") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + if (0 < mpfr_cmp(op0, op1)) + mpfr_set_ui(res, 1, MPFR_RNDN); + else + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == "<=") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + if (0 >= mpfr_cmp(op0, op1)) + mpfr_set_ui(res, 1, MPFR_RNDN); + else + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == ">=") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + if (0 <= mpfr_cmp(op0, op1)) + mpfr_set_ui(res, 1, MPFR_RNDN); + else + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == "and") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + if (0 == mpfr_cmp_ui(op0, 1) && 0 == mpfr_cmp_ui(op1, 1)) + mpfr_set_ui(res, 1, MPFR_RNDN); + else + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == "or") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + evaluateNode(node->operands[1].get(), inputValues, groundTruth); + mpfr_t &op0 = getResult(node->operands[0].get()); + mpfr_t &op1 = getResult(node->operands[1].get()); + if (0 == mpfr_cmp_ui(op0, 1) || 0 == mpfr_cmp_ui(op1, 1)) + mpfr_set_ui(res, 1, MPFR_RNDN); + else + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == "not") { + evaluateNode(node->operands[0].get(), inputValues, groundTruth); + mpfr_t &op = getResult(node->operands[0].get()); + mpfr_set_prec(res, nodePrec); + if (0 == mpfr_cmp_ui(op, 1)) + mpfr_set_ui(res, 0, MPFR_RNDN); + else + mpfr_set_ui(res, 1, MPFR_RNDN); + } else if (node->op == "TRUE") { + mpfr_set_ui(res, 1, MPFR_RNDN); + } else if (node->op == "FALSE") { + mpfr_set_ui(res, 0, MPFR_RNDN); + } else if (node->op == "PI") { + mpfr_const_pi(res, MPFR_RNDN); + } else if (node->op == "E") { + mpfr_const_euler(res, MPFR_RNDN); + } else if (node->op == "INFINITY") { + mpfr_set_inf(res, 1); + } else if (node->op == "NAN") { + mpfr_set_nan(res); + } else { + llvm::errs() << "MPFREvaluator: Unexpected operator '" << node->op << "'\n"; + llvm_unreachable("Unexpected operator encountered"); + } +} + +mpfr_t &MPFREvaluator::getResult(FPNode *node) { + assert(cache.count(node) > 0 && "MPFREvaluator: Unexpected unevaluated node"); + return cache.at(node).value; +} + +// Emulate computation using native floating-point types +void getFPValues(ArrayRef outputs, + const MapVector &inputValues, + SmallVectorImpl &results, PTCandidate *pt) { + assert(!outputs.empty()); + results.resize(outputs.size()); + + FPEvaluator evaluator(pt); + + for (const auto *output : outputs) { + evaluator.evaluateNode(output, inputValues); + } + + for (size_t i = 0; i < outputs.size(); ++i) { + results[i] = evaluator.getResult(outputs[i]); + } +} + +// If looking for ground truth, compute a "correct" answer with MPFR. +// For each sampled input configuration: +// 0. Ignore `FPNode.dtype`. +// 1. Compute the expression with MPFR at `prec` precision +// by calling `MPFRValueHelper`. When operand is a FPConst, use its +// lower bound. When operand is a FPLLValue, get its inputs from +// `inputs`. +// 2. Dynamically extend precisions +// until the first `groundTruthPrec` bits of significand don't change. +// Otherwise, compute the expression with MPFR at precisions specified within +// `FPNode`s or new precisions specified by `pt`. +void getMPFRValues(ArrayRef outputs, + const MapVector &inputValues, + SmallVectorImpl &results, bool groundTruth, + const unsigned groundTruthPrec, PTCandidate *pt) { + assert(!outputs.empty()); + results.resize(outputs.size()); + + if (!groundTruth) { + MPFREvaluator evaluator(0, pt); + // if (pt) { + // llvm::errs() << "getMPFRValues: PT candidate detected: " << pt->desc + // << "\n"; + // } else { + // llvm::errs() << "getMPFRValues: emulating original computation\n"; + // } + + for (const auto *output : outputs) { + evaluator.evaluateNode(output, inputValues, false); + } + for (size_t i = 0; i < outputs.size(); ++i) { + results[i] = mpfr_get_d(evaluator.getResult(outputs[i]), MPFR_RNDN); + } + return; + } + + unsigned curPrec = 64; + std::vector prevResExp(outputs.size(), 0); + std::vector prevResStr(outputs.size(), nullptr); + std::vector prevResSign(outputs.size(), 0); + std::vector converged(outputs.size(), false); + size_t numConverged = 0; + + while (true) { + MPFREvaluator evaluator(curPrec, nullptr); + + // llvm::errs() << "getMPFRValues: computing ground truth with precision " + // << curPrec << "\n"; + + for (const auto *output : outputs) { + evaluator.evaluateNode(output, inputValues, true); + } + + for (size_t i = 0; i < outputs.size(); ++i) { + if (converged[i]) + continue; + + mpfr_t &res = evaluator.getResult(outputs[i]); + int resSign = mpfr_sgn(res); + mpfr_exp_t resExp; + char *resStr = + mpfr_get_str(nullptr, &resExp, 2, groundTruthPrec, res, MPFR_RNDN); + + if (prevResStr[i] != nullptr && resSign == prevResSign[i] && + resExp == prevResExp[i] && strcmp(resStr, prevResStr[i]) == 0) { + converged[i] = true; + numConverged++; + mpfr_free_str(resStr); + mpfr_free_str(prevResStr[i]); + prevResStr[i] = nullptr; + continue; + } + + if (prevResStr[i]) { + mpfr_free_str(prevResStr[i]); + } + prevResStr[i] = resStr; + prevResExp[i] = resExp; + prevResSign[i] = resSign; + } + + if (numConverged == outputs.size()) { + for (size_t i = 0; i < outputs.size(); ++i) { + results[i] = mpfr_get_d(evaluator.getResult(outputs[i]), MPFR_RNDN); + } + break; + } + + curPrec *= 2; + + if (curPrec > FPOptMaxMPFRPrec) { + llvm::errs() << "getMPFRValues: MPFR precision limit reached for some " + "outputs, returning NaN\n"; + for (size_t i = 0; i < outputs.size(); ++i) { + if (!converged[i]) { + mpfr_free_str(prevResStr[i]); + results[i] = std::numeric_limits::quiet_NaN(); + } else { + results[i] = mpfr_get_d(evaluator.getResult(outputs[i]), MPFR_RNDN); + } + } + return; + } + } +} \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonEvaluators.h b/enzyme/Enzyme/Poseidon/PoseidonEvaluators.h new file mode 100644 index 000000000000..99258b76dc08 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonEvaluators.h @@ -0,0 +1,86 @@ +//=- PoseidonEvaluators.h - Expression evaluators for Poseidon ------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the evaluator classes for floating-point expressions +// in the Poseidon optimization pass. +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYME_POSEIDON_EVALUATORS_H +#define ENZYME_POSEIDON_EVALUATORS_H + +#include "llvm/ADT/MapVector.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/CommandLine.h" +#include +#include + +#include "PoseidonPrecUtils.h" +#include "PoseidonTypes.h" + +using namespace llvm; + +extern "C" { +extern llvm::cl::opt FPOptStrictMode; +extern llvm::cl::opt FPOptGeoMeanEps; +} + +class FPEvaluator { +private: + std::unordered_map cache; + std::unordered_map nodePrecisions; + +public: + FPEvaluator(PTCandidate *pt = nullptr); + + PrecisionChangeType getNodePrecision(const FPNode *node) const; + void evaluateNode(const FPNode *node, + const MapVector &inputValues); + double getResult(const FPNode *node) const; +}; + +class MPFREvaluator { +private: + struct CachedValue { + mpfr_t value; + unsigned prec; + + CachedValue(unsigned prec); + CachedValue(const CachedValue &) = delete; + CachedValue &operator=(const CachedValue &) = delete; + CachedValue(CachedValue &&other) noexcept; + CachedValue &operator=(CachedValue &&other) noexcept; + virtual ~CachedValue(); + }; + + std::unordered_map cache; + unsigned prec; + std::unordered_map nodeToNewPrec; + +public: + MPFREvaluator(unsigned prec, PTCandidate *pt = nullptr); + virtual ~MPFREvaluator() = default; + + unsigned getNodePrecision(const FPNode *node, bool groundTruth) const; + void evaluateNode(const FPNode *node, + const MapVector &inputValues, + bool groundTruth); + mpfr_t &getResult(FPNode *node); +}; + +void getFPValues(ArrayRef outputs, + const MapVector &inputValues, + SmallVectorImpl &results, PTCandidate *pt = nullptr); + +void getMPFRValues(ArrayRef outputs, + const MapVector &inputValues, + SmallVectorImpl &results, bool groundTruth, + const unsigned groundTruthPrec = 53, + PTCandidate *pt = nullptr); + +#endif // ENZYME_POSEIDON_EVALUATORS_H \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonHerbieUtils.cpp b/enzyme/Enzyme/Poseidon/PoseidonHerbieUtils.cpp new file mode 100644 index 000000000000..a2516be3917d --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonHerbieUtils.cpp @@ -0,0 +1,895 @@ +//=- PoseidonHerbieUtils.cpp - Herbie integration utilities ---------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities for integrating with the Herbie tool for +// floating-point expression optimization. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/InstructionCost.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/Program.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar/EarlyCSE.h" +#include "llvm/Transforms/Utils/Mem2Reg.h" + +#include "../Utils.h" +#include "Poseidon.h" +#include "PoseidonEvaluators.h" +#include "PoseidonHerbieUtils.h" +#include "PoseidonSolvers.h" +#include "PoseidonTypes.h" +#include "PoseidonUtils.h" + +#include +#include +#include +#include +#include + +using namespace llvm; + +extern "C" { +cl::opt HerbieNumThreads("herbie-num-threads", cl::init(8), cl::Hidden, + cl::desc("Number of threads Herbie uses")); +cl::opt HerbieTimeout("herbie-timeout", cl::init(9999), cl::Hidden, + cl::desc("Herbie's timeout to use for each " + "candidate expressions.")); +cl::opt + HerbieNumPoints("herbie-num-pts", cl::init(1024), cl::Hidden, + cl::desc("Number of input points Herbie uses to evaluate " + "candidate expressions.")); +cl::opt HerbieNumIters( + "herbie-num-iters", cl::init(6), cl::Hidden, + cl::desc("Number of times Herbie attempts to improve accuracy.")); +cl::opt HerbieNumEnodes( + "herbie-num-enodes", cl::init(8000), cl::Hidden, + cl::desc("Number of equivalence graph nodes to use when doing algebraic " + "reasoning in Herbie.")); +cl::opt HerbieDisableNumerics( + "herbie-disable-numerics", cl::init(false), cl::Hidden, + cl::desc("Disable Herbie rewrite rules that produce numerical shorthands " + "expm1, log1p, fma, and hypot")); +cl::opt HerbieDisableArithmetic( + "herbie-disable-arithmetic", cl::init(false), cl::Hidden, + cl::desc("Disable Herbie rewrite rules on basic arithmetic fasts.")); +cl::opt HerbieDisableFractions( + "herbie-disable-fractions", cl::init(false), cl::Hidden, + cl::desc("Disable Herbie rewrite rules on fraction arithmetic.")); +cl::opt + HerbieDisableTaylor("herbie-disable-taylor", cl::init(false), cl::Hidden, + cl::desc("Disable Herbie's series expansion")); +cl::opt HerbieDisableSetupSimplify( + "herbie-disable-setup-simplify", cl::init(false), cl::Hidden, + cl::desc("Stop Herbie from pre-simplifying expressions")); +cl::opt HerbieDisableGenSimplify( + "herbie-disable-gen-simplify", cl::init(false), cl::Hidden, + cl::desc("Stop Herbie from simplifying expressions " + "during the main improvement loop")); +cl::opt HerbieDisableRegime( + "herbie-disable-regime", cl::init(false), cl::Hidden, + cl::desc("Stop Herbie from branching between expressions candidates")); +cl::opt HerbieDisableBranchExpr( + "herbie-disable-branch-expr", cl::init(false), cl::Hidden, + cl::desc("Stop Herbie from branching on expressions")); +cl::opt HerbieDisableAvgError( + "herbie-disable-avg-error", cl::init(false), cl::Hidden, + cl::desc("Make Herbie choose the candidates with the least maximum error")); +} + +std::shared_ptr parseHerbieExpr( + const std::string &expr, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + // if (FPOptPrint) + // llvm::errs() << "Parsing: " << expr << "\n"; + std::string trimmedExpr = expr; + trimmedExpr.erase(0, trimmedExpr.find_first_not_of(" ")); + trimmedExpr.erase(trimmedExpr.find_last_not_of(" ") + 1); + + // Arguments + if (trimmedExpr.front() != '(' && trimmedExpr.front() != '#') { + if (auto node = valueToNodeMap[symbolToValueMap[trimmedExpr]]) { + return node; + } + } + + // Constants + static const std::regex constantPattern( + "^#s\\(literal\\s+([-+]?\\d+(/\\d+)?|[-+]?inf\\.0)\\s+(\\w+)\\)$"); + static const std::regex plainConstantPattern( + R"(^([-+]?(\d+(\.\d+)?)(/\d+)?|[-+]?inf\.0))"); + + { + std::smatch matches; + if (std::regex_match(trimmedExpr, matches, constantPattern)) { + std::string value = matches[1].str(); + std::string dtype = matches[3].str(); + if (dtype == "binary64") { + dtype = "f64"; + } else if (dtype == "binary32") { + dtype = "f32"; + } else { + std::string msg = + "Herbie expr parser: Unexpected constant dtype: " + dtype; + llvm_unreachable(msg.c_str()); + } + // if (FPOptPrint) + // llvm::errs() << "Herbie expr parser: Found __const " << value + // << " with dtype " << dtype << "\n"; + return std::make_shared(value, dtype); + } else if (std::regex_match(trimmedExpr, matches, plainConstantPattern)) { + std::string value = matches[1].str(); + std::string dtype = "f64"; // Assume f64 by default + return std::make_shared(value, dtype); + } + } + + if (trimmedExpr.substr(0, 9) == "#s(approx") { + if (trimmedExpr.back() != ')') { + llvm_unreachable(("Malformed approx expression: " + trimmedExpr).c_str()); + } + std::string inner = trimmedExpr.substr(9, trimmedExpr.size() - 9 - 1); + inner.erase(0, inner.find_first_not_of(" ")); + inner.erase(inner.find_last_not_of(" ") + 1); + + int depth = 0; + size_t splitPos = std::string::npos; + for (size_t i = 0; i < inner.size(); ++i) { + if (inner[i] == '(') + depth++; + else if (inner[i] == ')') + depth--; + else if (inner[i] == ' ' && depth == 0) { + splitPos = i; + break; + } + } + if (splitPos == std::string::npos) { + llvm_unreachable(("Malformed approx expression: " + trimmedExpr).c_str()); + } + std::string resultPart = inner.substr(splitPos + 1); + resultPart.erase(0, resultPart.find_first_not_of(" ")); + resultPart.erase(resultPart.find_last_not_of(" ") + 1); + return parseHerbieExpr(resultPart, valueToNodeMap, symbolToValueMap); + } + + if (trimmedExpr.front() != '(' || trimmedExpr.back() != ')') { + llvm::errs() << "Unexpected subexpression: " << trimmedExpr << "\n"; + assert(0 && "Failed to parse Herbie expression"); + } + + trimmedExpr = trimmedExpr.substr(1, trimmedExpr.size() - 2); + + auto endOp = trimmedExpr.find(' '); + std::string fullOp = trimmedExpr.substr(0, endOp); + + size_t pos = fullOp.find('.'); + std::string dtype; + std::string op; + if (pos != std::string::npos) { + op = fullOp.substr(0, pos); + dtype = fullOp.substr(pos + 1); + assert(dtype == "f64" || dtype == "f32"); + // llvm::errs() << "Herbie expr parser: Found operator " << op + // << " with dtype " << dtype << "\n"; + } else { + op = fullOp; + // llvm::errs() << "Herbie expr parser: Found operator " << op << "\n"; + } + + auto node = std::make_shared(op, dtype); + + int depth = 0; + auto start = trimmedExpr.find_first_not_of(" ", endOp); + std::string::size_type curr; + for (curr = start; curr < trimmedExpr.size(); ++curr) { + if (trimmedExpr[curr] == '(') + depth++; + if (trimmedExpr[curr] == ')') + depth--; + if (depth == 0 && trimmedExpr[curr] == ' ') { + node->addOperand(parseHerbieExpr(trimmedExpr.substr(start, curr - start), + valueToNodeMap, symbolToValueMap)); + start = curr + 1; + } + } + if (start < curr) { + node->addOperand(parseHerbieExpr(trimmedExpr.substr(start, curr - start), + valueToNodeMap, symbolToValueMap)); + } + + return node; +} + +bool improveViaHerbie( + const std::vector &inputExprs, + std::vector &COs, Module *M, + const TargetTransformInfo &TTI, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap, + int subgraphIdx) { + std::string Program = HERBIE_BINARY; + llvm::errs() << "random seed: " << std::to_string(FPOptRandomSeed) << "\n"; + + SmallVector BaseArgs = { + Program, "report", + "--seed", std::to_string(FPOptRandomSeed), + "--timeout", std::to_string(HerbieTimeout), + "--threads", std::to_string(HerbieNumThreads), + "--num-points", std::to_string(HerbieNumPoints), + "--num-iters", std::to_string(HerbieNumIters), + "--num-enodes", std::to_string(HerbieNumEnodes)}; + + BaseArgs.push_back("--disable"); + BaseArgs.push_back("generate:proofs"); + + if (HerbieDisableNumerics) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("rules:numerics"); + } + + if (HerbieDisableArithmetic) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("rules:arithmetic"); + } + + if (HerbieDisableFractions) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("rules:fractions"); + } + + if (HerbieDisableSetupSimplify) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("setup:simplify"); + } + + if (HerbieDisableGenSimplify) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("generate:simplify"); + } + + if (HerbieDisableTaylor) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("generate:taylor"); + } + + if (HerbieDisableRegime) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("reduce:regimes"); + } + + if (HerbieDisableBranchExpr) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("reduce:branch-expressions"); + } + + if (HerbieDisableAvgError) { + BaseArgs.push_back("--disable"); + BaseArgs.push_back("reduce:avg-error"); + } + + SmallVector> BaseArgsList; + BaseArgsList.push_back(BaseArgs); + + std::vector> seenExprs(COs.size()); + + bool success = false; + + auto processHerbieOutput = [&](const std::string &content, + bool skipEvaluation = false) -> bool { + Expected parsed = json::parse(content); + if (!parsed) { + llvm::errs() << "Failed to parse Herbie result!\n"; + return false; + } + + json::Object *obj = parsed->getAsObject(); + json::Array &tests = *obj->getArray("tests"); + + for (size_t testIndex = 0; testIndex < tests.size(); ++testIndex) { + auto &test = *tests[testIndex].getAsObject(); + + StringRef bestExpr = test.getString("output").value(); + if (bestExpr == "#f") { + continue; + } + + StringRef ID = test.getString("name").value(); + size_t index = std::stoul(ID.str()); + if (index >= COs.size()) { + llvm::errs() << "Invalid CO index: " << index << "\n"; + continue; + } + + CandidateOutput &CO = COs[index]; + auto &seenExprSet = seenExprs[index]; + + double bits = test.getNumber("bits").value(); + json::Array &costAccuracy = *test.getArray("cost-accuracy"); + + json::Array &initial = *costAccuracy[0].getAsArray(); + double initialCostVal = initial[0].getAsNumber().value(); + double initialAccuracy = 1.0 - initial[1].getAsNumber().value() / bits; + double initialCost = 1.0; + + CO.initialHerbieCost = initialCost; + CO.initialHerbieAccuracy = initialAccuracy; + + if (seenExprSet.count(bestExpr.str()) == 0) { + seenExprSet.insert(bestExpr.str()); + + json::Array &best = *costAccuracy[1].getAsArray(); + double bestCost = best[0].getAsNumber().value() / initialCostVal; + double bestAccuracy = 1.0 - best[1].getAsNumber().value() / bits; + + RewriteCandidate bestCandidate(bestCost, bestAccuracy, bestExpr.str()); + if (!skipEvaluation) { + bestCandidate.CompCost = getCompCost( + bestExpr.str(), M, TTI, valueToNodeMap, symbolToValueMap, + cast(CO.oldOutput)->getFastMathFlags()); + } + CO.candidates.push_back(bestCandidate); + } + + json::Array &alternatives = *costAccuracy[2].getAsArray(); + + // Handle alternatives + for (size_t j = 0; j < alternatives.size(); ++j) { + json::Array &entry = *alternatives[j].getAsArray(); + StringRef expr = entry[2].getAsString().value(); + + if (seenExprSet.count(expr.str()) != 0) { + continue; + } + seenExprSet.insert(expr.str()); + + double cost = entry[0].getAsNumber().value() / initialCostVal; + double accuracy = 1.0 - entry[1].getAsNumber().value() / bits; + + RewriteCandidate candidate(cost, accuracy, expr.str()); + if (!skipEvaluation) { + candidate.CompCost = + getCompCost(expr.str(), M, TTI, valueToNodeMap, symbolToValueMap, + cast(CO.oldOutput)->getFastMathFlags()); + } + CO.candidates.push_back(candidate); + } + + if (!skipEvaluation) { + setUnifiedAccuracyCost(CO, valueToNodeMap, symbolToValueMap); + } + } + return true; + }; + + for (size_t baseArgsIndex = 0; baseArgsIndex < BaseArgsList.size(); + ++baseArgsIndex) { + const auto &BaseArgs = BaseArgsList[baseArgsIndex]; + std::string content; + + // Try to get cached Herbie output first + std::string cacheFilePath; + bool cached = false; + + if (!FPOptCachePath.empty()) { + cacheFilePath = FPOptCachePath + "/cachedHerbieOutput_" + + std::to_string(subgraphIdx) + "_" + + std::to_string(baseArgsIndex) + ".txt"; + std::ifstream cacheFile(cacheFilePath); + if (cacheFile) { + content.assign((std::istreambuf_iterator(cacheFile)), + std::istreambuf_iterator()); + cacheFile.close(); + llvm::errs() << "Using cached Herbie output from " << cacheFilePath + << "\n"; + cached = true; + } + } + + // If we have cached output, process it directly + if (cached) { + llvm::errs() << "Herbie output: " << content << "\n"; + std::string dpCacheFilePath = FPOptCachePath + "/table.json"; + bool skipEvaluation = FPOptSolverType == "dp" && + !FPOptCachePath.empty() && + llvm::sys::fs::exists(dpCacheFilePath); + if (processHerbieOutput(content, skipEvaluation)) { + success = true; + } + continue; + } + + // No cached result, need to run Herbie + SmallString<32> tmpin, tmpout; + + if (llvm::sys::fs::createUniqueFile("herbie_input_%%%%%%%%%%%%%%%%", tmpin, + llvm::sys::fs::perms::owner_all)) { + llvm::errs() << "Failed to create a unique input file.\n"; + continue; + } + + if (llvm::sys::fs::createUniqueDirectory("herbie_output_%%%%%%%%%%%%%%%%", + tmpout)) { + llvm::errs() << "Failed to create a unique output directory.\n"; + if (auto EC = llvm::sys::fs::remove(tmpin)) + llvm::errs() << "Warning: Failed to remove temporary input file: " + << EC.message() << "\n"; + continue; + } + + std::ofstream input(tmpin.c_str()); + if (!input) { + llvm::errs() << "Failed to open input file.\n"; + if (auto EC = llvm::sys::fs::remove(tmpin)) + llvm::errs() << "Warning: Failed to remove temporary input file: " + << EC.message() << "\n"; + if (auto EC = llvm::sys::fs::remove_directories(tmpout)) + llvm::errs() << "Warning: Failed to remove temporary output directory: " + << EC.message() << "\n"; + continue; + } + for (const auto &expr : inputExprs) { + input << expr << "\n"; + } + input.close(); + + SmallVector Args; + Args.reserve(BaseArgs.size()); + for (const auto &arg : BaseArgs) { + Args.emplace_back(arg); + } + + Args.push_back(tmpin); + Args.push_back(tmpout); + + std::string ErrMsg; + bool ExecutionFailed = false; + + if (FPOptPrint) { + llvm::errs() << "Executing Herbie with arguments: "; + for (const auto &arg : Args) { + llvm::errs() << arg << " "; + } + llvm::errs() << "\n"; + } + + llvm::sys::ExecuteAndWait(Program, Args, /*Env=*/{}, + /*Redirects=*/{}, + /*SecondsToWait=*/0, /*MemoryLimit=*/0, &ErrMsg, + &ExecutionFailed); + + std::remove(tmpin.c_str()); + if (ExecutionFailed) { + llvm::errs() << "Execution failed: " << ErrMsg << "\n"; + if (auto EC = llvm::sys::fs::remove_directories(tmpout)) + llvm::errs() << "Warning: Failed to remove temporary output directory: " + << EC.message() << "\n"; + continue; + } + + std::ifstream output((tmpout + "/results.json").str()); + if (!output) { + llvm::errs() << "Failed to open output file.\n"; + if (auto EC = llvm::sys::fs::remove_directories(tmpout)) + llvm::errs() << "Warning: Failed to remove temporary output directory: " + << EC.message() << "\n"; + continue; + } + content.assign((std::istreambuf_iterator(output)), + std::istreambuf_iterator()); + output.close(); + if (auto EC = llvm::sys::fs::remove_directories(tmpout)) + llvm::errs() << "Warning: Failed to remove temporary output directory: " + << EC.message() << "\n"; + + llvm::errs() << "Herbie output: " << content << "\n"; + + // Save output to cache if needed + if (!FPOptCachePath.empty()) { + if (auto EC = llvm::sys::fs::create_directories(FPOptCachePath, true)) + llvm::errs() << "Warning: Could not create cache directory: " + << EC.message() << "\n"; + std::ofstream cacheFile(cacheFilePath); + if (!cacheFile) { + llvm_unreachable("Failed to open cache file for writing"); + } else { + cacheFile << content; + cacheFile.close(); + llvm::errs() << "Saved Herbie output to cache file " << cacheFilePath + << "\n"; + } + } + + // Process the output + if (processHerbieOutput(content, false)) { + success = true; + } + } + + return success; +} + +std::string getHerbieOperator(const Instruction &I) { + switch (I.getOpcode()) { + case Instruction::FNeg: + return "neg"; + case Instruction::FAdd: + return "+"; + case Instruction::FSub: + return "-"; + case Instruction::FMul: + return "*"; + case Instruction::FDiv: + return "/"; + case Instruction::Call: { + const CallInst *CI = dyn_cast(&I); + assert(CI && CI->getCalledFunction() && + "getHerbieOperator: Call without a function"); + + StringRef funcName = CI->getCalledFunction()->getName(); + + // LLVM intrinsics + if (startsWith(funcName, "llvm.")) { + std::regex regex("llvm\\.(\\w+)\\.?.*"); + std::smatch matches; + std::string nameStr = funcName.str(); + if (std::regex_search(nameStr, matches, regex) && matches.size() > 1) { + std::string intrinsic = matches[1]; + // Special case mappings + if (intrinsic == "fmuladd") + return "fma"; + if (intrinsic == "maxnum") + return "fmax"; + if (intrinsic == "minnum") + return "fmin"; + if (intrinsic == "powi") + return "pow"; + return intrinsic; + } + assert(0 && "getHerbieOperator: Unknown LLVM intrinsic"); + } + // libm functions + else { + std::string name = funcName.str(); + if (!name.empty() && name.back() == 'f') { + name.pop_back(); + } + return name; + } + } + default: + assert(0 && "getHerbieOperator: Unknown operator"); + } +} + +std::string getPrecondition( + const SmallSet &args, + const std::unordered_map> &valueToNodeMap, + const std::unordered_map &symbolToValueMap) { + std::string preconditions; + + for (const auto &arg : args) { + const auto node = valueToNodeMap.at(symbolToValueMap.at(arg)); + double lower = node->getLowerBound(); + double upper = node->getUpperBound(); + + if (upper - lower < 1e-10 && !std::isinf(lower) && !std::isinf(upper)) { + double midpoint = (lower + upper) / 2.0; + double tolerance = std::max(1e-10, std::abs(midpoint) * 1e-6); + lower = midpoint - tolerance; + upper = midpoint + tolerance; + } + + std::ostringstream lowerStr, upperStr; + lowerStr << std::setprecision(std::numeric_limits::max_digits10) + << std::scientific << lower; + upperStr << std::setprecision(std::numeric_limits::max_digits10) + << std::scientific << upper; + + preconditions += + " (<=" + + (std::isinf(lower) ? (lower > 0 ? " INFINITY" : " (- INFINITY)") + : (" " + lowerStr.str())) + + " " + arg + + (std::isinf(upper) ? (upper > 0 ? " INFINITY" : " (- INFINITY)") + : (" " + upperStr.str())) + + ")"; + } + + return preconditions.empty() ? "TRUE" : "(and" + preconditions + ")"; +} + +void setUnifiedAccuracyCost( + CandidateOutput &CO, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + + SmallVector, 4> sampledPoints; + getSampledPoints(CO.subgraph->inputs.getArrayRef(), valueToNodeMap, + symbolToValueMap, sampledPoints); + + SmallVector goldVals; + goldVals.resize(FPOptNumSamples); + + double origCost = 0.0; + if (FPOptReductionEval == "geomean") { + double sumLog = 0.0; + unsigned count = 0; + for (const auto &pair : enumerate(sampledPoints)) { + std::shared_ptr node = valueToNodeMap[CO.oldOutput]; + SmallVector results; + getMPFRValues({node.get()}, pair.value(), results, true, 53); + double goldVal = results[0]; + goldVals[pair.index()] = goldVal; + + getFPValues({node.get()}, pair.value(), results); + double realVal = results[0]; + double error = std::fabs(goldVal - realVal); + if (!std::isnan(error)) { + if (error == 0.0) { + if (FPOptGeoMeanEps == 0.0) + error = getOneULP(goldVal); + else + error += FPOptGeoMeanEps; + } else if (FPOptGeoMeanEps != 0.0) { + error += FPOptGeoMeanEps; + } + sumLog += std::log(error); + ++count; + } + } + assert(count != 0 && "No valid sample found for original expr"); + origCost = std::exp(sumLog / count); + } else if (FPOptReductionEval == "arithmean") { + double sum = 0.0; + unsigned count = 0; + for (const auto &pair : enumerate(sampledPoints)) { + std::shared_ptr node = valueToNodeMap[CO.oldOutput]; + SmallVector results; + getMPFRValues({node.get()}, pair.value(), results, true, 53); + double goldVal = results[0]; + goldVals[pair.index()] = goldVal; + + getFPValues({node.get()}, pair.value(), results); + double realVal = results[0]; + double error = std::fabs(goldVal - realVal); + if (!std::isnan(error)) { + sum += error; + ++count; + } + } + assert(count != 0 && "No valid sample found for original expr"); + origCost = sum / count; + } else if (FPOptReductionEval == "maxabs") { + double maxErr = 0.0; + for (const auto &pair : enumerate(sampledPoints)) { + std::shared_ptr node = valueToNodeMap[CO.oldOutput]; + SmallVector results; + getMPFRValues({node.get()}, pair.value(), results, true, 53); + double goldVal = results[0]; + goldVals[pair.index()] = goldVal; + + getFPValues({node.get()}, pair.value(), results); + double realVal = results[0]; + double error = std::fabs(goldVal - realVal); + if (!std::isnan(error)) + maxErr = std::max(maxErr, error); + } + origCost = maxErr; + } else { + llvm_unreachable("Unknown fpopt-reduction strategy"); + } + CO.initialAccCost = origCost * std::fabs(CO.grad); + if (std::isnan(CO.initialAccCost)) { + llvm::errs() << "Warning: NaN in initialAccCost computation:\n"; + llvm::errs() << " origCost = " << origCost << "\n"; + llvm::errs() << " CO.grad = " << CO.grad << "\n"; + llvm::errs() << " fabs(CO.grad) = " << std::fabs(CO.grad) << "\n"; + } + + SmallVector newCandidates; + for (auto &candidate : CO.candidates) { + bool discardCandidate = false; + double candCost = 0.0; + + std::shared_ptr parsedNode = + parseHerbieExpr(candidate.expr, valueToNodeMap, symbolToValueMap); + + if (FPOptReductionEval == "geomean") { + double sumLog = 0.0; + unsigned count = 0; + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getFPValues({parsedNode.get()}, pair.value(), results); + double realVal = results[0]; + double goldVal = goldVals[pair.index()]; + + if (FPOptStrictMode && (!std::isnan(goldVal)) && std::isnan(realVal)) { + discardCandidate = true; + break; + } + + double error = std::fabs(goldVal - realVal); + if (!std::isnan(error)) { + if (error == 0.0) { + if (FPOptGeoMeanEps == 0.0) + error = getOneULP(goldVal); + else + error += FPOptGeoMeanEps; + } else if (FPOptGeoMeanEps != 0.0) { + error += FPOptGeoMeanEps; + } + sumLog += std::log(error); + ++count; + } + } + if (!discardCandidate) { + if (count == 0) { + discardCandidate = true; + } else { + candCost = std::exp(sumLog / count); + } + } + } else if (FPOptReductionEval == "arithmean") { + double sum = 0.0; + unsigned count = 0; + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getFPValues({parsedNode.get()}, pair.value(), results); + double realVal = results[0]; + double goldVal = goldVals[pair.index()]; + + if (FPOptStrictMode && !std::isnan(goldVal) && std::isnan(realVal)) { + discardCandidate = true; + break; + } + + double error = std::fabs(goldVal - realVal); + if (!std::isnan(error)) { + sum += error; + ++count; + } + } + if (!discardCandidate) { + if (count == 0) { + discardCandidate = true; + } else { + candCost = sum / count; + } + } + } else if (FPOptReductionEval == "maxabs") { + double maxErr = 0.0; + bool hasValid = false; + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getFPValues({parsedNode.get()}, pair.value(), results); + double realVal = results[0]; + double goldVal = goldVals[pair.index()]; + + if (FPOptStrictMode && !std::isnan(goldVal) && std::isnan(realVal)) { + discardCandidate = true; + break; + } + + double error = std::fabs(goldVal - realVal); + if (!std::isnan(error)) { + hasValid = true; + maxErr = std::max(maxErr, error); + } + } + if (!discardCandidate) { + if (!hasValid) { + discardCandidate = true; + } else { + candCost = maxErr; + } + } + } else { + llvm_unreachable("Unknown fpopt-reduction strategy"); + } + + if (!discardCandidate) { + candidate.accuracyCost = candCost * std::fabs(CO.grad); + assert(!std::isnan(candidate.accuracyCost)); + newCandidates.push_back(std::move(candidate)); + } + } + CO.candidates = std::move(newCandidates); +} + +InstructionCost getCompCost( + const std::string &expr, Module *M, const TargetTransformInfo &TTI, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap, + const FastMathFlags &FMF) { + // llvm::errs() << "Evaluating cost of " << expr << "\n"; + SmallSet argStrSet; + getUniqueArgs(expr, argStrSet); + + SetVector args; + SmallVector argTypes; + SmallVector argNames; + for (const auto &argStr : argStrSet) { + Value *argValue = symbolToValueMap[argStr]; + args.insert(argValue); + argTypes.push_back(argValue->getType()); + argNames.push_back(argStr); + } + + auto parsedNode = parseHerbieExpr(expr, valueToNodeMap, symbolToValueMap); + + Type *ReturnType = nullptr; + for (Type *ArgTy : argTypes) { + if (ArgTy->isFloatingPointTy()) { + ReturnType = ArgTy; + break; + } + if (ArgTy->isVectorTy()) { + if (auto *VT = dyn_cast(ArgTy)) { + if (VT->getElementType()->isFloatingPointTy()) { + ReturnType = ArgTy; + break; + } + } + } + } + if (!ReturnType) { + if (parsedNode->dtype == "f32") + ReturnType = Type::getFloatTy(M->getContext()); + else if (parsedNode->dtype == "f16") + ReturnType = Type::getHalfTy(M->getContext()); + else + ReturnType = Type::getDoubleTy(M->getContext()); + } + + FunctionType *FT = FunctionType::get(ReturnType, argTypes, false); + Function *tempFunction = + Function::Create(FT, Function::InternalLinkage, "tempFunc", M); + + ValueToValueMapTy VMap; + Function::arg_iterator AI = tempFunction->arg_begin(); + for (const auto &argStr : argNames) { + VMap[symbolToValueMap[argStr]] = &*AI; + ++AI; + } + + BasicBlock *entry = + BasicBlock::Create(M->getContext(), "entry", tempFunction); + + IRBuilder<> builder(entry); + + builder.setFastMathFlags(FMF); + Value *RetVal = parsedNode->getLLValue(builder, &VMap); + assert(RetVal && "Parsed node did not produce a value"); + assert((RetVal->getType() == ReturnType) && + "Return value type mismatch with temp function return type"); + builder.CreateRet(RetVal); + + // llvm::errs() << "Temp function before optimizations:\n"; + // tempFunction->print(llvm::errs()); + + runPoseidonFunctionSimplify(*tempFunction, OptimizationLevel::O3); + + // llvm::errs() << "Temp function after optimizations:\n"; + // tempFunction->print(llvm::errs()); + + InstructionCost cost = getCompCost(tempFunction, TTI); + + tempFunction->eraseFromParent(); + return cost; +} \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonHerbieUtils.h b/enzyme/Enzyme/Poseidon/PoseidonHerbieUtils.h new file mode 100644 index 000000000000..5a31244ecdfd --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonHerbieUtils.h @@ -0,0 +1,77 @@ +//=- PoseidonHerbieUtils.h - Herbie integration utilities -----------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares utilities for integrating with the Herbie tool for +// floating-point expression optimization. +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYME_POSEIDON_HERBIE_UTILS_H +#define ENZYME_POSEIDON_HERBIE_UTILS_H + +#include "llvm/ADT/SmallSet.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/CommandLine.h" + +#include +#include +#include + +#include "PoseidonTypes.h" + +using namespace llvm; + +extern "C" { +extern llvm::cl::opt HerbieNumThreads; +extern llvm::cl::opt HerbieTimeout; +extern llvm::cl::opt HerbieNumPoints; +extern llvm::cl::opt HerbieNumIters; +extern llvm::cl::opt HerbieDisableNumerics; +extern llvm::cl::opt HerbieDisableArithmetic; +extern llvm::cl::opt HerbieDisableFractions; +extern llvm::cl::opt HerbieDisableTaylor; +extern llvm::cl::opt HerbieDisableSetupSimplify; +extern llvm::cl::opt HerbieDisableGenSimplify; +extern llvm::cl::opt HerbieDisableRegime; +extern llvm::cl::opt HerbieDisableBranchExpr; +extern llvm::cl::opt HerbieDisableAvgError; +} + +std::shared_ptr parseHerbieExpr( + const std::string &expr, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap); + +bool improveViaHerbie( + const std::vector &inputExprs, + std::vector &COs, Module *M, + const TargetTransformInfo &TTI, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap, + int subgraphIdx); + +std::string getHerbieOperator(const Instruction &I); + +std::string getPrecondition( + const SmallSet &args, + const std::unordered_map> &valueToNodeMap, + const std::unordered_map &symbolToValueMap); + +void setUnifiedAccuracyCost( + CandidateOutput &CO, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap); + +InstructionCost getCompCost( + const std::string &expr, Module *M, const TargetTransformInfo &TTI, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap, + const FastMathFlags &FMF); + +#endif // ENZYME_POSEIDON_HERBIE_UTILS_H diff --git a/enzyme/Enzyme/Poseidon/PoseidonPrecUtils.cpp b/enzyme/Enzyme/Poseidon/PoseidonPrecUtils.cpp new file mode 100644 index 000000000000..4d3c96744c2d --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonPrecUtils.cpp @@ -0,0 +1,711 @@ +//=- PoseidonPrecUtils.cpp - Precision change utilities for Poseidon ------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities for handling precision changes in the Poseidon +// optimization pass. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Transforms/Scalar/EarlyCSE.h" +#include "llvm/Transforms/Scalar/GVN.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Mem2Reg.h" + +#include +#include +#include + +#include "../Utils.h" +#include "Poseidon.h" +#include "PoseidonEvaluators.h" +#include "PoseidonPrecUtils.h" +#include "PoseidonTypes.h" +#include "PoseidonUtils.h" + +using namespace llvm; + +extern "C" { +cl::opt FPOptShowPTDetails( + "fpopt-show-pt-details", cl::init(false), cl::Hidden, + cl::desc("Print details of precision tuning candidates along with the DP " + "table (highly verbose for large applications)")); +cl::opt + FPOptMaxMPFRPrec("fpopt-max-mpfr-prec", cl::init(1024), cl::Hidden, + cl::desc("Max precision for MPFR gold value computation")); +} + +unsigned getMPFRPrec(PrecisionChangeType type) { + switch (type) { + case PrecisionChangeType::BF16: + return 8; + case PrecisionChangeType::FP16: + return 11; + case PrecisionChangeType::FP32: + return 24; + case PrecisionChangeType::FP64: + return 53; + case PrecisionChangeType::FP80: + return 64; + case PrecisionChangeType::FP128: + return 113; + default: + llvm_unreachable("Unsupported FP precision"); + } +} + +Type *getLLVMFPType(PrecisionChangeType type, LLVMContext &context) { + switch (type) { + case PrecisionChangeType::BF16: + return Type::getBFloatTy(context); + case PrecisionChangeType::FP16: + return Type::getHalfTy(context); + case PrecisionChangeType::FP32: + return Type::getFloatTy(context); + case PrecisionChangeType::FP64: + return Type::getDoubleTy(context); + case PrecisionChangeType::FP80: + return Type::getX86_FP80Ty(context); + case PrecisionChangeType::FP128: + return Type::getFP128Ty(context); + default: + llvm_unreachable("Unsupported FP precision"); + } +} + +PrecisionChangeType getPrecisionChangeType(Type *type) { + if (type->isHalfTy()) { + return PrecisionChangeType::BF16; + } else if (type->isHalfTy()) { + return PrecisionChangeType::FP16; + } else if (type->isFloatTy()) { + return PrecisionChangeType::FP32; + } else if (type->isDoubleTy()) { + return PrecisionChangeType::FP64; + } else if (type->isX86_FP80Ty()) { + return PrecisionChangeType::FP80; + } else if (type->isFP128Ty()) { + return PrecisionChangeType::FP128; + } else { + llvm_unreachable("Unsupported FP precision"); + } +} + +StringRef getPrecisionChangeTypeString(PrecisionChangeType type) { + switch (type) { + case PrecisionChangeType::BF16: + return "BF16"; + case PrecisionChangeType::FP16: + return "FP16"; + case PrecisionChangeType::FP32: + return "FP32"; + case PrecisionChangeType::FP64: + return "FP64"; + case PrecisionChangeType::FP80: + return "FP80"; + case PrecisionChangeType::FP128: + return "FP128"; + default: + return "Unknown PT type"; + } +} + +void changePrecision(Instruction *I, PrecisionChange &change, + MapVector &oldToNew) { + if (!Poseidonable(*I)) { + llvm_unreachable("Trying to tune an instruction is not Poseidonable"); + } + + IRBuilder<> Builder(I); + Builder.setFastMathFlags(I->getFastMathFlags()); + Type *newType = getLLVMFPType(change.newType, I->getContext()); + Value *newI = nullptr; + + if (isa(I) || isa(I)) { + SmallVector newOps; + for (auto &operand : I->operands()) { + Value *newOp = nullptr; + if (oldToNew.count(operand)) { + newOp = oldToNew[operand]; + } else if (operand->getType()->isIntegerTy()) { + newOp = operand; + } else { + IRBuilder<> OpBuilder(I); + OpBuilder.setFastMathFlags(I->getFastMathFlags()); + if (isa(operand)) { + newOp = + OpBuilder.CreateFPCast(operand, newType, "fpopt.const.fpcast"); + } else if (isa(operand) || isa(operand)) { + newOp = OpBuilder.CreateFPCast(operand, newType, "fpopt.fpcast"); + } else { + llvm_unreachable("Unsupported operand type"); + } + } + newOps.push_back(newOp); + } + newI = Builder.CreateNAryOp(I->getOpcode(), newOps); + } else if (auto *CI = dyn_cast(I)) { + SmallVector newArgs; + for (auto &arg : CI->args()) { + Value *newArg = nullptr; + if (oldToNew.count(arg)) { + newArg = oldToNew[arg]; + } else if (arg->getType()->isIntegerTy()) { + newArg = arg; + } else { + IRBuilder<> ArgBuilder(I); + ArgBuilder.setFastMathFlags(I->getFastMathFlags()); + if (isa(arg)) { + newArg = ArgBuilder.CreateFPCast(arg, newType, "fpopt.const.fpcast"); + } else if (isa(arg) || isa(arg)) { + newArg = ArgBuilder.CreateFPCast(arg, newType, "fpopt.fpcast"); + } else { + llvm_unreachable("Unsupported argument type"); + } + } + newArgs.push_back(newArg); + } + auto *calledFunc = CI->getCalledFunction(); + if (calledFunc && calledFunc->isIntrinsic()) { + Intrinsic::ID intrinsicID = calledFunc->getIntrinsicID(); + if (intrinsicID != Intrinsic::not_intrinsic) { + // Special cases for intrinsics with mixed types + if (intrinsicID == Intrinsic::powi) { + // powi + SmallVector overloadedTypes; + overloadedTypes.push_back(newType); + overloadedTypes.push_back(CI->getArgOperand(1)->getType()); + Function *newFunc = getIntrinsicDeclaration( + CI->getModule(), intrinsicID, overloadedTypes); + newI = Builder.CreateCall(newFunc, newArgs); + } else { + Function *newFunc = + getIntrinsicDeclaration(CI->getModule(), intrinsicID, newType); + newI = Builder.CreateCall(newFunc, newArgs); + } + } else { + llvm::errs() << "PT: Unknown intrinsic: " << *CI << "\n"; + llvm_unreachable("changePrecision: Unknown intrinsic call to change"); + } + } else { + StringRef funcName = calledFunc->getName(); + std::string newFuncName = getLibmFunctionForPrecision(funcName, newType); + + if (!newFuncName.empty()) { + Module *M = CI->getModule(); + SmallVector newArgTypes(newArgs.size(), newType); + + FunctionCallee newFuncCallee = M->getOrInsertFunction( + newFuncName, FunctionType::get(newType, newArgTypes, false)); + + if (Function *newFunc = dyn_cast(newFuncCallee.getCallee())) { + newI = Builder.CreateCall(newFunc, newArgs); + } else { + llvm::errs() << "PT: Failed to get " + << getPrecisionChangeTypeString(change.newType) + << " libm function for: " << *CI << "\n"; + llvm_unreachable("changePrecision: Failed to get libm function"); + } + } else { + llvm::errs() << "PT: Unknown function call: " << *CI << "\n"; + llvm_unreachable("changePrecision: Unknown function call to change"); + } + } + + } else { + llvm::errs() << "Unexpectedly Poseidonable instruction: " << *I << "\n"; + llvm_unreachable("Unexpectedly Poseidonable instruction"); + } + + oldToNew[I] = newI; +} + +// If `VMap` is passed, map `llvm::Value`s in `subgraph` to their cloned +// values and change outputs in VMap to new casted outputs. +void PTCandidate::apply(Subgraph &subgraph, ValueToValueMapTy *VMap) { + SetVector operations; + ValueToValueMapTy clonedToOriginal; // Maps cloned outputs to old outputs + if (VMap) { + for (auto *I : subgraph.operations) { + assert(VMap->count(I)); + if (Value *Mapped = VMap->lookup(I)) { + if (auto *MappedI = dyn_cast(Mapped)) { + operations.insert(MappedI); + clonedToOriginal[MappedI] = I; + } + } + // llvm::errs() << "Mapping back: " << *VMap->lookup(I) << " (in " + // << cast(VMap->lookup(I)) + // ->getParent() + // ->getParent() + // ->getName() + // << ") --> " << *I << " (in " + // << I->getParent()->getParent()->getName() << ")\n"; + } + } else { + operations = subgraph.operations; + } + + for (auto &change : changes) { + SmallPtrSet seen; + SmallVector todo; + MapVector oldToNew; + + SetVector instsToChange; + for (auto node : change.nodes) { + if (!node || !node->value) { + continue; + } + assert(isa(node->value)); + auto *I = cast(node->value); + if (VMap) { + assert(VMap->count(I)); + if (Value *Mapped = VMap->lookup(I)) { + if (auto *MappedI = dyn_cast(Mapped)) { + I = MappedI; + } else { + continue; + } + } else { + continue; + } + } + if (!operations.contains(I)) { + // Already erased by `CO.apply()`. + continue; + } + instsToChange.insert(I); + } + + SmallVector instsToChangeSorted; + topoSort(instsToChange, instsToChangeSorted); + + for (auto *I : instsToChangeSorted) { + changePrecision(I, change, oldToNew); + } + + // Restore the precisions of the last level of instructions to be changed. + // Clean up old instructions. + for (auto &[oldV, newV] : oldToNew) { + if (!isa(oldV)) { + continue; + } + + if (!instsToChange.contains(cast(oldV))) { + continue; + } + + SmallPtrSet users; + for (auto *user : oldV->users()) { + assert(isa(user) && + "PT: Unexpected non-instruction user of a changed instruction"); + if (!instsToChange.contains(cast(user))) { + users.insert(cast(user)); + } + } + + Value *casted = nullptr; + if (!users.empty()) { + IRBuilder<> builder(cast(oldV)->getParent(), + ++BasicBlock::iterator(cast(oldV))); + casted = builder.CreateFPCast( + newV, getLLVMFPType(change.oldType, builder.getContext())); + + if (VMap) { + assert(VMap->count(clonedToOriginal[oldV])); + (*VMap)[clonedToOriginal[oldV]] = casted; + } + } + + for (auto *user : users) { + user->replaceUsesOfWith(oldV, casted); + } + + // Assumes no external uses of the old value since all corresponding new + // values are already restored to original precision and used to replace + // uses of their old value. This is also advantageous to the solvers. + for (auto *user : oldV->users()) { + assert(instsToChange.contains(cast(user)) && + "PT: Unexpected external user of a changed instruction"); + } + + if (!oldV->use_empty()) { + oldV->replaceAllUsesWith(UndefValue::get(oldV->getType())); + } + + cast(oldV)->eraseFromParent(); + + // The change is being materialized to the original subgraph + if (!VMap) + subgraph.operations.remove(cast(oldV)); + } + } +} + +void setUnifiedAccuracyCost( + CandidateSubgraph &CS, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + + SmallVector, 4> sampledPoints; + getSampledPoints(CS.subgraph->inputs.getArrayRef(), valueToNodeMap, + symbolToValueMap, sampledPoints); + + MapVector> goldVals; + for (auto *output : CS.subgraph->outputs) { + auto *node = valueToNodeMap[output].get(); + goldVals[node].resize(FPOptNumSamples); + CS.perOutputInitialAccCost[node] = 0.; + } + + SmallVector outputs; + for (auto *output : CS.subgraph->outputs) + outputs.push_back(valueToNodeMap[output].get()); + + if (FPOptReductionEval == "geomean") { + struct RunningAcc { + double sumLog = 0.0; + unsigned count = 0; + }; + std::unordered_map runAcc; + for (auto *node : outputs) + runAcc[node] = RunningAcc(); + + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getMPFRValues(outputs, pair.value(), results, true, 53); + for (const auto &[node, result] : zip(outputs, results)) + goldVals[node][pair.index()] = result; + + getFPValues(outputs, pair.value(), results); + for (const auto &[node, result] : zip(outputs, results)) { + double goldVal = goldVals[node][pair.index()]; + double error = std::fabs(goldVal - result); + if (!std::isnan(error)) { + if (error == 0.0) { + if (FPOptGeoMeanEps == 0.0) + error = getOneULP(goldVal); + else + error += FPOptGeoMeanEps; + } else if (FPOptGeoMeanEps != 0.0) { + error += FPOptGeoMeanEps; + } + runAcc[node].sumLog += std::log(error); + ++runAcc[node].count; + } + } + } + CS.initialAccCost = 0.0; + for (auto *node : outputs) { + RunningAcc &ra = runAcc[node]; + assert(ra.count != 0 && "No valid sample found for original subgraph"); + double red = std::exp(ra.sumLog / ra.count); + CS.perOutputInitialAccCost[node] = red * std::fabs(node->grad); + CS.initialAccCost += CS.perOutputInitialAccCost[node]; + } + } else if (FPOptReductionEval == "arithmean") { + struct RunningAccArith { + double sum = 0.0; + unsigned count = 0; + }; + std::unordered_map runAcc; + for (auto *node : outputs) + runAcc[node] = RunningAccArith(); + + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getMPFRValues(outputs, pair.value(), results, true, 53); + for (const auto &[node, result] : zip(outputs, results)) + goldVals[node][pair.index()] = result; + + getFPValues(outputs, pair.value(), results); + for (const auto &[node, result] : zip(outputs, results)) { + double goldVal = goldVals[node][pair.index()]; + double error = std::fabs(goldVal - result); + if (!std::isnan(error)) { + runAcc[node].sum += error; + ++runAcc[node].count; + } + } + } + CS.initialAccCost = 0.0; + for (auto *node : outputs) { + auto &ra = runAcc[node]; + assert(ra.count != 0 && "No valid sample found for original subgraph"); + double red = ra.sum / ra.count; + CS.perOutputInitialAccCost[node] = red * std::fabs(node->grad); + CS.initialAccCost += CS.perOutputInitialAccCost[node]; + } + } else if (FPOptReductionEval == "maxabs") { + std::unordered_map runAcc; + for (auto *node : outputs) + runAcc[node] = 0.0; + + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getMPFRValues(outputs, pair.value(), results, true, 53); + for (const auto &[node, result] : zip(outputs, results)) + goldVals[node][pair.index()] = result; + + getFPValues(outputs, pair.value(), results); + for (const auto &[node, result] : zip(outputs, results)) { + double goldVal = goldVals[node][pair.index()]; + double error = std::fabs(goldVal - result); + if (!std::isnan(error)) + runAcc[node] = std::max(runAcc[node], error); + } + } + CS.initialAccCost = 0.0; + for (auto *node : outputs) { + double red = runAcc[node]; + CS.perOutputInitialAccCost[node] = red * std::fabs(node->grad); + CS.initialAccCost += CS.perOutputInitialAccCost[node]; + } + } else { + llvm_unreachable("Unknown fpopt-reduction strategy"); + } + assert(!std::isnan(CS.initialAccCost)); + + SmallVector newCandidates; + for (auto &candidate : CS.candidates) { + bool discardCandidate = false; + if (FPOptReductionEval == "geomean") { + struct RunningAcc { + double sumLog = 0.0; + unsigned count = 0; + }; + std::unordered_map candAcc; + for (auto *node : outputs) + candAcc[node] = RunningAcc(); + + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getFPValues(outputs, pair.value(), results, &candidate); + for (const auto &[node, result] : zip(outputs, results)) { + double goldVal = goldVals[node][pair.index()]; + if (FPOptStrictMode && !std::isnan(goldVal) && std::isnan(result)) { + discardCandidate = true; + break; + } + double error = std::fabs(goldVal - result); + if (!std::isnan(error)) { + if (error == 0.0) { + if (FPOptGeoMeanEps == 0.0) + error = getOneULP(goldVal); + else + error += FPOptGeoMeanEps; + } else if (FPOptGeoMeanEps != 0.0) { + error += FPOptGeoMeanEps; + } + candAcc[node].sumLog += std::log(error); + ++candAcc[node].count; + } + } + if (discardCandidate) + break; + } + if (!discardCandidate) { + candidate.accuracyCost = 0.0; + for (auto *node : outputs) { + RunningAcc &ra = candAcc[node]; + assert(ra.count != 0 && + "No valid sample found for candidate subgraph"); + double red = std::exp(ra.sumLog / ra.count); + candidate.perOutputAccCost[node] = red * std::fabs(node->grad); + candidate.accuracyCost += candidate.perOutputAccCost[node]; + } + assert(!std::isnan(candidate.accuracyCost)); + newCandidates.push_back(std::move(candidate)); + } + } else if (FPOptReductionEval == "arithmean") { + struct RunningAccArith { + double sum = 0.0; + unsigned count = 0; + }; + std::unordered_map candAcc; + for (auto *node : outputs) + candAcc[node] = RunningAccArith(); + + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getFPValues(outputs, pair.value(), results, &candidate); + for (const auto &[node, result] : zip(outputs, results)) { + double goldVal = goldVals[node][pair.index()]; + if (FPOptStrictMode && !std::isnan(goldVal) && std::isnan(result)) { + discardCandidate = true; + break; + } + double error = std::fabs(goldVal - result); + if (!std::isnan(error)) { + candAcc[node].sum += error; + ++candAcc[node].count; + } + } + if (discardCandidate) + break; + } + if (!discardCandidate) { + candidate.accuracyCost = 0.0; + for (auto *node : outputs) { + auto &ra = candAcc[node]; + assert(ra.count != 0 && + "No valid sample found for candidate subgraph"); + double red = ra.sum / ra.count; + candidate.perOutputAccCost[node] = red * std::fabs(node->grad); + candidate.accuracyCost += candidate.perOutputAccCost[node]; + } + assert(!std::isnan(candidate.accuracyCost)); + newCandidates.push_back(std::move(candidate)); + } + } else if (FPOptReductionEval == "maxabs") { + std::unordered_map candAcc; + for (auto *node : outputs) + candAcc[node] = 0.0; + + for (const auto &pair : enumerate(sampledPoints)) { + SmallVector results; + getFPValues(outputs, pair.value(), results, &candidate); + for (const auto &[node, result] : zip(outputs, results)) { + double goldVal = goldVals[node][pair.index()]; + if (FPOptStrictMode && !std::isnan(goldVal) && std::isnan(result)) { + discardCandidate = true; + break; + } + double error = std::fabs(goldVal - result); + if (!std::isnan(error)) + candAcc[node] = std::max(candAcc[node], error); + } + if (discardCandidate) + break; + } + if (!discardCandidate) { + candidate.accuracyCost = 0.0; + for (auto *node : outputs) { + double red = candAcc[node]; + candidate.perOutputAccCost[node] = red * std::fabs(node->grad); + candidate.accuracyCost += candidate.perOutputAccCost[node]; + } + assert(!std::isnan(candidate.accuracyCost)); + newCandidates.push_back(std::move(candidate)); + } + } else { + llvm_unreachable("Unknown fpopt-reduction strategy"); + } + } + CS.candidates = std::move(newCandidates); +} + +InstructionCost getCompCost(Subgraph &subgraph, const TargetTransformInfo &TTI, + PTCandidate &pt) { + assert(!subgraph.outputs.empty()); + + InstructionCost cost = 0; + + Function *F = cast(subgraph.outputs[0])->getFunction(); + + ValueToValueMapTy VMap; + Function *FClone = CloneFunction(F, VMap); + FClone->setName(F->getName() + "_clone"); + + pt.apply(subgraph, &VMap); + + // llvm::errs() << "\n========================================\n"; + // llvm::errs() << "DEBUG PT Cost Measurement: " << pt.desc << "\n"; + // llvm::errs() << "========================================\n"; + // llvm::errs() << "Before optimization:\n"; + // FClone->print(llvm::errs()); + // llvm::errs() << "========================================\n"; + + SmallVector trackedInputs; + for (auto &input : subgraph.inputs) { + if (VMap.count(input)) + trackedInputs.emplace_back(VMap[input]); + } + SmallVector trackedOutputs; + for (auto &output : subgraph.outputs) { + if (VMap.count(output)) + trackedOutputs.emplace_back(VMap[output]); + } + + runPoseidonFunctionSimplify(*FClone, OptimizationLevel::O3); + + // llvm::errs() << "Cloned function AFTER precision changes and + // optimization:\n"; FClone->print(llvm::errs()); llvm::errs() << + // "========================================\n"; + + SmallPtrSet clonedInputs; + for (auto &VH : trackedInputs) { + if (!VH.pointsToAliveValue()) + continue; + Value *V = VH; + if (V) + clonedInputs.insert(V); + } + + SmallPtrSet clonedOutputs; + for (auto &VH : trackedOutputs) { + if (!VH.pointsToAliveValue()) + continue; + Value *V = VH; + if (V) + clonedOutputs.insert(V); + } + + SmallPtrSet seen; + SmallVector todo; + + todo.insert(todo.end(), clonedOutputs.begin(), clonedOutputs.end()); + while (!todo.empty()) { + auto cur = todo.pop_back_val(); + if (!seen.insert(cur).second) + continue; + + if (clonedInputs.contains(cur)) + continue; + + if (auto *I = dyn_cast(cur)) { + auto instCost = getInstructionCompCost(I, TTI); + + // if (I->getType()->isFPOrFPVectorTy() || + // (I->getOpcode() == Instruction::FCmp)) { + // Type *Ty = I->getType(); + // if (I->getOpcode() == Instruction::FCmp) + // Ty = I->getOperand(0)->getType(); + + // std::string precStr = "unknown"; + // if (Ty->isFloatTy()) precStr = "FP32"; + // else if (Ty->isDoubleTy()) precStr = "FP64"; + + // llvm::errs() << " [" << precStr << "] Cost=" << instCost + // << " : " << *I << "\n"; + // } + + cost += instCost; + + auto operands = + isa(I) ? cast(I)->args() : I->operands(); + for (auto &operand : operands) { + todo.push_back(operand); + } + } + } + // llvm::errs() << "Total PT cost for " << pt.desc << ": " << cost << "\n"; + // llvm::errs() << "========================================\n\n"; + + FClone->eraseFromParent(); + + return cost; +} \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonPrecUtils.h b/enzyme/Enzyme/Poseidon/PoseidonPrecUtils.h new file mode 100644 index 000000000000..909e0b2512c4 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonPrecUtils.h @@ -0,0 +1,85 @@ +//=- PoseidonPrecUtils.h - Precision change utilities for Poseidon --------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares utilities for handling precision changes in the Poseidon +// optimization pass. +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYME_POSEIDON_PREC_UTILS_H +#define ENZYME_POSEIDON_PREC_UTILS_H + +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InstructionCost.h" +#include "llvm/Transforms/Utils/ValueMapper.h" + +#include +#include +#include + +using namespace llvm; + +extern "C" { +extern llvm::cl::opt FPOptShowPTDetails; +extern llvm::cl::opt FPOptMaxMPFRPrec; +} + +class FPNode; +class FPLLValue; +struct Subgraph; +class CandidateSubgraph; + +enum class PrecisionChangeType { BF16, FP16, FP32, FP64, FP80, FP128 }; +unsigned getMPFRPrec(PrecisionChangeType type); +Type *getLLVMFPType(PrecisionChangeType type, LLVMContext &context); +PrecisionChangeType getPrecisionChangeType(Type *type); +StringRef getPrecisionChangeTypeString(PrecisionChangeType type); + +struct PrecisionChange { + SetVector nodes; + PrecisionChangeType oldType; + PrecisionChangeType newType; + + explicit PrecisionChange(SetVector &nodes, + PrecisionChangeType oldType, + PrecisionChangeType newType) + : nodes(nodes), oldType(oldType), newType(newType) {} +}; + +struct PTCandidate { + SmallVector changes; + double accuracyCost = std::numeric_limits::quiet_NaN(); + InstructionCost CompCost = std::numeric_limits::max(); + std::string desc; + std::unordered_map perOutputAccCost; + std::unordered_map> errors; + + explicit PTCandidate(SmallVector changes, + const std::string &desc) + : changes(std::move(changes)), desc(desc) {} + + void apply(Subgraph &subgraph, ValueToValueMapTy *VMap = nullptr); +}; + +void changePrecision(Instruction *I, PrecisionChange &change, + MapVector &oldToNew); + +InstructionCost getCompCost(Subgraph &subgraph, const TargetTransformInfo &TTI, + PTCandidate &pt); + +void setUnifiedAccuracyCost( + CandidateSubgraph &CS, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap); + +#endif // ENZYME_POSEIDON_PREC_UTILS_H \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonProfUtils.cpp b/enzyme/Enzyme/Poseidon/PoseidonProfUtils.cpp new file mode 100644 index 000000000000..975c730bfe38 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonProfUtils.cpp @@ -0,0 +1,163 @@ +//=- PoseidonProfUtils.cpp - Profiling utilities for Poseidon +//------------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements profiling-related utilities for the Poseidon +// optimization pass. +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "PoseidonProfUtils.h" +#include "PoseidonUtils.h" + +using namespace llvm; + +extern "C" { +cl::opt FPOptLooseCoverage( + "fpopt-loose-coverage", cl::init(false), cl::Hidden, + cl::desc("Allow unexecuted FP instructions in subgraph indentification")); +cl::opt + FPOptWidenRange("fpopt-widen-range", cl::init(1), cl::Hidden, + cl::desc("Ablation study only: widen the range of input " + "hypercube by this factor")); +} + +void parseProfileFile(const std::string &profilePath, + std::unordered_map &profileMap) { + profileMap.clear(); + std::ifstream file(profilePath); + if (!file.is_open()) { + llvm::errs() << "Warning: Could not open profile file: " << profilePath + << "\n"; + return; + } + + std::string line; + std::regex indexPattern(R"(^(\d+)$)"); + + std::regex minResPattern( + R"(^\s*MinRes\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); + std::regex maxResPattern( + R"(^\s*MaxRes\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); + std::regex sumValuePattern( + R"(^\s*SumValue\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); + std::regex sumSensPattern( + R"(^\s*SumSens\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); + std::regex sumGradPattern( + R"(^\s*SumGrad\s*=\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan))"); + std::regex execPattern(R"(^\s*Exec\s*=\s*(\d+))"); + std::regex numOperandsPattern(R"(^\s*NumOperands\s*=\s*(\d+))"); + std::regex operandPattern( + R"(^\s*Operand\[(\d+)\]\s*=\s*\[([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan),\s*([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?|inf|-inf|nan|-nan)\])"); + + while (std::getline(file, line)) { + if (!line.empty() && line.back() == '\r') { + line.pop_back(); + } + + std::smatch match; + if (std::regex_match(line, match, indexPattern)) { + size_t idx = std::stoull(match[1]); + ProfileInfo info; + + std::string minResLine, maxResLine, sumValueLine, sumSensLine, + sumGradLine, execLine, numOperandsLine; + + if (std::getline(file, minResLine) && std::getline(file, maxResLine) && + std::getline(file, sumValueLine) && std::getline(file, sumSensLine) && + std::getline(file, sumGradLine) && std::getline(file, execLine) && + std::getline(file, numOperandsLine)) { + + auto stripCR = [](std::string &s) { + if (!s.empty() && s.back() == '\r') + s.pop_back(); + }; + stripCR(minResLine); + stripCR(maxResLine); + stripCR(sumValueLine); + stripCR(sumSensLine); + stripCR(sumGradLine); + stripCR(execLine); + stripCR(numOperandsLine); + + std::smatch mMinRes, mMaxRes, mSumValue, mSumSens, mSumGrad, mExec, + mNumOperands; + if (std::regex_search(minResLine, mMinRes, minResPattern) && + std::regex_search(maxResLine, mMaxRes, maxResPattern) && + std::regex_search(sumValueLine, mSumValue, sumValuePattern) && + std::regex_search(sumSensLine, mSumSens, sumSensPattern) && + std::regex_search(sumGradLine, mSumGrad, sumGradPattern) && + std::regex_search(execLine, mExec, execPattern) && + std::regex_search(numOperandsLine, mNumOperands, + numOperandsPattern)) { + + info.minRes = stringToDouble(mMinRes[1]); + info.maxRes = stringToDouble(mMaxRes[1]); + info.sumValue = stringToDouble(mSumValue[1]); + info.sumSens = stringToDouble(mSumSens[1]); + info.sumGrad = stringToDouble(mSumGrad[1]); + info.exec = static_cast(std::stoul(mExec[1])); + unsigned numOperands = + static_cast(std::stoul(mNumOperands[1])); + + info.minOperands.resize(numOperands, 0.0); + info.maxOperands.resize(numOperands, 0.0); + + for (unsigned i = 0; i < numOperands; ++i) { + if (std::getline(file, line)) { + if (!line.empty() && line.back() == '\r') + line.pop_back(); + + std::smatch operandMatch; + if (std::regex_search(line, operandMatch, operandPattern)) { + unsigned opIdx = + static_cast(std::stoul(operandMatch[1])); + double minVal = stringToDouble(operandMatch[2]); + double maxVal = stringToDouble(operandMatch[3]); + + if (opIdx < numOperands) { + info.minOperands[opIdx] = minVal; + info.maxOperands[opIdx] = maxVal; + } + } + } + } + + if (FPOptWidenRange != 1.0) { + double center = (info.minRes + info.maxRes) / 2.0; + double half_range = (info.maxRes - info.minRes) / 2.0; + double new_half_range = half_range * FPOptWidenRange; + info.minRes = center - new_half_range; + info.maxRes = center + new_half_range; + + for (size_t i = 0; i < info.minOperands.size(); ++i) { + double op_center = + (info.minOperands[i] + info.maxOperands[i]) / 2.0; + double op_half_range = + (info.maxOperands[i] - info.minOperands[i]) / 2.0; + double op_new_half_range = op_half_range * FPOptWidenRange; + info.minOperands[i] = op_center - op_new_half_range; + info.maxOperands[i] = op_center + op_new_half_range; + } + } + + profileMap[idx] = info; + } else { + llvm::errs() << "Warning: Failed to parse profile fields for index " + << idx << "\n"; + } + } else { + llvm::errs() << "Warning: Incomplete profile entry for index " << idx + << "\n"; + } + } + } +} diff --git a/enzyme/Enzyme/Poseidon/PoseidonProfUtils.h b/enzyme/Enzyme/Poseidon/PoseidonProfUtils.h new file mode 100644 index 000000000000..6c0f45c12463 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonProfUtils.h @@ -0,0 +1,51 @@ +//=- PoseidonProfUtils.h - Profiling utilities for Poseidon +//--------------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares profiling-related utilities for the Poseidon optimization +// pass. +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYME_POSEIDON_PROF_UTILS_H +#define ENZYME_POSEIDON_PROF_UTILS_H + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/CommandLine.h" + +#include +#include + +using namespace llvm; + +extern "C" { +extern llvm::cl::opt FPOptLooseCoverage; +extern llvm::cl::opt FPOptWidenRange; +} + +struct ProfileInfo { + double minRes; + double maxRes; + double sumValue; // Sum of values (not abs) + double sumSens; // Sum of sensitivity scores = |grad * value| + double sumGrad; // Sum of gradients (not abs) + unsigned exec; + + SmallVector minOperands; + SmallVector maxOperands; + + ProfileInfo() + : minRes(std::numeric_limits::max()), + maxRes(std::numeric_limits::lowest()), sumValue(0.0), + sumSens(0.0), sumGrad(0.0), exec(0) {} +}; + +void parseProfileFile(const std::string &profilePath, + std::unordered_map &profileMap); + +#endif // ENZYME_POSEIDON_PROF_UTILS_H \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonSolvers.cpp b/enzyme/Enzyme/Poseidon/PoseidonSolvers.cpp new file mode 100644 index 000000000000..79d7cef7d484 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonSolvers.cpp @@ -0,0 +1,1442 @@ +//=- PoseidonSolvers.cpp - Solver utilities for Poseidon ------------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements solver-related utilities for the Poseidon optimization +// pass. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/raw_ostream.h" + +#include "Poseidon.h" +#include "PoseidonSolvers.h" +#include "PoseidonTypes.h" + +#include "PoseidonHerbieUtils.h" +#include "PoseidonPrecUtils.h" +#include "PoseidonUtils.h" + +#include "llvm/IR/IRBuilder.h" +#include "llvm/Transforms/Utils/Cloning.h" + +#include + +using namespace llvm; + +extern "C" { +cl::opt FPOptSolverType("fpopt-solver-type", cl::init("dp"), + cl::Hidden, + cl::desc("Which solver to use; " + "either 'dp' or 'greedy'")); +cl::opt FPOptComputationCostBudget( + "fpopt-comp-cost-budget", cl::init(0L), cl::Hidden, + cl::desc("The maximum computation cost budget for the solver")); +cl::opt FPOptShowTable( + "fpopt-show-table", cl::init(false), cl::Hidden, + cl::desc( + "Print the full DP table (highly verbose for large applications)")); +cl::list FPOptShowTableCosts( + "fpopt-show-table-costs", cl::ZeroOrMore, cl::CommaSeparated, cl::Hidden, + cl::desc( + "Comma-separated list of computation costs for which to print DP table " + "entries. If provided, only specified computation costs are " + "printed. ")); +cl::opt FPOptEarlyPrune( + "fpopt-early-prune", cl::init(true), cl::Hidden, + cl::desc("Prune dominated candidates in expression transformation phases")); +cl::opt FPOptCostDominanceThreshold( + "fpopt-cost-dom-thres", cl::init(0.0), cl::Hidden, + cl::desc("The threshold for cost dominance in DP solver")); +cl::opt FPOptAccuracyDominanceThreshold( + "fpopt-acc-dom-thres", cl::init(0.0), cl::Hidden, + cl::desc("The threshold for accuracy dominance in DP solver")); +cl::opt FPOptRefineDPTable( + "fpopt-refine-dp", cl::init(false), cl::Hidden, + cl::desc("After initial DP build, materialize each solution in a cloned\n" + "function, run O3, and recompute cost deltas. Only applies when\n" + "generating a new cache, not when loading from cache.")); +cl::opt FPOptReportPath( + "fpopt-report-path", cl::init(""), cl::Hidden, + cl::desc("Directory to write Poseidon optimization reports.\n" + "Emits .json (Pareto table with source locations),\n" + ".txt (human-readable), _rewrites.json\n" + "(curated per-rewrite analysis with IDs), and\n" + "validate_config.json + validate.py (validation script).")); +cl::opt FPOptApplyRewrites( + "fpopt-apply-rewrites", cl::init(""), cl::Hidden, + cl::desc("Comma-separated rewrite IDs to apply (e.g.\n" + "R0_1,R3_0,PT1_2). IDs are from the _rewrites.json\n" + "report. Bypasses the DP solver. At most one candidate\n" + "per expression (R) or subgraph (PT).")); +} + +#if LLVM_VERSION_MAJOR >= 21 +#define GET_INSTRUCTION_COST(cost) (cost.getValue()) +#else +#define GET_INSTRUCTION_COST(cost) (cost.getValue().value()) +#endif + +static json::Value jsonFloat(double v) { + if (std::isfinite(v)) + return json::Value(v); + return json::Value(nullptr); +} + +static json::Object getSourceLocationJSON(Value *V) { + json::Object loc; + if (auto *I = dyn_cast(V)) { + if (const auto &DL = I->getDebugLoc()) { + loc["file"] = DL->getFilename().str(); + loc["line"] = static_cast(DL.getLine()); + loc["col"] = static_cast(DL.getCol()); + } + } + return loc; +} + +static std::string getSourceLocationStr(Value *V) { + if (auto *I = dyn_cast(V)) { + if (const auto &DL = I->getDebugLoc()) { + return (DL->getFilename() + ":" + Twine(DL.getLine()) + ":" + + Twine(DL.getCol())) + .str(); + } + } + return ""; +} + +static std::string getValueStr(Value *V) { + std::string str; + raw_string_ostream OS(str); + V->print(OS); + return str; +} + +// Collect unique source locations from a set of LLVM instructions. +static json::Array getSourceLocationsJSON(ArrayRef insts) { + json::Array locs; + SmallSet seen; + for (auto *I : insts) { + auto locStr = getSourceLocationStr(I); + if (!locStr.empty() && seen.insert(locStr).second) { + locs.push_back(getSourceLocationJSON(I)); + } + } + return locs; +} + +static json::Object +buildStepJSON(const SolutionStep &step, + const std::map &costToAccuracyMap) { + json::Object stepObj; + std::visit( + [&](auto *item) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + stepObj["type"] = "rewrite"; + stepObj["original_expr"] = item->expr; + auto &cand = item->candidates[step.candidateIndex]; + stepObj["rewritten_expr"] = cand.expr; + stepObj["herbie_cost"] = jsonFloat(cand.herbieCost); + stepObj["herbie_accuracy"] = jsonFloat(cand.herbieAccuracy); + stepObj["initial_herbie_cost"] = jsonFloat(item->initialHerbieCost); + stepObj["initial_herbie_accuracy"] = + jsonFloat(item->initialHerbieAccuracy); + stepObj["computation_cost_delta"] = + GET_INSTRUCTION_COST(item->getCompCostDelta(step.candidateIndex)); + stepObj["accuracy_cost_delta"] = + jsonFloat(item->getAccCostDelta(step.candidateIndex)); + stepObj["gradient"] = jsonFloat(item->grad); + stepObj["executions"] = static_cast(item->executions); + + // Source locations from the output instruction and erasable insts + SmallVector insts; + if (auto *I = dyn_cast(item->oldOutput)) + insts.push_back(I); + for (auto *I : item->erasableInsts) + insts.push_back(I); + stepObj["source_locations"] = getSourceLocationsJSON(insts); + + // Affected LLVM IR instructions + json::Array affectedIR; + if (auto *I = dyn_cast(item->oldOutput)) + affectedIR.push_back(getValueStr(I)); + for (auto *I : item->erasableInsts) { + if (I != item->oldOutput) + affectedIR.push_back(getValueStr(I)); + } + stepObj["affected_instructions"] = std::move(affectedIR); + } else if constexpr (std::is_same_v) { + auto &cand = item->candidates[step.candidateIndex]; + stepObj["type"] = "precision_change"; + stepObj["description"] = cand.desc; + stepObj["candidate_index"] = + static_cast(step.candidateIndex); + + json::Array changes; + for (const auto &change : cand.changes) { + json::Object changeObj; + changeObj["from"] = getPrecisionChangeTypeString(change.oldType); + changeObj["to"] = getPrecisionChangeTypeString(change.newType); + changeObj["num_operations"] = + static_cast(change.nodes.size()); + + SmallVector insts; + json::Array nodeIR; + for (auto *node : change.nodes) { + if (auto *I = dyn_cast(node->value)) { + insts.push_back(I); + nodeIR.push_back(getValueStr(I)); + } + } + changeObj["source_locations"] = getSourceLocationsJSON(insts); + changeObj["affected_instructions"] = std::move(nodeIR); + changes.push_back(std::move(changeObj)); + } + stepObj["changes"] = std::move(changes); + } + }, + step.item); + return stepObj; +} + +static void writeTextReportStep(raw_ostream &OS, const SolutionStep &step, + unsigned indent) { + std::string pad(indent, ' '); + std::visit( + [&](auto *item) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + auto &cand = item->candidates[step.candidateIndex]; + OS << pad << "[Rewrite] " << item->expr << " --> " << cand.expr + << "\n"; + if (!std::isnan(item->initialHerbieAccuracy) && + !std::isnan(cand.herbieAccuracy)) + OS << pad << " Herbie accuracy: " << item->initialHerbieAccuracy + << " -> " << cand.herbieAccuracy << " bits\n"; + OS << pad << " Gradient: " << item->grad + << ", Executions: " << item->executions << "\n"; + + // Source locations + SmallSet seen; + auto printLoc = [&](Instruction *I) { + auto loc = getSourceLocationStr(I); + if (!loc.empty() && seen.insert(loc).second) + OS << pad << " Source: " << loc << "\n"; + }; + if (auto *I = dyn_cast(item->oldOutput)) + printLoc(I); + for (auto *I : item->erasableInsts) + printLoc(I); + + // Affected IR + OS << pad << " Affected IR:\n"; + if (auto *I = dyn_cast(item->oldOutput)) + OS << pad << " " << *I << "\n"; + for (auto *I : item->erasableInsts) { + if (I != item->oldOutput) + OS << pad << " " << *I << "\n"; + } + } else if constexpr (std::is_same_v) { + auto &cand = item->candidates[step.candidateIndex]; + OS << pad << "[Precision] " << cand.desc << " (#" + << step.candidateIndex << ")\n"; + for (const auto &change : cand.changes) { + OS << pad << " " << getPrecisionChangeTypeString(change.oldType) + << " -> " << getPrecisionChangeTypeString(change.newType) + << " for " << change.nodes.size() << " operations\n"; + SmallSet seen; + for (auto *node : change.nodes) { + if (auto *I = dyn_cast(node->value)) { + auto loc = getSourceLocationStr(I); + if (!loc.empty() && seen.insert(loc).second) + OS << pad << " Source: " << loc << "\n"; + } + } + } + } + }, + step.item); +} + +static void +emitPoseidonReport(StringRef funcName, + const std::map &costToAccuracyMap, + const std::map> + &costToSolutionMap, + SmallVector &COs, + SmallVector &CSs) { + if (FPOptReportPath.empty()) + return; + + std::error_code EC; + if (!llvm::sys::fs::exists(FPOptReportPath)) { + EC = llvm::sys::fs::create_directories(FPOptReportPath); + if (EC) { + llvm::errs() << "Error creating report directory: " << EC.message() + << "\n"; + return; + } + } + + // Build JSON report + json::Object report; + report["function"] = funcName.str(); + report["num_pareto_points"] = static_cast(costToAccuracyMap.size()); + + if (!costToAccuracyMap.empty()) { + report["cost_range_min"] = + GET_INSTRUCTION_COST(costToAccuracyMap.begin()->first); + report["cost_range_max"] = + GET_INSTRUCTION_COST(costToAccuracyMap.rbegin()->first); + } + + // Candidate summary + json::Array coSummary; + for (const auto &CO : COs) { + json::Object co; + co["original_expr"] = CO.expr; + co["num_candidates"] = static_cast(CO.candidates.size()); + co["gradient"] = jsonFloat(CO.grad); + co["executions"] = static_cast(CO.executions); + co["initial_accuracy_cost"] = jsonFloat(CO.initialAccCost); + co["initial_computation_cost"] = GET_INSTRUCTION_COST(CO.initialCompCost); + if (auto *I = dyn_cast(CO.oldOutput)) { + auto loc = getSourceLocationJSON(I); + if (!loc.empty()) + co["source_location"] = std::move(loc); + } + coSummary.push_back(std::move(co)); + } + report["candidate_outputs"] = std::move(coSummary); + + json::Array csSummary; + for (const auto &CS : CSs) { + json::Object cs; + cs["num_candidates"] = static_cast(CS.candidates.size()); + cs["initial_accuracy_cost"] = jsonFloat(CS.initialAccCost); + cs["initial_computation_cost"] = GET_INSTRUCTION_COST(CS.initialCompCost); + csSummary.push_back(std::move(cs)); + } + report["candidate_subgraphs"] = std::move(csSummary); + + // Pareto table + json::Array paretoPoints; + for (const auto &pair : costToAccuracyMap) { + json::Object point; + point["computation_cost"] = GET_INSTRUCTION_COST(pair.first); + point["accuracy_cost"] = pair.second; + + auto it = costToSolutionMap.find(pair.first); + if (it != costToSolutionMap.end()) { + json::Array steps; + for (const auto &step : it->second) { + steps.push_back(buildStepJSON(step, costToAccuracyMap)); + } + point["steps"] = std::move(steps); + } + paretoPoints.push_back(std::move(point)); + } + report["pareto_points"] = std::move(paretoPoints); + + // Write JSON + std::string jsonFile = + (Twine(FPOptReportPath) + "/" + funcName + ".json").str(); + raw_fd_ostream jsonOut(jsonFile, EC, sys::fs::OF_Text); + if (EC) { + llvm::errs() << "Error writing JSON report: " << EC.message() << "\n"; + } else { + jsonOut << formatv("{0:2}", json::Value(std::move(report))) << "\n"; + llvm::errs() << "Poseidon JSON report written to " << jsonFile << "\n"; + } + + // Write human-readable text report + std::string textFile = + (Twine(FPOptReportPath) + "/" + funcName + ".txt").str(); + raw_fd_ostream textOut(textFile, EC, sys::fs::OF_Text); + if (EC) { + llvm::errs() << "Error writing text report: " << EC.message() << "\n"; + return; + } + + textOut << "=== Poseidon Report: " << funcName << " ===\n"; + textOut << "Pareto table: " << costToAccuracyMap.size() << " points"; + if (!costToAccuracyMap.empty()) { + textOut << ", cost range [" + << GET_INSTRUCTION_COST(costToAccuracyMap.begin()->first) << ", " + << GET_INSTRUCTION_COST(costToAccuracyMap.rbegin()->first) << "]"; + } + textOut << "\n"; + textOut << "Candidate outputs: " << COs.size() + << ", Candidate subgraphs: " << CSs.size() << "\n\n"; + + // Per-CO summary + for (size_t i = 0; i < COs.size(); ++i) { + textOut << "Expression #" << i << ": " << COs[i].expr << "\n"; + textOut << " Gradient: " << COs[i].grad + << ", Executions: " << COs[i].executions << "\n"; + if (auto *I = dyn_cast(COs[i].oldOutput)) { + auto loc = getSourceLocationStr(I); + if (!loc.empty()) + textOut << " Source: " << loc << "\n"; + } + textOut << " Candidates: " << COs[i].candidates.size() << "\n"; + for (size_t j = 0; j < COs[i].candidates.size(); ++j) { + auto &cand = COs[i].candidates[j]; + textOut << " [" << j << "] " << cand.expr; + if (!std::isnan(cand.herbieAccuracy)) + textOut << " (accuracy: " << cand.herbieAccuracy << " bits)"; + textOut << "\n"; + } + } + textOut << "\n"; + + // Pareto points + unsigned pointIdx = 0; + for (const auto &pair : costToAccuracyMap) { + textOut << "--- Pareto Point #" << pointIdx++ + << ": Cost=" << GET_INSTRUCTION_COST(pair.first) + << ", Accuracy=" << pair.second << " ---\n"; + auto it = costToSolutionMap.find(pair.first); + if (it != costToSolutionMap.end() && !it->second.empty()) { + for (const auto &step : it->second) { + writeTextReportStep(textOut, step, 2); + } + } else { + textOut << " (no changes)\n"; + } + textOut << "\n"; + } + + llvm::errs() << "Poseidon text report written to " << textFile << "\n"; + + // Emit validation config + script + { + std::string configFile = + (Twine(FPOptReportPath) + "/validate_config.json").str(); + { + json::Object cfg; + cfg["function"] = funcName.str(); + cfg["profile_path"] = FPProfileUse.getValue(); + cfg["cache_path"] = FPOptCachePath.getValue(); + json::Array budgetArr; + for (const auto &pair : costToAccuracyMap) + budgetArr.push_back(GET_INSTRUCTION_COST(pair.first)); + cfg["budgets"] = std::move(budgetArr); + json::Array accArr; + for (const auto &pair : costToAccuracyMap) + accArr.push_back(pair.second); + cfg["estimated_accuracy_costs"] = std::move(accArr); + + raw_fd_ostream cfgOut(configFile, EC, sys::fs::OF_Text); + if (EC) { + llvm::errs() << "Error writing validate config: " << EC.message() + << "\n"; + } else { + cfgOut << formatv("{0:2}", json::Value(std::move(cfg))) << "\n"; + } + } + } + + // --- Curated per-rewrite analysis --- + // Enumerate every individual rewrite candidate, categorize by its marginal + // impact, and rank tradeoffs by efficiency (benefit per unit cost). + { + json::Array rewrites; + for (size_t coIdx = 0; coIdx < COs.size(); ++coIdx) { + auto &CO = COs[coIdx]; + for (size_t ci = 0; ci < CO.candidates.size(); ++ci) { + auto &cand = CO.candidates[ci]; + double accDelta = CO.getAccCostDelta(ci); + auto compDelta = CO.getCompCostDelta(ci); + int64_t compDeltaVal = GET_INSTRUCTION_COST(compDelta); + + // Skip candidates with no usable data + if (std::isnan(accDelta)) + continue; + + // Categorize + // comp < 0 means faster, acc < 0 means more accurate + std::string category; + double efficiency = 0.0; + if (compDeltaVal <= 0 && accDelta <= 0) { + category = "free_win"; + // Rank by combined magnitude of improvement + efficiency = std::abs(accDelta) + std::abs((double)compDeltaVal); + } else if (compDeltaVal <= 0 && accDelta > 0) { + category = "speed_for_accuracy"; + // Efficiency: speed gained per unit of accuracy lost + efficiency = (accDelta > 1e-30) + ? std::abs((double)compDeltaVal) / accDelta + : std::abs((double)compDeltaVal); + } else if (compDeltaVal > 0 && accDelta <= 0) { + category = "accuracy_for_speed"; + // Efficiency: accuracy gained per unit of speed lost + efficiency = (compDeltaVal > 0) + ? std::abs(accDelta) / (double)compDeltaVal + : std::abs(accDelta); + } else { + continue; // Both worse — skip + } + + json::Object rw; + std::string id = "R" + std::to_string(coIdx) + "_" + std::to_string(ci); + rw["id"] = id; + rw["original_expr"] = CO.expr; + rw["rewritten_expr"] = cand.expr; + rw["category"] = category; + rw["efficiency"] = jsonFloat(efficiency); + rw["computation_cost_delta"] = compDeltaVal; + rw["accuracy_cost_delta"] = jsonFloat(accDelta); + rw["gradient"] = jsonFloat(CO.grad); + rw["executions"] = static_cast(CO.executions); + rw["herbie_accuracy"] = jsonFloat(cand.herbieAccuracy); + rw["initial_herbie_accuracy"] = jsonFloat(CO.initialHerbieAccuracy); + + // Source location + if (auto *I = dyn_cast(CO.oldOutput)) { + auto loc = getSourceLocationJSON(I); + if (!loc.empty()) + rw["source_location"] = std::move(loc); + } + + rewrites.push_back(std::move(rw)); + } + } + + // Also include precision tuning candidates + for (size_t csIdx = 0; csIdx < CSs.size(); ++csIdx) { + auto &CS = CSs[csIdx]; + for (size_t ci = 0; ci < CS.candidates.size(); ++ci) { + auto &pt = CS.candidates[ci]; + double accDelta = CS.getAccCostDelta(ci); + auto compDelta = CS.getCompCostDelta(ci); + int64_t compDeltaVal = GET_INSTRUCTION_COST(compDelta); + + if (std::isnan(accDelta)) + continue; + + std::string category; + double efficiency = 0.0; + if (compDeltaVal <= 0 && accDelta <= 0) { + category = "free_win"; + efficiency = std::abs(accDelta) + std::abs((double)compDeltaVal); + } else if (compDeltaVal <= 0 && accDelta > 0) { + category = "speed_for_accuracy"; + efficiency = (accDelta > 1e-30) + ? std::abs((double)compDeltaVal) / accDelta + : std::abs((double)compDeltaVal); + } else if (compDeltaVal > 0 && accDelta <= 0) { + category = "accuracy_for_speed"; + efficiency = (compDeltaVal > 0) + ? std::abs(accDelta) / (double)compDeltaVal + : std::abs(accDelta); + } else { + continue; + } + + json::Object rw; + std::string id = + "PT" + std::to_string(csIdx) + "_" + std::to_string(ci); + rw["id"] = id; + rw["type"] = "precision_change"; + rw["description"] = pt.desc; + rw["category"] = category; + rw["efficiency"] = jsonFloat(efficiency); + rw["computation_cost_delta"] = compDeltaVal; + rw["accuracy_cost_delta"] = jsonFloat(accDelta); + + // Collect affected instructions + deduplicated source locations + SmallVector ptInsts; + for (const auto &change : pt.changes) { + for (auto *node : change.nodes) { + if (auto *I = dyn_cast(node->value)) + ptInsts.push_back(I); + } + } + rw["source_locations"] = getSourceLocationsJSON(ptInsts); + json::Array affectedIR; + for (auto *I : ptInsts) + affectedIR.push_back(getValueStr(I)); + rw["affected_instructions"] = std::move(affectedIR); + + rewrites.push_back(std::move(rw)); + } + } + + // Sort: free_wins first, then by efficiency descending + auto rewriteVec = + SmallVector(rewrites.begin(), rewrites.end()); + llvm::sort(rewriteVec, [](const json::Value &a, const json::Value &b) { + auto *ao = a.getAsObject(); + auto *bo = b.getAsObject(); + StringRef aCat = ao->getString("category").value_or(""); + StringRef bCat = bo->getString("category").value_or(""); + auto catRank = [](StringRef c) -> int { + if (c == "free_win") + return 0; + if (c == "accuracy_for_speed") + return 1; + if (c == "speed_for_accuracy") + return 2; + return 3; + }; + int ar = catRank(aCat), br = catRank(bCat); + if (ar != br) + return ar < br; + double ae = ao->getNumber("efficiency").value_or(0); + double be = bo->getNumber("efficiency").value_or(0); + return ae > be; // higher efficiency first + }); + + json::Array sortedRewrites; + for (auto &v : rewriteVec) + sortedRewrites.push_back(std::move(v)); + + std::string rewritesFile = + (Twine(FPOptReportPath) + "/" + funcName + "_rewrites.json").str(); + raw_fd_ostream rwOut(rewritesFile, EC, sys::fs::OF_Text); + if (EC) { + llvm::errs() << "Error writing rewrites report: " << EC.message() << "\n"; + } else { + json::Object root; + root["function"] = funcName.str(); + root["total_rewrites"] = static_cast(sortedRewrites.size()); + root["rewrites"] = std::move(sortedRewrites); + rwOut << formatv("{0:2}", json::Value(std::move(root))) << "\n"; + llvm::errs() << "Poseidon curated rewrites written to " << rewritesFile + << "\n"; + } + } +} + +// Given the cost budget `FPOptComputationCostBudget`, we want to minimize the +// accuracy cost of the rewritten expressions. +bool accuracyGreedySolver( + SmallVector &COs, + SmallVector &CSs, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + bool changed = false; + llvm::errs() << "Starting accuracy greedy solver with computation budget: " + << FPOptComputationCostBudget << "\n"; + InstructionCost totalComputationCost = 0; + + SmallVector aoIndices; + for (size_t i = 0; i < COs.size(); ++i) { + aoIndices.push_back(i); + } + std::mt19937 g(FPOptRandomSeed); + std::shuffle(aoIndices.begin(), aoIndices.end(), g); + + for (size_t idx : aoIndices) { + auto &CO = COs[idx]; + int bestCandidateIndex = -1; + double bestAccuracyCost = std::numeric_limits::infinity(); + InstructionCost bestCandidateComputationCost; + + for (const auto &candidate : enumerate(CO.candidates)) { + size_t i = candidate.index(); + auto candCompCost = CO.getCompCostDelta(i); + auto candAccCost = CO.getAccCostDelta(i); + // llvm::errs() << "CO Candidate " << i << " for " << CO.expr + // << " has accuracy cost: " << candAccCost + // << " and computation cost: " << candCompCost << "\n"; + + if (totalComputationCost + candCompCost <= FPOptComputationCostBudget) { + if (candAccCost < bestAccuracyCost) { + // llvm::errs() << "CO Candidate " << i << " selected!\n"; + bestCandidateIndex = i; + bestAccuracyCost = candAccCost; + bestCandidateComputationCost = candCompCost; + } + } + } + + if (bestCandidateIndex != -1) { + CO.apply(bestCandidateIndex, valueToNodeMap, symbolToValueMap); + changed = true; + totalComputationCost += bestCandidateComputationCost; + if (FPOptPrint) { + llvm::errs() << "Greedy solver selected candidate " + << bestCandidateIndex << " for " << CO.expr + << " with accuracy cost: " << bestAccuracyCost + << " and computation cost: " + << bestCandidateComputationCost << "\n"; + } + } + } + + SmallVector accIndices; + for (size_t i = 0; i < CSs.size(); ++i) { + accIndices.push_back(i); + } + std::shuffle(accIndices.begin(), accIndices.end(), g); + + for (size_t idx : accIndices) { + auto &CS = CSs[idx]; + int bestCandidateIndex = -1; + double bestAccuracyCost = std::numeric_limits::infinity(); + InstructionCost bestCandidateComputationCost; + + for (const auto &candidate : enumerate(CS.candidates)) { + size_t i = candidate.index(); + auto candCompCost = CS.getCompCostDelta(i); + auto candAccCost = CS.getAccCostDelta(i); + // llvm::errs() << "CS Candidate " << i << " (" << candidate.value().desc + // << ") has accuracy cost: " << candAccCost + // << " and computation cost: " << candCompCost << "\n"; + + if (totalComputationCost + candCompCost <= FPOptComputationCostBudget) { + if (candAccCost < bestAccuracyCost) { + // llvm::errs() << "CS Candidate " << i << " selected!\n"; + bestCandidateIndex = i; + bestAccuracyCost = candAccCost; + bestCandidateComputationCost = candCompCost; + } + } + } + + if (bestCandidateIndex != -1) { + CS.apply(bestCandidateIndex); + changed = true; + totalComputationCost += bestCandidateComputationCost; + if (FPOptPrint) { + llvm::errs() << "Greedy solver selected candidate " + << bestCandidateIndex << " for " + << CS.candidates[bestCandidateIndex].desc + << " with accuracy cost: " << bestAccuracyCost + << " and computation cost: " + << bestCandidateComputationCost << "\n"; + } + } + } + + llvm::errs() << "Greedy solver finished with total computation cost: " + << totalComputationCost + << "; total allowance: " << FPOptComputationCostBudget << "\n"; + + return changed; +} + +bool accuracyDPSolver( + Function &F, const TargetTransformInfo &TTI, + SmallVector &COs, + SmallVector &CSs, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap, + double errorTol) { + bool changed = false; + llvm::errs() << "Starting accuracy DP solver with computation budget: " + << FPOptComputationCostBudget << "\n"; + if (errorTol > 0.0) { + llvm::errs() << "Absolute error tolerance: " << errorTol << "\n"; + } + + using CostMap = std::map; + using SolutionMap = std::map>; + + CostMap costToAccuracyMap; + SolutionMap costToSolutionMap; + CostMap newCostToAccuracyMap; + SolutionMap newCostToSolutionMap; + CostMap prunedCostToAccuracyMap; + SolutionMap prunedCostToSolutionMap; + + std::string cacheFilePath = FPOptCachePath + "/table.json"; + bool loadedFromCache = false; + + if (llvm::sys::fs::exists(cacheFilePath)) { + llvm::errs() << "Cache file found. Loading DP tables from cache.\n"; + loadedFromCache = true; + + llvm::ErrorOr> fileOrErr = + llvm::MemoryBuffer::getFile(cacheFilePath); + if (std::error_code ec = fileOrErr.getError()) { + llvm::errs() << "Error reading cache file: " << ec.message() << "\n"; + return changed; + } + llvm::StringRef buffer = fileOrErr.get()->getBuffer(); + llvm::Expected jsonOrErr = llvm::json::parse(buffer); + if (!jsonOrErr) { + llvm::errs() << "Error parsing JSON from cache file: " + << llvm::toString(jsonOrErr.takeError()) << "\n"; + return changed; + } + + llvm::json::Object *jsonObj = jsonOrErr->getAsObject(); + if (!jsonObj) { + llvm::errs() << "Invalid JSON format in cache file.\n"; + return changed; + } + + if (llvm::json::Object *costAccMap = + jsonObj->getObject("costToAccuracyMap")) { + for (auto &pair : *costAccMap) { + InstructionCost compCost(std::stoll(pair.first.str())); + double accCost = pair.second.getAsNumber().value(); + costToAccuracyMap[compCost] = accCost; + } + } else { + llvm_unreachable("Invalid costToAccuracyMap in cache file."); + } + + if (llvm::json::Object *costSolMap = + jsonObj->getObject("costToSolutionMap")) { + for (auto &pair : *costSolMap) { + InstructionCost compCost(std::stoll(pair.first.str())); + SmallVector solutionSteps; + + llvm::json::Array *stepsArray = pair.second.getAsArray(); + if (!stepsArray) { + llvm::errs() << "Invalid steps array in cache file.\n"; + return changed; + } + + for (llvm::json::Value &stepVal : *stepsArray) { + llvm::json::Object *stepObj = stepVal.getAsObject(); + if (!stepObj) { + llvm_unreachable("Invalid step object in cache file."); + } + + StringRef itemType = stepObj->getString("itemType").value(); + size_t candidateIndex = stepObj->getInteger("candidateIndex").value(); + size_t itemIndex = stepObj->getInteger("itemIndex").value(); + + if (itemType == "CO") { + if (itemIndex >= COs.size()) { + llvm_unreachable("Invalid CandidateOutput index in cache file."); + } + solutionSteps.emplace_back(&COs[itemIndex], candidateIndex); + } else if (itemType == "CS") { + if (itemIndex >= CSs.size()) { + llvm_unreachable( + "Invalid CandidateSubgraph index in cache file."); + } + solutionSteps.emplace_back(&CSs[itemIndex], candidateIndex); + } else { + llvm_unreachable("Invalid itemType in cache file."); + } + } + + costToSolutionMap[compCost] = solutionSteps; + } + } else { + llvm::errs() << "costToSolutionMap not found in cache file.\n"; + return changed; + } + + llvm::errs() << "Loaded DP tables from cache.\n"; + + } else { + llvm::errs() << "Cache file not found. Proceeding to solve DP.\n"; + + costToAccuracyMap[0] = 0; + costToSolutionMap[0] = {}; + + std::unordered_map aoPtrToIndex; + for (size_t i = 0; i < COs.size(); ++i) { + aoPtrToIndex[&COs[i]] = i; + } + std::unordered_map accPtrToIndex; + for (size_t i = 0; i < CSs.size(); ++i) { + accPtrToIndex[&CSs[i]] = i; + } + + int COCounter = 0; + + for (auto &CO : COs) { + // It is possible to apply zero candidate for an CO. + // When no candidate is applied, the resulting accuracy cost + // and solution steps remain the same. + newCostToAccuracyMap = costToAccuracyMap; + newCostToSolutionMap = costToSolutionMap; + + for (const auto &pair : costToAccuracyMap) { + InstructionCost currCompCost = pair.first; + double currAccCost = pair.second; + + for (const auto &candidate : enumerate(CO.candidates)) { + size_t i = candidate.index(); + auto candCompCost = CO.getCompCostDelta(i); + auto candAccCost = CO.getAccCostDelta(i); + + // Don't apply a candidate that strictly makes things worse + if (candCompCost >= 0 && candAccCost >= 0.) { + continue; + } + + InstructionCost newCompCost = currCompCost + candCompCost; + double newAccCost = currAccCost + candAccCost; + + // if (FPOptPrint) + // llvm::errs() << "CO candidate " << i + // << " has accuracy cost: " << candAccCost + // << " and computation cost: " << candCompCost << + // "\n"; + + if (newCostToAccuracyMap.find(newCompCost) == + newCostToAccuracyMap.end() || + newCostToAccuracyMap[newCompCost] > newAccCost) { + newCostToAccuracyMap[newCompCost] = newAccCost; + newCostToSolutionMap[newCompCost] = costToSolutionMap[currCompCost]; + newCostToSolutionMap[newCompCost].emplace_back(&CO, i); + // if (FPOptPrint) + // llvm::errs() << "Updating accuracy map (CO candidate " << i + // << "): computation cost " << newCompCost + // << " -> accuracy cost " << newAccCost << "\n"; + } + } + } + + // TODO: Do not prune CO parts of the DP table since COs influence CSs + if (!FPOptEarlyPrune) { + costToAccuracyMap = newCostToAccuracyMap; + costToSolutionMap = newCostToSolutionMap; + + llvm::errs() << "##### Finished processing " << ++COCounter << " of " + << COs.size() << " COs #####\n"; + llvm::errs() << "Current DP table sizes: " << costToAccuracyMap.size() + << "\n"; + continue; + } + + for (const auto &l : newCostToAccuracyMap) { + InstructionCost currCompCost = l.first; + double currAccCost = l.second; + + bool dominated = false; + for (const auto &r : newCostToAccuracyMap) { + InstructionCost otherCompCost = r.first; + double otherAccCost = r.second; + + if (currCompCost - otherCompCost > + std::fabs(FPOptCostDominanceThreshold * + GET_INSTRUCTION_COST(otherCompCost)) && + currAccCost - otherAccCost >= + std::fabs(FPOptAccuracyDominanceThreshold * otherAccCost)) { + // if (FPOptPrint) + // llvm::errs() << "CO candidate with computation cost: " + // << currCompCost + // << " and accuracy cost: " << currAccCost + // << " is dominated by candidate with computation + // cost:" + // << otherCompCost + // << " and accuracy cost: " << otherAccCost << "\n"; + dominated = true; + break; + } + } + + if (!dominated) { + prunedCostToAccuracyMap[currCompCost] = currAccCost; + prunedCostToSolutionMap[currCompCost] = + newCostToSolutionMap[currCompCost]; + } + } + + costToAccuracyMap = prunedCostToAccuracyMap; + costToSolutionMap = prunedCostToSolutionMap; + prunedCostToAccuracyMap.clear(); + prunedCostToSolutionMap.clear(); + + llvm::errs() << "##### Finished processing " << ++COCounter << " of " + << COs.size() << " COs #####\n"; + llvm::errs() << "Current DP table sizes: " << costToAccuracyMap.size() + << "\n"; + } + + int CSCounter = 0; + + for (auto &CS : CSs) { + // It is possible to apply zero candidate for an CS. + // When no candidate is applied, the resulting accuracy cost + // and solution steps remain the same. + newCostToAccuracyMap = costToAccuracyMap; + newCostToSolutionMap = costToSolutionMap; + + for (const auto &pair : costToAccuracyMap) { + InstructionCost currCompCost = pair.first; + double currAccCost = pair.second; + + for (const auto &candidate : enumerate(CS.candidates)) { + size_t i = candidate.index(); + auto candCompCost = + CS.getAdjustedCompCostDelta(i, costToSolutionMap[currCompCost]); + auto candAccCost = + CS.getAdjustedAccCostDelta(i, costToSolutionMap[currCompCost], + valueToNodeMap, symbolToValueMap); + + // Don't ever try to apply a strictly useless candidate + if (candCompCost >= 0 && candAccCost >= 0.) { + continue; + } + + InstructionCost newCompCost = currCompCost + candCompCost; + double newAccCost = currAccCost + candAccCost; + + // if (FPOptPrint) + // llvm::errs() << "CS candidate " << i << " (" + // << candidate.value().desc + // << ") has accuracy cost: " << candAccCost + // << " and computation cost: " << candCompCost << + // "\n"; + + if (newCostToAccuracyMap.find(newCompCost) == + newCostToAccuracyMap.end() || + newCostToAccuracyMap[newCompCost] > newAccCost) { + newCostToAccuracyMap[newCompCost] = newAccCost; + newCostToSolutionMap[newCompCost] = costToSolutionMap[currCompCost]; + newCostToSolutionMap[newCompCost].emplace_back(&CS, i); + // if (FPOptPrint) { + // llvm::errs() << "CS candidate " << i << " (" + // << candidate.value().desc + // << ") added; has accuracy cost: " << candAccCost + // << " and computation cost: " << candCompCost << + // "\n"; + // llvm::errs() << "Updating accuracy map (CS candidate " << i + // << "): computation cost " << newCompCost + // << " -> accuracy cost " << newAccCost << "\n"; + // } + } + } + } + + for (const auto &l : newCostToAccuracyMap) { + InstructionCost currCompCost = l.first; + double currAccCost = l.second; + + bool dominated = false; + for (const auto &r : newCostToAccuracyMap) { + InstructionCost otherCompCost = r.first; + double otherAccCost = r.second; + + if (currCompCost - otherCompCost > + std::fabs(FPOptCostDominanceThreshold * + GET_INSTRUCTION_COST(otherCompCost)) && + currAccCost - otherAccCost >= + std::fabs(FPOptAccuracyDominanceThreshold * otherAccCost)) { + // if (FPOptPrint) + // llvm::errs() << "CS candidate with computation cost: " + // << currCompCost + // << " and accuracy cost: " << currAccCost + // << " is dominated by candidate with computation + // cost:" + // << otherCompCost + // << " and accuracy cost: " << otherAccCost << "\n"; + dominated = true; + break; + } + } + + if (!dominated) { + prunedCostToAccuracyMap[currCompCost] = currAccCost; + prunedCostToSolutionMap[currCompCost] = + newCostToSolutionMap[currCompCost]; + } + } + + costToAccuracyMap = prunedCostToAccuracyMap; + costToSolutionMap = prunedCostToSolutionMap; + prunedCostToAccuracyMap.clear(); + prunedCostToSolutionMap.clear(); + + llvm::errs() << "##### Finished processing " << ++CSCounter << " of " + << CSs.size() << " CSs #####\n"; + llvm::errs() << "Current DP table sizes: " << costToAccuracyMap.size() + << "\n"; + } + + json::Object jsonObj; + + json::Object costAccMap; + for (const auto &pair : costToAccuracyMap) { + costAccMap[std::to_string(GET_INSTRUCTION_COST(pair.first))] = + pair.second; + } + jsonObj["costToAccuracyMap"] = std::move(costAccMap); + + json::Object costSolMap; + for (const auto &pair : costToSolutionMap) { + json::Array stepsArray; + for (const auto &step : pair.second) { + json::Object stepObj; + stepObj["candidateIndex"] = static_cast(step.candidateIndex); + + std::visit( + [&](auto *item) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + stepObj["itemType"] = "CO"; + size_t index = aoPtrToIndex[item]; + stepObj["itemIndex"] = static_cast(index); + } else if constexpr (std::is_same_v) { + stepObj["itemType"] = "CS"; + size_t index = accPtrToIndex[item]; + stepObj["itemIndex"] = static_cast(index); + } + }, + step.item); + stepsArray.push_back(std::move(stepObj)); + } + costSolMap[std::to_string(GET_INSTRUCTION_COST(pair.first))] = + std::move(stepsArray); + } + jsonObj["costToSolutionMap"] = std::move(costSolMap); + + if (!FPOptRefineDPTable) { + std::error_code EC; + llvm::raw_fd_ostream cacheFile(cacheFilePath, EC, llvm::sys::fs::OF_Text); + if (EC) { + llvm::errs() << "Error writing cache file: " << EC.message() << "\n"; + } else { + cacheFile << llvm::formatv("{0:2}", + llvm::json::Value(std::move(jsonObj))) + << "\n"; + cacheFile.close(); + llvm::errs() << "DP tables cached to file.\n"; + } + } else if (!COs.empty() || !CSs.empty()) { + ValueToValueMapTy BaseVMap; + Function *BaseClone = CloneFunction(&F, BaseVMap); + runPoseidonFunctionSimplify(*BaseClone, OptimizationLevel::O3); + InstructionCost BaseCost = getCompCost(BaseClone, TTI); + BaseClone->eraseFromParent(); + + using CostMap = std::map; + using SolutionMap = std::map>; + CostMap refinedCostToAccuracyMap; + SolutionMap refinedCostToSolutionMap; + + for (const auto &pair : costToSolutionMap) { + const SmallVector &steps = pair.second; + + ValueToValueMapTy VMap; + Function *FClone = CloneFunction(&F, VMap); + + for (const auto &step : steps) { + std::visit( + [&](auto *item) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + auto &CO = *item; + Instruction *oldI = cast(CO.oldOutput); + Instruction *clonedOldI = cast(VMap[oldI]); + IRBuilder<> builder(clonedOldI->getParent(), + ++BasicBlock::iterator(clonedOldI)); + builder.setFastMathFlags(clonedOldI->getFastMathFlags()); + auto parsedNode = + parseHerbieExpr(CO.candidates[step.candidateIndex].expr, + valueToNodeMap, symbolToValueMap); + Value *newVal = parsedNode->getLLValue(builder, &VMap); + clonedOldI->replaceAllUsesWith(newVal); + } else if constexpr (std::is_same_v) { + auto &CS = *item; + PTCandidate &pt = CS.candidates[step.candidateIndex]; + pt.apply(*CS.subgraph, &VMap); + } else { + llvm_unreachable("accuracyDPSolver refine: unexpected step"); + } + }, + step.item); + } + + runPoseidonFunctionSimplify(*FClone, OptimizationLevel::O3); + InstructionCost NewTotal = getCompCost(FClone, TTI); + InstructionCost NewDelta = NewTotal - BaseCost; + FClone->eraseFromParent(); + + double accCost = costToAccuracyMap[pair.first]; + auto it = refinedCostToAccuracyMap.find(NewDelta); + if (it == refinedCostToAccuracyMap.end() || it->second > accCost) { + refinedCostToAccuracyMap[NewDelta] = accCost; + refinedCostToSolutionMap[NewDelta] = steps; + } + } + + // One-pass domination pruning + std::map refinedPrunedAcc; + std::map> refinedPrunedSol; + for (const auto &l : refinedCostToAccuracyMap) { + InstructionCost currCompCost = l.first; + double currAccCost = l.second; + bool dominated = false; + for (const auto &r : refinedCostToAccuracyMap) { + InstructionCost otherCompCost = r.first; + double otherAccCost = r.second; + if (currCompCost - otherCompCost > + std::fabs(FPOptCostDominanceThreshold * + GET_INSTRUCTION_COST(otherCompCost)) && + currAccCost - otherAccCost >= + std::fabs(FPOptAccuracyDominanceThreshold * otherAccCost)) { + dominated = true; + break; + } + } + if (!dominated) { + refinedPrunedAcc[currCompCost] = currAccCost; + refinedPrunedSol[currCompCost] = + refinedCostToSolutionMap[currCompCost]; + } + } + + costToAccuracyMap = std::move(refinedPrunedAcc); + costToSolutionMap = std::move(refinedPrunedSol); + + json::Object jsonObj; + + json::Object costAccMap; + for (const auto &p : costToAccuracyMap) { + costAccMap[std::to_string(GET_INSTRUCTION_COST(p.first))] = p.second; + } + jsonObj["costToAccuracyMap"] = std::move(costAccMap); + + json::Object costSolMap; + std::unordered_map aoPtrToIndex; + for (size_t i = 0; i < COs.size(); ++i) + aoPtrToIndex[&COs[i]] = i; + std::unordered_map accPtrToIndex; + for (size_t i = 0; i < CSs.size(); ++i) + accPtrToIndex[&CSs[i]] = i; + + for (const auto &p : costToSolutionMap) { + json::Array stepsArray; + for (const auto &step : p.second) { + json::Object stepObj; + stepObj["candidateIndex"] = static_cast(step.candidateIndex); + std::visit( + [&](auto *item) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + stepObj["itemType"] = "CO"; + size_t index = aoPtrToIndex[item]; + stepObj["itemIndex"] = static_cast(index); + } else if constexpr (std::is_same_v) { + stepObj["itemType"] = "CS"; + size_t index = accPtrToIndex[item]; + stepObj["itemIndex"] = static_cast(index); + } + }, + step.item); + stepsArray.push_back(std::move(stepObj)); + } + costSolMap[std::to_string(GET_INSTRUCTION_COST(p.first))] = + std::move(stepsArray); + } + jsonObj["costToSolutionMap"] = std::move(costSolMap); + + std::error_code EC; + llvm::raw_fd_ostream cacheFile(cacheFilePath, EC, llvm::sys::fs::OF_Text); + if (EC) { + llvm::errs() << "Error writing refined cache file: " << EC.message() + << "\n"; + } else { + cacheFile << llvm::formatv("{0:2}", + llvm::json::Value(std::move(jsonObj))) + << "\n"; + cacheFile.close(); + llvm::errs() << "Refined DP tables cached to file.\n"; + } + } + } + + if (FPOptPrint) { + if (FPOptShowTable) { + llvm::errs() << "\n*** DP Table ***\n"; + for (const auto &pair : costToAccuracyMap) { + if (!FPOptShowTableCosts.empty()) { + bool shouldPrint = false; + for (auto selectedCost : FPOptShowTableCosts) + if (pair.first == selectedCost) { + shouldPrint = true; + break; + } + if (!shouldPrint) + continue; + } + + llvm::errs() << "Computation cost: " << pair.first + << ", Accuracy cost: " << pair.second << "\n"; + llvm::errs() << "\tSolution steps: \n"; + for (const auto &step : costToSolutionMap[pair.first]) { + std::visit( + [&](auto *item) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + llvm::errs() + << "\t\t" << item->expr << " --(" << step.candidateIndex + << ")-> " << item->candidates[step.candidateIndex].expr + << "\n"; + } else if constexpr (std::is_same_v) { + llvm::errs() << "\t\tCS: " + << item->candidates[step.candidateIndex].desc + << " (#" << step.candidateIndex << ")\n"; + if (FPOptShowPTDetails) { + auto &candidate = item->candidates[step.candidateIndex]; + for (const auto &change : candidate.changes) { + llvm::errs() + << "\t\t\tChanging from " + << getPrecisionChangeTypeString(change.oldType) + << " to " + << getPrecisionChangeTypeString(change.newType) + << ":\n"; + for (auto *val : change.nodes) { + llvm::errs() << "\t\t\t\t" << *val->value << "\n"; + } + } + } + } else { + llvm_unreachable( + "accuracyDPSolver: Unexpected type of solution step"); + } + }, + step.item); + } + } + llvm::errs() << "*** End of DP Table ***\n\n"; + } + } + + std::string budgetsFile = FPOptCachePath + "/budgets.txt"; + if (!llvm::sys::fs::exists(budgetsFile)) { + std::string budgetsStr; + for (const auto &pair : costToAccuracyMap) { + budgetsStr += std::to_string(GET_INSTRUCTION_COST(pair.first)) + ","; + } + + if (!budgetsStr.empty()) + budgetsStr.pop_back(); + + std::error_code EC; + llvm::raw_fd_ostream Out(budgetsFile, EC, llvm::sys::fs::OF_Text); + if (EC) { + llvm::errs() << "Error opening " << budgetsFile << ": " << EC.message() + << "\n"; + } else { + Out << budgetsStr; + } + } + + if (!loadedFromCache) + emitPoseidonReport(F.getName(), costToAccuracyMap, costToSolutionMap, COs, + CSs); + + llvm::errs() << "Critical computation cost range: [" + << costToAccuracyMap.begin()->first << ", " + << costToAccuracyMap.rbegin()->first << "]\n"; + + llvm::errs() << "DP table contains " << costToAccuracyMap.size() + << " entries.\n"; + + double totalCandidateCompositions = 1.0; + for (const auto &CO : COs) { + // +1 for the "do nothing" possibility + totalCandidateCompositions *= CO.candidates.size() + 1; + } + for (const auto &CS : CSs) { + totalCandidateCompositions *= CS.candidates.size() + 1; + } + llvm::errs() << "Total candidate compositions: " << totalCandidateCompositions + << "\n"; + + if (costToSolutionMap.find(0) != costToSolutionMap.end()) { + if (costToSolutionMap[0].empty()) { + llvm::errs() << "WARNING: No-op solution (utilized cost budget = 0) is " + "considered Pareto-optimal.\n"; + } + } + + double minAccCost = std::numeric_limits::infinity(); + InstructionCost bestCompCost = 0; + + if (errorTol > 0.0) { + InstructionCost minCompCost = std::numeric_limits::max(); + bool foundSolution = false; + + for (const auto &pair : costToAccuracyMap) { + InstructionCost compCost = pair.first; + double accCost = pair.second; + + if (accCost <= errorTol) { + if (compCost < minCompCost) { + minCompCost = compCost; + minAccCost = accCost; + bestCompCost = compCost; + foundSolution = true; + } + } + } + + if (!foundSolution) { + llvm::errs() << "No solution found that meets accuracy tolerance " + << errorTol << "!\n"; + llvm::errs() << "Best achievable accuracy in DP table: " + << costToAccuracyMap.begin()->second << "\n"; + return changed; + } + + llvm::errs() << "Found solution meeting accuracy tolerance " << errorTol + << "\n"; + llvm::errs() << "Accuracy cost achieved: " << minAccCost << "\n"; + llvm::errs() << "Computation cost required: " << bestCompCost << "\n"; + } else { + for (const auto &pair : costToAccuracyMap) { + InstructionCost compCost = pair.first; + double accCost = pair.second; + + if (compCost <= FPOptComputationCostBudget && accCost < minAccCost) { + minAccCost = accCost; + bestCompCost = compCost; + } + } + + if (minAccCost == std::numeric_limits::infinity()) { + llvm::errs() << "No solution found within the computation cost budget!\n"; + return changed; + } + + llvm::errs() << "Minimum accuracy cost within budget: " << minAccCost + << "\n"; + llvm::errs() << "Computation cost budget used: " << bestCompCost << "\n"; + } + + assert(costToSolutionMap.find(bestCompCost) != costToSolutionMap.end() && + "FPOpt DP solver: expected a solution!"); + + llvm::errs() << "\n!!! DP solver: Applying solution ... !!!\n"; + for (const auto &solution : costToSolutionMap[bestCompCost]) { + std::visit( + [&](auto *item) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + llvm::errs() << "Applying solution for " << item->expr << " --(" + << solution.candidateIndex << ")-> " + << item->candidates[solution.candidateIndex].expr + << "\n"; + item->apply(solution.candidateIndex, valueToNodeMap, + symbolToValueMap); + } else if constexpr (std::is_same_v) { + llvm::errs() << "Applying solution for CS: " + << item->candidates[solution.candidateIndex].desc + << " (#" << solution.candidateIndex << ")\n"; + item->apply(solution.candidateIndex); + } else { + llvm_unreachable( + "accuracyDPSolver: Unexpected type of solution step"); + } + }, + solution.item); + changed = true; + } + llvm::errs() << "!!! DP Solver: Solution applied !!!\n\n"; + + return changed; +} \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonSolvers.h b/enzyme/Enzyme/Poseidon/PoseidonSolvers.h new file mode 100644 index 000000000000..846c89d13b37 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonSolvers.h @@ -0,0 +1,55 @@ +//=- PoseidonSolvers.h - Solver utilities for Poseidon --------------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares solver-related utilities for the Poseidon optimization +// pass. +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYME_POSEIDON_SOLVERS_H +#define ENZYME_POSEIDON_SOLVERS_H + +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/CommandLine.h" + +#include +#include + +#include "PoseidonTypes.h" + +using namespace llvm; + +extern "C" { +extern llvm::cl::opt FPOptSolverType; +extern llvm::cl::opt FPOptComputationCostBudget; +extern llvm::cl::opt FPOptShowTable; +extern llvm::cl::list FPOptShowTableCosts; +extern llvm::cl::opt FPOptEarlyPrune; +extern llvm::cl::opt FPOptCostDominanceThreshold; +extern llvm::cl::opt FPOptAccuracyDominanceThreshold; +extern llvm::cl::opt FPOptApplyRewrites; +} + +bool accuracyGreedySolver( + SmallVector &COs, + SmallVector &CSs, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap); + +bool accuracyDPSolver( + Function &F, const TargetTransformInfo &TTI, + SmallVector &COs, + SmallVector &CSs, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap, + double errorTol = 0.0); + +#endif // ENZYME_POSEIDON_SOLVERS_H \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonTypes.cpp b/enzyme/Enzyme/Poseidon/PoseidonTypes.cpp new file mode 100644 index 000000000000..93604b9e9d39 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonTypes.cpp @@ -0,0 +1,960 @@ +//=- PoseidonNodes.cpp - AST node implementations for Poseidon ------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AST node classes for representing floating-point +// expressions in the Poseidon optimization pass. +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +#include +#include +#include + +#include "../Utils.h" +#include "Poseidon.h" +#include "PoseidonHerbieUtils.h" +#include "PoseidonPrecUtils.h" +#include "PoseidonTypes.h" +#include "PoseidonUtils.h" + +FPNode::NodeType FPNode::getType() const { return ntype; } + +void FPNode::addOperand(std::shared_ptr operand) { + operands.push_back(operand); +} + +bool FPNode::hasSymbol() const { + std::string msg = "Unexpected invocation of `hasSymbol` on an " + "unmaterialized " + + op + " FPNode"; + llvm_unreachable(msg.c_str()); +} + +std::string FPNode::toFullExpression( + std::unordered_map> &valueToNodeMap, + const SetVector &subgraphInputs, unsigned depth) { + std::string msg = "Unexpected invocation of `toFullExpression` on an " + "unmaterialized " + + op + " FPNode"; + llvm_unreachable(msg.c_str()); +} + +unsigned FPNode::getMPFRPrec() const { + if (dtype == "f16") + return 11; + if (dtype == "f32") + return 24; + if (dtype == "f64") + return 53; + std::string msg = + "getMPFRPrec: operator " + op + " has unknown dtype " + dtype; + llvm_unreachable(msg.c_str()); +} + +void FPNode::updateBounds(double lower, double upper) { + std::string msg = "Unexpected invocation of `updateBounds` on an " + "unmaterialized " + + op + " FPNode"; + llvm_unreachable(msg.c_str()); +} + +double FPNode::getLowerBound() const { + std::string msg = "Unexpected invocation of `getLowerBound` on an " + "unmaterialized " + + op + " FPNode"; + llvm_unreachable(msg.c_str()); +} + +double FPNode::getUpperBound() const { + std::string msg = "Unexpected invocation of `getUpperBound` on an " + "unmaterialized " + + op + " FPNode"; + llvm_unreachable(msg.c_str()); +} + +Value *FPNode::getLLValue(IRBuilder<> &builder, const ValueToValueMapTy *VMap) { + Module *M = builder.GetInsertBlock()->getModule(); + if (op == "if") { + Value *condValue = operands[0]->getLLValue(builder, VMap); + Value *trueValue = operands[1]->getLLValue(builder, VMap); + Value *falseValue = operands[2]->getLLValue(builder, VMap); + + return builder.CreateSelect(condValue, trueValue, falseValue, + "herbie.select"); + } + + SmallVector operandValues; + for (auto operand : operands) { + Value *val = operand->getLLValue(builder, VMap); + assert(val && "Operand produced a null value!"); + operandValues.push_back(val); + } + + static const std::unordered_map< + std::string, std::function &, Module *, + const SmallVectorImpl &)>> + opMap = { + {"neg", + [](IRBuilder<> &b, Module *M, const SmallVectorImpl &ops) + -> Value * { return b.CreateFNeg(ops[0], "herbie.neg"); }}, + {"+", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFAdd(ops[0], ops[1], "herbie.add"); + }}, + {"-", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFSub(ops[0], ops[1], "herbie.sub"); + }}, + {"*", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFMul(ops[0], ops[1], "herbie.mul"); + }}, + {"/", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFDiv(ops[0], ops[1], "herbie.div"); + }}, + {"fmin", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateBinaryIntrinsic(Intrinsic::minnum, ops[0], ops[1], + nullptr, "herbie.fmin"); + }}, + {"fmax", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateBinaryIntrinsic(Intrinsic::maxnum, ops[0], ops[1], + nullptr, "herbie.fmax"); + }}, + {"sin", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::sin, ops[0], nullptr, + "herbie.sin"); + }}, + {"cos", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::cos, ops[0], nullptr, + "herbie.cos"); + }}, + {"tan", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { +#if LLVM_VERSION_MAJOR > 16 + return b.CreateUnaryIntrinsic(Intrinsic::tan, ops[0], nullptr, + "herbie.tan"); +#else + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "tan" : "tanf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee tanFunc = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(tanFunc, {ops[0]}, "herbie.tan"); +#endif + }}, + {"exp", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::exp, ops[0], nullptr, + "herbie.exp"); + }}, + {"expm1", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "expm1" : "expm1f"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.expm1"); + }}, + {"log", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::log, ops[0], nullptr, + "herbie.log"); + }}, + {"log1p", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "log1p" : "log1pf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.log1p"); + }}, + {"sqrt", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::sqrt, ops[0], nullptr, + "herbie.sqrt"); + }}, + {"cbrt", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "cbrt" : "cbrtf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.cbrt"); + }}, + {"pow", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + // Use powi when possible + if (auto *CF = dyn_cast(ops[1])) { + double value = CF->getValueAPF().convertToDouble(); + if (value == std::floor(value) && value >= INT_MIN && + value <= INT_MAX) { + int exp = static_cast(value); + SmallVector overloadedTypes = { + ops[0]->getType(), Type::getInt32Ty(M->getContext())}; + Function *powiFunc = getIntrinsicDeclaration( + M, Intrinsic::powi, overloadedTypes); + Value *exponent = + ConstantInt::get(Type::getInt32Ty(M->getContext()), exp); + return b.CreateCall(powiFunc, {ops[0], exponent}, + "herbie.powi"); + } + } + + return b.CreateBinaryIntrinsic(Intrinsic::pow, ops[0], ops[1], + nullptr, "herbie.pow"); + }}, + {"fma", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateIntrinsic(Intrinsic::fma, {ops[0]->getType()}, + {ops[0], ops[1], ops[2]}, nullptr, + "herbie.fma"); + }}, + {"fabs", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::fabs, ops[0], nullptr, + "herbie.fabs"); + }}, + {"hypot", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "hypot" : "hypotf"; + FunctionType *FT = FunctionType::get(Ty, {Ty, Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0], ops[1]}, "herbie.hypot"); + }}, + {"asin", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { +#if LLVM_VERSION_MAJOR > 16 + return b.CreateUnaryIntrinsic(Intrinsic::asin, ops[0], nullptr, + "herbie.asin"); +#else + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "asin" : "asinf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.asin"); +#endif + }}, + {"acos", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { +#if LLVM_VERSION_MAJOR > 16 + return b.CreateUnaryIntrinsic(Intrinsic::acos, ops[0], nullptr, + "herbie.acos"); +#else + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "acos" : "acosf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.acos"); +#endif + }}, + {"atan", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { +#if LLVM_VERSION_MAJOR > 16 + return b.CreateUnaryIntrinsic(Intrinsic::atan, ops[0], nullptr, + "herbie.atan"); +#else + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "atan" : "atanf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.atan"); +#endif + }}, + {"atan2", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { +#if LLVM_VERSION_MAJOR > 16 + return b.CreateBinaryIntrinsic(Intrinsic::atan2, ops[0], ops[1], + nullptr, "herbie.atan2"); +#else + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "atan2" : "atan2f"; + FunctionType *FT = FunctionType::get(Ty, {Ty, Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0], ops[1]}, "herbie.atan2"); +#endif + }}, + {"sinh", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { +#if LLVM_VERSION_MAJOR > 16 + return b.CreateUnaryIntrinsic(Intrinsic::sinh, ops[0], nullptr, + "herbie.sinh"); +#else + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "sinh" : "sinhf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.sinh"); +#endif + }}, + {"cosh", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { +#if LLVM_VERSION_MAJOR > 16 + return b.CreateUnaryIntrinsic(Intrinsic::cosh, ops[0], nullptr, + "herbie.cosh"); +#else + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "cosh" : "coshf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.cosh"); +#endif + }}, + {"tanh", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { +#if LLVM_VERSION_MAJOR > 16 + return b.CreateUnaryIntrinsic(Intrinsic::tanh, ops[0], nullptr, + "herbie.tanh"); +#else + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "tanh" : "tanhf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.tanh"); +#endif + }}, + {"copysign", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateBinaryIntrinsic(Intrinsic::copysign, ops[0], ops[1], + nullptr, "herbie.copysign"); + }}, + {"rem", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFRem(ops[0], ops[1], "herbie.rem"); + }}, + {"ceil", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::ceil, ops[0], nullptr, + "herbie.ceil"); + }}, + {"floor", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::floor, ops[0], nullptr, + "herbie.floor"); + }}, + {"exp2", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::exp2, ops[0], nullptr, + "herbie.exp2"); + }}, + {"log10", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::log10, ops[0], nullptr, + "herbie.log10"); + }}, + {"log2", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::log2, ops[0], nullptr, + "herbie.log2"); + }}, + {"rint", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::rint, ops[0], nullptr, + "herbie.rint"); + }}, + {"round", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::round, ops[0], nullptr, + "herbie.round"); + }}, + {"trunc", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateUnaryIntrinsic(Intrinsic::trunc, ops[0], nullptr, + "herbie.trunc"); + }}, + {"fdim", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "fdim" : "fdimf"; + FunctionType *FT = FunctionType::get(Ty, {Ty, Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0], ops[1]}, "herbie.fdim"); + }}, + {"fmod", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "fmod" : "fmodf"; + FunctionType *FT = FunctionType::get(Ty, {Ty, Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0], ops[1]}, "herbie.fmod"); + }}, + {"remainder", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = + Ty->isDoubleTy() ? "remainder" : "remainderf"; + FunctionType *FT = FunctionType::get(Ty, {Ty, Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0], ops[1]}, "herbie.remainder"); + }}, + {"erf", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "erf" : "erff"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.erf"); + }}, + {"lgamma", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "lgamma" : "lgammaf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.lgamma"); + }}, + {"tgamma", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "tgamma" : "tgammaf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.tgamma"); + }}, + {"asinh", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "asinh" : "asinhf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.asinh"); + }}, + {"acosh", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "acosh" : "acoshf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.acosh"); + }}, + {"atanh", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + Type *Ty = ops[0]->getType(); + std::string funcName = Ty->isDoubleTy() ? "atanh" : "atanhf"; + FunctionType *FT = FunctionType::get(Ty, {Ty}, false); + FunctionCallee f = M->getOrInsertFunction(funcName, FT); + return b.CreateCall(f, {ops[0]}, "herbie.atanh"); + }}, + {"==", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFCmpOEQ(ops[0], ops[1], "herbie.eq"); + }}, + {"!=", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFCmpONE(ops[0], ops[1], "herbie.ne"); + }}, + {"<", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFCmpOLT(ops[0], ops[1], "herbie.lt"); + }}, + {">", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFCmpOGT(ops[0], ops[1], "herbie.gt"); + }}, + {"<=", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFCmpOLE(ops[0], ops[1], "herbie.le"); + }}, + {">=", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateFCmpOGE(ops[0], ops[1], "herbie.ge"); + }}, + {"and", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &ops) -> Value * { + return b.CreateAnd(ops[0], ops[1], "herbie.and"); + }}, + {"or", + [](IRBuilder<> &b, Module *M, const SmallVectorImpl &ops) + -> Value * { return b.CreateOr(ops[0], ops[1], "herbie.or"); }}, + {"not", + [](IRBuilder<> &b, Module *M, const SmallVectorImpl &ops) + -> Value * { return b.CreateNot(ops[0], "herbie.not"); }}, + {"TRUE", + [](IRBuilder<> &b, Module *M, const SmallVectorImpl &) + -> Value * { return ConstantInt::getTrue(b.getContext()); }}, + {"FALSE", + [](IRBuilder<> &b, Module *M, const SmallVectorImpl &) + -> Value * { return ConstantInt::getFalse(b.getContext()); }}, + {"PI", + [](IRBuilder<> &b, Module *M, const SmallVectorImpl &) + -> Value * { return ConstantFP::get(b.getDoubleTy(), M_PI); }}, + {"E", + [](IRBuilder<> &b, Module *M, const SmallVectorImpl &) + -> Value * { return ConstantFP::get(b.getDoubleTy(), M_E); }}, + {"INFINITY", + [](IRBuilder<> &b, Module *M, + const SmallVectorImpl &) -> Value * { + return ConstantFP::getInfinity(b.getDoubleTy(), false); + }}, + {"NaN", + [](IRBuilder<> &b, Module *M, const SmallVectorImpl &) + -> Value * { return ConstantFP::getNaN(b.getDoubleTy()); }}, + }; + + auto it = opMap.find(op); + if (it != opMap.end()) + return it->second(builder, M, operandValues); + else { + std::string msg = "FPNode getLLValue: Unexpected operator " + op; + llvm_unreachable(msg.c_str()); + } +} + +bool FPLLValue::hasSymbol() const { return !symbol.empty(); } + +std::string FPLLValue::toFullExpression( + std::unordered_map> &valueToNodeMap, + const SetVector &subgraphInputs, unsigned depth) { + // Check if this value is an input to the current subgraph + if (subgraphInputs.contains(value)) { + assert(hasSymbol() && "FPLLValue has no symbol!"); + return symbol; + } else { + assert(!operands.empty() && "FPNode has no operands!"); + + if (depth > FPOptMaxExprDepth) { + std::string msg = "Expression depth exceeded maximum allowed depth of " + + std::to_string(FPOptMaxExprDepth) + " for " + op + + "; consider disabling loop unrolling"; + + llvm_unreachable(msg.c_str()); + } + + std::string expr = "(" + (op == "neg" ? "-" : op); + for (auto operand : operands) { + expr += " " + operand->toFullExpression(valueToNodeMap, subgraphInputs, + depth + 1); + } + expr += ")"; + return expr; + } +} + +void FPLLValue::updateBounds(double lower, double upper) { + lb = std::min(lb, lower); + ub = std::max(ub, upper); + if (FPOptPrint) + llvm::errs() << "Updated bounds for " << *value << ": [" << lb << ", " << ub + << "]\n"; +} + +double FPLLValue::getLowerBound() const { return lb; } +double FPLLValue::getUpperBound() const { return ub; } + +Value *FPLLValue::getLLValue(IRBuilder<> &builder, + const ValueToValueMapTy *VMap) { + if (VMap) { + assert(VMap->count(value) && "FPLLValue not found in passed-in VMap!"); + return VMap->lookup(value); + } + return value; +} + +bool FPLLValue::classof(const FPNode *N) { + return N->getType() == NodeType::LLValue; +} + +std::string FPConst::toFullExpression( + std::unordered_map> &valueToNodeMap, + const SetVector &subgraphInputs, unsigned depth) { + return strValue; +} + +bool FPConst::hasSymbol() const { + std::string msg = "Unexpected invocation of `hasSymbol` on an FPConst"; + llvm_unreachable(msg.c_str()); +} + +void FPConst::updateBounds(double lower, double upper) { return; } + +double FPConst::getLowerBound() const { + if (strValue == "+inf.0") { + return std::numeric_limits::infinity(); + } else if (strValue == "-inf.0") { + return -std::numeric_limits::infinity(); + } + + double constantValue; + size_t div = strValue.find('/'); + + if (div != std::string::npos) { + std::string numerator = strValue.substr(0, div); + std::string denominator = strValue.substr(div + 1); + double num = stringToDouble(numerator); + double denom = stringToDouble(denominator); + + constantValue = num / denom; + } else { + constantValue = stringToDouble(strValue); + } + + return constantValue; +} + +double FPConst::getUpperBound() const { return getLowerBound(); } + +Value *FPConst::getLLValue(IRBuilder<> &builder, + const ValueToValueMapTy *VMap) { + Type *Ty; + if (dtype == "f64") { + Ty = builder.getDoubleTy(); + } else if (dtype == "f32") { + Ty = builder.getFloatTy(); + } else { + std::string msg = "FPConst getValue: Unexpected dtype: " + dtype; + llvm_unreachable(msg.c_str()); + } + if (strValue == "+inf.0") { + return ConstantFP::getInfinity(Ty, false); + } else if (strValue == "-inf.0") { + return ConstantFP::getInfinity(Ty, true); + } + + double constantValue; + size_t div = strValue.find('/'); + + if (div != std::string::npos) { + std::string numerator = strValue.substr(0, div); + std::string denominator = strValue.substr(div + 1); + double num = stringToDouble(numerator); + double denom = stringToDouble(denominator); + + constantValue = num / denom; + } else { + constantValue = stringToDouble(strValue); + } + + // if (FPOptPrint) + // llvm::errs() << "Returning " << strValue << " as " << dtype + // << " constant: " << constantValue << "\n"; + return ConstantFP::get(Ty, constantValue); +} + +bool FPConst::classof(const FPNode *N) { + return N->getType() == NodeType::Const; +} + +void CandidateOutput::apply( + size_t candidateIndex, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + // 4) parse the output string solution from herbieland + // 5) convert into a solution in llvm vals/instructions + + // if (FPOptPrint) + // llvm::errs() << "Parsing Herbie output: " << herbieOutput << "\n"; + auto parsedNode = parseHerbieExpr(candidates[candidateIndex].expr, + valueToNodeMap, symbolToValueMap); + // if (FPOptPrint) + // llvm::errs() << "Parsed Herbie output: " + // << parsedNode->toFullExpression(valueToNodeMap) << "\n"; + + auto *oldInst = cast(oldOutput); + + IRBuilder<> builder(oldInst->getParent(), ++BasicBlock::iterator(oldInst)); + builder.setFastMathFlags(oldInst->getFastMathFlags()); + + // auto *F = cast(oldOutput)->getParent()->getParent(); + // llvm::errs() << "Before: " << *F << "\n"; + Value *newOutput = parsedNode->getLLValue(builder); + assert(newOutput && "Failed to get value from parsed node"); + + oldOutput->replaceAllUsesWith(newOutput); + symbolToValueMap[valueToNodeMap[oldOutput]->symbol] = newOutput; + valueToNodeMap[newOutput] = std::make_shared( + newOutput, "__no", valueToNodeMap[oldOutput]->dtype); + + for (auto *I : erasableInsts) { + if (!I->use_empty()) + I->replaceAllUsesWith(UndefValue::get(I->getType())); + I->eraseFromParent(); + subgraph->operations.remove(I); // Avoid a second removal + cast(valueToNodeMap[I].get())->value = nullptr; + } + + // llvm::errs() << "After: " << *F << "\n"; + + subgraph->outputs_rewritten++; +} + +// Lower is better +InstructionCost CandidateOutput::getCompCostDelta(size_t candidateIndex) { + InstructionCost erasableCost = 0; + + for (auto *I : erasableInsts) { + erasableCost += getInstructionCompCost(I, *TTI); + } + + return (candidates[candidateIndex].CompCost - erasableCost) * executions; +} + +void CandidateOutput::findErasableInstructions() { + SmallPtrSet visited; + SmallPtrSet exprInsts; + collectExprInsts(oldOutput, subgraph->inputs, exprInsts, visited); + visited.clear(); + + SetVector instsToProcess(exprInsts.begin(), exprInsts.end()); + + SmallVector instsToProcessSorted; + reverseTopoSort(instsToProcess, instsToProcessSorted); + + // `oldOutput` is trivially erasable + erasableInsts.clear(); + erasableInsts.insert(cast(oldOutput)); + + for (auto *I : instsToProcessSorted) { + if (erasableInsts.contains(I)) + continue; + + bool usedOutside = false; + for (auto user : I->users()) { + if (auto *userI = dyn_cast(user)) { + if (erasableInsts.contains(userI)) { + continue; + } + } + // If the user is not an intruction or the user instruction is not an + // erasable instruction, then the current instruction is not erasable + // llvm::errs() << "Can't erase " << *I << " because of " << *user << + // "\n"; + usedOutside = true; + break; + } + + if (!usedOutside) { + erasableInsts.insert(I); + } + } + + // llvm::errs() << "Erasable instructions:\n"; + // for (auto *I : erasableInsts) { + // llvm::errs() << *I << "\n"; + // } + // llvm::errs() << "End of erasable instructions\n"; +} + +bool CandidateSubgraph::CacheKey::operator==(const CacheKey &other) const { + return candidateIndex == other.candidateIndex && + CandidateOutputs == other.CandidateOutputs; +} + +std::size_t +CandidateSubgraph::CacheKeyHash::operator()(const CacheKey &key) const { + std::size_t seed = std::hash{}(key.candidateIndex); + for (const auto *ao : key.CandidateOutputs) { + seed ^= std::hash{}(ao) + 0x9e3779b9 + + (seed << 6) + (seed >> 2); + } + return seed; +} + +void CandidateSubgraph::apply(size_t candidateIndex) { + if (candidateIndex >= candidates.size()) { + llvm_unreachable("Invalid candidate index"); + } + + // Traverse all the instructions to be changed precisions in a + // topological order with respect to operand dependencies. Insert FP casts + // between llvm::Value inputs and first level of instructions to be changed. + // Restore precisions of the last level of instructions to be changed. + candidates[candidateIndex].apply(*subgraph); +} + +// Lower is better +InstructionCost CandidateSubgraph::getCompCostDelta(size_t candidateIndex) { + // TODO: adjust this based on erasured instructions + return (candidates[candidateIndex].CompCost - initialCompCost) * executions; +} + +// Lower is better +double CandidateSubgraph::getAccCostDelta(size_t candidateIndex) { + return candidates[candidateIndex].accuracyCost - initialAccCost; +} + +// Lower is better +double CandidateOutput::getAccCostDelta(size_t candidateIndex) { + return candidates[candidateIndex].accuracyCost - initialAccCost; +} + +InstructionCost CandidateSubgraph::getAdjustedCompCostDelta( + size_t candidateIndex, const SmallVectorImpl &steps) { + CandidateOutputSet CandidateOutputs; + for (const auto &step : steps) { + if (auto *ptr = std::get_if(&step.item)) { + if ((*ptr)->subgraph == subgraph) { + CandidateOutputs.insert(*ptr); + } + } + } + + CacheKey key{candidateIndex, CandidateOutputs}; + + auto cacheIt = compCostDeltaCache.find(key); + if (cacheIt != compCostDeltaCache.end()) { + return cacheIt->second; + } + + Subgraph newSubgraph = *this->subgraph; + + for (auto &step : steps) { + if (auto *ptr = std::get_if(&step.item)) { + const auto &CO = **ptr; + if (CO.subgraph == subgraph) { + // Eliminate erasadable instructions from the adjusted CS + newSubgraph.operations.remove_if( + [&CO](Instruction *I) { return CO.erasableInsts.contains(I); }); + newSubgraph.outputs.remove(cast(CO.oldOutput)); + } + } + } + + // If all outputs are rewritten, then the adjusted CS is empty + if (newSubgraph.outputs.empty()) { + compCostDeltaCache[key] = 0; + return 0; + } + + InstructionCost initialCompCost = + getCompCost({newSubgraph.outputs.begin(), newSubgraph.outputs.end()}, + newSubgraph.inputs, TTI); + + InstructionCost candidateCompCost = + getCompCost(newSubgraph, TTI, candidates[candidateIndex]); + + InstructionCost adjustedCostDelta = + (candidateCompCost - initialCompCost) * executions; + // llvm::errs() << "Initial cost: " << initialCompCost << "\n"; + // llvm::errs() << "Candidate cost: " << candidateCompCost << "\n"; + // llvm::errs() << "Num executions: " << executions << "\n"; + // llvm::errs() << "Adjusted cost delta: " << adjustedCostDelta << "\n\n"; + + compCostDeltaCache[key] = adjustedCostDelta; + return adjustedCostDelta; +} + +double CandidateSubgraph::getAdjustedAccCostDelta( + size_t candidateIndex, SmallVectorImpl &steps, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap) { + CandidateOutputSet CandidateOutputs; + for (const auto &step : steps) { + if (auto *ptr = std::get_if(&step.item)) { + if ((*ptr)->subgraph == subgraph) { + CandidateOutputs.insert(*ptr); + } + } + } + + CacheKey key{candidateIndex, CandidateOutputs}; + + auto cacheIt = accCostDeltaCache.find(key); + if (cacheIt != accCostDeltaCache.end()) { + return cacheIt->second; + } + + double totalCandidateAccCost = 0.0; + double totalInitialAccCost = 0.0; + + // Collect erased output nodes + SmallPtrSet stepNodes; + for (const auto &step : steps) { + if (auto *ptr = std::get_if(&step.item)) { + const auto &CO = **ptr; + if (CO.subgraph == subgraph) { + auto it = valueToNodeMap.find(CO.oldOutput); + assert(it != valueToNodeMap.end() && it->second); + stepNodes.insert(it->second.get()); + } + } + } + + // Iterate over all output nodes and sum costs for nodes not erased + for (auto &[node, cost] : perOutputInitialAccCost) { + if (!stepNodes.count(node)) { + totalInitialAccCost += cost; + } + } + + for (auto &[node, cost] : candidates[candidateIndex].perOutputAccCost) { + if (!stepNodes.count(node)) { + totalCandidateAccCost += cost; + } + } + + double adjustedAccCostDelta = totalCandidateAccCost - totalInitialAccCost; + + accCostDeltaCache[key] = adjustedAccCostDelta; + return adjustedAccCostDelta; +} diff --git a/enzyme/Enzyme/Poseidon/PoseidonTypes.h b/enzyme/Enzyme/Poseidon/PoseidonTypes.h new file mode 100644 index 000000000000..def99cb727d5 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonTypes.h @@ -0,0 +1,252 @@ +//=- PoseidonTypes.h - AST node declarations for Poseidon -----------------=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the AST node classes for representing floating-point +// expressions in the Poseidon optimization pass. +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYME_POSEIDON_TYPES_H +#define ENZYME_POSEIDON_TYPES_H + +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/InstructionCost.h" +#include "llvm/Transforms/Utils/ValueMapper.h" + +#include +#include +#include +#include +#include +#include + +#include "PoseidonUtils.h" + +#include "PoseidonPrecUtils.h" + +using namespace llvm; + +class FPNode { +public: + enum class NodeType { Node, LLValue, Const }; + +private: + const NodeType ntype; + +public: + std::string op; + std::string dtype; + std::string symbol; + SmallVector, 2> operands; + double sens = + std::numeric_limits::quiet_NaN(); // Sensitivity score, sum of + // |grad * value| + double grad = + std::numeric_limits::quiet_NaN(); // Sum of gradients (not abs) + unsigned executions = 0; + + explicit FPNode(const std::string &op, const std::string &dtype) + : ntype(NodeType::Node), op(op), dtype(dtype) {} + explicit FPNode(NodeType ntype, const std::string &op, + const std::string &dtype) + : ntype(ntype), op(op), dtype(dtype) {} + virtual ~FPNode() = default; + + NodeType getType() const; + void addOperand(std::shared_ptr operand); + virtual bool hasSymbol() const; + virtual std::string toFullExpression( + std::unordered_map> &valueToNodeMap, + const SetVector &subgraphInputs, unsigned depth = 0); + unsigned getMPFRPrec() const; + virtual void updateBounds(double lower, double upper); + virtual double getLowerBound() const; + virtual double getUpperBound() const; + virtual Value *getLLValue(IRBuilder<> &builder, + const ValueToValueMapTy *VMap = nullptr); +}; + +class FPLLValue : public FPNode { +private: + double lb = std::numeric_limits::infinity(); + double ub = -std::numeric_limits::infinity(); + +public: + Value *value; + + explicit FPLLValue(Value *value, const std::string &op, + const std::string &dtype) + : FPNode(NodeType::LLValue, op, dtype), value(value) {} + + bool hasSymbol() const override; + std::string toFullExpression( + std::unordered_map> &valueToNodeMap, + const SetVector &subgraphInputs, unsigned depth = 0) override; + void updateBounds(double lower, double upper) override; + double getLowerBound() const override; + double getUpperBound() const override; + Value *getLLValue(IRBuilder<> &builder, + const ValueToValueMapTy *VMap = nullptr) override; + + static bool classof(const FPNode *N); +}; + +class FPConst : public FPNode { +private: + std::string strValue; + +public: + explicit FPConst(const std::string &strValue, const std::string &dtype) + : FPNode(NodeType::Const, "__const", dtype), strValue(strValue) {} + + std::string toFullExpression( + std::unordered_map> &valueToNodeMap, + const SetVector &subgraphInputs, unsigned depth = 0) override; + bool hasSymbol() const override; + void updateBounds(double lower, double upper) override; + double getLowerBound() const override; + double getUpperBound() const override; + Value *getLLValue(IRBuilder<> &builder, + const ValueToValueMapTy *VMap = nullptr) override; + + static bool classof(const FPNode *N); +}; + +struct Subgraph { + SetVector inputs; + SetVector outputs; + SetVector operations; + size_t outputs_rewritten = 0; + + Subgraph() = default; + explicit Subgraph(SetVector inputs, SetVector outputs, + SetVector operations) + : inputs(inputs), outputs(outputs), operations(operations) {} +}; + +struct SolutionStep; + +struct RewriteCandidate { + InstructionCost CompCost = std::numeric_limits::max(); + double herbieCost = std::numeric_limits::quiet_NaN(); + double herbieAccuracy = std::numeric_limits::quiet_NaN(); + double accuracyCost = std::numeric_limits::quiet_NaN(); + std::string expr; + + RewriteCandidate(double cost, double accuracy, std::string expression) + : herbieCost(cost), herbieAccuracy(accuracy), expr(expression) {} +}; + +class CandidateOutput { +public: + Subgraph *subgraph; + Value *oldOutput; + std::string expr; + double grad = std::numeric_limits::quiet_NaN(); + unsigned executions = 0; + const TargetTransformInfo *TTI = nullptr; + double initialAccCost = std::numeric_limits::quiet_NaN(); + InstructionCost initialCompCost = + std::numeric_limits::quiet_NaN(); + double initialHerbieCost = std::numeric_limits::quiet_NaN(); + double initialHerbieAccuracy = std::numeric_limits::quiet_NaN(); + SmallVector candidates; + SmallPtrSet erasableInsts; + + explicit CandidateOutput(Subgraph &subgraph, Value *oldOutput, + std::string expr, double grad, unsigned executions, + const TargetTransformInfo &TTI) + : subgraph(&subgraph), oldOutput(oldOutput), expr(expr), grad(grad), + executions(executions), TTI(&TTI) { + initialCompCost = getCompCost({oldOutput}, subgraph.inputs, TTI); + findErasableInstructions(); + } + + void + apply(size_t candidateIndex, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap); + InstructionCost getCompCostDelta(size_t candidateIndex); + double getAccCostDelta(size_t candidateIndex); + +private: + void findErasableInstructions(); +}; + +class CandidateSubgraph { +public: + Subgraph *subgraph; + const TargetTransformInfo &TTI; + double initialAccCost = std::numeric_limits::quiet_NaN(); + InstructionCost initialCompCost = + std::numeric_limits::quiet_NaN(); + unsigned executions = 0; + std::unordered_map perOutputInitialAccCost; + SmallVector candidates; + + using CandidateOutputSet = std::set; + struct CacheKey { + size_t candidateIndex; + CandidateOutputSet CandidateOutputs; + bool operator==(const CacheKey &other) const; + }; + + struct CacheKeyHash { + std::size_t operator()(const CacheKey &key) const; + }; + + std::unordered_map + compCostDeltaCache; + std::unordered_map accCostDeltaCache; + + explicit CandidateSubgraph(Subgraph &subgraph, const TargetTransformInfo &TTI) + : subgraph(&subgraph), TTI(TTI) { + initialCompCost = + getCompCost({subgraph.outputs.begin(), subgraph.outputs.end()}, + subgraph.inputs, TTI); + } + + void apply(size_t candidateIndex); + InstructionCost getCompCostDelta(size_t candidateIndex); + double getAccCostDelta(size_t candidateIndex); + InstructionCost + getAdjustedCompCostDelta(size_t candidateIndex, + const SmallVectorImpl &steps); + double getAdjustedAccCostDelta( + size_t candidateIndex, SmallVectorImpl &steps, + std::unordered_map> &valueToNodeMap, + std::unordered_map &symbolToValueMap); +}; + +struct SolutionStep { + std::variant item; + size_t candidateIndex; + + SolutionStep(CandidateOutput *ao_, size_t idx) + : item(ao_), candidateIndex(idx) {} + SolutionStep(CandidateSubgraph *acc_, size_t idx) + : item(acc_), candidateIndex(idx) {} +}; + +void getSampledPoints( + ArrayRef inputs, + const std::unordered_map> &valueToNodeMap, + const std::unordered_map &symbolToValueMap, + SmallVector, 4> &sampledPoints); + +void getSampledPoints( + const std::string &expr, + const std::unordered_map> &valueToNodeMap, + const std::unordered_map &symbolToValueMap, + SmallVector, 4> &sampledPoints); + +#endif // ENZYME_POSEIDON_TYPES_H \ No newline at end of file diff --git a/enzyme/Enzyme/Poseidon/PoseidonUtils.cpp b/enzyme/Enzyme/Poseidon/PoseidonUtils.cpp new file mode 100644 index 000000000000..343b6bcc61fb --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonUtils.cpp @@ -0,0 +1,1023 @@ +//=- PoseidonUtils.cpp - Utility functions for Poseidon optimization pass --=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements utility functions for the Poseidon floating-point +// optimization pass. +// +//===----------------------------------------------------------------------===// + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +#include "llvm/Analysis/TargetTransformInfo.h" + +#include "llvm/IR/Function.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Verifier.h" + +#include "llvm/Passes/PassBuilder.h" + +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InstructionCost.h" +#include "llvm/Support/raw_ostream.h" + +#include "llvm/Pass.h" + +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "PoseidonTypes.h" +#include "PoseidonUtils.h" + +using namespace llvm; + +extern "C" { +extern cl::opt FPOptPrint; +cl::opt + FPOptCostModelPath("fpopt-cost-model-path", cl::init(""), cl::Hidden, + cl::desc("Use a custom cost model in the FPOpt pass")); +cl::opt + FPOptNumSamples("fpopt-num-samples", cl::init(1024), cl::Hidden, + cl::desc("Number of sampled points for input hypercube")); +cl::opt + FPOptRandomSeed("fpopt-random-seed", cl::init(239778888), cl::Hidden, + cl::desc("The random seed used in the FPOpt pass")); +} + +void runPoseidonFunctionSimplify(Function &F, OptimizationLevel Level) { + + PassBuilder PB; + LoopAnalysisManager LAM; + FunctionAnalysisManager FAM; + CGSCCAnalysisManager CGAM; + ModuleAnalysisManager MAM; + PB.registerModuleAnalyses(MAM); + PB.registerCGSCCAnalyses(CGAM); + PB.registerFunctionAnalyses(FAM); + PB.registerLoopAnalyses(LAM); + PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); + + if (verifyFunction(F, &llvm::errs())) { + llvm_unreachable("Poseidon intermediate function failed verification"); + } + + FunctionPassManager FPM = + PB.buildFunctionSimplificationPipeline(Level, ThinOrFullLTOPhase::None); + (void)FPM.run(F, FAM); +} + +static const std::unordered_set LibmFuncs = { + "sin", "cos", "tan", "asin", "acos", "atan", "atan2", + "sinh", "cosh", "tanh", "asinh", "acosh", "atanh", "exp", + "log", "sqrt", "cbrt", "pow", "powi", "fabs", "fma", + "hypot", "expm1", "log1p", "ceil", "floor", "erf", "exp2", + "lgamma", "log10", "log2", "rint", "round", "tgamma", "trunc", + "copysign", "fdim", "fmod", "remainder"}; + +double getOneULP(double value) { + assert(!std::isnan(value) && !std::isinf(value)); + + double next = std::nextafter(value, std::numeric_limits::infinity()); + double ulp = std::fabs(next - value); + + return ulp; +} + +std::string getLibmFunctionForPrecision(StringRef funcName, Type *newType) { + std::string baseName = funcName.str(); + if (baseName.back() == 'f' || baseName.back() == 'l') { + baseName.pop_back(); + } + + if (LibmFuncs.count(baseName)) { + if (newType->isFloatTy()) { + return baseName + "f"; + } else if (newType->isDoubleTy()) { + return baseName; + } else if (newType->isFP128Ty() || newType->isX86_FP80Ty()) { + return baseName + "l"; + } + } + + return ""; +} + +double stringToDouble(const std::string &str) { + char *end; + errno = 0; + double result = std::strtod(str.c_str(), &end); + + if (errno == ERANGE) { + if (result == HUGE_VAL) { + result = std::numeric_limits::infinity(); + } else if (result == -HUGE_VAL) { + result = -std::numeric_limits::infinity(); + } + } + + return result; // Denormalized values are fine +} + +void topoSort(const SetVector &insts, + SmallVectorImpl &instsSorted) { + SmallPtrSet visited; + SmallPtrSet onStack; + + std::function dfsVisit = [&](Instruction *I) { + if (visited.count(I)) + return; + visited.insert(I); + onStack.insert(I); + + auto operands = + isa(I) ? cast(I)->args() : I->operands(); + for (auto &op : operands) { + if (isa(op)) { + Instruction *oI = cast(op); + if (insts.contains(oI)) { + if (onStack.count(oI)) { + llvm_unreachable( + "topoSort: Cycle detected in instruction dependencies!"); + } + dfsVisit(oI); + } + } + } + + onStack.erase(I); + instsSorted.push_back(I); + }; + + for (auto *I : insts) { + if (!visited.count(I)) { + dfsVisit(I); + } + } +} + +void reverseTopoSort(const SetVector &insts, + SmallVectorImpl &instsSorted) { + topoSort(insts, instsSorted); + std::reverse(instsSorted.begin(), instsSorted.end()); +} + +void getUniqueArgs(const std::string &expr, SmallSet &args) { + std::regex argPattern("v\\d+"); + + std::sregex_iterator begin(expr.begin(), expr.end(), argPattern); + std::sregex_iterator end; + + while (begin != end) { + args.insert(begin->str()); + ++begin; + } +} + +const std::map, InstructionCost> & +getCostModel() { + static std::map, InstructionCost> + CostModel; + static bool Loaded = false; + if (!Loaded) { + std::ifstream CostFile(FPOptCostModelPath); + if (!CostFile.is_open()) { + std::string msg = + "Cost model file could not be opened: " + FPOptCostModelPath; + llvm_unreachable(msg.c_str()); + } + std::string Line; + while (std::getline(CostFile, Line)) { + std::istringstream SS(Line); + std::string OpcodeStr, PrecisionStr, CostStr; + if (!std::getline(SS, OpcodeStr, ',')) { + llvm_unreachable( + ("Unexpected line in custom cost model: " + Line).c_str()); + } + if (!std::getline(SS, PrecisionStr, ',')) { + llvm_unreachable( + ("Unexpected line in custom cost model: " + Line).c_str()); + } + if (!std::getline(SS, CostStr)) { + llvm_unreachable( + ("Unexpected line in custom cost model: " + Line).c_str()); + } + CostModel[{OpcodeStr, PrecisionStr}] = std::stoi(CostStr); + } + Loaded = true; + } + return CostModel; +} + +InstructionCost queryCostModel(const std::string &OpcodeName, + const std::string &PrecisionName) { + const auto &CostModel = getCostModel(); + auto Key = std::make_pair(OpcodeName, PrecisionName); + auto It = CostModel.find(Key); + if (It != CostModel.end()) + return It->second; + + std::string msg = "Custom cost model: entry not found for " + OpcodeName + + " @ " + PrecisionName; + llvm::errs() << msg << "\n"; + llvm_unreachable(msg.c_str()); +} + +InstructionCost getInstructionCompCost(const Instruction *I, + const TargetTransformInfo &TTI) { + if (!I->getType()->isFPOrFPVectorTy()) + return 0; + + if (!FPOptCostModelPath.empty()) { + std::string OpcodeName; + switch (I->getOpcode()) { + case Instruction::FNeg: + OpcodeName = "fneg"; + break; + case Instruction::FAdd: + OpcodeName = "fadd"; + break; + case Instruction::FSub: + OpcodeName = "fsub"; + break; + case Instruction::FMul: + OpcodeName = "fmul"; + break; + case Instruction::FDiv: + OpcodeName = "fdiv"; + break; + case Instruction::FCmp: + OpcodeName = "fcmp"; + break; + case Instruction::FPExt: + OpcodeName = "fpext"; + break; + case Instruction::FPTrunc: + OpcodeName = "fptrunc"; + break; + case Instruction::PHI: + case Instruction::Select: + case Instruction::Load: + return 0; + case Instruction::Call: { + auto *Call = cast(I); + if (auto *CalledFunc = Call->getCalledFunction()) { + if (CalledFunc->isIntrinsic()) { + switch (CalledFunc->getIntrinsicID()) { + case Intrinsic::sin: + OpcodeName = "sin"; + break; + case Intrinsic::cos: + OpcodeName = "cos"; + break; +#if LLVM_VERSION_MAJOR > 16 + case Intrinsic::tan: + OpcodeName = "tan"; + break; + case Intrinsic::asin: + OpcodeName = "asin"; + break; + case Intrinsic::acos: + OpcodeName = "acos"; + break; + case Intrinsic::atan: + OpcodeName = "atan"; + break; + case Intrinsic::atan2: + OpcodeName = "atan2"; + break; + case Intrinsic::sinh: + OpcodeName = "sinh"; + break; + case Intrinsic::cosh: + OpcodeName = "cosh"; + break; + case Intrinsic::tanh: + OpcodeName = "tanh"; + break; +#endif + case Intrinsic::exp: + OpcodeName = "exp"; + break; + case Intrinsic::log: + OpcodeName = "log"; + break; + case Intrinsic::sqrt: + OpcodeName = "sqrt"; + break; + case Intrinsic::fabs: + OpcodeName = "fabs"; + break; + case Intrinsic::fma: + OpcodeName = "fma"; + break; + case Intrinsic::pow: + OpcodeName = "pow"; + break; + case Intrinsic::powi: + OpcodeName = "powi"; + break; + case Intrinsic::fmuladd: + OpcodeName = "fmuladd"; + break; + case Intrinsic::maxnum: + OpcodeName = "maxnum"; + break; + case Intrinsic::minnum: + OpcodeName = "minnum"; + break; + case Intrinsic::ceil: + OpcodeName = "ceil"; + break; + case Intrinsic::floor: + OpcodeName = "floor"; + break; + case Intrinsic::exp2: + OpcodeName = "exp2"; + break; + case Intrinsic::log10: + OpcodeName = "log10"; + break; + case Intrinsic::log2: + OpcodeName = "log2"; + break; + case Intrinsic::rint: + OpcodeName = "rint"; + break; + case Intrinsic::round: + OpcodeName = "round"; + break; + case Intrinsic::trunc: + OpcodeName = "trunc"; + break; + case Intrinsic::copysign: + OpcodeName = "copysign"; + break; + default: { + std::string msg = "Custom cost model: unsupported intrinsic " + + CalledFunc->getName().str(); + llvm_unreachable(msg.c_str()); + } + } + } else { + std::string FuncName = CalledFunc->getName().str(); + if (!FuncName.empty() && + (FuncName.back() == 'f' || FuncName.back() == 'l')) + FuncName.pop_back(); + + if (LibmFuncs.count(FuncName)) + OpcodeName = FuncName; + else { + std::string msg = + "Custom cost model: unknown function call " + FuncName; + llvm_unreachable(msg.c_str()); + } + } + } else { + llvm_unreachable("Custom cost model: unknown function call"); + } + break; + } + default: { + llvm::errs() << "Problematic instruction: " << *I << "\n"; + std::string msg = "Custom cost model: unexpected opcode " + + std::string(I->getOpcodeName()); + llvm_unreachable(msg.c_str()); + } + } + + std::string PrecisionName; + Type *Ty = I->getType(); + if (I->getOpcode() == Instruction::FCmp) + Ty = I->getOperand(0)->getType(); + + if (Ty->isBFloatTy()) + PrecisionName = "bf16"; + else if (Ty->isHalfTy()) + PrecisionName = "half"; + else if (Ty->isFloatTy()) + PrecisionName = "float"; + else if (Ty->isDoubleTy()) + PrecisionName = "double"; + else if (Ty->isX86_FP80Ty()) + PrecisionName = "fp80"; + else if (Ty->isFP128Ty()) + PrecisionName = "fp128"; + else { + std::string msg = "Custom cost model: unsupported precision type!"; + llvm_unreachable(msg.c_str()); + } + + // For FPExt/FPTrunc, update the opcode name to include conversion info. + if (I->getOpcode() == Instruction::FPExt || + I->getOpcode() == Instruction::FPTrunc) { + Type *SrcTy = I->getOperand(0)->getType(); + std::string SrcPrecisionName; + if (SrcTy->isBFloatTy()) + SrcPrecisionName = "bf16"; + else if (SrcTy->isHalfTy()) + SrcPrecisionName = "half"; + else if (SrcTy->isFloatTy()) + SrcPrecisionName = "float"; + else if (SrcTy->isDoubleTy()) + SrcPrecisionName = "double"; + else if (SrcTy->isX86_FP80Ty()) + SrcPrecisionName = "fp80"; + else if (SrcTy->isFP128Ty()) + SrcPrecisionName = "fp128"; + else { + std::string msg = "Custom cost model: unsupported precision type!"; + llvm_unreachable(msg.c_str()); + } + + OpcodeName += "_" + SrcPrecisionName + "_to_" + PrecisionName; + PrecisionName = SrcPrecisionName; + } + + return queryCostModel(OpcodeName, PrecisionName); + } else { + llvm::errs() << "WARNING: Custom cost model not found, using TTI cost!\n"; + return TTI.getInstructionCost(I, TargetTransformInfo::TCK_RecipThroughput); + } +} + +const std::unordered_set &getPTFuncs() { + static const std::unordered_set PTFuncs = []() { + if (FPOptCostModelPath.empty()) + return std::unordered_set{}; + std::unordered_set funcs; + for (const auto &func : LibmFuncs) { + InstructionCost costFP32 = queryCostModel(func, "float"); + InstructionCost costFP64 = queryCostModel(func, "double"); + InstructionCost costFPTrunc = + queryCostModel("fptrunc_double_to_float", "double"); + InstructionCost costFPExt = + queryCostModel("fpext_float_to_double", "float"); + InstructionCost costFPCast = costFPTrunc + costFPExt; + if (costFP32 + costFPCast < costFP64) + funcs.insert(func); + } + return funcs; + }(); + return PTFuncs; +} + +InstructionCost computeMaxCost( + BasicBlock *BB, std::unordered_map &MaxCost, + std::unordered_set &Visited, const TargetTransformInfo &TTI) { + if (MaxCost.find(BB) != MaxCost.end()) + return MaxCost[BB]; + + if (!Visited.insert(BB).second) + return 0; + + InstructionCost BBCost = 0; + for (const Instruction &I : *BB) { + if (I.isTerminator()) + continue; + + auto instCost = getInstructionCompCost(&I, TTI); + + // if (FPOptPrint) + // llvm::errs() << "Cost of " << I << " is: " << instCost << "\n"; + + BBCost += instCost; + } + + InstructionCost succCost = 0; + + if (!succ_empty(BB)) { + InstructionCost maxSuccCost = 0; + for (BasicBlock *Succ : successors(BB)) { + InstructionCost succBBCost = computeMaxCost(Succ, MaxCost, Visited, TTI); + if (succBBCost > maxSuccCost) + maxSuccCost = succBBCost; + } + // llvm::errs() << "Max succ cost: " << maxSuccCost << "\n"; + succCost = maxSuccCost; + } + + InstructionCost totalCost = BBCost + succCost; + // llvm::errs() << "BB " << BB->getName() << " cost: " << totalCost << "\n"; + MaxCost[BB] = totalCost; + Visited.erase(BB); + return totalCost; +} + +void getSampledPoints( + ArrayRef inputs, + const std::unordered_map> &valueToNodeMap, + const std::unordered_map &symbolToValueMap, + SmallVector, 4> &sampledPoints) { + std::default_random_engine gen; + gen.seed(FPOptRandomSeed); + std::uniform_real_distribution<> dis; + + MapVector> hypercube; + for (const auto input : inputs) { + const auto node = valueToNodeMap.at(input); + + double lower = node->getLowerBound(); + double upper = node->getUpperBound(); + + hypercube.insert({input, {lower, upper}}); + } + + // llvm::errs() << "Hypercube:\n"; + // for (const auto &entry : hypercube) { + // Value *val = entry.first; + // double lower = entry.second[0]; + // double upper = entry.second[1]; + // llvm::errs() << valueToNodeMap.at(val)->symbol << ": [" << lower << ", " + // << upper << "]\n"; + // } + + // Sample `FPOptNumSamples` points from the hypercube. Store it in + // `sampledPoints`. + sampledPoints.clear(); + sampledPoints.resize(FPOptNumSamples); + for (size_t i = 0; i < FPOptNumSamples; ++i) { + MapVector point; + for (const auto &entry : hypercube) { + Value *val = entry.first; + double lower = entry.second[0]; + double upper = entry.second[1]; + double sample = dis(gen, decltype(dis)::param_type{lower, upper}); + point.insert({val, sample}); + } + sampledPoints[i] = point; + // llvm::errs() << "Sample " << i << ":\n"; + // for (const auto &entry : point) { + // llvm::errs() << valueToNodeMap.at(entry.first)->symbol << ": " + // << entry.second << "\n"; + // } + } +} + +void getSampledPoints( + const std::string &expr, + const std::unordered_map> &valueToNodeMap, + const std::unordered_map &symbolToValueMap, + SmallVector, 4> &sampledPoints) { + SmallSet argStrSet; + getUniqueArgs(expr, argStrSet); + + SmallVector inputs; + for (const auto &argStr : argStrSet) { + inputs.push_back(symbolToValueMap.at(argStr)); + } + + getSampledPoints(inputs, valueToNodeMap, symbolToValueMap, sampledPoints); +} + +InstructionCost getCompCost(Function *F, const TargetTransformInfo &TTI) { + std::unordered_map MaxCost; + std::unordered_set Visited; + + BasicBlock *EntryBB = &F->getEntryBlock(); + InstructionCost TotalCost = computeMaxCost(EntryBB, MaxCost, Visited, TTI); + // llvm::errs() << "Total cost: " << TotalCost << "\n"; + return TotalCost; +} + +// Sum up the cost of `output` and its FP operands recursively up to `inputs` +// (exclusive). +InstructionCost getCompCost(const SmallVector &outputs, + const SetVector &inputs, + const TargetTransformInfo &TTI) { + assert(!outputs.empty()); + SmallPtrSet seen; + SmallVector todo; + InstructionCost cost = 0; + + todo.insert(todo.end(), outputs.begin(), outputs.end()); + while (!todo.empty()) { + auto cur = todo.pop_back_val(); + if (!seen.insert(cur).second) + continue; + + if (inputs.contains(cur)) + continue; + + if (auto *I = dyn_cast(cur)) { + // TODO: unfair to ignore branches when calculating cost + auto instCost = getInstructionCompCost(I, TTI); + + // if (FPOptPrint) + // llvm::errs() << "Cost of " << *I << " is: " << instCost << "\n"; + + // Only add the cost of the instruction if it is not an input + cost += instCost; + + auto operands = + isa(I) ? cast(I)->args() : I->operands(); + for (auto &operand : operands) { + todo.push_back(operand); + } + } + } + + return cost; +} + +void collectExprInsts(Value *V, const SetVector &inputs, + SmallPtrSetImpl &exprInsts, + SmallPtrSetImpl &visited) { + if (!V || inputs.contains(V) || visited.contains(V)) { + return; + } + + visited.insert(V); + + if (auto *I = dyn_cast(V)) { + exprInsts.insert(I); + + auto operands = + isa(I) ? cast(I)->args() : I->operands(); + + for (auto &op : operands) { + collectExprInsts(op, inputs, exprInsts, visited); + } + } +} + +bool isExpansionBottleneck(Instruction *I, const Subgraph &subgraph) { + // Allow splitting at outputs if there are multiple outputs + if (subgraph.outputs.contains(I) && subgraph.outputs.size() <= 1) { + return false; + } + + // Criteria 1: Number of internal uses + unsigned internalUses = 0; + for (auto *U : I->users()) { + if (auto *UI = dyn_cast(U)) { + if (subgraph.operations.contains(UI)) { + ++internalUses; + } + } + } + + if (internalUses < FPOptMinUsesForSplit) { + return false; + } + + // Criteria 2: Number of upstream internal operations (complexity) + SetVector relevantTree; + SmallVector worklist; + worklist.push_back(I); + relevantTree.insert(I); + + while (!worklist.empty()) { + Instruction *current = worklist.pop_back_val(); + auto operands = isa(current) ? cast(current)->args() + : current->operands(); + for (auto &op : operands) { + if (subgraph.inputs.contains(op)) { + continue; + } + if (auto *OpI = dyn_cast(op)) { + if (subgraph.operations.contains(OpI) && !relevantTree.contains(OpI)) { + worklist.push_back(OpI); + relevantTree.insert(OpI); + } + } + } + } + + SetVector movedOps; + SetVector keptSubtreeRoots; + + for (auto *op : relevantTree) { + if (op == I) { + movedOps.insert(op); + continue; + } + + bool hasExternalUse = false; + for (auto *U : op->users()) { + if (auto *UI = dyn_cast(U)) { + if (!relevantTree.contains(UI)) { + hasExternalUse = true; + break; + } + } + } + + if (hasExternalUse) { + keptSubtreeRoots.insert(op); + } else { + movedOps.insert(op); + } + } + + for (auto *keptRoot : keptSubtreeRoots) { + SetVector keptUpstream; + SmallVector keptWorklist; + keptWorklist.push_back(keptRoot); + keptUpstream.insert(keptRoot); + + while (!keptWorklist.empty()) { + Instruction *current = keptWorklist.pop_back_val(); + auto operands = isa(current) ? cast(current)->args() + : current->operands(); + for (auto &op : operands) { + if (subgraph.inputs.contains(op)) { + continue; + } + if (auto *OpI = dyn_cast(op)) { + if (relevantTree.contains(OpI) && !keptUpstream.contains(OpI)) { + keptWorklist.push_back(OpI); + keptUpstream.insert(OpI); + movedOps.remove(OpI); + } + } + } + } + } + + bool isBottleneck = movedOps.size() >= FPOptMinOpsForSplit; + if (FPOptPrint && isBottleneck) { + llvm::errs() << "Bottleneck: " << *I << "\n"; + llvm::errs() << "Num of operations that would be moved: " << movedOps.size() + << " (>=" << FPOptMinOpsForSplit << ")\n"; + llvm::errs() << "Num of internal uses: " << internalUses + << " (>=" << FPOptMinUsesForSplit << ")\n"; + llvm::errs() << "Operations that would be moved:\n"; + for (auto *op : movedOps) { + llvm::errs() << "\t" << *op << "\n"; + } + } + return isBottleneck; +} + +SetVector +findReachedInputs(const SetVector &operations) { + SetVector reachedInputs; + + for (auto *I : operations) { + auto operands = + isa(I) ? cast(I)->args() : I->operands(); + for (auto &op : operands) { + if (auto *OpI = dyn_cast(op)) { + if (operations.contains(OpI)) { + continue; + } + } + + reachedInputs.insert(op); + } + } + + return reachedInputs; +} + +void splitSubgraphAtBottleneck(Subgraph ¤tSubgraph, + Instruction *bottleneck, Subgraph &newSubgraph, + Subgraph &remainingSubgraph) { + + // 1. Get all upstream ops from bottleneck (relevant tree) + SetVector relevantTree; + SmallVector worklist; + worklist.push_back(bottleneck); + relevantTree.insert(bottleneck); + + while (!worklist.empty()) { + auto current = worklist.pop_back_val(); + auto operands = isa(current) ? cast(current)->args() + : current->operands(); + for (auto &op : operands) { + if (currentSubgraph.inputs.contains(op)) { + continue; + } + if (auto *OpI = dyn_cast(op)) { + if (currentSubgraph.operations.contains(OpI) && + !relevantTree.contains(OpI)) { + worklist.push_back(OpI); + relevantTree.insert(OpI); + } + } + } + } + + // 2. Move ops that have external uses to the new subgraph + SetVector movedOps; + SetVector keptSubtreeRoots; + + for (auto op : relevantTree) { + // The bottleneck is always moved to the new subgraph + if (op == bottleneck) { + movedOps.insert(op); + continue; + } + + bool hasExternalUse = false; + for (auto U : op->users()) { + if (auto UI = dyn_cast(U)) { + if (!relevantTree.contains(UI)) { + hasExternalUse = true; + break; + } + } + } + + if (hasExternalUse) { + keptSubtreeRoots.insert(op); + } else { + movedOps.insert(op); + } + } + + // 3. Remove upstream ops of kept subtree roots from movedOps + for (auto keptRoot : keptSubtreeRoots) { + SetVector keptUpstream; + SmallVector keptWorklist; + keptWorklist.push_back(keptRoot); + keptUpstream.insert(keptRoot); + + while (!keptWorklist.empty()) { + Instruction *current = keptWorklist.pop_back_val(); + auto operands = isa(current) ? cast(current)->args() + : current->operands(); + for (auto &op : operands) { + if (currentSubgraph.inputs.contains(op)) { + continue; + } + if (auto *OpI = dyn_cast(op)) { + if (relevantTree.contains(OpI) && !keptUpstream.contains(OpI)) { + keptWorklist.push_back(OpI); + keptUpstream.insert(OpI); + movedOps.remove(OpI); + } + } + } + } + } + + // 4. Build new subgraph. + // `operations`: dependencies of `movedOps` + // `inputs`: reached inputs of `operations` + // `outputs`: the bottleneck instruction + newSubgraph.operations = movedOps; + newSubgraph.outputs.insert(bottleneck); + newSubgraph.inputs = findReachedInputs(newSubgraph.operations); + + // 5. Build remaining subgraph. + // `operations`: unmoved ops of `currentSubgraph.operations` + // `inputs`: reached inputs of `operations` + // `outputs`: the outputs of the original subgraph (excluding bottleneck if it + // was an output) + for (auto I : currentSubgraph.operations) { + if (!movedOps.contains(I)) { + remainingSubgraph.operations.insert(I); + } + } + + remainingSubgraph.outputs = currentSubgraph.outputs; + if (currentSubgraph.outputs.contains(bottleneck)) { + remainingSubgraph.outputs.remove(bottleneck); + } + + remainingSubgraph.inputs = findReachedInputs(remainingSubgraph.operations); + assert(remainingSubgraph.inputs.contains(bottleneck)); +} + +void splitSubgraphs(SmallVectorImpl &subgraphs) { + + SmallVector resultSubgraphs; + SmallVector workQueue; + + for (const auto &subgraph : subgraphs) { + workQueue.push_back(subgraph); + } + + while (!workQueue.empty()) { + Subgraph currentSubgraph = workQueue.pop_back_val(); + + if (currentSubgraph.operations.size() <= 2) { + resultSubgraphs.push_back(currentSubgraph); + continue; + } + + // Sort all operations in topological order (inputs to outputs) + SmallVector sortedOps; + topoSort(currentSubgraph.operations, sortedOps); + + bool madeSplit = false; + for (auto *I : sortedOps) { + assert(currentSubgraph.operations.contains(I)); + + if (isExpansionBottleneck(I, currentSubgraph)) { + if (FPOptPrint) { + llvm::errs() << "Bottleneck: " << *I << "\n"; + llvm::errs() << "currentSubgraph.inputs (" + << currentSubgraph.inputs.size() << "): "; + for (auto *input : currentSubgraph.inputs) { + llvm::errs() << "\t" << *input << "\n"; + } + llvm::errs() << "\n"; + llvm::errs() << "currentSubgraph.operations (" + << currentSubgraph.operations.size() << "): "; + for (auto *op : currentSubgraph.operations) { + llvm::errs() << "\t" << *op << "\n"; + } + llvm::errs() << "\n"; + llvm::errs() << "currentSubgraph.outputs (" + << currentSubgraph.outputs.size() << "): "; + for (auto *output : currentSubgraph.outputs) { + llvm::errs() << "\t" << *output << "\n"; + } + llvm::errs() << "\n"; + } + + Subgraph newSubgraph, remainingSubgraph; + splitSubgraphAtBottleneck(currentSubgraph, I, newSubgraph, + remainingSubgraph); + + if (FPOptPrint) { + llvm::errs() << "=== Splitting subgraph at bottleneck: " << *I + << " ===\n"; + + llvm::errs() << " New subgraph:\n"; + llvm::errs() << " Inputs (" << newSubgraph.inputs.size() << "):\n"; + for (auto *input : newSubgraph.inputs) { + llvm::errs() << " " << *input << "\n"; + } + llvm::errs() << " Operations (" << newSubgraph.operations.size() + << "):\n"; + for (auto *op : newSubgraph.operations) { + llvm::errs() << " " << *op << "\n"; + } + llvm::errs() << " Outputs (" << newSubgraph.outputs.size() + << "):\n"; + for (auto *output : newSubgraph.outputs) { + llvm::errs() << " " << *output << "\n"; + } + + llvm::errs() << " Remaining subgraph:\n"; + + llvm::errs() << " Inputs (" << remainingSubgraph.inputs.size() + << "):\n"; + for (auto *input : remainingSubgraph.inputs) { + llvm::errs() << " " << *input << "\n"; + } + llvm::errs() << " Operations (" + << remainingSubgraph.operations.size() << "):\n"; + for (auto *op : remainingSubgraph.operations) { + llvm::errs() << " " << *op << "\n"; + } + llvm::errs() << " Outputs (" << remainingSubgraph.outputs.size() + << "):\n"; + for (auto *output : remainingSubgraph.outputs) { + llvm::errs() << " " << *output << "\n"; + } + } + + resultSubgraphs.push_back(newSubgraph); + currentSubgraph = remainingSubgraph; + madeSplit = true; + } + } + + if (madeSplit) { + workQueue.push_back(currentSubgraph); + } else { + resultSubgraphs.push_back(currentSubgraph); + } + } + + if (FPOptPrint) { + llvm::errs() << "=== Subgraph splitting complete ===\n"; + llvm::errs() << " Original subgraphs: " << subgraphs.size() << "\n"; + llvm::errs() << " Final subgraphs after splitting: " + << resultSubgraphs.size() << "\n"; + } + + subgraphs = std::move(resultSubgraphs); +} diff --git a/enzyme/Enzyme/Poseidon/PoseidonUtils.h b/enzyme/Enzyme/Poseidon/PoseidonUtils.h new file mode 100644 index 000000000000..4eae75e0f3ab --- /dev/null +++ b/enzyme/Enzyme/Poseidon/PoseidonUtils.h @@ -0,0 +1,83 @@ +//=- PoseidonUtils.h - Utility functions for Poseidon optimization pass ----=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares utility functions for the Poseidon floating-point +// optimization pass. +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYME_POSEIDON_UTILS_H +#define ENZYME_POSEIDON_UTILS_H + +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Type.h" +#include "llvm/Passes/OptimizationLevel.h" +#include "llvm/Support/InstructionCost.h" + +#include +#include +#include +#include + +using namespace llvm; + +extern "C" { +extern llvm::cl::opt FPOptCostModelPath; +extern llvm::cl::opt FPOptNumSamples; +extern llvm::cl::opt FPOptRandomSeed; +extern llvm::cl::opt FPOptMinUsesForSplit; +extern llvm::cl::opt FPOptMinOpsForSplit; +} + +struct Subgraph; +class FPNode; + +// Utility function declarations +double getOneULP(double value); +std::string getLibmFunctionForPrecision(StringRef funcName, Type *newType); +double stringToDouble(const std::string &str); +void topoSort(const SetVector &insts, + SmallVectorImpl &instsSorted); +void reverseTopoSort(const SetVector &insts, + SmallVectorImpl &instsSorted); + +void getUniqueArgs(const std::string &expr, SmallSet &args); + +const std::map, InstructionCost> & +getCostModel(); +InstructionCost queryCostModel(const std::string &OpcodeName, + const std::string &TypeName); +InstructionCost getInstructionCompCost(const Instruction *I, + const TargetTransformInfo &TTI); + +const std::unordered_set &getPTFuncs(); + +InstructionCost computeMaxCost( + BasicBlock *BB, std::unordered_map &MaxCost, + std::unordered_set &Visited, const TargetTransformInfo &TTI); + +InstructionCost getCompCost(Function *F, const TargetTransformInfo &TTI); + +InstructionCost getCompCost(const SmallVector &outputs, + const SetVector &inputs, + const TargetTransformInfo &TTI); + +void collectExprInsts(Value *V, const SetVector &inputs, + SmallPtrSetImpl &exprInsts, + SmallPtrSetImpl &visited); + +void splitSubgraphs(SmallVectorImpl &subgraphs); + +void runPoseidonFunctionSimplify(Function &F, OptimizationLevel Level); + +#endif // ENZYME_POSEIDON_UTILS_H diff --git a/enzyme/Enzyme/Poseidon/README.md b/enzyme/Enzyme/Poseidon/README.md new file mode 100644 index 000000000000..6b017f8da2be --- /dev/null +++ b/enzyme/Enzyme/Poseidon/README.md @@ -0,0 +1,382 @@ +# Poseidon + +Poseidon is a modular and extensible framework that fully automates advanced floating-point rewriting techniques for real-world applications within a production compiler. It operates as a PGO-like two-phase compiler that automatically extract numerical context (e.g., value ranges, sensitivities) from small surrogate profiling runs. It synthesizes algebraic rewrites via [Herbie](https://herbie.uwplse.org/), generates precision tuning candidates, and uses a dynamic programming solver to find a Pareto frontier of optimized programs. + +For details, please read our paper [Thinking Fast and Correct: Automated Rewriting of Numerical Code through Compiler Augmentation](https://ece.is/assets/pdf/poseidon-cgo26.pdf) (CGO 2026). + +If you use Poseidon in an academic setting, please kindly cite: + +```bibtex +@inproceedings{poseidon, + author={Qian, Siyuan Brant and Sathia, Vimarsh and Ivanov, Ivan R. and H\"{u}ckelheim, Jan and Hovland, Paul and Moses, William S.}, + booktitle={2026 IEEE/ACM International Symposium on Code Generation and Optimization (CGO)}, + title={Thinking Fast and Correct: Automated Rewriting of Numerical Code through Compiler Augmentation}, + year={2026}, + pages={548-562}, + doi={10.1109/CGO68049.2026.11395228} +} +``` + +## Table of Contents + +- [Build](#build) +- [Two-Phase Pipeline](#two-phase-pipeline) +- [How to Apply Rewrites?](#how-to-apply-rewrites) +- [Reporting](#reporting) + - [Per-Rewrite Analysis](#per-rewrite-analysis) + - [Applying User-Selected Rewrites](#applying-user-selected-rewrites) +- [Optimized Program Validation](#optimized-program-validation) +- [Generating MPFR References with RAPTOR](#optional-generating-mpfr-references-with-raptor) +- [Command-Line Reference](#command-line-reference) + + +## Build + +We recommend building from source and generating a hardware-specific cost model for best results. The cost model estimates per-operation latencies on your machine and is used by the DP solver to estimate computation costs. + +### Prerequisites + +```bash +sudo apt install build-essential cmake ninja-build libmpfr-dev +pip install lit numpy matplotlib tqdm +``` + +Additionally, install [Racket](https://racket-lang.org/) and [Rust](https://www.rust-lang.org/tools/install). + +### Build LLVM + +```bash +cd llvm-project +mkdir build && cd build +cmake -G Ninja \ + -DLLVM_ENABLE_PROJECTS="clang" \ + -DLLVM_ENABLE_LLD=ON \ + -DLLVM_TARGETS_TO_BUILD="X86" \ + -DCMAKE_BUILD_TYPE=Release \ + ../llvm +ninja +cd ../.. +``` + +### Build Enzyme with Poseidon Enabled + +```bash +cd Enzyme +mkdir build && cd build +cmake -G Ninja ../enzyme/ \ + -DLLVM_DIR=<...>/llvm-project/build/lib/cmake/llvm \ + -DLLVM_EXTERNAL_LIT=$(which lit) \ + -DCMAKE_BUILD_TYPE=Release \ + -DENABLE_POSEIDON=ON \ + -DCMAKE_C_COMPILER=<...>/llvm-project/build/bin/clang \ + -DCMAKE_CXX_COMPILER=<...>/llvm-project/build/bin/clang++ +ninja +cd ../.. +``` + +Replace `<...>` with the appropriate path prefix. + +A preconfigured Docker image is also available; see the [CGO artifact repository](https://github.com/PRONTOLab/Poseidon) for details. + +### Generate the Cost Model + +The cost model (`cm.csv`) is hardware-specific. A benchmarking script is provided at `Poseidon/scripts/microbm.py`. To generate the cost model for your machine: + +```bash +python3 /Enzyme/Poseidon/scripts/microbm.py +cp results.csv cm.csv +``` + +Then pass `--fpopt-cost-model-path=cm.csv` during the optimization phase. An example cost model (generated on AMD Ryzen Threadripper PRO 7995WX) is provided at `Poseidon/scripts/cm_example.csv` for reference. Without a cost model, Poseidon falls back to LLVM's `TargetTransformInfo` estimates. + +## Two-Phase Pipeline + +See the [dquat benchmark](https://github.com/PRONTOLab/Poseidon/tree/main/dquat) for an end-to-end example. + +:warning: Please Note: Both phases we introduce below must use identical compiler flags (e.g., `-O3`, `-ffast-math`, `-march=native`, etc.) to ensure profile indices match between compilations. + +### User Code Modification + +Annotate the function to optimize with `__enzyme_fp_optimize`: + +```cpp +template +return_type __enzyme_fp_optimize(void *, T...); +int enzyme_dup; + +__attribute__((noinline)) +void my_kernel(double *out, const double *in); + +void run() { + double out[N], out_grad[N], in[M], in_grad[M]; + // Gradients seed the sensitivity analysis (set to 1.0 for uniform weighting) + std::fill(out_grad, out_grad + N, 1.0); + __enzyme_fp_optimize((void *)my_kernel, + enzyme_dup, out, out_grad, + enzyme_dup, in, in_grad); +} +``` + +### Phase 1: Profiling + +```bash +clang++ mycode.cpp $CXXFLAGS \ + -fpass-plugin=ClangEnzyme-XX.so -Xclang -load -Xclang ClangEnzyme-XX.so \ + -mllvm --fpprofile-generate \ + -L$ENZYME_BUILD/Enzyme -lEnzymeFPProfile -lm -o mycode-prof + +ENZYME_FPPROFILE_DIR=./fpprofile ./mycode-prof +``` + +The instrumented binary records per-instruction value ranges, gradient sensitivities, and execution counts into a `.fpprofile` directory. A small surrogate input (e.g., 100 samples) is often sufficient. + +### Phase 2: Optimization + +```bash +clang++ mycode.cpp $CXXFLAGS \ + -fpass-plugin=ClangEnzyme-XX.so -Xclang -load -Xclang ClangEnzyme-XX.so \ + -mllvm --fpprofile-use=./fpprofile \ + -mllvm --fpopt-enable-herbie=1 \ + -mllvm --fpopt-enable-solver \ + -mllvm --fpopt-enable-pt \ + -mllvm --fpopt-comp-cost-budget=0 \ + -mllvm --fpopt-cache-path=./cache \ + -mllvm --fpopt-cost-model-path=cm.csv \ + -mllvm --fpopt-strict-mode \ + -lmpfr -lm -o mycode-opt +``` + +The first run invokes Herbie and the DP solver; results are cached in `--fpopt-cache-path`. Subsequent runs reuse the cache. The `--fpopt-comp-cost-budget` selects a point on the Pareto curve (0 = no-op baseline). + +## How to Apply Rewrites? + +Poseidon offers two ways to apply numerical rewrites: + +1. **Select an optimized program from the Pareto frontier.** Pass `--fpopt-comp-cost-budget=N` to pick a point computed by the DP solver. Each budget value corresponds to a different combination of rewrites and precision changes that the solver determined to be Pareto-optimal. The full set of available budgets is listed in `validate_config.json` (when using `--fpopt-report-path`) and `cache/budgets.txt`. + +2. **Obtain a custom optimized program.** Generate a [report](#reporting), review the individual rewrites in `_rewrites.json`, and pass the IDs of the ones you want via `--fpopt-apply-rewrites=R3_0,PT1_0,...`. This bypasses the DP solver and gives fine-grained control over exactly which rewrites are applied. See [Applying User-Selected Rewrites](#applying-user-selected-rewrites) for details. + +## Reporting + +By default, Poseidon silently applies the best solution within the given budget to the compiled binary. To understand *what* was changed and *why*, add `--fpopt-report-path=` (with `-g` for source locations): + +```bash +clang++ mycode.cpp $CXXFLAGS -g ... \ + -mllvm --fpopt-report-path=./report +``` + +### Report Contents + +| File | Description | +|------|-------------| +| `.json` | Full Pareto table with details | +| `.txt` | Plain-text version of the Pareto table | +| `_rewrites.json` | Detailed per-rewrite information: all rewrites categorized and ranked | + +Source locations (file, line, column) are populated when compiled with `-g`. Without `-g`, the symbolic expressions and affected IR are still available. + +:bulb: For complex programs/rewrites, it may be beneficial to feed the `.txt` report and the program source code to an LLM and asking it to explain what each rewrite does and why it improves numerical accuracy. + +### Understanding the Pareto Report + +The `.json` and `.txt` reports describe the full Pareto frontier. Each Pareto point represents one **optimized program** — a combination of rewrites and precision changes selected by the DP solver at a given computation cost budget. For example: + +``` +--- Pareto Point #5: Cost=-6863264, Accuracy=3.964900e-01 --- + [Rewrite] (* (sqrt (fma v6 v6 (fma v5 v5 (fma v4 v4 0)))) 0.25) + --> (*.f64 (sqrt.f64 ...) #s(literal 1/4 binary64)) + Herbie accuracy: 0.935 -> 0.999 bits + Source: dquat.cpp:87:18 + Source: dquat.cpp:56:11 + Affected IR: + %mul = fmul fast double %sqrt, 2.500000e-01 +``` + +### Per-Rewrite Analysis + +The `_rewrites.json` file lists every **individual** rewrite candidate, categorized by its estimated impact: + +- **`free_win`** — improves both accuracy and speed. Always beneficial. +- **`accuracy_for_speed`** — improves accuracy at the cost of extra computation. +- **`speed_for_accuracy`** — improves speed at the cost of some accuracy. + +Each entry includes an `efficiency` score that ranks tradeoffs: higher means more benefit per unit of cost. Entries are sorted by category (free wins first) then by efficiency descending. + +Each rewrite has a stable `id` (e.g., `R3_0`, `PT1_2`) that can be used with `--fpopt-apply-rewrites` (see below). + +Example: +```json +{ + "id": "R6_3", + "category": "speed_for_accuracy", + "efficiency": 4.602e+11, + "computation_cost_delta": -22632, + "accuracy_cost_delta": 4.918e-08, + "original_expr": "(* (/ 1 (sqrt ...)) v5)", + "rewritten_expr": "#s(approx ...)", + "source_location": {"file": "dquat.cpp", "line": 116, "col": 18} +} +``` + +### Applying User-Selected Rewrites + +A user can pick specific rewrites from `_rewrites.json` by their IDs (this bypasses the DP solver): + +```bash +clang++ mycode.cpp $CXXFLAGS ... \ + -mllvm --fpopt-apply-rewrites=R5_5,R6_3,R7_3,PT1_0 +``` + +With a few constraints: +- At most one rewrite per expression (e.g., `R5_0` and `R5_1` conflict) +- At most one precision change per subgraph (e.g., `PT1_0` and `PT1_2` conflict) + +Duplicates are detected and skipped with a warning. + +## Optimized Program Validation + +To validate the actual accuracy and runtime of optimized programs, a reference validation script is provided at `Poseidon/scripts/validate.py`. Copy it to your report directory alongside `validate_config.json`, then run: + +```bash +cp /Enzyme/Poseidon/scripts/validate.py ./report/ +python3 report/validate.py mycode.cpp \ + --enzyme-plugin /path/to/ClangEnzyme-XX.so \ + --cxx /path/to/clang++ \ + --extra-flags "-O3 -ffast-math -march=native -fno-exceptions -lmpfr" \ + --extra-run-args "--num-tests 100 --seed 42" \ + --gold-path gold_mpfr.txt \ + --num-samples 10 \ + --num-runs 5 +``` + +### What it does + +1. **Loads a gold accuracy reference** from `--gold-path` (MPFR ground truth; see [RAPTOR section](#generating-mpfr-gold-references-with-raptor) below) +2. **Compiles the original** (unoptimized) binary, measures its runtime and accuracy against gold +3. **Uniformly samples** N budgets from the full Pareto table (87 points for dquat with `-ffast-math`) +4. **For each budget:** recompiles with Poseidon at that budget using the cached Herbie results + DP table, runs the resulting binary, captures output +5. **Computes** geomean and max relative error against gold for each variant +6. **Measures** median runtime over multiple runs +7. **Prints a summary table** showing the original program as baseline, then each Pareto variant with estimated vs. validated accuracy and speedup + +### Example output + +``` +======================================================================== + Budget Est.AccCost GeomErr MaxErr Runtime Speedup +------------------------------------------------------------------------ + ORIGINAL -- 6.3760e-11 1.1794e+03 0.000015 1.00x +------------------------------------------------------------------------ + -7064000 3.5985e+01 6.3990e-03 3.4521e+02 0.000013 1.10x + -6863264 3.9649e-01 5.9403e-09 1.1794e+03 0.000014 1.08x + 1096856 -1.1591e-16 6.8253e-11 1.1794e+03 0.000013 1.13x +======================================================================== +``` + +Here `ORIGINAL` is the unoptimized double-precision program. Negative budgets allow the solver to trade accuracy for speed; positive budgets allow extra computation to improve accuracy. The `MaxErr: 1179` in the original comes from catastrophic cancellation (`1 - cos(theta)` near zero), which Herbie's rewrites at positive budgets fix. + +## (Optional) Generating MPFR References with RAPTOR + +For accuracy validation we recommend computing reference results at high floating-point precision (e.g., MPFR-2048 bits). [RAPTOR](https://github.com/RIKEN-RCCS/RAPTOR) provides this capability by running floating-point operations through MPFR. For the latest build and usage instructions of RAPTOR, see the [RAPTOR repository](https://github.com/RIKEN-RCCS/RAPTOR). + +### Building RAPTOR + +```bash +git clone https://github.com/RIKEN-RCCS/RAPTOR +cd RAPTOR && mkdir build && cd build +cmake .. -DLLVM_DIR=/path/to/llvm -DCMAKE_BUILD_TYPE=Release +make -j +``` + +### Source preparation + +Add a `#ifdef POSEIDON_GOLD` guard that wraps the target function call with RAPTOR's MPFR truncation: + +```cpp +#ifdef POSEIDON_GOLD +template +__attribute__((nothrow)) fty *__raptor_truncate_mem_func(fty *, int, int, int, int); +__attribute__((nothrow)) extern double __raptor_truncate_mem_value(...); +__attribute__((nothrow)) extern double __raptor_expand_mem_value(...); +extern "C" double raptor_fprt_gc_mark_seen(double); +extern "C" void raptor_fprt_gc_doit(); +// RAPTOR API: (func_ptr, from_bits, type_selector, exponent, mantissa) +// type_selector: 0=IEEE, 1=MPFR +#define RAPTOR_FROM 64 +#define RAPTOR_TYPE 1 +#define RAPTOR_TO_E 64 +#define RAPTOR_TO_M 2048 +#define TRUNC_SELF(X) X = __raptor_truncate_mem_value(X, RAPTOR_FROM, RAPTOR_TYPE, RAPTOR_TO_E, RAPTOR_TO_M) +#define EXPAND_SELF(X) X = __raptor_expand_mem_value(X, RAPTOR_FROM, RAPTOR_TYPE, RAPTOR_TO_E, RAPTOR_TO_M) +#else +int enzyme_dup; +template +return_type __enzyme_fp_optimize(void *, T...); +#endif +``` + +In the main loop: + +```cpp +#ifdef POSEIDON_GOLD + for (int i = 0; i < num_inputs; i++) TRUNC_SELF(inputs[i]); + __raptor_truncate_mem_func(my_kernel, RAPTOR_FROM, RAPTOR_TYPE, + RAPTOR_TO_E, RAPTOR_TO_M)(outputs, inputs); + for (int i = 0; i < num_outputs; i++) EXPAND_SELF(outputs[i]); + for (int i = 0; i < num_inputs; i++) EXPAND_SELF(inputs[i]); + raptor_fprt_gc_doit(); +#else + __enzyme_fp_optimize((void *)my_kernel, + enzyme_dup, outputs, outputs_grad, enzyme_dup, inputs, inputs_grad); +#endif +``` + +### Compiling and running + +```bash +clang++ mycode.cpp -O0 -fno-exceptions -DPOSEIDON_GOLD \ + -fpass-plugin=/path/to/RAPTOR/build/pass/ClangRaptor-XX.so \ + -Xclang -load -Xclang /path/to/RAPTOR/build/pass/ClangRaptor-XX.so \ + -L/path/to/RAPTOR/build/runtime -lRaptor-RT-XX \ + -lmpfr -lm -o gold.exe + +./gold.exe --output-path gold_mpfr.txt +``` + +Then pass `--gold-path gold_mpfr.txt` to `validate.py`. + +## Command-Line Reference + +### Profiling +| Flag | Default | Description | +|------|---------|-------------| +| `--fpprofile-generate` | false | Instrument for FP profiling | +| `--fpprofile-use=` | | Profile directory for optimization | + +### Herbie +| Flag | Default | Description | +|------|---------|-------------| +| `--fpopt-enable-herbie` | true | Use Herbie for algebraic rewrites | +| `--fpopt-cache-path` | "cache" | Cache directory for Herbie results and DP table | +| `--herbie-num-threads` | 8 | Herbie worker threads | +| `--herbie-timeout` | 120 | Per-expression Herbie timeout (seconds) | + +### Solver +| Flag | Default | Description | +|------|---------|-------------| +| `--fpopt-enable-solver` | true | Enable DP solver | +| `--fpopt-enable-pt` | true | Enable precision tuning candidates | +| `--fpopt-comp-cost-budget` | 0 | Computation cost budget (0 = no-op) | +| `--fpopt-cost-model-path` | | Hardware-specific cost model CSV | +| `--fpopt-strict-mode` | false | Discard candidates that produce NaN/inf | +| `--fpopt-num-samples` | 1024 | Number of MPFR evaluation samples | +| `--fpopt-early-prune` | true | Prune dominated candidates during DP | +| `--fpopt-loose-coverage` | false | Allow unexecuted instructions (suppress coverage errors) | + +### Reporting +| Flag | Default | Description | +|------|---------|-------------| +| `--fpopt-report-path` | | Output directory for JSON/text reports + validate.py | +| `--fpopt-apply-rewrites` | | Comma-separated rewrite IDs from `_rewrites.json` (bypasses DP solver) | +| `--fpopt-print` | false | Print debug info to stderr | +| `--fpopt-show-table` | false | Print full DP table to stderr | diff --git a/enzyme/Enzyme/Poseidon/scripts/cm_example.csv b/enzyme/Enzyme/Poseidon/scripts/cm_example.csv new file mode 100644 index 000000000000..5f1c42c6a1d4 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/scripts/cm_example.csv @@ -0,0 +1,98 @@ +fneg,float,75 +fadd,float,77 +fsub,float,78 +fmul,float,77 +fdiv,float,77 +fcmp,float,62 +fpext_float_to_double,float,78 +fmuladd,float,78 +sin,float,421 +cos,float,418 +tan,float,877 +exp,float,458 +log,float,300 +sqrt,float,124 +expm1,float,499 +log1p,float,478 +cbrt,float,2777 +pow,float,545 +fabs,float,76 +fma,float,75 +maxnum,float,75 +minnum,float,76 +ceil,float,77 +floor,float,75 +exp2,float,458 +log10,float,493 +log2,float,301 +rint,float,298 +round,float,304 +trunc,float,76 +copysign,float,76 +fdim,float,300 +fmod,float,298 +asin,float,331 +acos,float,296 +atan,float,780 +atan2,float,2730 +sinh,float,481 +cosh,float,328 +tanh,float,344 +asinh,float,341 +acosh,float,340 +atanh,float,377 +hypot,float,1938 +erf,float,327 +lgamma,float,1678 +tgamma,float,3468 +remainder,float,3026 +powi,float,83 +fneg,double,82 +fadd,double,83 +fsub,double,83 +fmul,double,82 +fdiv,double,127 +fcmp,double,65 +fptrunc_double_to_float,double,83 +fmuladd,double,83 +sin,double,3715 +cos,double,3769 +tan,double,4499 +exp,double,748 +log,double,350 +sqrt,double,236 +expm1,double,327 +log1p,double,460 +cbrt,double,1047 +pow,double,1102 +fabs,double,83 +fma,double,83 +maxnum,double,83 +minnum,double,83 +ceil,double,83 +floor,double,83 +exp2,double,442 +log10,double,604 +log2,double,327 +rint,double,326 +round,double,328 +trunc,double,83 +copysign,double,83 +fdim,double,329 +fmod,double,328 +asin,double,576 +acos,double,566 +atan,double,822 +atan2,double,1178 +sinh,double,627 +cosh,double,546 +tanh,double,327 +asinh,double,493 +acosh,double,385 +atanh,double,330 +hypot,double,526 +erf,double,327 +lgamma,double,1981 +tgamma,double,2017 +remainder,double,1064 +powi,double,82 diff --git a/enzyme/Enzyme/Poseidon/scripts/microbm.py b/enzyme/Enzyme/Poseidon/scripts/microbm.py new file mode 100644 index 000000000000..d5847b7808d5 --- /dev/null +++ b/enzyme/Enzyme/Poseidon/scripts/microbm.py @@ -0,0 +1,610 @@ +import time +import csv +import struct +import random +import numpy as np +import llvmlite.binding as llvm +import ctypes + + +random.seed(42) + +llvm.initialize_native_target() +llvm.initialize_native_asmprinter() + +FAST_MATH_FLAG = "fast" + +unrolled = 128 +iterations = 100000000 +AMPLIFIER = 10 + +precision_to_llvm_type = { + "double": "double", + "float": "float", + "half": "half", + "fp80": "x86_fp80", + "fp128": "fp128", + "bf16": "bfloat", +} + +precision_to_intrinsic_suffix = { + "double": "f64", + "float": "f32", + "half": "f16", + "fp80": "f80", + "fp128": "f128", + "bf16": "bf16", +} + +precision_ranks = {"bf16": 0, "half": 1, "float": 2, "double": 3, "fp80": 4, "fp128": 5} +precisions_ordered = ["bf16", "half", "float", "double", "fp80", "fp128"] +precisions = ["float", "double"] + + +def get_zero_literal(precision): + if precision in ("double", "float", "half"): + return "0.0" + elif precision == "bf16": + return "0xR0000" + elif precision == "fp80": + return "0xK00000000000000000000" + elif precision == "fp128": + return "0xL00000000000000000000000000000000" + return "0.0" + + +def float64_to_fp80_bytes(value: np.float64) -> bytes: + packed = struct.pack(">d", value) + (bits,) = struct.unpack(">Q", packed) + sign = (bits >> 63) & 0x1 + exponent = (bits >> 52) & 0x7FF + mantissa = bits & 0xFFFFFFFFFFFFF + + if exponent == 0: + if mantissa == 0: + fp80_exponent = 0 + fp80_mantissa = 0 + else: + shift = 0 + while (mantissa & (1 << 52)) == 0: + mantissa <<= 1 + shift += 1 + exponent = 1 - shift + exponent_bias_64 = 1023 + exponent_bias_80 = 16383 + fp80_exponent = exponent - exponent_bias_64 + exponent_bias_80 + fp80_mantissa = mantissa << (63 - 52) + elif exponent == 0x7FF: + fp80_exponent = 0x7FFF + if mantissa == 0: + fp80_mantissa = 0x8000000000000000 + else: + fp80_mantissa = 0xC000000000000000 | (mantissa << (63 - 52)) + else: + exponent_bias_64 = 1023 + exponent_bias_80 = 16383 + fp80_exponent = exponent - exponent_bias_64 + exponent_bias_80 + fp80_mantissa = (0x8000000000000000) | (mantissa << (63 - 52)) + + exponent_sign = (sign << 15) | fp80_exponent + fp80_bits = (exponent_sign << 64) | fp80_mantissa + fp80_bytes = fp80_bits.to_bytes(10, byteorder="big") + return fp80_bytes + + +def float64_to_fp128_bytes(value: np.float64) -> bytes: + packed = struct.pack(">d", value) + (bits,) = struct.unpack(">Q", packed) + sign = (bits >> 63) & 0x1 + exponent = (bits >> 52) & 0x7FF + mantissa = bits & 0xFFFFFFFFFFFFF + + if exponent == 0: + fp128_exponent = 0 + elif exponent == 0x7FF: + fp128_exponent = 0x7FFF + else: + exponent_bias_64 = 1023 + exponent_bias_128 = 16383 + fp128_exponent = exponent - exponent_bias_64 + exponent_bias_128 + + fp128_mantissa = mantissa << 60 + fp128_bits = (sign << 127) | (fp128_exponent << 112) | fp128_mantissa + fp128_bytes = fp128_bits.to_bytes(16, byteorder="big") + return fp128_bytes + + +def float_to_llvm_hex(f, precision): + if precision == "double": + f_cast = np.float64(f) + packed = struct.pack(">d", f_cast) + [i] = struct.unpack(">Q", packed) + return f"0x{i:016X}" + elif precision == "float": + f_cast = np.float32(f) + packed = struct.pack(">d", f_cast) + [i] = struct.unpack(">Q", packed) + return f"0x{i:016X}" + elif precision == "half": + f_cast = np.float16(f) + packed = f_cast.tobytes() + [i] = struct.unpack(">H", packed) + return f"0xH{i:04X}" + elif precision == "bf16": + f_cast = np.float32(f) + [bits] = struct.unpack(">I", struct.pack(">f", f_cast)) + bf16_bits = bits >> 16 + return f"0xR{bf16_bits:04X}" + elif precision == "fp80": + f_cast = np.float64(f) + fp80_bytes = float64_to_fp80_bytes(f_cast) + return f"0xK{fp80_bytes.hex().upper()}" + elif precision == "fp128": + f_cast = np.float64(f) + fp128_bytes = float64_to_fp128_bytes(f_cast) + swapped = fp128_bytes[8:] + fp128_bytes[:8] + return f"0xL{swapped.hex().upper()}" + else: + return str(f) + + +def generate_random_fp(precision): + if precision == "double": + f = random.uniform(-1e10, 1e10) + dtype = np.float64 + elif precision == "float": + f = random.uniform(-1e5, 1e5) + dtype = np.float32 + elif precision == "half": + f = random.uniform(-1e3, 1e3) + dtype = np.float16 + elif precision == "bf16": + f = random.uniform(-1e3, 1e3) + dtype = np.float32 + elif precision == "fp80": + f = random.uniform(-1e10, 1e10) + dtype = np.float64 + elif precision == "fp128": + f = random.uniform(-1e10, 1e10) + dtype = np.float64 + else: + f = random.uniform(-1e3, 1e3) + dtype = np.float64 + return dtype(f).item() + + +OP_INFO = { + "fneg": {"llvm_instr": "fneg", "num_operands": 1, "kind": "arithmetic"}, + "fadd": {"llvm_instr": "fadd", "num_operands": 2, "kind": "arithmetic"}, + "fsub": {"llvm_instr": "fsub", "num_operands": 2, "kind": "arithmetic"}, + "fmul": {"llvm_instr": "fmul", "num_operands": 2, "kind": "arithmetic"}, + "fdiv": {"llvm_instr": "fdiv", "num_operands": 2, "kind": "arithmetic"}, + "fcmp": {"llvm_instr": "fcmp", "num_operands": 2, "kind": "compare"}, + "fptrunc": {"llvm_instr": "fptrunc", "num_operands": 1, "kind": "cast"}, + "fpext": {"llvm_instr": "fpext", "num_operands": 1, "kind": "cast"}, +} + +FUNC_INFO = { + "fmuladd": {"intrinsic": "llvm.fmuladd", "num_operands": 3}, + "sin": {"intrinsic": "llvm.sin", "num_operands": 1}, + "cos": {"intrinsic": "llvm.cos", "num_operands": 1}, + "tan": {"intrinsic": None, "num_operands": 1}, + "exp": {"intrinsic": "llvm.exp", "num_operands": 1}, + "log": {"intrinsic": "llvm.log", "num_operands": 1}, + "sqrt": {"intrinsic": "llvm.sqrt", "num_operands": 1}, + "expm1": {"intrinsic": None, "num_operands": 1}, + "log1p": {"intrinsic": None, "num_operands": 1}, + "cbrt": {"intrinsic": None, "num_operands": 1}, + "pow": {"intrinsic": "llvm.pow", "num_operands": 2}, + "fabs": {"intrinsic": "llvm.fabs", "num_operands": 1}, + "fma": {"intrinsic": "llvm.fma", "num_operands": 3}, + "maxnum": {"intrinsic": "llvm.maxnum", "num_operands": 2}, + "minnum": {"intrinsic": "llvm.minnum", "num_operands": 2}, + "ceil": {"intrinsic": "llvm.ceil", "num_operands": 1}, + "floor": {"intrinsic": "llvm.floor", "num_operands": 1}, + "exp2": {"intrinsic": "llvm.exp2", "num_operands": 1}, + "log10": {"intrinsic": "llvm.log10", "num_operands": 1}, + "log2": {"intrinsic": "llvm.log2", "num_operands": 1}, + "rint": {"intrinsic": "llvm.rint", "num_operands": 1}, + "round": {"intrinsic": "llvm.round", "num_operands": 1}, + "trunc": {"intrinsic": "llvm.trunc", "num_operands": 1}, + "copysign": {"intrinsic": "llvm.copysign", "num_operands": 2}, + "fdim": {"intrinsic": None, "num_operands": 2}, + "fmod": {"intrinsic": None, "num_operands": 2}, + "asin": {"intrinsic": None, "num_operands": 1}, + "acos": {"intrinsic": None, "num_operands": 1}, + "atan": {"intrinsic": None, "num_operands": 1}, + "atan2": {"intrinsic": None, "num_operands": 2}, + "sinh": {"intrinsic": None, "num_operands": 1}, + "cosh": {"intrinsic": None, "num_operands": 1}, + "tanh": {"intrinsic": None, "num_operands": 1}, + "asinh": {"intrinsic": None, "num_operands": 1}, + "acosh": {"intrinsic": None, "num_operands": 1}, + "atanh": {"intrinsic": None, "num_operands": 1}, + "hypot": {"intrinsic": None, "num_operands": 2}, + "erf": {"intrinsic": None, "num_operands": 1}, + "lgamma": {"intrinsic": None, "num_operands": 1}, + "tgamma": {"intrinsic": None, "num_operands": 1}, + "remainder": {"intrinsic": None, "num_operands": 2}, + "powi": {"intrinsic": "llvm.powi", "num_operands": 2}, +} + + +def generate_loop_code(llvm_type, iterations, body_instructions, final_acc_reg): + zero_literal = get_zero_literal(llvm_type) + code = f""" +define i32 @main() optnone noinline {{ +entry: + %i = alloca i32 + %acc = alloca {llvm_type} + store i32 0, i32* %i + store {llvm_type} {zero_literal}, {llvm_type}* %acc + br label %loop + +loop: + %i_val = load i32, i32* %i + %cond = icmp slt i32 %i_val, {iterations} + br i1 %cond, label %body, label %exit + +body: + %acc_val0 = load {llvm_type}, {llvm_type}* %acc +{body_instructions} + store {llvm_type} {final_acc_reg}, {llvm_type}* %acc + %i_next = add i32 %i_val, 1 + store i32 %i_next, i32* %i + br label %loop + +exit: + %final_acc = load {llvm_type}, {llvm_type}* %acc + call void @use({llvm_type} %final_acc) + ret i32 0 +}} + +define void @use({llvm_type} %val) {{ + ret void +}} +""" + return code + + +def generate_arithmetic_op_code(op_key, precision, iterations): + """Generate LLVM IR for a basic arithmetic operator (or fneg) based on OP_INFO.""" + op_info = OP_INFO[op_key] + llvm_type = precision_to_llvm_type[precision] + body_lines = "" + for idx in range(unrolled): + operands = [] + for _ in range(op_info["num_operands"]): + f_val = generate_random_fp(precision) + operands.append(float_to_llvm_hex(f_val, precision)) + + if op_info["num_operands"] == 1: + line = f" %result{idx} = {op_info['llvm_instr']} {FAST_MATH_FLAG} {llvm_type} {operands[0]}" + body_lines += line + "\n" + body_lines += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {llvm_type} %acc_val{idx}, %result{idx}\n" + elif op_info["num_operands"] == 2: + line = f" %result{idx} = {op_info['llvm_instr']} {FAST_MATH_FLAG} {llvm_type} {operands[0]}, {operands[1]}" + body_lines += line + "\n" + body_lines += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {llvm_type} %acc_val{idx}, %result{idx}\n" + + final_acc = f"%acc_val{unrolled}" + return generate_loop_code(llvm_type, iterations, body_lines, final_acc) + + +def generate_compare_op_code(precision, iterations): + """Generate LLVM IR for an fcmp (comparison) operation.""" + llvm_type = precision_to_llvm_type[precision] + + body_lines = "" + for idx in range(unrolled): + f_a = generate_random_fp(precision) + f_b = generate_random_fp(precision) + a_hex = float_to_llvm_hex(f_a, precision) + b_hex = float_to_llvm_hex(f_b, precision) + line = f" %cmp{idx} = fcmp {FAST_MATH_FLAG} olt {llvm_type} {a_hex}, {b_hex}" + body_lines += line + "\n" + body_lines += f" %cmp_int{idx} = zext i1 %cmp{idx} to i32\n" + + code = f""" +define i32 @main() optnone noinline {{ +entry: + %i = alloca i32 + %acc = alloca i32 + store i32 0, i32* %i + store i32 0, i32* %acc + br label %loop + +loop: + %i_val = load i32, i32* %i + %cond = icmp slt i32 %i_val, {iterations} + br i1 %cond, label %body, label %exit + +body: + %acc_val0 = load i32, i32* %acc +{body_lines} +""" + + for idx in range(unrolled): + code += f" %acc_val{idx+1} = add i32 %acc_val{idx}, %cmp_int{idx}\n" + + final_acc = f"%acc_val{unrolled}" + code += f""" + store i32 {final_acc}, i32* %acc + %i_next = add i32 %i_val, 1 + store i32 %i_next, i32* %i + br label %loop + +exit: + %final_acc = load i32, i32* %acc + call void @use_i32(i32 %final_acc) + ret i32 0 +}} + +define void @use_i32(i32 %val) {{ + ret void +}} +""" + return code + + +def generate_cast_op_code(op_key, src_precision, dst_precision, iterations): + """Generate LLVM IR for a cast operation (fptrunc or fpext).""" + op_info = OP_INFO[op_key] + src_type = precision_to_llvm_type[src_precision] + dst_type = precision_to_llvm_type[dst_precision] + zero_literal = get_zero_literal(dst_precision) + + body_lines = "" + for idx in range(unrolled): + f_val = generate_random_fp(src_precision) + hex_val = float_to_llvm_hex(f_val, src_precision) + line = f" %result{idx} = {op_info['llvm_instr']} {src_type} {hex_val} to {dst_type}" + body_lines += line + "\n" + body_lines += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {dst_type} %acc_val{idx}, %result{idx}\n" + + final_acc = f"%acc_val{unrolled}" + code = f""" +define i32 @main() optnone noinline {{ +entry: + %i = alloca i32 + %acc = alloca {dst_type} + store i32 0, i32* %i + store {dst_type} {zero_literal}, {dst_type}* %acc + br label %loop + +loop: + %i_val = load i32, i32* %i + %cond = icmp slt i32 %i_val, {iterations} + br i1 %cond, label %body, label %exit + +body: + %acc_val0 = load {dst_type}, {dst_type}* %acc +{body_lines} + store {dst_type} {final_acc}, {dst_type}* %acc + %i_next = add i32 %i_val, 1 + store i32 %i_next, i32* %i + br label %loop + +exit: + %final_acc = load {dst_type}, {dst_type}* %acc + call void @use({dst_type} %final_acc) + ret i32 0 +}} + +define void @use({dst_type} %val) {{ + ret void +}} +""" + return code + + +def generate_function_call_code(func_name, precision, iterations): + """Generate LLVM IR for a function call based on FUNC_INFO.""" + + func_info = FUNC_INFO[func_name] + llvm_type = precision_to_llvm_type[precision] + intrinsic_suffix = precision_to_intrinsic_suffix.get(precision, "") + + if func_info["intrinsic"]: + fn = f"{func_info['intrinsic']}.{intrinsic_suffix}" + else: + fn = func_name + + num_operands = func_info["num_operands"] + + body_lines = " %acc_val0 = load " + llvm_type + ", " + llvm_type + "* %acc\n" + + for idx in range(unrolled): + if func_name == "powi": + f_val = generate_random_fp(precision) + i_val = random.randint(-10, 10) + f_hex = float_to_llvm_hex(f_val, precision) + line = f" %result{idx} = call {FAST_MATH_FLAG} {llvm_type} @{fn}(" f"{llvm_type} {f_hex}, i32 {i_val})" + body_lines += line + "\n" + body_lines += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {llvm_type} %acc_val{idx}, %result{idx}\n" + else: + operands = [] + for _ in range(num_operands): + f_val = generate_random_fp(precision) + operands.append(float_to_llvm_hex(f_val, precision)) + + if num_operands == 1: + call_str = f"call {FAST_MATH_FLAG} {llvm_type} @{fn}({llvm_type} {operands[0]})" + elif num_operands == 2: + call_str = ( + f"call {FAST_MATH_FLAG} {llvm_type} @{fn}({llvm_type} {operands[0]}, {llvm_type} {operands[1]})" + ) + elif num_operands == 3: + call_str = f"call {FAST_MATH_FLAG} {llvm_type} @{fn}({llvm_type} {operands[0]}, {llvm_type} {operands[1]}, {llvm_type} {operands[2]})" + else: + call_str = "" + + body_lines += f" %result{idx} = {call_str}\n" + body_lines += f" %acc_val{idx+1} = fadd {FAST_MATH_FLAG} {llvm_type} %acc_val{idx}, %result{idx}\n" + + if func_name == "powi": + decl = f"declare {llvm_type} @{fn}({llvm_type}, i32)" + else: + arg_types = ", ".join([llvm_type] * num_operands) + decl = f"declare {llvm_type} @{fn}({arg_types})" + + code = f""" +{decl} +define i32 @main() optnone noinline {{ +entry: + %i = alloca i32 + %acc = alloca {llvm_type} + store i32 0, i32* %i + store {llvm_type} {get_zero_literal(precision)}, {llvm_type}* %acc + br label %loop + +loop: + %i_val = load i32, i32* %i + %cond = icmp slt i32 %i_val, {iterations} + br i1 %cond, label %body, label %exit + +body: +{body_lines} + store {llvm_type} %acc_val{unrolled}, {llvm_type}* %acc + %i_next = add i32 %i_val, 1 + store i32 %i_next, i32* %i + br label %loop + +exit: + %final_acc = load {llvm_type}, {llvm_type}* %acc + call void @use({llvm_type} %final_acc) + ret i32 0 +}} + +define void @use({llvm_type} %val) {{ + ret void +}} +""" + return code + + +def generate_baseline_code(iterations): + return f""" +define i32 @main() optnone noinline {{ +entry: + %i = alloca i32 + store i32 0, i32* %i + br label %loop + +loop: + %i_val = load i32, i32* %i + %cond = icmp slt i32 %i_val, {iterations} + br i1 %cond, label %body, label %exit + +body: + %i_next = add i32 %i_val, 1 + store i32 %i_next, i32* %i + br label %loop + +exit: + ret i32 0 +}} +""" + + +def create_execution_engine(): + target = llvm.Target.from_default_triple() + target_machine = target.create_target_machine() + mod = llvm.parse_assembly("") + engine = llvm.create_mcjit_compiler(mod, target_machine) + return engine + + +def run_llvm_ir_jit(llvm_ir): + engine = create_execution_engine() + mod = llvm.parse_assembly(llvm_ir) + mod.verify() + engine.add_module(mod) + engine.finalize_object() + + func_ptr = engine.get_function_address("main") + cfunc = ctypes.CFUNCTYPE(ctypes.c_int)(func_ptr) + + start = time.perf_counter() + retval = cfunc() + end = time.perf_counter() + + return (end - start), retval + + +if __name__ == "__main__": + csv_file = "results.csv" + with open(csv_file, "w", newline="") as csvfile: + fieldnames = ["instruction", "precision", "cost"] + writer = csv.DictWriter(csvfile, fieldnames=fieldnames) + + llvm_code = generate_baseline_code(iterations) + baseline_time, _ = run_llvm_ir_jit(llvm_code) + + for precision in precisions: + for instr in OP_INFO: + op_kind = OP_INFO[instr]["kind"] + if op_kind == "cast": + src_precision = precision + src_rank = precision_ranks.get(src_precision) + if src_rank is None: + continue + + if instr == "fptrunc": + dst_precisions = [ + p for p in precisions_ordered if p in precisions and precision_ranks[p] < src_rank + ] + else: + dst_precisions = [ + p for p in precisions_ordered if p in precisions and precision_ranks[p] > src_rank + ] + + for dst_precision in dst_precisions: + if (src_precision, dst_precision) in [ + ("half", "bf16"), + ("bf16", "half"), + ]: + continue + code = generate_cast_op_code(instr, src_precision, dst_precision, iterations) + name = f"{instr}_{src_precision}_to_{dst_precision}" + elapsed, _ = run_llvm_ir_jit(code) + adjusted = (elapsed - baseline_time) * AMPLIFIER + writer.writerow( + { + "instruction": name, + "precision": src_precision, + "cost": int(adjusted), + } + ) + else: + if op_kind == "arithmetic": + code = generate_arithmetic_op_code(instr, precision, iterations) + elif op_kind == "compare": + code = generate_compare_op_code(precision, iterations) + else: + code = "" + if code.strip(): + elapsed, _ = run_llvm_ir_jit(code) + adjusted = (elapsed - baseline_time) * AMPLIFIER + writer.writerow( + { + "instruction": instr, + "precision": precision, + "cost": int(adjusted), + } + ) + + for func in FUNC_INFO: + code = generate_function_call_code(func, precision, iterations) + if code.strip(): + elapsed, _ = run_llvm_ir_jit(code) + adjusted = (elapsed - baseline_time) * AMPLIFIER + writer.writerow( + { + "instruction": func, + "precision": precision, + "cost": int(adjusted), + } + ) + + print(f"Results in '{csv_file}'. Baseline: {baseline_time:.6f}s") diff --git a/enzyme/Enzyme/Poseidon/scripts/validate.py b/enzyme/Enzyme/Poseidon/scripts/validate.py new file mode 100644 index 000000000000..4da43218adab --- /dev/null +++ b/enzyme/Enzyme/Poseidon/scripts/validate.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python3 +""" +Poseidon validation reference script. + +Compiles a uniformly sampled subset of Pareto-optimal variants with +Enzyme/Poseidon, runs each, and reports geomean relative error against +a gold reference and runtime. + +This is a REFERENCE SCRIPT meant to be copied and adapted for your +application. Search for "CUSTOMIZE" to find the places that likely +need modification for a new benchmark. + +Usage: + python validate.py --enzyme-plugin \ + [--cxx ] [--extra-flags "..."] \ + [--num-samples 10] [--num-runs 5] \ + [--gold-path gold.txt] +""" + +import argparse, glob, json, math, os, re, subprocess, sys, time + +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +CONFIG = json.load(open(os.path.join(SCRIPT_DIR, "validate_config.json"))) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def find_artifact(pattern): + hits = glob.glob(pattern) + return hits[0] if hits else None + + +def run(cmd, desc="", timeout=600, capture=False): + """Run a command and optionally capture stdout.""" + label = " ".join(cmd) if len(cmd) < 8 else f"{cmd[0]} ... ({desc})" + print(f" $ {label}") + try: + r = subprocess.run(cmd, capture_output=capture, text=True, timeout=timeout) + if r.returncode != 0: + print(f" FAILED ({desc}): exit {r.returncode}") + if capture and r.stderr: + print(r.stderr[:500]) + return None + return r.stdout if capture else "" + except subprocess.TimeoutExpired: + print(f" TIMEOUT ({desc})") + return None + + +# ---- CUSTOMIZE: output parsing ------------------------------------------- +# Replace this with a parser suited to your application's output format. +# The default extracts every floating-point number from stdout. + +def parse_output(text): + """Extract comparable numeric values from program output. + + Returns a list of floats in the order they appear. Adapt this for + your application -- e.g. skip header lines, parse specific columns, + or read a binary file instead. + """ + return [float(m) for m in + re.findall(r"[+-]?(?:\d+\.?\d*|\.\d+)(?:[eE][+-]?\d+)?", text)] + + +# ---- CUSTOMIZE: error metric --------------------------------------------- +# The default is geomean relative error. Replace with ULP, L2 norm, or +# whatever is appropriate for your domain. + +def compute_error(gold_vals, test_vals): + """Compute geomean and max relative error between two value lists.""" + n = min(len(gold_vals), len(test_vals)) + if n == 0: + return float("nan"), float("nan") + log_sum = 0.0 + count = 0 + max_err = 0.0 + for i in range(n): + g, t = gold_vals[i], test_vals[i] + if g == 0 and t == 0: + continue + denom = abs(g) if g != 0 else abs(t) + rel = abs(g - t) / denom + max_err = max(max_err, rel) + if rel > 0: + log_sum += math.log(rel) + count += 1 + if count == 0: + return 0.0, 0.0 + return math.exp(log_sum / count), max_err + + +# ---- CUSTOMIZE: runtime measurement -------------------------------------- +# The default uses wall-clock time. If your program reports its own +# timing (e.g. "Elapsed: 1.23s"), parse it here for more stable results. + +def measure_runtime(exe, num_runs, extra_run_args=None): + """Measure median wall-clock runtime of an executable.""" + times = [] + args = [exe] + (extra_run_args or []) + for _ in range(num_runs): + t0 = time.perf_counter() + r = subprocess.run(args, capture_output=True, text=True, timeout=300) + t1 = time.perf_counter() + if r.returncode == 0: + times.append(t1 - t0) + if not times: + return float("nan") + times.sort() + return times[len(times) // 2] + + +# ---- CUSTOMIZE: how to run the program and capture output ----------------- +# The default runs the executable with extra_run_args and captures stdout. +# If your program writes to a file, reads from stdin, or needs environment +# variables, modify this function. + +def run_and_capture(exe, extra_run_args): + """Run executable and return its stdout as a string, or None on failure.""" + return run([exe] + extra_run_args, f"run {os.path.basename(exe)}", capture=True) + + +def uniform_sample(lst, n): + """Pick n uniformly spaced items from lst (always include first and last).""" + if n >= len(lst): + return list(range(len(lst))) + if n <= 0: + return [] + if n == 1: + return [0] + step = (len(lst) - 1) / (n - 1) + return sorted(set(int(round(i * step)) for i in range(n))) + + +# --------------------------------------------------------------------------- +# Compilation +# --------------------------------------------------------------------------- + +def compile_gold(source, cxx, raptor_dir, extra_flags, gold_flags, out_dir): + """Compile with RAPTOR for MPFR gold reference. + + The source must contain RAPTOR truncation calls (e.g. guarded by + -DPOSEIDON_GOLD). See the Poseidon README for the full pattern. + """ + plugin = find_artifact(os.path.join(raptor_dir, "pass", "ClangRaptor-*.so")) + rt = find_artifact(os.path.join(raptor_dir, "runtime", "libRaptor-RT-*.a")) + if not plugin or not rt: + print(f"Error: RAPTOR artifacts not found in {raptor_dir}") + return None + rt_dir = os.path.dirname(rt) + rt_name = os.path.basename(rt).replace("lib", "", 1).replace(".a", "") + + exe = os.path.join(out_dir, "gold.exe") + cmd = [cxx, source] + extra_flags.split() + gold_flags.split() + [ + f"-fpass-plugin={plugin}", + "-Xclang", "-load", "-Xclang", plugin, + f"-L{rt_dir}", f"-l{rt_name}", + "-lmpfr", "-lm", + "-o", exe, + ] + if run(cmd, "compile gold") is None: + return None + return exe + + +def compile_variant(source, cxx, enzyme_plugin, budget, config, extra_flags, out_dir): + """Compile a Poseidon-optimized variant at a given budget.""" + exe = os.path.join(out_dir, f"opt_{budget}.exe") + cmd = [cxx, source] + extra_flags.split() + [ + f"-fpass-plugin={enzyme_plugin}", + "-Xclang", "-load", "-Xclang", enzyme_plugin, + "-mllvm", f"--fpprofile-use={config['profile_path']}", + "-mllvm", "--fpopt-enable-solver", + "-mllvm", "--fpopt-enable-herbie=1", + "-mllvm", "--fpopt-enable-pt", + "-mllvm", f"--fpopt-comp-cost-budget={budget}", + "-mllvm", f"--fpopt-cache-path={config['cache_path']}", + "-lmpfr", "-lm", + "-o", exe, + ] + if run(cmd, f"compile budget={budget}") is None: + return None + return exe + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + p = argparse.ArgumentParser( + description="Poseidon accuracy/runtime validation (reference script)") + p.add_argument("source", help="C/C++ source file") + p.add_argument("--enzyme-plugin", required=True, + help="Path to ClangEnzyme-XX.so") + p.add_argument("--cxx", default="clang++", + help="C++ compiler (default: clang++)") + p.add_argument("--extra-flags", default="-O3 -ffast-math -lm", + help="Compile flags (must match profiling flags)") + p.add_argument("--extra-run-args", default="", + help="Args passed to each executable at runtime") + p.add_argument("--num-samples", type=int, default=10, + help="Number of Pareto points to validate") + p.add_argument("--num-runs", type=int, default=5, + help="Runtime measurement repetitions per variant") + p.add_argument("--gold-path", default="", + help="Pre-computed gold reference (skip RAPTOR compilation)") + p.add_argument("--raptor-dir", default="", + help="RAPTOR build directory (for gold compilation)") + p.add_argument("--gold-source", default="", + help="Source file for gold compilation (if different)") + p.add_argument("--gold-flags", default="", + help="Extra flags for gold compilation (e.g. -DPOSEIDON_GOLD)") + args = p.parse_args() + + budgets = CONFIG["budgets"] + est_acc = CONFIG.get("estimated_accuracy_costs", []) + run_args = args.extra_run_args.split() if args.extra_run_args else [] + + out_dir = os.path.join(SCRIPT_DIR, "validation_output") + os.makedirs(out_dir, exist_ok=True) + + # --- Gold reference --- + if args.gold_path: + print(f"=== Using pre-computed gold reference: {args.gold_path} ===") + gold_text = open(args.gold_path).read() + else: + print("=== Step 1: Compiling MPFR gold reference with RAPTOR ===") + if not args.raptor_dir: + print("Error: --raptor-dir required for gold compilation.") + print("Either provide --gold-path or --raptor-dir.") + sys.exit(1) + gold_src = args.gold_source if args.gold_source else args.source + gold_exe = compile_gold(gold_src, args.cxx, args.raptor_dir, + args.extra_flags, args.gold_flags, out_dir) + if not gold_exe: + print("Gold compilation failed.") + sys.exit(1) + print("Running gold binary...") + gold_text = run_and_capture(gold_exe, run_args) + if gold_text is None: + print("Error: gold run failed.") + sys.exit(1) + + gold_vals = parse_output(gold_text) + print(f"Gold reference: {len(gold_vals)} values extracted.\n") + + # --- Original (unoptimized) baseline --- + print("=== Step 2: Measuring baseline (original) ===") + orig_exe = os.path.join(out_dir, "original.exe") + orig_cmd = [args.cxx, args.source] + args.extra_flags.split() + [ + f"-fpass-plugin={args.enzyme_plugin}", + "-Xclang", "-load", "-Xclang", args.enzyme_plugin, + "-lmpfr", "-lm", "-o", orig_exe, + ] + orig_geo_err = float("nan") + orig_max_err = float("nan") + orig_runtime = float("nan") + if run(orig_cmd, "compile original") is not None: + orig_runtime = measure_runtime(orig_exe, args.num_runs, run_args) + orig_text = run_and_capture(orig_exe, run_args) + if orig_text is not None: + orig_vals = parse_output(orig_text) + orig_geo_err, orig_max_err = compute_error(gold_vals, orig_vals) + print(f"Baseline runtime: {orig_runtime:.6f}s, " + f"geomean error: {orig_geo_err:.6e}, max error: {orig_max_err:.6e}\n") + else: + print("Warning: could not compile original.\n") + + # --- Sample and compile Pareto variants --- + sample_indices = uniform_sample(budgets, args.num_samples) + sampled_budgets = [budgets[i] for i in sample_indices] + sampled_est = [est_acc[i] if i < len(est_acc) else float("nan") + for i in sample_indices] + + print(f"=== Step 3: Compiling {len(sampled_budgets)} Pareto variants ===") + + results = [] + for budget, est in zip(sampled_budgets, sampled_est): + print(f"\n--- Budget: {budget} (estimated acc cost: {est:.6e}) ---") + exe = compile_variant(args.source, args.cxx, args.enzyme_plugin, + budget, CONFIG, args.extra_flags, out_dir) + if not exe: + results.append({"budget": budget, "error": "compile_failed"}) + continue + + test_text = run_and_capture(exe, run_args) + if test_text is None: + results.append({"budget": budget, "error": "run_failed"}) + continue + + test_vals = parse_output(test_text) + geo_err, max_err = compute_error(gold_vals, test_vals) + rt = measure_runtime(exe, args.num_runs, run_args) + speedup = orig_runtime / rt if rt > 0 and not math.isnan(orig_runtime) else float("nan") + + results.append({ + "budget": budget, + "estimated_accuracy_cost": est, + "geomean_relative_error": geo_err, + "max_relative_error": max_err, + "runtime": rt, + "speedup": speedup, + }) + print(f" geomean error: {geo_err:.6e}, max error: {max_err:.6e}, " + f"runtime: {rt:.6f}s, speedup: {speedup:.2f}x") + + # --- Summary --- + print("\n" + "=" * 72) + print(f"{'Budget':>10} {'Est.AccCost':>12} {'GeomErr':>12} {'MaxErr':>12} " + f"{'Runtime':>10} {'Speedup':>8}") + print("-" * 72) + if not math.isnan(orig_runtime): + print(f"{'ORIGINAL':>10} {'--':>12} " + f"{orig_geo_err:>12.4e} {orig_max_err:>12.4e} " + f"{orig_runtime:>10.6f} {'1.00x':>8}") + print("-" * 72) + for r in results: + if "error" in r and isinstance(r.get("error"), str): + print(f"{r['budget']:>10} {'':>12} {r['error']:>12}") + continue + print(f"{r['budget']:>10} {r.get('estimated_accuracy_cost',0):>12.4e} " + f"{r['geomean_relative_error']:>12.4e} {r['max_relative_error']:>12.4e} " + f"{r['runtime']:>10.6f} {r['speedup']:>7.2f}x") + print("=" * 72) + + # Save results + results_path = os.path.join(SCRIPT_DIR, f"{CONFIG['function']}_validation.json") + with open(results_path, "w") as f: + json.dump({"config": CONFIG, + "baseline": {"runtime": orig_runtime, + "geomean_relative_error": orig_geo_err, + "max_relative_error": orig_max_err}, + "results": results}, f, indent=2, default=str) + print(f"\nResults saved to {results_path}") + + +if __name__ == "__main__": + main() diff --git a/enzyme/Enzyme/Runtimes/FPProfiler/FPProfiler.cpp b/enzyme/Enzyme/Runtimes/FPProfiler/FPProfiler.cpp new file mode 100644 index 000000000000..563cb71494ab --- /dev/null +++ b/enzyme/Enzyme/Runtimes/FPProfiler/FPProfiler.cpp @@ -0,0 +1,195 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ProfileInfo { +public: + double minRes = std::numeric_limits::max(); + double maxRes = std::numeric_limits::lowest(); + std::vector minOperands; + std::vector maxOperands; + double sumValue = 0.0; + double sumSens = 0.0; + double sumGrad = 0.0; + unsigned exec = 0; + + void updateValue(double value, const double *operands, size_t numOperands) { + ++exec; + + if (minOperands.empty()) { + minOperands.resize(numOperands, std::numeric_limits::max()); + maxOperands.resize(numOperands, std::numeric_limits::lowest()); + } + for (size_t i = 0; i < numOperands; ++i) { + if (!std::isnan(operands[i])) { + minOperands[i] = std::min(minOperands[i], operands[i]); + maxOperands[i] = std::max(maxOperands[i], operands[i]); + } + } + + if (!std::isnan(value)) { + minRes = std::min(minRes, value); + maxRes = std::max(maxRes, value); + sumValue += value; + } + } + + void updateGradient(double value, double grad) { + if (!std::isnan(grad) && !std::isnan(value) && !std::isinf(grad) && + !std::isinf(value)) { + sumGrad += grad; + sumSens += std::fabs(grad * value); + } + } +}; + +class FPProfiler { +private: + std::string functionName; + std::unordered_map profileInfo; + static std::string dir; + +public: + FPProfiler(const std::string &funcName) : functionName(funcName) {} + + static void setOutputDir(const std::string &dir_) { dir = dir_; } + + std::string getOutputPath() const { + return dir + "/" + functionName + ".fpprofile"; + } + + void updateValue(size_t idx, double res, size_t numOperands, + const double *operands) { + auto it = profileInfo.try_emplace(idx).first; + it->second.updateValue(res, operands, numOperands); + } + + void updateGradient(size_t idx, double value, double grad) { + auto it = profileInfo.try_emplace(idx).first; + it->second.updateGradient(value, grad); + } + + void write() const { + std::string outputPath = getOutputPath(); + + struct stat st = {0}; + if (stat(dir.c_str(), &st) == -1) { + mkdir(dir.c_str(), 0755); + } + + std::ofstream out(outputPath); + if (!out.is_open()) { + std::cerr << "Warning: Could not open profile file: " << outputPath + << std::endl; + return; + } + + out << std::scientific + << std::setprecision(std::numeric_limits::max_digits10); + + for (const auto &pair : profileInfo) { + const auto i = pair.first; + const auto &info = pair.second; + out << i << "\n"; + out << "\tMinRes = " << info.minRes << "\n"; + out << "\tMaxRes = " << info.maxRes << "\n"; + out << "\tSumValue = " << info.sumValue << "\n"; + out << "\tSumSens = " << info.sumSens << "\n"; + out << "\tSumGrad = " << info.sumGrad << "\n"; + out << "\tExec = " << info.exec << "\n"; + out << "\tNumOperands = " << info.minOperands.size() << "\n"; + for (size_t i = 0; i < info.minOperands.size(); ++i) { + out << "\tOperand[" << i << "] = [" << info.minOperands[i] << ", " + << info.maxOperands[i] << "]\n"; + } + } + + out << "\n"; + out.close(); + } +}; + +std::string FPProfiler::dir = "./fpprofile"; + +static std::unordered_map> + profilerRegistry; +static std::mutex registryMutex; + +static void writeAllProfilesAtExit() { + std::lock_guard lock(registryMutex); + for (auto &pair : profilerRegistry) { + pair.second->write(); + } + profilerRegistry.clear(); +} + +static int RegisterFPProfileRuntime() { + const char *envPath = getenv("ENZYME_FPPROFILE_DIR"); + if (envPath) { + FPProfiler::setOutputDir(envPath); + } else { + FPProfiler::setOutputDir("./fpprofile"); + } + + std::atexit(writeAllProfilesAtExit); + + return 0; +} + +extern "C" int ENZYME_FPPROFILE_RUNTIME_VAR = RegisterFPProfileRuntime(); + +extern "C" { + +void ProfilerWrite() { + std::lock_guard lock(registryMutex); + for (auto &pair : profilerRegistry) { + pair.second->write(); + } +} + +void enzymeLogGrad(const char *funcName, size_t idx, double value, + double grad) { + if (!funcName) + return; + + std::lock_guard lock(registryMutex); + + auto it = profilerRegistry.find(funcName); + if (it == profilerRegistry.end()) { + profilerRegistry[funcName] = std::make_unique(funcName); + it = profilerRegistry.find(funcName); + } + + it->second->updateGradient(idx, value, grad); +} + +void enzymeLogValue(const char *funcName, size_t idx, double res, + size_t numOperands, double *operands) { + if (!funcName) + return; + + std::lock_guard lock(registryMutex); + + auto it = profilerRegistry.find(funcName); + if (it == profilerRegistry.end()) { + profilerRegistry[funcName] = std::make_unique(funcName); + it = profilerRegistry.find(funcName); + } + + it->second->updateValue(idx, res, numOperands, operands); +} + +} // extern "C" \ No newline at end of file diff --git a/enzyme/test/Enzyme/CMakeLists.txt b/enzyme/test/Enzyme/CMakeLists.txt index 3617d3425d34..0865c61e474e 100644 --- a/enzyme/test/Enzyme/CMakeLists.txt +++ b/enzyme/test/Enzyme/CMakeLists.txt @@ -11,6 +11,10 @@ add_subdirectory(ProbProg) add_subdirectory(JLSimplify) add_subdirectory(SimpleGVN) +if(ENABLE_POSEIDON) + add_subdirectory(Poseidon) +endif() + # Run regression and unit tests add_lit_testsuite(check-enzyme "Running enzyme regression tests" ${CMAKE_CURRENT_BINARY_DIR} diff --git a/enzyme/test/Enzyme/ForwardError/add.ll b/enzyme/test/Enzyme/ForwardError/add.ll index 265759867f7a..d75b5646df65 100644 --- a/enzyme/test/Enzyme/ForwardError/add.ll +++ b/enzyme/test/Enzyme/ForwardError/add.ll @@ -18,7 +18,7 @@ entry: declare double @__enzyme_error_estimate(double (double, double)*, ...) -; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, double %"y'") +; CHECK: define internal double @fwderrtester(double %x, double %"x'", double %y, double %"y'") ; CHECK-NEXT: entry: ; CHECK-NEXT: %[[i0:.+]] = fadd double %x, %y ; CHECK-NEXT: %[[i1:.+]] = fmul fast double %"x'", %x diff --git a/enzyme/test/Enzyme/ForwardError/cos.ll b/enzyme/test/Enzyme/ForwardError/cos.ll index 3d75b9115e2c..4036adb53916 100644 --- a/enzyme/test/Enzyme/ForwardError/cos.ll +++ b/enzyme/test/Enzyme/ForwardError/cos.ll @@ -24,7 +24,7 @@ declare double @llvm.sin.f64(double) declare double @__enzyme_error_estimate(double (double)*, ...) -; CHECK: define internal double @fwddiffetester(double %x, double %"x'") +; CHECK: define internal double @fwderrtester(double %x, double %"x'") ; CHECK-NEXT: entry: ; CHECK-NEXT: %[[i0:.+]] = tail call fast double @llvm.cos.f64(double %x) ; CHECK-NEXT: %[[i1:.+]] = fmul fast double %"x'", %x diff --git a/enzyme/test/Enzyme/ForwardError/div.ll b/enzyme/test/Enzyme/ForwardError/div.ll index ea8d9a6ab39b..87af99012ae3 100644 --- a/enzyme/test/Enzyme/ForwardError/div.ll +++ b/enzyme/test/Enzyme/ForwardError/div.ll @@ -18,7 +18,7 @@ entry: declare double @__enzyme_error_estimate(double (double, double)*, ...) -; CHECK: define internal double @fwddiffetester(double %x, double %"x'", double %y, double %"y'") +; CHECK: define internal double @fwderrtester(double %x, double %"x'", double %y, double %"y'") ; CHECK-NEXT: entry: ; CHECK-NEXT: %[[i0:.+]] = fdiv double %x, %y ; CHECK-NEXT: %[[i1:.+]] = fmul fast double %"x'", %x diff --git a/enzyme/test/Enzyme/Poseidon/CMakeLists.txt b/enzyme/test/Enzyme/Poseidon/CMakeLists.txt new file mode 100644 index 000000000000..6a3191ef451e --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/CMakeLists.txt @@ -0,0 +1,12 @@ +# Run regression and unit tests +add_lit_testsuite(check-poseidon "Running Poseidon regression tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${ENZYME_TEST_DEPS} + ARGS -v +) + +set_target_properties(check-poseidon PROPERTIES FOLDER "Tests") + +#add_lit_testsuites(ENZYME ${CMAKE_CURRENT_SOURCE_DIR} + # DEPENDS ${ENZYME_TEST_DEPS} +#) diff --git a/enzyme/test/Enzyme/Poseidon/Inputs/expm1div_profiles/preprocess_tester.fpprofile b/enzyme/test/Enzyme/Poseidon/Inputs/expm1div_profiles/preprocess_tester.fpprofile new file mode 100644 index 000000000000..da68eeac6e4c --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/Inputs/expm1div_profiles/preprocess_tester.fpprofile @@ -0,0 +1,29 @@ +0 + MinRes = 0.3679 + MaxRes = 7.389 + SumValue = 200.0 + SumSens = 80.0 + SumGrad = 40.0 + Exec = 100 + NumOperands = 1 + Operand[0] = [-1.0, 2.0] +1 + MinRes = -0.6321 + MaxRes = 6.389 + SumValue = 150.0 + SumSens = 60.0 + SumGrad = 30.0 + Exec = 100 + NumOperands = 2 + Operand[0] = [0.3679, 7.389] + Operand[1] = [-1.0, -1.0] +2 + MinRes = 0.5 + MaxRes = 3.195 + SumValue = 100.0 + SumSens = 50.0 + SumGrad = 25.0 + Exec = 100 + NumOperands = 2 + Operand[0] = [-0.6321, 6.389] + Operand[1] = [-1.0, 2.0] diff --git a/enzyme/test/Enzyme/Poseidon/Inputs/fma_opt_cache/cachedHerbieOutput_0_0.txt b/enzyme/test/Enzyme/Poseidon/Inputs/fma_opt_cache/cachedHerbieOutput_0_0.txt new file mode 100644 index 000000000000..1a281c4f0297 --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/Inputs/fma_opt_cache/cachedHerbieOutput_0_0.txt @@ -0,0 +1 @@ +{"branch":"release","commit":"2.2","date":1747796669,"flags":[],"iterations":6,"merged-cost-accuracy":[[1.0,0.5],[0.8,0.1],[[0.9,0.3,"(+.f64 (*.f64 v0 v2) (*.f64 v1 v2))"]]],"seed":239778888,"tests":[{"name":"0","input":"(FPCore (v0 v1 v2) (* (+ v0 v1) v2))","output":"(fma.f64 v0 v2 (*.f64 v1 v2))","bits":64,"cost-accuracy":[[320.0,0.5],[480.0,0.1],[[400.0,0.3,"(+.f64 (*.f64 v0 v2) (*.f64 v1 v2))"]]]}]} diff --git a/enzyme/test/Enzyme/Poseidon/Inputs/fma_opt_profiles/preprocess_tester.fpprofile b/enzyme/test/Enzyme/Poseidon/Inputs/fma_opt_profiles/preprocess_tester.fpprofile new file mode 100644 index 000000000000..bea69a1bb708 --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/Inputs/fma_opt_profiles/preprocess_tester.fpprofile @@ -0,0 +1,20 @@ +0 + MinRes = 3.0 + MaxRes = 15.0 + SumValue = 100.0 + SumSens = 50.0 + SumGrad = 25.0 + Exec = 100 + NumOperands = 2 + Operand[0] = [1.0, 5.0] + Operand[1] = [2.0, 10.0] +1 + MinRes = 6.0 + MaxRes = 150.0 + SumValue = 500.0 + SumSens = 100.0 + SumGrad = 50.0 + Exec = 100 + NumOperands = 2 + Operand[0] = [3.0, 15.0] + Operand[1] = [0.5, 10.0] diff --git a/enzyme/test/Enzyme/Poseidon/add_noopt.ll b/enzyme/test/Enzyme/Poseidon/add_noopt.ll new file mode 100644 index 000000000000..54d222723338 --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/add_noopt.ll @@ -0,0 +1,25 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -S | FileCheck %s -dump-input=always +; REQUIRES: poseidon + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %0 = fadd fast double %x, %y + ret double %0 +} + +define double @test_profile(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fp_optimize(double (double, double)* nonnull @tester, double %x, double %y, metadata !"enzyme_err_tol", double 1.0e-6) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fp_optimize(double (double, double)*, ...) + +; CHECK: define double @test_profile(double %x, double %y) +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[fadd:.+]] = call double @tester(double %x, double %y) +; CHECK-NEXT: ret double %[[fadd]] +; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/Poseidon/add_prof.ll b/enzyme/test/Enzyme/Poseidon/add_prof.ll new file mode 100644 index 000000000000..bd0fde39c746 --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/add_prof.ll @@ -0,0 +1,39 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-generate -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-generate -S | FileCheck %s -dump-input=always +; REQUIRES: poseidon + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, double %y) { +entry: + %0 = fadd fast double %x, %y + ret double %0 +} + +define double @test_profile(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fp_optimize(double (double, double)* nonnull @tester, double %x, double %y, metadata !"enzyme_err_tol", double 1.0e-6) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fp_optimize(double (double, double)*, ...) + +; CHECK: @ENZYME_FPPROFILE_RUNTIME_VAR = external global i32 +; CHECK: @fpprofiled_preprocess_tester = private unnamed_addr constant [18 x i8] c"preprocess_tester\00", align 1 + +; CHECK: define internal { double, double } @instrtester(double %x, double %y, double %[[differet:.+]]) #{{[0-9]+}} { +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[alloca:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[fadd:.+]] = fadd fast double %x, %y, !enzyme_active !0, !enzyme_fpprofile_idx !1 +; CHECK-NEXT: store double %x, ptr %[[alloca]], align 8 +; CHECK-NEXT: %[[gep:.+]] = getelementptr [2 x double], ptr %[[alloca]], i32 0, i32 1 +; CHECK-NEXT: store double %y, ptr %[[gep]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_tester, i64 0, double %[[fadd]], i32 2, ptr %[[alloca]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_tester, i64 0, double %[[fadd]], double %[[differet]]) +; CHECK-NEXT: %[[ins1:.+]] = insertvalue { double, double } undef, double %[[differet]], 0 +; CHECK-NEXT: %[[ins2:.+]] = insertvalue { double, double } %[[ins1]], double %[[differet]], 1 +; CHECK-NEXT: ret { double, double } %[[ins2]] +; CHECK-NEXT: } + +; CHECK: !0 = !{} +; CHECK: !1 = !{i64 0} \ No newline at end of file diff --git a/enzyme/test/Enzyme/Poseidon/expm1div_opt_herbie_live.ll b/enzyme/test/Enzyme/Poseidon/expm1div_opt_herbie_live.ll new file mode 100644 index 000000000000..4730f786c0ee --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/expm1div_opt_herbie_live.ll @@ -0,0 +1,29 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-use=%S/Inputs/expm1div_profiles -fpopt-enable-herbie=true -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path= -fpopt-print -S 2>&1 | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-use=%S/Inputs/expm1div_profiles -fpopt-enable-herbie=true -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path= -fpopt-print -S 2>&1 | FileCheck %s +; REQUIRES: poseidon + +; Live Herbie on (exp(x) - 1) / x: verifies expm1 candidate is produced. + +declare double @llvm.exp.f64(double) + +define double @tester(double %x) { +entry: + %exp = call fast double @llvm.exp.f64(double %x) + %sub = fsub fast double %exp, 1.0 + %div = fdiv fast double %sub, %x + ret double %div +} + +define double @test_opt(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fp_optimize(double (double)* nonnull @tester, double %x, metadata !"enzyme_err_tol", double 0.5) + ret double %0 +} + +declare double @__enzyme_fp_optimize(double (double)*, ...) + +; CHECK: Candidates: +; CHECK: expm1 +; CHECK: Finished +; CHECK: define double @preprocess_tester(double %x) +; CHECK: ret double diff --git a/enzyme/test/Enzyme/Poseidon/fcmp_select_prof.ll b/enzyme/test/Enzyme/Poseidon/fcmp_select_prof.ll new file mode 100644 index 000000000000..80611080e5a2 --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/fcmp_select_prof.ll @@ -0,0 +1,188 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -fpprofile-generate -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -fpprofile-generate -S | FileCheck %s +; REQUIRES: poseidon + +; CHECK: @ENZYME_FPPROFILE_RUNTIME_VAR = external global i32 +; CHECK: @fpprofiled_preprocess_test_maxnum_zero = private unnamed_addr constant [28 x i8] c"preprocess_test_maxnum_zero\00", align 1 +; CHECK: @fpprofiled_preprocess_test_maxnum_zero_reversed = private unnamed_addr constant [37 x i8] c"preprocess_test_maxnum_zero_reversed\00", align 1 +; CHECK: @fpprofiled_preprocess_test_maxnum_general = private unnamed_addr constant [31 x i8] c"preprocess_test_maxnum_general\00", align 1 +; CHECK: @fpprofiled_preprocess_test_minnum_general = private unnamed_addr constant [31 x i8] c"preprocess_test_minnum_general\00", align 1 +; CHECK: @fpprofiled_preprocess_test_combined = private unnamed_addr constant [25 x i8] c"preprocess_test_combined\00", align 1 + +define double @test_maxnum_zero(double %x) { +entry: + %cmp = fcmp ogt double %x, 0.0 + %result = select i1 %cmp, double %x, double 0.0 + ret double %result +} + +; CHECK: define internal { double } @instrtest_maxnum_zero(double %x, double %[[differet:.+]]) #{{[0-9]+}} { +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[alloca:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[maxnum:.+]] = call double @llvm.maxnum.f64(double %x, double 0.000000e+00), !enzyme_active !{{[0-9]+}}, !enzyme_fpprofile_idx !{{[0-9]+}} +; CHECK-NEXT: store double %x, ptr %[[alloca]], align 8 +; CHECK-NEXT: %[[gep:.+]] = getelementptr [2 x double], ptr %[[alloca]], i32 0, i32 1 +; CHECK-NEXT: store double 0.000000e+00, ptr %[[gep]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_test_maxnum_zero, i64 0, double %[[maxnum]], i32 2, ptr %[[alloca]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_test_maxnum_zero, i64 0, double %[[maxnum]], double %[[differet]]) +; CHECK-NEXT: %[[cmp:.+]] = fcmp fast olt double %x, 0.000000e+00 +; CHECK-NEXT: %[[sel:.+]] = select fast i1 %[[cmp]], double 0.000000e+00, double %[[differet]] +; CHECK-NEXT: %[[ins:.+]] = insertvalue { double } undef, double %[[sel]], 0 +; CHECK-NEXT: ret { double } %[[ins]] +; CHECK-NEXT: } + +define double @test_maxnum_zero_reversed(double %x) { +entry: + %cmp = fcmp olt double %x, 0.0 + %result = select i1 %cmp, double 0.0, double %x + ret double %result +} + +; CHECK: define internal { double } @instrtest_maxnum_zero_reversed(double %x, double %[[differet:.+]]) #{{[0-9]+}} { +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[alloca:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[maxnum:.+]] = call double @llvm.maxnum.f64(double %x, double 0.000000e+00), !enzyme_active !{{[0-9]+}}, !enzyme_fpprofile_idx !{{[0-9]+}} +; CHECK-NEXT: store double %x, ptr %[[alloca]], align 8 +; CHECK-NEXT: %[[gep:.+]] = getelementptr [2 x double], ptr %[[alloca]], i32 0, i32 1 +; CHECK-NEXT: store double 0.000000e+00, ptr %[[gep]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_test_maxnum_zero_reversed, i64 0, double %[[maxnum]], i32 2, ptr %[[alloca]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_test_maxnum_zero_reversed, i64 0, double %[[maxnum]], double %[[differet]]) +; CHECK-NEXT: %[[cmp:.+]] = fcmp fast olt double %x, 0.000000e+00 +; CHECK-NEXT: %[[sel:.+]] = select fast i1 %[[cmp]], double 0.000000e+00, double %[[differet]] +; CHECK-NEXT: %[[ins:.+]] = insertvalue { double } undef, double %[[sel]], 0 +; CHECK-NEXT: ret { double } %[[ins]] +; CHECK-NEXT: } + +define double @test_maxnum_general(double %x, double %y) { +entry: + %cmp = fcmp ogt double %x, %y + %result = select i1 %cmp, double %x, double %y + ret double %result +} + +; CHECK: define internal { double, double } @instrtest_maxnum_general(double %x, double %y, double %[[differet:.+]]) #{{[0-9]+}} { +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[alloca:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[maxnum:.+]] = call double @llvm.maxnum.f64(double %x, double %y), !enzyme_active !{{[0-9]+}}, !enzyme_fpprofile_idx !{{[0-9]+}} +; CHECK-NEXT: store double %x, ptr %[[alloca]], align 8 +; CHECK-NEXT: %[[gep:.+]] = getelementptr [2 x double], ptr %[[alloca]], i32 0, i32 1 +; CHECK-NEXT: store double %y, ptr %[[gep]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_test_maxnum_general, i64 0, double %[[maxnum]], i32 2, ptr %[[alloca]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_test_maxnum_general, i64 0, double %[[maxnum]], double %[[differet]]) +; CHECK-NEXT: %[[cmp:.+]] = fcmp fast olt double %x, %y +; CHECK-NEXT: %[[sel1:.+]] = select fast i1 %[[cmp]], double 0.000000e+00, double %[[differet]] +; CHECK-NEXT: %[[cmp2:.+]] = fcmp fast olt double %x, %y +; CHECK-NEXT: %[[sel2:.+]] = select fast i1 %[[cmp2]], double %[[differet]], double 0.000000e+00 +; CHECK-NEXT: %[[ins1:.+]] = insertvalue { double, double } undef, double %[[sel1]], 0 +; CHECK-NEXT: %[[ins2:.+]] = insertvalue { double, double } %[[ins1]], double %[[sel2]], 1 +; CHECK-NEXT: ret { double, double } %[[ins2]] +; CHECK-NEXT: } + +define double @test_minnum_general(double %x, double %y) { +entry: + %cmp = fcmp olt double %x, %y + %result = select i1 %cmp, double %x, double %y + ret double %result +} + +; CHECK: define internal { double, double } @instrtest_minnum_general(double %x, double %y, double %[[differet:.+]]) #{{[0-9]+}} { +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[alloca:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[minnum:.+]] = call double @llvm.minnum.f64(double %x, double %y), !enzyme_active !{{[0-9]+}}, !enzyme_fpprofile_idx !{{[0-9]+}} +; CHECK-NEXT: store double %x, ptr %[[alloca]], align 8 +; CHECK-NEXT: %[[gep:.+]] = getelementptr [2 x double], ptr %[[alloca]], i32 0, i32 1 +; CHECK-NEXT: store double %y, ptr %[[gep]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_test_minnum_general, i64 0, double %[[minnum]], i32 2, ptr %[[alloca]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_test_minnum_general, i64 0, double %[[minnum]], double %[[differet]]) +; CHECK-NEXT: %[[cmp:.+]] = fcmp fast olt double %x, %y +; CHECK-NEXT: %[[sel1:.+]] = select fast i1 %[[cmp]], double %[[differet]], double 0.000000e+00 +; CHECK-NEXT: %[[cmp2:.+]] = fcmp fast olt double %x, %y +; CHECK-NEXT: %[[sel2:.+]] = select fast i1 %[[cmp2]], double 0.000000e+00 +; CHECK-NEXT: %[[ins1:.+]] = insertvalue { double, double } undef, double %[[sel1]], 0 +; CHECK-NEXT: %[[ins2:.+]] = insertvalue { double, double } %[[ins1]], double %[[sel2]], 1 +; CHECK-NEXT: ret { double, double } %[[ins2]] +; CHECK-NEXT: } + +define double @test_combined(double %x, double %y, double %z) { +entry: + %cmp1 = fcmp ogt double %x, 0.0 + %max_x = select i1 %cmp1, double %x, double 0.0 + %cmp2 = fcmp olt double %max_x, %y + %min_xy = select i1 %cmp2, double %max_x, double %y + %result = fadd fast double %min_xy, %z + ret double %result +} + +; CHECK: define internal { double, double, double } @instrtest_combined(double %x, double %y, double %z, double %[[differet:.+]]) #{{[0-9]+}} { +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[alloca0:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[alloca1:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[alloca2:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[maxnum:.+]] = call double @llvm.maxnum.f64(double %x, double 0.000000e+00), !enzyme_active !{{[0-9]+}}, !enzyme_fpprofile_idx !{{[0-9]+}} +; CHECK-NEXT: store double %x, ptr %[[alloca2]], align 8 +; CHECK-NEXT: %[[gep2:.+]] = getelementptr [2 x double], ptr %[[alloca2]], i32 0, i32 1 +; CHECK-NEXT: store double 0.000000e+00, ptr %[[gep2]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_test_combined, i64 0, double %[[maxnum]], i32 2, ptr %[[alloca2]]) +; CHECK-NEXT: %[[minnum:.+]] = call double @llvm.minnum.f64(double %[[maxnum]], double %y), !enzyme_active !{{[0-9]+}}, !enzyme_fpprofile_idx !{{[0-9]+}} +; CHECK-NEXT: store double %[[maxnum]], ptr %[[alloca1]], align 8 +; CHECK-NEXT: %[[gep1:.+]] = getelementptr [2 x double], ptr %[[alloca1]], i32 0, i32 1 +; CHECK-NEXT: store double %y, ptr %[[gep1]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_test_combined, i64 1, double %[[minnum]], i32 2, ptr %[[alloca1]]) +; CHECK-NEXT: %result = fadd fast double %[[minnum]], %z, !enzyme_active !{{[0-9]+}}, !enzyme_fpprofile_idx !{{[0-9]+}} +; CHECK-NEXT: store double %[[minnum]], ptr %[[alloca0]], align 8 +; CHECK-NEXT: %[[gep0:.+]] = getelementptr [2 x double], ptr %[[alloca0]], i32 0, i32 1 +; CHECK-NEXT: store double %z, ptr %[[gep0]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_test_combined, i64 2, double %result, i32 2, ptr %[[alloca0]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_test_combined, i64 2, double %result, double %[[differet]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_test_combined, i64 1, double %[[minnum]], double %[[differet]]) +; CHECK-NEXT: %[[cmp1:.+]] = fcmp fast olt double %[[maxnum]], %y +; CHECK-NEXT: %[[sel1:.+]] = select fast i1 %[[cmp1]], double %[[differet]], double 0.000000e+00 +; CHECK-NEXT: %[[cmp2:.+]] = fcmp fast olt double %[[maxnum]], %y +; CHECK-NEXT: %[[sel2:.+]] = select fast i1 %[[cmp2]], double 0.000000e+00, double %[[differet]] +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_test_combined, i64 0, double %[[maxnum]], double %[[sel1]]) +; CHECK-NEXT: %[[cmp3:.+]] = fcmp fast olt double %x, 0.000000e+00 +; CHECK-NEXT: %[[sel3:.+]] = select fast i1 %[[cmp3]], double 0.000000e+00, double %[[sel1]] +; CHECK-NEXT: %[[ins1:.+]] = insertvalue { double, double, double } undef, double %[[sel3]], 0 +; CHECK-NEXT: %[[ins2:.+]] = insertvalue { double, double, double } %[[ins1]], double %[[sel2]], 1 +; CHECK-NEXT: %[[ins3:.+]] = insertvalue { double, double, double } %[[ins2]], double %[[differet]], 2 +; CHECK-NEXT: ret { double, double, double } %[[ins3]] +; CHECK-NEXT: } + +define double @test_profile_maxnum_zero(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fp_optimize(double (double)* nonnull @test_maxnum_zero, double %x) + ret double %0 +} + +define double @test_profile_maxnum_zero_reversed(double %x) { +entry: + %0 = tail call double (double (double)*, ...) @__enzyme_fp_optimize(double (double)* nonnull @test_maxnum_zero_reversed, double %x) + ret double %0 +} + +define double @test_profile_maxnum(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fp_optimize(double (double, double)* nonnull @test_maxnum_general, double %x, double %y) + ret double %0 +} + + +define double @test_profile_minnum(double %x, double %y) { +entry: + %0 = tail call double (double (double, double)*, ...) @__enzyme_fp_optimize(double (double, double)* nonnull @test_minnum_general, double %x, double %y) + ret double %0 +} + +define double @test_profile_combined(double %x, double %y, double %z) { +entry: + %0 = tail call double (double (double, double, double)*, ...) @__enzyme_fp_optimize(double (double, double, double)* nonnull @test_combined, double %x, double %y, double %z) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fp_optimize(...) + +; CHECK: !0 = !{} +; CHECK: !1 = !{i64 0} +; CHECK: !2 = !{i64 1} +; CHECK: !3 = !{i64 2} \ No newline at end of file diff --git a/enzyme/test/Enzyme/Poseidon/fma_opt_cached.ll b/enzyme/test/Enzyme/Poseidon/fma_opt_cached.ll new file mode 100644 index 000000000000..2a17078ac389 --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/fma_opt_cached.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-use=%S/Inputs/fma_opt_profiles -fpopt-enable-herbie=true -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path=%S/Inputs/fma_opt_cache -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-use=%S/Inputs/fma_opt_profiles -fpopt-enable-herbie=true -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path=%S/Inputs/fma_opt_cache -S | FileCheck %s +; REQUIRES: poseidon + +; Cached Herbie rewrite: (x + y) * z --> fma(x, z, y * z) + +define double @tester(double %x, double %y, double %z) { +entry: + %add = fadd fast double %x, %y + %mul = fmul fast double %add, %z + ret double %mul +} + +define double @test_opt(double %x, double %y, double %z) { +entry: + %0 = tail call double (double (double, double, double)*, ...) @__enzyme_fp_optimize(double (double, double, double)* nonnull @tester, double %x, double %y, double %z, metadata !"enzyme_err_tol", double 0.5) + ret double %0 +} + +declare double @__enzyme_fp_optimize(double (double, double, double)*, ...) + +; CHECK: define double @test_opt(double %x, double %y, double %z) +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[call:.+]] = call double @preprocess_tester(double %x, double %y, double %z) +; CHECK-NEXT: ret double %[[call]] + +; CHECK: define double @preprocess_tester(double %x, double %y, double %z) +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[mul:.+]] = fmul fast double %z, %y +; CHECK-NEXT: %[[fma:.+]] = tail call fast double @llvm.fma.f64(double %x, double %z, double %[[mul]]) +; CHECK-NEXT: ret double %[[fma]] diff --git a/enzyme/test/Enzyme/Poseidon/fma_opt_noherbie.ll b/enzyme/test/Enzyme/Poseidon/fma_opt_noherbie.ll new file mode 100644 index 000000000000..d783ee9f48e0 --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/fma_opt_noherbie.ll @@ -0,0 +1,31 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-use=%S/Inputs/fma_opt_profiles -fpopt-enable-herbie=false -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path= -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-use=%S/Inputs/fma_opt_profiles -fpopt-enable-herbie=false -fpopt-enable-pt=false -fpopt-enable-solver=false -fpopt-cache-path= -S | FileCheck %s +; REQUIRES: poseidon + +; Opt phase with no Herbie/PT: __enzyme_fp_optimize lowers to preprocess_tester call. + +define double @tester(double %x, double %y, double %z) { +entry: + %add = fadd fast double %x, %y + %mul = fmul fast double %add, %z + ret double %mul +} + +define double @test_opt(double %x, double %y, double %z) { +entry: + %0 = tail call double (double (double, double, double)*, ...) @__enzyme_fp_optimize(double (double, double, double)* nonnull @tester, double %x, double %y, double %z, metadata !"enzyme_err_tol", double 0.5) + ret double %0 +} + +declare double @__enzyme_fp_optimize(double (double, double, double)*, ...) + +; CHECK: define double @test_opt(double %x, double %y, double %z) +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[call:.+]] = call double @preprocess_tester(double %x, double %y, double %z) +; CHECK-NEXT: ret double %[[call]] + +; CHECK: define double @preprocess_tester(double %x, double %y, double %z) +; CHECK-NEXT: entry: +; CHECK: fadd fast double +; CHECK: fmul fast double +; CHECK: ret double diff --git a/enzyme/test/Enzyme/Poseidon/fma_prof.ll b/enzyme/test/Enzyme/Poseidon/fma_prof.ll new file mode 100644 index 000000000000..b3ead4bae6d1 --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/fma_prof.ll @@ -0,0 +1,45 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-generate -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-generate -S | FileCheck %s +; REQUIRES: poseidon + +; Function Attrs: noinline nounwind readnone uwtable +define double @tester(double %x, double %y, double %z) { +entry: + %0 = fmul fast double %x, %y + %1 = fadd fast double %0, %z + ret double %1 +} + +define double @test_profile(double %x, double %y, double %z) { +entry: + %0 = tail call double (double (double, double, double)*, ...) @__enzyme_fp_optimize(double (double, double, double)* nonnull @tester, double %x, double %y, double %z) + ret double %0 +} + +; Function Attrs: nounwind +declare double @__enzyme_fp_optimize(double (double, double, double)*, ...) + +; CHECK: @ENZYME_FPPROFILE_RUNTIME_VAR = external global i32 +; CHECK: @fpprofiled_preprocess_tester = private unnamed_addr constant [18 x i8] c"preprocess_tester\00", align 1 + +; CHECK: define internal { double, double, double } @instrtester(double %x, double %y, double %z, double %[[differet:.+]]) #{{[0-9]+}} { +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[alloca:.+]] = alloca [3 x double], align 8 +; CHECK-NEXT: %[[fmuladd:.+]] = call fast double @llvm.fmuladd.f64(double %x, double %y, double %z), !enzyme_active !0, !enzyme_fpprofile_idx !1 +; CHECK-NEXT: store double %x, ptr %[[alloca]], align 8 +; CHECK-NEXT: %[[gep1:.+]] = getelementptr [3 x double], ptr %[[alloca]], i32 0, i32 1 +; CHECK-NEXT: store double %y, ptr %[[gep1]], align 8 +; CHECK-NEXT: %[[gep2:.+]] = getelementptr [3 x double], ptr %[[alloca]], i32 0, i32 2 +; CHECK-NEXT: store double %z, ptr %[[gep2]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_tester, i64 0, double %[[fmuladd]], i32 3, ptr %[[alloca]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_tester, i64 0, double %[[fmuladd]], double %[[differet]]) +; CHECK-NEXT: %[[grad1:.+]] = fmul fast double %[[differet]], %y +; CHECK-NEXT: %[[grad2:.+]] = fmul fast double %[[differet]], %x +; CHECK-NEXT: %[[ins1:.+]] = insertvalue { double, double, double } undef, double %[[grad1]], 0 +; CHECK-NEXT: %[[ins2:.+]] = insertvalue { double, double, double } %[[ins1]], double %[[grad2]], 1 +; CHECK-NEXT: %[[ins3:.+]] = insertvalue { double, double, double } %[[ins2]], double %[[differet]], 2 +; CHECK-NEXT: ret { double, double, double } %[[ins3]] +; CHECK-NEXT: } + +; CHECK: !0 = !{} +; CHECK: !1 = !{i64 0} \ No newline at end of file diff --git a/enzyme/test/Enzyme/Poseidon/structret_prof.ll b/enzyme/test/Enzyme/Poseidon/structret_prof.ll new file mode 100644 index 000000000000..15b1dfbfe2ce --- /dev/null +++ b/enzyme/test/Enzyme/Poseidon/structret_prof.ll @@ -0,0 +1,45 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -fpprofile-generate -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -enzyme-preopt=false -fpprofile-generate -S | FileCheck %s +; REQUIRES: poseidon + +; Adapted from enzyme/test/Enzyme/ReverseMode/gradient-struct-ret.ll + +%struct.Gradients = type { double, double } + +; Function Attrs: noinline nounwind readnone uwtable +define dso_local double @muldd(double %x, double %y) { +entry: + %mul = fmul fast double %x, %y + ret double %mul +} + +define dso_local %struct.Gradients @test_profile(double %x, double %y) { +entry: + %call = call %struct.Gradients (i8*, ...) @__enzyme_fp_optimize(i8* bitcast (double (double, double)* @muldd to i8*), double %x, double %y) + ret %struct.Gradients %call +} + +; Function Attrs: nounwind +declare %struct.Gradients @__enzyme_fp_optimize(i8*, ...) + +; CHECK: @ENZYME_FPPROFILE_RUNTIME_VAR = external global i32 +; CHECK: @fpprofiled_preprocess_muldd = private unnamed_addr constant [17 x i8] c"preprocess_muldd\00", align 1 + +; CHECK: define internal { double, double } @instrmuldd(double %x, double %y, double %[[differet:.+]]) #{{[0-9]+}} { +; CHECK-NEXT: entry: +; CHECK-NEXT: %[[alloca:.+]] = alloca [2 x double], align 8 +; CHECK-NEXT: %[[mul:.+]] = fmul fast double %x, %y, !enzyme_active !0, !enzyme_fpprofile_idx !1 +; CHECK-NEXT: store double %x, ptr %[[alloca]], align 8 +; CHECK-NEXT: %[[gep:.+]] = getelementptr [2 x double], ptr %[[alloca]], i32 0, i32 1 +; CHECK-NEXT: store double %y, ptr %[[gep]], align 8 +; CHECK-NEXT: call void @enzymeLogValue(ptr @fpprofiled_preprocess_muldd, i64 0, double %[[mul]], i32 2, ptr %[[alloca]]) +; CHECK-NEXT: call void @enzymeLogGrad(ptr @fpprofiled_preprocess_muldd, i64 0, double %[[mul]], double %[[differet]]) +; CHECK-NEXT: %[[grad1:.+]] = fmul fast double %[[differet]], %y +; CHECK-NEXT: %[[grad2:.+]] = fmul fast double %[[differet]], %x +; CHECK-NEXT: %[[ins1:.+]] = insertvalue { double, double } undef, double %[[grad1]], 0 +; CHECK-NEXT: %[[ins2:.+]] = insertvalue { double, double } %[[ins1]], double %[[grad2]], 1 +; CHECK-NEXT: ret { double, double } %[[ins2]] +; CHECK-NEXT: } + +; CHECK: !0 = !{} +; CHECK: !1 = !{i64 0} \ No newline at end of file diff --git a/enzyme/test/Integration/CMakeLists.txt b/enzyme/test/Integration/CMakeLists.txt index f76a45cfda30..e2ddaf268a5d 100644 --- a/enzyme/test/Integration/CMakeLists.txt +++ b/enzyme/test/Integration/CMakeLists.txt @@ -6,6 +6,9 @@ add_subdirectory(ReverseMode) add_subdirectory(BatchMode) add_subdirectory(Sparse) add_subdirectory(Truncate) +if(ENABLE_POSEIDON) + add_subdirectory(Poseidon) +endif() # Run regression and unit tests add_lit_testsuite(check-enzyme-integration "Running enzyme integration tests" diff --git a/enzyme/test/Integration/ForwardError/binops.c b/enzyme/test/Integration/ForwardError/binops.c index 2770060575ce..8c814b4c4bcd 100644 --- a/enzyme/test/Integration/ForwardError/binops.c +++ b/enzyme/test/Integration/ForwardError/binops.c @@ -11,24 +11,13 @@ double fabs(double); extern double __enzyme_error_estimate(void *, ...); -int errorLogCount = 0; - -void enzymeLogError(double res, double err, const char *opcodeName, - const char *calleeName, const char *moduleName, - const char *functionName, const char *blockName) { - ++errorLogCount; - printf("Res = %e, Error = %e, Op = %s, Callee = %s, Module = %s, Function = " - "%s, BasicBlock = %s\n", - res, err, opcodeName, calleeName, moduleName, functionName, blockName); -} - // An example from https://dl.acm.org/doi/10.1145/3371128 double fun(double x) { double v1 = cos(x); double v2 = 1 - v1; double v3 = x * x; double v4 = v2 / v3; - double v5 = sin(v4); // Inactive -- logger is not invoked. + double v5 = sin(v4); printf("v1 = %.18e, v2 = %.18e, v3 = %.18e, v4 = %.18e, v5 = %.18e\n", v1, v2, v3, v4, v5); @@ -42,5 +31,4 @@ int main() { printf("res = %.18e, abs error = %.18e, rel error = %.18e\n", res, error, fabs(error / res)); APPROX_EQ(error, 2.2222222222e-2, 1e-4); - TEST_EQ(errorLogCount, 4); } diff --git a/enzyme/test/Integration/Poseidon/CMakeLists.txt b/enzyme/test/Integration/Poseidon/CMakeLists.txt new file mode 100644 index 000000000000..6f5dcfa1dde7 --- /dev/null +++ b/enzyme/test/Integration/Poseidon/CMakeLists.txt @@ -0,0 +1,8 @@ +# Run regression and unit tests +add_lit_testsuite(check-poseidon-integration "Running Poseidon integration tests" + ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${ENZYME_TEST_DEPS} ClangEnzyme-${LLVM_VERSION_MAJOR} + ARGS -v +) + +set_target_properties(check-poseidon-integration PROPERTIES FOLDER "Tests") diff --git a/enzyme/test/Integration/Poseidon/prof_add.c b/enzyme/test/Integration/Poseidon/prof_add.c new file mode 100644 index 000000000000..242201509c30 --- /dev/null +++ b/enzyme/test/Integration/Poseidon/prof_add.c @@ -0,0 +1,32 @@ +// RUN: %clang -O0 %s -S -emit-llvm -o %t.ll +// RUN: %opt %t.ll %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify,simplifycfg)" -enzyme-preopt=false -fpprofile-generate -S -o %t.opt.ll +// RUN: %clang -O0 %t.opt.ll -c -o %t.o +// RUN: %clang++ %t.o %FPProfileLib -lstdc++ -lm -o %t.exe +// RUN: rm -rf %t.profiles && ENZYME_FPPROFILE_DIR=%t.profiles %t.exe +// RUN: cat %t.profiles/preprocess_tester.fpprofile | FileCheck %s +// REQUIRES: poseidon + +#include + +extern double __enzyme_fp_optimize(void *, ...); + +double tester(double x, double y) { + return x + y; +} + +int main() { + double res = __enzyme_fp_optimize((void *)tester, 3.0, 4.0); + printf("result = %f\n", res); + + res = __enzyme_fp_optimize((void *)tester, 1.0, 2.0); + printf("result = %f\n", res); + + return 0; +} + +// CHECK: MinRes = 3.{{[0-9e+]+}} +// CHECK: MaxRes = 7.{{[0-9e+]+}} +// CHECK: Exec = 2 +// CHECK: NumOperands = 2 +// CHECK: Operand[0] = [1.{{[0-9e+]+}}, 3.{{[0-9e+]+}}] +// CHECK: Operand[1] = [2.{{[0-9e+]+}}, 4.{{[0-9e+]+}}] diff --git a/enzyme/test/lit.site.cfg.py.in b/enzyme/test/lit.site.cfg.py.in index 0cc5e6f28f38..678a454b6090 100644 --- a/enzyme/test/lit.site.cfg.py.in +++ b/enzyme/test/lit.site.cfg.py.in @@ -118,6 +118,10 @@ config.substitutions.append(('%newLoadClangEnzyme', newPM)) config.substitutions.append(('%hasMPFR', has_mpfr)) +if "@ENABLE_POSEIDON@" == "1": + config.available_features.add('poseidon') + config.substitutions.append(('%FPProfileLib', '@ENZYME_BINARY_DIR@/Enzyme/libEnzymeFPProfile.a')) + # Let the main config do the real work. cfgfile = "@ENZYME_SOURCE_DIR@/test/lit.cfg.py" if len("@ENZYME_SOURCE_DIR@") == 0: diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index a7ace51e2bed..2b3fdd3d7b5c 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -2177,6 +2177,103 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, else os << " gutils->eraseIfUnused(" << origName << ");\n"; + if (intrinsic != MLIRDerivatives) { + os << "#ifdef ENABLE_POSEIDON\n"; + os << " if (gutils->profiled && " + "Poseidonable(" + << origName << ")) {\n" + << " if (auto md = " << origName + << ".getMetadata(\"enzyme_fpprofile_idx\")) {\n" + << " size_t instIdx = " + "cast(cast(md->getOperand(0))->" + "getValue())->getZExtValue();\n" + << " Type *PtrTy = PointerType::getUnqual(" << origName + << ".getContext());\n" + << " Type *DoubleTy = Type::getDoubleTy(" << origName + << ".getContext());\n" + << " Type *SizeTy = Type::getInt64Ty(" << origName + << ".getContext());\n" + << " Type *Int32Ty = Type::getInt32Ty(" << origName + << ".getContext());\n" + << " FunctionType *LogValueFT = " + "FunctionType::get(Type::getVoidTy(" + << origName << ".getContext()),\n" + << " {PtrTy, SizeTy, DoubleTy, Int32Ty, PtrTy}, false);\n" + << " FunctionCallee logFunc = " << origName + << ".getModule()->getOrInsertFunction(\"enzymeLogValue\", " + "LogValueFT);\n" + << " IRBuilder<> BuilderZ(&" << origName << ");\n" + << " getForwardBuilder(BuilderZ);\n" + << " std::string funcName = gutils->oldFunc->getName().str();\n" + << " GlobalVariable *gv = " + "gutils->oldFunc->getParent()->getNamedGlobal(\"fpprofiled_\" + " + "funcName);\n" + << " if (!gv)\n" + << " gv = BuilderZ.CreateGlobalString(funcName, " + "\"fpprofiled_\" + funcName);\n" + << " Value *funcNamePtr = " + "BuilderZ.CreateInBoundsGEP(gv->getValueType(), gv, " + "{BuilderZ.getInt32(0), BuilderZ.getInt32(0)});\n" + << " Value *origValue = " + "BuilderZ.CreateFPExt(gutils->getNewFromOriginal(&" + << origName << "),\n" + << " Type::getDoubleTy(" << origName << ".getContext()));\n" + << " unsigned numOperands = isa(" << origName + << ") ?\n" + << " cast(" << origName + << ").arg_size() : " << origName << ".getNumOperands();\n" + << " Value *numOperandsValue = ConstantInt::get(\n" + << " Type::getInt32Ty(" << origName + << ".getContext()), numOperands);\n" + << " auto operands = isa(" << origName << ") ?\n" + << " cast(" << origName << ").args() : " << origName + << ".operands();\n" + << " ArrayType *operandArrayType = ArrayType::get(\n" + << " Type::getDoubleTy(" << origName + << ".getContext()), numOperands);\n" + << " Value *operandArrayValue = " + "IRBuilder<>(gutils->inversionAllocs).\n" + << " CreateAlloca(operandArrayType);\n" + << " for (auto operand : enumerate(operands)) {\n" + << " Value *origOp = " + "gutils->getNewFromOriginal(operand.value());\n" + << " Value *operandValue = nullptr;\n" + << " if (origOp->getType()->isFloatingPointTy()) {\n" + << " operandValue = BuilderZ.CreateFPExt(\n" + << " origOp, Type::getDoubleTy(" << origName + << ".getContext()));\n" + << " } else if (origOp->getType()->isIntegerTy()) {\n" + << " operandValue = BuilderZ.CreateSIToFP(\n" + << " origOp, Type::getDoubleTy(" << origName + << ".getContext()));\n" + << " } else {\n" + << " llvm_unreachable(\"Unsupported operand type\");\n" + << " }\n" + << " Value *ptr = BuilderZ.CreateGEP(\n" + << " operandArrayType, operandArrayValue,\n" + << " {ConstantInt::get(Type::getInt32Ty(" << origName + << ".getContext()), 0),\n" + << " ConstantInt::get(Type::getInt32Ty(" << origName + << ".getContext()), operand.index())});\n" + << " BuilderZ.CreateStore(operandValue, ptr);\n" + << " }\n" + << " Value *operandPtrValue = BuilderZ.CreateGEP(\n" + << " operandArrayType, operandArrayValue,\n" + << " {ConstantInt::get(Type::getInt32Ty(" << origName + << ".getContext()), 0),\n" + << " ConstantInt::get(Type::getInt32Ty(" << origName + << ".getContext()), 0)});\n" + << " CallInst *logCallInst = BuilderZ.CreateCall(\n" + << " logFunc, {funcNamePtr, ConstantInt::get(SizeTy, " + "instIdx), " + "origValue, numOperandsValue, operandPtrValue});\n" + << " logCallInst->setDebugLoc(gutils->getNewFromOriginal(" + << origName << ".getDebugLoc()));\n" + << " }\n" + << " }\n"; + os << "#endif\n"; + } + if (intrinsic == MLIRDerivatives) { os << " if (gutils->isConstantInstruction(op))\n"; os << " return success();\n"; @@ -2350,8 +2447,6 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } os << " }\n"; - // forward error TODO: `ForwardFromSummedReverse` behavior - // also for custom derivatives. if (intrinsic != MLIRDerivatives) { os << " case DerivativeMode::ForwardModeError: {\n"; os << " IRBuilder<> Builder2(&" << origName << ");\n"; @@ -2465,113 +2560,10 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } // Perform the max with 1 ulp - // error TODO os << " res = Builder2.CreateMaxNum(get1ULP(Builder2, " "gutils->getNewFromOriginal(&" << origName << ")), res);\n"; - os << " assert(res);\n"; - - // Insert logging function call (optional) - os << " Function *logFunc = " << origName - << ".getModule()->getFunction(\"enzymeLogError\");\n"; - os << " if (logFunc) {\n" - << " std::string moduleName = " << origName - << ".getModule()->getModuleIdentifier() ;\n" - << " std::string functionName = " << origName - << ".getFunction()->getName().str();\n" - << " std::string blockName = " << origName - << ".getParent()->getName().str();\n" - << " int funcIdx = -1, blockIdx = -1, instIdx = -1;\n" - << " auto funcIt = std::find_if(" << origName - << ".getModule()->begin(), " << origName - << ".getModule()->end(),\n" - " [&](const auto& func) { return &func == " - << origName - << ".getFunction(); });\n" - " if (funcIt != " - << origName - << ".getModule()->end()) {\n" - " funcIdx = " - "std::distance(" - << origName << ".getModule()->begin(), funcIt);\n" - << " }\n" - << " auto blockIt = std::find_if(" << origName - << ".getFunction()->begin(), " << origName - << ".getFunction()->end(),\n" - " [&](const auto& block) { return &block == " - << origName - << ".getParent(); });\n" - " if (blockIt != " - << origName - << ".getFunction()->end()) {\n" - " blockIdx = std::distance(" - << origName << ".getFunction()->begin(), blockIt);\n" - << " }\n" - << " auto instIt = std::find_if(" << origName - << ".getParent()->begin(), " << origName - << ".getParent()->end(),\n" - " [&](const auto& curr) { return &curr == &" - << origName - << "; });\n" - " if (instIt != " - << origName - << ".getParent()->end()) {\n" - " instIdx = std::distance(" - << origName << ".getParent()->begin(), instIt);\n" - << " }\n" - << " Value *origValue = " - "Builder2.CreateFPExt(gutils->getNewFromOriginal(&" - << origName << "), Type::getDoubleTy(" << origName - << ".getContext()));\n" - << " Value *errValue = Builder2.CreateFPExt(res, " - "Type::getDoubleTy(" - << origName << ".getContext()));\n" - << " std::string opcodeName = " << origName - << ".getOpcodeName();\n" - << " std::string calleeName = \"\";\n" - << " if (auto CI = dyn_cast(&" << origName - << ")) {\n" - << " if (Function *fn = CI->getCalledFunction()) {\n" - << " calleeName = fn->getName();\n" - << " } else {\n" - << " calleeName = \"\";\n" - << " }\n" - << " }\n" - << "#if LLVM_VERSION_MAJOR >= 17\n" - << " Value *moduleNameValue = " - "Builder2.CreateGlobalString(moduleName);\n" - << " Value *functionNameValue = " - "Builder2.CreateGlobalString(functionName + \" (\" +" - "std::to_string(funcIdx) + \")\");\n" - << " Value *blockNameValue = " - "Builder2.CreateGlobalString(blockName + \" (\" +" - "std::to_string(blockIdx) + \")\");\n" - << " Value *opcodeNameValue = " - "Builder2.CreateGlobalString(opcodeName + \" (\" " - "+std::to_string(instIdx) + \")\");\n" - << " Value *calleeNameValue = " - "Builder2.CreateGlobalString(calleeName);\n" - << "#else\n" - << " Value *moduleNameValue = " - "Builder2.CreateGlobalStringPtr(moduleName);\n" - << " Value *functionNameValue = " - "Builder2.CreateGlobalStringPtr(functionName + \" (\" +" - "std::to_string(funcIdx) + \")\");\n" - << " Value *blockNameValue = " - "Builder2.CreateGlobalStringPtr(blockName + \" (\" +" - "std::to_string(blockIdx) + \")\");\n" - << " Value *opcodeNameValue = " - "Builder2.CreateGlobalStringPtr(opcodeName + \" (\" " - "+std::to_string(instIdx) + \")\");\n" - << " Value *calleeNameValue = " - "Builder2.CreateGlobalStringPtr(calleeName);\n" - << "#endif\n" - << " Builder2.CreateCall(logFunc, {origValue, " - "errValue, opcodeNameValue, calleeNameValue, moduleNameValue, " - "functionNameValue, blockNameValue});\n" - << " }\n"; - os << " setDiffe(&" << origName << ", res, Builder2);\n"; os << " break;\n"; os << " }\n"; @@ -2583,6 +2575,62 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " IRBuilder<> Builder2(&" << origName << ");\n"; os << " getReverseBuilder(Builder2);\n"; os << " Value *dif = nullptr;\n"; + + // Set `dif` immediately for profiled instructions + os << "#ifdef ENABLE_POSEIDON\n" + << " if (gutils->profiled && " + "Poseidonable(" + << origName << ")) {\n" + << " if (auto md = " << origName + << ".getMetadata(\"enzyme_fpprofile_idx\")) {\n" + << " size_t instIdx = " + "cast(cast(md->getOperand(0))->" + "getValue())->getZExtValue();\n" + << " dif = diffe(&" << origName << ", Builder2);\n" + << " setDiffe(&" << origName + << ", Constant::getNullValue(gutils->getShadowType(" << origName + << ".getType())), Builder2);\n" + << " Type *PtrTy = PointerType::getUnqual(" << origName + << ".getContext());\n" + << " Type *SizeTy = Type::getInt64Ty(" << origName + << ".getContext());\n" + << " Type *DoubleTy = Type::getDoubleTy(" << origName + << ".getContext());\n" + << " FunctionType *LogGradFT = " + "FunctionType::get(Type::getVoidTy(" + << origName << ".getContext()),\n" + << " {PtrTy, SizeTy, DoubleTy, DoubleTy}, false);\n" + << " FunctionCallee logFunc = " << origName + << ".getModule()->getOrInsertFunction(\"enzymeLogGrad\", LogGradFT);\n" + << " std::string funcName = " + "gutils->oldFunc->getName().str();\n" + << " GlobalVariable *gv = " + "gutils->oldFunc->getParent()->getNamedGlobal(\"fpprofiled_\" + " + "funcName);\n" + << " if (!gv)\n" + << " gv = Builder2.CreateGlobalString(funcName, " + "\"fpprofiled_\" + funcName);\n" + << " Value *funcNamePtr = " + "Builder2.CreateInBoundsGEP(gv->getValueType(), gv, " + "{Builder2.getInt32(0), Builder2.getInt32(0)});\n" + << " Value *primalInst = " + "gutils->lookupM(gutils->getNewFromOriginal(&" + << origName << "), Builder2);\n" + << " Value *primalDouble = " + "Builder2.CreateFPExt(primalInst, " + "Type::getDoubleTy(" + << origName << ".getContext()));\n" + << " Value *gradDouble = Builder2.CreateFPExt(dif, " + "Type::getDoubleTy(" + << origName << ".getContext()));\n" + << " CallInst *logCallInst = Builder2.CreateCall(logFunc, " + "{funcNamePtr, ConstantInt::get(SizeTy, instIdx), primalDouble, " + "gradDouble});\n" + << " logCallInst->setDebugLoc(gutils->getNewFromOriginal(" + << origName << ".getDebugLoc()));\n" + << " }\n" + << " }\n" + << "#endif\n"; } else { os << "};\n"; emitMLIRReverse(os, pattern, tree, intrinsic, origName, argOps);