Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions enzyme/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,24 @@ td_library(
],
)

td_library(
name = "ImpulseDialectTdFiles",
srcs = [
"Enzyme/MLIR/Dialect/Impulse/Dialect.td",
"Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.td",
],
includes = ["."],
deps = [
":EnzymeDialectTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:ControlFlowInterfacesTdFiles",
"@llvm-project//mlir:FunctionInterfacesTdFiles",
"@llvm-project//mlir:LoopLikeInterfaceTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
"@llvm-project//mlir:ViewLikeInterfaceTdFiles",
],
)

cc_library(
name = "LLVMExtDialect",
srcs = [
Expand Down Expand Up @@ -621,6 +639,111 @@ gentbl_cc_library(
deps = [":EnzymeDialectTdFiles"],
)

gentbl_cc_library(
name = "ImpulseOpsIncGen",
strip_include_prefix = "Enzyme/MLIR",
tbl_outs = [
(
["-gen-op-decls"],
"Enzyme/MLIR/Dialect/Impulse/ImpulseOps.h.inc",
),
(
["-gen-op-defs"],
"Enzyme/MLIR/Dialect/Impulse/ImpulseOps.cpp.inc",
),
(
[
"-gen-dialect-decls",
"-dialect=impulse",
],
"Enzyme/MLIR/Dialect/Impulse/ImpulseOpsDialect.h.inc",
),
(
[
"-gen-dialect-defs",
"-dialect=impulse",
],
"Enzyme/MLIR/Dialect/Impulse/ImpulseOpsDialect.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td",
deps = [":ImpulseDialectTdFiles"],
)

gentbl_cc_library(
name = "ImpulseTypesIncGen",
strip_include_prefix = "Enzyme/MLIR",
tbl_outs = [
(
["-gen-typedef-decls"],
"Enzyme/MLIR/Dialect/Impulse/ImpulseOpsTypes.h.inc",
),
(
["-gen-typedef-defs"],
"Enzyme/MLIR/Dialect/Impulse/ImpulseOpsTypes.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td",
deps = [":ImpulseDialectTdFiles"],
)

gentbl_cc_library(
name = "ImpulseEnumsIncGen",
strip_include_prefix = "Enzyme/MLIR",
tbl_outs = [
(
["-gen-enum-decls"],
"Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.h.inc",
),
(
["-gen-enum-defs"],
"Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Enzyme/MLIR/Dialect/Impulse/ImpulseEnums.td",
deps = [":ImpulseDialectTdFiles"],
)

gentbl_cc_library(
name = "ImpulseAttributesIncGen",
strip_include_prefix = "Enzyme/MLIR",
tbl_outs = [
(
[
"-gen-attrdef-decls",
"-attrdefs-dialect=impulse",
],
"Enzyme/MLIR/Dialect/Impulse/ImpulseAttributes.h.inc",
),
(
[
"-gen-attrdef-defs",
"-attrdefs-dialect=impulse",
],
"Enzyme/MLIR/Dialect/Impulse/ImpulseAttributes.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td",
deps = [":ImpulseDialectTdFiles"],
)

gentbl_cc_library(
name = "ImpulseAttributeInterfacesIncGen",
tbl_outs = [
(
["-gen-attr-interface-decls"],
"Enzyme/MLIR/Dialect/Impulse/ImpulseAttributeInterfaces.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Enzyme/MLIR/Dialect/Impulse/ImpulseOps.td",
deps = [":ImpulseDialectTdFiles"],
)

gentbl_cc_library(
name = "EnzymeTypeInterfacesIncGen",
strip_include_prefix = "Enzyme",
Expand Down Expand Up @@ -820,17 +943,20 @@ gentbl_cc_library(
deps = [":ImplementationsCommonTdFiles"],
)


cc_library(
name = "EnzymeMLIR",
srcs = glob([
"Enzyme/MLIR/Dialect/*.cpp",
"Enzyme/MLIR/Dialect/Impulse/*.cpp",
"Enzyme/MLIR/Passes/*.cpp",
"Enzyme/MLIR/Interfaces/*.cpp",
"Enzyme/MLIR/Analysis/*.cpp",
"Enzyme/MLIR/Implementations/*.cpp",
]),
hdrs = glob([
"Enzyme/MLIR/Dialect/*.h",
"Enzyme/MLIR/Dialect/Impulse/*.h",
"Enzyme/MLIR/Passes/*.h",
"Enzyme/MLIR/Interfaces/*.h",
"Enzyme/MLIR/Analysis/*.h",
Expand All @@ -852,6 +978,11 @@ cc_library(
":EnzymePassesIncGen",
":EnzymeTypeInterfacesIncGen",
":EnzymeTypesIncGen",
":ImpulseAttributeInterfacesIncGen",
":ImpulseAttributesIncGen",
":ImpulseEnumsIncGen",
":ImpulseOpsIncGen",
":ImpulseTypesIncGen",
":LLVMExtDialect",
":affine-derivatives",
":arith-derivatives",
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ MLIRIR
MLIRMemRefDialect
)

add_subdirectory(Impulse)
add_subdirectory(LLVMExt)
90 changes: 0 additions & 90 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,94 +33,4 @@ def Activity : I32EnumAttr<"Activity",

def ActivityAttr : EnumAttr<Enzyme_Dialect, Activity, "activity">;

def RngDistribution : I32EnumAttr<"RngDistribution",
"Random number distribution type",
[
I32EnumAttrCase<"UNIFORM", 0>,
I32EnumAttrCase<"NORMAL", 1>,
I32EnumAttrCase<"MULTINORMAL", 2>,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::enzyme";
}

def RngDistributionAttr : EnumAttr<Enzyme_Dialect, RngDistribution, "rng_distribution">;

def Transpose : I32EnumAttr<"Transpose",
"Transpose mode for matrix operations",
[
I32EnumAttrCase<"NO_TRANSPOSE", 0>,
I32EnumAttrCase<"TRANSPOSE", 1>,
I32EnumAttrCase<"ADJOINT", 2>,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::enzyme";
}

def TransposeAttr : EnumAttr<Enzyme_Dialect, Transpose, "transpose">;

def SupportKind : I32EnumAttr<"SupportKind",
"Domain where distribution density is non-zero",
[
I32EnumAttrCase<"REAL", 0>,
I32EnumAttrCase<"POSITIVE", 1>,
I32EnumAttrCase<"UNIT_INTERVAL", 2>,
I32EnumAttrCase<"INTERVAL", 3>,
I32EnumAttrCase<"GREATER_THAN", 4>,
I32EnumAttrCase<"LESS_THAN", 5>,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::enzyme";
}

def SupportKindAttr : EnumAttr<Enzyme_Dialect, SupportKind, "support_kind">;

def SupportAttr : Enzyme_Attr<"Support", "support"> {
let summary = "Distribution support specification for constraint transforms";
let description = [{
Distribution support specification for constraint transforms.
}];
let parameters = (ins
"SupportKind":$kind,
OptionalParameter<"FloatAttr">:$lower_bound,
OptionalParameter<"FloatAttr">:$upper_bound
);
let assemblyFormat = "`<` $kind (`,` `lower` `=` $lower_bound^)? (`,` `upper` `=` $upper_bound^)? `>`";
}

def HMCConfigAttr : Enzyme_Attr<"HMCConfig", "hmc_config"> {
let summary = "Configuration for HMC inference";
let description = [{
HMC-specific parameters for fixed-trajectory Hamiltonian Monte Carlo.
- trajectory_length: Total trajectory length for leapfrog integration.
- adapt_step_size (default: true): Whether to adapt step size during warmup using dual averaging.
- adapt_mass_matrix (default: true): Whether to adapt the mass matrix during warmup using Welford covariance.
}];
let parameters = (ins
"FloatAttr":$trajectory_length,
DefaultValuedParameter<"bool", "true">:$adapt_step_size,
DefaultValuedParameter<"bool", "true">:$adapt_mass_matrix
);
let assemblyFormat = "`<` `trajectory_length` `=` $trajectory_length (`,` `adapt_step_size` `=` $adapt_step_size^)? (`,` `adapt_mass_matrix` `=` $adapt_mass_matrix^)? `>`";
}

def NUTSConfigAttr : Enzyme_Attr<"NUTSConfig", "nuts_config"> {
let summary = "Configuration for NUTS inference";
let description = [{
NUTS-specific parameters for No-U-Turn Sampler with adaptive path length.
- max_tree_depth: Maximum tree depth, controls max leapfrog steps (2^max_tree_depth).
- max_delta_energy: Threshold for trajectory divergence (rejected if energy change exceeds this).
Default: 1000.0
- adapt_step_size (default: true): Whether to adapt step size during warmup using dual averaging.
- adapt_mass_matrix (default: true): Whether to adapt the mass matrix during warmup using Welford covariance.
}];
let parameters = (ins
DefaultValuedParameter<"int64_t", "10">:$max_tree_depth,
OptionalParameter<"FloatAttr">:$max_delta_energy,
DefaultValuedParameter<"bool", "true">:$adapt_step_size,
DefaultValuedParameter<"bool", "true">:$adapt_mass_matrix
);
let assemblyFormat = "`<` (`max_tree_depth` `=` $max_tree_depth^)? (`,` `max_delta_energy` `=` $max_delta_energy^)? (`,` `adapt_step_size` `=` $adapt_step_size^)? (`,` `adapt_mass_matrix` `=` $adapt_mass_matrix^)? `>`";
}

#endif // ENZYME_ENUMS
Loading
Loading