Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,8 @@ std::optional<Value> getStored(Operation *op) {
return storeOp.getValue();
} else if (auto storeOp = dyn_cast<memref::StoreOp>(op)) {
return storeOp.getValue();
} else if (auto pushOp = dyn_cast<enzyme::PushOp>(op)) {
return pushOp.getValue();
}
return std::nullopt;
}
Expand Down
38 changes: 8 additions & 30 deletions enzyme/Enzyme/MLIR/Analysis/DataFlowAliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ using namespace mlir;
using namespace mlir::dataflow;

static bool isPointerLike(Type type) {
return isa<MemRefType, LLVM::LLVMPointerType>(type);
return isa<MemRefType, LLVM::LLVMPointerType, enzyme::CacheType>(type);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -121,29 +121,6 @@ ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate,
[&](DistinctAttr dest, AliasClassSet::State state) {
assert(state == AliasClassSet::State::Defined &&
"unknown must have been handled above");
#ifndef NDEBUG
if (replace) {
auto it = map.find(dest);
if (it != map.end()) {
// Check that we are updating to a state that's >= in the
// lattice.
// TODO: consider a stricter check that we only replace unknown
// values or a value with itself, currently blocked by memalign.
AliasClassSet valuesCopy(values);
(void)valuesCopy.join(it->getSecond());
values.print(llvm::errs());
llvm::errs() << "\n";
it->getSecond().print(llvm::errs());
llvm::errs() << "\n";
valuesCopy.print(llvm::errs());
llvm::errs() << "\n";
assert(valuesCopy == values &&
"attempting to replace a pointsTo entry with an alias class "
"set that is ordered _before_ the existing one -> "
"non-monotonous update ");
}
}
#endif // NDEBUG
return joinPotentiallyMissing(dest, values);
});
}
Expand Down Expand Up @@ -278,7 +255,7 @@ LogicalResult enzyme::PointsToPointerAnalysis::visitOperation(
// fixpoint and bail.
auto memory = dyn_cast<MemoryEffectOpInterface>(op);
if (!memory) {
if (isNoOp(op))
if (isNoOp(op) || isMemoryEffectFree(op))
return success();
propagateIfChanged(after, after->markAllPointToUnknown());
return success();
Expand Down Expand Up @@ -557,7 +534,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
std::optional<LLVM::ModRefInfo> otherModRef =
getFunctionOtherModRef(callee);

SmallVector<int> pointerLikeOperands;
SmallVector<unsigned> pointerLikeOperands;
for (auto &&[i, operand] : llvm::enumerate(call.getArgOperands())) {
if (isPointerLike(operand.getType()))
pointerLikeOperands.push_back(i);
Expand All @@ -575,7 +552,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
// unknown alias sets into any writable pointer.
(void)functionMayCapture.markUnknown();
} else {
for (int pointerAsData : pointerLikeOperands) {
for (unsigned pointerAsData : pointerLikeOperands) {
// If not captured, it cannot be stored in anything.
if ((pointerAsData < numArguments &&
!!callee.getArgAttr(pointerAsData,
Expand All @@ -593,7 +570,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
AliasClassSet writableClasses = AliasClassSet::getUndefined();
AliasClassSet nonWritableOperandClasses = AliasClassSet::getUndefined();
ChangeResult changed = ChangeResult::NoChange;
for (int pointerOperand : pointerLikeOperands) {
for (unsigned pointerOperand : pointerLikeOperands) {
auto *destClasses = getOrCreateFor<AliasClassLattice>(
getProgramPointAfter(call), call.getArgOperands()[pointerOperand]);

Expand Down Expand Up @@ -696,7 +673,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
continue;
}

for (int operandNo : pointerLikeOperands) {
for (unsigned operandNo : pointerLikeOperands) {
const auto *srcClasses = getOrCreateFor<AliasClassLattice>(
getProgramPointAfter(call), call.getArgOperands()[operandNo]);
if (mayReadArg(callee, operandNo, argModRef)) {
Expand Down Expand Up @@ -840,7 +817,8 @@ static bool isAliasTransferFullyDescribedByMemoryEffects(Operation *op) {
}
}
}
return isa<memref::LoadOp, memref::StoreOp, LLVM::LoadOp, LLVM::StoreOp>(op);
return isa<memref::LoadOp, memref::StoreOp, LLVM::LoadOp, LLVM::StoreOp,
enzyme::PushOp, enzyme::PopOp>(op);
}

void enzyme::AliasAnalysis::transfer(
Expand Down
6 changes: 3 additions & 3 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -279,23 +279,23 @@ def PushOp : Enzyme_Op<"push", [
"cache", "value",
"::llvm::cast<enzyme::CacheType>($_self).getType()">]> {
let summary = "Push value to cache";
let arguments = (ins AnyType : $cache, AnyType : $value);
let arguments = (ins Arg<AnyType, "the cache to push to", [MemWrite]>:$cache, AnyType:$value);
}

def PopOp : Enzyme_Op<"pop", [
TypesMatchWith<"type of 'output' matches element type of 'cache'",
"cache", "output",
"::llvm::cast<enzyme::CacheType>($_self).getType()">]> {
let summary = "Retrieve information for the reverse mode pass.";
let arguments = (ins AnyType : $cache);
let arguments = (ins Arg<AnyType, "the cache to pop from", [MemRead, MemWrite]>:$cache);
let results = (outs AnyType:$output);
}

def InitOp : Enzyme_Op<"init",
[DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
let summary = "Create enzyme.gradient and enzyme.cache";
let arguments = (ins );
let results = (outs AnyType);
let results = (outs Res<AnyType, "", [MemAlloc<DefaultResource, 0, FullEffect>]>);
}

def Cache : Enzyme_Type<"Cache"> {
Expand Down
82 changes: 77 additions & 5 deletions enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
#include <cassert>
#include <deque>

#include "Analysis/DataFlowAliasAnalysis.h"
#include "mlir/Analysis/DataFlow/Utils.h"

#include "llvm/ADT/MapVector.h"

using namespace mlir;
Expand Down Expand Up @@ -263,11 +266,76 @@ static inline void bfs(const Graph &G, const llvm::SetVector<Value> &Sources,
}
}

struct OverwriteAnalyzer {
static OverwriteAnalyzer analyzeFunc(FunctionOpInterface funcOp) {
return OverwriteAnalyzer(funcOp);
}

bool isPtrPotentiallyModified(Value ptr) const {
// If the alias analysis failed, conservatively assume all pointers may
// be modified
if (!valid)
return true;

// Check if the pointer's alias classes intersect the modified alias classes
auto *ptrClass = solver.lookupState<AliasClassLattice>(ptr);
return !ptrClass->alias(modified).isNo();
}

private:
DataFlowSolver solver;
// The set of all alias classes that are potentially modified in the function
AliasClassLattice modified;
bool valid = true;

OverwriteAnalyzer(FunctionOpInterface funcOp)
: solver(DataFlowConfig().setInterprocedural(false)), modified(nullptr) {
dataflow::loadBaselineAnalyses(solver);
solver.load<enzyme::AliasAnalysis>(funcOp.getContext(), /*relative=*/false);
solver.load<enzyme::PointsToPointerAnalysis>();
if (failed(solver.initializeAndRun(funcOp))) {
assert(false && "dataflow analysis failed");
valid = false;
} else {
funcOp.walk([&](MemoryEffectOpInterface memory) {
SmallVector<MemoryEffects::EffectInstance> effects;
memory.getEffects(effects);
for (const auto &effect : effects) {
if (isa<MemoryEffects::Write>(effect.getEffect())) {
Value val = effect.getValue();
if (val) {
(void)modified.join(*solver.lookupState<AliasClassLattice>(val));
} else {
(void)modified.markUnknown();
}
}
}
});
}
}
};

bool isLoadMovable(const OverwriteAnalyzer &analyzer, Operation *op) {
if (!hasSingleEffect<MemoryEffects::Read>(op)) {
return false;
}
auto memory = cast<MemoryEffectOpInterface>(op);
SmallVector<MemoryEffects::EffectInstance> effects;
memory.getEffects(effects);
assert(effects.size() == 1 &&
isa<MemoryEffects::Read>(effects.front().getEffect()));
Value ptr = effects.front().getValue();

// The load can be re-done if the pointer's contents are never modified
// by the function.
return !analyzer.isPtrPotentiallyModified(ptr);
}

// Whether or not an operation can be moved from the forward region to the
// reverse region or vice-versa.
static inline bool isMovable(Operation *op) {
static inline bool isMovable(const OverwriteAnalyzer &analyzer, Operation *op) {
return op->getNumRegions() == 0 && op->getBlock()->getTerminator() != op &&
mlir::isPure(op);
(mlir::isPure(op) || isLoadMovable(analyzer, op));
}

// Given a graph `G`, construct a new graph `G2`, where all paths must terminate
Expand Down Expand Up @@ -487,6 +555,8 @@ void mlir::enzyme::minCutCache(Block *forward, Block *reverse,
}

Graph G;
auto overwriteAnalyzer = OverwriteAnalyzer::analyzeFunc(
forward->getParent()->getParentOfType<FunctionOpInterface>());

LLVM_DEBUG(llvm::dbgs() << "trying min/cut\n");
LLVM_DEBUG(
Expand Down Expand Up @@ -518,7 +588,7 @@ void mlir::enzyme::minCutCache(Block *forward, Block *reverse,
}

Operation *owner = todo.getDefiningOp();
if (!owner || !isMovable(owner)) {
if (!owner || !isMovable(overwriteAnalyzer, owner)) {
roots.insert(todo);
continue;
}
Expand All @@ -544,7 +614,8 @@ void mlir::enzyme::minCutCache(Block *forward, Block *reverse,

bool isRequired = false;
for (auto user : poped.getUsers()) {
if (user->getBlock() != reverse || !isMovable(user)) {
if (user->getBlock() != reverse ||
!isMovable(overwriteAnalyzer, user)) {
G[info.pushedValue()].insert(Node(user));
Required.insert(user);
isRequired = true;
Expand All @@ -567,7 +638,8 @@ void mlir::enzyme::minCutCache(Block *forward, Block *reverse,

bool isRequired = false;
for (auto user : todo.getUsers()) {
if (user->getBlock() != reverse || !isMovable(user)) {
if (user->getBlock() != reverse ||
!isMovable(overwriteAnalyzer, user)) {
G[todo].insert(Node(user));
Required.insert(user);
isRequired = true;
Expand Down
58 changes: 27 additions & 31 deletions enzyme/test/MLIR/ReverseMode/scf_parallel.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: %eopt %s --pass-pipeline="builtin.module(enzyme,canonicalize,remove-unnecessary-enzyme-ops)" | FileCheck %s
func.func @foo(%x: memref<?xf32>, %y: memref<?xf32>) {
// RUN: %eopt %s --pass-pipeline="builtin.module(enzyme,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math)" | FileCheck %s
func.func @foo(%x: memref<?xf32> {llvm.noalias}, %y: memref<?xf32> {llvm.noalias}) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
Expand All @@ -12,39 +12,35 @@ func.func @foo(%x: memref<?xf32>, %y: memref<?xf32>) {
return
}

func.func @dfoo(%x: memref<?xf32>, %dx: memref<?xf32>, %y: memref<?xf32>, %dy: memref<?xf32>) {
func.func @dfoo(%x: memref<?xf32> {llvm.noalias}, %dx: memref<?xf32> {llvm.noalias}, %y: memref<?xf32> {llvm.noalias}, %dy: memref<?xf32> {llvm.noalias}) {
enzyme.autodiff @foo(%x, %dx, %y, %dy) {
activity = [#enzyme<activity enzyme_dup>, #enzyme<activity enzyme_dup>],
ret_activity = []
} : (memref<?xf32>, memref<?xf32>, memref<?xf32>, memref<?xf32>) -> ()
return
}

// CHECK: func.func private @diffefoo(%arg0: memref<?xf32>, %arg1: memref<?xf32>, %arg2: memref<?xf32>, %arg3: memref<?xf32>) {
// CHECK-NEXT: %c4 = arith.constant 4 : index
// CHECK-NEXT: %c1 = arith.constant 1 : index
// CHECK-NEXT: %c0 = arith.constant 0 : index
// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %alloc = memref.alloc() : memref<4xf32>
// CHECK-NEXT: scf.parallel (%arg4) = (%c0) to (%c4) step (%c1) {
// CHECK-NEXT: %0 = memref.load %arg0[%arg4] : memref<?xf32>
// CHECK-NEXT: memref.store %0, %alloc[%arg4] : memref<4xf32>
// CHECK-NEXT: %1 = arith.mulf %0, %0 : f32
// CHECK-NEXT: memref.store %1, %arg2[%arg4] : memref<?xf32>
// CHECK-NEXT: scf.reduce
// CHECK-NEXT: }
// CHECK-NEXT: scf.parallel (%arg4) = (%c0) to (%c4) step (%c1) {
// CHECK-NEXT: %0 = memref.load %alloc[%arg4] : memref<4xf32>
// CHECK-NEXT: %1 = memref.load %arg3[%arg4] : memref<?xf32>
// CHECK-NEXT: %2 = arith.addf %1, %cst : f32
// CHECK-NEXT: memref.store %cst, %arg3[%arg4] : memref<?xf32>
// CHECK-NEXT: %3 = arith.mulf %2, %0 : f32
// CHECK-NEXT: %4 = arith.addf %3, %cst : f32
// CHECK-NEXT: %5 = arith.mulf %2, %0 : f32
// CHECK-NEXT: %6 = arith.addf %4, %5 : f32
// CHECK-NEXT: %7 = memref.atomic_rmw addf %6, %arg1[%arg4] : (f32, memref<?xf32>) -> f32
// CHECK-NEXT: scf.reduce
// CHECK-NEXT: }
// CHECK-NEXT: memref.dealloc %alloc : memref<4xf32>
// CHECK-NEXT: return
// CHECK-NEXT: }
// CHECK-LABEL: func.func private @diffefoo(
// CHECK-SAME: %[[ARG0:.*]]: memref<?xf32> {llvm.noalias}, %[[ARG1:.*]]: memref<?xf32>, %[[ARG2:.*]]: memref<?xf32> {llvm.noalias}, %[[ARG3:.*]]: memref<?xf32>) {
// CHECK: %[[CONSTANT_0:.*]] = arith.constant 4 : index
// CHECK: %[[CONSTANT_1:.*]] = arith.constant 1 : index
// CHECK: %[[CONSTANT_2:.*]] = arith.constant 0 : index
// CHECK: %[[CONSTANT_3:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: scf.parallel (%[[VAL_0:.*]]) = (%[[CONSTANT_2]]) to (%[[CONSTANT_0]]) step (%[[CONSTANT_1]]) {
// CHECK: %[[LOAD_0:.*]] = memref.load %[[ARG0]]{{\[}}%[[VAL_0]]] : memref<?xf32>
// CHECK: %[[MULF_0:.*]] = arith.mulf %[[LOAD_0]], %[[LOAD_0]] : f32
// CHECK: memref.store %[[MULF_0]], %[[ARG2]]{{\[}}%[[VAL_0]]] : memref<?xf32>
// CHECK: scf.reduce
// CHECK: }
// CHECK: scf.parallel (%[[VAL_1:.*]]) = (%[[CONSTANT_2]]) to (%[[CONSTANT_0]]) step (%[[CONSTANT_1]]) {
// CHECK: %[[LOAD_1:.*]] = memref.load %[[ARG0]]{{\[}}%[[VAL_1]]] : memref<?xf32>
// CHECK: %[[LOAD_2:.*]] = memref.load %[[ARG3]]{{\[}}%[[VAL_1]]] : memref<?xf32>
// CHECK: memref.store %[[CONSTANT_3]], %[[ARG3]]{{\[}}%[[VAL_1]]] : memref<?xf32>
// CHECK: %[[MULF_1:.*]] = arith.mulf %[[LOAD_2]], %[[LOAD_1]] : f32
// CHECK: %[[MULF_2:.*]] = arith.mulf %[[LOAD_2]], %[[LOAD_1]] : f32
// CHECK: %[[ADDF_0:.*]] = arith.addf %[[MULF_1]], %[[MULF_2]] : f32
// CHECK: %[[ATOMIC_RMW_0:.*]] = memref.atomic_rmw addf %[[ADDF_0]], %[[ARG1]]{{\[}}%[[VAL_1]]] : (f32, memref<?xf32>) -> f32
// CHECK: scf.reduce
// CHECK: }
// CHECK: return
// CHECK: }
Loading