diff --git a/enzyme/.bazelrc b/enzyme/.bazelrc index 3c3d089fa2cc..66bfe941132e 100644 --- a/enzyme/.bazelrc +++ b/enzyme/.bazelrc @@ -21,4 +21,3 @@ build --define=use_fast_cpp_protos=true build --define=allow_oversize_protos=true build -c opt - diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index f88d212e8c76..1803eb3167d4 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -349,13 +349,31 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( : regionBranchOp.getSuccessorInputs(successor); // Need to know which of the arguments are being forwarded to from - // operands. + // operands. An operand needs a shadow — and the ForOp needs a matching + // shadow result — whenever EITHER its iter arg OR its corresponding op + // result is active. Using only iter arg activity misses the + // constant-accumulator case (constant init arg that produces an active + // result because the loop body accumulates active values into it). + // Using only result activity misses the case where an iter arg is active + // but its result is not (e.g. pointer-typed iter args used for address + // arithmetic whose final values are unused downstream). + // forceAugmentedReturns uses only iter arg activity, so for positions + // where the result is active but the iter arg is constant, the second + // overload inserts the missing shadow block arg after takeBody. for (auto &&[i, regionValue, operand] : llvm::enumerate(targetValues, operandRange)) { - if (gutils->isConstantValue(regionValue)) + bool iterArgActive = !gutils->isConstantValue(regionValue); + bool resultActive = i < op->getNumResults() && + !gutils->isConstantValue(op->getResult(i)); + if (!iterArgActive && !resultActive) continue; operandPositionsToShadow.insert(operandRange.getBeginOperandIndex() + i); - if (successor.isParent()) + // Add the corresponding result to resultPositionsToShadow if the iter + // arg is active: forceAugmentedReturns will have inserted a shadow + // block arg for it, so the ForOp needs a matching shadow result. + // Active results (regardless of iter arg activity) are covered by the + // loop below. + if (successor.isParent() || (iterArgActive && i < op->getNumResults())) resultPositionsToShadow.insert(i); } } @@ -423,6 +441,47 @@ LogicalResult mlir::enzyme::detail::controlFlowForwardHandler( replacementRegion.takeBody(region); } + // forceAugmentedReturns inserts shadow block args only for iter args that + // are themselves active. When an iter arg is constant but its corresponding + // op result is active (e.g. a zero accumulator that accumulates active + // values across iterations), the first overload still adds that position to + // both operandPositionsToShadow and resultPositionsToShadow (union + // criterion), so replacement has the right number of results. However, the + // body block is missing the shadow block arg that the replacement's + // iter_arg slot expects. Insert it here, after takeBody has placed the + // cloned body into replacement. + // + // We also register the mapping in invertedPointers so that invertPointerM, + // which checks invertedPointers before isConstantValue, returns the shadow + // block arg instead of zero when body ops reference this iter arg. + if (auto rbIface = dyn_cast(op)) { + SmallVector entrySuccessors; + rbIface.getEntrySuccessorRegions( + SmallVector(op->getNumOperands(), Attribute()), + entrySuccessors); + for (const RegionSuccessor &successor : entrySuccessors) { + if (successor.isParent()) + continue; + ValueRange successorInputs = rbIface.getSuccessorInputs(successor); + for (auto [i, iterArg] : llvm::enumerate(successorInputs)) { + if (!resultPositionsToShadow.count(i)) + continue; + if (!gutils->isConstantValue(iterArg)) + continue; + // iterArg is constant but position i needs a shadow result. + // Insert the missing shadow block arg right after iterArg's clone. + auto clonedIterArg = + cast(gutils->getNewFromOriginal(iterArg)); + Block *block = clonedIterArg.getParentBlock(); + Value shadowArg = block->insertArgument( + clonedIterArg.getArgNumber() + 1, + gutils->getShadowType(clonedIterArg.getType()), + clonedIterArg.getLoc()); + gutils->invertedPointers.map(iterArg, shadowArg); + } + } + } + // Inject the mapping for the new results into GradientUtil's shadow // table. SmallVector reps; diff --git a/enzyme/test/MLIR/ForwardMode/for3.mlir b/enzyme/test/MLIR/ForwardMode/for3.mlir new file mode 100644 index 000000000000..ea1d347f0250 --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/for3.mlir @@ -0,0 +1,45 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +// Test that a constant iter arg whose corresponding ForOp result is active +// (the "constant accumulator" pattern) is correctly differentiated. +// The iter arg %acc is initialized from a constant zero and is therefore +// marked constant by activity analysis, but the ForOp result is active +// because active values (%x) are accumulated into it through the body. +// The differentiated ForOp must have a shadow iter arg (also zero-initialized) +// that accumulates the tangent dx on each iteration. + +module { + func.func @square(%x : f64) -> f64 { + %zero = arith.constant 0.0 : f64 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %r = scf.for %i = %c0 to %c10 step %c1 iter_args(%acc = %zero) -> (f64) { + %n = arith.addf %acc, %x : f64 + scf.yield %n : f64 + } + return %r : f64 + } + func.func @dsq(%x : f64, %dx : f64) -> f64 { + %r = enzyme.fwddiff @square(%x, %dx) { activity=[#enzyme], ret_activity=[#enzyme] } : (f64, f64) -> (f64) + return %r : f64 + } +} + +// The differentiated ForOp must have TWO iter args: the primal accumulator +// (init = 0.0) and its shadow (init = 0.0, since the original init is a +// constant). On each iteration the shadow accumulates dx (= %arg1). + +// CHECK: func.func private @fwddiffesquare(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { +// CHECK-DAG: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[cst_0:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index +// CHECK-NEXT: %[[r:.+]]:2 = scf.for %{{.+}} = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[acc:.+]] = %[[cst_0]], %[[sacc:.+]] = %[[cst]]) -> (f64, f64) { +// CHECK-NEXT: %[[sn:.+]] = arith.addf %[[sacc]], %[[arg1]] : f64 +// CHECK-NEXT: %[[n:.+]] = arith.addf %[[acc]], %[[arg0]] : f64 +// CHECK-NEXT: scf.yield %[[n]], %[[sn]] : f64, f64 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[r]]#1 : f64 +// CHECK-NEXT: }