diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 7a98e2fb02d1..2d9cc828584c 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -51,6 +51,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringMap.h" #include "llvm/Support/AMDGPUMetadata.h" @@ -58,6 +59,12 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/TimeProfiler.h" +#define DEBUG_TYPE "enzyme" + +STATISTIC(NumValuesCached, "Number of values cached for the reverse pass"); +STATISTIC(NumValuesRecomputed, + "Number of values recomputed instead of cached"); + #if LLVM_VERSION_MAJOR >= 14 #define addAttribute addAttributeAtIndex #define hasAttribute hasAttributeAtIndex @@ -8492,6 +8499,7 @@ void GradientUtils::computeMinCache() { if (oneneed || shadowOneNeed) { knownRecomputeHeuristic[&I] = false; + ++NumValuesCached; CountTrackedPointers T(I.getType()); assert(!T.derived); @@ -8569,6 +8577,10 @@ void GradientUtils::computeMinCache() { for (auto V : Intermediates) { knownRecomputeHeuristic[V] = !MinReq.count(V); + if (MinReq.count(V)) + ++NumValuesCached; + else + ++NumValuesRecomputed; if (!MinReq.count(V) && NeedGraph.count(V)) { if (auto CI = dyn_cast(V)) if (getFuncNameFromCall(CI) == "julia.call")