diff --git a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp index 77304fa04432..50e4f7b14180 100644 --- a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp @@ -477,6 +477,8 @@ std::optional getStored(Operation *op) { return storeOp.getValue(); } else if (auto storeOp = dyn_cast(op)) { return storeOp.getValue(); + } else if (auto pushOp = dyn_cast(op)) { + return pushOp.getValue(); } return std::nullopt; } diff --git a/enzyme/Enzyme/MLIR/Analysis/DataFlowAliasAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/DataFlowAliasAnalysis.cpp index be936215c8c3..021b3a8bdf72 100644 --- a/enzyme/Enzyme/MLIR/Analysis/DataFlowAliasAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/DataFlowAliasAnalysis.cpp @@ -48,7 +48,7 @@ using namespace mlir; using namespace mlir::dataflow; static bool isPointerLike(Type type) { - return isa(type); + return isa(type); } //===----------------------------------------------------------------------===// @@ -121,29 +121,6 @@ ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate, [&](DistinctAttr dest, AliasClassSet::State state) { assert(state == AliasClassSet::State::Defined && "unknown must have been handled above"); -#ifndef NDEBUG - if (replace) { - auto it = map.find(dest); - if (it != map.end()) { - // Check that we are updating to a state that's >= in the - // lattice. - // TODO: consider a stricter check that we only replace unknown - // values or a value with itself, currently blocked by memalign. - AliasClassSet valuesCopy(values); - (void)valuesCopy.join(it->getSecond()); - values.print(llvm::errs()); - llvm::errs() << "\n"; - it->getSecond().print(llvm::errs()); - llvm::errs() << "\n"; - valuesCopy.print(llvm::errs()); - llvm::errs() << "\n"; - assert(valuesCopy == values && - "attempting to replace a pointsTo entry with an alias class " - "set that is ordered _before_ the existing one -> " - "non-monotonous update "); - } - } -#endif // NDEBUG return joinPotentiallyMissing(dest, values); }); } @@ -278,7 +255,7 @@ LogicalResult enzyme::PointsToPointerAnalysis::visitOperation( // fixpoint and bail. auto memory = dyn_cast(op); if (!memory) { - if (isNoOp(op)) + if (isNoOp(op) || isMemoryEffectFree(op)) return success(); propagateIfChanged(after, after->markAllPointToUnknown()); return success(); @@ -557,7 +534,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( std::optional otherModRef = getFunctionOtherModRef(callee); - SmallVector pointerLikeOperands; + SmallVector pointerLikeOperands; for (auto &&[i, operand] : llvm::enumerate(call.getArgOperands())) { if (isPointerLike(operand.getType())) pointerLikeOperands.push_back(i); @@ -575,7 +552,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( // unknown alias sets into any writable pointer. (void)functionMayCapture.markUnknown(); } else { - for (int pointerAsData : pointerLikeOperands) { + for (unsigned pointerAsData : pointerLikeOperands) { // If not captured, it cannot be stored in anything. if ((pointerAsData < numArguments && !!callee.getArgAttr(pointerAsData, @@ -593,7 +570,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( AliasClassSet writableClasses = AliasClassSet::getUndefined(); AliasClassSet nonWritableOperandClasses = AliasClassSet::getUndefined(); ChangeResult changed = ChangeResult::NoChange; - for (int pointerOperand : pointerLikeOperands) { + for (unsigned pointerOperand : pointerLikeOperands) { auto *destClasses = getOrCreateFor( getProgramPointAfter(call), call.getArgOperands()[pointerOperand]); @@ -696,7 +673,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( continue; } - for (int operandNo : pointerLikeOperands) { + for (unsigned operandNo : pointerLikeOperands) { const auto *srcClasses = getOrCreateFor( getProgramPointAfter(call), call.getArgOperands()[operandNo]); if (mayReadArg(callee, operandNo, argModRef)) { @@ -840,7 +817,8 @@ static bool isAliasTransferFullyDescribedByMemoryEffects(Operation *op) { } } } - return isa(op); + return isa(op); } void enzyme::AliasAnalysis::transfer( diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index 983580cb79a8..4688dd417ad3 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -279,7 +279,7 @@ def PushOp : Enzyme_Op<"push", [ "cache", "value", "::llvm::cast($_self).getType()">]> { let summary = "Push value to cache"; - let arguments = (ins AnyType : $cache, AnyType : $value); + let arguments = (ins Arg:$cache, AnyType:$value); } def PopOp : Enzyme_Op<"pop", [ @@ -287,7 +287,7 @@ def PopOp : Enzyme_Op<"pop", [ "cache", "output", "::llvm::cast($_self).getType()">]> { let summary = "Retrieve information for the reverse mode pass."; - let arguments = (ins AnyType : $cache); + let arguments = (ins Arg:$cache); let results = (outs AnyType:$output); } @@ -295,7 +295,7 @@ def InitOp : Enzyme_Op<"init", [DeclareOpInterfaceMethods]> { let summary = "Create enzyme.gradient and enzyme.cache"; let arguments = (ins ); - let results = (outs AnyType); + let results = (outs Res]>); } def Cache : Enzyme_Type<"Cache"> { diff --git a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp index 542c494ab79d..bb00a51879d3 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp @@ -16,6 +16,9 @@ #include #include +#include "Analysis/DataFlowAliasAnalysis.h" +#include "mlir/Analysis/DataFlow/Utils.h" + #include "llvm/ADT/MapVector.h" using namespace mlir; @@ -263,11 +266,76 @@ static inline void bfs(const Graph &G, const llvm::SetVector &Sources, } } +struct OverwriteAnalyzer { + static OverwriteAnalyzer analyzeFunc(FunctionOpInterface funcOp) { + return OverwriteAnalyzer(funcOp); + } + + bool isPtrPotentiallyModified(Value ptr) const { + // If the alias analysis failed, conservatively assume all pointers may + // be modified + if (!valid) + return true; + + // Check if the pointer's alias classes intersect the modified alias classes + auto *ptrClass = solver.lookupState(ptr); + return !ptrClass->alias(modified).isNo(); + } + +private: + DataFlowSolver solver; + // The set of all alias classes that are potentially modified in the function + AliasClassLattice modified; + bool valid = true; + + OverwriteAnalyzer(FunctionOpInterface funcOp) + : solver(DataFlowConfig().setInterprocedural(false)), modified(nullptr) { + dataflow::loadBaselineAnalyses(solver); + solver.load(funcOp.getContext(), /*relative=*/false); + solver.load(); + if (failed(solver.initializeAndRun(funcOp))) { + assert(false && "dataflow analysis failed"); + valid = false; + } else { + funcOp.walk([&](MemoryEffectOpInterface memory) { + SmallVector effects; + memory.getEffects(effects); + for (const auto &effect : effects) { + if (isa(effect.getEffect())) { + Value val = effect.getValue(); + if (val) { + (void)modified.join(*solver.lookupState(val)); + } else { + (void)modified.markUnknown(); + } + } + } + }); + } + } +}; + +bool isLoadMovable(const OverwriteAnalyzer &analyzer, Operation *op) { + if (!hasSingleEffect(op)) { + return false; + } + auto memory = cast(op); + SmallVector effects; + memory.getEffects(effects); + assert(effects.size() == 1 && + isa(effects.front().getEffect())); + Value ptr = effects.front().getValue(); + + // The load can be re-done if the pointer's contents are never modified + // by the function. + return !analyzer.isPtrPotentiallyModified(ptr); +} + // Whether or not an operation can be moved from the forward region to the // reverse region or vice-versa. -static inline bool isMovable(Operation *op) { +static inline bool isMovable(const OverwriteAnalyzer &analyzer, Operation *op) { return op->getNumRegions() == 0 && op->getBlock()->getTerminator() != op && - mlir::isPure(op); + (mlir::isPure(op) || isLoadMovable(analyzer, op)); } // Given a graph `G`, construct a new graph `G2`, where all paths must terminate @@ -487,6 +555,8 @@ void mlir::enzyme::minCutCache(Block *forward, Block *reverse, } Graph G; + auto overwriteAnalyzer = OverwriteAnalyzer::analyzeFunc( + forward->getParent()->getParentOfType()); LLVM_DEBUG(llvm::dbgs() << "trying min/cut\n"); LLVM_DEBUG( @@ -518,7 +588,7 @@ void mlir::enzyme::minCutCache(Block *forward, Block *reverse, } Operation *owner = todo.getDefiningOp(); - if (!owner || !isMovable(owner)) { + if (!owner || !isMovable(overwriteAnalyzer, owner)) { roots.insert(todo); continue; } @@ -544,7 +614,8 @@ void mlir::enzyme::minCutCache(Block *forward, Block *reverse, bool isRequired = false; for (auto user : poped.getUsers()) { - if (user->getBlock() != reverse || !isMovable(user)) { + if (user->getBlock() != reverse || + !isMovable(overwriteAnalyzer, user)) { G[info.pushedValue()].insert(Node(user)); Required.insert(user); isRequired = true; @@ -567,7 +638,8 @@ void mlir::enzyme::minCutCache(Block *forward, Block *reverse, bool isRequired = false; for (auto user : todo.getUsers()) { - if (user->getBlock() != reverse || !isMovable(user)) { + if (user->getBlock() != reverse || + !isMovable(overwriteAnalyzer, user)) { G[todo].insert(Node(user)); Required.insert(user); isRequired = true; diff --git a/enzyme/test/MLIR/ReverseMode/scf_parallel.mlir b/enzyme/test/MLIR/ReverseMode/scf_parallel.mlir index 32abf43f76b5..ff39c387d234 100644 --- a/enzyme/test/MLIR/ReverseMode/scf_parallel.mlir +++ b/enzyme/test/MLIR/ReverseMode/scf_parallel.mlir @@ -1,5 +1,5 @@ -// RUN: %eopt %s --pass-pipeline="builtin.module(enzyme,canonicalize,remove-unnecessary-enzyme-ops)" | FileCheck %s -func.func @foo(%x: memref, %y: memref) { +// RUN: %eopt %s --pass-pipeline="builtin.module(enzyme,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math)" | FileCheck %s +func.func @foo(%x: memref {llvm.noalias}, %y: memref {llvm.noalias}) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index @@ -12,7 +12,7 @@ func.func @foo(%x: memref, %y: memref) { return } -func.func @dfoo(%x: memref, %dx: memref, %y: memref, %dy: memref) { +func.func @dfoo(%x: memref {llvm.noalias}, %dx: memref {llvm.noalias}, %y: memref {llvm.noalias}, %dy: memref {llvm.noalias}) { enzyme.autodiff @foo(%x, %dx, %y, %dy) { activity = [#enzyme, #enzyme], ret_activity = [] @@ -20,31 +20,27 @@ func.func @dfoo(%x: memref, %dx: memref, %y: memref, %dy: m return } -// CHECK: func.func private @diffefoo(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref) { -// CHECK-NEXT: %c4 = arith.constant 4 : index -// CHECK-NEXT: %c1 = arith.constant 1 : index -// CHECK-NEXT: %c0 = arith.constant 0 : index -// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32 -// CHECK-NEXT: %alloc = memref.alloc() : memref<4xf32> -// CHECK-NEXT: scf.parallel (%arg4) = (%c0) to (%c4) step (%c1) { -// CHECK-NEXT: %0 = memref.load %arg0[%arg4] : memref -// CHECK-NEXT: memref.store %0, %alloc[%arg4] : memref<4xf32> -// CHECK-NEXT: %1 = arith.mulf %0, %0 : f32 -// CHECK-NEXT: memref.store %1, %arg2[%arg4] : memref -// CHECK-NEXT: scf.reduce -// CHECK-NEXT: } -// CHECK-NEXT: scf.parallel (%arg4) = (%c0) to (%c4) step (%c1) { -// CHECK-NEXT: %0 = memref.load %alloc[%arg4] : memref<4xf32> -// CHECK-NEXT: %1 = memref.load %arg3[%arg4] : memref -// CHECK-NEXT: %2 = arith.addf %1, %cst : f32 -// CHECK-NEXT: memref.store %cst, %arg3[%arg4] : memref -// CHECK-NEXT: %3 = arith.mulf %2, %0 : f32 -// CHECK-NEXT: %4 = arith.addf %3, %cst : f32 -// CHECK-NEXT: %5 = arith.mulf %2, %0 : f32 -// CHECK-NEXT: %6 = arith.addf %4, %5 : f32 -// CHECK-NEXT: %7 = memref.atomic_rmw addf %6, %arg1[%arg4] : (f32, memref) -> f32 -// CHECK-NEXT: scf.reduce -// CHECK-NEXT: } -// CHECK-NEXT: memref.dealloc %alloc : memref<4xf32> -// CHECK-NEXT: return -// CHECK-NEXT: } +// CHECK-LABEL: func.func private @diffefoo( +// CHECK-SAME: %[[ARG0:.*]]: memref {llvm.noalias}, %[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref {llvm.noalias}, %[[ARG3:.*]]: memref) { +// CHECK: %[[CONSTANT_0:.*]] = arith.constant 4 : index +// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : index +// CHECK: %[[CONSTANT_2:.*]] = arith.constant 0 : index +// CHECK: %[[CONSTANT_3:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: scf.parallel (%[[VAL_0:.*]]) = (%[[CONSTANT_2]]) to (%[[CONSTANT_0]]) step (%[[CONSTANT_1]]) { +// CHECK: %[[LOAD_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[VAL_0]]] : memref +// CHECK: %[[MULF_0:.*]] = arith.mulf %[[LOAD_0]], %[[LOAD_0]] : f32 +// CHECK: memref.store %[[MULF_0]], %[[ARG2]]{{\[}}%[[VAL_0]]] : memref +// CHECK: scf.reduce +// CHECK: } +// CHECK: scf.parallel (%[[VAL_1:.*]]) = (%[[CONSTANT_2]]) to (%[[CONSTANT_0]]) step (%[[CONSTANT_1]]) { +// CHECK: %[[LOAD_1:.*]] = memref.load %[[ARG0]]{{\[}}%[[VAL_1]]] : memref +// CHECK: %[[LOAD_2:.*]] = memref.load %[[ARG3]]{{\[}}%[[VAL_1]]] : memref +// CHECK: memref.store %[[CONSTANT_3]], %[[ARG3]]{{\[}}%[[VAL_1]]] : memref +// CHECK: %[[MULF_1:.*]] = arith.mulf %[[LOAD_2]], %[[LOAD_1]] : f32 +// CHECK: %[[MULF_2:.*]] = arith.mulf %[[LOAD_2]], %[[LOAD_1]] : f32 +// CHECK: %[[ADDF_0:.*]] = arith.addf %[[MULF_1]], %[[MULF_2]] : f32 +// CHECK: %[[ATOMIC_RMW_0:.*]] = memref.atomic_rmw addf %[[ADDF_0]], %[[ARG1]]{{\[}}%[[VAL_1]]] : (f32, memref) -> f32 +// CHECK: scf.reduce +// CHECK: } +// CHECK: return +// CHECK: }